lemonade-sdk 7.0.3__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (55) hide show
  1. lemonade/api.py +3 -3
  2. lemonade/cli.py +11 -17
  3. lemonade/common/build.py +0 -47
  4. lemonade/common/network.py +50 -0
  5. lemonade/common/status.py +2 -21
  6. lemonade/common/system_info.py +19 -4
  7. lemonade/profilers/memory_tracker.py +3 -1
  8. lemonade/tools/accuracy.py +3 -4
  9. lemonade/tools/adapter.py +1 -2
  10. lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
  11. lemonade/tools/huggingface/load.py +235 -0
  12. lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
  13. lemonade/tools/humaneval.py +9 -3
  14. lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
  15. lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
  16. lemonade/tools/mmlu.py +7 -15
  17. lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
  18. lemonade/tools/oga/utils.py +423 -0
  19. lemonade/tools/perplexity.py +4 -3
  20. lemonade/tools/prompt.py +2 -1
  21. lemonade/tools/quark/quark_load.py +2 -1
  22. lemonade/tools/quark/quark_quantize.py +5 -5
  23. lemonade/tools/report/table.py +3 -3
  24. lemonade/tools/server/llamacpp.py +159 -34
  25. lemonade/tools/server/serve.py +169 -147
  26. lemonade/tools/server/static/favicon.ico +0 -0
  27. lemonade/tools/server/static/styles.css +568 -0
  28. lemonade/tools/server/static/webapp.html +439 -0
  29. lemonade/tools/server/tray.py +458 -0
  30. lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
  31. lemonade/tools/server/utils/system_tray.py +395 -0
  32. lemonade/tools/server/{instructions.py → webapp.py} +4 -10
  33. lemonade/version.py +1 -1
  34. lemonade_install/install.py +46 -28
  35. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/METADATA +84 -22
  36. lemonade_sdk-8.0.0.dist-info/RECORD +70 -0
  37. lemonade_server/cli.py +182 -27
  38. lemonade_server/model_manager.py +192 -20
  39. lemonade_server/pydantic_models.py +9 -4
  40. lemonade_server/server_models.json +5 -3
  41. lemonade/common/analyze_model.py +0 -26
  42. lemonade/common/labels.py +0 -61
  43. lemonade/common/onnx_helpers.py +0 -176
  44. lemonade/common/plugins.py +0 -10
  45. lemonade/common/tensor_helpers.py +0 -83
  46. lemonade/tools/server/static/instructions.html +0 -262
  47. lemonade_sdk-7.0.3.dist-info/RECORD +0 -69
  48. /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
  49. /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
  50. /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
  51. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/WHEEL +0 -0
  52. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/entry_points.txt +0 -0
  53. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/LICENSE +0 -0
  54. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/licenses/NOTICE.md +0 -0
  55. {lemonade_sdk-7.0.3.dist-info → lemonade_sdk-8.0.0.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,12 @@ import lemonade.common.status as status
7
7
  from lemonade.tools import FirstTool
8
8
  from lemonade.tools.adapter import PassthroughTokenizer, ModelAdapter
9
9
  from lemonade.cache import Keys
10
- from lemonade.tools.huggingface_load import get_base_model
11
10
 
12
11
 
13
12
  class LlamaCppAdapter(ModelAdapter):
14
- def __init__(self, model, output_tokens, context_size, threads, executable):
13
+ def __init__(
14
+ self, model, output_tokens, context_size, threads, executable, lib_dir=None
15
+ ):
15
16
  super().__init__()
16
17
 
17
18
  self.model = os.path.normpath(model)
@@ -19,6 +20,7 @@ class LlamaCppAdapter(ModelAdapter):
19
20
  self.context_size = context_size
20
21
  self.threads = threads
21
22
  self.executable = os.path.normpath(executable)
23
+ self.lib_dir = lib_dir
22
24
 
23
25
  def generate(
24
26
  self,
@@ -78,6 +80,15 @@ class LlamaCppAdapter(ModelAdapter):
78
80
  cmd = [str(m) for m in cmd]
79
81
 
80
82
  try:
83
+ # Set up environment with library path for Linux
84
+ env = os.environ.copy()
85
+ if self.lib_dir and os.name != "nt": # Not Windows
86
+ current_ld_path = env.get("LD_LIBRARY_PATH", "")
87
+ if current_ld_path:
88
+ env["LD_LIBRARY_PATH"] = f"{self.lib_dir}:{current_ld_path}"
89
+ else:
90
+ env["LD_LIBRARY_PATH"] = self.lib_dir
91
+
81
92
  process = subprocess.Popen(
82
93
  cmd,
83
94
  stdout=subprocess.PIPE,
@@ -85,6 +96,7 @@ class LlamaCppAdapter(ModelAdapter):
85
96
  universal_newlines=True,
86
97
  encoding="utf-8",
87
98
  errors="replace",
99
+ env=env,
88
100
  )
89
101
 
90
102
  raw_output, stderr = process.communicate(timeout=600)
@@ -208,11 +220,14 @@ class LoadLlamaCpp(FirstTool):
208
220
  output_tokens: int = 512,
209
221
  model_binary: Optional[str] = None,
210
222
  executable: str = None,
223
+ lib_dir: Optional[str] = None,
211
224
  ) -> State:
212
225
  """
213
226
  Load a llama.cpp model
214
227
  """
215
228
 
229
+ from lemonade.common.network import get_base_model
230
+
216
231
  if executable is None:
217
232
  raise Exception(f"{self.__class__.unique_name} requires an executable path")
218
233
 
@@ -241,6 +256,7 @@ class LoadLlamaCpp(FirstTool):
241
256
  context_size=context_size,
242
257
  threads=threads,
243
258
  executable=executable,
259
+ lib_dir=lib_dir,
244
260
  )
