lemonade-sdk 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +180 -0
  3. lemonade/cache.py +92 -0
  4. lemonade/cli.py +173 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/build.py +176 -0
  7. lemonade/common/cli_helpers.py +139 -0
  8. lemonade/common/exceptions.py +98 -0
  9. lemonade/common/filesystem.py +368 -0
  10. lemonade/common/inference_engines.py +408 -0
  11. lemonade/common/network.py +93 -0
  12. lemonade/common/printing.py +110 -0
  13. lemonade/common/status.py +471 -0
  14. lemonade/common/system_info.py +1411 -0
  15. lemonade/common/test_helpers.py +28 -0
  16. lemonade/profilers/__init__.py +1 -0
  17. lemonade/profilers/agt_power.py +437 -0
  18. lemonade/profilers/hwinfo_power.py +429 -0
  19. lemonade/profilers/memory_tracker.py +259 -0
  20. lemonade/profilers/profiler.py +58 -0
  21. lemonade/sequence.py +363 -0
  22. lemonade/state.py +159 -0
  23. lemonade/tools/__init__.py +1 -0
  24. lemonade/tools/accuracy.py +432 -0
  25. lemonade/tools/adapter.py +114 -0
  26. lemonade/tools/bench.py +302 -0
  27. lemonade/tools/flm/__init__.py +1 -0
  28. lemonade/tools/flm/utils.py +305 -0
  29. lemonade/tools/huggingface/bench.py +187 -0
  30. lemonade/tools/huggingface/load.py +235 -0
  31. lemonade/tools/huggingface/utils.py +359 -0
  32. lemonade/tools/humaneval.py +264 -0
  33. lemonade/tools/llamacpp/bench.py +255 -0
  34. lemonade/tools/llamacpp/load.py +222 -0
  35. lemonade/tools/llamacpp/utils.py +1260 -0
  36. lemonade/tools/management_tools.py +319 -0
  37. lemonade/tools/mmlu.py +319 -0
  38. lemonade/tools/oga/__init__.py +0 -0
  39. lemonade/tools/oga/bench.py +120 -0
  40. lemonade/tools/oga/load.py +804 -0
  41. lemonade/tools/oga/migration.py +403 -0
  42. lemonade/tools/oga/utils.py +462 -0
  43. lemonade/tools/perplexity.py +147 -0
  44. lemonade/tools/prompt.py +263 -0
  45. lemonade/tools/report/__init__.py +0 -0
  46. lemonade/tools/report/llm_report.py +203 -0
  47. lemonade/tools/report/table.py +899 -0
  48. lemonade/tools/server/__init__.py +0 -0
  49. lemonade/tools/server/flm.py +133 -0
  50. lemonade/tools/server/llamacpp.py +320 -0
  51. lemonade/tools/server/serve.py +2123 -0
  52. lemonade/tools/server/static/favicon.ico +0 -0
  53. lemonade/tools/server/static/index.html +279 -0
  54. lemonade/tools/server/static/js/chat.js +1059 -0
  55. lemonade/tools/server/static/js/model-settings.js +183 -0
  56. lemonade/tools/server/static/js/models.js +1395 -0
  57. lemonade/tools/server/static/js/shared.js +556 -0
  58. lemonade/tools/server/static/logs.html +191 -0
  59. lemonade/tools/server/static/styles.css +2654 -0
  60. lemonade/tools/server/static/webapp.html +321 -0
  61. lemonade/tools/server/tool_calls.py +153 -0
  62. lemonade/tools/server/tray.py +664 -0
  63. lemonade/tools/server/utils/macos_tray.py +226 -0
  64. lemonade/tools/server/utils/port.py +77 -0
  65. lemonade/tools/server/utils/thread.py +85 -0
  66. lemonade/tools/server/utils/windows_tray.py +408 -0
  67. lemonade/tools/server/webapp.py +34 -0
  68. lemonade/tools/server/wrapped_server.py +559 -0
  69. lemonade/tools/tool.py +374 -0
  70. lemonade/version.py +1 -0
  71. lemonade_install/__init__.py +1 -0
  72. lemonade_install/install.py +239 -0
  73. lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
  74. lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
  75. lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
  76. lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
  77. lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
  78. lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
  79. lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
  80. lemonade_server/cli.py +805 -0
  81. lemonade_server/model_manager.py +758 -0
  82. lemonade_server/pydantic_models.py +159 -0
  83. lemonade_server/server_models.json +643 -0
  84. lemonade_server/settings.py +39 -0
