llama-benchy 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ name: Python Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [ "main" ]
6
+ pull_request:
7
+ branches: [ "main" ]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.10", "3.11", "3.12"]
15
+
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+ - name: Set up Python ${{ matrix.python-version }}
19
+ uses: actions/setup-python@v5
20
+ with:
21
+ python-version: ${{ matrix.python-version }}
22
+ - name: Install dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install -e ".[dev]"
26
+ - name: Test with pytest
27
+ run: |
28
+ pytest
@@ -3,8 +3,8 @@
3
3
  venv/
4
4
  env/
5
5
 
6
- # GitHub workflows
7
- .github/
6
+ # Copilot instructions
7
+ .github/copilot-instructions.md
8
8
 
9
9
  # Python package artifacts
10
10
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llama-benchy
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: llama-bench style benchmarking tool for all OpenAI-compatible LLM endpoints
5
5
  Author: eugr
6
6
  License: MIT License
@@ -41,6 +41,12 @@ Requires-Dist: openai
41
41
  Requires-Dist: requests
42
42
  Requires-Dist: tabulate
43
43
  Requires-Dist: transformers
44
+ Provides-Extra: dev
45
+ Requires-Dist: fastapi; extra == 'dev'
46
+ Requires-Dist: pydantic; extra == 'dev'
47
+ Requires-Dist: pytest; extra == 'dev'
48
+ Requires-Dist: pytest-asyncio; extra == 'dev'
49
+ Requires-Dist: uvicorn; extra == 'dev'
44
50
  Description-Content-Type: text/markdown
45
51
 
46
52
  # llama-benchy - llama-bench style benchmarking tool for all backends
@@ -73,12 +79,12 @@ As of January 2nd, 2026, I wasn't able to find any existing benchmarking tool th
73
79
  - Downloads a book from Project Gutenberg to use as source text for prompts to ensure better benchmarking of spec.decoding/MTP models.
74
80
  - Supports executing a command after each run (e.g., to clear cache).
75
81
  - Configurable latency measurement mode.
82
+ - Supports concurrent requests (`--concurrency`) to measure throughput under load.
83
+ - Can save results to file in Markdown, JSON, or CSV format.
76
84
 
77
85
  # Current Limitations
78
86
 
79
87
  - Evaluates against `/v1/chat/completions` endpoint only.
80
- - Doesn't measure throughput in concurrency mode (coming later).
81
- - Outputs results as a Markdown table only for now.
82
88
 
83
89
  ## Installation
84
90
 
@@ -213,6 +219,9 @@ Generally you don't need to disable prompt caching on the server, as a probabili
213
219
  - `--adapt-prompt`: Adapt prompt size based on warmup token usage delta (Default: True).
214
220
  - `--no-adapt-prompt`: Disable prompt size adaptation.
215
221
  - `--enable-prefix-caching`: Enable prefix caching performance measurement. When enabled (and depth > 0), it performs a two-step benchmark: first loading the context (reported as `ctx_pp`), then running the prompt with the cached context.
222
+ - `--concurrency`: List of concurrency levels (number of concurrent requests per test) (Default: [1]).
223
+ - `--save-result`: File to save results to.
224
+ - `--format`: Output format: 'md', 'json', 'csv' (Default: 'md').
216
225
 
217
226
  ### Metrics
218
227
 
@@ -228,6 +237,9 @@ The script attempts to estimate network or processing latency to provide "server
228
237
 
229
238
  #### Table Columns
230
239
 
240
+ - When `concurrency` > 1:
241
+ - **`t/s (total)`**: Total throughput across all concurrent requests.
242
+ - **`t/s (req)`**: Average throughput per individual request.
231
243
  - **`t/s` (Tokens per Second)**:
232
244
  - **For Prompt Processing (pp)**: Calculated as `Total Prompt Tokens / est_ppt`. This represents the prefill speed.
233
245
  - **For Token Generation (tg)**: Calculated as `(Total Generated Tokens - 1) / (Time of Last Token - Time of First Token)`. This represents the decode speed, excluding the first token latency.
