lemonade-sdk 7.0.3__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (55) hide show
  1. lemonade/api.py +3 -3
  2. lemonade/cli.py +11 -17
  3. lemonade/common/build.py +0 -47
  4. lemonade/common/network.py +50 -0
  5. lemonade/common/status.py +2 -21
  6. lemonade/common/system_info.py +19 -4
  7. lemonade/profilers/memory_tracker.py +3 -1
  8. lemonade/tools/accuracy.py +3 -4
  9. lemonade/tools/adapter.py +1 -2
  10. lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
  11. lemonade/tools/huggingface/load.py +235 -0
  12. lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
  13. lemonade/tools/humaneval.py +9 -3
  14. lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
  15. lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
  16. lemonade/tools/mmlu.py +7 -15
  17. lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
  18. lemonade/tools/oga/utils.py +423 -0
  19. lemonade/tools/perplexity.py +4 -3
  20. lemonade/tools/prompt.py +2 -1
  21. lemonade/tools/quark/quark_load.py +2 -1
  22. lemonade/tools/quark/quark_quantize.py +5 -5
  23. lemonade/tools/report/table.py +3 -3
  24. lemonade/tools/server/llamacpp.py +159 -34
  25. lemonade/tools/server/serve.py +169 -147
  26. lemonade/tools/server/static/favicon.ico +0 -0
  27. lemonade/tools/server/static/styles.css +568 -0
  28. lemonade/tools/server/static/webapp.html +439 -0
  29. lemonade/tools/server/tray.py +458 -0
  30. lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
  31. lemonade/tools/server/utils/system_tray.py +395 -0
  32. lemonade/tools/server/{instructions.py → webapp.py} +4 -10
  33. lemonade/version.py +1 -1
  34. lemonade_install/install.py +46 -28
  35. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/METADATA +84 -22
  36. lemonade_sdk-8.0.0.dist-info/RECORD +70 -0
  37. lemonade_server/cli.py +182 -27
  38. lemonade_server/model_manager.py +192 -20
  39. lemonade_server/pydantic_models.py +9 -4
  40. lemonade_server/server_models.json +5 -3
  41. lemonade/common/analyze_model.py +0 -26
  42. lemonade/common/labels.py +0 -61
  43. lemonade/common/onnx_helpers.py +0 -176
  44. lemonade/common/plugins.py +0 -10
  45. lemonade/common/tensor_helpers.py +0 -83
  46. lemonade/tools/server/static/instructions.html +0 -262
  47. lemonade_sdk-7.0.3.dist-info/RECORD +0 -69
  48. /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
  49. /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
  50. /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
  51. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/WHEEL +0 -0
  52. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/entry_points.txt +0 -0
  53. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/LICENSE +0 -0
  54. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/NOTICE.md +0 -0
  55. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