@@ -0,0 +1,58 @@
1
+ import abc
2
+
3
+
4
+ class Profiler(abc.ABC):
5
+
6
+ unique_name: str
7
+
8
+ def __init__(self, parser_arg_value=None):
9
+ self.parser_arg_value = parser_arg_value
10
+ # Statistics that will be displayed to the CLI user
11
+ self.status_stats = []
12
+
13
+ @staticmethod
14
+ @abc.abstractmethod
15
+ def add_arguments_to_parser(parser):
16
+ """
17
+ Adds the argument parsing for this tool to the parser.
18
+ Uses f"--{self.unique_name}" as the argument.
19
+ """
20
+
21
+ @abc.abstractmethod
22
+ def start(self, build_dir):
23
+ """
24
+ This method is called prior to the tool sequence starting.
25
+ This informs the profiler to start gathering data.
26
+ The build directory can be used to store profiling data.
27
+ """
28
+
29
+ def tool_starting(self, tool_name):
30
+ """
31
+ This method is called to inform the profiler of the name of the tool that is about to start.
32
+ """
33
+
34
+ def tool_stopping(self):
35
+ """
36
+ This method is called to inform the profiler that the tool has finished.
37
+ """
38
+
39
+ def stop(self):
40
+ """
41
+ This method is called when the tool sequence has finished.
42
+ This informs the profiler to stop gathering data.
43
+ """
44
+
45
+ @abc.abstractmethod
46
+ def generate_results(self, state, timestamp, start_times):
47
+ """
48
+ This method is called so that the profiler can create its output files.
49
+ The state is passed so that build info can be gathered and stats can be written.
50
+ The timestamp can be used for filename in current working directory.
51
+ The start times parameter is a dict with the keys being the tools names and
52
+ the values being the time the tool started. There is an initial "warmup" key
53
+ that has a start time before the first tool and a "cool down" key that contains the
54
+ time when the last tool ended.
55
+ """
56
+
57
+
58
+ # Copyright (c) 2025 AMD
lemonade/sequence.py ADDED
@@ -0,0 +1,363 @@
1
+ import sys
2
+ import time
3
+ import os
4
+ import platform
5
+ import copy
6
+ from datetime import datetime
7
+ from typing import List, Dict, Optional
8
+ import pytz
9
+ import psutil
10
+ import lemonade.common.printing as printing
11
+ import lemonade.common.exceptions as exp
12
+ import lemonade.common.build as build
13
+ from lemonade.common.system_info import get_system_info_dict
14
+ import lemonade.common.filesystem as fs
15
+ import lemonade.common.status as status
16
+ from lemonade.tools.tool import Tool
17
+ from lemonade.profilers.profiler import Profiler
18
+ from lemonade.state import State
19
+
20
+
21
+ def _rewind_stdout(lines: int = 1):
22
+ """
23
+ Helper function for the command line monitor. Moves the cursor up a
24
+ certain number of lines in the terminal, corresponding to the
25
+ status line for a Tool, so that we can update the status of
26
+ that Tool.
27
+ """
28
+ rewind_stdout_one_line = "\033[1A"
29
+ rewind_multiple_lines = rewind_stdout_one_line * lines
30
+ print(rewind_multiple_lines, end="")
31
+ sys.stdout.flush()
32
+
33
+
34
+ class Sequence:
35
+ """
36
+ Helper class to launch and manage build tools.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ tools: Dict[Tool, List[str]],
42
+ profilers: List[Profiler] = None,
43
+ ):
44
+
45
+ self.tools = tools
46
+ self.profilers = [] if profilers is None else profilers
47
+
48
+ # Make sure all the tool names are unique
49
+ self.tool_names = [tool.__class__.unique_name for tool in self.tools.keys()]
50
+
51
+ if len(self.tool_names) != len(set(self.tool_names)):
52
+ msg = f"""
53
+ All tools in a Sequence must have unique unique_names, however Sequence
54
+ received duplicates in the list of names: {self.tool_names}
55
+ """
56
+ raise ValueError(msg)
57
+
58
+ # Save the process (used to get memory usage)
59
+ self.process = psutil.Process()
60
+
61
+ def show_monitor(self, state: State, verbosity: bool):
62
+ """
63
+ Displays the monitor on the terminal. The purpose of the monitor
64
+ is to show the status of each tool (success, failure, not started yet,
65
+ or in-progress).
66
+ """
67
+
68
+ if verbosity:
69
+ print()
70
+
71
+ printing.logn(
72
+ f'Building "{state.build_name}"',
73
+ c=printing.Colors.BOLD,
74
+ )
75
+
76
+ for tool in self.tools:
77
+ tool.status_line(successful=None, verbosity=True)
78
+
79
+ _rewind_stdout(len(self.tools))
80
+
81
+ def _advance_cursor(self, current_tool_name: str):
82
+ # Advance the cursor below the monitor so
83
+ # we can print a message
84
+ tool_depth_in_sequence = len(self.tool_names) - self.tool_names.index(
85
+ current_tool_name
86
+ )
87
+ stdout_lines_to_advance = tool_depth_in_sequence - 2
88
+ cursor_down = "\n" * stdout_lines_to_advance
89
+
90
+ print(cursor_down)
91
+
92
+ def _get_mem_usage_str(self) -> str:
93
+ """
94
+ Returns a string with memory usage for the current process
95
+ (non-swapped physical memory). In Windows OS, the peak memory used in the
96
+ process is also included.
97
+
98
+ Example: '1.100 GB (1.638 GB peak)'
99
+ """
100
+ mem_info = self.process.memory_info()
101
+ mem_info_str = f"{mem_info.rss / 1024 ** 3:,.3f} GB"
102
+ if platform.system() == "Windows":
103
+ mem_info_str += f" ({mem_info.peak_wset / 1024 ** 3:,.3f} GB peak)"
104
+ return mem_info_str
105
+
106
+ def launch(
107
+ self,
108
+ state: State,
109
+ lean_cache: bool = False,
110
+ monitor: Optional[bool] = None,
111
+ stats_to_save: Optional[Dict] = None,
112
+ ) -> State:
113
+ """
114
+ Executes the sequence of tools.
115
+ """
116
+
117
+ current_time = datetime.now()
118
+ timestamp = current_time.strftime("%Y-%m-%d-%H%M%S")
119
+ start_times = {"warmup": time.time()}
120
+
121
+ # Allow monitor to be globally disabled by an environment variable
122
+ if monitor is None:
123
+ if os.environ.get("LEMONADE_BUILD_MONITOR") == "False":
124
+ monitor_setting = False
125
+ else:
126
+ monitor_setting = True
127
+ else:
128
+ monitor_setting = monitor
129
+
130
+ # Create a build directory in the cache
131
+ fs.make_build_dir(state.cache_dir, state.build_name)
132
+
133
+ # Start profilers
134
+ build_dir = build.output_dir(state.cache_dir, state.build_name)
135
+ for profiler in self.profilers:
136
+ profiler.start(build_dir)
137
+
138
+ self.show_monitor(state, monitor_setting)
139
+
140
+ if state.build_status == build.FunctionStatus.SUCCESSFUL:
141
+ msg = """
142
+ build_model() is running a build on a model that already built successfully, which
143
+ should not happen because the build should have loaded from cache or rebuilt from scratch.
144
+ If you are using custom tools and Sequences then you have some debugging to do. Otherwise,
145
+ please file an issue at https://github.com/lemonade-sdk/lemonade/issues
146
+ """
147
+ raise exp.Error(msg)
148
+
149
+ # Keep a copy of any stats we loaded from disk, in case we need to
150
+ # restore them later
151
+ saved_stats = copy.deepcopy(fs.Stats(state.cache_dir, state.build_name).stats)
152
+
153
+ # Save build name to stats so it shows up on reports
154
+ state.save_stat(fs.Keys.BUILD_NAME, state.build_name)
155
+
156
+ # Indicate that the build is running. If the build fails for any reason,
157
+ # we will try to catch the exception and note it in the stats.
158
+ # If a concluded build still has a status of "running", this means
159
+ # there was an uncaught exception.
160
+ state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.INCOMPLETE)
161
+
162
+ # Save a timestamp so that we know the order of builds within a cache
163
+ pacific_tz = pytz.timezone("America/Los_Angeles")
164
+ state.save_stat(
165
+ fs.Keys.TIMESTAMP,
166
+ datetime.now(pacific_tz),
167
+ )
168
+
169
+ # Save the system information used for this build
170
+ system_info = get_system_info_dict()
171
+ state.save_stat(
172
+ fs.Keys.SYSTEM_INFO,
173
+ system_info,
174
+ )
175
+
176
+ # Collect telemetry for the build
177
+ state.save_stat(
178
+ fs.Keys.SELECTED_SEQUENCE_OF_TOOLS,
179
+ self.tool_names,
180
+ )
181
+
182
+ # At the beginning of a sequence no tool has started
183
+ for tool in self.tools:
184
+ state.save_stat(tool.status_key, build.FunctionStatus.NOT_STARTED)
185
+ state.save_stat(tool.duration_key, "-")
186
+ state.save_stat(tool.memory_key, "-")
187
+
188
+ # Save any additional stats passed in via arguments
189
+ if stats_to_save:
190
+ for stat_key, stat_value in stats_to_save.items():
191
+ state.save_stat(stat_key, stat_value)
192
+
193
+ # Save initial memory as a build statistic
194
+ state.save_stat(f"{fs.Keys.TOOL_MEMORY}:__init__", self._get_mem_usage_str())
195
+
196
+ # Run the build
197
+ saved_exception = None
198
+ for tool, argv in self.tools.items():
199
+
200
+ start_time = time.time()
201
+ start_times[tool.unique_name] = start_time
202
+
203
+ # Inform profiler of name of tool about to start
204
+ for profiler in self.profilers:
205
+ profiler.tool_starting(tool.unique_name)
206
+
207
+ try:
208
+
209
+ # Set status as incomplete, since tool just started
210
+ state.save_stat(tool.status_key, build.FunctionStatus.INCOMPLETE)
211
+
212
+ # Collect telemetry about the tool
213
+ state.current_build_tool = tool.unique_name
214
+
215
+ # Run the tool
216
+ state = tool.parse_and_run(state, argv, monitor_setting)
217
+
218
+ # Save the state so that it can be assessed for a cache hit
219
+ state.save()
220
+
221
+ except exp.SkipBuild as e:
222
+ # SkipBuild is a special exception, which means that a build
223
+ # was loaded from disk, then we realized we want to skip it.
224
+ # In order to preserve the original stats and state of the build,
225
+ # we need to restore the stats file to what it was at the beginning
226
+ # of this function call. We also need to avoid calling state.save().
227
+
228
+ # Restore the prior stats
229
+ fs.save_yaml(
230
+ saved_stats, fs.Stats(state.cache_dir, state.build_name).file
231
+ )
232
+
233
+ # Advance the cursor below the monitor so
234
+ # we can print a message
235
+ self._advance_cursor(tool.unique_name)
236
+ printing.log_warning(str(e))
237
+ return
238
+
239
+ # Broad exception is desirable as we want to capture
240
+ # all exceptions (including those we can't anticipate)
241
+ except Exception as e: # pylint: disable=broad-except
242
+
243
+ if os.environ.get("LEMONADE_DEBUG", "").lower() == "true":
244
+ # It may be useful to raise the exception here, since
245
+ # if any of the subsequent lines of code raise another
246
+ # exception it will be very hard to root cause e.
247
+ raise e
248
+
249
+ # Update tool and build status
250
+ state.save_stat(tool.status_key, build.FunctionStatus.ERROR)
251
+ state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.ERROR)
252
+
253
+ # Save the log file for the failed tool to stats for easy reference
254
+ stats = fs.Stats(state.cache_dir, state.build_name)
255
+ stats.save_eval_error_log(tool.logfile_path)
256
+
257
+ # Advance the cursor below the monitor so
258
+ # we can print a message
259
+ self._advance_cursor(tool.unique_name)
260
+
261
+ if vars(state).get("invocation_info"):
262
+ state.invocation_info.status_message = f"Error: {e}"
263
+ state.invocation_info.status_message_color = printing.Colors.WARNING
264
+ else:
265
+ printing.log_error(e)
266
+
267
+ # We will raise this exception after we capture as many statistics
268
+ # about the build as possible
269
+ saved_exception = e
270
+
271
+ # Don't run any more tools
272
+ break
273
+
274
+ else:
275
+ # Update tool Status
276
+ state.save_stat(tool.status_key, build.FunctionStatus.SUCCESSFUL)
277
+ state.current_build_tool = None
278
+
279
+ finally:
280
+ # Store tool duration
281
+ execution_time = time.time() - start_time
282
+ state.save_stat(tool.duration_key, execution_time)
283
+
284
+ # Store current memory and peak working memory
285
+ state.save_stat(tool.memory_key, self._get_mem_usage_str())
286
+
287
+ # Inform profilers that tool has finished
288
+ for profiler in self.profilers:
289
+ profiler.tool_stopping()
290
+
291
+ start_times["cool down"] = time.time()
292
+
293
+ # Tell the profilers to stop gathering data
294
+ for profiler in self.profilers:
295
+ profiler.stop()
296
+
297
+ if not saved_exception:
298
+ state.build_status = build.FunctionStatus.SUCCESSFUL
299
+ state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.SUCCESSFUL)
300
+ if vars(state).get("invocation_info"):
301
+ state.invocation_info.status_message = (
302
+ f"Successful build! {state.invocation_info.extra_status}"
303
+ )
304
+ state.invocation_info.status_message_color = printing.Colors.OKGREEN
305
+
306
+ # Generate profiler output
307
+ for profiler in self.profilers:
308
+ profiler.generate_results(state, timestamp, start_times)
309
+
310
+ if vars(state).get("models_found") and vars(state).get("invocation_info"):
311
+
312
+ # Present status statistics from the tools
313
+ for tool in self.tools:
314
+ state.invocation_info.stats_keys += tool.status_stats
315
+
316
+ # Present status statistics from the profilers
317
+ for profiler in self.profilers:
318
+ state.invocation_info.stats_keys += profiler.status_stats
319
+
320
+ print()
321
+
322
+ status.recursive_print(
323
+ models_found=state.models_found,
324
+ build_name=state.build_name,
325
+ cache_dir=state.cache_dir,
326
+ parent_model_hash=None,
327
+ parent_invocation_hash=None,
328
+ script_names_visited=[],
329
+ )
330
+
331
+ if lean_cache:
332
+ printing.log_info("Removing build artifacts...")
333
+ fs.clean_output_dir(state.cache_dir, state.build_name)
334
+
335
+ state.save()
336
+
337
+ if saved_exception:
338
+ raise saved_exception
339
+
340
+ printing.log_success(
341
+ f"\n Saved to **{build.output_dir(state.cache_dir, state.build_name)}**"
342
+ )
343
+
344
+ return state
345
+
346
+ def status_line(self, verbosity):
347
+ """
348
+ Print a status line in the monitor for every tool in the sequence
349
+ """
350
+ for tool in self.tools:
351
+ tool.status_line(successful=None, verbosity=verbosity)
352
+
353
+ @property
354
+ def info(self) -> Dict[str, Dict]:
355
+ """
356
+ Return a dictionary of tool_name:argv for the sequence
357
+ """
358
+
359
+ return {tool.__class__.unique_name: argv for tool, argv in self.tools.items()}
360
+
361
+
362
+ # This file was originally licensed under Apache 2.0. It has been modified.
363
+ # Modifications Copyright (c) 2025 AMD
lemonade/state.py ADDED
@@ -0,0 +1,159 @@
1
+ import os
2
+ import sys
3
+ from typing import Dict, Optional, Any
4
+ import yaml
5
+ import lemonade.common.build as build
6
+ import lemonade.common.filesystem as fs
7
+ from lemonade.version import __version__ as lemonade_version
8
+
9
+
10
+ def _is_nice_to_write(value):
11
+ """
12
+ Checks whether a value is nice to write to YAML.
13
+ Returns True if the value is a string, int, float, bool, list, dict, or tuple.
14
+ Returns False otherwise.
15
+ """
16
+ if isinstance(value, (str, int, float, bool)):
17
+ return True
18
+ elif isinstance(value, list) or isinstance(value, tuple):
19
+ # Check if all elements in the list are nice to write
20
+ return all(_is_nice_to_write(item) for item in value)
21
+ elif isinstance(value, dict):
22
+ # Check if all values in the dictionary are nice to write
23
+ return all(_is_nice_to_write(item) for item in value.values())
24
+ return False
25
+
26
+
27
+ def _sanitize_for_yaml(input_dict: Dict) -> Dict:
28
+ """
29
+ Creates a new dictionary containing only nice-to-write values
30
+ from the original dictionary.
31
+ """
32
+ result = {}
33
+ for key, value in input_dict.items():
34
+ if _is_nice_to_write(value):
35
+ result[key] = value
36
+ return result
37
+
38
+
39
+ class State:
40
+ """
41
+ The State class is meant to carry build state, starting with the user's
42
+ initial arguments, through each build Tool in the Sequence, and finally
43
+ to the disk, where it is used to assess cache hits.
44
+
45
+ State is initialized with the key members that are shared by every build,
46
+ and reasonable default values are assigned as appropriate.
47
+
48
+ Tool developers can also add any members they wish. To get or set an
49
+ attribute, reference it as an attribute:
50
+ 1. get: `my_variable = state.attribute_name`
51
+ 2. set: `state.attribute_name = my_variable`
52
+
53
+ Build State can be saved and loaded from disk in the form of a state.yaml file
54
+ via State.save() and load_state(), respectively. Note that while State can
55
+ contain members of any type, only YAML-safe members (str, int, bool, float,
56
+ list, dict, tuple) will be saved and loaded.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ cache_dir: str,
62
+ build_name: Optional[str] = None,
63
+ sequence_info: Dict[str, Dict] = None,
64
+ **kwargs,
65
+ ):
66
+
67
+ # The default model name is the name of the python file that calls build_model()
68
+ if build_name is None:
69
+ build_name = os.path.basename(sys.argv[0])
70
+
71
+ # Support "~" in the cache_dir argument
72
+ parsed_cache_dir = os.path.expanduser(cache_dir)
73
+
74
+ # Save settings as State members
75
+ self.cache_dir = parsed_cache_dir
76
+ self.build_name = build_name
77
+ self.sequence_info = sequence_info
78
+ self.lemonade_version = lemonade_version
79
+ self.build_status = build.FunctionStatus.NOT_STARTED
80
+ self.downcast_applied = False
81
+ self.uid = build.unique_id()
82
+ self.results = None
83
+
84
+ # Store any additional kwargs as members
85
+ for key, value in kwargs.items():
86
+ self.__dict__[key] = value
87
+
88
+ def __setattr__(self, name: str, value: Any) -> None:
89
+ """
90
+ Tool developers can add a new member to State by simply
91
+ assigning it as an attribute, i.e., `state.new_member = value`.
92
+ """
93
+ return super().__setattr__(name, value)
94
+
95
+ def save_stat(self, key: str, value):
96
+ """
97
+ Save statistics to an yaml file in the build directory
98
+ """
99
+
100
+ stats = fs.Stats(self.cache_dir, self.build_name)
101
+ stats.save_stat(key, value)
102
+
103
+ def save_sub_stat(self, parent_key: str, key: str, value):
104
+ """
105
+ Save statistics to an yaml file in the build directory
106
+ """
107
+
108
+ stats = fs.Stats(self.cache_dir, self.build_name)
109
+ stats.save_sub_stat(parent_key, key, value)
110
+
111
+ def save(self):
112
+ """
113
+ Save all YAML-friendly members to disk as a state.yaml file.
114
+
115
+ Note that `model` and `inputs` will typically not be saved since
116
+ they are typically in non-YAML-friendly types such as `torch.nn.Module`
117
+ and `torch.tensor`.
118
+ """
119
+
120
+ state_to_save = _sanitize_for_yaml(vars(self))
121
+
122
+ # Create a build directory in the cache
123
+ fs.make_build_dir(self.cache_dir, self.build_name)
124
+
125
+ with open(
126
+ build.state_file(self.cache_dir, self.build_name),
127
+ "w",
128
+ encoding="utf8",
129
+ ) as outfile:
130
+ yaml.dump(state_to_save, outfile)
131
+
132
+
133
+ def load_state(
134
+ cache_dir=None,
135
+ build_name=None,
136
+ state_path=None,
137
+ ) -> State:
138
+ """
139
+ Read a state.yaml file corresponding to a specific build in a specific
140
+ cache, and use its contents to initialize a State instance.
141
+ """
142
+
143
+ if state_path is not None:
144
+ file_path = state_path
145
+ elif build_name is not None and cache_dir is not None:
146
+ file_path = build.state_file(cache_dir, build_name)
147
+ else:
148
+ raise ValueError(
149
+ "This function requires either build_name and cache_dir to be set, "
150
+ "or state_path to be set, not both or neither"
151
+ )
152
+
153
+ state_dict = build.load_yaml(file_path)
154
+
155
+ return State(**state_dict)
156
+
157
+
158
+ # This file was originally licensed under Apache 2.0. It has been modified.
159
+ # Modifications Copyright (c) 2025 AMD
@@ -0,0 +1 @@
1
+ from .tool import Tool, FirstTool, NiceHelpFormatter