@@ -267,3 +279,31 @@ llama-benchy \
267
279
  ```
268
280
 
269
281
  This will run benchmarks for all combinations of pp (128, 256), tg (32, 64), and depth (0, 1024).
282
+
283
+ ## Development
284
+
285
+ ### Running Integration Tests
286
+
287
+ This repository includes a mock server and an integration test suite to verify `llama-benchy` logic without needing a real GPU server.
288
+
289
+ The mock server emulates:
290
+ - **Prompt Processing (PP):** ~1000 t/s drift-corrected.
291
+ - **Token Generation (TG):** ~50 t/s.
292
+ - **Prefix Caching:** Emulates cache hits by skipping processing time for cached prefixes (system messages).
293
+ - **OpenAI API Compatibility**: Serves `/v1/chat/completions` and `/v1/models`.
294
+
295
+ To run the integration tests:
296
+
297
+ ```bash
298
+ # Install development dependencies
299
+ uv sync --all-extras --dev
300
+
301
+ # Run tests
302
+ uv run pytest tests/test_mock_integration.py
303
+ ```
304
+
305
+ This test will:
306
+ 1. Spin up the mock server on port 8001.
307
+ 2. Run `llama-benchy` against it.
308
+ 3. Parse the JSON output.
309
+ 4. Verify that throughputs match the emulated speeds (PP ~1000, TG ~50) and that caching effectively increases effective throughput.
@@ -28,12 +28,12 @@ As of January 2nd, 2026, I wasn't able to find any existing benchmarking tool th
28
28
  - Downloads a book from Project Gutenberg to use as source text for prompts to ensure better benchmarking of spec.decoding/MTP models.
29
29
  - Supports executing a command after each run (e.g., to clear cache).
30
30
  - Configurable latency measurement mode.
31
+ - Supports concurrent requests (`--concurrency`) to measure throughput under load.
32
+ - Can save results to file in Markdown, JSON, or CSV format.
31
33
 
32
34
  # Current Limitations
33
35
 
34
36
  - Evaluates against `/v1/chat/completions` endpoint only.
35
- - Doesn't measure throughput in concurrency mode (coming later).
36
- - Outputs results as a Markdown table only for now.
37
37
 
38
38
  ## Installation
39
39
 
@@ -168,6 +168,9 @@ Generally you don't need to disable prompt caching on the server, as a probabili
168
168
  - `--adapt-prompt`: Adapt prompt size based on warmup token usage delta (Default: True).
169
169
  - `--no-adapt-prompt`: Disable prompt size adaptation.
170
170
  - `--enable-prefix-caching`: Enable prefix caching performance measurement. When enabled (and depth > 0), it performs a two-step benchmark: first loading the context (reported as `ctx_pp`), then running the prompt with the cached context.
171
+ - `--concurrency`: List of concurrency levels (number of concurrent requests per test) (Default: [1]).
172
+ - `--save-result`: File to save results to.
173
+ - `--format`: Output format: 'md', 'json', 'csv' (Default: 'md').
171
174
 
172
175
  ### Metrics
173
176
 
@@ -183,6 +186,9 @@ The script attempts to estimate network or processing latency to provide "server
183
186
 
184
187
  #### Table Columns
185
188
 
189
+ - When `concurrency` > 1:
190
+ - **`t/s (total)`**: Total throughput across all concurrent requests.
191
+ - **`t/s (req)`**: Average throughput per individual request.
186
192
  - **`t/s` (Tokens per Second)**:
187
193
  - **For Prompt Processing (pp)**: Calculated as `Total Prompt Tokens / est_ppt`. This represents the prefill speed.
188
194
  - **For Token Generation (tg)**: Calculated as `(Total Generated Tokens - 1) / (Time of Last Token - Time of First Token)`. This represents the decode speed, excluding the first token latency.
@@ -222,3 +228,31 @@ llama-benchy \
222
228
  ```
223
229
 
224
230
  This will run benchmarks for all combinations of pp (128, 256), tg (32, 64), and depth (0, 1024).
