lemonade-sdk 7.0.4__py3-none-any.whl → 8.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/api.py +3 -3
- lemonade/cli.py +11 -17
- lemonade/common/build.py +0 -47
- lemonade/common/network.py +50 -0
- lemonade/common/status.py +2 -21
- lemonade/common/system_info.py +19 -4
- lemonade/profilers/memory_tracker.py +3 -1
- lemonade/tools/accuracy.py +3 -4
- lemonade/tools/adapter.py +1 -2
- lemonade/tools/{huggingface_bench.py → huggingface/bench.py} +2 -87
- lemonade/tools/huggingface/load.py +235 -0
- lemonade/tools/{huggingface_load.py → huggingface/utils.py} +87 -255
- lemonade/tools/humaneval.py +9 -3
- lemonade/tools/{llamacpp_bench.py → llamacpp/bench.py} +1 -1
- lemonade/tools/{llamacpp.py → llamacpp/load.py} +18 -2
- lemonade/tools/mmlu.py +7 -15
- lemonade/tools/{ort_genai/oga.py → oga/load.py} +31 -422
- lemonade/tools/oga/utils.py +423 -0
- lemonade/tools/perplexity.py +4 -3
- lemonade/tools/prompt.py +2 -1
- lemonade/tools/quark/quark_load.py +2 -1
- lemonade/tools/quark/quark_quantize.py +5 -5
- lemonade/tools/report/table.py +3 -3
- lemonade/tools/server/llamacpp.py +188 -45
- lemonade/tools/server/serve.py +184 -146
- lemonade/tools/server/static/favicon.ico +0 -0
- lemonade/tools/server/static/styles.css +568 -0
- lemonade/tools/server/static/webapp.html +439 -0
- lemonade/tools/server/tray.py +458 -0
- lemonade/tools/server/{port_utils.py → utils/port.py} +22 -3
- lemonade/tools/server/utils/system_tray.py +395 -0
- lemonade/tools/server/{instructions.py → webapp.py} +4 -10
- lemonade/version.py +1 -1
- lemonade_install/install.py +46 -28
- lemonade_sdk-8.0.1.dist-info/METADATA +179 -0
- lemonade_sdk-8.0.1.dist-info/RECORD +70 -0
- lemonade_server/cli.py +182 -27
- lemonade_server/model_manager.py +192 -20
- lemonade_server/pydantic_models.py +9 -4
- lemonade_server/server_models.json +5 -3
- lemonade/common/analyze_model.py +0 -26
- lemonade/common/labels.py +0 -61
- lemonade/common/onnx_helpers.py +0 -176
- lemonade/common/plugins.py +0 -10
- lemonade/common/tensor_helpers.py +0 -83
- lemonade/tools/server/static/instructions.html +0 -262
- lemonade_sdk-7.0.4.dist-info/METADATA +0 -113
- lemonade_sdk-7.0.4.dist-info/RECORD +0 -69
- /lemonade/tools/{ort_genai → oga}/__init__.py +0 -0
- /lemonade/tools/{ort_genai/oga_bench.py → oga/bench.py} +0 -0
- /lemonade/tools/server/{thread_utils.py → utils/thread.py} +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/WHEEL +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.4.dist-info → lemonade_sdk-8.0.1.dist-info}/top_level.txt +0 -0
lemonade/api.py
CHANGED
|
@@ -57,7 +57,7 @@ def from_pretrained(
|
|
|
57
57
|
# Huggingface supports all checkpoints, so there is nothing to check for
|
|
58
58
|
|
|
59
59
|
import torch
|
|
60
|
-
from lemonade.tools.
|
|
60
|
+
from lemonade.tools.huggingface.load import HuggingfaceLoad
|
|
61
61
|
|
|
62
62
|
state = _make_state(recipe, checkpoint)
|
|
63
63
|
|
|
@@ -73,7 +73,7 @@ def from_pretrained(
|
|
|
73
73
|
# Huggingface Transformers recipe for discrete GPU (Nvidia, Instinct, Radeon)
|
|
74
74
|
|
|
75
75
|
import torch
|
|
76
|
-
from lemonade.tools.
|
|
76
|
+
from lemonade.tools.huggingface.load import HuggingfaceLoad
|
|
77
77
|
|
|
78
78
|
state = _make_state(recipe, checkpoint)
|
|
79
79
|
|
|
@@ -87,7 +87,7 @@ def from_pretrained(
|
|
|
87
87
|
return state.model, state.tokenizer
|
|
88
88
|
|
|
89
89
|
elif recipe.startswith("oga-"):
|
|
90
|
-
import lemonade.tools.
|
|
90
|
+
import lemonade.tools.oga.load as oga
|
|
91
91
|
|
|
92
92
|
# Make sure the user chose a supported runtime, e.g., oga-cpu
|
|
93
93
|
user_backend = recipe.split("oga-")[1]
|
lemonade/cli.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
3
|
+
# pylint: disable=C0413
|
|
4
|
+
# Prevent HF warnings from showing on every import
|
|
5
|
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
|
2
6
|
from lemonade.version import __version__ as version_number
|
|
3
7
|
from lemonade.tools import FirstTool, NiceHelpFormatter
|
|
4
8
|
from lemonade.profilers.memory_tracker import MemoryTracker
|
|
@@ -8,12 +12,12 @@ from lemonade.sequence import Sequence
|
|
|
8
12
|
from lemonade.tools.management_tools import Cache, Version, SystemInfo
|
|
9
13
|
from lemonade.state import State
|
|
10
14
|
|
|
11
|
-
from lemonade.tools.
|
|
12
|
-
|
|
13
|
-
from lemonade.tools.
|
|
14
|
-
from lemonade.tools.
|
|
15
|
-
from lemonade.tools.
|
|
16
|
-
from lemonade.tools.llamacpp import LoadLlamaCpp
|
|
15
|
+
from lemonade.tools.huggingface.load import HuggingfaceLoad
|
|
16
|
+
from lemonade.tools.huggingface.bench import HuggingfaceBench
|
|
17
|
+
from lemonade.tools.oga.load import OgaLoad
|
|
18
|
+
from lemonade.tools.oga.bench import OgaBench
|
|
19
|
+
from lemonade.tools.llamacpp.bench import LlamaCppBench
|
|
20
|
+
from lemonade.tools.llamacpp.load import LoadLlamaCpp
|
|
17
21
|
|
|
18
22
|
import lemonade.cache as cache
|
|
19
23
|
from lemonade.tools.mmlu import AccuracyMMLU
|
|
@@ -24,7 +28,6 @@ from lemonade.tools.prompt import LLMPrompt
|
|
|
24
28
|
from lemonade.tools.quark.quark_load import QuarkLoad
|
|
25
29
|
from lemonade.tools.quark.quark_quantize import QuarkQuantize
|
|
26
30
|
from lemonade.tools.report.llm_report import LemonadeReport
|
|
27
|
-
from lemonade.tools.server.serve import Server
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
def main():
|
|
@@ -40,26 +43,17 @@ def main():
|
|
|
40
43
|
LMEvalHarness,
|
|
41
44
|
LLMPrompt,
|
|
42
45
|
HuggingfaceBench,
|
|
46
|
+
OgaLoad,
|
|
43
47
|
OgaBench,
|
|
44
48
|
QuarkQuantize,
|
|
45
49
|
QuarkLoad,
|
|
46
50
|
LemonadeReport,
|
|
47
|
-
Server,
|
|
48
51
|
# Inherited from lemonade
|
|
49
52
|
Cache,
|
|
50
53
|
Version,
|
|
51
54
|
SystemInfo,
|
|
52
55
|
]
|
|
53
56
|
|
|
54
|
-
# Import onnxruntime-genai recipes
|
|
55
|
-
try:
|
|
56
|
-
from lemonade.tools.ort_genai.oga import OgaLoad
|
|
57
|
-
|
|
58
|
-
tools = tools + [OgaLoad]
|
|
59
|
-
|
|
60
|
-
except ModuleNotFoundError:
|
|
61
|
-
pass
|
|
62
|
-
|
|
63
57
|
# List the available profilers
|
|
64
58
|
profilers = [MemoryTracker]
|
|
65
59
|
|
lemonade/common/build.py
CHANGED
|
@@ -6,8 +6,6 @@ from typing import Dict
|
|
|
6
6
|
import hashlib
|
|
7
7
|
import psutil
|
|
8
8
|
import yaml
|
|
9
|
-
import torch
|
|
10
|
-
import numpy as np
|
|
11
9
|
import lemonade.common.exceptions as exp
|
|
12
10
|
|
|
13
11
|
state_file_name = "state.yaml"
|
|
@@ -101,51 +99,6 @@ def unique_id():
|
|
|
101
99
|
return hashlib.sha256(f"{pid}{start_time}".encode()).hexdigest()
|
|
102
100
|
|
|
103
101
|
|
|
104
|
-
def get_shapes_and_dtypes(inputs: dict):
|
|
105
|
-
"""
|
|
106
|
-
Return the shape and data type of each value in the inputs dict
|
|
107
|
-
"""
|
|
108
|
-
shapes = {}
|
|
109
|
-
dtypes = {}
|
|
110
|
-
for key in sorted(inputs):
|
|
111
|
-
value = inputs[key]
|
|
112
|
-
if isinstance(
|
|
113
|
-
value,
|
|
114
|
-
(list, tuple),
|
|
115
|
-
):
|
|
116
|
-
for v, i in zip(value, range(len(value))):
|
|
117
|
-
if isinstance(v, (list, tuple)):
|
|
118
|
-
# Handle nested lists/tuples, for example past_key_values
|
|
119
|
-
# in an LLM that has KV-caching enabled
|
|
120
|
-
for v2, i2 in zip(v, range(len(v))):
|
|
121
|
-
subsubkey = f"{key}[{i}][{i2}]"
|
|
122
|
-
shapes[subsubkey] = np.array(v2).shape
|
|
123
|
-
dtypes[subsubkey] = np.array(v2).dtype.name
|
|
124
|
-
else:
|
|
125
|
-
# Handle single list/tuple
|
|
126
|
-
subkey = f"{key}[{i}]"
|
|
127
|
-
shapes[subkey] = np.array(v).shape
|
|
128
|
-
dtypes[subkey] = np.array(v).dtype.name
|
|
129
|
-
elif torch.is_tensor(value):
|
|
130
|
-
shapes[key] = np.array(value.detach()).shape
|
|
131
|
-
dtypes[key] = np.array(value.detach()).dtype.name
|
|
132
|
-
elif isinstance(value, np.ndarray):
|
|
133
|
-
shapes[key] = value.shape
|
|
134
|
-
dtypes[key] = value.dtype.name
|
|
135
|
-
elif isinstance(value, (bool, int, float)):
|
|
136
|
-
shapes[key] = (1,)
|
|
137
|
-
dtypes[key] = type(value).__name__
|
|
138
|
-
elif value is None:
|
|
139
|
-
pass
|
|
140
|
-
else:
|
|
141
|
-
raise exp.Error(
|
|
142
|
-
"One of the provided inputs contains the unsupported "
|
|
143
|
-
f' type {type(value)} at key "{key}".'
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
return shapes, dtypes
|
|
147
|
-
|
|
148
|
-
|
|
149
102
|
class Logger:
|
|
150
103
|
"""
|
|
151
104
|
Redirects stdout to file (and console if needed)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import socket
|
|
4
|
+
from huggingface_hub import model_info
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def is_offline():
|
|
8
|
+
"""
|
|
9
|
+
Check if the system is offline by attempting to connect to huggingface.co.
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
bool: True if the system is offline (cannot connect to huggingface.co),
|
|
13
|
+
False otherwise.
|
|
14
|
+
"""
|
|
15
|
+
if os.environ.get("LEMONADE_OFFLINE"):
|
|
16
|
+
return True
|
|
17
|
+
try:
|
|
18
|
+
socket.gethostbyname("huggingface.co")
|
|
19
|
+
return False
|
|
20
|
+
except socket.gaierror:
|
|
21
|
+
return True
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_base_model(checkpoint: str) -> Optional[str]:
|
|
25
|
+
"""
|
|
26
|
+
Get the base model information for a given checkpoint from the Hugging Face Hub.
|
|
27
|
+
Will auto-detect if we're offline and skip the network call in that case.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
checkpoint: The model checkpoint to query
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The base model name if found, or None if not found or error occurs
|
|
34
|
+
"""
|
|
35
|
+
# Skip network call in offline mode
|
|
36
|
+
if is_offline():
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
info = model_info(checkpoint)
|
|
41
|
+
if info.cardData and "base_model" in info.cardData:
|
|
42
|
+
if info.cardData["base_model"] is not None:
|
|
43
|
+
# This is a derived model
|
|
44
|
+
return info.cardData["base_model"]
|
|
45
|
+
else:
|
|
46
|
+
# This is itself a base model
|
|
47
|
+
return [checkpoint]
|
|
48
|
+
except Exception: # pylint: disable=broad-except
|
|
49
|
+
pass
|
|
50
|
+
return None
|
lemonade/common/status.py
CHANGED
|
@@ -7,12 +7,10 @@ import dataclasses
|
|
|
7
7
|
from typing import Callable, List, Union, Dict, Optional
|
|
8
8
|
import textwrap
|
|
9
9
|
import psutil
|
|
10
|
-
import torch
|
|
11
10
|
from lemonade.common import printing
|
|
12
11
|
from lemonade.state import State
|
|
13
12
|
import lemonade.common.build as build
|
|
14
13
|
import lemonade.common.filesystem as fs
|
|
15
|
-
import lemonade.common.analyze_model as analyze_model
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
def _pretty_print_key(key: str) -> str:
|
|
@@ -64,7 +62,6 @@ class SkipFields:
|
|
|
64
62
|
|
|
65
63
|
file_name: bool = False
|
|
66
64
|
model_name: bool = False
|
|
67
|
-
parameters: bool = False
|
|
68
65
|
location: bool = False
|
|
69
66
|
input_shape: bool = False
|
|
70
67
|
build_dir: bool = False
|
|
@@ -147,18 +144,6 @@ class UniqueInvocationInfo(BasicInfo):
|
|
|
147
144
|
print(f", line {self.line}")
|
|
148
145
|
self.skip.location = True
|
|
149
146
|
|
|
150
|
-
def _print_parameters(self):
|
|
151
|
-
if self.skip.parameters or self.params is None:
|
|
152
|
-
return
|
|
153
|
-
|
|
154
|
-
# Display number of parameters and size
|
|
155
|
-
parameters_size = parameters_to_size(self.params)
|
|
156
|
-
print(
|
|
157
|
-
f"{self.indent}\tParameters:\t{'{:,}'.format(self.params)} ({parameters_size})"
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
self.skip.parameters = True
|
|
161
|
-
|
|
162
147
|
def _print_unique_input_shape(
|
|
163
148
|
self,
|
|
164
149
|
exec_time_formatted: str,
|
|
@@ -348,7 +333,6 @@ class UniqueInvocationInfo(BasicInfo):
|
|
|
348
333
|
if (self.depth == 0 and not model_visited) or (self.depth != 0):
|
|
349
334
|
# Print this information only once per model
|
|
350
335
|
self._print_location()
|
|
351
|
-
self._print_parameters()
|
|
352
336
|
self._print_unique_input_shape(
|
|
353
337
|
exec_time_formatted, invocation_idx, multiple_unique_invocations
|
|
354
338
|
)
|
|
@@ -362,16 +346,13 @@ class UniqueInvocationInfo(BasicInfo):
|
|
|
362
346
|
|
|
363
347
|
@dataclasses.dataclass
|
|
364
348
|
class ModelInfo(BasicInfo):
|
|
365
|
-
model:
|
|
349
|
+
model: str = None
|
|
366
350
|
old_forward: Union[Callable, None] = None
|
|
367
351
|
unique_invocations: Union[Dict[str, UniqueInvocationInfo], None] = (
|
|
368
352
|
dataclasses.field(default_factory=dict)
|
|
369
353
|
)
|
|
370
354
|
last_unique_invocation_executed: Union[str, None] = None
|
|
371
355
|
|
|
372
|
-
def __post_init__(self):
|
|
373
|
-
self.params = analyze_model.count_parameters(self.model)
|
|
374
|
-
|
|
375
356
|
|
|
376
357
|
def recursive_print(
|
|
377
358
|
models_found: Dict[str, ModelInfo],
|
|
@@ -447,7 +428,7 @@ def stop_logger_forward() -> None:
|
|
|
447
428
|
def add_to_state(
|
|
448
429
|
state: State,
|
|
449
430
|
name: str,
|
|
450
|
-
model:
|
|
431
|
+
model: str,
|
|
451
432
|
extension: str = "",
|
|
452
433
|
input_shapes: Optional[Dict] = None,
|
|
453
434
|
):
|
lemonade/common/system_info.py
CHANGED
|
@@ -3,6 +3,7 @@ import importlib.metadata
|
|
|
3
3
|
import platform
|
|
4
4
|
import re
|
|
5
5
|
import subprocess
|
|
6
|
+
import ctypes
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class SystemInfo(ABC):
|
|
@@ -184,11 +185,25 @@ class WindowsSystemInfo(SystemInfo):
|
|
|
184
185
|
str: Windows power setting.
|
|
185
186
|
"""
|
|
186
187
|
try:
|
|
187
|
-
|
|
188
|
-
|
|
188
|
+
# Capture output as bytes
|
|
189
|
+
out_bytes = subprocess.check_output(["powercfg", "/getactivescheme"])
|
|
190
|
+
|
|
191
|
+
# Get system's OEM code page (e.g., cp437, cp850)
|
|
192
|
+
oem_cp = "cp" + str(ctypes.windll.kernel32.GetOEMCP())
|
|
193
|
+
|
|
194
|
+
# Decode using detected OEM code page
|
|
195
|
+
out = out_bytes.decode(oem_cp)
|
|
196
|
+
|
|
197
|
+
# Extract power scheme name from parentheses
|
|
198
|
+
match = re.search(r"\((.*?)\)", out)
|
|
199
|
+
if match:
|
|
200
|
+
return match.group(1)
|
|
201
|
+
return "Power scheme name not found in output"
|
|
202
|
+
|
|
189
203
|
except subprocess.CalledProcessError:
|
|
190
|
-
|
|
191
|
-
|
|
204
|
+
return "Windows power setting not found (command failed)"
|
|
205
|
+
except Exception as e: # pylint: disable=broad-except
|
|
206
|
+
return f"Error retrieving power setting: {str(e)}"
|
|
192
207
|
|
|
193
208
|
def get_dict(self) -> dict:
|
|
194
209
|
"""
|
|
@@ -2,7 +2,6 @@ import os
|
|
|
2
2
|
import time
|
|
3
3
|
import textwrap
|
|
4
4
|
from multiprocessing import Process, Queue
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
5
|
import psutil
|
|
7
6
|
import yaml
|
|
8
7
|
import lemonade.common.filesystem as fs
|
|
@@ -101,6 +100,9 @@ class MemoryTracker(Profiler):
|
|
|
101
100
|
self.tracking_active = False
|
|
102
101
|
|
|
103
102
|
def generate_results(self, state, timestamp, _):
|
|
103
|
+
|
|
104
|
+
import matplotlib.pyplot as plt
|
|
105
|
+
|
|
104
106
|
if self.tracker_process is None:
|
|
105
107
|
return
|
|
106
108
|
|
lemonade/tools/accuracy.py
CHANGED
|
@@ -7,15 +7,11 @@ import sys
|
|
|
7
7
|
import time
|
|
8
8
|
from typing import Optional
|
|
9
9
|
|
|
10
|
-
import requests
|
|
11
|
-
|
|
12
10
|
from lemonade.state import State
|
|
13
11
|
from lemonade.tools import Tool
|
|
14
12
|
import lemonade.common.printing as printing
|
|
15
13
|
import lemonade.common.build as build
|
|
16
14
|
|
|
17
|
-
from lemonade.tools.server.thread_utils import ServerRunner
|
|
18
|
-
|
|
19
15
|
|
|
20
16
|
def is_port_in_use(port, host="localhost"):
|
|
21
17
|
"""
|
|
@@ -193,6 +189,9 @@ class LMEvalHarness(Tool):
|
|
|
193
189
|
output_path: Optional[str] = None,
|
|
194
190
|
) -> State:
|
|
195
191
|
|
|
192
|
+
import requests
|
|
193
|
+
from lemonade.tools.server.utils.thread import ServerRunner
|
|
194
|
+
|
|
196
195
|
model = state.model
|
|
197
196
|
tokenizer = state.tokenizer
|
|
198
197
|
|
lemonade/tools/adapter.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from transformers import AutoTokenizer
|
|
3
2
|
|
|
4
3
|
|
|
5
4
|
class ModelAdapter(abc.ABC):
|
|
@@ -31,7 +30,7 @@ class TokenizerAdapter(abc.ABC):
|
|
|
31
30
|
Base class for adapting an LLM's tokenizer to work with lemonade's standard tools
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
|
-
def __init__(self, tokenizer
|
|
33
|
+
def __init__(self, tokenizer=None):
|
|
35
34
|
self.auto_tokenizer = tokenizer
|
|
36
35
|
|
|
37
36
|
@abc.abstractmethod
|
|
@@ -1,10 +1,6 @@
|
|
|
1
1
|
import argparse
|
|
2
|
-
from typing import List, Tuple
|
|
3
|
-
import time
|
|
4
2
|
import statistics
|
|
5
3
|
from statistics import StatisticsError
|
|
6
|
-
from contextlib import nullcontext
|
|
7
|
-
import torch
|
|
8
4
|
from lemonade.state import State
|
|
9
5
|
from lemonade.cache import Keys
|
|
10
6
|
from lemonade.tools.bench import Bench
|
|
@@ -12,89 +8,6 @@ from lemonade.tools.bench import Bench
|
|
|
12
8
|
default_beams = 1
|
|
13
9
|
|
|
14
10
|
|
|
15
|
-
def benchmark_huggingface_llm(
|
|
16
|
-
model: torch.nn.Module,
|
|
17
|
-
tokenizer,
|
|
18
|
-
input_ids,
|
|
19
|
-
dtype,
|
|
20
|
-
num_beams: int,
|
|
21
|
-
target_output_tokens: int,
|
|
22
|
-
iterations: int,
|
|
23
|
-
warmup_iterations: int,
|
|
24
|
-
report_progress_fn,
|
|
25
|
-
) -> List[Tuple[float, int]]:
|
|
26
|
-
|
|
27
|
-
amp_enabled = True if (dtype == torch.float16 or dtype == torch.bfloat16) else False
|
|
28
|
-
# The "if amp_enabled else nullcontext()" is to get around a bug in PyTorch 2.1
|
|
29
|
-
# where torch.cpu.amp.autocast(enabled=False) does nothing
|
|
30
|
-
with (
|
|
31
|
-
torch.cpu.amp.autocast(enabled=amp_enabled, dtype=dtype)
|
|
32
|
-
if amp_enabled
|
|
33
|
-
else nullcontext()
|
|
34
|
-
):
|
|
35
|
-
|
|
36
|
-
per_iteration_result = []
|
|
37
|
-
tokens_out_len_list = []
|
|
38
|
-
|
|
39
|
-
# Early stopping is only a valid parameter with multiple beams
|
|
40
|
-
early_stopping = num_beams > 1
|
|
41
|
-
|
|
42
|
-
with torch.no_grad(), torch.inference_mode():
|
|
43
|
-
# Don't capture time for warmup
|
|
44
|
-
for count in range(warmup_iterations):
|
|
45
|
-
outputs = model.generate(
|
|
46
|
-
input_ids,
|
|
47
|
-
num_beams=num_beams,
|
|
48
|
-
max_new_tokens=target_output_tokens,
|
|
49
|
-
min_new_tokens=target_output_tokens,
|
|
50
|
-
early_stopping=early_stopping,
|
|
51
|
-
pad_token_id=tokenizer.eos_token_id,
|
|
52
|
-
)
|
|
53
|
-
tokens_out_len_list.append(outputs.shape[1] - input_ids.shape[1])
|
|
54
|
-
report_progress_fn((count + 1) / (warmup_iterations + iterations))
|
|
55
|
-
|
|
56
|
-
for count in range(iterations):
|
|
57
|
-
# CUDA synchronization is required prior to GPU benchmarking
|
|
58
|
-
# This has no negative effect on CPU-only benchmarks, and is more robust than
|
|
59
|
-
# checking `model.device == "cuda"` since it applies to multi-GPU environments
|
|
60
|
-
# Synchronization is done before collecting the start time because this will
|
|
61
|
-
# ensure that the GPU has finished initialization tasks such as loading weights
|
|
62
|
-
if torch.cuda.is_available():
|
|
63
|
-
torch.cuda.synchronize()
|
|
64
|
-
start_time = time.perf_counter()
|
|
65
|
-
|
|
66
|
-
outputs = model.generate(
|
|
67
|
-
input_ids,
|
|
68
|
-
num_beams=num_beams,
|
|
69
|
-
max_new_tokens=target_output_tokens,
|
|
70
|
-
min_new_tokens=target_output_tokens,
|
|
71
|
-
early_stopping=early_stopping,
|
|
72
|
-
pad_token_id=tokenizer.eos_token_id,
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
if torch.cuda.is_available():
|
|
76
|
-
torch.cuda.synchronize()
|
|
77
|
-
end_time = time.perf_counter()
|
|
78
|
-
|
|
79
|
-
latency = end_time - start_time
|
|
80
|
-
|
|
81
|
-
token_len = outputs.shape[1] - input_ids.shape[1]
|
|
82
|
-
tokens_out_len_list.append(token_len)
|
|
83
|
-
|
|
84
|
-
# Only count an iteration if it produced enough tokens
|
|
85
|
-
if token_len >= target_output_tokens:
|
|
86
|
-
per_iteration_result.append((latency, token_len))
|
|
87
|
-
|
|
88
|
-
report_progress_fn(
|
|
89
|
-
(warmup_iterations + count + 1) / (warmup_iterations + iterations)
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
if not per_iteration_result:
|
|
93
|
-
raise Bench.not_enough_tokens(target_output_tokens)
|
|
94
|
-
|
|
95
|
-
return per_iteration_result, tokens_out_len_list
|
|
96
|
-
|
|
97
|
-
|
|
98
11
|
class HuggingfaceBench(Bench):
|
|
99
12
|
"""
|
|
100
13
|
Benchmarks the performance of the generate() method of an LLM loaded from
|
|
@@ -171,6 +84,8 @@ class HuggingfaceBench(Bench):
|
|
|
171
84
|
tokens_per_second = (new_tokens - 1) / (execution_latency - prefill_latency)
|
|
172
85
|
"""
|
|
173
86
|
|
|
87
|
+
from lemonade.tools.huggingface.utils import benchmark_huggingface_llm
|
|
88
|
+
|
|
174
89
|
if self.first_run_prompt:
|
|
175
90
|
if vars(state).get(Keys.MODEL) is None:
|
|
176
91
|
raise ValueError(
|