lemonade-sdk 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +180 -0
  3. lemonade/cache.py +92 -0
  4. lemonade/cli.py +173 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/build.py +176 -0
  7. lemonade/common/cli_helpers.py +139 -0
  8. lemonade/common/exceptions.py +98 -0
  9. lemonade/common/filesystem.py +368 -0
  10. lemonade/common/inference_engines.py +408 -0
  11. lemonade/common/network.py +93 -0
  12. lemonade/common/printing.py +110 -0
  13. lemonade/common/status.py +471 -0
  14. lemonade/common/system_info.py +1411 -0
  15. lemonade/common/test_helpers.py +28 -0
  16. lemonade/profilers/__init__.py +1 -0
  17. lemonade/profilers/agt_power.py +437 -0
  18. lemonade/profilers/hwinfo_power.py +429 -0
  19. lemonade/profilers/memory_tracker.py +259 -0
  20. lemonade/profilers/profiler.py +58 -0
  21. lemonade/sequence.py +363 -0
  22. lemonade/state.py +159 -0
  23. lemonade/tools/__init__.py +1 -0
  24. lemonade/tools/accuracy.py +432 -0
  25. lemonade/tools/adapter.py +114 -0
  26. lemonade/tools/bench.py +302 -0
  27. lemonade/tools/flm/__init__.py +1 -0
  28. lemonade/tools/flm/utils.py +305 -0
  29. lemonade/tools/huggingface/bench.py +187 -0
  30. lemonade/tools/huggingface/load.py +235 -0
  31. lemonade/tools/huggingface/utils.py +359 -0
  32. lemonade/tools/humaneval.py +264 -0
  33. lemonade/tools/llamacpp/bench.py +255 -0
  34. lemonade/tools/llamacpp/load.py +222 -0
  35. lemonade/tools/llamacpp/utils.py +1260 -0
  36. lemonade/tools/management_tools.py +319 -0
  37. lemonade/tools/mmlu.py +319 -0
  38. lemonade/tools/oga/__init__.py +0 -0
  39. lemonade/tools/oga/bench.py +120 -0
  40. lemonade/tools/oga/load.py +804 -0
  41. lemonade/tools/oga/migration.py +403 -0
  42. lemonade/tools/oga/utils.py +462 -0
  43. lemonade/tools/perplexity.py +147 -0
  44. lemonade/tools/prompt.py +263 -0
  45. lemonade/tools/report/__init__.py +0 -0
  46. lemonade/tools/report/llm_report.py +203 -0
  47. lemonade/tools/report/table.py +899 -0
  48. lemonade/tools/server/__init__.py +0 -0
  49. lemonade/tools/server/flm.py +133 -0
  50. lemonade/tools/server/llamacpp.py +320 -0
  51. lemonade/tools/server/serve.py +2123 -0
  52. lemonade/tools/server/static/favicon.ico +0 -0
  53. lemonade/tools/server/static/index.html +279 -0
  54. lemonade/tools/server/static/js/chat.js +1059 -0
  55. lemonade/tools/server/static/js/model-settings.js +183 -0
  56. lemonade/tools/server/static/js/models.js +1395 -0
  57. lemonade/tools/server/static/js/shared.js +556 -0
  58. lemonade/tools/server/static/logs.html +191 -0
  59. lemonade/tools/server/static/styles.css +2654 -0
  60. lemonade/tools/server/static/webapp.html +321 -0
  61. lemonade/tools/server/tool_calls.py +153 -0
  62. lemonade/tools/server/tray.py +664 -0
  63. lemonade/tools/server/utils/macos_tray.py +226 -0
  64. lemonade/tools/server/utils/port.py +77 -0
  65. lemonade/tools/server/utils/thread.py +85 -0
  66. lemonade/tools/server/utils/windows_tray.py +408 -0
  67. lemonade/tools/server/webapp.py +34 -0
  68. lemonade/tools/server/wrapped_server.py +559 -0
  69. lemonade/tools/tool.py +374 -0
  70. lemonade/version.py +1 -0
  71. lemonade_install/__init__.py +1 -0
  72. lemonade_install/install.py +239 -0
  73. lemonade_sdk-9.1.1.dist-info/METADATA +276 -0
  74. lemonade_sdk-9.1.1.dist-info/RECORD +84 -0
  75. lemonade_sdk-9.1.1.dist-info/WHEEL +5 -0
  76. lemonade_sdk-9.1.1.dist-info/entry_points.txt +5 -0
  77. lemonade_sdk-9.1.1.dist-info/licenses/LICENSE +201 -0
  78. lemonade_sdk-9.1.1.dist-info/licenses/NOTICE.md +47 -0
  79. lemonade_sdk-9.1.1.dist-info/top_level.txt +3 -0
  80. lemonade_server/cli.py +805 -0
  81. lemonade_server/model_manager.py +758 -0
  82. lemonade_server/pydantic_models.py +159 -0
  83. lemonade_server/server_models.json +643 -0
  84. lemonade_server/settings.py +39 -0