231
+
232
+ ## Development
233
+
234
+ ### Running Integration Tests
235
+
236
+ This repository includes a mock server and an integration test suite to verify `llama-benchy` logic without needing a real GPU server.
237
+
238
+ The mock server emulates:
239
+ - **Prompt Processing (PP):** ~1000 t/s drift-corrected.
240
+ - **Token Generation (TG):** ~50 t/s.
241
+ - **Prefix Caching:** Emulates cache hits by skipping processing time for cached prefixes (system messages).
242
+ - **OpenAI API Compatibility**: Serves `/v1/chat/completions` and `/v1/models`.
243
+
244
+ To run the integration tests:
245
+
246
+ ```bash
247
+ # Install development dependencies
248
+ uv sync --all-extras --dev
249
+
250
+ # Run tests
251
+ uv run pytest tests/test_mock_integration.py
252
+ ```
253
+
254
+ This test will:
255
+ 1. Spin up the mock server on port 8001.
256
+ 2. Run `llama-benchy` against it.
257
+ 3. Parse the JSON output.
258
+ 4. Verify that throughputs match the emulated speeds (PP ~1000, TG ~50) and that caching effectively increases effective throughput.
@@ -32,6 +32,15 @@ dependencies = [
32
32
  "aiohttp",
33
33
  ]
34
34
 
35
+ [project.optional-dependencies]
36
+ dev = [
37
+ "pytest",
38
+ "pytest-asyncio",
39
+ "fastapi",
40
+ "uvicorn",
41
+ "pydantic",
42
+ ]
43
+
35
44
  [project.scripts]
36
45
  llama-benchy = "llama_benchy.__main__:main"
37
46
 