1
+ import os
2
+ import time
3
+ import json
4
+ import logging
5
+ from queue import Queue
6
+ from packaging.version import Version
7
+ import onnxruntime_genai as og
8
+ from transformers import AutoTokenizer
9
+ from lemonade.tools.adapter import (
10
+ ModelAdapter,
11
+ TokenizerAdapter,
12
+ PassthroughTokenizerResult,
13
+ )
14
+
15
+
16
+ class OrtGenaiTokenizer(TokenizerAdapter):
17
+ def __init__(self, model: og.Model, hf_tokenizer: AutoTokenizer):
18
+ super().__init__(hf_tokenizer)
19
+ # Initialize OGA tokenizer
20
+ self.tokenizer = og.Tokenizer(model)
21
+
22
+ # Placeholder value since some code will try to query it
23
+ # If we actually need this to return a proper value, then
24
+ # og.GeneratorParams.eos_token_id has it
25
+ self.eos_token_id = None
26
+
27
+ def __call__(self, prompt: str, return_tensors="np"):
28
+ tokens = self.tokenizer.encode(prompt)
29
+ return PassthroughTokenizerResult(tokens)
30
+
31
+ # pylint: disable=unused-argument
32
+ def decode(self, response, skip_special_tokens=True) -> str:
33
+ return self.tokenizer.decode(response)
34
+
35
+
36
+ class OrtGenaiStreamer:
37
+ def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
38
+ self.tokenizer = tokenizer
39
+ self.text_queue = Queue()
40
+ self.stop_signal = None
41
+ self.timeout = timeout
42
+
43
+ def add_text(self, text: str):
44
+ self.text_queue.put(text, timeout=self.timeout)
45
+
46
+ def done(self):
47
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
48
+
49
+ def __iter__(self):
50
+ return self
51
+
52
+ def __next__(self):
53
+ value = self.text_queue.get(timeout=self.timeout)
54
+ if value == self.stop_signal:
55
+ raise StopIteration()
56
+ else:
57
+ return value
58
+
59
+
60
+ class OrtGenaiModel(ModelAdapter):
61
+
62
+ def __init__(self, input_folder):
63
+ super().__init__()
64
+ self.model = og.Model(input_folder)
65
+ self.type = "ort-genai"
66
+ self.config = self.load_config(input_folder)
67
+
68
+ def load_config(self, input_folder):
69
+ rai_config_path = os.path.join(input_folder, "rai_config.json")
70
+ if os.path.exists(rai_config_path):
71
+ with open(rai_config_path, "r", encoding="utf-8") as f:
72
+ max_prompt_length = json.load(f)["max_prompt_length"]["1.4.1"]
73
+ else:
74
+ max_prompt_length = None
75
+
76
+ config_path = os.path.join(input_folder, "genai_config.json")
77
+ if os.path.exists(config_path):
78
+ with open(config_path, "r", encoding="utf-8") as f:
79
+ config_dict = json.load(f)
80
+ if max_prompt_length:
81
+ config_dict["max_prompt_length"] = max_prompt_length
82
+ return config_dict
83
+ return None
84
+
85
+ def generate(
86
+ self,
87
+ input_ids,
88
+ max_new_tokens=512,
89
+ min_new_tokens=0,
90
+ do_sample=True,
91
+ top_k=50,
92
+ top_p=1.0,
93
+ temperature=0.7,
94
+ streamer: OrtGenaiStreamer = None,
95
+ pad_token_id=None,
96
+ stopping_criteria=None,
97
+ max_length=None,
98
+ random_seed=1,
99
+ ):
100
+ params = og.GeneratorParams(self.model)
101
+
102
+ prompt_length = len(input_ids)
103
+ max_prompt_length = self.config.get("max_prompt_length")
104
+ if max_prompt_length and prompt_length > max_prompt_length:
105
+ raise ValueError(
106
+ f"This prompt (length {prompt_length}) exceeds the model's "
107
+ f"maximum allowed prompt length ({max_prompt_length})."
108
+ )
109
+
110
+ # There is a breaking API change in OGA 0.6.0
111
+ # Determine whether we should use the old or new APIs
112
+ # This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
113
+ use_oga_post_6_api = (
114
+ Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
115
+ )
116
+ use_oga_pre_6_api = not use_oga_post_6_api
117
+
118
+ if pad_token_id:
119
+ params.pad_token_id = pad_token_id
120
+
121
+ # Handle max_length and max_new_tokens
122
+ if max_length and max_new_tokens:
123
+ logging.warning(
124
+ "Both max_length and max_new_tokens were provided. "
125
+ "max_length will take precedence. "
126
+ "When setting max_length, please explicitly set max_new_tokens to None."
127
+ )
128
+ max_length_to_use = None
129
+ if max_length:
130
+ max_length_to_use = max_length
131
+ elif max_new_tokens:
132
+ max_length_to_use = prompt_length + max_new_tokens
133
+
134
+ min_length = prompt_length + min_new_tokens
135
+
136
+ if use_oga_pre_6_api:
137
+ params.input_ids = input_ids
138
+
139
+ if random_seed is None:
140
+ random_seed = -1 # In og.Generator, -1 = seed with random device
141
+
142
+ if self.config and "search" in self.config:
143
+ search_config = self.config["search"]
144
+ params.set_search_options(
145
+ do_sample=search_config.get("do_sample", do_sample),
146
+ top_k=search_config.get("top_k", top_k),
147
+ top_p=search_config.get("top_p", top_p),
148
+ temperature=search_config.get("temperature", temperature),
149
+ max_length=max_length_to_use,
150
+ min_length=min_length,
151
+ early_stopping=search_config.get("early_stopping", False),
152
+ length_penalty=search_config.get("length_penalty", 1.0),
153
+ num_beams=search_config.get("num_beams", 1),
154
+ num_return_sequences=search_config.get("num_return_sequences", 1),
155
+ repetition_penalty=search_config.get("repetition_penalty", 1.0),
156
+ past_present_share_buffer=search_config.get(
157
+ "past_present_share_buffer", True
158
+ ),
159
+ random_seed=random_seed,
160
+ # Not currently supported by OGA
161
+ # diversity_penalty=search_config.get('diversity_penalty', 0.0),
162
+ # no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
163
+ )
164
+ else:
165
+ params.set_search_options(
166
+ do_sample=do_sample,
167
+ top_k=top_k,
168
+ top_p=top_p,
169
+ temperature=temperature,
170
+ max_length=max_length_to_use,
171
+ min_length=min_length,
172
+ random_seed=random_seed,
173
+ )
174
+ params.try_graph_capture_with_max_batch_size(1)
175
+
176
+ generator = og.Generator(self.model, params)
177
+
178
+ if streamer is None:
179
+ prompt_start_time = time.perf_counter()
180
+ if use_oga_post_6_api:
181
+ generator.append_tokens(input_ids)
182
+ if use_oga_pre_6_api:
183
+ generator.compute_logits()
184
+ generator.generate_next_token()
185
+ prompt_end_time = time.perf_counter()
186
+
187
+ self.time_to_first_token = prompt_end_time - prompt_start_time
188
+
189
+ if max_new_tokens > 1:
190
+
191
+ token_gen_times = []
192
+ while not generator.is_done():
193
+ token_gen_start_time = time.perf_counter()
194
+ if use_oga_pre_6_api:
195
+ generator.compute_logits()
196
+ generator.generate_next_token()
197
+ token_gen_end_time = time.perf_counter()
198
+
199
+ token_gen_times.append(token_gen_end_time - token_gen_start_time)
200
+
201
+ if token_gen_times:
202
+ # List will be empty if we generated 1 or 0 tokens, and we don't
203
+ # want a divide-by-zero error in those cases
204
+ avg_token_gen_latency_s = sum(token_gen_times) / len(
205
+ token_gen_times
206
+ )
207
+ self.tokens_per_second = 1 / avg_token_gen_latency_s
208
+
209
+ return [generator.get_sequence(0)]
210
+ else:
211
+ if use_oga_post_6_api:
212
+ generator.append_tokens(input_ids)
213
+ tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
214
+
215
+ stop_early = False
216
+
217
+ while not generator.is_done() and not stop_early:
218
+ if use_oga_pre_6_api:
219
+ generator.compute_logits()
220
+ generator.generate_next_token()
221
+
222
+ new_token = generator.get_next_tokens()[0]
223
+ new_text = tokenizer_stream.decode(new_token)
224
+
225
+ streamer.add_text(new_text)
226
+
227
+ if stopping_criteria is not None:
228
+ if stopping_criteria[0].stop_event.is_set():
229
+ stop_early = True
230
+
231
+ streamer.done()
232
+
233
+ def _model_call(self, input_ids):
234
+ """
235
+ Run the model on input_ids and get logits.
236
+
237
+ This method directly accesses model logits rather than using the full generate pipeline for
238
+ several important reasons:
239
+ 1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
240
+ producing multiple tokens through iterative inference
241
+ 2. Efficiency: Direct access is more efficient for logprob calculations with no
242
+ sampling overhead
243
+ 3. Precision: Logprob calculations require exact control over input-to-output mapping
244
+ 4. Consistency: Similar approach used in both HF and OGA implementations
245
+
246
+ Args:
247
+ input_ids: Input token IDs
248
+
249
+ Returns:
250
+ Logits for each token in the sequence
251
+ """
252
+ import torch
253
+
254
+ # Setup generator params
255
+ params = og.GeneratorParams(self.model)
256
+
257
+ # Configure for a simple forward pass
258
+ params.set_search_options(
259
+ do_sample=False,
260
+ temperature=0.0,
261
+ max_length=len(input_ids),
262
+ )
263
+
264
+ # Initialize generator
265
+ generator = og.Generator(self.model, params)
266
+
267
+ # Feed tokens to model based on API version
268
+ generator.append_tokens(input_ids)
269
+
270
+ # Extract logits - this returns a list of logits tensors
271
+ logits = generator.get_output("logits")
272
+
273
+ # Convert to torch tensor for easier processing
274
+ return torch.tensor(logits[0])
275
+
276
+ def _select_cont_toks(self, logits, context_len, continuation_tokens):
277
+ """
278
+ Select and process logits for continuation tokens.
279
+
280
+ Args:
281
+ logits: Full sequence logits
282
+ context_len: Length of context tokens
283
+ continuation_tokens: List or tensor of continuation token IDs
284
+
285
+ Returns:
286
+ Log probabilities for continuation tokens
287
+ """
288
+ import torch
289
+
290
+ # Extract relevant logits for continuation prediction (shift by one)
291
+ cont_logits = logits[
292
+ context_len - 1 : context_len - 1 + len(continuation_tokens)
293
+ ]
294
+
295
+ # Convert to torch tensors if needed
296
+ if not isinstance(continuation_tokens, torch.Tensor):
297
+ continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
298
+
299
+ # Apply log softmax to get log probabilities
300
+ log_probs = torch.log_softmax(cont_logits, dim=-1)
301
+
302
+ # Get log probs for the specific continuation tokens
303
+ token_log_probs = torch.gather(
304
+ log_probs, 1, continuation_tokens.unsqueeze(-1)
305
+ ).squeeze(-1)
306
+
307
+ return token_log_probs
308
+
309
+ def compute_logprobs(
310
+ self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
311
+ ):
312
+ """
313
+ Compute log probabilities for all tokens in the given text.
314
+
315
+ Args:
316
+ text: The full text to analyze (e.g., prompt + completion)
317
+ prompt_length: Number of tokens in the prompt. If provided and echo=False,
318
+ only completion tokens after this position will be returned.
319
+ logprobs: If not None, return log probabilities. Value indicates how many top
320
+ alternatives to return. If True but not an integer, defaults to 5 alternatives.
321
+ echo: If True, include logprobs for prompt tokens. If False, only return logprobs
322
+ for completion tokens.
323
+
324
+ Returns:
325
+ - text_offset: Character offsets for each token in the text
326
+ - token_logprobs: Log probability for each token
327
+ - tokens: The actual tokens used
328
+ - top_logprobs: Top alternative log probabilities for each position
329
+ """
330
+ import torch
331
+
332
+ if tokenizer is None:
333
+ raise ValueError("Tokenizer is required for logprob calculation")
334
+
335
+ # Encode the full text
336
+ tokens = tokenizer(text).input_ids # pylint: disable=E1102
337
+
338
+ # Track character offsets for each token
339
+ text_offset = []
340
+ start_idx = 0
341
+
342
+ token_strings = []
343
+ for token_id in tokens:
344
+ token_str = tokenizer.decode([token_id])
345
+ token_strings.append(token_str)
346
+
347
+ # Calculate character offsets for tokens - handles cases where tokens
348
+ # may not directly match in the original text due to encoding differences,
349
+ # special characters, or tokenization artifacts
350
+ try:
351
+ pos = text[start_idx:].find(token_str)
352
+ if pos != -1:
353
+ text_offset.append(start_idx + pos)
354
+ start_idx += pos + len(token_str)
355
+ else:
356
+ text_offset.append(start_idx)
357
+ except (TypeError, ValueError, UnicodeError):
358
+ # Fallback to current position when matching fails due to encoding issues
359
+ text_offset.append(start_idx)
360
+
361
+ # Get logits from model
362
+ logits = self._model_call(tokens)
363
+
364
+ # Calculate log probabilities for each token
365
+ all_log_probs = torch.log_softmax(logits, dim=-1)
366
+
367
+ # The first token doesn't have a conditional probability
368
+ # For tokens after the first, get the predicted probability
369
+ token_log_probs = []
370
+ top_logprobs_list = []
371
+
372
+ # For each position, get the actual token probability and top alternatives
373
+ for i in range(len(tokens)):
374
+ # Get previous token position logits
375
+ if i > 0: # First token has no preceding context
376
+ prev_logits = all_log_probs[i - 1]
377
+ curr_token_id = tokens[i]
378
+ # Get probability of the actual token that appeared
379
+ token_logprob = prev_logits[curr_token_id].item()
380
+ token_log_probs.append(token_logprob)
381
+
382
+ # Get top-k alternatives if requested
383
+ if logprobs is not None:
384
+ num_alternatives = logprobs if isinstance(logprobs, int) else 5
385
+ topk_values, topk_indices = torch.topk(
386
+ prev_logits, min(num_alternatives, prev_logits.size(-1))
387
+ )
388
+
389
+ # Create dictionary of token: logprob
390
+ position_logprobs = {}
391
+ for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
392
+ token_str = tokenizer.decode([idx])
393
+ position_logprobs[token_str] = val
394
+
395
+ top_logprobs_list.append(position_logprobs)
396
+ else:
397
+ # For the first token, we don't have a conditional probability
398
+ token_log_probs.append(None)
399
+ top_logprobs_list.append({})
400
+
401
+ # If we don't want to echo prompt tokens, filter them out
402
+ if not echo and prompt_length is not None:
403
+ # Ensure prompt_length is within bounds
404
+ prompt_length = min(prompt_length, len(tokens))
405
+
406
+ # Filter results to only include completion tokens
407
+ if prompt_length < len(tokens):
408
+ filtered_text_offset = text_offset[prompt_length:]
409
+ filtered_token_logprobs = token_log_probs[prompt_length:]
410
+ filtered_tokens = token_strings[prompt_length:]
411
+ filtered_top_logprobs = top_logprobs_list[prompt_length:]
412
+
413
+ return (
414
+ filtered_text_offset,
415
+ filtered_token_logprobs,
416
+ filtered_tokens,
417
+ filtered_top_logprobs,
418
+ )
419
+ else:
420
+ # No completion tokens
421
+ return [], [], [], []
422
+
423
+ return text_offset, token_log_probs, token_strings, top_logprobs_list
@@ -1,8 +1,5 @@
1
1
  import os
