webscout 7.4__py3-none-any.whl → 7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of webscout might be problematic. Click here for more details.
- webscout/AIauto.py +5 -53
- webscout/AIutel.py +8 -318
- webscout/DWEBS.py +460 -489
- webscout/Extra/YTToolkit/YTdownloader.py +14 -53
- webscout/Extra/YTToolkit/transcriber.py +12 -13
- webscout/Extra/YTToolkit/ytapi/video.py +0 -1
- webscout/Extra/__init__.py +0 -1
- webscout/Extra/autocoder/autocoder_utiles.py +0 -4
- webscout/Extra/autocoder/rawdog.py +13 -41
- webscout/Extra/gguf.py +652 -428
- webscout/Extra/weather.py +178 -156
- webscout/Extra/weather_ascii.py +70 -17
- webscout/Litlogger/core/logger.py +1 -2
- webscout/Litlogger/handlers/file.py +1 -1
- webscout/Litlogger/styles/formats.py +0 -2
- webscout/Litlogger/utils/detectors.py +0 -1
- webscout/Provider/AISEARCH/DeepFind.py +0 -1
- webscout/Provider/AISEARCH/ISou.py +1 -1
- webscout/Provider/AISEARCH/felo_search.py +0 -1
- webscout/Provider/AllenAI.py +24 -9
- webscout/Provider/C4ai.py +432 -0
- webscout/Provider/ChatGPTGratis.py +24 -56
- webscout/Provider/Cloudflare.py +18 -21
- webscout/Provider/DeepSeek.py +27 -48
- webscout/Provider/Deepinfra.py +129 -53
- webscout/Provider/Gemini.py +1 -1
- webscout/Provider/GithubChat.py +362 -0
- webscout/Provider/Glider.py +25 -8
- webscout/Provider/HF_space/qwen_qwen2.py +2 -2
- webscout/Provider/HeckAI.py +38 -5
- webscout/Provider/HuggingFaceChat.py +462 -0
- webscout/Provider/Jadve.py +20 -5
- webscout/Provider/Marcus.py +7 -50
- webscout/Provider/Netwrck.py +43 -67
- webscout/Provider/PI.py +4 -2
- webscout/Provider/Perplexitylabs.py +26 -6
- webscout/Provider/Phind.py +29 -3
- webscout/Provider/PizzaGPT.py +10 -51
- webscout/Provider/TTI/AiForce/async_aiforce.py +4 -37
- webscout/Provider/TTI/AiForce/sync_aiforce.py +41 -38
- webscout/Provider/TTI/FreeAIPlayground/__init__.py +9 -9
- webscout/Provider/TTI/FreeAIPlayground/async_freeaiplayground.py +206 -206
- webscout/Provider/TTI/FreeAIPlayground/sync_freeaiplayground.py +192 -192
- webscout/Provider/TTI/MagicStudio/__init__.py +2 -0
- webscout/Provider/TTI/MagicStudio/async_magicstudio.py +111 -0
- webscout/Provider/TTI/MagicStudio/sync_magicstudio.py +109 -0
- webscout/Provider/TTI/PollinationsAI/async_pollinations.py +5 -24
- webscout/Provider/TTI/PollinationsAI/sync_pollinations.py +2 -22
- webscout/Provider/TTI/__init__.py +2 -3
- webscout/Provider/TTI/aiarta/__init__.py +2 -0
- webscout/Provider/TTI/aiarta/async_aiarta.py +482 -0
- webscout/Provider/TTI/aiarta/sync_aiarta.py +440 -0
- webscout/Provider/TTI/fastflux/__init__.py +22 -0
- webscout/Provider/TTI/fastflux/async_fastflux.py +257 -0
- webscout/Provider/TTI/fastflux/sync_fastflux.py +247 -0
- webscout/Provider/TTS/__init__.py +2 -2
- webscout/Provider/TTS/deepgram.py +12 -39
- webscout/Provider/TTS/elevenlabs.py +14 -40
- webscout/Provider/TTS/gesserit.py +11 -35
- webscout/Provider/TTS/murfai.py +13 -39
- webscout/Provider/TTS/parler.py +17 -40
- webscout/Provider/TTS/speechma.py +180 -0
- webscout/Provider/TTS/streamElements.py +17 -44
- webscout/Provider/TextPollinationsAI.py +39 -59
- webscout/Provider/Venice.py +217 -200
- webscout/Provider/WiseCat.py +27 -5
- webscout/Provider/Youchat.py +63 -36
- webscout/Provider/__init__.py +13 -8
- webscout/Provider/akashgpt.py +28 -10
- webscout/Provider/copilot.py +416 -0
- webscout/Provider/flowith.py +196 -0
- webscout/Provider/freeaichat.py +32 -45
- webscout/Provider/granite.py +17 -53
- webscout/Provider/koala.py +20 -5
- webscout/Provider/llamatutor.py +7 -47
- webscout/Provider/llmchat.py +36 -53
- webscout/Provider/multichat.py +92 -98
- webscout/Provider/talkai.py +1 -0
- webscout/Provider/turboseek.py +3 -0
- webscout/Provider/tutorai.py +2 -0
- webscout/Provider/typegpt.py +154 -64
- webscout/Provider/x0gpt.py +3 -1
- webscout/Provider/yep.py +102 -20
- webscout/__init__.py +3 -0
- webscout/cli.py +4 -40
- webscout/conversation.py +1 -10
- webscout/exceptions.py +19 -9
- webscout/litagent/__init__.py +2 -2
- webscout/litagent/agent.py +351 -20
- webscout/litagent/constants.py +34 -5
- webscout/litprinter/__init__.py +0 -3
- webscout/models.py +181 -0
- webscout/optimizers.py +1 -1
- webscout/prompt_manager.py +2 -8
- webscout/scout/core/scout.py +1 -4
- webscout/scout/core/search_result.py +1 -1
- webscout/scout/core/text_utils.py +1 -1
- webscout/scout/core.py +2 -5
- webscout/scout/element.py +1 -1
- webscout/scout/parsers/html_parser.py +1 -1
- webscout/scout/utils.py +0 -1
- webscout/swiftcli/__init__.py +1 -3
- webscout/tempid.py +1 -1
- webscout/update_checker.py +55 -95
- webscout/version.py +1 -1
- webscout/webscout_search_async.py +1 -2
- webscout/yep_search.py +297 -297
- webscout-7.6.dist-info/LICENSE.md +146 -0
- {webscout-7.4.dist-info → webscout-7.6.dist-info}/METADATA +104 -514
- {webscout-7.4.dist-info → webscout-7.6.dist-info}/RECORD +113 -120
- webscout/Extra/autollama.py +0 -231
- webscout/Local/__init__.py +0 -10
- webscout/Local/_version.py +0 -3
- webscout/Local/formats.py +0 -747
- webscout/Local/model.py +0 -1368
- webscout/Local/samplers.py +0 -125
- webscout/Local/thread.py +0 -539
- webscout/Local/ui.py +0 -401
- webscout/Local/utils.py +0 -388
- webscout/Provider/Amigo.py +0 -274
- webscout/Provider/Bing.py +0 -243
- webscout/Provider/DiscordRocks.py +0 -253
- webscout/Provider/TTI/blackbox/__init__.py +0 -4
- webscout/Provider/TTI/blackbox/async_blackbox.py +0 -212
- webscout/Provider/TTI/blackbox/sync_blackbox.py +0 -199
- webscout/Provider/TTI/deepinfra/__init__.py +0 -4
- webscout/Provider/TTI/deepinfra/async_deepinfra.py +0 -227
- webscout/Provider/TTI/deepinfra/sync_deepinfra.py +0 -199
- webscout/Provider/TTI/imgninza/__init__.py +0 -4
- webscout/Provider/TTI/imgninza/async_ninza.py +0 -214
- webscout/Provider/TTI/imgninza/sync_ninza.py +0 -209
- webscout/Provider/TTS/voicepod.py +0 -117
- webscout/Provider/dgaf.py +0 -214
- webscout-7.4.dist-info/LICENSE.md +0 -211
- {webscout-7.4.dist-info → webscout-7.6.dist-info}/WHEEL +0 -0
- {webscout-7.4.dist-info → webscout-7.6.dist-info}/entry_points.txt +0 -0
- {webscout-7.4.dist-info → webscout-7.6.dist-info}/top_level.txt +0 -0
webscout/Local/model.py
DELETED
|
@@ -1,1368 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import sys
|
|
5
|
-
import uuid
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
from .utils import (
|
|
9
|
-
_SupportsWriteAndFlush,
|
|
10
|
-
UnreachableException,
|
|
11
|
-
print_version_info,
|
|
12
|
-
QuickGGUFReader,
|
|
13
|
-
print_warning,
|
|
14
|
-
print_verbose,
|
|
15
|
-
assert_type,
|
|
16
|
-
NoneType,
|
|
17
|
-
truncate,
|
|
18
|
-
softmax
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
from llama_cpp import Llama, StoppingCriteriaList
|
|
22
|
-
from typing import Generator, Optional
|
|
23
|
-
from .samplers import SamplerSettings
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from webscout import exceptions
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class Model:
|
|
30
|
-
"""
|
|
31
|
-
A high-level abstraction of a Llama model
|
|
32
|
-
|
|
33
|
-
The following methods are available:
|
|
34
|
-
- unload:
|
|
35
|
-
Unload the model from memory
|
|
36
|
-
- reload:
|
|
37
|
-
Re-load the model, optionally changing parameters
|
|
38
|
-
- load:
|
|
39
|
-
Load the model into memory
|
|
40
|
-
- is_loaded:
|
|
41
|
-
Return `True` if the model is fully loaded, `False` otherwise
|
|
42
|
-
- tokenize:
|
|
43
|
-
Tokenize the given text, from `str` to `list[int]`
|
|
44
|
-
- detokenize:
|
|
45
|
-
Detokenize the given text, from `list[int]` or `int` to `str`
|
|
46
|
-
- get_length:
|
|
47
|
-
Return the length of the given text as measured in tokens
|
|
48
|
-
- get_tokenization_mapping:
|
|
49
|
-
Return a mapping of token IDs to tokens for a given text
|
|
50
|
-
- print_tokenization_mapping:
|
|
51
|
-
Display the tokenization map for a given text
|
|
52
|
-
- generate:
|
|
53
|
-
Generate text from an input and return it all at once when finished
|
|
54
|
-
- stream:
|
|
55
|
-
Return a Generator that yields tokens as they are generated
|
|
56
|
-
- stream_print:
|
|
57
|
-
Stream tokens to a file as they are generated
|
|
58
|
-
- ingest:
|
|
59
|
-
Ingest the given text into the model's cache, reducing the latency of
|
|
60
|
-
future generations that start with the same text
|
|
61
|
-
- candidates:
|
|
62
|
-
Return a sorted list of candidates for the next token, along with
|
|
63
|
-
their normalized probabilities
|
|
64
|
-
- print_candidates:
|
|
65
|
-
Print a sorted list of candidates for the next token, along with
|
|
66
|
-
their normalized probabilities
|
|
67
|
-
|
|
68
|
-
The following attributes are available:
|
|
69
|
-
- verbose `bool`:
|
|
70
|
-
Whether the model was loaded with `verbose=True`
|
|
71
|
-
- metadata `dict`:
|
|
72
|
-
A dictionary containing the GGUF metadata of the model
|
|
73
|
-
- context_length `int`:
|
|
74
|
-
The currently loaded context length of the model, in tokens
|
|
75
|
-
- n_ctx `int`:
|
|
76
|
-
Alias to context_length
|
|
77
|
-
- llama `llama_cpp.Llama`:
|
|
78
|
-
The underlying Llama instance
|
|
79
|
-
- vocab `list[str]`:
|
|
80
|
-
A list of all tokens in the model's vocabulary
|
|
81
|
-
- bos_token `int`:
|
|
82
|
-
The beginning-of-sequence token ID
|
|
83
|
-
- eos_token `int`:
|
|
84
|
-
The end-of-sequence token ID
|
|
85
|
-
- eot_token `int`:
|
|
86
|
-
The end-of-turn token ID (or `None` if not found)
|
|
87
|
-
- nl_token `int`:
|
|
88
|
-
The newline token ID (or `None` if not found)
|
|
89
|
-
- prefix_token `int`:
|
|
90
|
-
The infill prefix token ID (or `None` if not found)
|
|
91
|
-
- middle_token `int`:
|
|
92
|
-
The infill middle token ID (or `None` if not found)
|
|
93
|
-
- suffix_token `int`:
|
|
94
|
-
The infill suffix token ID (or `None` if not found)
|
|
95
|
-
- cls_token `int`:
|
|
96
|
-
The classifier token ID (or `None` if not found)
|
|
97
|
-
- sep_token `int`:
|
|
98
|
-
The separator token ID (or `None` if not found)
|
|
99
|
-
- filename `str`:
|
|
100
|
-
The name of the file the model was loaded from
|
|
101
|
-
- n_ctx_train `int`:
|
|
102
|
-
The native context length of the model
|
|
103
|
-
- rope_freq_base_train `float`:
|
|
104
|
-
The native RoPE frequency base (theta) value
|
|
105
|
-
- rope_freq_base `float`:
|
|
106
|
-
The currently loaded RoPE frequency base (theta) value
|
|
107
|
-
- flash_attn `bool`:
|
|
108
|
-
Whether the model was loaded with Flash Attention enabled
|
|
109
|
-
- n_vocab `int`:
|
|
110
|
-
The number of tokens in the model's vocabulary
|
|
111
|
-
- n_layer `int`:
|
|
112
|
-
The number of layers in the model
|
|
113
|
-
- n_gpu_layers `int`:
|
|
114
|
-
The number of layers offloaded to the GPU (-1 for all layers)
|
|
115
|
-
- type_k `int`:
|
|
116
|
-
The GGML data type used for the `K` cache. 1 == f16, q8_0 otherwise
|
|
117
|
-
- type_v `int`:
|
|
118
|
-
The GGML data type used for the `V` cache. 1 == f16, q8_0 otherwise
|
|
119
|
-
- n_gqa `int`:
|
|
120
|
-
The GQA (Grouped-Query Attention) factor of the model
|
|
121
|
-
- uuid `uuid.UUID`:
|
|
122
|
-
A randomly generated UUID, unique to this specific model instance
|
|
123
|
-
"""
|
|
124
|
-
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
model_path: str,
|
|
128
|
-
context_length: Optional[int] = 2048,
|
|
129
|
-
n_gpu_layers: int = 0,
|
|
130
|
-
offload_kqv: bool = True,
|
|
131
|
-
flash_attn: bool = False,
|
|
132
|
-
quantize_kv_cache: bool = False,
|
|
133
|
-
verbose: bool = False,
|
|
134
|
-
**kwargs
|
|
135
|
-
):
|
|
136
|
-
"""
|
|
137
|
-
Given the path to a GGUF file, construct a Model instance.
|
|
138
|
-
|
|
139
|
-
The model must be in GGUF format.
|
|
140
|
-
|
|
141
|
-
The following parameters are optional:
|
|
142
|
-
- context_length:
|
|
143
|
-
The context length at which to load the model, in tokens
|
|
144
|
-
- n_gpu_layers:
|
|
145
|
-
The number of layers to be offloaded to the GPU
|
|
146
|
-
- offload_kqv:
|
|
147
|
-
Whether the KQV cache (context) should be offloaded
|
|
148
|
-
- flash_attn:
|
|
149
|
-
Whether to use Flash Attention
|
|
150
|
-
- quantize_kv_cache:
|
|
151
|
-
Whether to use q8_0 values for KV cache
|
|
152
|
-
- verbose:
|
|
153
|
-
Whether to print additional backend information. `bool`
|
|
154
|
-
|
|
155
|
-
The following additional keyword arguments are also accepted:
|
|
156
|
-
- do_not_load:
|
|
157
|
-
If `True`, construct the model instance but do not load it into
|
|
158
|
-
memory yet. Call `Model.load()` before using the model
|
|
159
|
-
- debug:
|
|
160
|
-
If `True`, print additional backend information from llama.cpp
|
|
161
|
-
"""
|
|
162
|
-
|
|
163
|
-
assert_type(verbose, bool, 'verbose', 'Model')
|
|
164
|
-
assert_type(model_path, str, 'model_path', 'Model')
|
|
165
|
-
if not os.path.exists(model_path):
|
|
166
|
-
raise FileNotFoundError(
|
|
167
|
-
f"Model: the given model_path {model_path!r} does not exist"
|
|
168
|
-
)
|
|
169
|
-
if os.path.isdir(model_path):
|
|
170
|
-
raise IsADirectoryError(
|
|
171
|
-
f"Model: the given model_path {model_path!r} is a directory, "
|
|
172
|
-
"not a GGUF file"
|
|
173
|
-
)
|
|
174
|
-
assert_type(context_length, (int, NoneType), 'context_length', 'Model')
|
|
175
|
-
assert_type(n_gpu_layers, int, 'n_gpu_layers', 'Model')
|
|
176
|
-
assert_type(offload_kqv, bool, 'offload_kqv', 'Model')
|
|
177
|
-
assert_type(flash_attn, bool, 'flash_attn', 'Model')
|
|
178
|
-
assert_type(quantize_kv_cache, bool, 'quantize_kv_cache', 'Model')
|
|
179
|
-
|
|
180
|
-
# save __init__ parameters for __repr__
|
|
181
|
-
self._model_path = model_path
|
|
182
|
-
self._context_length = context_length
|
|
183
|
-
self._n_gpu_layers = n_gpu_layers
|
|
184
|
-
self._offload_kqv = offload_kqv
|
|
185
|
-
self._flash_attn = flash_attn
|
|
186
|
-
self._verbose = self.verbose = verbose
|
|
187
|
-
self._quantize_kv_cache = quantize_kv_cache
|
|
188
|
-
|
|
189
|
-
_kwargs_keys = kwargs.keys() # only read once
|
|
190
|
-
|
|
191
|
-
if '__uuid' not in _kwargs_keys:
|
|
192
|
-
self.uuid = uuid.uuid4()
|
|
193
|
-
else:
|
|
194
|
-
# Model.reload() passes this kwarg to preserve the UUID
|
|
195
|
-
self.uuid = kwargs.get('__uuid')
|
|
196
|
-
|
|
197
|
-
if 'do_not_load' in _kwargs_keys:
|
|
198
|
-
if kwargs.get('do_not_load') is True:
|
|
199
|
-
# only save __init__ params to be used later in self.load()
|
|
200
|
-
return
|
|
201
|
-
|
|
202
|
-
if verbose:
|
|
203
|
-
print_version_info(file=sys.stderr)
|
|
204
|
-
|
|
205
|
-
if sys.byteorder == 'big':
|
|
206
|
-
print_warning(
|
|
207
|
-
"host is big-endian, please ensure your GGUF file is also "
|
|
208
|
-
"big-endian"
|
|
209
|
-
)
|
|
210
|
-
elif sys.byteorder == 'little':
|
|
211
|
-
if verbose:
|
|
212
|
-
print_verbose(
|
|
213
|
-
"host is little-endian"
|
|
214
|
-
)
|
|
215
|
-
else:
|
|
216
|
-
print_warning(
|
|
217
|
-
f"unexpected value for sys.byteorder: {sys.byteorder!r}, "
|
|
218
|
-
"expected 'little' for little-endian host or 'big' for "
|
|
219
|
-
"big-endian host"
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
self._model_file_size_bytes = os.stat(model_path).st_size
|
|
223
|
-
self.metadata = QuickGGUFReader.load_metadata(model_path)
|
|
224
|
-
|
|
225
|
-
_debug = False
|
|
226
|
-
if 'debug' in _kwargs_keys:
|
|
227
|
-
_debug = bool(kwargs.get('debug'))
|
|
228
|
-
|
|
229
|
-
if verbose and not _debug:
|
|
230
|
-
__class__._print_metadata(self.metadata)
|
|
231
|
-
|
|
232
|
-
n_ctx_train = None
|
|
233
|
-
rope_freq_base_train = None
|
|
234
|
-
n_layer = None
|
|
235
|
-
n_attn_heads = None
|
|
236
|
-
n_kv_heads = None
|
|
237
|
-
n_gqa = None
|
|
238
|
-
|
|
239
|
-
for key in self.metadata.keys():
|
|
240
|
-
if key.endswith('.context_length'):
|
|
241
|
-
n_ctx_train = int(self.metadata[key])
|
|
242
|
-
elif key.endswith('.rope.freq_base'):
|
|
243
|
-
rope_freq_base_train = float(self.metadata[key])
|
|
244
|
-
elif key.endswith('.block_count'):
|
|
245
|
-
n_layer = int(self.metadata[key])
|
|
246
|
-
elif key.endswith('.attention.head_count'):
|
|
247
|
-
n_attn_heads = int(self.metadata[key])
|
|
248
|
-
elif key.endswith('.attention.head_count_kv'):
|
|
249
|
-
n_kv_heads = int(self.metadata[key])
|
|
250
|
-
|
|
251
|
-
if n_layer is None:
|
|
252
|
-
exc = KeyError(
|
|
253
|
-
f"GGUF file metadata does not specify n_layer"
|
|
254
|
-
)
|
|
255
|
-
exc.add_note(
|
|
256
|
-
f"GGUF file is at {self._model_path!r}"
|
|
257
|
-
)
|
|
258
|
-
raise exc
|
|
259
|
-
|
|
260
|
-
if n_ctx_train is None:
|
|
261
|
-
exc = KeyError(
|
|
262
|
-
f"GGUF file metadata does not specify a context length"
|
|
263
|
-
)
|
|
264
|
-
exc.add_note(
|
|
265
|
-
f"GGUF file is at {self._model_path!r}"
|
|
266
|
-
)
|
|
267
|
-
raise exc
|
|
268
|
-
|
|
269
|
-
if n_attn_heads is not None and n_kv_heads is not None:
|
|
270
|
-
n_gqa = int(n_attn_heads / n_kv_heads)
|
|
271
|
-
|
|
272
|
-
if context_length <= 0:
|
|
273
|
-
context_length = None
|
|
274
|
-
|
|
275
|
-
rope_freq_base = __class__._calculate_rope_freq_base(
|
|
276
|
-
n_ctx_train,
|
|
277
|
-
context_length if context_length is not None else n_ctx_train,
|
|
278
|
-
rope_freq_base_train
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
if context_length is None:
|
|
282
|
-
if n_ctx_train > 32768:
|
|
283
|
-
print_warning(
|
|
284
|
-
f"you did not specify a context length, and the native "
|
|
285
|
-
f"context length of this model is very large "
|
|
286
|
-
f"({n_ctx_train}). defaulting to 32768 to avoid "
|
|
287
|
-
f"out-of-memory errors. you should specify a higher "
|
|
288
|
-
f"context length if you need it"
|
|
289
|
-
)
|
|
290
|
-
self.context_length = self.n_ctx = 32768
|
|
291
|
-
else:
|
|
292
|
-
self.context_length = self.n_ctx = n_ctx_train
|
|
293
|
-
|
|
294
|
-
elif context_length <= n_ctx_train:
|
|
295
|
-
self.context_length = self.n_ctx = context_length
|
|
296
|
-
|
|
297
|
-
elif context_length > n_ctx_train:
|
|
298
|
-
print_warning(
|
|
299
|
-
f"you have specified a context length that is greater than "
|
|
300
|
-
f"the natively supported context length of this model "
|
|
301
|
-
f"({context_length} > {n_ctx_train}). the model will still "
|
|
302
|
-
f"work, but the quality of output may be subpar. consider "
|
|
303
|
-
f"decreasing the context length to {n_ctx_train} or lower "
|
|
304
|
-
f"for best results"
|
|
305
|
-
)
|
|
306
|
-
self.context_length = self.n_ctx = context_length
|
|
307
|
-
|
|
308
|
-
else:
|
|
309
|
-
raise UnreachableException
|
|
310
|
-
|
|
311
|
-
cpu_count = int(os.cpu_count()) # only read once
|
|
312
|
-
|
|
313
|
-
if n_gpu_layers < 0 or n_gpu_layers > n_layer:
|
|
314
|
-
n_gpu_layers = n_layer
|
|
315
|
-
|
|
316
|
-
if n_gpu_layers == n_layer:
|
|
317
|
-
# fully offloaded
|
|
318
|
-
n_batch = 1024
|
|
319
|
-
else:
|
|
320
|
-
# partially offloaded
|
|
321
|
-
n_batch = 512
|
|
322
|
-
|
|
323
|
-
# NOTE: the optimal n_threads value (for text generation) is equal
|
|
324
|
-
# to the number of physical cores (for homogenous CPUs) or
|
|
325
|
-
# to the number of performance cores (for heterogenous CPUs)
|
|
326
|
-
#
|
|
327
|
-
# the optimal n_threads_batch value (for prompt eval) is equal
|
|
328
|
-
# to the total number of logical cores, regardless of
|
|
329
|
-
# their type
|
|
330
|
-
|
|
331
|
-
n_threads = max(cpu_count//2, 1)
|
|
332
|
-
n_threads_batch = cpu_count
|
|
333
|
-
|
|
334
|
-
if flash_attn and n_gpu_layers == 0:
|
|
335
|
-
flash_attn = False
|
|
336
|
-
print_warning(
|
|
337
|
-
"disabling flash_attn because n_gpu_layers == 0"
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
if quantize_kv_cache:
|
|
341
|
-
# use q8_0 for K, V
|
|
342
|
-
if flash_attn:
|
|
343
|
-
type_k = 8
|
|
344
|
-
type_v = 8
|
|
345
|
-
if verbose:
|
|
346
|
-
print_verbose(
|
|
347
|
-
"using q8_0 KV cache"
|
|
348
|
-
)
|
|
349
|
-
else: # llama.cpp requires flash_attn for V quantization
|
|
350
|
-
type_k = 8
|
|
351
|
-
type_v = 1
|
|
352
|
-
if verbose:
|
|
353
|
-
print_verbose(
|
|
354
|
-
"using q8_0 K cache, f16 V cache"
|
|
355
|
-
)
|
|
356
|
-
print_verbose(
|
|
357
|
-
"to quantize V cache, flash_attn must be enabled"
|
|
358
|
-
)
|
|
359
|
-
else:
|
|
360
|
-
# use f16 for K, V (default)
|
|
361
|
-
type_k = 1
|
|
362
|
-
type_v = 1
|
|
363
|
-
|
|
364
|
-
# guard against models with no rope_freq_base
|
|
365
|
-
if rope_freq_base is None:
|
|
366
|
-
rope_freq_base = 0
|
|
367
|
-
|
|
368
|
-
if verbose:
|
|
369
|
-
print_verbose(
|
|
370
|
-
f"attempting to load model, offloading "
|
|
371
|
-
f"{n_gpu_layers}/{n_layer} layers..."
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
# llama.cpp needs -ngl set to `-1`, not just n_layer
|
|
375
|
-
if n_gpu_layers >= n_layer:
|
|
376
|
-
_llama_ngl = -1
|
|
377
|
-
else:
|
|
378
|
-
_llama_ngl = n_gpu_layers
|
|
379
|
-
|
|
380
|
-
self.llama = Llama(
|
|
381
|
-
model_path=model_path,
|
|
382
|
-
n_ctx=self.context_length,
|
|
383
|
-
n_gpu_layers=_llama_ngl,
|
|
384
|
-
use_mmap=True,
|
|
385
|
-
use_mlock=False,
|
|
386
|
-
logits_all=False,
|
|
387
|
-
n_batch=n_batch,
|
|
388
|
-
n_threads=n_threads,
|
|
389
|
-
n_threads_batch=n_threads_batch,
|
|
390
|
-
rope_freq_base=rope_freq_base,
|
|
391
|
-
mul_mat_q=True,
|
|
392
|
-
offload_kqv=offload_kqv,
|
|
393
|
-
flash_attn=flash_attn,
|
|
394
|
-
type_k=type_k,
|
|
395
|
-
type_v=type_v,
|
|
396
|
-
verbose=_debug
|
|
397
|
-
)
|
|
398
|
-
|
|
399
|
-
# NOTE: llama.cpp uses the nearest multiple of 32 as the actual
|
|
400
|
-
# context length. here we update self.context_length to reflect
|
|
401
|
-
# this
|
|
402
|
-
self.context_length = self.n_ctx = self.llama.n_ctx()
|
|
403
|
-
|
|
404
|
-
if self.n_ctx < 512:
|
|
405
|
-
print_warning(
|
|
406
|
-
f'the currently loaded context length is less than 512 tokens '
|
|
407
|
-
f'({self.n_ctx} < 512). sometimes this can cause problems in '
|
|
408
|
-
f'llama.cpp. consider increasing the context length to at '
|
|
409
|
-
f'least 512 tokens'
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
try:
|
|
413
|
-
self.vocab: list[str] = self.metadata['tokenizer.ggml.tokens']
|
|
414
|
-
except (KeyError, TypeError, ValueError):
|
|
415
|
-
print_warning(
|
|
416
|
-
"could not set Model.vocab, constructing manually..."
|
|
417
|
-
)
|
|
418
|
-
self.vocab = [
|
|
419
|
-
self.llama._model.detokenize([i], special=True).decode(
|
|
420
|
-
'utf-8', errors='ignore'
|
|
421
|
-
) for i in range(self.llama._model.n_vocab())
|
|
422
|
-
]
|
|
423
|
-
try:
|
|
424
|
-
self.bos_token = int(self.metadata['tokenizer.ggml.bos_token_id'])
|
|
425
|
-
except (KeyError, TypeError, ValueError):
|
|
426
|
-
self.bos_token = int(self.llama._model.token_bos())
|
|
427
|
-
if self.bos_token < 0:
|
|
428
|
-
self.bos_token = None
|
|
429
|
-
print_warning(
|
|
430
|
-
"could not set Model.bos_token, defaulting to None"
|
|
431
|
-
)
|
|
432
|
-
try:
|
|
433
|
-
self.eos_token = int(self.metadata['tokenizer.ggml.eos_token_id'])
|
|
434
|
-
except (KeyError, TypeError, ValueError):
|
|
435
|
-
self.eos_token = int(self.llama._model.token_eos())
|
|
436
|
-
if self.eos_token < 0:
|
|
437
|
-
self.eos_token = None
|
|
438
|
-
print_warning(
|
|
439
|
-
"could not set Model.eos_token, defaulting to None"
|
|
440
|
-
)
|
|
441
|
-
|
|
442
|
-
# These special tokens are optional
|
|
443
|
-
|
|
444
|
-
self.eot_token = int(self.llama._model.token_eot())
|
|
445
|
-
if self.eot_token < 0:
|
|
446
|
-
self.eot_token = None
|
|
447
|
-
|
|
448
|
-
self.nl_token = int(self.llama._model.token_nl())
|
|
449
|
-
if self.nl_token < 0:
|
|
450
|
-
self.nl_token = None
|
|
451
|
-
|
|
452
|
-
self.prefix_token = int(self.llama._model.token_prefix())
|
|
453
|
-
if self.prefix_token < 0:
|
|
454
|
-
self.prefix_token = None
|
|
455
|
-
|
|
456
|
-
self.middle_token = int(self.llama._model.token_middle())
|
|
457
|
-
if self.middle_token < 0:
|
|
458
|
-
self.middle_token = None
|
|
459
|
-
|
|
460
|
-
self.suffix_token = int(self.llama._model.token_suffix())
|
|
461
|
-
if self.suffix_token < 0:
|
|
462
|
-
self.suffix_token = None
|
|
463
|
-
|
|
464
|
-
self.cls_token = int(self.llama._model.token_cls())
|
|
465
|
-
if self.cls_token < 0:
|
|
466
|
-
self.cls_token = None
|
|
467
|
-
|
|
468
|
-
self.sep_token = int(self.llama._model.token_sep())
|
|
469
|
-
if self.sep_token < 0:
|
|
470
|
-
self.sep_token = None
|
|
471
|
-
|
|
472
|
-
# Misc. attributes
|
|
473
|
-
_add_bos_token = self.llama._model.add_bos_token()
|
|
474
|
-
if _add_bos_token == 1:
|
|
475
|
-
self.add_bos_token = True
|
|
476
|
-
elif _add_bos_token == 0:
|
|
477
|
-
self.add_bos_token = False
|
|
478
|
-
else:
|
|
479
|
-
self.add_bos_token = None
|
|
480
|
-
print_warning(
|
|
481
|
-
"Model.add_bos_token is unknown, defaulting to None"
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
_add_eos_token = self.llama._model.add_eos_token()
|
|
485
|
-
if _add_eos_token == 1:
|
|
486
|
-
self.add_eos_token = True
|
|
487
|
-
elif _add_eos_token == 0:
|
|
488
|
-
self.add_eos_token = False
|
|
489
|
-
else:
|
|
490
|
-
self.add_eos_token = None
|
|
491
|
-
print_warning(
|
|
492
|
-
"Model.add_eos_token is unknown, defaulting to None"
|
|
493
|
-
)
|
|
494
|
-
|
|
495
|
-
self.filename: str = os.path.basename(model_path)
|
|
496
|
-
self.n_ctx_train: int = n_ctx_train
|
|
497
|
-
self.rope_freq_base_train: float = rope_freq_base_train
|
|
498
|
-
self.rope_freq_base: float = rope_freq_base
|
|
499
|
-
self.n_batch: int = n_batch
|
|
500
|
-
self.n_threads: int = n_threads
|
|
501
|
-
self.n_threads_batch: int = n_threads_batch
|
|
502
|
-
self.flash_attn: bool = flash_attn
|
|
503
|
-
self.n_embd = self.llama._model.n_embd()
|
|
504
|
-
self.n_params = self.llama._model.n_params()
|
|
505
|
-
self.bpw = (8*self._model_file_size_bytes)/self.n_params
|
|
506
|
-
self.n_vocab: int = len(self.vocab)
|
|
507
|
-
self.n_layer: int = n_layer
|
|
508
|
-
self.n_gpu_layers: int = n_gpu_layers
|
|
509
|
-
self.offload_kqv = offload_kqv
|
|
510
|
-
self.is_native: bool = self.context_length <= self.n_ctx_train
|
|
511
|
-
self.type_k: int = type_k
|
|
512
|
-
self.type_v: int = type_v
|
|
513
|
-
self.n_gqa: int = n_gqa
|
|
514
|
-
|
|
515
|
-
if verbose:
|
|
516
|
-
print_verbose(
|
|
517
|
-
f"{'new' if '__uuid' not in _kwargs_keys else 'reloaded'} "
|
|
518
|
-
f"Model instance with the following attributes:"
|
|
519
|
-
)
|
|
520
|
-
print_verbose(f" uuid == {self.uuid}")
|
|
521
|
-
print_verbose(f" filename == {self.filename}")
|
|
522
|
-
print_verbose(f" n_params == {self.n_params}")
|
|
523
|
-
print_verbose(
|
|
524
|
-
f" bpw == {self.bpw} "
|
|
525
|
-
f"({__class__._get_bpw_quality_hint(self.bpw)})"
|
|
526
|
-
)
|
|
527
|
-
print_verbose(f" n_gpu_layers == {self.n_gpu_layers}")
|
|
528
|
-
print_verbose(f" n_layer == {self.n_layer}")
|
|
529
|
-
print_verbose(f" offload_kqv == {self.offload_kqv}")
|
|
530
|
-
print_verbose(f" flash_attn == {self.flash_attn}")
|
|
531
|
-
print_verbose(f" n_gqa == {self.n_gqa}")
|
|
532
|
-
print_verbose(
|
|
533
|
-
f" type_k == {self.type_k} "
|
|
534
|
-
f"({'f16' if self.type_k == 1 else 'q8_0'})"
|
|
535
|
-
)
|
|
536
|
-
print_verbose(
|
|
537
|
-
f" type_v == {self.type_v} "
|
|
538
|
-
f"({'f16' if self.type_v == 1 else 'q8_0'})"
|
|
539
|
-
)
|
|
540
|
-
print_verbose(f" n_batch == {self.n_batch}")
|
|
541
|
-
print_verbose(
|
|
542
|
-
f" n_threads == {self.n_threads}/{cpu_count}"
|
|
543
|
-
)
|
|
544
|
-
print_verbose(
|
|
545
|
-
f" n_threads_batch == {self.n_threads_batch}/{cpu_count}"
|
|
546
|
-
)
|
|
547
|
-
print_verbose(f" n_ctx_train == {self.n_ctx_train}")
|
|
548
|
-
print_verbose(f" n_ctx == {self.n_ctx}")
|
|
549
|
-
print_verbose(f" rope_freq_base_train == {self.rope_freq_base_train}")
|
|
550
|
-
print_verbose(f" rope_freq_base == {self.rope_freq_base}")
|
|
551
|
-
print_verbose(f" n_embd == {self.n_embd}")
|
|
552
|
-
print_verbose(f" n_vocab == {self.n_vocab}")
|
|
553
|
-
print_verbose(f" bos_token == {self.bos_token}")
|
|
554
|
-
print_verbose(f" eos_token == {self.eos_token}")
|
|
555
|
-
if self.eot_token is not None:
|
|
556
|
-
print_verbose(f" eot_token == {self.eot_token}")
|
|
557
|
-
if self.nl_token is not None:
|
|
558
|
-
print_verbose(f" nl_token == {self.nl_token}")
|
|
559
|
-
if self.prefix_token is not None:
|
|
560
|
-
print_verbose(f" prefix_token == {self.prefix_token}")
|
|
561
|
-
if self.middle_token is not None:
|
|
562
|
-
print_verbose(f" middle_token == {self.middle_token}")
|
|
563
|
-
if self.suffix_token is not None:
|
|
564
|
-
print_verbose(f" suffix_token == {self.suffix_token}")
|
|
565
|
-
if self.cls_token is not None:
|
|
566
|
-
print_verbose(f" cls_token == {self.cls_token}")
|
|
567
|
-
if self.sep_token is not None:
|
|
568
|
-
print_verbose(f" sep_token == {self.sep_token}")
|
|
569
|
-
print_verbose(f" add_bos_token == {self.add_bos_token}")
|
|
570
|
-
print_verbose(f" add_eos_token == {self.add_eos_token}")
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
@staticmethod
|
|
574
|
-
def _calculate_rope_freq_base(
|
|
575
|
-
n_ctx_train: int,
|
|
576
|
-
n_ctx_load: int,
|
|
577
|
-
rope_freq_base_train: Optional[float]
|
|
578
|
-
) -> float:
|
|
579
|
-
"""
|
|
580
|
-
Returns the rope_freq_base (theta) value at which model should be loaded
|
|
581
|
-
"""
|
|
582
|
-
assert_type(n_ctx_train, int, 'n_ctx_train', '_calculate_rope_freq_base')
|
|
583
|
-
assert_type(n_ctx_load, int, 'n_ctx_load', '_calculate_rope_freq_base')
|
|
584
|
-
assert_type(rope_freq_base_train, (float, NoneType),
|
|
585
|
-
'rope_freq_base_train', '_calculate_rope_freq_base')
|
|
586
|
-
|
|
587
|
-
if n_ctx_load <= n_ctx_train:
|
|
588
|
-
if rope_freq_base_train is None:
|
|
589
|
-
return 0.0
|
|
590
|
-
else:
|
|
591
|
-
return rope_freq_base_train
|
|
592
|
-
|
|
593
|
-
if rope_freq_base_train is None or rope_freq_base_train == 0.0:
|
|
594
|
-
raise ValueError(
|
|
595
|
-
'unable to load model with greater than native '
|
|
596
|
-
f'context length ({n_ctx_load} > {n_ctx_train}) '
|
|
597
|
-
'because model does not specify rope_freq_base. '
|
|
598
|
-
f'try again with context_length <= {n_ctx_train}'
|
|
599
|
-
)
|
|
600
|
-
|
|
601
|
-
return ((n_ctx_load/n_ctx_train)**(2**(1/4)))*rope_freq_base_train
|
|
602
|
-
|
|
603
|
-
# traditional formula:
|
|
604
|
-
# return (n_ctx_load/n_ctx_train)*rope_freq_base_train
|
|
605
|
-
# experimental formula A:
|
|
606
|
-
# return ((n_ctx_load/n_ctx_train)**2)*rope_freq_base_train
|
|
607
|
-
# experimental formula B:
|
|
608
|
-
# return ((n_ctx_load/n_ctx_train)**(2**(1/4)))*rope_freq_base_train
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
@staticmethod
|
|
612
|
-
def _get_bpw_quality_hint(bpw: float) -> str:
|
|
613
|
-
if 0.0 < bpw < 2.0:
|
|
614
|
-
return 'terrible'
|
|
615
|
-
elif 2.0 <= bpw < 4.0:
|
|
616
|
-
return 'bad'
|
|
617
|
-
elif 4.0 <= bpw < 5.0:
|
|
618
|
-
return 'good'
|
|
619
|
-
elif 5.0 <= bpw < 16.0:
|
|
620
|
-
return 'great'
|
|
621
|
-
elif bpw >= 16.0:
|
|
622
|
-
return 'native'
|
|
623
|
-
else:
|
|
624
|
-
raise UnreachableException
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
@staticmethod
|
|
628
|
-
def _print_metadata(
|
|
629
|
-
metadata: dict,
|
|
630
|
-
file: _SupportsWriteAndFlush = sys.stderr
|
|
631
|
-
) -> None:
|
|
632
|
-
max_len_key = max(len(k) for k in metadata.keys())
|
|
633
|
-
print(f'webscout.Local: read model metadata from GGUF file header:', file=file)
|
|
634
|
-
for k, v in metadata.items():
|
|
635
|
-
print(
|
|
636
|
-
f'webscout.Local: {k:<{max_len_key}} : {truncate(repr(v))}',
|
|
637
|
-
file=file
|
|
638
|
-
)
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
def __repr__(self) -> str:
|
|
642
|
-
return (
|
|
643
|
-
f"Model({self._model_path!r}, "
|
|
644
|
-
f"context_length={self._context_length}, "
|
|
645
|
-
f"n_gpu_layers={self._n_gpu_layers}, "
|
|
646
|
-
f"offload_kqv={self._offload_kqv}, "
|
|
647
|
-
f"flash_attn={self._flash_attn}, "
|
|
648
|
-
f"quantize_kv_cache={self._quantize_kv_cache}, "
|
|
649
|
-
f"verbose={self._verbose})"
|
|
650
|
-
)
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
def __sizeof__(self) -> int:
|
|
654
|
-
"""Returns the size of the model file on disk, NOT the memory usage"""
|
|
655
|
-
return self._model_file_size_bytes
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
def __del__(self):
|
|
659
|
-
if self.is_loaded():
|
|
660
|
-
self.unload()
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
def __enter__(self):
|
|
664
|
-
if not self.is_loaded():
|
|
665
|
-
self.load()
|
|
666
|
-
return self
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
def __exit__(self, *_):
|
|
670
|
-
if self.is_loaded():
|
|
671
|
-
self.unload()
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
def __call__(
|
|
675
|
-
self,
|
|
676
|
-
prompt: str | list[int],
|
|
677
|
-
stops: Optional[list[str | int]] = None,
|
|
678
|
-
sampler: Optional[SamplerSettings] = None
|
|
679
|
-
) -> str:
|
|
680
|
-
"""
|
|
681
|
-
`Model(...)` is a shorthand for `Model.generate(...)`
|
|
682
|
-
"""
|
|
683
|
-
return self.generate(prompt=prompt, stops=stops, sampler=sampler)
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
def __eq__(self, value: object, /) -> bool:
|
|
687
|
-
if not isinstance(value, __class__):
|
|
688
|
-
return NotImplemented
|
|
689
|
-
if not (hasattr(self, 'uuid') and hasattr(value, 'uuid')):
|
|
690
|
-
raise AttributeError(
|
|
691
|
-
"At least one of the models being compared is missing the "
|
|
692
|
-
"`.uuid` attribute"
|
|
693
|
-
)
|
|
694
|
-
return self.uuid == value.uuid
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
def __hash__(self, /) -> int:
|
|
698
|
-
return hash(self.uuid)
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
def unload(self):
|
|
702
|
-
"""
|
|
703
|
-
Unload the model from memory
|
|
704
|
-
|
|
705
|
-
Does nothing if the model is not loaded
|
|
706
|
-
"""
|
|
707
|
-
if not self.is_loaded():
|
|
708
|
-
if self.verbose:
|
|
709
|
-
print_verbose('model already unloaded')
|
|
710
|
-
return
|
|
711
|
-
|
|
712
|
-
if self.verbose:
|
|
713
|
-
print_verbose('unloading model...')
|
|
714
|
-
|
|
715
|
-
self.llama.close()
|
|
716
|
-
|
|
717
|
-
while hasattr(self, 'llama'):
|
|
718
|
-
delattr(self, 'llama')
|
|
719
|
-
|
|
720
|
-
if self.verbose:
|
|
721
|
-
print_verbose('model unloaded')
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
def reload(
|
|
725
|
-
self,
|
|
726
|
-
context_length: Optional[int] = None,
|
|
727
|
-
n_gpu_layers: Optional[int] = None,
|
|
728
|
-
offload_kqv: Optional[bool] = None,
|
|
729
|
-
flash_attn: Optional[bool] = None,
|
|
730
|
-
quantize_kv_cache: Optional[bool] = None,
|
|
731
|
-
verbose: Optional[bool] = None
|
|
732
|
-
):
|
|
733
|
-
"""
|
|
734
|
-
Re-load the model into memory using the specified parameters
|
|
735
|
-
|
|
736
|
-
Any parameters unspecified will be unchanged
|
|
737
|
-
"""
|
|
738
|
-
__uuid = self.uuid
|
|
739
|
-
self.unload()
|
|
740
|
-
self.__init__(
|
|
741
|
-
model_path = self._model_path,
|
|
742
|
-
context_length = (
|
|
743
|
-
self._context_length if context_length is None
|
|
744
|
-
else context_length
|
|
745
|
-
),
|
|
746
|
-
n_gpu_layers = (
|
|
747
|
-
self._n_gpu_layers if n_gpu_layers is None
|
|
748
|
-
else n_gpu_layers
|
|
749
|
-
),
|
|
750
|
-
offload_kqv = (
|
|
751
|
-
self._offload_kqv if offload_kqv is None
|
|
752
|
-
else offload_kqv
|
|
753
|
-
),
|
|
754
|
-
flash_attn = (
|
|
755
|
-
self._flash_attn if flash_attn is None
|
|
756
|
-
else flash_attn
|
|
757
|
-
),
|
|
758
|
-
quantize_kv_cache = (
|
|
759
|
-
self._quantize_kv_cache if quantize_kv_cache is None
|
|
760
|
-
else quantize_kv_cache
|
|
761
|
-
),
|
|
762
|
-
verbose = (
|
|
763
|
-
self._verbose if verbose is None
|
|
764
|
-
else verbose
|
|
765
|
-
),
|
|
766
|
-
__uuid = __uuid # do not change UUID on reload
|
|
767
|
-
)
|
|
768
|
-
assert_model_is_loaded(self)
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
def load(self) -> None:
|
|
772
|
-
"""
|
|
773
|
-
Load the model into memory
|
|
774
|
-
|
|
775
|
-
Does nothing if already loaded
|
|
776
|
-
"""
|
|
777
|
-
if self.is_loaded():
|
|
778
|
-
if self.verbose:
|
|
779
|
-
print_verbose('model already loaded')
|
|
780
|
-
else:
|
|
781
|
-
self.reload()
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
def is_loaded(self) -> bool:
|
|
785
|
-
"""
|
|
786
|
-
Return `True` if the model is fully loaded, `False` otherwise
|
|
787
|
-
"""
|
|
788
|
-
try:
|
|
789
|
-
assert_model_is_loaded(self)
|
|
790
|
-
except exceptions.ModelUnloadedException:
|
|
791
|
-
return False
|
|
792
|
-
else:
|
|
793
|
-
return True
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
def tokenize(self, text: str) -> list[int]:
|
|
797
|
-
"""
|
|
798
|
-
Tokenize the given text (from `str` to `list[int]`)
|
|
799
|
-
"""
|
|
800
|
-
assert_type(text, str, 'text', 'tokenize')
|
|
801
|
-
assert_model_is_loaded(self)
|
|
802
|
-
tokens = self.llama._model.tokenize(
|
|
803
|
-
text.encode('utf-8'),
|
|
804
|
-
add_bos=(
|
|
805
|
-
self.add_bos_token if self.add_bos_token is not None
|
|
806
|
-
else True
|
|
807
|
-
),
|
|
808
|
-
special=True
|
|
809
|
-
)
|
|
810
|
-
# remove duplicate BOS tokens at the start of the text
|
|
811
|
-
while len(tokens) >= 2 and tokens[0] == self.bos_token and tokens[1] == self.bos_token:
|
|
812
|
-
tokens.pop(0)
|
|
813
|
-
if self.verbose:
|
|
814
|
-
print_verbose("tokenize: removed duplicate BOS token")
|
|
815
|
-
# remove duplicate EOS tokens at the end of the text
|
|
816
|
-
while len(tokens) >= 2 and tokens[-1] == self.eos_token and tokens[-2] == self.eos_token:
|
|
817
|
-
tokens.pop(-1)
|
|
818
|
-
if self.verbose:
|
|
819
|
-
print_verbose("tokenize: removed duplicate EOS token")
|
|
820
|
-
return tokens
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
def detokenize(self, tokens: list[int] | int) -> str:
|
|
824
|
-
"""
|
|
825
|
-
Detokenize the given text (from `int` or `list[int]` to `str`)
|
|
826
|
-
"""
|
|
827
|
-
assert_type(tokens, (list, int), 'tokens', 'detokenize')
|
|
828
|
-
if isinstance(tokens, int):
|
|
829
|
-
tokens = [tokens] # handle single tokens
|
|
830
|
-
for tok_id in tokens:
|
|
831
|
-
if not 0 <= tok_id < self.n_vocab:
|
|
832
|
-
raise ValueError(
|
|
833
|
-
f"detokenize: token id {tok_id} is out of range. "
|
|
834
|
-
f"acceptable values for this model are between 0 and "
|
|
835
|
-
f"{self.n_vocab-1} inclusive"
|
|
836
|
-
)
|
|
837
|
-
# remove duplicate BOS tokens at the start of the text
|
|
838
|
-
while len(tokens) >= 2 and tokens[0] == self.bos_token and tokens[1] == self.bos_token:
|
|
839
|
-
tokens.pop(0)
|
|
840
|
-
if self.verbose:
|
|
841
|
-
print_verbose("detokenize: removed duplicate BOS token")
|
|
842
|
-
# remove duplicate EOS tokens at the end of the text
|
|
843
|
-
while len(tokens) >= 2 and tokens[-1] == self.eos_token and tokens[-2] == self.eos_token:
|
|
844
|
-
tokens.pop(-1)
|
|
845
|
-
if self.verbose:
|
|
846
|
-
print_verbose("detokenize: removed duplicate EOS token")
|
|
847
|
-
assert_model_is_loaded(self)
|
|
848
|
-
return self.llama._model.detokenize(
|
|
849
|
-
tokens,
|
|
850
|
-
special=True
|
|
851
|
-
).decode('utf-8', errors='ignore')
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
def get_length(self, text: str) -> int:
|
|
855
|
-
"""
|
|
856
|
-
Return the length of the given text in as measured in tokens
|
|
857
|
-
"""
|
|
858
|
-
return len(self.tokenize(text))
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
def get_tokenization_mapping(
|
|
862
|
-
self,
|
|
863
|
-
text: str
|
|
864
|
-
) -> list[tuple[int, str]]:
|
|
865
|
-
"""
|
|
866
|
-
Tokenize the given text and return a list of tuples where the first
|
|
867
|
-
item in the tuple is the token ID and the second item is the
|
|
868
|
-
corresponding text
|
|
869
|
-
"""
|
|
870
|
-
token_id_list: list[int] = self.tokenize(text)
|
|
871
|
-
|
|
872
|
-
return list(
|
|
873
|
-
zip(
|
|
874
|
-
token_id_list,
|
|
875
|
-
[self.detokenize(tok_id) for tok_id in token_id_list]
|
|
876
|
-
)
|
|
877
|
-
)
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
def print_tokenization_mapping(self, text: str) -> None:
|
|
881
|
-
"""
|
|
882
|
-
Tokenize the given text and display a mapping of each
|
|
883
|
-
token ID and its corresponding decoded text
|
|
884
|
-
|
|
885
|
-
This is meant to be equivalent to `llama.cpp/llama-tokenize`
|
|
886
|
-
"""
|
|
887
|
-
token_mapping_list = self.get_tokenization_mapping(text)
|
|
888
|
-
|
|
889
|
-
for token_id, token_text in token_mapping_list:
|
|
890
|
-
print(f"{token_id:>7} -> '{token_text}'")
|
|
891
|
-
print(f"Total number of tokens: {len(token_mapping_list)}")
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
def generate(
|
|
895
|
-
self,
|
|
896
|
-
prompt: str | list[int],
|
|
897
|
-
stops: Optional[list[str | int]] = None,
|
|
898
|
-
sampler: Optional[SamplerSettings] = None
|
|
899
|
-
) -> str:
|
|
900
|
-
"""
|
|
901
|
-
Given a prompt, return a generated string.
|
|
902
|
-
|
|
903
|
-
prompt: The text from which to generate
|
|
904
|
-
|
|
905
|
-
The following parameters are optional:
|
|
906
|
-
- stops: A list of strings and/or token IDs at which to end the generation early
|
|
907
|
-
- sampler: The SamplerSettings object used to control text generation
|
|
908
|
-
"""
|
|
909
|
-
|
|
910
|
-
stops = [] if stops is None else stops
|
|
911
|
-
assert_type(stops, list, 'stops', 'generate')
|
|
912
|
-
for item in stops:
|
|
913
|
-
assert_type(
|
|
914
|
-
item,
|
|
915
|
-
(str, int),
|
|
916
|
-
"some item in parameter 'stops'",
|
|
917
|
-
'generate'
|
|
918
|
-
)
|
|
919
|
-
|
|
920
|
-
sampler = SamplerSettings() if sampler is None else sampler
|
|
921
|
-
|
|
922
|
-
if sampler.temp < 0.0:
|
|
923
|
-
print_warning(
|
|
924
|
-
f'generate: using negative temperature value {sampler.temp}'
|
|
925
|
-
)
|
|
926
|
-
|
|
927
|
-
assert_type(prompt, (str, list), 'prompt', 'generate')
|
|
928
|
-
if isinstance(prompt, list):
|
|
929
|
-
prompt_tokens = prompt
|
|
930
|
-
else:
|
|
931
|
-
if self.verbose:
|
|
932
|
-
print_verbose(
|
|
933
|
-
"generate: tokenizing prompt"
|
|
934
|
-
)
|
|
935
|
-
prompt_tokens = self.tokenize(prompt)
|
|
936
|
-
|
|
937
|
-
input_length = len(prompt_tokens)
|
|
938
|
-
|
|
939
|
-
if input_length > self.context_length:
|
|
940
|
-
print(f'webscout.Local: raw input: {prompt_tokens}')
|
|
941
|
-
raise exceptions.ExceededContextLengthException(
|
|
942
|
-
f"generate: length of input exceeds model's context length "
|
|
943
|
-
f"({input_length} > {self.context_length})"
|
|
944
|
-
)
|
|
945
|
-
elif input_length == self.context_length:
|
|
946
|
-
print(f'webscout.Local: raw input: {prompt_tokens}')
|
|
947
|
-
raise exceptions.ExceededContextLengthException(
|
|
948
|
-
f"generate: length of input is equal to model's context "
|
|
949
|
-
f"length ({input_length} == {self.context_length}). this "
|
|
950
|
-
f"leaves no room for any new tokens to be generated"
|
|
951
|
-
)
|
|
952
|
-
elif self.verbose:
|
|
953
|
-
print_verbose(
|
|
954
|
-
f"generate: received prompt with {input_length} tokens"
|
|
955
|
-
)
|
|
956
|
-
|
|
957
|
-
stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
|
|
958
|
-
stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
|
|
959
|
-
stopping_criteria = None
|
|
960
|
-
if stop_token_ids != []:
|
|
961
|
-
def stop_on_token_ids(tokens, *args, **kwargs):
|
|
962
|
-
return tokens[-1] in stop_token_ids
|
|
963
|
-
stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
|
|
964
|
-
|
|
965
|
-
if self.verbose:
|
|
966
|
-
print_verbose(f'generate: using the following sampler settings:')
|
|
967
|
-
print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
|
|
968
|
-
print_verbose(f'top_k == {sampler.top_k}')
|
|
969
|
-
print_verbose(f'top_p == {sampler.top_p}')
|
|
970
|
-
print_verbose(f'min_p == {sampler.min_p}')
|
|
971
|
-
print_verbose(f'temp == {sampler.temp}')
|
|
972
|
-
print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
|
|
973
|
-
print_verbose(f'presence_penalty == {sampler.presence_penalty}')
|
|
974
|
-
print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
|
|
975
|
-
|
|
976
|
-
assert_model_is_loaded(self)
|
|
977
|
-
return self.llama.create_completion(
|
|
978
|
-
prompt=prompt_tokens,
|
|
979
|
-
max_tokens=sampler.max_len_tokens,
|
|
980
|
-
temperature=sampler.temp,
|
|
981
|
-
top_p=sampler.top_p,
|
|
982
|
-
min_p=sampler.min_p,
|
|
983
|
-
frequency_penalty=sampler.frequency_penalty,
|
|
984
|
-
presence_penalty=sampler.presence_penalty,
|
|
985
|
-
repeat_penalty=sampler.repeat_penalty,
|
|
986
|
-
top_k=sampler.top_k,
|
|
987
|
-
stop=stop_strs,
|
|
988
|
-
stopping_criteria=stopping_criteria
|
|
989
|
-
)['choices'][0]['text']
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
def stream(
|
|
993
|
-
self,
|
|
994
|
-
prompt: str | list[int],
|
|
995
|
-
stops: Optional[list[str | int]] = None,
|
|
996
|
-
sampler: Optional[SamplerSettings] = None
|
|
997
|
-
) -> Generator:
|
|
998
|
-
"""
|
|
999
|
-
Given a prompt, return a Generator that yields dicts containing tokens.
|
|
1000
|
-
|
|
1001
|
-
To get the token string itself, subscript the dict with:
|
|
1002
|
-
|
|
1003
|
-
`['choices'][0]['text']`
|
|
1004
|
-
|
|
1005
|
-
prompt: The text from which to generate
|
|
1006
|
-
|
|
1007
|
-
The following parameters are optional:
|
|
1008
|
-
- stops: A list of strings and/or token IDs at which to end the generation early
|
|
1009
|
-
- sampler: The SamplerSettings object used to control text generation
|
|
1010
|
-
"""
|
|
1011
|
-
|
|
1012
|
-
stops = [] if stops is None else stops
|
|
1013
|
-
assert_type(stops, list, 'stops', 'stream')
|
|
1014
|
-
for item in stops:
|
|
1015
|
-
assert_type(
|
|
1016
|
-
item,
|
|
1017
|
-
(str, int),
|
|
1018
|
-
"some item in parameter 'stops'",
|
|
1019
|
-
'stream'
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
sampler = SamplerSettings() if sampler is None else sampler
|
|
1023
|
-
|
|
1024
|
-
if sampler.temp < 0.0:
|
|
1025
|
-
print_warning(
|
|
1026
|
-
f'stream: using negative temperature value {sampler.temp}'
|
|
1027
|
-
)
|
|
1028
|
-
|
|
1029
|
-
assert_type(prompt, (str, list), 'prompt', 'stream')
|
|
1030
|
-
if isinstance(prompt, list):
|
|
1031
|
-
prompt_tokens = prompt
|
|
1032
|
-
else:
|
|
1033
|
-
if self.verbose:
|
|
1034
|
-
print_verbose(
|
|
1035
|
-
"stream: tokenizing prompt"
|
|
1036
|
-
)
|
|
1037
|
-
prompt_tokens = self.tokenize(prompt)
|
|
1038
|
-
|
|
1039
|
-
input_length = len(prompt_tokens)
|
|
1040
|
-
|
|
1041
|
-
if input_length > self.context_length:
|
|
1042
|
-
print(f'webscout.Local: raw input: {prompt_tokens}')
|
|
1043
|
-
raise exceptions.ExceededContextLengthException(
|
|
1044
|
-
f"stream: length of input exceeds model's context length "
|
|
1045
|
-
f"({input_length} > {self.context_length})"
|
|
1046
|
-
)
|
|
1047
|
-
elif input_length == self.context_length:
|
|
1048
|
-
print(f'webscout.Local: raw input: {prompt_tokens}')
|
|
1049
|
-
raise exceptions.ExceededContextLengthException(
|
|
1050
|
-
f"stream: length of input is equal to model's context "
|
|
1051
|
-
f"length ({input_length} == {self.context_length}). this "
|
|
1052
|
-
f"leaves no room for any new tokens to be generated"
|
|
1053
|
-
)
|
|
1054
|
-
elif self.verbose:
|
|
1055
|
-
print_verbose(
|
|
1056
|
-
f"stream: received prompt with {input_length} tokens"
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
|
|
1060
|
-
stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
|
|
1061
|
-
stopping_criteria = None
|
|
1062
|
-
if stop_token_ids != []:
|
|
1063
|
-
def stop_on_token_ids(tokens, *args, **kwargs):
|
|
1064
|
-
return tokens[-1] in stop_token_ids
|
|
1065
|
-
stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
|
|
1066
|
-
|
|
1067
|
-
if self.verbose:
|
|
1068
|
-
print_verbose(f'stream: using the following sampler settings:')
|
|
1069
|
-
print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
|
|
1070
|
-
print_verbose(f'top_k == {sampler.top_k}')
|
|
1071
|
-
print_verbose(f'top_p == {sampler.top_p}')
|
|
1072
|
-
print_verbose(f'min_p == {sampler.min_p}')
|
|
1073
|
-
print_verbose(f'temp == {sampler.temp}')
|
|
1074
|
-
print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
|
|
1075
|
-
print_verbose(f'presence_penalty == {sampler.presence_penalty}')
|
|
1076
|
-
print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
|
|
1077
|
-
|
|
1078
|
-
assert_model_is_loaded(self)
|
|
1079
|
-
return self.llama.create_completion(
|
|
1080
|
-
prompt=prompt_tokens,
|
|
1081
|
-
max_tokens=sampler.max_len_tokens,
|
|
1082
|
-
temperature=sampler.temp,
|
|
1083
|
-
top_p=sampler.top_p,
|
|
1084
|
-
min_p=sampler.min_p,
|
|
1085
|
-
frequency_penalty=sampler.frequency_penalty,
|
|
1086
|
-
presence_penalty=sampler.presence_penalty,
|
|
1087
|
-
repeat_penalty=sampler.repeat_penalty,
|
|
1088
|
-
top_k=sampler.top_k,
|
|
1089
|
-
stream=True,
|
|
1090
|
-
stop=stop_strs,
|
|
1091
|
-
stopping_criteria=stopping_criteria
|
|
1092
|
-
)
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
def stream_print(
|
|
1096
|
-
self,
|
|
1097
|
-
prompt: str | list[int],
|
|
1098
|
-
stops: Optional[list[str | int]] = None,
|
|
1099
|
-
sampler: Optional[SamplerSettings] = None,
|
|
1100
|
-
end: str = '\n',
|
|
1101
|
-
file: _SupportsWriteAndFlush = None,
|
|
1102
|
-
flush: bool = True
|
|
1103
|
-
) -> str:
|
|
1104
|
-
"""
|
|
1105
|
-
Given a prompt, stream text to a file as it is generated, and return
|
|
1106
|
-
the generated string. The returned string does not include the `end`
|
|
1107
|
-
parameter.
|
|
1108
|
-
|
|
1109
|
-
prompt: The text from which to generate
|
|
1110
|
-
|
|
1111
|
-
The following parameters are optional:
|
|
1112
|
-
- stops: A list of strings and/or token IDs at which to end the generation early
|
|
1113
|
-
- sampler: The SamplerSettings object used to control text generation
|
|
1114
|
-
- end: A string to print after the generated text
|
|
1115
|
-
- file: The file where text should be printed
|
|
1116
|
-
- flush: Whether to flush the stream after each token
|
|
1117
|
-
"""
|
|
1118
|
-
|
|
1119
|
-
token_generator = self.stream(
|
|
1120
|
-
prompt=prompt,
|
|
1121
|
-
stops=stops,
|
|
1122
|
-
sampler=sampler
|
|
1123
|
-
)
|
|
1124
|
-
|
|
1125
|
-
file = sys.stdout if file is None else file
|
|
1126
|
-
|
|
1127
|
-
response = ''
|
|
1128
|
-
for i in token_generator:
|
|
1129
|
-
tok = i['choices'][0]['text']
|
|
1130
|
-
print(tok, end='', file=file, flush=flush)
|
|
1131
|
-
response += tok
|
|
1132
|
-
|
|
1133
|
-
# print `end`, and always flush stream after generation is done
|
|
1134
|
-
print(end, end='', file=file, flush=True)
|
|
1135
|
-
|
|
1136
|
-
return response
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
def ingest(self, text: str | list[int]) -> None:
|
|
1140
|
-
"""
|
|
1141
|
-
Ingest the given text into the model's cache
|
|
1142
|
-
"""
|
|
1143
|
-
|
|
1144
|
-
assert_type(text, (str, list), 'prompt', 'stream')
|
|
1145
|
-
if isinstance(text, list):
|
|
1146
|
-
tokens = text
|
|
1147
|
-
else:
|
|
1148
|
-
if self.verbose:
|
|
1149
|
-
print_verbose(
|
|
1150
|
-
"ingest: tokenizing text"
|
|
1151
|
-
)
|
|
1152
|
-
tokens = self.tokenize(text)
|
|
1153
|
-
|
|
1154
|
-
input_length = len(tokens)
|
|
1155
|
-
|
|
1156
|
-
if input_length > self.context_length:
|
|
1157
|
-
print(f'webscout.Local: raw input: {tokens}')
|
|
1158
|
-
raise exceptions.ExceededContextLengthException(
|
|
1159
|
-
f"ingest: length of input exceeds model's context length "
|
|
1160
|
-
f"({input_length} > {self.context_length})"
|
|
1161
|
-
)
|
|
1162
|
-
elif input_length == self.context_length:
|
|
1163
|
-
print(f'webscout.Local: raw input: {tokens}')
|
|
1164
|
-
raise exceptions.ExceededContextLengthException(
|
|
1165
|
-
f"ingest: length of input is equal to model's context "
|
|
1166
|
-
f"length ({input_length} == {self.context_length}). this "
|
|
1167
|
-
f"leaves no room for any new tokens to be generated"
|
|
1168
|
-
)
|
|
1169
|
-
elif self.verbose:
|
|
1170
|
-
print_verbose(
|
|
1171
|
-
f"ingest: ingesting {input_length} tokens"
|
|
1172
|
-
)
|
|
1173
|
-
|
|
1174
|
-
assert_model_is_loaded(self)
|
|
1175
|
-
self.llama.create_completion(
|
|
1176
|
-
prompt=tokens,
|
|
1177
|
-
max_tokens=2,
|
|
1178
|
-
temperature=0.0
|
|
1179
|
-
)
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
def candidates(
|
|
1183
|
-
self,
|
|
1184
|
-
prompt: str,
|
|
1185
|
-
k: int = 40,
|
|
1186
|
-
temp: Optional[float] = None,
|
|
1187
|
-
raw_token_ids: bool = False
|
|
1188
|
-
) -> list[tuple[str, np.floating]]:
|
|
1189
|
-
"""
|
|
1190
|
-
Given prompt `str` and k `int`, return a sorted list of the
|
|
1191
|
-
top k candidates for most likely next token, along with their
|
|
1192
|
-
normalized probabilities (logprobs).
|
|
1193
|
-
|
|
1194
|
-
The following parameters are optional:
|
|
1195
|
-
- temp: The temperature to apply to the distribution
|
|
1196
|
-
- raw_token_ids: If `True`, return raw token IDs instead of text tokens
|
|
1197
|
-
|
|
1198
|
-
If parameter `k` is <= 0, the probabilities for all tokens in the
|
|
1199
|
-
vocabulary will be returned. Vocabulary sizes are often in the
|
|
1200
|
-
hundred-thousands.
|
|
1201
|
-
"""
|
|
1202
|
-
|
|
1203
|
-
assert_type(prompt, str, 'prompt', 'candidates')
|
|
1204
|
-
assert_type(k, int, 'k', 'candidates')
|
|
1205
|
-
assert_type(temp, (float, NoneType), 'temp', 'candidates')
|
|
1206
|
-
assert_model_is_loaded(self)
|
|
1207
|
-
if k <= 0:
|
|
1208
|
-
k = self.n_vocab
|
|
1209
|
-
if self.verbose:
|
|
1210
|
-
print_verbose(
|
|
1211
|
-
f"candidates: k <= 0, using n_vocab ({self.n_vocab})"
|
|
1212
|
-
)
|
|
1213
|
-
if not 1 <= k <= self.n_vocab:
|
|
1214
|
-
raise ValueError(
|
|
1215
|
-
f"candidates: k should be between 1 and {self.n_vocab} "
|
|
1216
|
-
f"inclusive"
|
|
1217
|
-
)
|
|
1218
|
-
|
|
1219
|
-
prompt_tokens = self.tokenize(prompt)
|
|
1220
|
-
input_length = len(prompt_tokens)
|
|
1221
|
-
|
|
1222
|
-
if input_length > self.context_length:
|
|
1223
|
-
print(f'webscout.Local: raw input: {prompt_tokens}')
|
|
1224
|
-
raise exceptions.ExceededContextLengthException(
|
|
1225
|
-
f"candidates: length of input exceeds model's context length "
|
|
1226
|
-
f"({input_length} > {self.context_length})"
|
|
1227
|
-
)
|
|
1228
|
-
elif input_length == self.context_length:
|
|
1229
|
-
print(f'webscout.Local: raw input: {prompt_tokens}')
|
|
1230
|
-
raise exceptions.ExceededContextLengthException(
|
|
1231
|
-
f"candidates: length of input is equal to model's context "
|
|
1232
|
-
f"length ({input_length} == {self.context_length}). this "
|
|
1233
|
-
f"leaves no room for any new tokens to be generated"
|
|
1234
|
-
)
|
|
1235
|
-
|
|
1236
|
-
# it is necessary to reset the model before calling llama.eval()
|
|
1237
|
-
elif self.verbose:
|
|
1238
|
-
print_verbose(
|
|
1239
|
-
"candidates: reset model state..."
|
|
1240
|
-
)
|
|
1241
|
-
self.llama.reset()
|
|
1242
|
-
|
|
1243
|
-
if self.verbose:
|
|
1244
|
-
print_verbose(
|
|
1245
|
-
"candidates: eval..."
|
|
1246
|
-
)
|
|
1247
|
-
self.llama.eval(prompt_tokens) # single forward pass
|
|
1248
|
-
|
|
1249
|
-
scores = self.llama.scores[len(prompt_tokens) - 1]
|
|
1250
|
-
|
|
1251
|
-
# Get the top k indices based on raw scores
|
|
1252
|
-
top_k_indices = np.argpartition(scores, -k)[-k:]
|
|
1253
|
-
|
|
1254
|
-
# Get the scores of the top k tokens
|
|
1255
|
-
top_k_scores = scores[top_k_indices]
|
|
1256
|
-
|
|
1257
|
-
# Apply softmax to the top k scores
|
|
1258
|
-
if self.verbose:
|
|
1259
|
-
print_verbose(
|
|
1260
|
-
f'candidates: compute softmax over {len(top_k_scores)} '
|
|
1261
|
-
f'values...'
|
|
1262
|
-
)
|
|
1263
|
-
normalized_scores = softmax(z=top_k_scores, T=temp)
|
|
1264
|
-
|
|
1265
|
-
# consider only the top k tokens
|
|
1266
|
-
logprobs = [
|
|
1267
|
-
(
|
|
1268
|
-
self.llama._model.detokenize(
|
|
1269
|
-
[tok_id], special=True
|
|
1270
|
-
).decode('utf-8', errors='ignore'),
|
|
1271
|
-
normalized_scores[i]
|
|
1272
|
-
) for i, tok_id in enumerate(top_k_indices)
|
|
1273
|
-
] if not raw_token_ids else [
|
|
1274
|
-
(
|
|
1275
|
-
tok_id,
|
|
1276
|
-
normalized_scores[i]
|
|
1277
|
-
) for i, tok_id in enumerate(top_k_indices)
|
|
1278
|
-
]
|
|
1279
|
-
|
|
1280
|
-
# sort by probability
|
|
1281
|
-
logprobs.sort(key=lambda x: x[1], reverse=True)
|
|
1282
|
-
|
|
1283
|
-
return logprobs
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
def print_candidates(
|
|
1287
|
-
self,
|
|
1288
|
-
prompt: str,
|
|
1289
|
-
k: int = 40,
|
|
1290
|
-
temp: Optional[float] = None,
|
|
1291
|
-
raw_token_ids: bool = False,
|
|
1292
|
-
file: _SupportsWriteAndFlush = None,
|
|
1293
|
-
) -> None:
|
|
1294
|
-
"""
|
|
1295
|
-
Given prompt `str` and k `int`, print a sorted list of the
|
|
1296
|
-
top k candidates for most likely next token, along with their
|
|
1297
|
-
normalized probabilities (logprobs).
|
|
1298
|
-
|
|
1299
|
-
The following parameters are optional:
|
|
1300
|
-
- temp: The temperature to apply to the distribution
|
|
1301
|
-
- raw_token_ids: If `True`, print raw token IDs instead of text tokens
|
|
1302
|
-
|
|
1303
|
-
If parameter `k` is <= 0, the probabilities for all tokens in the
|
|
1304
|
-
vocabulary will be printed. Vocabulary sizes are often in the
|
|
1305
|
-
hundred-thousands.
|
|
1306
|
-
"""
|
|
1307
|
-
for _tuple in self.candidates(
|
|
1308
|
-
prompt=prompt, k=k, temp=temp, raw_token_ids=raw_token_ids
|
|
1309
|
-
):
|
|
1310
|
-
percent_as_string = f"{_tuple[1] * 100 :>7.3f}"
|
|
1311
|
-
# do not print tokens with ~0.000% probability
|
|
1312
|
-
if percent_as_string != " 0.000":
|
|
1313
|
-
print(
|
|
1314
|
-
f"token {_tuple[0]!r:<32} has probability "
|
|
1315
|
-
f"{percent_as_string} %",
|
|
1316
|
-
file=sys.stdout if file is None else file,
|
|
1317
|
-
)
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
def assert_model_is_loaded(model) -> None:
|
|
1321
|
-
"""
|
|
1322
|
-
Ensure the model is fully constructed, such that
|
|
1323
|
-
`model.llama._model.model is not None` is guaranteed to be `True`
|
|
1324
|
-
|
|
1325
|
-
Raise ModelUnloadedException otherwise
|
|
1326
|
-
"""
|
|
1327
|
-
try:
|
|
1328
|
-
if model.llama._model.model is not None:
|
|
1329
|
-
return
|
|
1330
|
-
except AttributeError:
|
|
1331
|
-
pass
|
|
1332
|
-
|
|
1333
|
-
if model is None:
|
|
1334
|
-
exc = exceptions.ModelUnloadedException(
|
|
1335
|
-
"model is None"
|
|
1336
|
-
)
|
|
1337
|
-
elif not hasattr(model, 'llama'):
|
|
1338
|
-
exc = exceptions.ModelUnloadedException(
|
|
1339
|
-
"webscout.Local.Model instance has no attribute 'llama'"
|
|
1340
|
-
)
|
|
1341
|
-
elif not hasattr(model.llama, '_model'):
|
|
1342
|
-
exc = exceptions.ModelUnloadedException(
|
|
1343
|
-
"llama_cpp.Llama instance has no attribute '_model'"
|
|
1344
|
-
)
|
|
1345
|
-
elif not hasattr(model.llama._model, 'model'):
|
|
1346
|
-
exc = exceptions.ModelUnloadedException(
|
|
1347
|
-
"llama_cpp._internals._LlamaModel instance has no attribute "
|
|
1348
|
-
"'model'"
|
|
1349
|
-
)
|
|
1350
|
-
elif model.llama._model.model is None:
|
|
1351
|
-
exc = exceptions.ModelUnloadedException(
|
|
1352
|
-
"llama_cpp._internals._LlamaModel.model is None"
|
|
1353
|
-
)
|
|
1354
|
-
else:
|
|
1355
|
-
raise UnreachableException
|
|
1356
|
-
|
|
1357
|
-
if not isinstance(model, Model):
|
|
1358
|
-
exc.add_note(
|
|
1359
|
-
'WARNING: `assert_model_is_loaded` was called on an object '
|
|
1360
|
-
'that is NOT an instance of `webscout.Local.Model` '
|
|
1361
|
-
f'(object had type {type(model)!r})'
|
|
1362
|
-
)
|
|
1363
|
-
else:
|
|
1364
|
-
exc.add_note(
|
|
1365
|
-
'Are you trying to use a model that has been unloaded?'
|
|
1366
|
-
)
|
|
1367
|
-
|
|
1368
|
-
raise exc
|