@@ -0,0 +1,45 @@
1
+ """
2
+ Main entry point for the llama-benchy CLI.
3
+ """
4
+
5
+ import asyncio
6
+ import datetime
7
+ from . import __version__
8
+ from .config import BenchmarkConfig
9
+ from .corpus import TokenizedCorpus
10
+ from .prompts import PromptGenerator
11
+ from .client import LLMClient
12
+ from .runner import BenchmarkRunner
13
+
14
+ async def main_async():
15
+ # 1. Parse Configuration
16
+ config = BenchmarkConfig.from_args()
17
+
18
+ # 2. Print Header
19
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
20
+ print(f"llama-benchy ({__version__})")
21
+ print(f"Date: {current_time}")
22
+ print(f"Benchmarking model: {config.model} at {config.base_url}")
23
+ print(f"Concurrency levels: {config.concurrency_levels}")
24
+
25
+ # 3. Prepare Data
26
+ corpus = TokenizedCorpus(config.book_url, config.tokenizer, config.model)
27
+ print(f"Total tokens available in text corpus: {len(corpus)}")
28
+
29
+ # 4. Initialize Components
30
+ prompt_gen = PromptGenerator(corpus)
31
+ client = LLMClient(config.base_url, config.api_key, config.served_model_name)
32
+ runner = BenchmarkRunner(config, client, prompt_gen)
33
+
34
+ # 5. Run Benchmark Suite
35
+ await runner.run_suite()
36
+
37
+ print(f"\nllama-benchy ({__version__})")
38
+ print(f"date: {current_time} | latency mode: {config.latency_mode}")
39
+
40
+ def main():
41
+ """Entry point for the CLI command."""
42
+ asyncio.run(main_async())
43
+
44
+ if __name__ == "__main__":
45
+ main()
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.1'
32
- __version_tuple__ = version_tuple = (0, 1, 1)
31
+ __version__ = version = '0.2.0'
32
+ __version_tuple__ = version_tuple = (0, 2, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,199 @@
1
+ import time
2
+ import json
3
+ import codecs
4
+ import aiohttp
5
+ import asyncio
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+ from typing import Optional, List, Dict, Any
9
+
10
+ @dataclass
11
+ class RequestResult:
12
+ start_ts: float = 0.0
13
+ end_ts: float = 0.0
14
+ first_token_ts: Optional[float] = None
15
+ first_response_ts: Optional[float] = None
16
+ prompt_tokens: int = 0
17
+ total_tokens: int = 0
18
+ error: Optional[str] = None
19
+
20
+ class LLMClient:
21
+ def __init__(self, base_url: str, api_key: str, model_name: str):
22
+ self.base_url = base_url
23
+ self.api_key = api_key
24
+ self.model_name = model_name
25
+ self.headers = {"Authorization": f"Bearer {api_key}"}
26
+
27
+ async def measure_latency(self, session: aiohttp.ClientSession, mode: str = "api") -> float:
28
+ if mode == "none":
29
+ print("Skipping latency measurement (assuming 0 ms).")
30
+ return 0
31
+
32
+ print(f"Measuring latency using mode: {mode}...")
33
+ latencies = []
34
+
35
+ for _ in range(3):
36
+ start = time.perf_counter()
37
+ try:
38
+ if mode == "api":
39
+ async with session.get(f"{self.base_url}/models", headers=self.headers) as response:
40
+ await response.read()
41
+ latencies.append(time.perf_counter() - start)
42
+ elif mode == "generation":
43
+ payload = {
44
+ "model": self.model_name,
45
+ "messages": [{"role": "user", "content": "hello"}],
46
+ "max_tokens": 1,
47
+ "stream": True
48
+ }
49
+ async with session.post(f"{self.base_url}/chat/completions", json=payload, headers=self.headers) as response:
50
+ async for _ in response.content:
51
+ latencies.append(time.perf_counter() - start)
52
+ break
53
+ async for _ in response.content: pass
54
+ except Exception as e:
55
+ print(f"Error measuring latency: {e}")
56
+
57
+ if latencies:
58
+ avg_latency = np.mean(latencies)
59
+ print(f"Average latency ({mode}): {avg_latency*1000:.2f} ms")
60
+ return avg_latency
61
+ return 0
62
+
63
+ async def warmup(self, session: aiohttp.ClientSession, tokenizer=None):
64
+ print("Warming up...")
65
+ warmup_text = "Warmup " * 10
66
+
67
+ delta_user = 0
68
+ delta_context = 0
69
+
70
+ # 1. User only
71
+ payload_user = {
72
+ "model": self.model_name,
73
+ "messages": [{"role": "user", "content": warmup_text}],
74
+ "max_tokens": 1
75
+ }
76
+
77
+ try:
78
+ async with session.post(f"{self.base_url}/chat/completions", json=payload_user, headers=self.headers) as response:
79
+ response_json = await response.json()
80
+ if tokenizer:
81
+ if 'usage' in response_json:
82
+ prompt_tokens = response_json['usage']['prompt_tokens']
83
+ local_tokens = len(tokenizer.encode(warmup_text, add_special_tokens=False))
84
+ delta_user = prompt_tokens - local_tokens
85
+ print(f"Warmup (User only) complete. Delta: {delta_user} tokens (Server: {prompt_tokens}, Local: {local_tokens})")
86
+ else:
87
+ print("Warmup (User only) complete (no usage stats found).")
88
+ else:
89
+ print("Warmup complete.")
90
+
91
+ if tokenizer:
92
+ # 2. Context Only
93
+ payload_sys_empty = {
94
+ "model": self.model_name,
95
+ "messages": [
96
+ {"role": "system", "content": warmup_text},
97
+ {"role": "user", "content": ""}
98
+ ],
99
+ "max_tokens": 1
100
+ }
101
+ async with session.post(f"{self.base_url}/chat/completions", json=payload_sys_empty, headers=self.headers) as response:
102
+ response_json = await response.json()
103
+ if 'usage' in response_json:
104
+ prompt_tokens = response_json['usage']['prompt_tokens']
105
+ local_tokens = len(tokenizer.encode(warmup_text, add_special_tokens=False))
106
+ delta_context = prompt_tokens - local_tokens
107
+ print(f"Warmup (System+Empty) complete. Delta: {delta_context} tokens (Server: {prompt_tokens}, Local: {local_tokens})")
108
+ else:
109
+ delta_context = delta_user
110
+ except Exception as e:
111
+ print(f"Warmup failed: {e}")
112
+ return delta_user, delta_context
113
+
114
+ async def run_generation(
115
+ self,
116
+ session: aiohttp.ClientSession,
117
+ context_text: str,
118
+ prompt_text: str,
119
+ max_tokens: int,
120
+ no_cache: bool
121
+ ) -> RequestResult:
122
+
123
+ messages = []
124
+ if context_text:
125
+ messages.append({"role": "system", "content": context_text})
126
+ messages.append({"role": "user", "content": prompt_text})
127
+
128
+ result = RequestResult()
129
+
130
+ try:
131
+ payload = {
132
+ "model": self.model_name,
133
+ "messages": messages,
134
+ "max_tokens": max_tokens,
135
+ "stream": True,
136
+ "stream_options": {"include_usage": True},
137
+ }
138
+
139
+ if no_cache:
140
+ payload["cache_prompt"] = False
141
+
142
+ result.start_ts = time.perf_counter()
143
+
144
+ async with session.post(f"{self.base_url}/chat/completions", json=payload, headers=self.headers) as response:
145
+ if response.status != 200:
146
+ error_text = await response.text()
147
+ result.error = f"HTTP {response.status}: {error_text}"
148
+ print(result.error)
149
+ return result
150
+
151
+ decoder = codecs.getincrementaldecoder("utf-8")(errors='replace')
152
+ buffer = ""
153
+
154
+ async for chunk_bytes in response.content:
155
+ chunk_time = time.perf_counter()
156
+ decoded_chunk = decoder.decode(chunk_bytes, final=False)
157
+ buffer += decoded_chunk
158
+
159
+ while "\n" in buffer:
160
+ line, buffer = buffer.split("\n", 1)
161
+ line = line.strip()
162
+ if not line:
163
+ continue
164
+
165
+ if line == 'data: [DONE]' or line == 'data:[DONE]':
166
+ continue
167
+
168
+ if line.startswith('data:'):
169
+ try:
170
+ json_str = line[5:].strip()
171
+ chunk = json.loads(json_str)
172
+
173
+ if 'usage' in chunk:
174
+ result.prompt_tokens = chunk['usage'].get('prompt_tokens', 0)
175
+
176
+ if 'choices' in chunk and len(chunk['choices']) > 0:
177
+ if result.first_response_ts is None:
178
+ result.first_response_ts = chunk_time
179
+
180
+ delta = chunk['choices'][0].get('delta', {})
181
+ content = delta.get('content')
182
+ reasoning_content = delta.get('reasoning_content')
183
+ reasoning = delta.get('reasoning')
184
+
185
+ if content or reasoning_content or reasoning:
186
+ if result.first_token_ts is None:
187
+ result.first_token_ts = chunk_time
188
+
189
+ result.total_tokens += 1
190
+ except json.JSONDecodeError:
191
+ continue
192
+
193
+ result.end_ts = time.perf_counter()
194
+
195
+ except Exception as e:
196
+ print(f"Error during run: {e}")
197
+ result.error = str(e)
198
+
199
+ return result
@@ -0,0 +1,76 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+ import argparse
4
+ import os
5
+ from ._version import __version__
6
+
7
+ @dataclass
8
+ class BenchmarkConfig:
9
+ base_url: str
10
+ api_key: str
11
+ model: str
12
+ served_model_name: str
13
+ tokenizer: Optional[str]
14
+ pp_counts: List[int]
15
+ tg_counts: List[int]
16
+ depths: List[int]
17
+ num_runs: int
18
+ no_cache: bool
19
+ latency_mode: str
20
+ no_warmup: bool
21
+ adapt_prompt: bool
22
+ enable_prefix_caching: bool
23
+ book_url: str
24
+ post_run_cmd: Optional[str]
25
+ concurrency_levels: List[int]
26
+ save_result: Optional[str] = None
27
+ result_format: str = "md"
28
+
29
+ @classmethod
30
+ def from_args(cls):
31
+ parser = argparse.ArgumentParser(description="LLM Benchmark Script")
32
+ parser.add_argument('--version', action='version', version=f'%(prog)s {__version__}')
33
+ parser.add_argument("--base-url", type=str, required=True, help="OpenAI compatible endpoint URL")
34
+ parser.add_argument("--api-key", type=str, default="EMPTY", help="API Key for the endpoint")
35
+ parser.add_argument("--model", type=str, required=True, help="Model name to use for benchmarking")
36
+ parser.add_argument("--served-model-name", type=str, default=None, help="Model name used in API calls (defaults to --model if not specified)")
37
+ parser.add_argument("--tokenizer", type=str, default=None, help="HuggingFace tokenizer name (defaults to model name)")
38
+ parser.add_argument("--pp", type=int, nargs='+', required=False, default=[2048], help="List of prompt processing token counts - default: 2048")
39
+ parser.add_argument("--tg", type=int, nargs='+', required=False, default=[32], help="List of token generation counts - default: 32")
40
+ parser.add_argument("--depth", type=int, nargs='+', default=[0], help="List of context depths (previous conversation tokens) - default: 0")
41
+ parser.add_argument("--runs", type=int, default=3, help="Number of runs per test - default: 3")
42
+ parser.add_argument("--no-cache", action="store_true", help="Ensure unique requests to avoid prefix caching and send cache_prompt=false to the server")
43
+ parser.add_argument("--post-run-cmd", type=str, default=None, help="Command to execute after each test run")
44
+ parser.add_argument("--book-url", type=str, default="https://www.gutenberg.org/files/1661/1661-0.txt", help="URL of a book to use for text generation, defaults to Sherlock Holmes")
45
+ parser.add_argument("--latency-mode", type=str, default="api", choices=["api", "generation", "none"], help="Method to measure latency: 'api' (list models) - default, 'generation' (single token generation), or 'none' (skip latency measurement)")
46
+ parser.add_argument("--no-warmup", action="store_true", help="Skip warmup phase")
47
+ parser.add_argument("--adapt-prompt", action="store_true", default=True, help="Adapt prompt size based on warmup token usage delta (default: True)")
48
+ parser.add_argument("--no-adapt-prompt", action="store_false", dest="adapt_prompt", help="Disable prompt size adaptation")
49
+ parser.add_argument("--enable-prefix-caching", action="store_true", help="Enable prefix caching performance measurement")
50
+ parser.add_argument("--concurrency", type=int, nargs='+', default=[1], help="List of concurrency levels (number of concurrent requests per test) - default: [1]")
51
+ parser.add_argument("--save-result", type=str, help="File to save results to")
52
+ parser.add_argument("--format", type=str, default="md", choices=["md", "json", "csv"], help="Output format")
53
+
54
+ args = parser.parse_args()
55
+
56
+ return cls(
57
+ base_url=args.base_url,
58
+ api_key=args.api_key,
59
+ model=args.model,
60
+ served_model_name=args.served_model_name if args.served_model_name else args.model,
61
+ tokenizer=args.tokenizer,
62
+ pp_counts=args.pp,
63
+ tg_counts=args.tg,
64
+ depths=args.depth,
65
+ num_runs=args.runs,
66
+ no_cache=args.no_cache,
67
+ latency_mode=args.latency_mode,
68
+ no_warmup=args.no_warmup,
69
+ adapt_prompt=args.adapt_prompt,
70
+ enable_prefix_caching=args.enable_prefix_caching,
71
+ book_url=args.book_url,
72
+ post_run_cmd=args.post_run_cmd,
73
+ concurrency_levels=args.concurrency,
74
+ save_result=args.save_result,
75
+ result_format=args.format
76
+ )
@@ -0,0 +1,62 @@
1
+ import os
2
+ import hashlib
3
+ import requests
4
+ from transformers import AutoTokenizer
5
+
6
+ class TokenizedCorpus:
7
+ def __init__(self, book_url: str, tokenizer_name: str, model_name: str):
8
+ self.book_url = book_url
9
+ self.tokenizer = self._get_tokenizer(model_name, tokenizer_name)
10
+ self.tokens = self._load_data()
11
+
12
+ def _get_tokenizer(self, model_name, tokenizer_name=None):
13
+ try:
14
+ name = tokenizer_name if tokenizer_name else model_name
15
+ return AutoTokenizer.from_pretrained(name)
16
+ except Exception as e:
17
+ print(f"Error loading tokenizer: {e}")
18
+ print("Falling back to 'gpt2' tokenizer as approximation.")
19
+ return AutoTokenizer.from_pretrained("gpt2")
20
+
21
+ def _load_data(self):
22
+ try:
23
+ # Create cache directory if it doesn't exist
24
+ cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "llama-benchy")
25
+ os.makedirs(cache_dir, exist_ok=True)
26
+
27
+ # Generate hash of the URL for the filename
28
+ url_hash = hashlib.md5(self.book_url.encode()).hexdigest()
29
+ cache_file = os.path.join(cache_dir, f"{url_hash}.txt")
30
+
31
+ if os.path.exists(cache_file):
32
+ print(f"Loading text from cache: {cache_file}")
33
+ with open(cache_file, "r", encoding="utf-8") as f:
34
+ text = f.read()
35
+ else:
36
+ print(f"Downloading book from {self.book_url}...")
37
+ response = requests.get(self.book_url)
38
+ response.raise_for_status()
39
+ text = response.text
40
+ # Basic cleanup
41
+ start_idx = text.find("*** START OF THE PROJECT GUTENBERG EBOOK")
42
+ if start_idx != -1:
43
+ text = text[start_idx:]
44
+
45
+ # Save to cache
46
+ with open(cache_file, "w", encoding="utf-8") as f:
47
+ f.write(text)
48
+ print(f"Saved text to cache: {cache_file}")
49
+
50
+ return self.tokenizer.encode(text, add_special_tokens=False)
51
+ except Exception as e:
52
+ print(f"Error downloading or processing book: {e}")
53
+ exit(1)
54
+
55
+ def get_tokenizer(self):
56
+ return self.tokenizer
57
+
58
+ def get_tokens(self):
59
+ return self.tokens
60
+
61
+ def __len__(self):
62
+ return len(self.tokens)
@@ -0,0 +1,54 @@
1
+ import uuid
2
+ import numpy as np
3
+ from typing import Tuple, List
4
+
5
+ from .corpus import TokenizedCorpus
6
+
7
+ class PromptGenerator:
8
+ def __init__(self, corpus: TokenizedCorpus):
9
+ self.corpus = corpus
10
+ self.tokenizer = corpus.get_tokenizer()
11
+ self.all_tokens = corpus.get_tokens()
12
+
13
+ def generate(self, prompt_tokens: int, context_tokens: int = 0, no_cache: bool = False) -> Tuple[str, str]:
14
+ """
15
+ Generates a single (context, prompt) pair.
16
+ """
17
+ suffix = ""
18
+ suffix_len = 0
19
+ if no_cache:
20
+ suffix = f" {uuid.uuid4()}"
21
+ suffix_len = len(self.tokenizer.encode(suffix, add_special_tokens=False))
22
+
23
+ # Adjust prompt tokens to fetch from text
24
+ text_prompt_tokens = max(0, prompt_tokens - suffix_len)
25
+
26
+ # Create a pool of tokens large enough
27
+ total_needed = text_prompt_tokens + context_tokens
28
+
29
+ # Create a local reference to tokens to potentially extend
30
+ current_tokens = self.all_tokens
31
+
32
+ if len(current_tokens) < total_needed:
33
+ # Repeat tokens if not enough
34
+ current_tokens = current_tokens * (total_needed // len(current_tokens) + 2)
35
+
36
+ # Pick a random start position
37
+ max_start = len(current_tokens) - total_needed
38
+ start_idx = np.random.randint(0, max_start)
39
+
40
+ selected_tokens = current_tokens[start_idx : start_idx + total_needed]
41
+
42
+ context_text = self.tokenizer.decode(selected_tokens[:context_tokens]) if context_tokens > 0 else ""
43
+ prompt_text = self.tokenizer.decode(selected_tokens[context_tokens:])
44
+
45
+ if no_cache:
46
+ prompt_text += suffix
47
+
48
+ return context_text, prompt_text
49
+
50
+ def generate_batch(self, batch_size: int, prompt_tokens: int, context_tokens: int = 0, no_cache: bool = False) -> List[Tuple[str, str]]:
51
+ """
52
+ Generates a batch of (context, prompt) pairs.
53
+ """
54
+ return [self.generate(prompt_tokens, context_tokens, no_cache) for _ in range(batch_size)]