lemonade-sdk 8.1.4__py3-none-any.whl → 8.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (53) hide show
  1. lemonade/cache.py +6 -1
  2. lemonade/cli.py +47 -5
  3. lemonade/common/inference_engines.py +13 -4
  4. lemonade/common/status.py +4 -4
  5. lemonade/common/system_info.py +544 -1
  6. lemonade/profilers/agt_power.py +437 -0
  7. lemonade/profilers/hwinfo_power.py +429 -0
  8. lemonade/tools/accuracy.py +143 -48
  9. lemonade/tools/adapter.py +6 -1
  10. lemonade/tools/bench.py +26 -8
  11. lemonade/tools/flm/__init__.py +1 -0
  12. lemonade/tools/flm/utils.py +303 -0
  13. lemonade/tools/huggingface/bench.py +6 -1
  14. lemonade/tools/llamacpp/bench.py +146 -27
  15. lemonade/tools/llamacpp/load.py +30 -2
  16. lemonade/tools/llamacpp/utils.py +393 -33
  17. lemonade/tools/oga/bench.py +5 -26
  18. lemonade/tools/oga/load.py +60 -121
  19. lemonade/tools/oga/migration.py +403 -0
  20. lemonade/tools/report/table.py +76 -8
  21. lemonade/tools/server/flm.py +133 -0
  22. lemonade/tools/server/llamacpp.py +220 -553
  23. lemonade/tools/server/serve.py +684 -168
  24. lemonade/tools/server/static/js/chat.js +666 -342
  25. lemonade/tools/server/static/js/model-settings.js +24 -3
  26. lemonade/tools/server/static/js/models.js +597 -73
  27. lemonade/tools/server/static/js/shared.js +79 -14
  28. lemonade/tools/server/static/logs.html +191 -0
  29. lemonade/tools/server/static/styles.css +491 -66
  30. lemonade/tools/server/static/webapp.html +83 -31
  31. lemonade/tools/server/tray.py +158 -38
  32. lemonade/tools/server/utils/macos_tray.py +226 -0
  33. lemonade/tools/server/utils/{system_tray.py → windows_tray.py} +13 -0
  34. lemonade/tools/server/webapp.py +4 -1
  35. lemonade/tools/server/wrapped_server.py +559 -0
  36. lemonade/version.py +1 -1
  37. lemonade_install/install.py +54 -611
  38. {lemonade_sdk-8.1.4.dist-info → lemonade_sdk-8.2.2.dist-info}/METADATA +29 -72
  39. lemonade_sdk-8.2.2.dist-info/RECORD +83 -0
  40. lemonade_server/cli.py +145 -37
  41. lemonade_server/model_manager.py +521 -37
  42. lemonade_server/pydantic_models.py +28 -1
  43. lemonade_server/server_models.json +246 -92
  44. lemonade_server/settings.py +39 -39
  45. lemonade/tools/quark/__init__.py +0 -0
  46. lemonade/tools/quark/quark_load.py +0 -173
  47. lemonade/tools/quark/quark_quantize.py +0 -439
  48. lemonade_sdk-8.1.4.dist-info/RECORD +0 -77
  49. {lemonade_sdk-8.1.4.dist-info → lemonade_sdk-8.2.2.dist-info}/WHEEL +0 -0
  50. {lemonade_sdk-8.1.4.dist-info → lemonade_sdk-8.2.2.dist-info}/entry_points.txt +0 -0
  51. {lemonade_sdk-8.1.4.dist-info → lemonade_sdk-8.2.2.dist-info}/licenses/LICENSE +0 -0
  52. {lemonade_sdk-8.1.4.dist-info → lemonade_sdk-8.2.2.dist-info}/licenses/NOTICE.md +0 -0
  53. {lemonade_sdk-8.1.4.dist-info → lemonade_sdk-8.2.2.dist-info}/top_level.txt +0 -0