@@ -0,0 +1,462 @@
1
+ import os
2
+ import time
3
+ import json
4
+ import logging
5
+ from queue import Queue
6
+ from packaging.version import Version
7
+ import onnxruntime_genai as og
8
+ from transformers import AutoTokenizer
9
+ from lemonade.tools.adapter import (
10
+ ModelAdapter,
11
+ TokenizerAdapter,
12
+ PassthroughTokenizerResult,
13
+ )
14
+ from lemonade_install.install import _get_ryzenai_version_info
15
+
16
+
17
+ class OrtGenaiTokenizer(TokenizerAdapter):
18
+ def __init__(self, model: og.Model, hf_tokenizer: AutoTokenizer):
19
+ super().__init__(hf_tokenizer)
20
+ # Initialize OGA tokenizer
21
+ self.tokenizer = og.Tokenizer(model)
22
+
23
+ # Placeholder value since some code will try to query it
24
+ # If we actually need this to return a proper value, then
25
+ # og.GeneratorParams.eos_token_id has it
26
+ self.eos_token_id = None
27
+
28
+ def __call__(self, prompt: str, return_tensors="np"):
29
+ tokens = self.tokenizer.encode(prompt)
30
+ return PassthroughTokenizerResult(tokens)
31
+
32
+ # pylint: disable=unused-argument
33
+ def decode(self, response, skip_special_tokens=True) -> str:
34
+ return self.tokenizer.decode(response)
35
+
36
+
37
+ class OrtGenaiStreamer:
38
+ def __init__(self, tokenizer: OrtGenaiTokenizer, timeout=None):
39
+ self.tokenizer = tokenizer
40
+ self.text_queue = Queue()
41
+ self.stop_signal = None
42
+ self.timeout = timeout
43
+
44
+ def add_text(self, text: str):
45
+ self.text_queue.put(text, timeout=self.timeout)
46
+
47
+ def done(self):
48
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
49
+
50
+ def __iter__(self):
51
+ return self
52
+
53
+ def __next__(self):
54
+ value = self.text_queue.get(timeout=self.timeout)
55
+ if value == self.stop_signal:
56
+ raise StopIteration()
57
+ else:
58
+ return value
59
+
60
+
61
+ class OrtGenaiModel(ModelAdapter):
62
+
63
+ def __init__(self, input_folder):
64
+ super().__init__()
65
+ self.model = og.Model(input_folder)
66
+ self.type = "ort-genai"
67
+ self.config = self.load_config(input_folder)
68
+
69
+ def load_config(self, input_folder):
70
+ rai_config_path = os.path.join(input_folder, "rai_config.json")
71
+ max_prompt_length = None
72
+
73
+ try:
74
+ detected_version, _ = _get_ryzenai_version_info()
75
+
76
+ if os.path.exists(rai_config_path):
77
+ with open(rai_config_path, "r", encoding="utf-8") as f:
78
+ rai_config = json.load(f)
79
+ if (
80
+ "max_prompt_length" in rai_config
81
+ and detected_version in rai_config["max_prompt_length"]
82
+ ):
83
+ max_prompt_length = rai_config["max_prompt_length"][
84
+ detected_version
85
+ ]
86
+ except: # pylint: disable=bare-except
87
+ pass
88
+
89
+ config_path = os.path.join(input_folder, "genai_config.json")
90
+ if os.path.exists(config_path):
91
+ with open(config_path, "r", encoding="utf-8") as f:
92
+ config_dict = json.load(f)
93
+ config_dict["max_prompt_length"] = max_prompt_length
94
+ return config_dict
95
+ return None
96
+
97
+ def generate(
98
+ self,
99
+ input_ids,
100
+ max_new_tokens=512,
101
+ min_new_tokens=0,
102
+ do_sample=True,
103
+ top_k=None,
104
+ top_p=None,
105
+ temperature=None,
106
+ repeat_penalty=None,
107
+ streamer: OrtGenaiStreamer = None,
108
+ pad_token_id=None,
109
+ stopping_criteria=None,
110
+ max_length=None,
111
+ random_seed=1,
112
+ ):
113
+ params = og.GeneratorParams(self.model)
114
+
115
+ # OGA models return a list of tokens (older versions) or 1d numpy array (newer versions)
116
+ prompt_length = len(input_ids)
117
+
118
+ max_prompt_length = self.config.get("max_prompt_length")
119
+ if max_prompt_length and prompt_length > max_prompt_length:
120
+ raise ValueError(
121
+ f"This prompt (length {prompt_length}) exceeds the model's "
122
+ f"maximum allowed prompt length ({max_prompt_length})."
123
+ )
124
+ self.prompt_tokens = prompt_length
125
+
126
+ # There is a breaking API change in OGA 0.6.0
127
+ # Determine whether we should use the old or new APIs
128
+ # This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
129
+ use_oga_post_6_api = (
130
+ Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
131
+ )
132
+ use_oga_pre_6_api = not use_oga_post_6_api
133
+
134
+ if pad_token_id:
135
+ params.pad_token_id = pad_token_id
136
+
137
+ # Handle max_length and max_new_tokens
138
+ if max_length and max_new_tokens:
139
+ logging.warning(
140
+ "Both max_length and max_new_tokens were provided. "
141
+ "max_length will take precedence. "
142
+ "When setting max_length, please explicitly set max_new_tokens to None."
143
+ )
144
+ max_length_to_use = None
145
+ if max_length:
146
+ max_length_to_use = max_length
147
+ elif max_new_tokens:
148
+ max_length_to_use = prompt_length + max_new_tokens
149
+
150
+ min_length = prompt_length + min_new_tokens
151
+
152
+ if use_oga_pre_6_api:
153
+ params.input_ids = input_ids
154
+
155
+ if random_seed is None:
156
+ random_seed = -1 # In og.Generator, -1 = seed with random device
157
+
158
+ # Get search config if available, otherwise use empty dict
159
+ # Thanks to the empty dict, if the model doesn't have a built-in search
160
+ # config, the .get() calls will all just use the default values
161
+ search_config = {}
162
+ if self.config and "search" in self.config:
163
+ search_config = self.config["search"]
164
+
165
+ # Apply parameter hierarchy: user provided > search config > defaults
166
+ default_top_k = 50
167
+ default_top_p = 1.0
168
+ default_temperature = 0.7
169
+ default_repetition_penalty = 1.0
170
+
171
+ top_k_to_use = (
172
+ top_k if top_k is not None else search_config.get("top_k", default_top_k)
173
+ )
174
+ top_p_to_use = (
175
+ top_p if top_p is not None else search_config.get("top_p", default_top_p)
176
+ )
177
+ temperature_to_use = (
178
+ temperature
179
+ if temperature is not None
180
+ else search_config.get("temperature", default_temperature)
181
+ )
182
+ # Map the llamacpp name, `repeat_penalty`, to the OGA name, `repetition_penalty`
183
+ repetition_penalty_to_use = (
184
+ repeat_penalty
185
+ if repeat_penalty is not None
186
+ else search_config.get("repetition_penalty", default_repetition_penalty)
187
+ )
188
+
189
+ # Set search options once with all parameters
190
+ params.set_search_options(
191
+ do_sample=search_config.get("do_sample", do_sample),
192
+ top_k=top_k_to_use,
193
+ top_p=top_p_to_use,
194
+ temperature=temperature_to_use,
195
+ repetition_penalty=repetition_penalty_to_use,
196
+ max_length=max_length_to_use,
197
+ min_length=min_length,
198
+ early_stopping=search_config.get("early_stopping", False),
199
+ length_penalty=search_config.get("length_penalty", 1.0),
200
+ num_beams=search_config.get("num_beams", 1),
201
+ num_return_sequences=search_config.get("num_return_sequences", 1),
202
+ past_present_share_buffer=search_config.get(
203
+ "past_present_share_buffer", True
204
+ ),
205
+ random_seed=random_seed,
206
+ # Not currently supported by OGA
207
+ # diversity_penalty=search_config.get('diversity_penalty', 0.0),
208
+ # no_repeat_ngram_size=search_config.get('no_repeat_ngram_size', 0),
209
+ )
210
+ params.try_graph_capture_with_max_batch_size(1)
211
+
212
+ generator = og.Generator(self.model, params)
213
+
214
+ if streamer is None:
215
+ prompt_start_time = time.perf_counter()
216
+ if use_oga_post_6_api:
217
+ generator.append_tokens(input_ids)
218
+ if use_oga_pre_6_api:
219
+ generator.compute_logits()
220
+ generator.generate_next_token()
221
+ prompt_end_time = time.perf_counter()
222
+
223
+ self.time_to_first_token = prompt_end_time - prompt_start_time
224
+
225
+ if max_new_tokens > 1:
226
+
227
+ token_gen_times = []
228
+ while not generator.is_done():
229
+ token_gen_start_time = time.perf_counter()
230
+ if use_oga_pre_6_api:
231
+ generator.compute_logits()
232
+ generator.generate_next_token()
233
+ token_gen_end_time = time.perf_counter()
234
+
235
+ token_gen_times.append(token_gen_end_time - token_gen_start_time)
236
+
237
+ if token_gen_times:
238
+ # List will be empty if we generated 1 or 0 tokens, and we don't
239
+ # want a divide-by-zero error in those cases
240
+ avg_token_gen_latency_s = sum(token_gen_times) / len(
241
+ token_gen_times
242
+ )
243
+ self.tokens_per_second = 1 / avg_token_gen_latency_s
244
+
245
+ response = generator.get_sequence(0)
246
+ self.response_tokens = len(response) - self.prompt_tokens
247
+ return [response]
248
+ else:
249
+ if use_oga_post_6_api:
250
+ generator.append_tokens(input_ids)
251
+ tokenizer_stream = streamer.tokenizer.tokenizer.create_stream()
252
+ self.response_tokens = 0
253
+ stop_early = False
254
+
255
+ while not generator.is_done() and not stop_early:
256
+ if use_oga_pre_6_api:
257
+ generator.compute_logits()
258
+ generator.generate_next_token()
259
+ self.response_tokens += 1
260
+
261
+ new_token = generator.get_next_tokens()[0]
262
+ new_text = tokenizer_stream.decode(new_token)
263
+
264
+ streamer.add_text(new_text)
265
+
266
+ if stopping_criteria is not None:
267
+ if stopping_criteria[0].stop_event.is_set():
268
+ stop_early = True
269
+
270
+ streamer.done()
271
+
272
+ def _model_call(self, input_ids):
273
+ """
274
+ Run the model on input_ids and get logits.
275
+
276
+ This method directly accesses model logits rather than using the full generate pipeline for
277
+ several important reasons:
278
+ 1. Purpose: We need raw logits from a single forward pass, while generate() is optimized for
279
+ producing multiple tokens through iterative inference
280
+ 2. Efficiency: Direct access is more efficient for logprob calculations with no
281
+ sampling overhead
282
+ 3. Precision: Logprob calculations require exact control over input-to-output mapping
283
+ 4. Consistency: Similar approach used in both HF and OGA implementations
284
+
285
+ Args:
286
+ input_ids: Input token IDs
287
+
288
+ Returns:
289
+ Logits for each token in the sequence
290
+ """
291
+ import torch
292
+
293
+ # Setup generator params
294
+ params = og.GeneratorParams(self.model)
295
+
296
+ # Configure for a simple forward pass
297
+ params.set_search_options(
298
+ do_sample=False,
299
+ temperature=0.0,
300
+ max_length=len(input_ids),
301
+ )
302
+
303
+ # Initialize generator
304
+ generator = og.Generator(self.model, params)
305
+
306
+ # Feed tokens to model based on API version
307
+ generator.append_tokens(input_ids)
308
+
309
+ # Extract logits - this returns a list of logits tensors
310
+ logits = generator.get_output("logits")
311
+
312
+ # Convert to torch tensor for easier processing
313
+ return torch.tensor(logits[0])
314
+
315
+ def _select_cont_toks(self, logits, context_len, continuation_tokens):
316
+ """
317
+ Select and process logits for continuation tokens.
318
+
319
+ Args:
320
+ logits: Full sequence logits
321
+ context_len: Length of context tokens
322
+ continuation_tokens: List or tensor of continuation token IDs
323
+
324
+ Returns:
325
+ Log probabilities for continuation tokens
326
+ """
327
+ import torch
328
+
329
+ # Extract relevant logits for continuation prediction (shift by one)
330
+ cont_logits = logits[
331
+ context_len - 1 : context_len - 1 + len(continuation_tokens)
332
+ ]
333
+
334
+ # Convert to torch tensors if needed
335
+ if not isinstance(continuation_tokens, torch.Tensor):
336
+ continuation_tokens = torch.tensor(continuation_tokens, dtype=torch.long)
337
+
338
+ # Apply log softmax to get log probabilities
339
+ log_probs = torch.log_softmax(cont_logits, dim=-1)
340
+
341
+ # Get log probs for the specific continuation tokens
342
+ token_log_probs = torch.gather(
343
+ log_probs, 1, continuation_tokens.unsqueeze(-1)
344
+ ).squeeze(-1)
345
+
346
+ return token_log_probs
347
+
348
+ def compute_logprobs(
349
+ self, text, tokenizer, prompt_length=None, logprobs=None, echo=False
350
+ ):
351
+ """
352
+ Compute log probabilities for all tokens in the given text.
353
+
354
+ Args:
355
+ text: The full text to analyze (e.g., prompt + completion)
356
+ prompt_length: Number of tokens in the prompt. If provided and echo=False,
357
+ only completion tokens after this position will be returned.
358
+ logprobs: If not None, return log probabilities. Value indicates how many top
359
+ alternatives to return. If True but not an integer, defaults to 5 alternatives.
360
+ echo: If True, include logprobs for prompt tokens. If False, only return logprobs
361
+ for completion tokens.
362
+
363
+ Returns:
364
+ - text_offset: Character offsets for each token in the text
365
+ - token_logprobs: Log probability for each token
366
+ - tokens: The actual tokens used
367
+ - top_logprobs: Top alternative log probabilities for each position
368
+ """
369
+ import torch
370
+
371
+ if tokenizer is None:
372
+ raise ValueError("Tokenizer is required for logprob calculation")
373
+
374
+ # Encode the full text
375
+ tokens = tokenizer(text).input_ids # pylint: disable=E1102
376
+
377
+ # Track character offsets for each token
378
+ text_offset = []
379
+ start_idx = 0
380
+
381
+ token_strings = []
382
+ for token_id in tokens:
383
+ token_str = tokenizer.decode([token_id])
384
+ token_strings.append(token_str)
385
+
386
+ # Calculate character offsets for tokens - handles cases where tokens
387
+ # may not directly match in the original text due to encoding differences,
388
+ # special characters, or tokenization artifacts
389
+ try:
390
+ pos = text[start_idx:].find(token_str)
391
+ if pos != -1:
392
+ text_offset.append(start_idx + pos)
393
+ start_idx += pos + len(token_str)
394
+ else:
395
+ text_offset.append(start_idx)
396
+ except (TypeError, ValueError, UnicodeError):
397
+ # Fallback to current position when matching fails due to encoding issues
398
+ text_offset.append(start_idx)
399
+
400
+ # Get logits from model
401
+ logits = self._model_call(tokens)
402
+
403
+ # Calculate log probabilities for each token
404
+ all_log_probs = torch.log_softmax(logits, dim=-1)
405
+
406
+ # The first token doesn't have a conditional probability
407
+ # For tokens after the first, get the predicted probability
408
+ token_log_probs = []
409
+ top_logprobs_list = []
410
+
411
+ # For each position, get the actual token probability and top alternatives
412
+ for i in range(len(tokens)):
413
+ # Get previous token position logits
414
+ if i > 0: # First token has no preceding context
415
+ prev_logits = all_log_probs[i - 1]
416
+ curr_token_id = tokens[i]
417
+ # Get probability of the actual token that appeared
418
+ token_logprob = prev_logits[curr_token_id].item()
419
+ token_log_probs.append(token_logprob)
420
+
421
+ # Get top-k alternatives if requested
422
+ if logprobs is not None:
423
+ num_alternatives = logprobs if isinstance(logprobs, int) else 5
424
+ topk_values, topk_indices = torch.topk(
425
+ prev_logits, min(num_alternatives, prev_logits.size(-1))
426
+ )
427
+
428
+ # Create dictionary of token: logprob
429
+ position_logprobs = {}
430
+ for val, idx in zip(topk_values.tolist(), topk_indices.tolist()):
431
+ token_str = tokenizer.decode([idx])
432
+ position_logprobs[token_str] = val
433
+
434
+ top_logprobs_list.append(position_logprobs)
435
+ else:
436
+ # For the first token, we don't have a conditional probability
437
+ token_log_probs.append(None)
438
+ top_logprobs_list.append({})
439
+
440
+ # If we don't want to echo prompt tokens, filter them out
441
+ if not echo and prompt_length is not None:
442
+ # Ensure prompt_length is within bounds
443
+ prompt_length = min(prompt_length, len(tokens))
444
+
445
+ # Filter results to only include completion tokens
446
+ if prompt_length < len(tokens):
447
+ filtered_text_offset = text_offset[prompt_length:]
448
+ filtered_token_logprobs = token_log_probs[prompt_length:]
449
+ filtered_tokens = token_strings[prompt_length:]
450
+ filtered_top_logprobs = top_logprobs_list[prompt_length:]
451
+
452
+ return (
453
+ filtered_text_offset,
454
+ filtered_token_logprobs,
455
+ filtered_tokens,
456
+ filtered_top_logprobs,
457
+ )
458
+ else:
459
+ # No completion tokens
460
+ return [], [], [], []
461
+
462
+ return text_offset, token_log_probs, token_strings, top_logprobs_list
@@ -0,0 +1,147 @@
1
+ import os
2
+ import argparse
3
+ from lemonade.state import State
4
+ from lemonade.tools import Tool
5
+ import lemonade.common.printing as printing
6
+ import lemonade.common.build as build
7
+
8
+
9
+ class AccuracyPerplexity(Tool):
10
+ """
11
+ Measure perplexity of an LLM using the Wikitext-2 dataset.
12
+
13
+ Required input state:
14
+ - state.model: instance that provides a __call__() method that returns
15
+ output.logits and supports model.config.max_position_embeddings
16
+ - state.tokenizer: instance of Hugging Face PretrainedTokenizer
17
+
18
+ Output state produced: None
19
+
20
+ See docs/dev_cli/perplexity.md for more details.
21
+ """
22
+
23
+ unique_name = "accuracy-perplexity"
24
+
25
+ def __init__(self):
26
+ super().__init__(monitor_message="Measuring perplexity")
27
+
28
+ @staticmethod
29
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
30
+ parser = __class__.helpful_parser(
31
+ short_description="Measure perplexity score",
32
+ add_help=add_help,
33
+ )
34
+ return parser
35
+
36
+ def run(
37
+ self,
38
+ state: State,
39
+ ) -> State:
40
+
41
+ import pandas as pd
42
+ import torch
43
+ from datasets import load_dataset
44
+
45
+ try:
46
+ printing.log_info("Downloading dataset ...")
47
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
48
+ except Exception as e: # pylint: disable=broad-except
49
+ printing.log_error(f"Error during dataset load: {e}")
50
+ raise e
51
+
52
+ tokenizer = state.tokenizer
53
+ model = state.model
54
+ # Tokenize the entire test dataset text, joining entries with double new lines
55
+ encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
56
+
57
+ # Retrieve the maximum input length that the model can handle
58
+ try:
59
+ max_length = model.config.max_position_embeddings
60
+ except AttributeError:
61
+ # Some LLMs do not have the config.max_position_embeddings attribute
62
+ # However, most LLMs support at least 2048 context length, so this
63
+ # try-except will allow a few more LLMs to work
64
+ max_length = 2048
65
+ # Set stride to half of the maximum input length for overlapping window processing
66
+ # Refer to docs/dev_cli/perplexity.md for more information on sliding window
67
+ stride = max_length // 2
68
+ # Determine the total sequence length of the tokenized input
69
+ seq_len = encodings.input_ids.size(1)
70
+
71
+ negative_log_likelihoods = []
72
+ summary_data = []
73
+ prev_end_location = 0
74
+
75
+ model_results_dir = os.path.join(
76
+ build.output_dir(state.cache_dir, state.build_name), "perplexity"
77
+ )
78
+
79
+ for begin_location in range(0, seq_len, stride):
80
+ end_location = min(begin_location + max_length, seq_len)
81
+ target_len = end_location - prev_end_location
82
+ input_ids = encodings.input_ids[:, begin_location:end_location]
83
+ target_ids = input_ids.clone()
84
+ target_ids[:, :-target_len] = -100
85
+
86
+ # Forward pass the model to get logits
87
+ with torch.no_grad():
88
+ try:
89
+ outputs = model(input_ids, labels=target_ids)
90
+ logits = outputs.logits
91
+ except Exception as e: # pylint: disable=broad-except
92
+ printing.log_error(
93
+ f"Error during model forward pass execution: {e}"
94
+ )
95
+
96
+ # Compute loss manually for visualization
97
+ shift_logits = logits[..., :-1, :].contiguous()
98
+ shift_labels = target_ids[..., 1:].contiguous()
99
+ effective_token_count = (target_ids != -100).sum().item()
100
+ negative_log_likelihoods.append(
101
+ (outputs.loss.item(), effective_token_count)
102
+ )
103
+
104
+ # Decode predicted and actual next words for the last token position
105
+ predictions = torch.argmax(shift_logits, dim=-1)
106
+ predicted_tokens = predictions[:, -1]
107
+ actual_tokens = shift_labels[:, -1]
108
+
109
+ predicted_words = tokenizer.batch_decode(
110
+ predicted_tokens, skip_special_tokens=True
111
+ )
112
+ actual_words = tokenizer.batch_decode(
113
+ actual_tokens, skip_special_tokens=True
114
+ )
115
+ context = tokenizer.decode(input_ids[0, :])
116
+
117
+ summary_data.append(
118
+ {
119
+ "Context": context[-stride:],
120
+ "Predicted next word": predicted_words,
121
+ "Actual next word": actual_words,
122
+ "Loss for this window": outputs.loss.item(),
123
+ }
124
+ )
125
+ prev_end_location = end_location
126
+
127
+ # Total loss calculation considering the number of tokens for each segment
128
+ total_loss = sum(loss * count for loss, count in negative_log_likelihoods)
129
+ total_tokens = sum(count for _, count in negative_log_likelihoods)
130
+
131
+ # Calculate average negative_log_likelihood and perplexity
132
+ average_negative_log_likelihood = total_loss / total_tokens
133
+ perplexity = torch.exp(torch.tensor(average_negative_log_likelihood))
134
+
135
+ # Save accuracy results to stats file
136
+ state.save_stat("perplexity_score", float(perplexity.item()))
137
+
138
+ # Save accuracy results to CSV file
139
+ summary_df = pd.DataFrame(summary_data)
140
+ summary_df.to_csv(
141
+ os.path.join(model_results_dir, "summary_results.csv"), index=False
142
+ )
143
+ return state
144
+
145
+
146
+ # This file was originally licensed under Apache 2.0. It has been modified.
147
+ # Modifications Copyright (c) 2025 AMD