2
2
  import argparse
3
- import pandas as pd
4
- import torch
5
- from datasets import load_dataset
6
3
  from lemonade.state import State
7
4
  from lemonade.tools import Tool
8
5
  import lemonade.common.printing as printing
@@ -41,6 +38,10 @@ class AccuracyPerplexity(Tool):
41
38
  state: State,
42
39
  ) -> State:
43
40
 
41
+ import pandas as pd
42
+ import torch
43
+ from datasets import load_dataset
44
+
44
45
  try:
45
46
  printing.log_info("Downloading dataset ...")
46
47
  dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
lemonade/tools/prompt.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import argparse
2
2
  import os
3
- import matplotlib.pyplot as plt
4
3
  import lemonade.common.build as build
5
4
  import lemonade.common.printing as printing
6
5
  from lemonade.state import State
@@ -154,6 +153,8 @@ class LLMPrompt(Tool):
154
153
  random_seed: int = DEFAULT_RANDOM_SEED,
155
154
  ) -> State:
156
155
 
156
+ import matplotlib.pyplot as plt
157
+
157
158
  model: ModelAdapter = state.model
158
159
  tokenizer: TokenizerAdapter = state.tokenizer
159
160
 
@@ -2,7 +2,6 @@ import argparse
2
2
  import os