245
261
  state.tokenizer = PassthroughTokenizer()
246
262
 
lemonade/tools/mmlu.py CHANGED
@@ -4,9 +4,6 @@ import tarfile
4
4
  from pathlib import Path
5
5
  from typing import List, Optional
6
6
  import subprocess
7
- import numpy as np
8
- import pandas as pd
9
- import requests
10
7
  from lemonade.state import State
11
8
  from lemonade.tools import Tool
12
9
  import lemonade.common.printing as printing
@@ -84,6 +81,9 @@ class AccuracyMMLU(Tool):
84
81
  tests: List[str] = None,
85
82
  ) -> State:
86
83
 
84
+ import numpy as np
85
+ import pandas as pd
86
+
87
87
  if data_dir:
88
88
  data_dir_to_use = data_dir
89
89
  else:
@@ -224,18 +224,6 @@ class AccuracyMMLU(Tool):
224
224
  return state
225
225
 
226
226
 
227
- def _list_tests(data_dir):
228
- """Lists all available tests based on the files in the test data directory."""
229
- test_files = [
230
- f for f in os.listdir(os.path.join(data_dir, "test")) if f.endswith("_test.csv")
231
- ]
232
- print(
233
- "Available tests:",
234
- *[f.replace("_test.csv", "") for f in sorted(test_files)],
235
- sep="\n",
236
- )
237
-
238
-
239
227
  def _format_subject(subject):
240
228
  """Formats a subject string by replacing underscores with spaces."""
241
229
  return " ".join(subject.split("_"))
@@ -243,6 +231,8 @@ def _format_subject(subject):
243
231
 
244
232
  def _safe_read_csv(path):
245
233
  """Safely reads a CSV file and returns a DataFrame."""
234
+ import pandas as pd
235
+
246
236
  try:
247
237
  return pd.read_csv(path, header=None)
248
238
  except FileNotFoundError:
@@ -292,6 +282,8 @@ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str):
292
282
  Download the dataset from the given URL and extract it into the target directory.