lemonade/tools/adapter.py CHANGED
@@ -10,11 +10,14 @@ class ModelAdapter(abc.ABC):
10
10
  """
11
11
  Self-benchmarking ModelAdapters can store their results in the
12
12
  tokens_per_second and time_to_first_token members.
13
+ ModelAdapters that run generate in a different process can store the
14
+ peak memory used (bytes) by that process in the peak_wset member.
13
15
  """
14
16
  self.tokens_per_second = None
15
17
  self.time_to_first_token = None
16
18
  self.prompt_tokens = None
17
19
  self.response_tokens = None
20
+ self.peak_wset = None
18
21
 
19
22
  self.type = "generic"
20
23
 
@@ -27,7 +30,9 @@ class ModelAdapter(abc.ABC):
27
30
  with recipe components, which themselves may not support a lot of arguments.
28
31
 
29
32
  The generate method should store prompt and response lengths (in tokens)
30
- in the prompt_tokens and response_tokens members.
33
+ in the prompt_tokens and response_tokens members. If a different process is used,
34
+ the generate method can also store the peak memory used by that process in the
35
+ peak_wset member.
31
36
  """
32
37
 
33
38
 
lemonade/tools/bench.py CHANGED
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
2
2
  import argparse
3
3
  import os
4
4
  import platform
5
- import psutil
6
5
  from lemonade.state import State
7
6
  from lemonade.tools import Tool
8
7
  from lemonade.cache import Keys
@@ -29,7 +28,9 @@ class Bench(Tool, ABC):
29
28
  Keys.SECONDS_TO_FIRST_TOKEN,
30
29
  Keys.STD_DEV_SECONDS_TO_FIRST_TOKEN,
31
30
  Keys.TOKEN_GENERATION_TOKENS_PER_SECOND,
31
+ Keys.STD_DEV_TOKENS_PER_SECOND,
32
32
  Keys.PREFILL_TOKENS_PER_SECOND,
33
+ Keys.STD_DEV_PREFILL_TOKENS_PER_SECOND,
33
34
  Keys.PROMPT_TOKENS,
34
35
  Keys.RESPONSE_TOKENS,
35
36
  Keys.MAX_MEMORY_USED_GBYTE,
@@ -42,7 +43,9 @@ class Bench(Tool, ABC):
42
43
  self.mean_time_to_first_token_list = []
43
44
  self.std_dev_time_to_first_token_list = []
44
45
  self.prefill_tokens_per_second_list = []
46
+ self.std_dev_prefill_tokens_per_second_list = []
45
47
  self.token_generation_tokens_per_second_list = []
48
+ self.std_dev_token_generation_tokens_per_second_list = []
46
49
  self.max_memory_used_gb_list = []
47
50
 
48
51
  # Max memory used can only be measured on Windows systems
@@ -88,7 +91,7 @@ class Bench(Tool, ABC):
88
91
  default=[str(default_prompt_length)],
89
92
  metavar="PROMPT",
90
93
  help="Input one or more prompts to the LLM. Three formats are supported. "
91
- "1) integer: use a synthetic prompt with the specified length "
94
+ "1) integer: use a synthetic prompt with the specified token length "
92
95
  "2) str: use a user-provided prompt string "
93
96
  "3) path/to/prompt.txt: load the prompt from a text file. "
94
97
  f"(default: {default_prompt_length}) ",
@@ -190,11 +193,6 @@ class Bench(Tool, ABC):
190
193
  )
191
194
  self.first_run_prompt = False
192
195
 
193
- if self.save_max_memory_used:
194
- self.max_memory_used_gb_list.append(
195
- psutil.Process().memory_info().peak_wset / 1024**3
196
- )
197
-
198
196
  self.set_percent_progress(None)
199
197
  self.save_stats(state)
200
198
 
@@ -211,7 +209,10 @@ class Bench(Tool, ABC):
211
209
  output_tokens,
212
210
  **kwargs,
213
211
  ):
214
- pass
212
+ """
213
+ The run_prompt method should append the appropriate value to each of the per prompt
214
+ measurement statistics lists that are members of the Bench class.
215
+ """
215
216
 
216
217
  @staticmethod
217
218
  def get_item_or_list(lst):
@@ -246,10 +247,27 @@ class Bench(Tool, ABC):
246
247
  Keys.PREFILL_TOKENS_PER_SECOND,
247
248
  self.get_item_or_list(self.prefill_tokens_per_second_list),
248
249
  )
250
+ if not all(
251
+ element is None for element in self.std_dev_prefill_tokens_per_second_list
252
+ ):
253
+ state.save_stat(
254
+ Keys.STD_DEV_PREFILL_TOKENS_PER_SECOND,
255
+ self.get_item_or_list(self.std_dev_prefill_tokens_per_second_list),
256
+ )
249
257
  state.save_stat(
250
258
  Keys.TOKEN_GENERATION_TOKENS_PER_SECOND,
251
259
  self.get_item_or_list(self.token_generation_tokens_per_second_list),
252
260
  )
261
+ if not all(
262
+ element is None
263
+ for element in self.std_dev_token_generation_tokens_per_second_list
264
+ ):
265
+ state.save_stat(
266
+ Keys.STD_DEV_TOKENS_PER_SECOND,
267
+ self.get_item_or_list(
268
+ self.std_dev_token_generation_tokens_per_second_list
269
+ ),
270
+ )
253
271
  if self.save_max_memory_used:
254
272
  state.save_stat(
255
273
  Keys.MAX_MEMORY_USED_GBYTE,
@@ -0,0 +1 @@
1
+ # FLM (FastFlowLM) utilities for Lemonade SDK
@@ -0,0 +1,303 @@
1
+ """
2
+ FLM (FastFlowLM) utilities for installation, version checking, and model management.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import subprocess
8
+ import tempfile
9
+ import time
10
+ from typing import List, Optional
11
+
12
+ import requests
13
+ from packaging.version import Version, InvalidVersion
14
+
15
+
16
+ def get_flm_latest_version() -> Optional[str]:
17
+ """
18
+ Get and return the latest FLM version from "https://github.com/FastFlowLM/FastFlowLM/tags"
19
+ This uses the GitHub tags API.
20
+ """
21
+ url = "https://api.github.com/repos/FastFlowLM/FastFlowLM/tags"
22
+ try:
23
+ response = requests.get(url, timeout=10)
24
+ response.raise_for_status()
25
+ tags = response.json()
26
+ if not tags:
27
+ return None
28
+ # Tags are sorted in reverse chronological order; find the first that looks like a version
29
+ for tag in tags:
30
+ tag_name = tag.get("name", "")
31
+ # Accept tags of the form v0.9.10, 0.9.10, etc.
32
+ if tag_name.startswith("v"):
33
+ version_candidate = tag_name[1:]
34
+ else:
35
+ version_candidate = tag_name
36
+ try:
37
+ # validate it's a version string
38
+ _ = Version(version_candidate)
39
+ return version_candidate
40
+ except InvalidVersion:
41
+ continue
42
+ return None
43
+ except requests.exceptions.RequestException as e:
44
+ logging.debug("Error retrieving latest FLM version: %s", e)
45
+ return None
46
+
47
+
48
+ def check_flm_version() -> Optional[str]:
49
+ """
50
+ Check if FLM is installed and return version, or None if not available.
51
+ """
52
+ latest_version_str = get_flm_latest_version()
53
+ try:
54
+ result = subprocess.run(
55
+ ["flm", "version"],
56
+ capture_output=True,
57
+ text=True,
58
+ check=True,
59
+ encoding="utf-8",
60
+ errors="replace",
61
+ )
62
+
63
+ # Parse version from output like "FLM v0.9.4"
64
+ output = result.stdout.strip()
65
+ if output.startswith("FLM v"):
66
+ version_str = output[5:] # Remove "FLM v" prefix
67
+ return version_str, latest_version_str
68
+ return None, latest_version_str
69
+
70
+ except (subprocess.CalledProcessError, FileNotFoundError):
71
+ return None, latest_version_str
72
+
73
+
74
+ def refresh_environment():
75
+ """
76
+ Refresh PATH to pick up newly installed executables.
77
+ """
78
+ if os.name == "nt": # Windows
79
+ # On Windows, we need to refresh the PATH from registry
80
+ import winreg
81
+
82
+ try:
83
+ with winreg.OpenKey(
84
+ winreg.HKEY_LOCAL_MACHINE,
85
+ r"SYSTEM\CurrentControlSet\Control\Session Manager\Environment",
86
+ ) as key:
87
+ path_value, _ = winreg.QueryValueEx(key, "PATH")
88
+ os.environ["PATH"] = path_value + ";" + os.environ.get("PATH", "")
89
+ except Exception as e: # pylint: disable=broad-except
90
+ logging.debug("Could not refresh PATH from registry: %s", e)
91
+
92
+ # Also try to add common installation paths
93
+ common_paths = [
94
+ r"C:\Program Files\FLM",
95
+ r"C:\Program Files (x86)\FLM",
96
+ os.path.expanduser(r"~\AppData\Local\FLM"),
97
+ ]
98
+ for path in common_paths:
99
+ if os.path.exists(path) and path not in os.environ.get("PATH", ""):
100
+ os.environ["PATH"] = path + ";" + os.environ.get("PATH", "")
101
+
102
+
103
+ def install_flm():
104
+ """
105
+ Check if FLM is installed and at minimum version.
106
+ If not, download and run the GUI installer, then wait for completion.
107
+ """
108
+ # Check current FLM installation
109
+ current_version, latest_version = check_flm_version()
110
+
111
+ if (
112
+ current_version
113
+ and latest_version
114
+ and Version(current_version) == Version(latest_version)
115
+ ):
116
+ logging.info(
117
+ "FLM v%s is already installed and is up to date (latest version: v%s).",
118
+ current_version,
119
+ latest_version,
120
+ )
121
+ return
122
+
123
+ if current_version:
124
+ if not latest_version:
125
+ logging.info(
126
+ "Unable to detect the latest FLM version; continuing with installed FLM v%s.",
127
+ current_version,
128
+ )
129
+ return
130
+ logging.info(
131
+ "FLM v%s is installed but below latest version v%s. Upgrading...",
132
+ current_version,
133
+ latest_version,
134
+ )
135
+ verysilent = True
136
+ else:
137
+ logging.info("FLM not found. Installing FLM v%s or later...", latest_version)
138
+ verysilent = False
139
+
140
+ # Download the installer
141
+ # pylint: disable=line-too-long
142
+ installer_url = "https://github.com/FastFlowLM/FastFlowLM/releases/latest/download/flm-setup.exe"
143
+ installer_path = os.path.join(tempfile.gettempdir(), "flm-setup.exe")
144
+ installer_args = [installer_path, "/VERYSILENT"] if verysilent else [installer_path]
145
+
146
+ try:
147
+ # Remove existing installer if present
148
+ if os.path.exists(installer_path):
149
+ os.remove(installer_path)
150
+
151
+ logging.info("Downloading FLM installer...")
152
+ response = requests.get(installer_url, stream=True, timeout=30)
153
+ response.raise_for_status()
154
+
155
+ # Save installer to disk
156
+ with open(installer_path, "wb") as f:
157
+ for chunk in response.iter_content(chunk_size=8192):
158
+ f.write(chunk)
159
+ f.flush()
160
+ os.fsync(f.fileno())
161
+
162
+ logging.info("Downloaded FLM installer to %s", installer_path)
163
+
164
+ # Launch the installer GUI
165
+ logging.warning(
166
+ "Launching FLM installer GUI. Please complete the installation..."
167
+ if not verysilent
168
+ else "Installing FLM..."
169
+ )
170
+
171
+ # Launch installer and wait for it to complete
172
+ if os.name == "nt": # Windows
173
+ process = subprocess.Popen(installer_args, shell=True)
174
+ else:
175
+ process = subprocess.Popen(installer_args)
176
+
177
+ # Wait for installer to complete
178
+ process.wait()
179
+
180
+ if process.returncode != 0:
181
+ raise RuntimeError(
182
+ f"FLM installer failed with exit code {process.returncode}"
183
+ )
184
+
185
+ logging.info("FLM installer completed successfully")
186
+
187
+ # Refresh environment to pick up new PATH entries
188
+ refresh_environment()
189
+
190
+ # Wait a moment for system to update
191
+ time.sleep(2)
192
+
193
+ # Verify installation
194
+ max_retries = 10
195
+ for attempt in range(max_retries):
196
+ new_version, latest_version = check_flm_version()
197
+ if new_version and Version(new_version) == Version(latest_version):
198
+ logging.info("FLM v%s successfully installed and verified", new_version)
199
+ return
200
+
201
+ if attempt < max_retries - 1:
202
+ logging.debug(
203
+ "FLM not yet available in PATH, retrying... (attempt %d/%d)",
204
+ attempt + 1,
205
+ max_retries,
206
+ )
207
+ time.sleep(3)
208
+ refresh_environment()
209
+
210
+ # Final check failed
211
+ raise RuntimeError(
212
+ "FLM installation completed but 'flm' command is not available in PATH. "
213
+ "Please ensure FLM is properly installed and available in your system PATH."
214
+ )
215
+
216
+ except requests.RequestException as e:
217
+ raise RuntimeError(f"Failed to download FLM installer: {e}") from e
218
+ except Exception as e:
219
+ raise RuntimeError(f"FLM installation failed: {e}") from e
220
+ finally:
221
+ # Clean up installer file
222
+ if os.path.exists(installer_path):
223
+ try:
224
+ os.remove(installer_path)
225
+ except OSError:
226
+ pass # Ignore cleanup errors
227
+
228
+
229
+ def download_flm_model(config_checkpoint, _=None, do_not_upgrade=False) -> dict:
230
+ """
231
+ Downloads the FLM model for the given configuration.
232
+
233
+ Args:
234
+ config_checkpoint: name of the FLM model to install.
235
+ _: placeholder for `config_mmproj`, which is standard
236
+ for WrappedServer (see llamacpp/utils.py) .
237
+ do_not_upgrade: whether to re-download the model if it is already
238
+ available.
239
+ """
240
+
241
+ if do_not_upgrade:
242
+ command = ["flm", "pull", f"{config_checkpoint}"]
243
+ else:
244
+ command = ["flm", "pull", f"{config_checkpoint}", "--force"]
245
+
246
+ subprocess.run(command, check=True)
247
+
248
+
249
+ def get_flm_installed_models() -> List[str]:
250
+ """
251
+ Parse FLM model list and return installed model checkpoints.
252
+
253
+ Returns:
254
+ List of installed FLM model checkpoints (e.g., ["llama3.2:1b", "gemma3:4b"])
255
+ """
256
+ try:
257
+ result = subprocess.run(
258
+ ["flm", "list"],
259
+ capture_output=True,
260
+ text=True,
261
+ check=True,
262
+ encoding="utf-8",
263
+ errors="replace",
264
+ )
265
+
266
+ # Check if we got valid output
267
+ if not result.stdout:
268
+ return []
269
+
270
+ installed_checkpoints = []
271
+
272
+ lines = result.stdout.strip().split("\n")
273
+ for line in lines:
274
+ line = line.strip()
275
+ if line.startswith("- "):
276
+ # Remove the leading "- " and parse the model info
277
+ model_info = line[2:].strip()
278
+
279
+ # Check if model is installed (✅)
280
+ if model_info.endswith(" ✅"):
281
+ checkpoint = model_info[:-2].strip()
282
+ installed_checkpoints.append(checkpoint)
283
+
284
+ return installed_checkpoints
285
+
286
+ except (
287
+ subprocess.CalledProcessError,
288
+ FileNotFoundError,
289
+ AttributeError,
290
+ NotADirectoryError,
291
+ ):
292
+ # FLM not installed, not available, or output parsing failed
293
+ return []
294
+
295
+
296
+ def is_flm_available() -> bool:
297
+ """
298
+ Check if FLM is available and meets minimum version requirements.
299
+ """
300
+ current_version, latest_version = check_flm_version()
301
+ return current_version is not None and Version(current_version) == Version(
302
+ latest_version
303
+ )
@@ -1,6 +1,7 @@
1
1
  import argparse
2
2
  import statistics
3
3
  from statistics import StatisticsError
4
+ import psutil
4
5
  from lemonade.state import State
5
6
  from lemonade.cache import Keys
6
7
  from lemonade.tools.bench import Bench
@@ -75,7 +76,7 @@ class HuggingfaceBench(Bench):
75
76
  warmup_iterations: int,
76
77
  output_tokens: int,
77
78
  num_beams: int = default_beams,
78
- ) -> State:
79
+ ):
79
80
  """
80
81
  We don't have access to the internal timings of generate(), so time to first
81
82
  token (TTFT, aka prefill latency) and token/s are calculated using the following formulae:
@@ -176,6 +177,10 @@ class HuggingfaceBench(Bench):
176
177
  self.token_generation_tokens_per_second_list.append(
177
178
  (mean_token_len - 1) / mean_decode_latency
178
179
  )
180
+ if self.save_max_memory_used:
181
+ self.max_memory_used_gb_list.append(
182
+ psutil.Process().memory_info().peak_wset / 1024**3
183
+ )
179
184
 
180
185
 
181
186
  # This file was originally licensed under Apache 2.0. It has been modified.
@@ -2,9 +2,15 @@ import argparse
2
2
  import statistics
3
3
  from statistics import StatisticsError
4
4
  from lemonade.state import State
5
- from lemonade.cache import Keys
5
+ from lemonade.tools.tool import Tool
6
6
  from lemonade.tools.llamacpp.utils import LlamaCppAdapter
7
- from lemonade.tools.bench import Bench
7
+ from lemonade.tools.bench import (
8
+ Bench,
9
+ default_prompt_length,
10
+ default_iterations,
11
+ default_output_tokens,
12
+ default_warmup_runs,
13
+ )
8
14
 
9
15
 
10
16
  class LlamaCppBench(Bench):
@@ -14,16 +20,6 @@ class LlamaCppBench(Bench):
14
20
 
15
21
  unique_name = "llamacpp-bench"
16
22
 
17
- def __init__(self):
18
- super().__init__()
19
-
20
- # Additional statistics generated by this bench tool
21
- self.status_stats.insert(
22
- self.status_stats.index(Keys.TOKEN_GENERATION_TOKENS_PER_SECOND) + 1,
23
- Keys.STD_DEV_TOKENS_PER_SECOND,
24
- )
25
- self.std_dev_token_generation_tokens_per_second_list = []
26
-
27
23
  @staticmethod
28
24
  def parser(add_help: bool = True) -> argparse.ArgumentParser:
29
25
  parser = __class__.helpful_parser(
@@ -33,8 +29,46 @@ class LlamaCppBench(Bench):
33
29
 
34
30
  parser = Bench.parser(parser)
35
31
 
32
+ parser.add_argument(
33
+ "--cli",
34
+ action="store_true",
35
+ help="Set this flag to use llama-cli.exe to benchmark model performance. "
36
+ "This executable will be called once per iteration. Otherwise, "
37
+ "llama-bench.exe is used by default. In this default behavior behavior, "
38
+ "the only valid prompt format is integer token lengths. Also, the "
39
+ "warmup-iterations parameter is ignored and the default value for number of "
40
+ "threads is 16.",
41
+ )
42
+
36
43
  return parser
37
44
 
45
+ def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
46
+ """
47
+ Helper function to parse CLI arguments into the args expected by run()
48
+ """
49
+
50
+ # Call Tool parse method, NOT the Bench parse method
51
+ parsed_args = Tool.parse(self, state, args, known_only)
52
+
53
+ if parsed_args.cli:
54
+ parsed_args = super().parse(state, args, known_only)
55
+ else:
56
+ # Make sure prompts is a list of integers
57
+ if parsed_args.prompts is None:
58
+ parsed_args.prompts = [default_prompt_length]
59
+ prompt_ints = []
60
+ for prompt_item in parsed_args.prompts:
61
+ if prompt_item.isdigit():
62
+ prompt_ints.append(int(prompt_item))
63
+ else:
64
+ raise Exception(
65
+ f"When not using the --cli flag to {self.unique_name}, the prompt format "
66
+ "must be in integer format."
67
+ )
68
+ parsed_args.prompts = prompt_ints
69
+
70
+ return parsed_args
71
+
38
72
  def run_prompt(
39
73
  self,
40
74
  state: State,
@@ -43,7 +77,7 @@ class LlamaCppBench(Bench):
43
77
  iterations: int,
44
78
  warmup_iterations: int,
45
79
  output_tokens: int,
46
- ) -> State:
80
+ ):
47
81
  """
48
82
  Benchmark llama.cpp model that was loaded by LoadLlamaCpp.
49
83
  """
@@ -61,6 +95,7 @@ class LlamaCppBench(Bench):
61
95
 
62
96
  per_iteration_tokens_per_second = []
63
97
  per_iteration_time_to_first_token = []
98
+ per_iteration_peak_wset = []
64
99
 
65
100
  for iteration in range(iterations + warmup_iterations):
66
101
  try:
@@ -69,7 +104,10 @@ class LlamaCppBench(Bench):
69
104
  model.time_to_first_token = None
70
105
  model.tokens_per_second = None
71
106
  raw_output, stderr = model.generate(
72
- prompt, max_new_tokens=output_tokens, return_raw=True
107
+ prompt,
108
+ max_new_tokens=output_tokens,
109
+ return_raw=True,
110
+ save_max_memory_used=self.save_max_memory_used,
73
111
  )
74
112
 
75
113
  if model.time_to_first_token is None or model.tokens_per_second is None:
@@ -85,6 +123,7 @@ class LlamaCppBench(Bench):
85
123
  if iteration > warmup_iterations - 1:
86
124
  per_iteration_tokens_per_second.append(model.tokens_per_second)
87
125
  per_iteration_time_to_first_token.append(model.time_to_first_token)
126
+ per_iteration_peak_wset.append(model.peak_wset)
88
127
 
89
128
  report_progress_fn((iteration + 1) / (warmup_iterations + iterations))
90
129
 
@@ -115,21 +154,101 @@ class LlamaCppBench(Bench):
115
154
  except StatisticsError:
116
155
  # Less than 2 measurements
117
156
  self.std_dev_token_generation_tokens_per_second_list.append(None)
157
+ if self.save_max_memory_used:
158
+ filtered_list = [
159
+ item for item in per_iteration_peak_wset if item is not None
160
+ ]
161
+ mean_gb_used = (
162
+ None
163
+ if len(filtered_list) == 0
164
+ else statistics.mean(filtered_list) / 1024**3
165
+ )
166
+ self.max_memory_used_gb_list.append(mean_gb_used)
167
+
168
+ def run_llama_bench_exe(self, state, prompts, iterations, output_tokens):
169
+
170
+ if prompts is None:
171
+ prompts = [default_prompt_length]
172
+ elif isinstance(prompts, int):
173
+ prompts = [prompts]
174
+
175
+ state.save_stat("prompts", prompts)
176
+ state.save_stat("iterations", iterations)
177
+ state.save_stat("output_tokens", output_tokens)
118
178
 
119
- def save_stats(self, state):
120
- super().save_stats(state)
121
-
122
- # Save additional statistics
123
- if not all(
124
- element is None
125
- for element in self.std_dev_token_generation_tokens_per_second_list
126
- ):
127
- state.save_stat(
128
- Keys.STD_DEV_TOKENS_PER_SECOND,
129
- self.get_item_or_list(
130
- self.std_dev_token_generation_tokens_per_second_list
131
- ),
179
+ counter = 0
180
+ report_progress_fn = lambda x: self.set_percent_progress(
181
+ 100 * (counter + x) / len(prompts)
182
+ )
183
+ self.first_run_prompt = True
184
+ for counter, prompt in enumerate(prompts):
185
+ report_progress_fn(0)
186
+
187
+ self.run_prompt_llama_bench_exe(
188
+ state,
189
+ prompt,
190
+ iterations,
191
+ output_tokens,
132
192
  )
193
+ self.first_run_prompt = False
194
+
195
+ self.set_percent_progress(None)
196
+ self.save_stats(state)
197
+ return state
198
+
199
+ def run_prompt_llama_bench_exe(self, state, prompt, iterations, output_tokens):
200
+
201
+ model: LlamaCppAdapter = state.model
202
+ prompt_length, pp_tps, pp_tps_sd, tg_tps, tg_tps_sd, peak_wset = (
203
+ model.benchmark(prompt, iterations, output_tokens)
204
+ )
205
+ self.input_ids_len_list.append(prompt_length)
206
+ self.prefill_tokens_per_second_list.append(pp_tps)
207
+ self.std_dev_prefill_tokens_per_second_list.append(pp_tps_sd)
208
+ self.mean_time_to_first_token_list.append(prompt_length / pp_tps)
209
+ self.token_generation_tokens_per_second_list.append(tg_tps)
210
+ self.std_dev_token_generation_tokens_per_second_list.append(tg_tps_sd)
211
+ self.tokens_out_len_list.append(output_tokens * iterations)
212
+ if self.save_max_memory_used:
213
+ if peak_wset is not None:
214
+ self.max_memory_used_gb_list.append(peak_wset / 1024**3)
215
+ else:
216
+ self.max_memory_used_gb_list.append(None)
217
+
218
+ def run(
219
+ self,
220
+ state: State,
221
+ prompts: list[str] = None,
222
+ iterations: int = default_iterations,
223
+ warmup_iterations: int = default_warmup_runs,
224
+ output_tokens: int = default_output_tokens,
225
+ cli: bool = False,
226
+ **kwargs,
227
+ ) -> State:
228
+ """
229
+ Args:
230
+ - prompts: List of input prompts used as starting points for LLM text generation
231
+ - iterations: Number of benchmarking samples to take; results are
232
+ reported as the median and mean of the samples.
233
+ - warmup_iterations: Subset of the iterations to treat as warmup,
234
+ and not included in the results.
235
+ - output_tokens: Number of new tokens LLM to create.
236
+ - cli: Use multiple calls to llama-cpp.exe instead of llama-bench.exe
237
+ - kwargs: Additional parameters used by bench tools
238
+ """
239
+
240
+ # Check that state has the attribute model and it is a LlamaCPP model
241
+ if not hasattr(state, "model") or not isinstance(state.model, LlamaCppAdapter):
242
+ raise Exception("Load model using llamacpp-load first.")
243
+
244
+ if cli:
245
+ state = super().run(
246
+ state, prompts, iterations, warmup_iterations, output_tokens, **kwargs
247
+ )
248
+ else:
249
+ state = self.run_llama_bench_exe(state, prompts, iterations, output_tokens)
250
+
251
+ return state
133
252
 
134
253
 
135
254
  # This file was originally licensed under Apache 2.0. It has been modified.