lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
lemonade/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from lemonade.version import __version__
2
+
3
+ from .state import load_state, State
4
+
5
+ from .cli import main as lemonadecli
lemonade/api.py ADDED
@@ -0,0 +1,125 @@
1
+ # pylint: disable=no-member
2
+
3
+ from typing import Tuple, Dict
4
+ from lemonade.state import State
5
+ import lemonade.common.printing as printing
6
+ import lemonade.cache as cache
7
+ from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
8
+
9
+
10
+ class NotSupported(Exception):
11
+ """
12
+ Indicates that a checkpoint/recipe pair are not supported
13
+ together at this time.
14
+ """
15
+
16
+ def __init__(self, msg):
17
+ super().__init__(msg)
18
+ printing.log_error(msg)
19
+
20
+
21
+ def _raise_not_supported(recipe, checkpoint):
22
+ raise NotSupported(
23
+ f"Recipe {recipe} does not have support for checkpoint {checkpoint}"
24
+ )
25
+
26
+
27
+ def _make_state(recipe, checkpoint) -> Dict:
28
+ return State(cache_dir=cache.DEFAULT_CACHE_DIR, build_name=f"{checkpoint}_{recipe}")
29
+
30
+
31
+ def from_pretrained(
32
+ checkpoint: str,
33
+ recipe: str = "hf-cpu",
34
+ ) -> Tuple[ModelAdapter, TokenizerAdapter]:
35
+ """
36
+ Load an LLM and the corresponding tokenizer using a lemonade recipe.
37
+
38
+ Args:
39
+ - checkpoint: huggingface checkpoint that defines the LLM
40
+ - recipe: defines the implementation and hardware used for the LLM
41
+
42
+ Recipe choices:
43
+ - hf-cpu: Huggingface Transformers implementation for CPU with max-perf settings
44
+ - hf-dgpu: Huggingface Transformers implementation on dGPU (via device="cuda")
45
+ - oga-cpu: CPU implementation based on onnxruntime-genai
46
+ - oga-igpu: DirectML implementation for iGPU based on onnxruntime-genai-directml
47
+ - oga-hybird: AMD Ryzen AI Hybrid implementation based on onnxruntime-genai
48
+
49
+ Returns:
50
+ - model: LLM instance with a generate() method that invokes the recipe
51
+ - tokenizer: tokenizer instance compatible with the model, which supports
52
+ the encode (call) and decode() methods.
53
+ """
54
+
55
+ if recipe == "hf-cpu":
56
+ # Huggingface Transformers recipe for CPU
57
+ # Huggingface supports all checkpoints, so there is nothing to check for
58
+
59
+ import torch
60
+ from lemonade.tools.huggingface_load import HuggingfaceLoad
61
+
62
+ state = _make_state(recipe, checkpoint)
63
+
64
+ state = HuggingfaceLoad().run(
65
+ state,
66
+ input=checkpoint,
67
+ dtype=torch.bfloat16,
68
+ )
69
+
70
+ return state.model, state.tokenizer
71
+
72
+ elif recipe == "hf-dgpu":
73
+ # Huggingface Transformers recipe for discrete GPU (Nvidia, Instinct, Radeon)
74
+
75
+ import torch
76
+ from lemonade.tools.huggingface_load import HuggingfaceLoad
77
+
78
+ state = _make_state(recipe, checkpoint)
79
+
80
+ state = HuggingfaceLoad().run(
81
+ state,
82
+ input=checkpoint,
83
+ dtype=torch.bfloat16,
84
+ device="cuda",
85
+ )
86
+
87
+ return state.model, state.tokenizer
88
+
89
+ elif recipe.startswith("oga-"):
90
+ import lemonade.tools.ort_genai.oga as oga
91
+
92
+ # Make sure the user chose a supported runtime, e.g., oga-cpu
93
+ user_backend = recipe.split("oga-")[1]
94
+ supported_backends = ["cpu", "igpu", "npu", "hybrid"]
95
+ supported_recipes = [f"oga-{backend}" for backend in supported_backends]
96
+ if recipe not in supported_recipes:
97
+ raise NotSupported(
98
+ "Selected OGA recipe is not supported. "
99
+ f"The supported OGA recipes are: {supported_recipes}"
100
+ )
101
+
102
+ backend_to_dtype = {
103
+ "cpu": "int4",
104
+ "igpu": "int4",
105
+ "hybrid": "int4",
106
+ "npu": "int4",
107
+ }
108
+
109
+ state = _make_state(recipe, checkpoint)
110
+
111
+ state = oga.OgaLoad().run(
112
+ state,
113
+ input=checkpoint,
114
+ device=user_backend,
115
+ dtype=backend_to_dtype[user_backend],
116
+ )
117
+
118
+ return state.model, state.tokenizer
119
+
120
+ else:
121
+ _raise_not_supported(recipe, checkpoint)
122
+
123
+
124
+ # This file was originally licensed under Apache 2.0. It has been modified.
125
+ # Modifications Copyright (c) 2025 AMD
lemonade/cache.py ADDED
@@ -0,0 +1,85 @@
1
+ import os
2
+ from datetime import datetime, timezone
3
+
4
+ # Allow an environment variable to override the default
5
+ # location for the build cache
6
+ if os.environ.get("LEMONADE_CACHE_DIR"):
7
+ DEFAULT_CACHE_DIR = os.path.expanduser(os.environ.get("LEMONADE_CACHE_DIR"))
8
+ else:
9
+ DEFAULT_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "lemonade")
10
+
11
+
12
+ def checkpoint_to_model_name(checkpoint_name: str) -> str:
13
+ """
14
+ Get the model's name by stripping the author's name from the checkpoint name
15
+ """
16
+
17
+ return checkpoint_name.split("/")[1]
18
+
19
+
20
+ def get_timestamp() -> str:
21
+ """
22
+ Get a timestamp string in the format:
23
+ <year>y_<month>m_<day>d_<hour>h_<minute>m_<second>s
24
+ """
25
+ # Get the current time in GMT
26
+ current_time = datetime.now(timezone.utc)
27
+
28
+ # Format the timestamp string
29
+ timestamp = current_time.strftime("%Yy_%mm_%dd_%Hh_%Mm_%Ss")
30
+ return timestamp
31
+
32
+
33
+ def build_name(input_name):
34
+ """
35
+ Name the lemonade build by concatenating these two factors:
36
+ 1. Sanitize the input name (typically a model checkpoint name) by
37
+ replacing any `/` characters with `_`.
38
+ 2. Timestamp to ensure that builds in the same cache will not
39
+ collide in the same build directory.
40
+
41
+ If the input_name is a local folder, then we don't know the
42
+ model checkpoint name, so we use "local_model"
43
+ """
44
+
45
+ if os.path.isdir(input_name):
46
+ input_name_sanitized = "local_model"
47
+ else:
48
+ # Sanitize the input name
49
+ input_name_sanitized = input_name.replace("/", "_")
50
+
51
+ # Get the formatted timestamp string
52
+ timestamp = get_timestamp()
53
+
54
+ return f"{input_name_sanitized}_{timestamp}"
55
+
56
+
57
+ class Keys:
58
+ MODEL = "model"
59
+ PER_ITERATION_LATENCY = "per_iteration_latency"
60
+ MEAN_LATENCY = "mean_latency"
61
+ STD_DEV_LATENCY = "std_dev_latency"
62
+ TOKEN_GENERATION_TOKENS_PER_SECOND = "token_generation_tokens_per_second"
63
+ STD_DEV_TOKENS_PER_SECOND = "std_dev_tokens_per_second"
64
+ SECONDS_TO_FIRST_TOKEN = "seconds_to_first_token"
65
+ PREFILL_TOKENS_PER_SECOND = "prefill_tokens_per_second"
66
+ STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token"
67
+ CHECKPOINT = "checkpoint"
68
+ DTYPE = "dtype"
69
+ PROMPT = "prompt"
70
+ PROMPT_TOKENS = "prompt_tokens"
71
+ PROMPT_TEMPLATE = "prompt_template"
72
+ RESPONSE = "response"
73
+ RESPONSE_TOKENS = "response_tokens"
74
+ RESPONSE_LENGTHS_HISTOGRAM = "response_lengths_histogram"
75
+ CACHE_DIR = "cache_dir"
76
+ DEVICE = "device"
77
+ LOCAL_MODEL_FOLDER = "local_model_folder"
78
+ MEMORY_USAGE_PLOT = "memory_usage_plot"
79
+ MAX_MEMORY_USED_GB = "max_memory_used_GB"
80
+ MAX_MEMORY_USED_GBYTE = "max_memory_used_gbyte"
81
+ RYZEN_AI_VERSION_INFO = "ryzen_ai_version_info"
82
+
83
+
84
+ # This file was originally licensed under Apache 2.0. It has been modified.
85
+ # Modifications Copyright (c) 2025 AMD
lemonade/cli.py ADDED
@@ -0,0 +1,135 @@
1
+ import os
2
+ from lemonade.version import __version__ as version_number
3
+ from lemonade.tools import FirstTool, NiceHelpFormatter
4
+ from lemonade.profilers.memory_tracker import MemoryTracker
5
+ import lemonade.common.filesystem as fs
6
+ import lemonade.common.cli_helpers as cli
7
+ from lemonade.sequence import Sequence
8
+ from lemonade.tools.management_tools import Cache, Version, SystemInfo
9
+ from lemonade.state import State
10
+
11
+ from lemonade.tools.huggingface_load import HuggingfaceLoad
12
+
13
+ from lemonade.tools.huggingface_bench import HuggingfaceBench
14
+ from lemonade.tools.ort_genai.oga_bench import OgaBench
15
+ from lemonade.tools.llamacpp_bench import LlamaCppBench
16
+ from lemonade.tools.llamacpp import LoadLlamaCpp
17
+
18
+ import lemonade.cache as cache
19
+ from lemonade.tools.mmlu import AccuracyMMLU
20
+ from lemonade.tools.humaneval import AccuracyHumaneval
21
+ from lemonade.tools.perplexity import AccuracyPerplexity
22
+ from lemonade.tools.prompt import LLMPrompt
23
+ from lemonade.tools.quark.quark_load import QuarkLoad
24
+ from lemonade.tools.quark.quark_quantize import QuarkQuantize
25
+ from lemonade.tools.report.llm_report import LemonadeReport
26
+ from lemonade.tools.server.serve import Server
27
+
28
+
29
+ def main():
30
+
31
+ # List the available tools
32
+ tools = [
33
+ HuggingfaceLoad,
34
+ LoadLlamaCpp,
35
+ LlamaCppBench,
36
+ AccuracyMMLU,
37
+ AccuracyHumaneval,
38
+ AccuracyPerplexity,
39
+ LLMPrompt,
40
+ HuggingfaceBench,
41
+ OgaBench,
42
+ QuarkQuantize,
43
+ QuarkLoad,
44
+ LemonadeReport,
45
+ Server,
46
+ # Inherited from lemonade
47
+ Cache,
48
+ Version,
49
+ SystemInfo,
50
+ ]
51
+
52
+ # Import onnxruntime-genai recipes
53
+ try:
54
+ from lemonade.tools.ort_genai.oga import OgaLoad
55
+
56
+ tools = tools + [OgaLoad]
57
+
58
+ except ModuleNotFoundError:
59
+ pass
60
+
61
+ # List the available profilers
62
+ profilers = [MemoryTracker]
63
+
64
+ # Define the argument parser
65
+ parser = cli.CustomArgumentParser(
66
+ description=f"""Tools for evaluating and deploying LLMs (v{version_number}).
67
+
68
+ Read this to learn the command syntax:
69
+ https://github.com/lemonade-sdk/lemonade/blob/main/docs/README.md""",
70
+ formatter_class=NiceHelpFormatter,
71
+ )
72
+
73
+ parser.add_argument(
74
+ "-i",
75
+ "--input",
76
+ help="The input that will be evaluated by the starting tool "
77
+ "(e.g., huggingface checkpoint)",
78
+ )
79
+
80
+ parser.add_argument(
81
+ "-d",
82
+ "--cache-dir",
83
+ help="Cache directory where tool results are "
84
+ f"stored (default: {cache.DEFAULT_CACHE_DIR})",
85
+ required=False,
86
+ default=cache.DEFAULT_CACHE_DIR,
87
+ )
88
+
89
+ for profiler in profilers:
90
+ profiler.add_arguments_to_parser(parser)
91
+
92
+ global_args, tool_instances, evaluation_tools = cli.parse_tools(
93
+ parser, tools, cli_name="lemonade"
94
+ )
95
+
96
+ profiler_instances = [
97
+ profiler(global_args[profiler.unique_name])
98
+ for profiler in profilers
99
+ if global_args.get(profiler.unique_name, None) is not None
100
+ ]
101
+
102
+ if len(evaluation_tools) > 0:
103
+ if not issubclass(evaluation_tools[0], FirstTool):
104
+ parser.error(
105
+ "The first tool in the sequence needs to be one "
106
+ "of the 'tools that can start a sequence.' Use "
107
+ "`lemonade -h` to see that list of tools."
108
+ )
109
+ # Run the evaluation tools as a build
110
+ sequence = Sequence(tools=tool_instances, profilers=profiler_instances)
111
+
112
+ # Forward the selected input to the first tool in the sequence
113
+ first_tool_args = next(iter(sequence.tools.values()))
114
+ first_tool_args.append("--input")
115
+ first_tool_args.append(global_args["input"])
116
+
117
+ state = State(
118
+ cache_dir=os.path.abspath(global_args["cache_dir"]),
119
+ build_name=cache.build_name(global_args["input"]),
120
+ sequence_info=sequence.info,
121
+ )
122
+ sequence.launch(state)
123
+ else:
124
+ # Run the management tools
125
+ for management_tool, argv in tool_instances.items():
126
+ # Support "~" in the cache_dir argument
127
+ parsed_cache_dir = os.path.expanduser(global_args[fs.Keys.CACHE_DIR])
128
+ management_tool.parse_and_run(parsed_cache_dir, argv)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
133
+
134
+ # This file was originally licensed under Apache 2.0. It has been modified.
135
+ # Modifications Copyright (c) 2025 AMD
File without changes
@@ -0,0 +1,26 @@
1
+ import numpy as np
2
+ import torch
3
+ import onnx
4
+
5
+
6
+ def count_parameters(model: torch.nn.Module) -> int:
7
+ """
8
+ Returns the number of parameters of a given model
9
+ """
10
+ if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
11
+ return sum([parameter.numel() for _, parameter in model.named_parameters()])
12
+ elif isinstance(model, str) and model.endswith(".onnx"):
13
+ onnx_model = onnx.load(model)
14
+ return int(
15
+ sum(
16
+ np.prod(tensor.dims, dtype=np.int64)
17
+ for tensor in onnx_model.graph.initializer
18
+ if tensor.name not in onnx_model.graph.input
19
+ )
20
+ )
21
+ else:
22
+ return None
23
+
24
+
25
+ # This file was originally licensed under Apache 2.0. It has been modified.
26
+ # Modifications Copyright (c) 2025 AMD
@@ -0,0 +1,223 @@
1
+ import os
2
+ import logging
3
+ import sys
4
+ import traceback
5
+ from typing import Dict
6
+ import hashlib
7
+ import psutil
8
+ import yaml
9
+ import torch
10
+ import numpy as np
11
+ import lemonade.common.exceptions as exp
12
+
13
+ state_file_name = "state.yaml"
14
+
15
+
16
+ def load_yaml(file_path) -> Dict:
17
+ with open(file_path, "r", encoding="utf8") as stream:
18
+ try:
19
+ return yaml.load(stream, Loader=yaml.FullLoader)
20
+ except yaml.YAMLError as e:
21
+ raise exp.IOError(
22
+ f"Failed while trying to open {file_path}."
23
+ f"The exception that triggered this was:\n{e}"
24
+ )
25
+
26
+
27
+ def builds_dir(cache_dir):
28
+ """
29
+ Each build stores stats, logs, and other files in a build directory.
30
+ All build directories are located at:
31
+ <cache_dir>/builds
32
+ """
33
+ return os.path.join(cache_dir, "builds")
34
+
35
+
36
+ def output_dir(cache_dir, build_name):
37
+ """
38
+ Each build stores stats, logs, and other files in an output directory at:
39
+ All build directories are located at:
40
+ <builds_dir>/<build_name>
41
+ """
42
+ path = os.path.join(builds_dir(cache_dir), build_name)
43
+ return path
44
+
45
+
46
+ def state_file(cache_dir, build_name):
47
+ path = os.path.join(output_dir(cache_dir, build_name), state_file_name)
48
+ return path
49
+
50
+
51
+ class FunctionStatus:
52
+ """
53
+ Status values that are assigned to tools, builds, benchmarks, and other
54
+ functionality to help the user understand whether that function completed
55
+ successfully or not.
56
+ """
57
+
58
+ # SUCCESSFUL means the tool/build/benchmark completed successfully.
59
+ SUCCESSFUL = "successful"
60
+
61
+ # ERROR means the tool/build/benchmark failed and threw some error that
62
+ # was caught by lemonade. You should proceed by looking at the build
63
+ # logs to see what happened.
64
+
65
+ ERROR = "error"
66
+
67
+ # TIMEOUT means the tool/build/benchmark failed because it exceeded the timeout
68
+ # set for the lemonade command.
69
+ TIMEOUT = "timeout"
70
+
71
+ # KILLED means the build/benchmark failed because the system killed it. This can
72
+ # happen because of an out-of-memory (OOM), system shutdown, etc.
73
+ # You should proceed by re-running the build and keeping an eye on it to observe
74
+ # why it is being killed (e.g., watch the RAM utilization to diagnose an OOM).
75
+ KILLED = "killed"
76
+
77
+ # The NOT_STARTED status is applied to all tools/builds/benchmarks at startup.
78
+ # It will be replaced by one of the other status values if the tool/build/benchmark
79
+ # has a chance to start running.
80
+ # A value of NOT_STARTED in the report CSV indicates that the tool/build/benchmark
81
+ # never had a chance to start because lemonade exited before that functionality had
82
+ # a chance to start running.
83
+ NOT_STARTED = "not_started"
84
+
85
+ # INCOMPLETE indicates that a tool/build/benchmark started running and did not complete.
86
+ # Each tool, build, and benchmark are marked as INCOMPLETE when they start running.
87
+ # If you open the lemonade_stats.yaml file while the tool/build/benchmark
88
+ # is still running, the status will show as INCOMPLETE. If the tool/build/benchmark
89
+ # is killed without the chance to do any stats cleanup, the status will continue to
90
+ # show as INCOMPLETE in lemonade_stats.yaml.
91
+ # When the report CSV is created, any instance of an INCOMPLETE tool/build/benchmark
92
+ # status will be replaced by KILLED.
93
+ INCOMPLETE = "incomplete"
94
+
95
+
96
+ # Create a unique ID from this run by hashing pid + process start time
97
+ def unique_id():
98
+ pid = os.getpid()
99
+ p = psutil.Process(pid)
100
+ start_time = p.create_time()
101
+ return hashlib.sha256(f"{pid}{start_time}".encode()).hexdigest()
102
+
103
+
104
+ def get_shapes_and_dtypes(inputs: dict):
105
+ """
106
+ Return the shape and data type of each value in the inputs dict
107
+ """
108
+ shapes = {}
109
+ dtypes = {}
110
+ for key in sorted(inputs):
111
+ value = inputs[key]
112
+ if isinstance(
113
+ value,
114
+ (list, tuple),
115
+ ):
116
+ for v, i in zip(value, range(len(value))):
117
+ if isinstance(v, (list, tuple)):
118
+ # Handle nested lists/tuples, for example past_key_values
119
+ # in an LLM that has KV-caching enabled
120
+ for v2, i2 in zip(v, range(len(v))):
121
+ subsubkey = f"{key}[{i}][{i2}]"
122
+ shapes[subsubkey] = np.array(v2).shape
123
+ dtypes[subsubkey] = np.array(v2).dtype.name
124
+ else:
125
+ # Handle single list/tuple
126
+ subkey = f"{key}[{i}]"
127
+ shapes[subkey] = np.array(v).shape
128
+ dtypes[subkey] = np.array(v).dtype.name
129
+ elif torch.is_tensor(value):
130
+ shapes[key] = np.array(value.detach()).shape
131
+ dtypes[key] = np.array(value.detach()).dtype.name
132
+ elif isinstance(value, np.ndarray):
133
+ shapes[key] = value.shape
134
+ dtypes[key] = value.dtype.name
135
+ elif isinstance(value, (bool, int, float)):
136
+ shapes[key] = (1,)
137
+ dtypes[key] = type(value).__name__
138
+ elif value is None:
139
+ pass
140
+ else:
141
+ raise exp.Error(
142
+ "One of the provided inputs contains the unsupported "
143
+ f' type {type(value)} at key "{key}".'
144
+ )
145
+
146
+ return shapes, dtypes
147
+
148
+
149
+ class Logger:
150
+ """
151
+ Redirects stdout to file (and console if needed)
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ initial_message: str,
157
+ log_path: str = None,
158
+ ):
159
+ self.debug = os.environ.get("LEMONADE_BUILD_DEBUG") == "True"
160
+ self.terminal = sys.stdout
161
+ self.terminal_err = sys.stderr
162
+ self.log_path = log_path
163
+
164
+ # Create the empty logfile
165
+ with open(log_path, "w", encoding="utf-8") as f:
166
+ f.write(f"{initial_message}\n")
167
+
168
+ # Disable any existing loggers so that we can capture all
169
+ # outputs to a logfile
170
+ self.root_logger = logging.getLogger()
171
+ self.handlers = [handler for handler in self.root_logger.handlers]
172
+ for handler in self.handlers:
173
+ self.root_logger.removeHandler(handler)
174
+
175
+ # Send any logger outputs to the logfile
176
+ if not self.debug:
177
+ self.file_handler = logging.FileHandler(filename=log_path)
178
+ self.file_handler.setLevel(logging.INFO)
179
+ formatter = logging.Formatter(
180
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
181
+ )
182
+ self.file_handler.setFormatter(formatter)
183
+ self.root_logger.addHandler(self.file_handler)
184
+
185
+ def __enter__(self):
186
+ sys.stdout = self
187
+ sys.stderr = self
188
+
189
+ def __exit__(self, _exc_type, _exc_value, _exc_tb):
190
+ # Ensure we also capture the traceback as part of the logger when exceptions happen
191
+ if _exc_type:
192
+ traceback.print_exception(_exc_type, _exc_value, _exc_tb)
193
+
194
+ # Stop redirecting stdout/stderr
195
+ sys.stdout = self.terminal
196
+ sys.stderr = self.terminal_err
197
+
198
+ # Remove the logfile logging handler
199
+ if not self.debug:
200
+ self.file_handler.close()
201
+ self.root_logger.removeHandler(self.file_handler)
202
+
203
+ # Restore any pre-existing loggers
204
+ for handler in self.handlers:
205
+ self.root_logger.addHandler(handler)
206
+
207
+ def write(self, message):
208
+ if self.log_path is not None:
209
+ with open(self.log_path, "a", encoding="utf-8") as f:
210
+ f.write(message)
211
+ if self.debug or self.log_path is None:
212
+ self.terminal.write(message)
213
+ self.terminal.flush()
214
+ self.terminal_err.write(message)
215
+ self.terminal_err.flush()
216
+
217
+ def flush(self):
218
+ # needed for python 3 compatibility.
219
+ pass
220
+
221
+
222
+ # This file was originally licensed under Apache 2.0. It has been modified.
223
+ # Modifications Copyright (c) 2025 AMD