293
283
  """
294
284
 
285
+ import requests
286
+
295
287
  # Create the directory if it does not exist
296
288
  Path(data_cache_dir).mkdir(parents=True, exist_ok=True)
297
289
 
@@ -10,28 +10,16 @@
10
10
 
11
11
  import argparse
12
12
  import os
13
- import time
14
13
  import json
15
14
  import shutil
16
- import logging
17
15
  from fnmatch import fnmatch
18
- from queue import Queue
19
16
  import subprocess
20
- from packaging.version import Version
21
- from huggingface_hub import snapshot_download
22
- import onnxruntime_genai as og
23
- import onnxruntime_genai.models.builder as model_builder
24
- from transformers import AutoTokenizer
17
+
18
+
25
19
  from lemonade.state import State
26
20
  from lemonade.tools import FirstTool
27
21
  import lemonade.common.status as status
28
22
  import lemonade.common.printing as printing
29
- from lemonade.tools.huggingface_load import get_base_model, is_offline
30
- from lemonade.tools.adapter import (
31
- ModelAdapter,
32
- TokenizerAdapter,
33
- PassthroughTokenizerResult,
34
- )
35
23
  from lemonade.cache import Keys
36
24
  from lemonade_install.install import (
37
25
  get_ryzen_ai_version_info,
@@ -57,414 +45,16 @@ execution_providers = {
57
45
  }
58
46
 
59
47
 
60
- class OrtGenaiTokenizer(TokenizerAdapter):
61
- def __init__(self, model: og.Model, hf_tokenizer: AutoTokenizer):
62
- super().__init__(hf_tokenizer)
63
- # Initialize OGA tokenizer
64
- self.tokenizer = og.Tokenizer(model)
65
-
66
- # Placeholder value since some code will try to query it
67
- # If we actually need this to return a proper value, then
68
- # og.GeneratorParams.eos_token_id has it
69
- self.eos_token_id = None
70
-
71
- def __call__(self, prompt: str, return_tensors="np"):
72
- tokens = self.tokenizer.encode(prompt)
73
- return PassthroughTokenizerResult(tokens)
74
-
75
- # pylint: disable=unused-argument
76
- def decode(self, response, skip_special_tokens=True) -> str:
77
- return self.tokenizer.decode(response)
78
-
79
-
80
- class OrtGenaiStreamer:
81
- def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
82
- self.tokenizer = tokenizer
83
- self.text_queue = Queue()
84
- self.stop_signal = None
85
- self.timeout = timeout
86
-
87
- def add_text(self, text: str):
88
- self.text_queue.put(text, timeout=self.timeout)
89
-
90
- def done(self):
91
- self.text_queue.put(self.stop_signal, timeout=self.timeout)
92
-
93
- def __iter__(self):
94
- return self
95
-
96
- def __next__(self):
97
- value = self.text_queue.get(timeout=self.timeout)
98
- if value == self.stop_signal:
99
- raise StopIteration()
100
- else:
101
- return value
102
-
103
-
104
- class OrtGenaiModel(ModelAdapter):
105
-
106
- def __init__(self, input_folder):
107
- super().__init__()
108
- self.model = og.Model(input_folder)
109
- self.type = "ort-genai"
110
- self.config = self.load_config(input_folder)
111
-
112
- def load_config(self, input_folder):
113
- rai_config_path = os.path.join(input_folder, "rai_config.json")
114
- if os.path.exists(rai_config_path):
115
- with open(rai_config_path, "r", encoding="utf-8") as f:
116
- max_prompt_length = json.load(f)["max_prompt_length"]["1.4.1"]
117
- else:
118
- max_prompt_length = None
119
-
120
- config_path = os.path.join(input_folder, "genai_config.json")
121
- if os.path.exists(config_path):
122
- with open(config_path, "r", encoding="utf-8") as f:
123
- config_dict = json.load(f)
124
- if max_prompt_length:
125
- config_dict["max_prompt_length"] = max_prompt_length
126
- return config_dict
127
- return None
128
-
129
- def generate(
130
- self,
131
- input_ids,
132
- max_new_tokens=512,
133
- min_new_tokens=0,
134
- do_sample=True,
135
- top_k=50,
136
- top_p=1.0,
137
- temperature=0.7,
138
- streamer: OrtGenaiStreamer = None,
139
- pad_token_id=None,
140
- stopping_criteria=None,
141
- max_length=None,
142
- random_seed=1,
143
- ):
144
- params = og.GeneratorParams(self.model)
145
-
146
- prompt_length = len(input_ids)
147
- max_prompt_length = self.config.get("max_prompt_length")
148
- if max_prompt_length and prompt_length > max_prompt_length:
149
- raise ValueError(
150
- f"This prompt (length {prompt_length}) exceeds the model's "
151
- f"maximum allowed prompt length ({max_prompt_length})."
152
- )
153
-
154
- # There is a breaking API change in OGA 0.6.0
155
- # Determine whether we should use the old or new APIs
156
- # This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
157
- use_oga_post_6_api = (
158
- Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
159
- )
160
- use_oga_pre_6_api = not use_oga_post_6_api
161
-
162
- if pad_token_id:
163
- params.pad_token_id = pad_token_id
164
-
165
- # Handle max_length and max_new_tokens
166
- if max_length and max_new_tokens:
167
- logging.warning(
168
- "Both max_length and max_new_tokens were provided. "
169
- "max_length will take precedence. "
170
- "When setting max_length, please explicitly set max_new_tokens to None."
171
- )
172
- max_length_to_use = None
173
- if max_length:
174
- max_length_to_use = max_length
175
- elif max_new_tokens:
176
- max_length_to_use = prompt_length + max_new_tokens
177
-
178
- min_length = prompt_length + min_new_tokens
179
-
180
- if use_oga_pre_6_api:
181
- params.input_ids = input_ids
182
-
183
- if random_seed is None:
184
- random_seed = -1 # In og.Generator, -1 = seed with random device
185
-
186
- if self.config and "search" in self.config:
187
- search_config = self.config["search"]
188
- params.set_search_options(
189
- do_sample=search_config.get("do_sample", do_sample),
190
- top_k=search_config.get("top_k", top_k),
191
- top_p=search_config.get("top_p", top_p),
192
- temperature=search_config.get("temperature", temperature),
193
- max_length=max_length_to_use,
194
- min_length=min_length,
195
- early_stopping=search_config.get("early_stopping", False),
196
- length_penalty=search_config.get("length_penalty", 1.0),
197
- num_beams=search_config.get("num_beams", 1),
198
- num_return_sequences=search_config.get("num_return_sequences", 1),
199
- repetition_penalty=search_config.get("repetition_penalty", 1.0),
200
- past_present_share_buffer=search_config.get(
201
- "past_present_share_buffer", True
202
- ),
203
- random_seed=random_seed,
204
- # Not currently supported by OGA
205
- # diversity_penalty=search_config.get('diversity_penalty', 0.0),
206
- # no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
207
- )
208
- else:
209
- params.set_search_options(
210
- do_sample=do_sample,
211
- top_k=top_k,
212
- top_p=top_p,
213
- temperature=temperature,
214
- max_length=max_length_to_use,
215
- min_length=min_length,
216
- random_seed=random_seed,
217
- )
218
- params.try_graph_capture_with_max_batch_size(1)
219
-
220
- generator = og.Generator(self.model, params)
221
-
222
- if streamer is None:
223
- prompt_start_time = time.perf_counter()
224
- if use_oga_post_6_api:
225
- generator.append_tokens(input_ids)
226
- if use_oga_pre_6_api:
227
- generator.compute_logits()
228
- generator.generate_next_token()
229
- prompt_end_time = time.perf_counter()
230
-
231
- self.time_to_first_token = prompt_end_time - prompt_start_time
232
-
233
- if max_new_tokens > 1:
234
-
235
- token_gen_times = []
236
- while not generator.is_done():
237
- token_gen_start_time = time.perf_counter()
238
- if use_oga_pre_6_api:
239
- generator.compute_logits()
240
- generator.generate_next_token()
241
- token_gen_end_time = time.perf_counter()
242
-
243
- token_gen_times.append(token_gen_end_time - token_gen_start_time)
244
-
245
- if token_gen_times:
246
- # List will be empty if we generated 1 or 0 tokens, and we don't
247
- # want a divide-by-zero error in those cases
248
- avg_token_gen_latency_s = sum(token_gen_times) / len(
249
- token_gen_times
250
- )
251
- self.tokens_per_second = 1 / avg_token_gen_latency_s
252
-
253
- return [generator.get_sequence(0)]
254
- else:
255
- if use_oga_post_6_api:
256
- generator.append_tokens(input_ids)
257
- tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
258
-
259
- stop_early = False
260
-
261
- while not generator.is_done() and not stop_early:
262
- if use_oga_pre_6_api:
263
- generator.compute_logits()
264
- generator.generate_next_token()
265
-
266
- new_token = generator.get_next_tokens()[0]
267
- new_text = tokenizer_stream.decode(new_token)
268
-
269
- streamer.add_text(new_text)
270
-
271
- if stopping_criteria is not None:
272
- if stopping_criteria[0].stop_event.is_set():
273
- stop_early = True
274
-
275
- streamer.done()
276
-
277
- def _model_call(self, input_ids):
278
- """
279
- Run the model on input_ids and get logits.
280
-
281
- This method directly accesses model logits rather than using the full generate pipeline for
282
- several important reasons:
283
- 1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
284
- producing multiple tokens through iterative inference
285
- 2. Efficiency: Direct access is more efficient for logprob calculations with no
286
- sampling overhead
287
- 3. Precision: Logprob calculations require exact control over input-to-output mapping
288
- 4. Consistency: Similar approach used in both HF and OGA implementations
289
-
290
- Args:
291
- input_ids: Input token IDs
292
-
293
- Returns:
294
- Logits for each token in the sequence
295
- """
296
- import torch
297
-
298
- # Setup generator params
299
- params = og.GeneratorParams(self.model)
300
-
301
- # Configure for a simple forward pass
302
- params.set_search_options(
303
- do_sample=False,
304
- temperature=0.0,
305
- max_length=len(input_ids),
306
- )
307
-
308
- # Initialize generator
309
- generator = og.Generator(self.model, params)
310
-
311
- # Feed tokens to model based on API version
312
- generator.append_tokens(input_ids)
313
-
314
- # Extract logits - this returns a list of logits tensors
315
- logits = generator.get_output("logits")
316
-
317
- # Convert to torch tensor for easier processing
318
- return torch.tensor(logits[0])
319
-
320
- def _select_cont_toks(self, logits, context_len, continuation_tokens):
321
- """
322
- Select and process logits for continuation tokens.
323
-
324
- Args:
325
- logits: Full sequence logits
326
- context_len: Length of context tokens
327
- continuation_tokens: List or tensor of continuation token IDs
328
-
329
- Returns:
330
- Log probabilities for continuation tokens
331
- """
332
- import torch
333
-
334
- # Extract relevant logits for continuation prediction (shift by one)
335
- cont_logits = logits[
336
- context_len - 1 : context_len - 1 + len(continuation_tokens)
337
- ]
338
-
339
- # Convert to torch tensors if needed
340
- if not isinstance(continuation_tokens, torch.Tensor):
341
- continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
342
-
343
- # Apply log softmax to get log probabilities
344
- log_probs = torch.log_softmax(cont_logits, dim=-1)
345
-
346
- # Get log probs for the specific continuation tokens
347
- token_log_probs = torch.gather(
348
- log_probs, 1, continuation_tokens.unsqueeze(-1)
349
- ).squeeze(-1)
350
-
351
- return token_log_probs
352
-
353
- def compute_logprobs(
354
- self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
355
- ):
356
- """
357
- Compute log probabilities for all tokens in the given text.
358
-
359
- Args:
360
- text: The full text to analyze (e.g., prompt + completion)
361
- prompt_length: Number of tokens in the prompt. If provided and echo=False,
362
- only completion tokens after this position will be returned.
363
- logprobs: If not None, return log probabilities. Value indicates how many top
364
- alternatives to return. If True but not an integer, defaults to 5 alternatives.
365
- echo: If True, include logprobs for prompt tokens. If False, only return logprobs
366
- for completion tokens.
367
-
368
- Returns:
369
- - text_offset: Character offsets for each token in the text
370
- - token_logprobs: Log probability for each token
371
- - tokens: The actual tokens used
372
- - top_logprobs: Top alternative log probabilities for each position
373
- """
374
- import torch
375
-
376
- if tokenizer is None:
377
- raise ValueError("Tokenizer is required for logprob calculation")
378
-
379
- # Encode the full text
380
- tokens = tokenizer(text).input_ids # pylint: disable=E1102
381
-
382
- # Track character offsets for each token
383
- text_offset = []
384
- start_idx = 0
385
-
386
- token_strings = []
387
- for token_id in tokens:
388
- token_str = tokenizer.decode([token_id])
389
- token_strings.append(token_str)
390
-
391
- # Calculate character offsets for tokens - handles cases where tokens
392
- # may not directly match in the original text due to encoding differences,
393
- # special characters, or tokenization artifacts
394
- try:
395
- pos = text[start_idx:].find(token_str)
396
- if pos != -1:
397
- text_offset.append(start_idx + pos)
398
- start_idx += pos + len(token_str)
399
- else:
400
- text_offset.append(start_idx)
401
- except (TypeError, ValueError, UnicodeError):
402
- # Fallback to current position when matching fails due to encoding issues
403
- text_offset.append(start_idx)
404
-
405
- # Get logits from model
406
- logits = self._model_call(tokens)
407
-
408
- # Calculate log probabilities for each token
409
- all_log_probs = torch.log_softmax(logits, dim=-1)
410
-
411
- # The first token doesn't have a conditional probability
412
- # For tokens after the first, get the predicted probability
413
- token_log_probs = []
414
- top_logprobs_list = []
415
-
416
- # For each position, get the actual token probability and top alternatives
417
- for i in range(len(tokens)):
418
- # Get previous token position logits
419
- if i > 0: # First token has no preceding context
420
- prev_logits = all_log_probs[i - 1]
421
- curr_token_id = tokens[i]
422
- # Get probability of the actual token that appeared
423
- token_logprob = prev_logits[curr_token_id].item()
424
- token_log_probs.append(token_logprob)
425
-
426
- # Get top-k alternatives if requested
427
- if logprobs is not None:
428
- num_alternatives = logprobs if isinstance(logprobs, int) else 5
429
- topk_values, topk_indices = torch.topk(
430
- prev_logits, min(num_alternatives, prev_logits.size(-1))
431
- )
432
-
433
- # Create dictionary of token: logprob
434
- position_logprobs = {}
435
- for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
436
- token_str = tokenizer.decode([idx])
437
- position_logprobs[token_str] = val
438
-
439
- top_logprobs_list.append(position_logprobs)
440
- else:
441
- # For the first token, we don't have a conditional probability
442
- token_log_probs.append(None)
443
- top_logprobs_list.append({})
444
-
445
- # If we don't want to echo prompt tokens, filter them out
446
- if not echo and prompt_length is not None:
447
- # Ensure prompt_length is within bounds
448
- prompt_length = min(prompt_length, len(tokens))
449
-
450
- # Filter results to only include completion tokens
451
- if prompt_length < len(tokens):
452
- filtered_text_offset = text_offset[prompt_length:]
453
- filtered_token_logprobs = token_log_probs[prompt_length:]
454
- filtered_tokens = token_strings[prompt_length:]
455
- filtered_top_logprobs = top_logprobs_list[prompt_length:]
456
-
457
- return (
458
- filtered_text_offset,
459
- filtered_token_logprobs,
460
- filtered_tokens,
461
- filtered_top_logprobs,
462
- )
463
- else:
464
- # No completion tokens
465
- return [], [], [], []
466
-
467
- return text_offset, token_log_probs, token_strings, top_logprobs_list
48
+ def import_error_heler(e: Exception):
49
+ """
50
+ Print a helpful message in the event of an import error
51
+ """
52
+ raise ImportError(
53
+ f"{e}\n Please install lemonade-sdk with "
54
+ "one of the llm-oga extras, for example:\n"
55
+ "pip install lemonade-sdk[llm-oga-cpu]\n"
56
+ "See https://lemonade_server.ai/install_options.html for details"
57
+ )
468
58
 
469
59
 
470
60
  class OgaLoad(FirstTool):
@@ -624,6 +214,8 @@ class OgaLoad(FirstTool):
624
214
  files that have locally been quantized/converted to OGA format and any other
625
215
  models that have been manually added by the user.
626
216
  """
