lemonade-sdk 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lemonade/__init__.py +5 -0
- lemonade/api.py +180 -0
- lemonade/cache.py +92 -0
- lemonade/cli.py +173 -0
- lemonade/common/__init__.py +0 -0
- lemonade/common/build.py +176 -0
- lemonade/common/cli_helpers.py +139 -0
- lemonade/common/exceptions.py +98 -0
- lemonade/common/filesystem.py +368 -0
- lemonade/common/inference_engines.py +408 -0
- lemonade/common/network.py +93 -0
- lemonade/common/printing.py +110 -0
- lemonade/common/status.py +471 -0
- lemonade/common/system_info.py +1411 -0
- lemonade/common/test_helpers.py +28 -0
- lemonade/profilers/__init__.py +1 -0
- lemonade/profilers/agt_power.py +437 -0
- lemonade/profilers/hwinfo_power.py +429 -0
- lemonade/profilers/memory_tracker.py +259 -0
- lemonade/profilers/profiler.py +58 -0
- lemonade/sequence.py +363 -0
- lemonade/state.py +159 -0
- lemonade/tools/__init__.py +1 -0
- lemonade/tools/accuracy.py +432 -0
- lemonade/tools/adapter.py +114 -0
- lemonade/tools/bench.py +302 -0
- lemonade/tools/flm/__init__.py +1 -0
- lemonade/tools/flm/utils.py +305 -0
- lemonade/tools/huggingface/bench.py +187 -0
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/huggingface/utils.py +359 -0
- lemonade/tools/humaneval.py +264 -0
- lemonade/tools/llamacpp/bench.py +255 -0
- lemonade/tools/llamacpp/load.py +222 -0
- lemonade/tools/llamacpp/utils.py +1260 -0
- lemonade/tools/management_tools.py +319 -0
- lemonade/tools/mmlu.py +319 -0
- lemonade/tools/oga/__init__.py +0 -0
- lemonade/tools/oga/bench.py +120 -0
- lemonade/tools/oga/load.py +804 -0
- lemonade/tools/oga/migration.py +403 -0
- lemonade/tools/oga/utils.py +462 -0
- lemonade/tools/perplexity.py +147 -0
- lemonade/tools/prompt.py +263 -0
- lemonade/tools/report/__init__.py +0 -0
- lemonade/tools/report/llm_report.py +203 -0
- lemonade/tools/report/table.py +899 -0
- lemonade/tools/server/__init__.py +0 -0
- lemonade/tools/server/flm.py +133 -0
- lemonade/tools/server/llamacpp.py +320 -0
- lemonade/tools/server/serve.py +2123 -0
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/index.html +279 -0
- lemonade/tools/server/static/js/chat.js +1059 -0
- lemonade/tools/server/static/js/model-settings.js +183 -0
- lemonade/tools/server/static/js/models.js +1395 -0
- lemonade/tools/server/static/js/shared.js +556 -0
- lemonade/tools/server/static/logs.html +191 -0
- lemonade/tools/server/static/styles.css +2654 -0
- lemonade/tools/server/static/webapp.html +321 -0
- lemonade/tools/server/tool_calls.py +153 -0
- lemonade/tools/server/tray.py +664 -0
- lemonade/tools/server/utils/macos_tray.py +226 -0
- lemonade/tools/server/utils/port.py +77 -0
- lemonade/tools/server/utils/thread.py +85 -0
- lemonade/tools/server/utils/windows_tray.py +408 -0
- lemonade/tools/server/webapp.py +34 -0
- lemonade/tools/server/wrapped_server.py +559 -0
- lemonade/tools/tool.py +374 -0
- lemonade/version.py +1 -0
- lemonade_install/__init__.py +1 -0
- lemonade_install/install.py +239 -0
- lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
- lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
- lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
- lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
- lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
- lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
- lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
- lemonade_server/cli.py +805 -0
- lemonade_server/model_manager.py +758 -0
- lemonade_server/pydantic_models.py +159 -0
- lemonade_server/server_models.json +643 -0
- lemonade_server/settings.py +39 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Profiler(abc.ABC):
|
|
5
|
+
|
|
6
|
+
unique_name: str
|
|
7
|
+
|
|
8
|
+
def __init__(self, parser_arg_value=None):
|
|
9
|
+
self.parser_arg_value = parser_arg_value
|
|
10
|
+
# Statistics that will be displayed to the CLI user
|
|
11
|
+
self.status_stats = []
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
@abc.abstractmethod
|
|
15
|
+
def add_arguments_to_parser(parser):
|
|
16
|
+
"""
|
|
17
|
+
Adds the argument parsing for this tool to the parser.
|
|
18
|
+
Uses f"--{self.unique_name}" as the argument.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@abc.abstractmethod
|
|
22
|
+
def start(self, build_dir):
|
|
23
|
+
"""
|
|
24
|
+
This method is called prior to the tool sequence starting.
|
|
25
|
+
This informs the profiler to start gathering data.
|
|
26
|
+
The build directory can be used to store profiling data.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def tool_starting(self, tool_name):
|
|
30
|
+
"""
|
|
31
|
+
This method is called to inform the profiler of the name of the tool that is about to start.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def tool_stopping(self):
|
|
35
|
+
"""
|
|
36
|
+
This method is called to inform the profiler that the tool has finished.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def stop(self):
|
|
40
|
+
"""
|
|
41
|
+
This method is called when the tool sequence has finished.
|
|
42
|
+
This informs the profiler to stop gathering data.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
def generate_results(self, state, timestamp, start_times):
|
|
47
|
+
"""
|
|
48
|
+
This method is called so that the profiler can create its output files.
|
|
49
|
+
The state is passed so that build info can be gathered and stats can be written.
|
|
50
|
+
The timestamp can be used for filename in current working directory.
|
|
51
|
+
The start times parameter is a dict with the keys being the tools names and
|
|
52
|
+
the values being the time the tool started. There is an initial "warmup" key
|
|
53
|
+
that has a start time before the first tool and a "cool down" key that contains the
|
|
54
|
+
time when the last tool ended.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Copyright (c) 2025 AMD
|
lemonade/sequence.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import time
|
|
3
|
+
import os
|
|
4
|
+
import platform
|
|
5
|
+
import copy
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import List, Dict, Optional
|
|
8
|
+
import pytz
|
|
9
|
+
import psutil
|
|
10
|
+
import lemonade.common.printing as printing
|
|
11
|
+
import lemonade.common.exceptions as exp
|
|
12
|
+
import lemonade.common.build as build
|
|
13
|
+
from lemonade.common.system_info import get_system_info_dict
|
|
14
|
+
import lemonade.common.filesystem as fs
|
|
15
|
+
import lemonade.common.status as status
|
|
16
|
+
from lemonade.tools.tool import Tool
|
|
17
|
+
from lemonade.profilers.profiler import Profiler
|
|
18
|
+
from lemonade.state import State
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _rewind_stdout(lines: int = 1):
|
|
22
|
+
"""
|
|
23
|
+
Helper function for the command line monitor. Moves the cursor up a
|
|
24
|
+
certain number of lines in the terminal, corresponding to the
|
|
25
|
+
status line for a Tool, so that we can update the status of
|
|
26
|
+
that Tool.
|
|
27
|
+
"""
|
|
28
|
+
rewind_stdout_one_line = "\033[1A"
|
|
29
|
+
rewind_multiple_lines = rewind_stdout_one_line * lines
|
|
30
|
+
print(rewind_multiple_lines, end="")
|
|
31
|
+
sys.stdout.flush()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Sequence:
|
|
35
|
+
"""
|
|
36
|
+
Helper class to launch and manage build tools.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
tools: Dict[Tool, List[str]],
|
|
42
|
+
profilers: List[Profiler] = None,
|
|
43
|
+
):
|
|
44
|
+
|
|
45
|
+
self.tools = tools
|
|
46
|
+
self.profilers = [] if profilers is None else profilers
|
|
47
|
+
|
|
48
|
+
# Make sure all the tool names are unique
|
|
49
|
+
self.tool_names = [tool.__class__.unique_name for tool in self.tools.keys()]
|
|
50
|
+
|
|
51
|
+
if len(self.tool_names) != len(set(self.tool_names)):
|
|
52
|
+
msg = f"""
|
|
53
|
+
All tools in a Sequence must have unique unique_names, however Sequence
|
|
54
|
+
received duplicates in the list of names: {self.tool_names}
|
|
55
|
+
"""
|
|
56
|
+
raise ValueError(msg)
|
|
57
|
+
|
|
58
|
+
# Save the process (used to get memory usage)
|
|
59
|
+
self.process = psutil.Process()
|
|
60
|
+
|
|
61
|
+
def show_monitor(self, state: State, verbosity: bool):
|
|
62
|
+
"""
|
|
63
|
+
Displays the monitor on the terminal. The purpose of the monitor
|
|
64
|
+
is to show the status of each tool (success, failure, not started yet,
|
|
65
|
+
or in-progress).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
if verbosity:
|
|
69
|
+
print()
|
|
70
|
+
|
|
71
|
+
printing.logn(
|
|
72
|
+
f'Building "{state.build_name}"',
|
|
73
|
+
c=printing.Colors.BOLD,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
for tool in self.tools:
|
|
77
|
+
tool.status_line(successful=None, verbosity=True)
|
|
78
|
+
|
|
79
|
+
_rewind_stdout(len(self.tools))
|
|
80
|
+
|
|
81
|
+
def _advance_cursor(self, current_tool_name: str):
|
|
82
|
+
# Advance the cursor below the monitor so
|
|
83
|
+
# we can print a message
|
|
84
|
+
tool_depth_in_sequence = len(self.tool_names) - self.tool_names.index(
|
|
85
|
+
current_tool_name
|
|
86
|
+
)
|
|
87
|
+
stdout_lines_to_advance = tool_depth_in_sequence - 2
|
|
88
|
+
cursor_down = "\n" * stdout_lines_to_advance
|
|
89
|
+
|
|
90
|
+
print(cursor_down)
|
|
91
|
+
|
|
92
|
+
def _get_mem_usage_str(self) -> str:
|
|
93
|
+
"""
|
|
94
|
+
Returns a string with memory usage for the current process
|
|
95
|
+
(non-swapped physical memory). In Windows OS, the peak memory used in the
|
|
96
|
+
process is also included.
|
|
97
|
+
|
|
98
|
+
Example: '1.100 GB (1.638 GB peak)'
|
|
99
|
+
"""
|
|
100
|
+
mem_info = self.process.memory_info()
|
|
101
|
+
mem_info_str = f"{mem_info.rss / 1024 ** 3:,.3f} GB"
|
|
102
|
+
if platform.system() == "Windows":
|
|
103
|
+
mem_info_str += f" ({mem_info.peak_wset / 1024 ** 3:,.3f} GB peak)"
|
|
104
|
+
return mem_info_str
|
|
105
|
+
|
|
106
|
+
def launch(
|
|
107
|
+
self,
|
|
108
|
+
state: State,
|
|
109
|
+
lean_cache: bool = False,
|
|
110
|
+
monitor: Optional[bool] = None,
|
|
111
|
+
stats_to_save: Optional[Dict] = None,
|
|
112
|
+
) -> State:
|
|
113
|
+
"""
|
|
114
|
+
Executes the sequence of tools.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
current_time = datetime.now()
|
|
118
|
+
timestamp = current_time.strftime("%Y-%m-%d-%H%M%S")
|
|
119
|
+
start_times = {"warmup": time.time()}
|
|
120
|
+
|
|
121
|
+
# Allow monitor to be globally disabled by an environment variable
|
|
122
|
+
if monitor is None:
|
|
123
|
+
if os.environ.get("LEMONADE_BUILD_MONITOR") == "False":
|
|
124
|
+
monitor_setting = False
|
|
125
|
+
else:
|
|
126
|
+
monitor_setting = True
|
|
127
|
+
else:
|
|
128
|
+
monitor_setting = monitor
|
|
129
|
+
|
|
130
|
+
# Create a build directory in the cache
|
|
131
|
+
fs.make_build_dir(state.cache_dir, state.build_name)
|
|
132
|
+
|
|
133
|
+
# Start profilers
|
|
134
|
+
build_dir = build.output_dir(state.cache_dir, state.build_name)
|
|
135
|
+
for profiler in self.profilers:
|
|
136
|
+
profiler.start(build_dir)
|
|
137
|
+
|
|
138
|
+
self.show_monitor(state, monitor_setting)
|
|
139
|
+
|
|
140
|
+
if state.build_status == build.FunctionStatus.SUCCESSFUL:
|
|
141
|
+
msg = """
|
|
142
|
+
build_model() is running a build on a model that already built successfully, which
|
|
143
|
+
should not happen because the build should have loaded from cache or rebuilt from scratch.
|
|
144
|
+
If you are using custom tools and Sequences then you have some debugging to do. Otherwise,
|
|
145
|
+
please file an issue at https://github.com/lemonade-sdk/lemonade/issues
|
|
146
|
+
"""
|
|
147
|
+
raise exp.Error(msg)
|
|
148
|
+
|
|
149
|
+
# Keep a copy of any stats we loaded from disk, in case we need to
|
|
150
|
+
# restore them later
|
|
151
|
+
saved_stats = copy.deepcopy(fs.Stats(state.cache_dir, state.build_name).stats)
|
|
152
|
+
|
|
153
|
+
# Save build name to stats so it shows up on reports
|
|
154
|
+
state.save_stat(fs.Keys.BUILD_NAME, state.build_name)
|
|
155
|
+
|
|
156
|
+
# Indicate that the build is running. If the build fails for any reason,
|
|
157
|
+
# we will try to catch the exception and note it in the stats.
|
|
158
|
+
# If a concluded build still has a status of "running", this means
|
|
159
|
+
# there was an uncaught exception.
|
|
160
|
+
state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.INCOMPLETE)
|
|
161
|
+
|
|
162
|
+
# Save a timestamp so that we know the order of builds within a cache
|
|
163
|
+
pacific_tz = pytz.timezone("America/Los_Angeles")
|
|
164
|
+
state.save_stat(
|
|
165
|
+
fs.Keys.TIMESTAMP,
|
|
166
|
+
datetime.now(pacific_tz),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Save the system information used for this build
|
|
170
|
+
system_info = get_system_info_dict()
|
|
171
|
+
state.save_stat(
|
|
172
|
+
fs.Keys.SYSTEM_INFO,
|
|
173
|
+
system_info,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Collect telemetry for the build
|
|
177
|
+
state.save_stat(
|
|
178
|
+
fs.Keys.SELECTED_SEQUENCE_OF_TOOLS,
|
|
179
|
+
self.tool_names,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# At the beginning of a sequence no tool has started
|
|
183
|
+
for tool in self.tools:
|
|
184
|
+
state.save_stat(tool.status_key, build.FunctionStatus.NOT_STARTED)
|
|
185
|
+
state.save_stat(tool.duration_key, "-")
|
|
186
|
+
state.save_stat(tool.memory_key, "-")
|
|
187
|
+
|
|
188
|
+
# Save any additional stats passed in via arguments
|
|
189
|
+
if stats_to_save:
|
|
190
|
+
for stat_key, stat_value in stats_to_save.items():
|
|
191
|
+
state.save_stat(stat_key, stat_value)
|
|
192
|
+
|
|
193
|
+
# Save initial memory as a build statistic
|
|
194
|
+
state.save_stat(f"{fs.Keys.TOOL_MEMORY}:__init__", self._get_mem_usage_str())
|
|
195
|
+
|
|
196
|
+
# Run the build
|
|
197
|
+
saved_exception = None
|
|
198
|
+
for tool, argv in self.tools.items():
|
|
199
|
+
|
|
200
|
+
start_time = time.time()
|
|
201
|
+
start_times[tool.unique_name] = start_time
|
|
202
|
+
|
|
203
|
+
# Inform profiler of name of tool about to start
|
|
204
|
+
for profiler in self.profilers:
|
|
205
|
+
profiler.tool_starting(tool.unique_name)
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
|
|
209
|
+
# Set status as incomplete, since tool just started
|
|
210
|
+
state.save_stat(tool.status_key, build.FunctionStatus.INCOMPLETE)
|
|
211
|
+
|
|
212
|
+
# Collect telemetry about the tool
|
|
213
|
+
state.current_build_tool = tool.unique_name
|
|
214
|
+
|
|
215
|
+
# Run the tool
|
|
216
|
+
state = tool.parse_and_run(state, argv, monitor_setting)
|
|
217
|
+
|
|
218
|
+
# Save the state so that it can be assessed for a cache hit
|
|
219
|
+
state.save()
|
|
220
|
+
|
|
221
|
+
except exp.SkipBuild as e:
|
|
222
|
+
# SkipBuild is a special exception, which means that a build
|
|
223
|
+
# was loaded from disk, then we realized we want to skip it.
|
|
224
|
+
# In order to preserve the original stats and state of the build,
|
|
225
|
+
# we need to restore the stats file to what it was at the beginning
|
|
226
|
+
# of this function call. We also need to avoid calling state.save().
|
|
227
|
+
|
|
228
|
+
# Restore the prior stats
|
|
229
|
+
fs.save_yaml(
|
|
230
|
+
saved_stats, fs.Stats(state.cache_dir, state.build_name).file
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Advance the cursor below the monitor so
|
|
234
|
+
# we can print a message
|
|
235
|
+
self._advance_cursor(tool.unique_name)
|
|
236
|
+
printing.log_warning(str(e))
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
# Broad exception is desirable as we want to capture
|
|
240
|
+
# all exceptions (including those we can't anticipate)
|
|
241
|
+
except Exception as e: # pylint: disable=broad-except
|
|
242
|
+
|
|
243
|
+
if os.environ.get("LEMONADE_DEBUG", "").lower() == "true":
|
|
244
|
+
# It may be useful to raise the exception here, since
|
|
245
|
+
# if any of the subsequent lines of code raise another
|
|
246
|
+
# exception it will be very hard to root cause e.
|
|
247
|
+
raise e
|
|
248
|
+
|
|
249
|
+
# Update tool and build status
|
|
250
|
+
state.save_stat(tool.status_key, build.FunctionStatus.ERROR)
|
|
251
|
+
state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.ERROR)
|
|
252
|
+
|
|
253
|
+
# Save the log file for the failed tool to stats for easy reference
|
|
254
|
+
stats = fs.Stats(state.cache_dir, state.build_name)
|
|
255
|
+
stats.save_eval_error_log(tool.logfile_path)
|
|
256
|
+
|
|
257
|
+
# Advance the cursor below the monitor so
|
|
258
|
+
# we can print a message
|
|
259
|
+
self._advance_cursor(tool.unique_name)
|
|
260
|
+
|
|
261
|
+
if vars(state).get("invocation_info"):
|
|
262
|
+
state.invocation_info.status_message = f"Error: {e}"
|
|
263
|
+
state.invocation_info.status_message_color = printing.Colors.WARNING
|
|
264
|
+
else:
|
|
265
|
+
printing.log_error(e)
|
|
266
|
+
|
|
267
|
+
# We will raise this exception after we capture as many statistics
|
|
268
|
+
# about the build as possible
|
|
269
|
+
saved_exception = e
|
|
270
|
+
|
|
271
|
+
# Don't run any more tools
|
|
272
|
+
break
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
# Update tool Status
|
|
276
|
+
state.save_stat(tool.status_key, build.FunctionStatus.SUCCESSFUL)
|
|
277
|
+
state.current_build_tool = None
|
|
278
|
+
|
|
279
|
+
finally:
|
|
280
|
+
# Store tool duration
|
|
281
|
+
execution_time = time.time() - start_time
|
|
282
|
+
state.save_stat(tool.duration_key, execution_time)
|
|
283
|
+
|
|
284
|
+
# Store current memory and peak working memory
|
|
285
|
+
state.save_stat(tool.memory_key, self._get_mem_usage_str())
|
|
286
|
+
|
|
287
|
+
# Inform profilers that tool has finished
|
|
288
|
+
for profiler in self.profilers:
|
|
289
|
+
profiler.tool_stopping()
|
|
290
|
+
|
|
291
|
+
start_times["cool down"] = time.time()
|
|
292
|
+
|
|
293
|
+
# Tell the profilers to stop gathering data
|
|
294
|
+
for profiler in self.profilers:
|
|
295
|
+
profiler.stop()
|
|
296
|
+
|
|
297
|
+
if not saved_exception:
|
|
298
|
+
state.build_status = build.FunctionStatus.SUCCESSFUL
|
|
299
|
+
state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.SUCCESSFUL)
|
|
300
|
+
if vars(state).get("invocation_info"):
|
|
301
|
+
state.invocation_info.status_message = (
|
|
302
|
+
f"Successful build! {state.invocation_info.extra_status}"
|
|
303
|
+
)
|
|
304
|
+
state.invocation_info.status_message_color = printing.Colors.OKGREEN
|
|
305
|
+
|
|
306
|
+
# Generate profiler output
|
|
307
|
+
for profiler in self.profilers:
|
|
308
|
+
profiler.generate_results(state, timestamp, start_times)
|
|
309
|
+
|
|
310
|
+
if vars(state).get("models_found") and vars(state).get("invocation_info"):
|
|
311
|
+
|
|
312
|
+
# Present status statistics from the tools
|
|
313
|
+
for tool in self.tools:
|
|
314
|
+
state.invocation_info.stats_keys += tool.status_stats
|
|
315
|
+
|
|
316
|
+
# Present status statistics from the profilers
|
|
317
|
+
for profiler in self.profilers:
|
|
318
|
+
state.invocation_info.stats_keys += profiler.status_stats
|
|
319
|
+
|
|
320
|
+
print()
|
|
321
|
+
|
|
322
|
+
status.recursive_print(
|
|
323
|
+
models_found=state.models_found,
|
|
324
|
+
build_name=state.build_name,
|
|
325
|
+
cache_dir=state.cache_dir,
|
|
326
|
+
parent_model_hash=None,
|
|
327
|
+
parent_invocation_hash=None,
|
|
328
|
+
script_names_visited=[],
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
if lean_cache:
|
|
332
|
+
printing.log_info("Removing build artifacts...")
|
|
333
|
+
fs.clean_output_dir(state.cache_dir, state.build_name)
|
|
334
|
+
|
|
335
|
+
state.save()
|
|
336
|
+
|
|
337
|
+
if saved_exception:
|
|
338
|
+
raise saved_exception
|
|
339
|
+
|
|
340
|
+
printing.log_success(
|
|
341
|
+
f"\n Saved to **{build.output_dir(state.cache_dir, state.build_name)}**"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
return state
|
|
345
|
+
|
|
346
|
+
def status_line(self, verbosity):
|
|
347
|
+
"""
|
|
348
|
+
Print a status line in the monitor for every tool in the sequence
|
|
349
|
+
"""
|
|
350
|
+
for tool in self.tools:
|
|
351
|
+
tool.status_line(successful=None, verbosity=verbosity)
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def info(self) -> Dict[str, Dict]:
|
|
355
|
+
"""
|
|
356
|
+
Return a dictionary of tool_name:argv for the sequence
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
return {tool.__class__.unique_name: argv for tool, argv in self.tools.items()}
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
363
|
+
# Modifications Copyright (c) 2025 AMD
|
lemonade/state.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Dict, Optional, Any
|
|
4
|
+
import yaml
|
|
5
|
+
import lemonade.common.build as build
|
|
6
|
+
import lemonade.common.filesystem as fs
|
|
7
|
+
from lemonade.version import __version__ as lemonade_version
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _is_nice_to_write(value):
|
|
11
|
+
"""
|
|
12
|
+
Checks whether a value is nice to write to YAML.
|
|
13
|
+
Returns True if the value is a string, int, float, bool, list, dict, or tuple.
|
|
14
|
+
Returns False otherwise.
|
|
15
|
+
"""
|
|
16
|
+
if isinstance(value, (str, int, float, bool)):
|
|
17
|
+
return True
|
|
18
|
+
elif isinstance(value, list) or isinstance(value, tuple):
|
|
19
|
+
# Check if all elements in the list are nice to write
|
|
20
|
+
return all(_is_nice_to_write(item) for item in value)
|
|
21
|
+
elif isinstance(value, dict):
|
|
22
|
+
# Check if all values in the dictionary are nice to write
|
|
23
|
+
return all(_is_nice_to_write(item) for item in value.values())
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _sanitize_for_yaml(input_dict: Dict) -> Dict:
|
|
28
|
+
"""
|
|
29
|
+
Creates a new dictionary containing only nice-to-write values
|
|
30
|
+
from the original dictionary.
|
|
31
|
+
"""
|
|
32
|
+
result = {}
|
|
33
|
+
for key, value in input_dict.items():
|
|
34
|
+
if _is_nice_to_write(value):
|
|
35
|
+
result[key] = value
|
|
36
|
+
return result
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class State:
|
|
40
|
+
"""
|
|
41
|
+
The State class is meant to carry build state, starting with the user's
|
|
42
|
+
initial arguments, through each build Tool in the Sequence, and finally
|
|
43
|
+
to the disk, where it is used to assess cache hits.
|
|
44
|
+
|
|
45
|
+
State is initialized with the key members that are shared by every build,
|
|
46
|
+
and reasonable default values are assigned as appropriate.
|
|
47
|
+
|
|
48
|
+
Tool developers can also add any members they wish. To get or set an
|
|
49
|
+
attribute, reference it as an attribute:
|
|
50
|
+
1. get: `my_variable = state.attribute_name`
|
|
51
|
+
2. set: `state.attribute_name = my_variable`
|
|
52
|
+
|
|
53
|
+
Build State can be saved and loaded from disk in the form of a state.yaml file
|
|
54
|
+
via State.save() and load_state(), respectively. Note that while State can
|
|
55
|
+
contain members of any type, only YAML-safe members (str, int, bool, float,
|
|
56
|
+
list, dict, tuple) will be saved and loaded.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
cache_dir: str,
|
|
62
|
+
build_name: Optional[str] = None,
|
|
63
|
+
sequence_info: Dict[str, Dict] = None,
|
|
64
|
+
**kwargs,
|
|
65
|
+
):
|
|
66
|
+
|
|
67
|
+
# The default model name is the name of the python file that calls build_model()
|
|
68
|
+
if build_name is None:
|
|
69
|
+
build_name = os.path.basename(sys.argv[0])
|
|
70
|
+
|
|
71
|
+
# Support "~" in the cache_dir argument
|
|
72
|
+
parsed_cache_dir = os.path.expanduser(cache_dir)
|
|
73
|
+
|
|
74
|
+
# Save settings as State members
|
|
75
|
+
self.cache_dir = parsed_cache_dir
|
|
76
|
+
self.build_name = build_name
|
|
77
|
+
self.sequence_info = sequence_info
|
|
78
|
+
self.lemonade_version = lemonade_version
|
|
79
|
+
self.build_status = build.FunctionStatus.NOT_STARTED
|
|
80
|
+
self.downcast_applied = False
|
|
81
|
+
self.uid = build.unique_id()
|
|
82
|
+
self.results = None
|
|
83
|
+
|
|
84
|
+
# Store any additional kwargs as members
|
|
85
|
+
for key, value in kwargs.items():
|
|
86
|
+
self.__dict__[key] = value
|
|
87
|
+
|
|
88
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Tool developers can add a new member to State by simply
|
|
91
|
+
assigning it as an attribute, i.e., `state.new_member = value`.
|
|
92
|
+
"""
|
|
93
|
+
return super().__setattr__(name, value)
|
|
94
|
+
|
|
95
|
+
def save_stat(self, key: str, value):
|
|
96
|
+
"""
|
|
97
|
+
Save statistics to an yaml file in the build directory
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
stats = fs.Stats(self.cache_dir, self.build_name)
|
|
101
|
+
stats.save_stat(key, value)
|
|
102
|
+
|
|
103
|
+
def save_sub_stat(self, parent_key: str, key: str, value):
|
|
104
|
+
"""
|
|
105
|
+
Save statistics to an yaml file in the build directory
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
stats = fs.Stats(self.cache_dir, self.build_name)
|
|
109
|
+
stats.save_sub_stat(parent_key, key, value)
|
|
110
|
+
|
|
111
|
+
def save(self):
|
|
112
|
+
"""
|
|
113
|
+
Save all YAML-friendly members to disk as a state.yaml file.
|
|
114
|
+
|
|
115
|
+
Note that `model` and `inputs` will typically not be saved since
|
|
116
|
+
they are typically in non-YAML-friendly types such as `torch.nn.Module`
|
|
117
|
+
and `torch.tensor`.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
state_to_save = _sanitize_for_yaml(vars(self))
|
|
121
|
+
|
|
122
|
+
# Create a build directory in the cache
|
|
123
|
+
fs.make_build_dir(self.cache_dir, self.build_name)
|
|
124
|
+
|
|
125
|
+
with open(
|
|
126
|
+
build.state_file(self.cache_dir, self.build_name),
|
|
127
|
+
"w",
|
|
128
|
+
encoding="utf8",
|
|
129
|
+
) as outfile:
|
|
130
|
+
yaml.dump(state_to_save, outfile)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def load_state(
|
|
134
|
+
cache_dir=None,
|
|
135
|
+
build_name=None,
|
|
136
|
+
state_path=None,
|
|
137
|
+
) -> State:
|
|
138
|
+
"""
|
|
139
|
+
Read a state.yaml file corresponding to a specific build in a specific
|
|
140
|
+
cache, and use its contents to initialize a State instance.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
if state_path is not None:
|
|
144
|
+
file_path = state_path
|
|
145
|
+
elif build_name is not None and cache_dir is not None:
|
|
146
|
+
file_path = build.state_file(cache_dir, build_name)
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"This function requires either build_name and cache_dir to be set, "
|
|
150
|
+
"or state_path to be set, not both or neither"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
state_dict = build.load_yaml(file_path)
|
|
154
|
+
|
|
155
|
+
return State(**state_dict)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# This file was originally licensed under Apache 2.0. It has been modified.
|
|
159
|
+
# Modifications Copyright (c) 2025 AMD
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .tool import Tool, FirstTool, NiceHelpFormatter
|