3
3
  import sys
4
4
 
5
- import torch
6
5
  from lemonade.state import State
7
6
  from lemonade.tools import Tool
8
7
  import lemonade.common.printing as printing
@@ -101,6 +100,8 @@ class QuarkLoad(Tool):
101
100
  Exception: If an error occurs during the QuarkLoad process.
102
101
  """
103
102
 
103
+ import torch
104
+
104
105
  try:
105
106
  if os.path.isdir(DEFAULT_QUARK_DIR):
106
107
  quark_llm_path = os.path.join(
@@ -2,9 +2,6 @@ import argparse
2
2
  import os
3
3
  import sys
4
4
  from pathlib import Path
5
-
6
- import torch
7
- from transformers import AutoProcessor
8
5
  from lemonade.state import State
9
6
  from lemonade.tools import Tool
10
7
  import lemonade.common.printing as printing
@@ -319,8 +316,8 @@ class QuarkQuantize(Tool):
319
316
  - Optionally exporting, compiling, and evaluating the model.
320
317
  """
321
318
 
322
- model = state.model.model
323
- tokenizer = state.tokenizer
319
+ import torch
320
+ from transformers import AutoProcessor
324
321
 
325
322
  # Importing quark utils after adding to sys.path
326
323
  from llm_utils.data_preparation import get_calib_dataloader
