lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import platform
|
|
3
|
+
import shutil
|
|
4
|
+
import sys
|
|
5
|
+
import math
|
|
6
|
+
import dataclasses
|
|
7
|
+
from typing import Callable, List, Union, Dict, Optional
|
|
8
|
+
import textwrap
|
|
9
|
+
import psutil
|
|
10
|
+
from lemonade.common import printing
|
|
11
|
+
from lemonade.state import State
|
|
12
|
+
import lemonade.common.build as build
|
|
13
|
+
import lemonade.common.filesystem as fs
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _pretty_print_key(key: str) -> str:
|
|
17
|
+
result = key.split("_")
|
|
18
|
+
result = [word.capitalize() for word in result]
|
|
19
|
+
result = " ".join(result)
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PrettyFloat(float):
|
|
24
|
+
def __repr__(self):
|
|
25
|
+
return f"{self:0.3f}"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def parameters_to_size(parameters: int, byte_per_parameter: int = 4) -> str:
|
|
29
|
+
size_bytes = parameters * byte_per_parameter
|
|
30
|
+
if size_bytes == 0:
|
|
31
|
+
return "0B"
|
|
32
|
+
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
|
|
33
|
+
i = int(math.floor(math.log(size_bytes, 1024)))
|
|
34
|
+
p = math.pow(1024, i)
|
|
35
|
+
s = round(size_bytes / p, 2)
|
|
36
|
+
return "%s %s" % (s, size_name[i])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclasses.dataclass
|
|
40
|
+
class BasicInfo:
|
|
41
|
+
name: str
|
|
42
|
+
script_name: str
|
|
43
|
+
file: str = ""
|
|
44
|
+
line: int = 0
|
|
45
|
+
params: int = 0
|
|
46
|
+
depth: int = 0
|
|
47
|
+
parent_hash: Union[str, None] = None
|
|
48
|
+
model_class: type = None
|
|
49
|
+
# This is the "model hash", not to be confused with the
|
|
50
|
+
# "invocation hash"
|
|
51
|
+
hash: Union[str, None] = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclasses.dataclass
|
|
55
|
+
class SkipFields:
|
|
56
|
+
"""
|
|
57
|
+
Keep track of which fields of a model's status should be skipped
|
|
58
|
+
during printout. There are two use cases in mind:
|
|
59
|
+
- For incremental printout: fields that have already been printed.
|
|
60
|
+
- For low-verbosity: fields that should never be printed.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
file_name: bool = False
|
|
64
|
+
model_name: bool = False
|
|
65
|
+
location: bool = False
|
|
66
|
+
input_shape: bool = False
|
|
67
|
+
build_dir: bool = False
|
|
68
|
+
unique_input_shape: bool = False
|
|
69
|
+
previous_status_message: Optional[str] = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclasses.dataclass
|
|
73
|
+
class UniqueInvocationInfo(BasicInfo):
|
|
74
|
+
"""
|
|
75
|
+
Refers to unique static model invocations
|
|
76
|
+
(i.e. models executed with unique input shapes)
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
invocation_hash: Union[str, None] = None
|
|
80
|
+
traceback: List[str] = None
|
|
81
|
+
inputs: Union[dict, None] = None
|
|
82
|
+
input_shapes: Union[dict, None] = None
|
|
83
|
+
executed: int = 0
|
|
84
|
+
exec_time: float = 0.0
|
|
85
|
+
status_message: str = ""
|
|
86
|
+
extra_status: Optional[str] = ""
|
|
87
|
+
is_target: bool = False
|
|
88
|
+
auto_selected: bool = False
|
|
89
|
+
status_message_color: printing.Colors = printing.Colors.ENDC
|
|
90
|
+
traceback_message_color: printing.Colors = printing.Colors.FAIL
|
|
91
|
+
stats_keys: List[str] = dataclasses.field(default_factory=list)
|
|
92
|
+
forward_function_pointer: callable = None
|
|
93
|
+
original_forward_function: callable = None
|
|
94
|
+
# Fields specific to printing status
|
|
95
|
+
skip: SkipFields = None
|
|
96
|
+
extension: str = None
|
|
97
|
+
indent: str = None
|
|
98
|
+
|
|
99
|
+
def __post_init__(self):
|
|
100
|
+
self.skip = SkipFields()
|
|
101
|
+
|
|
102
|
+
def _print_heading(
|
|
103
|
+
self,
|
|
104
|
+
exec_time_formatted: str,
|
|
105
|
+
print_file_name: bool,
|
|
106
|
+
model_visited: bool,
|
|
107
|
+
multiple_unique_invocations: bool,
|
|
108
|
+
):
|
|
109
|
+
if self.skip.file_name or self.skip.model_name:
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
if print_file_name:
|
|
113
|
+
print(f"{self.script_name}{self.extension}:")
|
|
114
|
+
|
|
115
|
+
# Print invocation about the model (only applies to scripts, not ONNX or GGUF files, nor
|
|
116
|
+
# LLMs, which have no extension)
|
|
117
|
+
if not (
|
|
118
|
+
self.extension in [".onnx", ".gguf"]
|
|
119
|
+
or self.extension == build.state_file_name
|
|
120
|
+
or self.extension == ""
|
|
121
|
+
):
|
|
122
|
+
if self.depth == 0 and multiple_unique_invocations:
|
|
123
|
+
if not model_visited:
|
|
124
|
+
printing.logn(f"{self.indent}{self.name}")
|
|
125
|
+
else:
|
|
126
|
+
printing.log(f"{self.indent}{self.name}")
|
|
127
|
+
printing.logn(
|
|
128
|
+
f" (executed {self.executed}x{exec_time_formatted})",
|
|
129
|
+
c=printing.Colors.OKGREEN,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self.skip.file_name = True
|
|
133
|
+
self.skip.model_name = True
|
|
134
|
+
|
|
135
|
+
def _print_location(self):
|
|
136
|
+
if self.skip.location or self.file == "":
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
if self.depth == 0:
|
|
140
|
+
print(f"{self.indent}\tLocation:\t{self.file}", end="")
|
|
141
|
+
if self.extension in [".onnx", ".gguf"]:
|
|
142
|
+
print()
|
|
143
|
+
else:
|
|
144
|
+
print(f", line {self.line}")
|
|
145
|
+
self.skip.location = True
|
|
146
|
+
|
|
147
|
+
def _print_unique_input_shape(
|
|
148
|
+
self,
|
|
149
|
+
exec_time_formatted: str,
|
|
150
|
+
invocation_idx: int,
|
|
151
|
+
multiple_unique_invocations: bool,
|
|
152
|
+
):
|
|
153
|
+
if self.skip.unique_input_shape:
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
if self.depth == 0 and multiple_unique_invocations:
|
|
157
|
+
printing.logn(
|
|
158
|
+
f"\n{self.indent}\tWith input shape {invocation_idx+1} "
|
|
159
|
+
f"(executed {self.executed}x{exec_time_formatted})",
|
|
160
|
+
c=printing.Colors.OKGREEN,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
self.skip.unique_input_shape = True
|
|
164
|
+
|
|
165
|
+
def _print_input_shape(self):
|
|
166
|
+
if self.skip.input_shape or self.input_shapes is None:
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
# Prepare input shape to be printed
|
|
170
|
+
input_shape = dict(self.input_shapes)
|
|
171
|
+
input_shape = {key: value for key, value in input_shape.items() if value != ()}
|
|
172
|
+
input_shape = str(input_shape).replace("{", "").replace("}", "")
|
|
173
|
+
|
|
174
|
+
print(f"{self.indent}\tInput Shape:\t{input_shape}")
|
|
175
|
+
|
|
176
|
+
self.skip.input_shape = True
|
|
177
|
+
|
|
178
|
+
def _print_build_dir(self, cache_dir: str, build_name: str):
|
|
179
|
+
if self.skip.build_dir or not self.is_target:
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
print(f"{self.indent}\tBuild dir:\t{build.output_dir(cache_dir, build_name)}")
|
|
183
|
+
|
|
184
|
+
self.skip.build_dir = True
|
|
185
|
+
|
|
186
|
+
def _print_peak_memory(self):
|
|
187
|
+
if platform.system() == "Windows":
|
|
188
|
+
print(
|
|
189
|
+
f"{self.indent}\tPeak memory:\t"
|
|
190
|
+
f"{psutil.Process().memory_info().peak_wset / 1024**3:,.3f} GB"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def _print_status(self, cache_dir: str, build_name: str):
|
|
194
|
+
stats = fs.Stats(cache_dir, build_name)
|
|
195
|
+
if self.skip.previous_status_message:
|
|
196
|
+
if self.skip.previous_status_message == self.status_message:
|
|
197
|
+
# This is a special case for skipping: we only want to skip
|
|
198
|
+
# printing the outcome if we have already printed that
|
|
199
|
+
# exact message already.
|
|
200
|
+
return
|
|
201
|
+
else:
|
|
202
|
+
# Print some whitespace to help the status stand out
|
|
203
|
+
print()
|
|
204
|
+
|
|
205
|
+
printing.log(f"{self.indent}\tStatus:\t\t")
|
|
206
|
+
printing.logn(
|
|
207
|
+
f"{self.status_message}",
|
|
208
|
+
c=self.status_message_color,
|
|
209
|
+
)
|
|
210
|
+
if self.is_target:
|
|
211
|
+
|
|
212
|
+
# Get the maximum key length to figure out the number
|
|
213
|
+
# of tabs needed to align the values
|
|
214
|
+
max_key_len = 0
|
|
215
|
+
for key in self.stats_keys:
|
|
216
|
+
max_key_len = max(len(_pretty_print_key(key)), max_key_len)
|
|
217
|
+
|
|
218
|
+
screen_width = shutil.get_terminal_size().columns
|
|
219
|
+
wrap_screen_width = screen_width - 2
|
|
220
|
+
|
|
221
|
+
for key in self.stats_keys:
|
|
222
|
+
nice_key = _pretty_print_key(key)
|
|
223
|
+
try:
|
|
224
|
+
value = stats.stats[key]
|
|
225
|
+
if isinstance(value, float):
|
|
226
|
+
value = PrettyFloat(value)
|
|
227
|
+
elif isinstance(value, list):
|
|
228
|
+
value = [
|
|
229
|
+
PrettyFloat(v) if isinstance(v, float) else v for v in value
|
|
230
|
+
]
|
|
231
|
+
# Tools may provide a unit of measurement for their status
|
|
232
|
+
# stats, whose key name should follow the format
|
|
233
|
+
# "STATUS_STATS_KEY_units"
|
|
234
|
+
units_key = key + "_units"
|
|
235
|
+
units = stats.stats.get(units_key)
|
|
236
|
+
units = units if units is not None else ""
|
|
237
|
+
if self.extension == "":
|
|
238
|
+
value_tabs = " " * (
|
|
239
|
+
(max_key_len - len(_pretty_print_key(key))) + 1
|
|
240
|
+
)
|
|
241
|
+
hanging_indent = (
|
|
242
|
+
len(self.indent) + 8 + len(nice_key) + 1 + len(value_tabs)
|
|
243
|
+
)
|
|
244
|
+
hanging_indent_str = " " * hanging_indent
|
|
245
|
+
if (
|
|
246
|
+
isinstance(value, list)
|
|
247
|
+
and len(value) > 0
|
|
248
|
+
and all(isinstance(item, str) for item in value)
|
|
249
|
+
):
|
|
250
|
+
# Value is a list of strings, so output each one starting
|
|
251
|
+
# on its own line
|
|
252
|
+
printing.logn(f"{self.indent}\t{nice_key}:{value_tabs}[")
|
|
253
|
+
for line_counter, text in enumerate(value):
|
|
254
|
+
lines = textwrap.wrap(
|
|
255
|
+
"'" + text + "'",
|
|
256
|
+
width=wrap_screen_width,
|
|
257
|
+
initial_indent=hanging_indent_str,
|
|
258
|
+
subsequent_indent=hanging_indent_str,
|
|
259
|
+
)
|
|
260
|
+
if line_counter + 1 < len(value):
|
|
261
|
+
# Not the last text item in the list, so add a comma
|
|
262
|
+
lines[-1] = lines[-1] + ","
|
|
263
|
+
for line in lines:
|
|
264
|
+
printing.logn(line)
|
|
265
|
+
printing.logn(f"{' ' * hanging_indent}] {units}")
|
|
266
|
+
else:
|
|
267
|
+
# Wrap value as needed
|
|
268
|
+
status_str = (
|
|
269
|
+
f"{self.indent}\t{nice_key}:{value_tabs}{value} {units}"
|
|
270
|
+
)
|
|
271
|
+
lines = textwrap.wrap(
|
|
272
|
+
status_str,
|
|
273
|
+
width=wrap_screen_width,
|
|
274
|
+
subsequent_indent=hanging_indent_str,
|
|
275
|
+
)
|
|
276
|
+
for line in lines:
|
|
277
|
+
printing.logn(line)
|
|
278
|
+
else:
|
|
279
|
+
printing.logn(
|
|
280
|
+
f"{self.indent}\t\t\t{nice_key}:\t{value} {units}"
|
|
281
|
+
)
|
|
282
|
+
except KeyError:
|
|
283
|
+
# Ignore any keys that are missing because that means the
|
|
284
|
+
# evaluation did not produce them
|
|
285
|
+
pass
|
|
286
|
+
|
|
287
|
+
if self.traceback is not None:
|
|
288
|
+
if os.environ.get("LEMONADE_TRACEBACK") != "False":
|
|
289
|
+
for line in self.traceback:
|
|
290
|
+
for subline in line.split("\n")[:-1]:
|
|
291
|
+
print(f"{self.indent}\t{subline}")
|
|
292
|
+
|
|
293
|
+
else:
|
|
294
|
+
printing.logn(
|
|
295
|
+
f"{self.indent}\t\t\tTo see the full stack trace, "
|
|
296
|
+
"rerun with `export LEMONADE_TRACEBACK=True`.\n",
|
|
297
|
+
c=self.status_message_color,
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
print()
|
|
301
|
+
|
|
302
|
+
self.skip.previous_status_message = self.status_message
|
|
303
|
+
|
|
304
|
+
def print(
|
|
305
|
+
self,
|
|
306
|
+
build_name: str,
|
|
307
|
+
cache_dir: str,
|
|
308
|
+
print_file_name: bool = False,
|
|
309
|
+
invocation_idx: int = 0,
|
|
310
|
+
model_visited: bool = False,
|
|
311
|
+
multiple_unique_invocations: bool = False,
|
|
312
|
+
):
|
|
313
|
+
"""
|
|
314
|
+
Print information about a given model or submodel.
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
if self.extension in [".onnx", ".gguf"] or self.extension == "":
|
|
318
|
+
self.indent = "\t" * (2 * self.depth)
|
|
319
|
+
else:
|
|
320
|
+
self.indent = "\t" * (2 * self.depth + 1)
|
|
321
|
+
|
|
322
|
+
if self.exec_time == 0:
|
|
323
|
+
exec_time_formatted = ""
|
|
324
|
+
else:
|
|
325
|
+
exec_time_formatted = f" - {self.exec_time:.2f}s"
|
|
326
|
+
|
|
327
|
+
self._print_heading(
|
|
328
|
+
exec_time_formatted,
|
|
329
|
+
print_file_name,
|
|
330
|
+
model_visited,
|
|
331
|
+
multiple_unique_invocations,
|
|
332
|
+
)
|
|
333
|
+
if (self.depth == 0 and not model_visited) or (self.depth != 0):
|
|
334
|
+
# Print this information only once per model
|
|
335
|
+
self._print_location()
|
|
336
|
+
self._print_unique_input_shape(
|
|
337
|
+
exec_time_formatted, invocation_idx, multiple_unique_invocations
|
|
338
|
+
)
|
|
339
|
+
self._print_input_shape()
|
|
340
|
+
self._print_build_dir(cache_dir=cache_dir, build_name=build_name)
|
|
341
|
+
self._print_peak_memory()
|
|
342
|
+
self._print_status(cache_dir=cache_dir, build_name=build_name)
|
|
343
|
+
|
|
344
|
+
print()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@dataclasses.dataclass
|
|
348
|
+
class ModelInfo(BasicInfo):
|
|
349
|
+
model: str = None
|
|
350
|
+
old_forward: Union[Callable, None] = None
|
|
351
|
+
unique_invocations: Union[Dict[str, UniqueInvocationInfo], None] = (
|
|
352
|
+
dataclasses.field(default_factory=dict)
|
|
353
|
+
)
|
|
354
|
+
last_unique_invocation_executed: Union[str, None] = None
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def recursive_print(
|
|
358
|
+
models_found: Dict[str, ModelInfo],
|
|
359
|
+
build_name: str,
|
|
360
|
+
cache_dir: str,
|
|
361
|
+
parent_model_hash: Union[str, None] = None,
|
|
362
|
+
parent_invocation_hash: Union[str, None] = None,
|
|
363
|
+
script_names_visited: List[str] = False,
|
|
364
|
+
) -> None:
|
|
365
|
+
script_names_visited = []
|
|
366
|
+
|
|
367
|
+
for model_hash in models_found.keys():
|
|
368
|
+
model_visited = False
|
|
369
|
+
model_info = models_found[model_hash]
|
|
370
|
+
invocation_idx = 0
|
|
371
|
+
for invocation_hash in model_info.unique_invocations.keys():
|
|
372
|
+
unique_invocation = model_info.unique_invocations[invocation_hash]
|
|
373
|
+
|
|
374
|
+
if (
|
|
375
|
+
parent_model_hash == model_info.parent_hash
|
|
376
|
+
and unique_invocation.executed > 0
|
|
377
|
+
and (
|
|
378
|
+
model_info.unique_invocations[invocation_hash].parent_hash
|
|
379
|
+
== parent_invocation_hash
|
|
380
|
+
)
|
|
381
|
+
):
|
|
382
|
+
print_file_name = False
|
|
383
|
+
if model_info.script_name not in script_names_visited:
|
|
384
|
+
script_names_visited.append(model_info.script_name)
|
|
385
|
+
if model_info.depth == 0:
|
|
386
|
+
print_file_name = True
|
|
387
|
+
|
|
388
|
+
# In this verbosity mode we want to print all of the information
|
|
389
|
+
# every time, so reset SkipFields
|
|
390
|
+
# NOTE: to introduce a new lower-verbosity mode, set some members
|
|
391
|
+
# of SkipFields to False to skip them
|
|
392
|
+
unique_invocation.skip = SkipFields()
|
|
393
|
+
|
|
394
|
+
unique_invocation.print(
|
|
395
|
+
build_name=build_name,
|
|
396
|
+
cache_dir=cache_dir,
|
|
397
|
+
print_file_name=print_file_name,
|
|
398
|
+
invocation_idx=invocation_idx,
|
|
399
|
+
model_visited=model_visited,
|
|
400
|
+
multiple_unique_invocations=len(model_info.unique_invocations) > 1,
|
|
401
|
+
)
|
|
402
|
+
model_visited = True
|
|
403
|
+
invocation_idx += 1
|
|
404
|
+
|
|
405
|
+
if print_file_name:
|
|
406
|
+
script_names_visited.append(model_info.script_name)
|
|
407
|
+
|
|
408
|
+
recursive_print(
|
|
409
|
+
models_found,
|
|
410
|
+
build_name,
|
|
411
|
+
cache_dir,
|
|
412
|
+
parent_model_hash=model_hash,
|
|
413
|
+
parent_invocation_hash=invocation_hash,
|
|
414
|
+
script_names_visited=script_names_visited,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def stop_logger_forward() -> None:
|
|
419
|
+
"""
|
|
420
|
+
Stop forwarding stdout and stderr to file
|
|
421
|
+
"""
|
|
422
|
+
if hasattr(sys.stdout, "terminal"):
|
|
423
|
+
sys.stdout = sys.stdout.terminal
|
|
424
|
+
if hasattr(sys.stderr, "terminal_err"):
|
|
425
|
+
sys.stderr = sys.stderr.terminal_err
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def add_to_state(
|
|
429
|
+
state: State,
|
|
430
|
+
name: str,
|
|
431
|
+
model: str,
|
|
432
|
+
extension: str = "",
|
|
433
|
+
input_shapes: Optional[Dict] = None,
|
|
434
|
+
):
|
|
435
|
+
if vars(state).get("model_hash"):
|
|
436
|
+
model_hash = state.model_hash
|
|
437
|
+
else:
|
|
438
|
+
model_hash = 0
|
|
439
|
+
|
|
440
|
+
if os.path.exists(name):
|
|
441
|
+
file_name = fs.clean_file_name(name)
|
|
442
|
+
file = name
|
|
443
|
+
else:
|
|
444
|
+
file_name = name
|
|
445
|
+
file = ""
|
|
446
|
+
|
|
447
|
+
state.invocation_info = UniqueInvocationInfo(
|
|
448
|
+
name=input,
|
|
449
|
+
script_name=file_name,
|
|
450
|
+
file=file,
|
|
451
|
+
input_shapes=input_shapes,
|
|
452
|
+
hash=model_hash,
|
|
453
|
+
is_target=True,
|
|
454
|
+
extension=extension,
|
|
455
|
+
executed=1,
|
|
456
|
+
)
|
|
457
|
+
state.models_found = {
|
|
458
|
+
"the_model": ModelInfo(
|
|
459
|
+
model=model,
|
|
460
|
+
name=input,
|
|
461
|
+
script_name=input,
|
|
462
|
+
file=input,
|
|
463
|
+
unique_invocations={model_hash: state.invocation_info},
|
|
464
|
+
hash=model_hash,
|
|
465
|
+
)
|
|
466
|
+
}
|
|
467
|
+
state.invocation_info.params = state.models_found["the_model"].params
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
471
|
+
# Modifications Copyright (c) 2025 AMD
|