217
+ from huggingface_hub import snapshot_download
218
+
627
219
  if subfolder is None:
628
220
  subfolder = f"{execution_providers[device]}-{dtype}"
629
221
  subfolder += (
@@ -749,6 +341,12 @@ class OgaLoad(FirstTool):
749
341
  Uses OGA model builder to quantize safetensors format model and convert to ONNX
750
342
  format. The model files are saved to the full_model_path folder.
751
343
  """
344
+
345
+ try:
346
+ import onnxruntime_genai.models.builder as model_builder
347
+ except ImportError as e:
348
+ import_error_heler(e)
349
+
752
350
  printing.log_info(f"Building {checkpoint} for {device} using {dtype}")
753
351
  extra_options = {}
754
352
  if int4_block_size is not None:
@@ -837,6 +435,14 @@ class OgaLoad(FirstTool):
837
435
  Loads the OGA model from local folder and then loads the tokenizer.
838
436
  Will auto-detect if we're offline.
839
437
  """
438
+
439
+ try:
440
+ from transformers import AutoTokenizer
441
+ from lemonade.tools.oga.utils import OrtGenaiModel, OrtGenaiTokenizer
442
+ from lemonade.common.network import is_offline
443
+ except ImportError as e:
444
+ import_error_heler(e)
445
+
840
446
  try:
841
447
  state.model = OrtGenaiModel(full_model_path)
842
448
  except Exception as e:
@@ -945,6 +551,9 @@ class OgaLoad(FirstTool):
945
551
  trust_remote_code=False,
946
552
  subfolder: str = None,
947
553
  ) -> State:
554
+ from huggingface_hub import snapshot_download
555
+ from lemonade.common.network import get_base_model, is_offline
556
+
948
557
  # Auto-detect offline status
949
558
  offline = is_offline()
950
559
  if offline: