webscout 2.3b0__py3-none-any.whl → 2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of webscout might be problematic. Click here for more details.
- webscout/Local/__init__.py +10 -0
- webscout/Local/_version.py +3 -0
- webscout/Local/formats.py +482 -0
- webscout/Local/model.py +702 -0
- webscout/Local/samplers.py +161 -0
- webscout/Local/thread.py +680 -0
- webscout/Local/utils.py +185 -0
- webscout/__init__.py +4 -5
- {webscout-2.3b0.dist-info → webscout-2.5.dist-info}/METADATA +6 -5
- {webscout-2.3b0.dist-info → webscout-2.5.dist-info}/RECORD +14 -7
- {webscout-2.3b0.dist-info → webscout-2.5.dist-info}/LICENSE.md +0 -0
- {webscout-2.3b0.dist-info → webscout-2.5.dist-info}/WHEEL +0 -0
- {webscout-2.3b0.dist-info → webscout-2.5.dist-info}/entry_points.txt +0 -0
- {webscout-2.3b0.dist-info → webscout-2.5.dist-info}/top_level.txt +0 -0
webscout/Local/model.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
1
|
+
from ._version import __version__, __llama_cpp_version__
|
|
2
|
+
|
|
3
|
+
"""Submodule containing the Model class to work with language models"""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .utils import (
|
|
9
|
+
_SupportsWriteAndFlush,
|
|
10
|
+
print_warning,
|
|
11
|
+
print_verbose,
|
|
12
|
+
GGUFReader,
|
|
13
|
+
softmax
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from .samplers import SamplerSettings, DefaultSampling
|
|
17
|
+
from llama_cpp import Llama, StoppingCriteriaList
|
|
18
|
+
from typing import Generator, Optional, Union
|
|
19
|
+
from os.path import isdir, exists
|
|
20
|
+
from heapq import nlargest
|
|
21
|
+
|
|
22
|
+
from os import cpu_count as os_cpu_count
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModelUnloadedException(Exception):
|
|
26
|
+
"""Exception raised when trying to use a Model that has been unloaded"""
|
|
27
|
+
def __init__(self, message):
|
|
28
|
+
self.message = message
|
|
29
|
+
super().__init__(self.message)
|
|
30
|
+
self.add_note('Are you trying to use a Model that has been unloaded?')
|
|
31
|
+
|
|
32
|
+
class Model:
|
|
33
|
+
"""
|
|
34
|
+
A high-level abstraction of a llama model
|
|
35
|
+
|
|
36
|
+
This is just a brief overview of webscout.Local.Model.
|
|
37
|
+
To see a full description of each method and its parameters,
|
|
38
|
+
call help(Model), or see the relevant docstring.
|
|
39
|
+
|
|
40
|
+
The following methods are available:
|
|
41
|
+
- `.generate()` - Generate text
|
|
42
|
+
- `.get_length()` - Get the length of a given text in tokens
|
|
43
|
+
- `.ingest()` - Ingest text into the model's cache
|
|
44
|
+
- `.next_candidates()` - Get a list of the most likely next tokens (WIP)
|
|
45
|
+
- `.stream()` - Return a Generator that can stream text as it is generated
|
|
46
|
+
- `.stream_print()` - Print text as it is generated
|
|
47
|
+
- `.trim()` - Trim a given text to the model's context length
|
|
48
|
+
- `.unload()` - Unload the model from memory
|
|
49
|
+
|
|
50
|
+
The following attributes are available:
|
|
51
|
+
- `.bos_token` - The model's beginning-of-stream token ID
|
|
52
|
+
- `.context_length` - The model's loaded context length
|
|
53
|
+
- `.flash_attn` - Whether the model was loaded with `flash_attn=True`
|
|
54
|
+
- `.eos_token` - The model's end-of-stream token ID
|
|
55
|
+
- `.llama` - The underlying `llama_cpp.Llama` instance
|
|
56
|
+
- `.metadata` - The GGUF metadata of the model
|
|
57
|
+
- `.n_ctx_train` - The native context length of the model
|
|
58
|
+
- `.rope_freq_base` - The model's loaded RoPE frequency base
|
|
59
|
+
- `.rope_freq_base_train` - The model's native RoPE frequency base
|
|
60
|
+
- `.tokens` - A list of all the tokens in the model's tokenizer
|
|
61
|
+
- `.verbose` - Whether the model was loaded with `verbose=True`
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
model_path: str,
|
|
67
|
+
context_length: Optional[int] = None,
|
|
68
|
+
n_gpu_layers: int = 0,
|
|
69
|
+
offload_kqv: bool = True,
|
|
70
|
+
flash_attn: bool = False,
|
|
71
|
+
verbose: bool = False
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Given the path to a GGUF file, construct a Model instance.
|
|
75
|
+
|
|
76
|
+
The model must be in GGUF format.
|
|
77
|
+
|
|
78
|
+
The following parameters are optional:
|
|
79
|
+
- context_length: The context length at which to load the model, in tokens
|
|
80
|
+
- n_gpu_layers: The number of layers to be offloaded to the GPU
|
|
81
|
+
- offload_kqv: Whether the KQV cache (context) should be offloaded
|
|
82
|
+
- flash_attn: Whether to use Flash Attention
|
|
83
|
+
- verbose: Whether to print additional backend information
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
if verbose:
|
|
87
|
+
print_verbose(f"webscout.Local package version: {__version__}")
|
|
88
|
+
print_verbose(f"llama_cpp package version: {__llama_cpp_version__}")
|
|
89
|
+
|
|
90
|
+
assert isinstance(model_path, str), \
|
|
91
|
+
f"Model: model_path should be a string, not {type(model_path)}"
|
|
92
|
+
assert exists(model_path), \
|
|
93
|
+
f"Model: the given model_path '{model_path}' does not exist"
|
|
94
|
+
assert not isdir(model_path), \
|
|
95
|
+
f"Model: the given model_path '{model_path}' is a directory, not a GGUF file"
|
|
96
|
+
assert isinstance(context_length, (int, type(None))), \
|
|
97
|
+
f"Model: context_length should be int or None, not {type(context_length)}"
|
|
98
|
+
assert isinstance(flash_attn, bool), \
|
|
99
|
+
f"Model: flash_attn should be bool (True or False), not {type(flash_attn)}"
|
|
100
|
+
|
|
101
|
+
# save __init__ parameters for __repr__
|
|
102
|
+
self._model_path = model_path
|
|
103
|
+
self._context_length = context_length
|
|
104
|
+
self._n_gpu_layers = n_gpu_layers
|
|
105
|
+
self._offload_kqv = offload_kqv
|
|
106
|
+
self._flash_attn = flash_attn
|
|
107
|
+
self._verbose = self.verbose = verbose
|
|
108
|
+
|
|
109
|
+
# if context_length <= 0, use n_ctx_train
|
|
110
|
+
if isinstance(context_length, int) and context_length <= 0:
|
|
111
|
+
context_length = None
|
|
112
|
+
|
|
113
|
+
# this does not use Llama.metadata because we want to use GGUF
|
|
114
|
+
# metadata to determine some parameters of the Llama instance
|
|
115
|
+
# before it is created
|
|
116
|
+
self.metadata = GGUFReader.load_metadata(self, model_path)
|
|
117
|
+
metadata_keys = self.metadata.keys() # only read once
|
|
118
|
+
|
|
119
|
+
n_ctx_train = None
|
|
120
|
+
for key in metadata_keys:
|
|
121
|
+
if key.endswith('.context_length'):
|
|
122
|
+
n_ctx_train = self.metadata[key]
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
if n_ctx_train is None:
|
|
126
|
+
raise KeyError(
|
|
127
|
+
"GGUF file does not specify a context length"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
rope_freq_base_train = None
|
|
131
|
+
for key in metadata_keys:
|
|
132
|
+
if key.endswith('.rope.freq_base'):
|
|
133
|
+
rope_freq_base_train = self.metadata[key]
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
if rope_freq_base_train is None and context_length is not None:
|
|
137
|
+
if context_length > n_ctx_train:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
'unable to load model with greater than native ' + \
|
|
140
|
+
f'context length ({context_length} > {n_ctx_train}) ' + \
|
|
141
|
+
'because model does not specify freq_base. ' + \
|
|
142
|
+
f'try again with `context_length={n_ctx_train}`'
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if rope_freq_base_train is None or context_length is None or \
|
|
146
|
+
context_length <= n_ctx_train:
|
|
147
|
+
# no need to do context scaling, load model normally
|
|
148
|
+
|
|
149
|
+
if context_length is None:
|
|
150
|
+
self.context_length = n_ctx_train
|
|
151
|
+
else:
|
|
152
|
+
self.context_length = context_length
|
|
153
|
+
rope_freq_base = rope_freq_base_train
|
|
154
|
+
|
|
155
|
+
elif context_length > n_ctx_train:
|
|
156
|
+
# multiply rope_freq_base according to requested context length
|
|
157
|
+
# because context length > n_ctx_train and rope freq base is known
|
|
158
|
+
|
|
159
|
+
rope_freq_base = (context_length/n_ctx_train)*rope_freq_base_train
|
|
160
|
+
self.context_length = context_length
|
|
161
|
+
|
|
162
|
+
if self.verbose:
|
|
163
|
+
print_verbose(
|
|
164
|
+
'chosen context length is greater than native context '
|
|
165
|
+
f'length ({context_length} > {n_ctx_train}), '
|
|
166
|
+
'rope_freq_base will be changed from '
|
|
167
|
+
f'{rope_freq_base_train} to {rope_freq_base}'
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if 2 <= context_length/n_ctx_train < 4:
|
|
171
|
+
print_warning(
|
|
172
|
+
'loading model with 2x native context length or more, '
|
|
173
|
+
'expect small loss of quality'
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
elif 4 <= context_length/n_ctx_train < 8:
|
|
177
|
+
print_warning(
|
|
178
|
+
'loading model with 4x native context length or more, '
|
|
179
|
+
'expect moderate loss of quality'
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
elif context_length/n_ctx_train >= 8:
|
|
183
|
+
print_warning(
|
|
184
|
+
'loading model with 8x native context length or more, '
|
|
185
|
+
'expect SIGNIFICANT loss of quality'
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
self.tokens: list[str] = self.metadata['tokenizer.ggml.tokens']
|
|
190
|
+
except KeyError:
|
|
191
|
+
print_warning(
|
|
192
|
+
"could not set Model.tokens, defaulting to None"
|
|
193
|
+
)
|
|
194
|
+
self.tokens = None
|
|
195
|
+
try:
|
|
196
|
+
self.bos_token: int = self.metadata['tokenizer.ggml.bos_token_id']
|
|
197
|
+
except KeyError:
|
|
198
|
+
print_warning(
|
|
199
|
+
"could not set Model.bos_token, defaulting to None"
|
|
200
|
+
)
|
|
201
|
+
self.bos_token = None
|
|
202
|
+
try:
|
|
203
|
+
self.eos_token: int = self.metadata['tokenizer.ggml.eos_token_id']
|
|
204
|
+
except KeyError:
|
|
205
|
+
print_warning(
|
|
206
|
+
"could not set Model.eos_token, defaulting to None"
|
|
207
|
+
)
|
|
208
|
+
self.eos_token = None
|
|
209
|
+
|
|
210
|
+
cpu_count = os_cpu_count()
|
|
211
|
+
|
|
212
|
+
# these values for n_threads and n_threads_batch are
|
|
213
|
+
# known to be optimal for most systems
|
|
214
|
+
n_batch = 512 # can this be optimized?
|
|
215
|
+
n_threads = max(cpu_count//2, 1)
|
|
216
|
+
n_threads_batch = cpu_count
|
|
217
|
+
|
|
218
|
+
if flash_attn and n_gpu_layers == 0:
|
|
219
|
+
print_warning(
|
|
220
|
+
"disabling flash_attn because n_gpu_layers == 0"
|
|
221
|
+
)
|
|
222
|
+
flash_attn = False
|
|
223
|
+
|
|
224
|
+
# guard against models with no rope_freq_base
|
|
225
|
+
if rope_freq_base is None:
|
|
226
|
+
rope_freq_base = 0
|
|
227
|
+
|
|
228
|
+
self.llama: Llama = Llama(
|
|
229
|
+
model_path=model_path,
|
|
230
|
+
n_ctx=self.context_length,
|
|
231
|
+
n_gpu_layers=n_gpu_layers,
|
|
232
|
+
use_mmap=True,
|
|
233
|
+
use_mlock=False,
|
|
234
|
+
logits_all=False,
|
|
235
|
+
n_batch=n_batch,
|
|
236
|
+
n_threads=n_threads,
|
|
237
|
+
n_threads_batch=n_threads_batch,
|
|
238
|
+
rope_freq_base=rope_freq_base,
|
|
239
|
+
mul_mat_q=True,
|
|
240
|
+
offload_kqv=offload_kqv,
|
|
241
|
+
flash_attn=flash_attn,
|
|
242
|
+
# KV cache quantization
|
|
243
|
+
# use 1 for F16 (default), 8 for q8_0, 2 for q4_0, 3 for q4_1
|
|
244
|
+
#type_k=8,
|
|
245
|
+
#type_v=8,
|
|
246
|
+
verbose=verbose
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# once model is loaded, replace metadata (as read using internal class)
|
|
250
|
+
# with metadata (as read using the more robust llama-cpp-python code)
|
|
251
|
+
self.metadata = self.llama.metadata
|
|
252
|
+
|
|
253
|
+
# expose these values because they may be useful / informative
|
|
254
|
+
self.n_ctx_train = n_ctx_train
|
|
255
|
+
self.rope_freq_base_train = rope_freq_base_train
|
|
256
|
+
self.rope_freq_base = rope_freq_base
|
|
257
|
+
self.flash_attn = flash_attn
|
|
258
|
+
|
|
259
|
+
if self.verbose:
|
|
260
|
+
print_verbose("new Model instance with the following attributes:")
|
|
261
|
+
print_verbose(f"model: {model_path}")
|
|
262
|
+
print_verbose(f"param: n_gpu_layers == {n_gpu_layers}")
|
|
263
|
+
print_verbose(f"param: offload_kqv == {offload_kqv}")
|
|
264
|
+
print_verbose(f"param: flash_attn == {flash_attn}")
|
|
265
|
+
print_verbose(f"param: n_batch == {n_batch}")
|
|
266
|
+
print_verbose(f"param: n_threads == {n_threads}")
|
|
267
|
+
print_verbose(f"param: n_threads_batch == {n_threads_batch}")
|
|
268
|
+
print_verbose(f" gguf: n_ctx_train == {n_ctx_train}")
|
|
269
|
+
print_verbose(f"param: self.context_length == {self.context_length}")
|
|
270
|
+
print_verbose(f" gguf: rope_freq_base_train == {rope_freq_base_train}")
|
|
271
|
+
print_verbose(f"param: rope_freq_base == {rope_freq_base}")
|
|
272
|
+
|
|
273
|
+
def __repr__(self) -> str:
|
|
274
|
+
return \
|
|
275
|
+
f"Model({repr(self._model_path)}, " + \
|
|
276
|
+
f"context_length={self._context_length}, " + \
|
|
277
|
+
f"n_gpu_layers={self._n_gpu_layers}, " + \
|
|
278
|
+
f"offload_kqv={self._offload_kqv}, "+ \
|
|
279
|
+
f"flash_attn={self._flash_attn}, " + \
|
|
280
|
+
f"verbose={self._verbose})"
|
|
281
|
+
|
|
282
|
+
def __del__(self):
|
|
283
|
+
self.unload()
|
|
284
|
+
|
|
285
|
+
def __enter__(self):
|
|
286
|
+
return self
|
|
287
|
+
|
|
288
|
+
def __exit__(self, *_):
|
|
289
|
+
self.unload()
|
|
290
|
+
|
|
291
|
+
def __call__(
|
|
292
|
+
self,
|
|
293
|
+
prompt: Union[str, list[int]],
|
|
294
|
+
stops: list[Union[str, int]] = [],
|
|
295
|
+
sampler: SamplerSettings = DefaultSampling
|
|
296
|
+
) -> str:
|
|
297
|
+
"""
|
|
298
|
+
`Model(...)` is a shorthand for `Model.generate(...)`
|
|
299
|
+
"""
|
|
300
|
+
return self.generate(prompt, stops, sampler)
|
|
301
|
+
|
|
302
|
+
def unload(self):
|
|
303
|
+
"""
|
|
304
|
+
Unload the model from memory
|
|
305
|
+
"""
|
|
306
|
+
# ref: llama_cpp._internals._LlamaModel.__del__()
|
|
307
|
+
if not hasattr(self, 'llama'):
|
|
308
|
+
# nothing can be done
|
|
309
|
+
return
|
|
310
|
+
try:
|
|
311
|
+
if self.llama._model.model is not None:
|
|
312
|
+
# actually unload the model from memory
|
|
313
|
+
self.llama._model._llama_free_model(self.llama._model.model)
|
|
314
|
+
self.llama._model.model = None
|
|
315
|
+
except AttributeError:
|
|
316
|
+
# broken or already being destroyed by GC, abort
|
|
317
|
+
return
|
|
318
|
+
if hasattr(self, 'llama'):
|
|
319
|
+
delattr(self, 'llama')
|
|
320
|
+
if self.verbose:
|
|
321
|
+
print_verbose('Model unloaded')
|
|
322
|
+
|
|
323
|
+
def trim(
|
|
324
|
+
self,
|
|
325
|
+
text: str,
|
|
326
|
+
overwrite: Optional[str] = None
|
|
327
|
+
) -> str:
|
|
328
|
+
|
|
329
|
+
"""
|
|
330
|
+
Trim the given text to the context length of this model,
|
|
331
|
+
leaving room for two extra tokens.
|
|
332
|
+
|
|
333
|
+
Optionally overwrite the oldest tokens with the text given in the
|
|
334
|
+
`overwrite` parameter, which may be useful for keeping some
|
|
335
|
+
information in context.
|
|
336
|
+
|
|
337
|
+
Does nothing if the text is equal to or shorter than
|
|
338
|
+
(context_length - 2).
|
|
339
|
+
"""
|
|
340
|
+
assert_model_is_loaded(self)
|
|
341
|
+
trim_length = self.context_length - 2
|
|
342
|
+
tokens_list = self.llama.tokenize(
|
|
343
|
+
text.encode("utf-8", errors="ignore")
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
if len(tokens_list) <= trim_length:
|
|
347
|
+
if overwrite is not None:
|
|
348
|
+
text[0 : len(overwrite)] = overwrite
|
|
349
|
+
return text
|
|
350
|
+
|
|
351
|
+
if len(tokens_list) > trim_length and overwrite is None:
|
|
352
|
+
# cut to trim_length
|
|
353
|
+
tokens_list = tokens_list[-trim_length:]
|
|
354
|
+
return self.llama.detokenize(tokens_list).decode(
|
|
355
|
+
"utf-8",
|
|
356
|
+
errors="ignore"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if len(tokens_list) > trim_length and overwrite is not None:
|
|
360
|
+
# cut to trim_length
|
|
361
|
+
tokens_list = tokens_list[-trim_length:]
|
|
362
|
+
overwrite_tokens = self.llama.tokenize(overwrite.encode(
|
|
363
|
+
"utf-8",
|
|
364
|
+
errors="ignore"
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
# overwrite oldest tokens
|
|
368
|
+
tokens_list[0 : len(overwrite_tokens)] = overwrite_tokens
|
|
369
|
+
return self.llama.detokenize(tokens_list).decode(
|
|
370
|
+
"utf-8",
|
|
371
|
+
errors="ignore"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def get_length(self, text: str) -> int:
|
|
375
|
+
"""
|
|
376
|
+
Return the length of the given text in tokens according to this model,
|
|
377
|
+
including the appended BOS token.
|
|
378
|
+
"""
|
|
379
|
+
assert_model_is_loaded(self)
|
|
380
|
+
return len(self.llama.tokenize(
|
|
381
|
+
text.encode(
|
|
382
|
+
"utf-8",
|
|
383
|
+
errors="ignore"
|
|
384
|
+
)
|
|
385
|
+
))
|
|
386
|
+
|
|
387
|
+
def generate(
|
|
388
|
+
self,
|
|
389
|
+
prompt: Union[str, list[int]],
|
|
390
|
+
stops: list[Union[str, int]] = [],
|
|
391
|
+
sampler: SamplerSettings = DefaultSampling
|
|
392
|
+
) -> str:
|
|
393
|
+
"""
|
|
394
|
+
Given a prompt, return a generated string.
|
|
395
|
+
|
|
396
|
+
prompt: The text from which to generate
|
|
397
|
+
|
|
398
|
+
The following parameters are optional:
|
|
399
|
+
- stops: A list of strings and/or token IDs at which to end the generation early
|
|
400
|
+
- sampler: The SamplerSettings object used to control text generation
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
assert isinstance(prompt, (str, list)), \
|
|
404
|
+
f"generate: prompt should be string or list[int], not {type(prompt)}"
|
|
405
|
+
if isinstance(prompt, list):
|
|
406
|
+
assert all(isinstance(tok, int) for tok in prompt), \
|
|
407
|
+
"generate: some token in prompt is not an integer"
|
|
408
|
+
assert isinstance(stops, list), \
|
|
409
|
+
f"generate: parameter `stops` should be a list, not {type(stops)}"
|
|
410
|
+
assert all(isinstance(item, (str, int)) for item in stops), \
|
|
411
|
+
f"generate: some item in parameter `stops` is not a string or int"
|
|
412
|
+
|
|
413
|
+
if self.verbose:
|
|
414
|
+
print_verbose(f'using the following sampler settings for Model.generate:')
|
|
415
|
+
print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
|
|
416
|
+
print_verbose(f'temp == {sampler.temp}')
|
|
417
|
+
print_verbose(f'top_p == {sampler.top_p}')
|
|
418
|
+
print_verbose(f'min_p == {sampler.min_p}')
|
|
419
|
+
print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
|
|
420
|
+
print_verbose(f'presence_penalty == {sampler.presence_penalty}')
|
|
421
|
+
print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
|
|
422
|
+
print_verbose(f'top_k == {sampler.top_k}')
|
|
423
|
+
|
|
424
|
+
# if any stop item is a token ID (int)
|
|
425
|
+
if any(isinstance(stop, int) for stop in stops):
|
|
426
|
+
# stop_strs is a list of all stopping strings
|
|
427
|
+
stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
|
|
428
|
+
# stop_token_ids is a list of all stop token IDs
|
|
429
|
+
stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
|
|
430
|
+
def stop_on_token_ids(tokens, *args, **kwargs):
|
|
431
|
+
return tokens[-1] in stop_token_ids
|
|
432
|
+
stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
|
|
433
|
+
assert_model_is_loaded(self)
|
|
434
|
+
return self.llama.create_completion(
|
|
435
|
+
prompt,
|
|
436
|
+
max_tokens=sampler.max_len_tokens,
|
|
437
|
+
temperature=sampler.temp,
|
|
438
|
+
top_p=sampler.top_p,
|
|
439
|
+
min_p=sampler.min_p,
|
|
440
|
+
frequency_penalty=sampler.frequency_penalty,
|
|
441
|
+
presence_penalty=sampler.presence_penalty,
|
|
442
|
+
repeat_penalty=sampler.repeat_penalty,
|
|
443
|
+
top_k=sampler.top_k,
|
|
444
|
+
stop=stop_strs,
|
|
445
|
+
stopping_criteria=stopping_criteria
|
|
446
|
+
)['choices'][0]['text']
|
|
447
|
+
|
|
448
|
+
# if stop items are only strings
|
|
449
|
+
assert_model_is_loaded(self)
|
|
450
|
+
return self.llama.create_completion(
|
|
451
|
+
prompt,
|
|
452
|
+
max_tokens=sampler.max_len_tokens,
|
|
453
|
+
temperature=sampler.temp,
|
|
454
|
+
top_p=sampler.top_p,
|
|
455
|
+
min_p=sampler.min_p,
|
|
456
|
+
frequency_penalty=sampler.frequency_penalty,
|
|
457
|
+
presence_penalty=sampler.presence_penalty,
|
|
458
|
+
repeat_penalty=sampler.repeat_penalty,
|
|
459
|
+
top_k=sampler.top_k,
|
|
460
|
+
stop=stops
|
|
461
|
+
)['choices'][0]['text']
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def stream(
|
|
465
|
+
self,
|
|
466
|
+
prompt: Union[str, list[int]],
|
|
467
|
+
stops: list[Union[str, int]] = [],
|
|
468
|
+
sampler: SamplerSettings = DefaultSampling
|
|
469
|
+
) -> Generator:
|
|
470
|
+
|
|
471
|
+
"""
|
|
472
|
+
Given a prompt, return a Generator that yields dicts containing tokens.
|
|
473
|
+
|
|
474
|
+
To get the token string itself, subscript the dict with:
|
|
475
|
+
|
|
476
|
+
`['choices'][0]['text']`
|
|
477
|
+
|
|
478
|
+
prompt: The text from which to generate
|
|
479
|
+
|
|
480
|
+
The following parameters are optional:
|
|
481
|
+
- stops: A list of strings and/or token IDs at which to end the generation early
|
|
482
|
+
- sampler: The SamplerSettings object used to control text generation
|
|
483
|
+
"""
|
|
484
|
+
|
|
485
|
+
assert isinstance(prompt, (str, list)), \
|
|
486
|
+
f"stream: prompt should be string or list[int], not {type(prompt)}"
|
|
487
|
+
if isinstance(prompt, list):
|
|
488
|
+
assert all(isinstance(tok, int) for tok in prompt), \
|
|
489
|
+
"stream: some token in prompt is not an integer"
|
|
490
|
+
assert isinstance(stops, list), \
|
|
491
|
+
f"stream: parameter `stops` should be a list, not {type(stops)}"
|
|
492
|
+
assert all(isinstance(item, (str, int)) for item in stops), \
|
|
493
|
+
f"stream: some item in parameter `stops` is not a string or int"
|
|
494
|
+
|
|
495
|
+
if self.verbose:
|
|
496
|
+
print_verbose(f'using the following sampler settings for Model.stream:')
|
|
497
|
+
print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
|
|
498
|
+
print_verbose(f'temp == {sampler.temp}')
|
|
499
|
+
print_verbose(f'top_p == {sampler.top_p}')
|
|
500
|
+
print_verbose(f'min_p == {sampler.min_p}')
|
|
501
|
+
print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
|
|
502
|
+
print_verbose(f'presence_penalty == {sampler.presence_penalty}')
|
|
503
|
+
print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
|
|
504
|
+
print_verbose(f'top_k == {sampler.top_k}')
|
|
505
|
+
|
|
506
|
+
# if any stop item is a token ID (int)
|
|
507
|
+
if any(isinstance(stop, int) for stop in stops):
|
|
508
|
+
# stop_strs is a list of all stopping strings
|
|
509
|
+
stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
|
|
510
|
+
# stop_token_ids is a list of all stop token IDs
|
|
511
|
+
stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
|
|
512
|
+
def stop_on_token_ids(tokens, *args, **kwargs):
|
|
513
|
+
return tokens[-1] in stop_token_ids
|
|
514
|
+
stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
|
|
515
|
+
assert_model_is_loaded(self)
|
|
516
|
+
return self.llama.create_completion(
|
|
517
|
+
prompt,
|
|
518
|
+
max_tokens=sampler.max_len_tokens,
|
|
519
|
+
temperature=sampler.temp,
|
|
520
|
+
top_p=sampler.top_p,
|
|
521
|
+
min_p=sampler.min_p,
|
|
522
|
+
frequency_penalty=sampler.frequency_penalty,
|
|
523
|
+
presence_penalty=sampler.presence_penalty,
|
|
524
|
+
repeat_penalty=sampler.repeat_penalty,
|
|
525
|
+
top_k=sampler.top_k,
|
|
526
|
+
stream=True,
|
|
527
|
+
stop=stop_strs,
|
|
528
|
+
stopping_criteria=stopping_criteria
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
assert_model_is_loaded(self)
|
|
532
|
+
return self.llama.create_completion(
|
|
533
|
+
prompt,
|
|
534
|
+
max_tokens=sampler.max_len_tokens,
|
|
535
|
+
temperature=sampler.temp,
|
|
536
|
+
top_p=sampler.top_p,
|
|
537
|
+
min_p=sampler.min_p,
|
|
538
|
+
frequency_penalty=sampler.frequency_penalty,
|
|
539
|
+
presence_penalty=sampler.presence_penalty,
|
|
540
|
+
repeat_penalty=sampler.repeat_penalty,
|
|
541
|
+
top_k=sampler.top_k,
|
|
542
|
+
stream=True,
|
|
543
|
+
stop=stops
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def stream_print(
|
|
548
|
+
self,
|
|
549
|
+
prompt: Union[str, list[int]],
|
|
550
|
+
stops: list[Union[str, int]] = [],
|
|
551
|
+
sampler: SamplerSettings = DefaultSampling,
|
|
552
|
+
end: str = "\n",
|
|
553
|
+
file: _SupportsWriteAndFlush = sys.stdout,
|
|
554
|
+
flush: bool = True
|
|
555
|
+
) -> str:
|
|
556
|
+
"""
|
|
557
|
+
Given a prompt, stream text as it is generated, and return the generated string.
|
|
558
|
+
The returned string does not include the `end` parameter.
|
|
559
|
+
|
|
560
|
+
`Model.stream_print(...)` is a shorthand for:
|
|
561
|
+
|
|
562
|
+
```
|
|
563
|
+
s = Model.stream(prompt, stops=stops, sampler=sampler)
|
|
564
|
+
for i in s:
|
|
565
|
+
tok = i['choices'][0]['text']
|
|
566
|
+
print(tok, end='', file=file, flush=flush)
|
|
567
|
+
print(end, end='', file=file, flush=True)
|
|
568
|
+
```
|
|
569
|
+
|
|
570
|
+
prompt: The text from which to generate
|
|
571
|
+
|
|
572
|
+
The following parameters are optional:
|
|
573
|
+
- stops: A list of strings and/or token IDs at which to end the generation early
|
|
574
|
+
- sampler: The SamplerSettings object used to control text generation
|
|
575
|
+
- end: A string to print after the generated text
|
|
576
|
+
- file: The file where text should be printed
|
|
577
|
+
- flush: Whether to flush the stream after each token
|
|
578
|
+
"""
|
|
579
|
+
|
|
580
|
+
token_generator = self.stream(
|
|
581
|
+
prompt=prompt,
|
|
582
|
+
stops=stops,
|
|
583
|
+
sampler=sampler
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
res = ''
|
|
587
|
+
for i in token_generator:
|
|
588
|
+
tok = i['choices'][0]['text']
|
|
589
|
+
print(tok, end='', file=file, flush=flush)
|
|
590
|
+
res += tok
|
|
591
|
+
|
|
592
|
+
# print `end`, and always flush stream after generation is done
|
|
593
|
+
print(end, end='', file=file, flush=True)
|
|
594
|
+
|
|
595
|
+
return res
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def ingest(self, text: str) -> None:
|
|
599
|
+
"""
|
|
600
|
+
Ingest the given text into the model's cache
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
assert_model_is_loaded(self)
|
|
604
|
+
self.llama.create_completion(
|
|
605
|
+
text,
|
|
606
|
+
max_tokens=1,
|
|
607
|
+
temperature=0.0
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def candidates(
|
|
612
|
+
self,
|
|
613
|
+
prompt: str,
|
|
614
|
+
k: int
|
|
615
|
+
) -> list[tuple[str, np.float64]]:
|
|
616
|
+
"""
|
|
617
|
+
Given prompt `str` and k `int`, return a sorted list of the
|
|
618
|
+
top k candidates for most likely next token, along with their
|
|
619
|
+
normalized probabilities
|
|
620
|
+
"""
|
|
621
|
+
|
|
622
|
+
assert isinstance(prompt, str), \
|
|
623
|
+
f"next_candidates: prompt should be str, not {type(prompt)}"
|
|
624
|
+
assert isinstance(k, int), \
|
|
625
|
+
f"next_candidates: k should be int, not {type(k)}"
|
|
626
|
+
assert 0 < k <= len(self.tokens), \
|
|
627
|
+
f"next_candidates: k should be between 0 and {len(self.tokens)}"
|
|
628
|
+
|
|
629
|
+
assert_model_is_loaded(self)
|
|
630
|
+
prompt_tokens = self.llama.tokenize(prompt.encode('utf-8', errors='ignore'))
|
|
631
|
+
self.llama.reset() # reset model state
|
|
632
|
+
self.llama.eval(prompt_tokens)
|
|
633
|
+
scores = self.llama.scores[len(prompt_tokens) - 1]
|
|
634
|
+
|
|
635
|
+
# len(self.llama.scores) == self.context_length
|
|
636
|
+
# len(self.llama.scores[i]) == len(self.tokens)
|
|
637
|
+
|
|
638
|
+
# normalize scores with softmax
|
|
639
|
+
# must normalize over all tokens in vocab, not just top k
|
|
640
|
+
if self.verbose:
|
|
641
|
+
print_verbose(f'calculating softmax over {len(scores)} values')
|
|
642
|
+
normalized_scores: list[np.float64] = list(softmax(scores))
|
|
643
|
+
|
|
644
|
+
# construct the final list
|
|
645
|
+
i = 0
|
|
646
|
+
token_probs_list: list[tuple[str, np.float64]] = []
|
|
647
|
+
for tok_str in self.tokens:
|
|
648
|
+
token_probs_list.append((tok_str, normalized_scores[i]))
|
|
649
|
+
i += 1
|
|
650
|
+
|
|
651
|
+
# return token_probs_list, sorted by probability, only top k
|
|
652
|
+
return nlargest(k, token_probs_list, key=lambda x:x[1])
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def print_candidates(
|
|
656
|
+
self,
|
|
657
|
+
prompt: str,
|
|
658
|
+
k: int,
|
|
659
|
+
file: _SupportsWriteAndFlush = sys.stdout,
|
|
660
|
+
flush: bool = False
|
|
661
|
+
) -> None:
|
|
662
|
+
"""
|
|
663
|
+
Like `Model.candidates()`, but print the values instead
|
|
664
|
+
of returning them
|
|
665
|
+
"""
|
|
666
|
+
|
|
667
|
+
for _tuple in self.candidates(prompt, k):
|
|
668
|
+
print(
|
|
669
|
+
f"token '{_tuple[0]}' has probability {_tuple[1]}",
|
|
670
|
+
file=file,
|
|
671
|
+
flush=flush
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# if flush is False, then so far file is not flushed, but it should
|
|
675
|
+
# always be flushed at the end of printing
|
|
676
|
+
if not flush:
|
|
677
|
+
file.flush()
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def assert_model_is_loaded(model: Model) -> None:
|
|
681
|
+
"""
|
|
682
|
+
Ensure the Model is fully constructed, such that
|
|
683
|
+
`Model.llama._model.model is not None` is guaranteed to be `True`
|
|
684
|
+
|
|
685
|
+
Raise ModelUnloadedException otherwise
|
|
686
|
+
"""
|
|
687
|
+
if not hasattr(model, 'llama'):
|
|
688
|
+
raise ModelUnloadedException(
|
|
689
|
+
"webscout.Local.Model instance has no attribute 'llama'"
|
|
690
|
+
)
|
|
691
|
+
if not hasattr(model.llama, '_model'):
|
|
692
|
+
raise ModelUnloadedException(
|
|
693
|
+
"llama_cpp.Llama instance has no attribute '_model'"
|
|
694
|
+
)
|
|
695
|
+
if not hasattr(model.llama._model, 'model'):
|
|
696
|
+
raise ModelUnloadedException(
|
|
697
|
+
"llama_cpp._internals._LlamaModel instance has no attribute 'model'"
|
|
698
|
+
)
|
|
699
|
+
if model.llama._model.model is None:
|
|
700
|
+
raise ModelUnloadedException(
|
|
701
|
+
"llama_cpp._internals._LlamaModel.model is None"
|
|
702
|
+
)
|