lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
@@ -0,0 +1,490 @@
1
+ import os
2
+ import platform
3
+ import shutil
4
+ import sys
5
+ import math
6
+ import dataclasses
7
+ from typing import Callable, List, Union, Dict, Optional
8
+ import textwrap
9
+ import psutil
10
+ import torch
11
+ from lemonade.common import printing
12
+ from lemonade.state import State
13
+ import lemonade.common.build as build
14
+ import lemonade.common.filesystem as fs
15
+ import lemonade.common.analyze_model as analyze_model
16
+
17
+
18
+ def _pretty_print_key(key: str) -> str:
19
+ result = key.split("_")
20
+ result = [word.capitalize() for word in result]
21
+ result = " ".join(result)
22
+ return result
23
+
24
+
25
+ class PrettyFloat(float):
26
+ def __repr__(self):
27
+ return f"{self:0.3f}"
28
+
29
+
30
+ def parameters_to_size(parameters: int, byte_per_parameter: int = 4) -> str:
31
+ size_bytes = parameters * byte_per_parameter
32
+ if size_bytes == 0:
33
+ return "0B"
34
+ size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
35
+ i = int(math.floor(math.log(size_bytes, 1024)))
36
+ p = math.pow(1024, i)
37
+ s = round(size_bytes / p, 2)
38
+ return "%s %s" % (s, size_name[i])
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class BasicInfo:
43
+ name: str
44
+ script_name: str
45
+ file: str = ""
46
+ line: int = 0
47
+ params: int = 0
48
+ depth: int = 0
49
+ parent_hash: Union[str, None] = None
50
+ model_class: type = None
51
+ # This is the "model hash", not to be confused with the
52
+ # "invocation hash"
53
+ hash: Union[str, None] = None
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class SkipFields:
58
+ """
59
+ Keep track of which fields of a model's status should be skipped
60
+ during printout. There are two use cases in mind:
61
+ - For incremental printout: fields that have already been printed.
62
+ - For low-verbosity: fields that should never be printed.
63
+ """
64
+
65
+ file_name: bool = False
66
+ model_name: bool = False
67
+ parameters: bool = False
68
+ location: bool = False
69
+ input_shape: bool = False
70
+ build_dir: bool = False
71
+ unique_input_shape: bool = False
72
+ previous_status_message: Optional[str] = None
73
+
74
+
75
+ @dataclasses.dataclass
76
+ class UniqueInvocationInfo(BasicInfo):
77
+ """
78
+ Refers to unique static model invocations
79
+ (i.e. models executed with unique input shapes)
80
+ """
81
+
82
+ invocation_hash: Union[str, None] = None
83
+ traceback: List[str] = None
84
+ inputs: Union[dict, None] = None
85
+ input_shapes: Union[dict, None] = None
86
+ executed: int = 0
87
+ exec_time: float = 0.0
88
+ status_message: str = ""
89
+ extra_status: Optional[str] = ""
90
+ is_target: bool = False
91
+ auto_selected: bool = False
92
+ status_message_color: printing.Colors = printing.Colors.ENDC
93
+ traceback_message_color: printing.Colors = printing.Colors.FAIL
94
+ stats_keys: List[str] = dataclasses.field(default_factory=list)
95
+ forward_function_pointer: callable = None
96
+ original_forward_function: callable = None
97
+ # Fields specific to printing status
98
+ skip: SkipFields = None
99
+ extension: str = None
100
+ indent: str = None
101
+
102
+ def __post_init__(self):
103
+ self.skip = SkipFields()
104
+
105
+ def _print_heading(
106
+ self,
107
+ exec_time_formatted: str,
108
+ print_file_name: bool,
109
+ model_visited: bool,
110
+ multiple_unique_invocations: bool,
111
+ ):
112
+ if self.skip.file_name or self.skip.model_name:
113
+ return
114
+
115
+ if print_file_name:
116
+ print(f"{self.script_name}{self.extension}:")
117
+
118
+ # Print invocation about the model (only applies to scripts, not ONNX files or
119
+ # LLMs, which have no extension)
120
+ if not (
121
+ self.extension == ".onnx"
122
+ or self.extension == build.state_file_name
123
+ or self.extension == ""
124
+ ):
125
+ if self.depth == 0 and multiple_unique_invocations:
126
+ if not model_visited:
127
+ printing.logn(f"{self.indent}{self.name}")
128
+ else:
129
+ printing.log(f"{self.indent}{self.name}")
130
+ printing.logn(
131
+ f" (executed {self.executed}x{exec_time_formatted})",
132
+ c=printing.Colors.OKGREEN,
133
+ )
134
+
135
+ self.skip.file_name = True
136
+ self.skip.model_name = True
137
+
138
+ def _print_location(self):
139
+ if self.skip.location or self.file == "":
140
+ return
141
+
142
+ if self.depth == 0:
143
+ print(f"{self.indent}\tLocation:\t{self.file}", end="")
144
+ if self.extension == ".onnx":
145
+ print()
146
+ else:
147
+ print(f", line {self.line}")
148
+ self.skip.location = True
149
+
150
+ def _print_parameters(self):
151
+ if self.skip.parameters or self.params is None:
152
+ return
153
+
154
+ # Display number of parameters and size
155
+ parameters_size = parameters_to_size(self.params)
156
+ print(
157
+ f"{self.indent}\tParameters:\t{'{:,}'.format(self.params)} ({parameters_size})"
158
+ )
159
+
160
+ self.skip.parameters = True
161
+
162
+ def _print_unique_input_shape(
163
+ self,
164
+ exec_time_formatted: str,
165
+ invocation_idx: int,
166
+ multiple_unique_invocations: bool,
167
+ ):
168
+ if self.skip.unique_input_shape:
169
+ return
170
+
171
+ if self.depth == 0 and multiple_unique_invocations:
172
+ printing.logn(
173
+ f"\n{self.indent}\tWith input shape {invocation_idx+1} "
174
+ f"(executed {self.executed}x{exec_time_formatted})",
175
+ c=printing.Colors.OKGREEN,
176
+ )
177
+
178
+ self.skip.unique_input_shape = True
179
+
180
+ def _print_input_shape(self):
181
+ if self.skip.input_shape or self.input_shapes is None:
182
+ return
183
+
184
+ # Prepare input shape to be printed
185
+ input_shape = dict(self.input_shapes)
186
+ input_shape = {key: value for key, value in input_shape.items() if value != ()}
187
+ input_shape = str(input_shape).replace("{", "").replace("}", "")
188
+
189
+ print(f"{self.indent}\tInput Shape:\t{input_shape}")
190
+
191
+ self.skip.input_shape = True
192
+
193
+ def _print_build_dir(self, cache_dir: str, build_name: str):
194
+ if self.skip.build_dir or not self.is_target:
195
+ return
196
+
197
+ print(f"{self.indent}\tBuild dir:\t{build.output_dir(cache_dir, build_name)}")
198
+
199
+ self.skip.build_dir = True
200
+
201
+ def _print_peak_memory(self):
202
+ if platform.system() == "Windows":
203
+ print(
204
+ f"{self.indent}\tPeak memory:\t"
205
+ f"{psutil.Process().memory_info().peak_wset / 1024**3:,.3f} GB"
206
+ )
207
+
208
+ def _print_status(self, cache_dir: str, build_name: str):
209
+ stats = fs.Stats(cache_dir, build_name)
210
+ if self.skip.previous_status_message:
211
+ if self.skip.previous_status_message == self.status_message:
212
+ # This is a special case for skipping: we only want to skip
213
+ # printing the outcome if we have already printed that
214
+ # exact message already.
215
+ return
216
+ else:
217
+ # Print some whitespace to help the status stand out
218
+ print()
219
+
220
+ printing.log(f"{self.indent}\tStatus:\t\t")
221
+ printing.logn(
222
+ f"{self.status_message}",
223
+ c=self.status_message_color,
224
+ )
225
+ if self.is_target:
226
+
227
+ # Get the maximum key length to figure out the number
228
+ # of tabs needed to align the values
229
+ max_key_len = 0
230
+ for key in self.stats_keys:
231
+ max_key_len = max(len(_pretty_print_key(key)), max_key_len)
232
+
233
+ screen_width = shutil.get_terminal_size().columns
234
+ wrap_screen_width = screen_width - 2
235
+
236
+ for key in self.stats_keys:
237
+ nice_key = _pretty_print_key(key)
238
+ try:
239
+ value = stats.stats[key]
240
+ if isinstance(value, float):
241
+ value = PrettyFloat(value)
242
+ elif isinstance(value, list):
243
+ value = [
244
+ PrettyFloat(v) if isinstance(v, float) else v for v in value
245
+ ]
246
+ # Tools may provide a unit of measurement for their status
247
+ # stats, whose key name should follow the format
248
+ # "STATUS_STATS_KEY_units"
249
+ units_key = key + "_units"
250
+ units = stats.stats.get(units_key)
251
+ units = units if units is not None else ""
252
+ if self.extension == "":
253
+ value_tabs = " " * (
254
+ (max_key_len - len(_pretty_print_key(key))) + 1
255
+ )
256
+ hanging_indent = (
257
+ len(self.indent) + 8 + len(nice_key) + 1 + len(value_tabs)
258
+ )
259
+ hanging_indent_str = " " * hanging_indent
260
+ if (
261
+ isinstance(value, list)
262
+ and len(value) > 0
263
+ and all(isinstance(item, str) for item in value)
264
+ ):
265
+ # Value is a list of strings, so output each one starting
266
+ # on its own line
267
+ printing.logn(f"{self.indent}\t{nice_key}:{value_tabs}[")
268
+ for line_counter, text in enumerate(value):
269
+ lines = textwrap.wrap(
270
+ "'" + text + "'",
271
+ width=wrap_screen_width,
272
+ initial_indent=hanging_indent_str,
273
+ subsequent_indent=hanging_indent_str,
274
+ )
275
+ if line_counter + 1 < len(value):
276
+ # Not the last text item in the list, so add a comma
277
+ lines[-1] = lines[-1] + ","
278
+ for line in lines:
279
+ printing.logn(line)
280
+ printing.logn(f"{' ' * hanging_indent}] {units}")
281
+ else:
282
+ # Wrap value as needed
283
+ status_str = (
284
+ f"{self.indent}\t{nice_key}:{value_tabs}{value} {units}"
285
+ )
286
+ lines = textwrap.wrap(
287
+ status_str,
288
+ width=wrap_screen_width,
289
+ subsequent_indent=hanging_indent_str,
290
+ )
291
+ for line in lines:
292
+ printing.logn(line)
293
+ else:
294
+ printing.logn(
295
+ f"{self.indent}\t\t\t{nice_key}:\t{value} {units}"
296
+ )
297
+ except KeyError:
298
+ # Ignore any keys that are missing because that means the
299
+ # evaluation did not produce them
300
+ pass
301
+
302
+ if self.traceback is not None:
303
+ if os.environ.get("LEMONADE_TRACEBACK") != "False":
304
+ for line in self.traceback:
305
+ for subline in line.split("\n")[:-1]:
306
+ print(f"{self.indent}\t{subline}")
307
+
308
+ else:
309
+ printing.logn(
310
+ f"{self.indent}\t\t\tTo see the full stack trace, "
311
+ "rerun with `export LEMONADE_TRACEBACK=True`.\n",
312
+ c=self.status_message_color,
313
+ )
314
+ else:
315
+ print()
316
+
317
+ self.skip.previous_status_message = self.status_message
318
+
319
+ def print(
320
+ self,
321
+ build_name: str,
322
+ cache_dir: str,
323
+ print_file_name: bool = False,
324
+ invocation_idx: int = 0,
325
+ model_visited: bool = False,
326
+ multiple_unique_invocations: bool = False,
327
+ ):
328
+ """
329
+ Print information about a given model or submodel.
330
+ """
331
+
332
+ if self.extension == ".onnx" or self.extension == "":
333
+ self.indent = "\t" * (2 * self.depth)
334
+ else:
335
+ self.indent = "\t" * (2 * self.depth + 1)
336
+
337
+ if self.exec_time == 0:
338
+ exec_time_formatted = ""
339
+ else:
340
+ exec_time_formatted = f" - {self.exec_time:.2f}s"
341
+
342
+ self._print_heading(
343
+ exec_time_formatted,
344
+ print_file_name,
345
+ model_visited,
346
+ multiple_unique_invocations,
347
+ )
348
+ if (self.depth == 0 and not model_visited) or (self.depth != 0):
349
+ # Print this information only once per model
350
+ self._print_location()
351
+ self._print_parameters()
352
+ self._print_unique_input_shape(
353
+ exec_time_formatted, invocation_idx, multiple_unique_invocations
354
+ )
355
+ self._print_input_shape()
356
+ self._print_build_dir(cache_dir=cache_dir, build_name=build_name)
357
+ self._print_peak_memory()
358
+ self._print_status(cache_dir=cache_dir, build_name=build_name)
359
+
360
+ print()
361
+
362
+
363
+ @dataclasses.dataclass
364
+ class ModelInfo(BasicInfo):
365
+ model: torch.nn.Module = None
366
+ old_forward: Union[Callable, None] = None
367
+ unique_invocations: Union[Dict[str, UniqueInvocationInfo], None] = (
368
+ dataclasses.field(default_factory=dict)
369
+ )
370
+ last_unique_invocation_executed: Union[str, None] = None
371
+
372
+ def __post_init__(self):
373
+ self.params = analyze_model.count_parameters(self.model)
374
+
375
+
376
+ def recursive_print(
377
+ models_found: Dict[str, ModelInfo],
378
+ build_name: str,
379
+ cache_dir: str,
380
+ parent_model_hash: Union[str, None] = None,
381
+ parent_invocation_hash: Union[str, None] = None,
382
+ script_names_visited: List[str] = False,
383
+ ) -> None:
384
+ script_names_visited = []
385
+
386
+ for model_hash in models_found.keys():
387
+ model_visited = False
388
+ model_info = models_found[model_hash]
389
+ invocation_idx = 0
390
+ for invocation_hash in model_info.unique_invocations.keys():
391
+ unique_invocation = model_info.unique_invocations[invocation_hash]
392
+
393
+ if (
394
+ parent_model_hash == model_info.parent_hash
395
+ and unique_invocation.executed > 0
396
+ and (
397
+ model_info.unique_invocations[invocation_hash].parent_hash
398
+ == parent_invocation_hash
399
+ )
400
+ ):
401
+ print_file_name = False
402
+ if model_info.script_name not in script_names_visited:
403
+ script_names_visited.append(model_info.script_name)
404
+ if model_info.depth == 0:
405
+ print_file_name = True
406
+
407
+ # In this verbosity mode we want to print all of the information
408
+ # every time, so reset SkipFields
409
+ # NOTE: to introduce a new lower-verbosity mode, set some members
410
+ # of SkipFields to False to skip them
411
+ unique_invocation.skip = SkipFields()
412
+
413
+ unique_invocation.print(
414
+ build_name=build_name,
415
+ cache_dir=cache_dir,
416
+ print_file_name=print_file_name,
417
+ invocation_idx=invocation_idx,
418
+ model_visited=model_visited,
419
+ multiple_unique_invocations=len(model_info.unique_invocations) > 1,
420
+ )
421
+ model_visited = True
422
+ invocation_idx += 1
423
+
424
+ if print_file_name:
425
+ script_names_visited.append(model_info.script_name)
426
+
427
+ recursive_print(
428
+ models_found,
429
+ build_name,
430
+ cache_dir,
431
+ parent_model_hash=model_hash,
432
+ parent_invocation_hash=invocation_hash,
433
+ script_names_visited=script_names_visited,
434
+ )
435
+
436
+
437
+ def stop_logger_forward() -> None:
438
+ """
439
+ Stop forwarding stdout and stderr to file
440
+ """
441
+ if hasattr(sys.stdout, "terminal"):
442
+ sys.stdout = sys.stdout.terminal
443
+ if hasattr(sys.stderr, "terminal_err"):
444
+ sys.stderr = sys.stderr.terminal_err
445
+
446
+
447
+ def add_to_state(
448
+ state: State,
449
+ name: str,
450
+ model: Union[str, torch.nn.Module],
451
+ extension: str = "",
452
+ input_shapes: Optional[Dict] = None,
453
+ ):
454
+ if vars(state).get("model_hash"):
455
+ model_hash = state.model_hash
456
+ else:
457
+ model_hash = 0
458
+
459
+ if os.path.exists(name):
460
+ file_name = fs.clean_file_name(name)
461
+ file = name
462
+ else:
463
+ file_name = name
464
+ file = ""
465
+
466
+ state.invocation_info = UniqueInvocationInfo(
467
+ name=input,
468
+ script_name=file_name,
469
+ file=file,
470
+ input_shapes=input_shapes,
471
+ hash=model_hash,
472
+ is_target=True,
473
+ extension=extension,
474
+ executed=1,
475
+ )
476
+ state.models_found = {
477
+ "the_model": ModelInfo(
478
+ model=model,
479
+ name=input,
480
+ script_name=input,
481
+ file=input,
482
+ unique_invocations={model_hash: state.invocation_info},
483
+ hash=model_hash,
484
+ )
485
+ }
486
+ state.invocation_info.params = state.models_found["the_model"].params
487
+
488
+
489
+ # This file was originally licensed under Apache 2.0. It has been modified.
490
+ # Modifications Copyright (c) 2025 AMD