@@ -328,6 +325,9 @@ class QuarkQuantize(Tool):
328
325
  from llm_ptq.configuration_preparation import get_config, get_export_config
329
326
  from quark.torch import ModelQuantizer, ModelExporter, save_params
330
327
 
328
+ model = state.model.model
329
+ tokenizer = state.tokenizer
330
+
331
331
  # 1. Load Model
332
332
  printing.log_info("Loading model ...")
333
333
  model_type = get_model_type(model)
@@ -7,10 +7,10 @@ from tabulate import tabulate
7
7
  import lemonade.common.build as build
8
8
  import lemonade.common.filesystem as fs
9
9
  from lemonade.cache import Keys
10
- from lemonade.tools.huggingface_bench import HuggingfaceBench
11
- from lemonade.tools.llamacpp_bench import LlamaCppBench
10
+ from lemonade.tools.huggingface.bench import HuggingfaceBench
11
+ from lemonade.tools.llamacpp.bench import LlamaCppBench
12
12
  from lemonade.tools.mmlu import AccuracyMMLU
13
- from lemonade.tools.ort_genai.oga_bench import OgaBench
13
+ from lemonade.tools.oga.bench import OgaBench
14
14
 
15
15
  # List of python packages for which to log the version
16
16
  PYTHON_PACKAGES = ["onnxruntime", "transformers", "lemonade-sdk", "voe"]