webscout 2.2__py3-none-any.whl → 2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of webscout might be problematic. Click here for more details.

@@ -0,0 +1,702 @@
1
+ from ._version import __version__, __llama_cpp_version__
2
+
3
+ """Submodule containing the Model class to work with language models"""
4
+
5
+ import sys
6
+ import numpy as np
7
+
8
+ from .utils import (
9
+ _SupportsWriteAndFlush,
10
+ print_warning,
11
+ print_verbose,
12
+ GGUFReader,
13
+ softmax
14
+ )
15
+
16
+ from .samplers import SamplerSettings, DefaultSampling
17
+ from llama_cpp import Llama, StoppingCriteriaList
18
+ from typing import Generator, Optional, Union
19
+ from os.path import isdir, exists
20
+ from heapq import nlargest
21
+
22
+ from os import cpu_count as os_cpu_count
23
+
24
+
25
+ class ModelUnloadedException(Exception):
26
+ """Exception raised when trying to use a Model that has been unloaded"""
27
+ def __init__(self, message):
28
+ self.message = message
29
+ super().__init__(self.message)
30
+ self.add_note('Are you trying to use a Model that has been unloaded?')
31
+
32
+ class Model:
33
+ """
34
+ A high-level abstraction of a llama model
35
+
36
+ This is just a brief overview of webscout.Local.Model.
37
+ To see a full description of each method and its parameters,
38
+ call help(Model), or see the relevant docstring.
39
+
40
+ The following methods are available:
41
+ - `.generate()` - Generate text
42
+ - `.get_length()` - Get the length of a given text in tokens
43
+ - `.ingest()` - Ingest text into the model's cache
44
+ - `.next_candidates()` - Get a list of the most likely next tokens (WIP)
45
+ - `.stream()` - Return a Generator that can stream text as it is generated
46
+ - `.stream_print()` - Print text as it is generated
47
+ - `.trim()` - Trim a given text to the model's context length
48
+ - `.unload()` - Unload the model from memory
49
+
50
+ The following attributes are available:
51
+ - `.bos_token` - The model's beginning-of-stream token ID
52
+ - `.context_length` - The model's loaded context length
53
+ - `.flash_attn` - Whether the model was loaded with `flash_attn=True`
54
+ - `.eos_token` - The model's end-of-stream token ID
55
+ - `.llama` - The underlying `llama_cpp.Llama` instance
56
+ - `.metadata` - The GGUF metadata of the model
57
+ - `.n_ctx_train` - The native context length of the model
58
+ - `.rope_freq_base` - The model's loaded RoPE frequency base
59
+ - `.rope_freq_base_train` - The model's native RoPE frequency base
60
+ - `.tokens` - A list of all the tokens in the model's tokenizer
61
+ - `.verbose` - Whether the model was loaded with `verbose=True`
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ model_path: str,
67
+ context_length: Optional[int] = None,
68
+ n_gpu_layers: int = 0,
69
+ offload_kqv: bool = True,
70
+ flash_attn: bool = False,
71
+ verbose: bool = False
72
+ ):
73
+ """
74
+ Given the path to a GGUF file, construct a Model instance.
75
+
76
+ The model must be in GGUF format.
77
+
78
+ The following parameters are optional:
79
+ - context_length: The context length at which to load the model, in tokens
80
+ - n_gpu_layers: The number of layers to be offloaded to the GPU
81
+ - offload_kqv: Whether the KQV cache (context) should be offloaded
82
+ - flash_attn: Whether to use Flash Attention
83
+ - verbose: Whether to print additional backend information
84
+ """
85
+
86
+ if verbose:
87
+ print_verbose(f"webscout.Local package version: {__version__}")
88
+ print_verbose(f"llama_cpp package version: {__llama_cpp_version__}")
89
+
90
+ assert isinstance(model_path, str), \
91
+ f"Model: model_path should be a string, not {type(model_path)}"
92
+ assert exists(model_path), \
93
+ f"Model: the given model_path '{model_path}' does not exist"
94
+ assert not isdir(model_path), \
95
+ f"Model: the given model_path '{model_path}' is a directory, not a GGUF file"
96
+ assert isinstance(context_length, (int, type(None))), \
97
+ f"Model: context_length should be int or None, not {type(context_length)}"
98
+ assert isinstance(flash_attn, bool), \
99
+ f"Model: flash_attn should be bool (True or False), not {type(flash_attn)}"
100
+
101
+ # save __init__ parameters for __repr__
102
+ self._model_path = model_path
103
+ self._context_length = context_length
104
+ self._n_gpu_layers = n_gpu_layers
105
+ self._offload_kqv = offload_kqv
106
+ self._flash_attn = flash_attn
107
+ self._verbose = self.verbose = verbose
108
+
109
+ # if context_length <= 0, use n_ctx_train
110
+ if isinstance(context_length, int) and context_length <= 0:
111
+ context_length = None
112
+
113
+ # this does not use Llama.metadata because we want to use GGUF
114
+ # metadata to determine some parameters of the Llama instance
115
+ # before it is created
116
+ self.metadata = GGUFReader.load_metadata(self, model_path)
117
+ metadata_keys = self.metadata.keys() # only read once
118
+
119
+ n_ctx_train = None
120
+ for key in metadata_keys:
121
+ if key.endswith('.context_length'):
122
+ n_ctx_train = self.metadata[key]
123
+ break
124
+
125
+ if n_ctx_train is None:
126
+ raise KeyError(
127
+ "GGUF file does not specify a context length"
128
+ )
129
+
130
+ rope_freq_base_train = None
131
+ for key in metadata_keys:
132
+ if key.endswith('.rope.freq_base'):
133
+ rope_freq_base_train = self.metadata[key]
134
+ break
135
+
136
+ if rope_freq_base_train is None and context_length is not None:
137
+ if context_length > n_ctx_train:
138
+ raise ValueError(
139
+ 'unable to load model with greater than native ' + \
140
+ f'context length ({context_length} > {n_ctx_train}) ' + \
141
+ 'because model does not specify freq_base. ' + \
142
+ f'try again with `context_length={n_ctx_train}`'
143
+ )
144
+
145
+ if rope_freq_base_train is None or context_length is None or \
146
+ context_length <= n_ctx_train:
147
+ # no need to do context scaling, load model normally
148
+
149
+ if context_length is None:
150
+ self.context_length = n_ctx_train
151
+ else:
152
+ self.context_length = context_length
153
+ rope_freq_base = rope_freq_base_train
154
+
155
+ elif context_length > n_ctx_train:
156
+ # multiply rope_freq_base according to requested context length
157
+ # because context length > n_ctx_train and rope freq base is known
158
+
159
+ rope_freq_base = (context_length/n_ctx_train)*rope_freq_base_train
160
+ self.context_length = context_length
161
+
162
+ if self.verbose:
163
+ print_verbose(
164
+ 'chosen context length is greater than native context '
165
+ f'length ({context_length} > {n_ctx_train}), '
166
+ 'rope_freq_base will be changed from '
167
+ f'{rope_freq_base_train} to {rope_freq_base}'
168
+ )
169
+
170
+ if 2 <= context_length/n_ctx_train < 4:
171
+ print_warning(
172
+ 'loading model with 2x native context length or more, '
173
+ 'expect small loss of quality'
174
+ )
175
+
176
+ elif 4 <= context_length/n_ctx_train < 8:
177
+ print_warning(
178
+ 'loading model with 4x native context length or more, '
179
+ 'expect moderate loss of quality'
180
+ )
181
+
182
+ elif context_length/n_ctx_train >= 8:
183
+ print_warning(
184
+ 'loading model with 8x native context length or more, '
185
+ 'expect SIGNIFICANT loss of quality'
186
+ )
187
+
188
+ try:
189
+ self.tokens: list[str] = self.metadata['tokenizer.ggml.tokens']
190
+ except KeyError:
191
+ print_warning(
192
+ "could not set Model.tokens, defaulting to None"
193
+ )
194
+ self.tokens = None
195
+ try:
196
+ self.bos_token: int = self.metadata['tokenizer.ggml.bos_token_id']
197
+ except KeyError:
198
+ print_warning(
199
+ "could not set Model.bos_token, defaulting to None"
200
+ )
201
+ self.bos_token = None
202
+ try:
203
+ self.eos_token: int = self.metadata['tokenizer.ggml.eos_token_id']
204
+ except KeyError:
205
+ print_warning(
206
+ "could not set Model.eos_token, defaulting to None"
207
+ )
208
+ self.eos_token = None
209
+
210
+ cpu_count = os_cpu_count()
211
+
212
+ # these values for n_threads and n_threads_batch are
213
+ # known to be optimal for most systems
214
+ n_batch = 512 # can this be optimized?
215
+ n_threads = max(cpu_count//2, 1)
216
+ n_threads_batch = cpu_count
217
+
218
+ if flash_attn and n_gpu_layers == 0:
219
+ print_warning(
220
+ "disabling flash_attn because n_gpu_layers == 0"
221
+ )
222
+ flash_attn = False
223
+
224
+ # guard against models with no rope_freq_base
225
+ if rope_freq_base is None:
226
+ rope_freq_base = 0
227
+
228
+ self.llama: Llama = Llama(
229
+ model_path=model_path,
230
+ n_ctx=self.context_length,
231
+ n_gpu_layers=n_gpu_layers,
232
+ use_mmap=True,
233
+ use_mlock=False,
234
+ logits_all=False,
235
+ n_batch=n_batch,
236
+ n_threads=n_threads,
237
+ n_threads_batch=n_threads_batch,
238
+ rope_freq_base=rope_freq_base,
239
+ mul_mat_q=True,
240
+ offload_kqv=offload_kqv,
241
+ flash_attn=flash_attn,
242
+ # KV cache quantization
243
+ # use 1 for F16 (default), 8 for q8_0, 2 for q4_0, 3 for q4_1
244
+ #type_k=8,
245
+ #type_v=8,
246
+ verbose=verbose
247
+ )
248
+
249
+ # once model is loaded, replace metadata (as read using internal class)
250
+ # with metadata (as read using the more robust llama-cpp-python code)
251
+ self.metadata = self.llama.metadata
252
+
253
+ # expose these values because they may be useful / informative
254
+ self.n_ctx_train = n_ctx_train
255
+ self.rope_freq_base_train = rope_freq_base_train
256
+ self.rope_freq_base = rope_freq_base
257
+ self.flash_attn = flash_attn
258
+
259
+ if self.verbose:
260
+ print_verbose("new Model instance with the following attributes:")
261
+ print_verbose(f"model: {model_path}")
262
+ print_verbose(f"param: n_gpu_layers == {n_gpu_layers}")
263
+ print_verbose(f"param: offload_kqv == {offload_kqv}")
264
+ print_verbose(f"param: flash_attn == {flash_attn}")
265
+ print_verbose(f"param: n_batch == {n_batch}")
266
+ print_verbose(f"param: n_threads == {n_threads}")
267
+ print_verbose(f"param: n_threads_batch == {n_threads_batch}")
268
+ print_verbose(f" gguf: n_ctx_train == {n_ctx_train}")
269
+ print_verbose(f"param: self.context_length == {self.context_length}")
270
+ print_verbose(f" gguf: rope_freq_base_train == {rope_freq_base_train}")
271
+ print_verbose(f"param: rope_freq_base == {rope_freq_base}")
272
+
273
+ def __repr__(self) -> str:
274
+ return \
275
+ f"Model({repr(self._model_path)}, " + \
276
+ f"context_length={self._context_length}, " + \
277
+ f"n_gpu_layers={self._n_gpu_layers}, " + \
278
+ f"offload_kqv={self._offload_kqv}, "+ \
279
+ f"flash_attn={self._flash_attn}, " + \
280
+ f"verbose={self._verbose})"
281
+
282
+ def __del__(self):
283
+ self.unload()
284
+
285
+ def __enter__(self):
286
+ return self
287
+
288
+ def __exit__(self, *_):
289
+ self.unload()
290
+
291
+ def __call__(
292
+ self,
293
+ prompt: Union[str, list[int]],
294
+ stops: list[Union[str, int]] = [],
295
+ sampler: SamplerSettings = DefaultSampling
296
+ ) -> str:
297
+ """
298
+ `Model(...)` is a shorthand for `Model.generate(...)`
299
+ """
300
+ return self.generate(prompt, stops, sampler)
301
+
302
+ def unload(self):
303
+ """
304
+ Unload the model from memory
305
+ """
306
+ # ref: llama_cpp._internals._LlamaModel.__del__()
307
+ if not hasattr(self, 'llama'):
308
+ # nothing can be done
309
+ return
310
+ try:
311
+ if self.llama._model.model is not None:
312
+ # actually unload the model from memory
313
+ self.llama._model._llama_free_model(self.llama._model.model)
314
+ self.llama._model.model = None
315
+ except AttributeError:
316
+ # broken or already being destroyed by GC, abort
317
+ return
318
+ if hasattr(self, 'llama'):
319
+ delattr(self, 'llama')
320
+ if self.verbose:
321
+ print_verbose('Model unloaded')
322
+
323
+ def trim(
324
+ self,
325
+ text: str,
326
+ overwrite: Optional[str] = None
327
+ ) -> str:
328
+
329
+ """
330
+ Trim the given text to the context length of this model,
331
+ leaving room for two extra tokens.
332
+
333
+ Optionally overwrite the oldest tokens with the text given in the
334
+ `overwrite` parameter, which may be useful for keeping some
335
+ information in context.
336
+
337
+ Does nothing if the text is equal to or shorter than
338
+ (context_length - 2).
339
+ """
340
+ assert_model_is_loaded(self)
341
+ trim_length = self.context_length - 2
342
+ tokens_list = self.llama.tokenize(
343
+ text.encode("utf-8", errors="ignore")
344
+ )
345
+
346
+ if len(tokens_list) <= trim_length:
347
+ if overwrite is not None:
348
+ text[0 : len(overwrite)] = overwrite
349
+ return text
350
+
351
+ if len(tokens_list) > trim_length and overwrite is None:
352
+ # cut to trim_length
353
+ tokens_list = tokens_list[-trim_length:]
354
+ return self.llama.detokenize(tokens_list).decode(
355
+ "utf-8",
356
+ errors="ignore"
357
+ )
358
+
359
+ if len(tokens_list) > trim_length and overwrite is not None:
360
+ # cut to trim_length
361
+ tokens_list = tokens_list[-trim_length:]
362
+ overwrite_tokens = self.llama.tokenize(overwrite.encode(
363
+ "utf-8",
364
+ errors="ignore"
365
+ )
366
+ )
367
+ # overwrite oldest tokens
368
+ tokens_list[0 : len(overwrite_tokens)] = overwrite_tokens
369
+ return self.llama.detokenize(tokens_list).decode(
370
+ "utf-8",
371
+ errors="ignore"
372
+ )
373
+
374
+ def get_length(self, text: str) -> int:
375
+ """
376
+ Return the length of the given text in tokens according to this model,
377
+ including the appended BOS token.
378
+ """
379
+ assert_model_is_loaded(self)
380
+ return len(self.llama.tokenize(
381
+ text.encode(
382
+ "utf-8",
383
+ errors="ignore"
384
+ )
385
+ ))
386
+
387
+ def generate(
388
+ self,
389
+ prompt: Union[str, list[int]],
390
+ stops: list[Union[str, int]] = [],
391
+ sampler: SamplerSettings = DefaultSampling
392
+ ) -> str:
393
+ """
394
+ Given a prompt, return a generated string.
395
+
396
+ prompt: The text from which to generate
397
+
398
+ The following parameters are optional:
399
+ - stops: A list of strings and/or token IDs at which to end the generation early
400
+ - sampler: The SamplerSettings object used to control text generation
401
+ """
402
+
403
+ assert isinstance(prompt, (str, list)), \
404
+ f"generate: prompt should be string or list[int], not {type(prompt)}"
405
+ if isinstance(prompt, list):
406
+ assert all(isinstance(tok, int) for tok in prompt), \
407
+ "generate: some token in prompt is not an integer"
408
+ assert isinstance(stops, list), \
409
+ f"generate: parameter `stops` should be a list, not {type(stops)}"
410
+ assert all(isinstance(item, (str, int)) for item in stops), \
411
+ f"generate: some item in parameter `stops` is not a string or int"
412
+
413
+ if self.verbose:
414
+ print_verbose(f'using the following sampler settings for Model.generate:')
415
+ print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
416
+ print_verbose(f'temp == {sampler.temp}')
417
+ print_verbose(f'top_p == {sampler.top_p}')
418
+ print_verbose(f'min_p == {sampler.min_p}')
419
+ print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
420
+ print_verbose(f'presence_penalty == {sampler.presence_penalty}')
421
+ print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
422
+ print_verbose(f'top_k == {sampler.top_k}')
423
+
424
+ # if any stop item is a token ID (int)
425
+ if any(isinstance(stop, int) for stop in stops):
426
+ # stop_strs is a list of all stopping strings
427
+ stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
428
+ # stop_token_ids is a list of all stop token IDs
429
+ stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
430
+ def stop_on_token_ids(tokens, *args, **kwargs):
431
+ return tokens[-1] in stop_token_ids
432
+ stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
433
+ assert_model_is_loaded(self)
434
+ return self.llama.create_completion(
435
+ prompt,
436
+ max_tokens=sampler.max_len_tokens,
437
+ temperature=sampler.temp,
438
+ top_p=sampler.top_p,
439
+ min_p=sampler.min_p,
440
+ frequency_penalty=sampler.frequency_penalty,
441
+ presence_penalty=sampler.presence_penalty,
442
+ repeat_penalty=sampler.repeat_penalty,
443
+ top_k=sampler.top_k,
444
+ stop=stop_strs,
445
+ stopping_criteria=stopping_criteria
446
+ )['choices'][0]['text']
447
+
448
+ # if stop items are only strings
449
+ assert_model_is_loaded(self)
450
+ return self.llama.create_completion(
451
+ prompt,
452
+ max_tokens=sampler.max_len_tokens,
453
+ temperature=sampler.temp,
454
+ top_p=sampler.top_p,
455
+ min_p=sampler.min_p,
456
+ frequency_penalty=sampler.frequency_penalty,
457
+ presence_penalty=sampler.presence_penalty,
458
+ repeat_penalty=sampler.repeat_penalty,
459
+ top_k=sampler.top_k,
460
+ stop=stops
461
+ )['choices'][0]['text']
462
+
463
+
464
+ def stream(
465
+ self,
466
+ prompt: Union[str, list[int]],
467
+ stops: list[Union[str, int]] = [],
468
+ sampler: SamplerSettings = DefaultSampling
469
+ ) -> Generator:
470
+
471
+ """
472
+ Given a prompt, return a Generator that yields dicts containing tokens.
473
+
474
+ To get the token string itself, subscript the dict with:
475
+
476
+ `['choices'][0]['text']`
477
+
478
+ prompt: The text from which to generate
479
+
480
+ The following parameters are optional:
481
+ - stops: A list of strings and/or token IDs at which to end the generation early
482
+ - sampler: The SamplerSettings object used to control text generation
483
+ """
484
+
485
+ assert isinstance(prompt, (str, list)), \
486
+ f"stream: prompt should be string or list[int], not {type(prompt)}"
487
+ if isinstance(prompt, list):
488
+ assert all(isinstance(tok, int) for tok in prompt), \
489
+ "stream: some token in prompt is not an integer"
490
+ assert isinstance(stops, list), \
491
+ f"stream: parameter `stops` should be a list, not {type(stops)}"
492
+ assert all(isinstance(item, (str, int)) for item in stops), \
493
+ f"stream: some item in parameter `stops` is not a string or int"
494
+
495
+ if self.verbose:
496
+ print_verbose(f'using the following sampler settings for Model.stream:')
497
+ print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
498
+ print_verbose(f'temp == {sampler.temp}')
499
+ print_verbose(f'top_p == {sampler.top_p}')
500
+ print_verbose(f'min_p == {sampler.min_p}')
501
+ print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
502
+ print_verbose(f'presence_penalty == {sampler.presence_penalty}')
503
+ print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
504
+ print_verbose(f'top_k == {sampler.top_k}')
505
+
506
+ # if any stop item is a token ID (int)
507
+ if any(isinstance(stop, int) for stop in stops):
508
+ # stop_strs is a list of all stopping strings
509
+ stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
510
+ # stop_token_ids is a list of all stop token IDs
511
+ stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
512
+ def stop_on_token_ids(tokens, *args, **kwargs):
513
+ return tokens[-1] in stop_token_ids
514
+ stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
515
+ assert_model_is_loaded(self)
516
+ return self.llama.create_completion(
517
+ prompt,
518
+ max_tokens=sampler.max_len_tokens,
519
+ temperature=sampler.temp,
520
+ top_p=sampler.top_p,
521
+ min_p=sampler.min_p,
522
+ frequency_penalty=sampler.frequency_penalty,
523
+ presence_penalty=sampler.presence_penalty,
524
+ repeat_penalty=sampler.repeat_penalty,
525
+ top_k=sampler.top_k,
526
+ stream=True,
527
+ stop=stop_strs,
528
+ stopping_criteria=stopping_criteria
529
+ )
530
+
531
+ assert_model_is_loaded(self)
532
+ return self.llama.create_completion(
533
+ prompt,
534
+ max_tokens=sampler.max_len_tokens,
535
+ temperature=sampler.temp,
536
+ top_p=sampler.top_p,
537
+ min_p=sampler.min_p,
538
+ frequency_penalty=sampler.frequency_penalty,
539
+ presence_penalty=sampler.presence_penalty,
540
+ repeat_penalty=sampler.repeat_penalty,
541
+ top_k=sampler.top_k,
542
+ stream=True,
543
+ stop=stops
544
+ )
545
+
546
+
547
+ def stream_print(
548
+ self,
549
+ prompt: Union[str, list[int]],
550
+ stops: list[Union[str, int]] = [],
551
+ sampler: SamplerSettings = DefaultSampling,
552
+ end: str = "\n",
553
+ file: _SupportsWriteAndFlush = sys.stdout,
554
+ flush: bool = True
555
+ ) -> str:
556
+ """
557
+ Given a prompt, stream text as it is generated, and return the generated string.
558
+ The returned string does not include the `end` parameter.
559
+
560
+ `Model.stream_print(...)` is a shorthand for:
561
+
562
+ ```
563
+ s = Model.stream(prompt, stops=stops, sampler=sampler)
564
+ for i in s:
565
+ tok = i['choices'][0]['text']
566
+ print(tok, end='', file=file, flush=flush)
567
+ print(end, end='', file=file, flush=True)
568
+ ```
569
+
570
+ prompt: The text from which to generate
571
+
572
+ The following parameters are optional:
573
+ - stops: A list of strings and/or token IDs at which to end the generation early
574
+ - sampler: The SamplerSettings object used to control text generation
575
+ - end: A string to print after the generated text
576
+ - file: The file where text should be printed
577
+ - flush: Whether to flush the stream after each token
578
+ """
579
+
580
+ token_generator = self.stream(
581
+ prompt=prompt,
582
+ stops=stops,
583
+ sampler=sampler
584
+ )
585
+
586
+ res = ''
587
+ for i in token_generator:
588
+ tok = i['choices'][0]['text']
589
+ print(tok, end='', file=file, flush=flush)
590
+ res += tok
591
+
592
+ # print `end`, and always flush stream after generation is done
593
+ print(end, end='', file=file, flush=True)
594
+
595
+ return res
596
+
597
+
598
+ def ingest(self, text: str) -> None:
599
+ """
600
+ Ingest the given text into the model's cache
601
+ """
602
+
603
+ assert_model_is_loaded(self)
604
+ self.llama.create_completion(
605
+ text,
606
+ max_tokens=1,
607
+ temperature=0.0
608
+ )
609
+
610
+
611
+ def candidates(
612
+ self,
613
+ prompt: str,
614
+ k: int
615
+ ) -> list[tuple[str, np.float64]]:
616
+ """
617
+ Given prompt `str` and k `int`, return a sorted list of the
618
+ top k candidates for most likely next token, along with their
619
+ normalized probabilities
620
+ """
621
+
622
+ assert isinstance(prompt, str), \
623
+ f"next_candidates: prompt should be str, not {type(prompt)}"
624
+ assert isinstance(k, int), \
625
+ f"next_candidates: k should be int, not {type(k)}"
626
+ assert 0 < k <= len(self.tokens), \
627
+ f"next_candidates: k should be between 0 and {len(self.tokens)}"
628
+
629
+ assert_model_is_loaded(self)
630
+ prompt_tokens = self.llama.tokenize(prompt.encode('utf-8', errors='ignore'))
631
+ self.llama.reset() # reset model state
632
+ self.llama.eval(prompt_tokens)
633
+ scores = self.llama.scores[len(prompt_tokens) - 1]
634
+
635
+ # len(self.llama.scores) == self.context_length
636
+ # len(self.llama.scores[i]) == len(self.tokens)
637
+
638
+ # normalize scores with softmax
639
+ # must normalize over all tokens in vocab, not just top k
640
+ if self.verbose:
641
+ print_verbose(f'calculating softmax over {len(scores)} values')
642
+ normalized_scores: list[np.float64] = list(softmax(scores))
643
+
644
+ # construct the final list
645
+ i = 0
646
+ token_probs_list: list[tuple[str, np.float64]] = []
647
+ for tok_str in self.tokens:
648
+ token_probs_list.append((tok_str, normalized_scores[i]))
649
+ i += 1
650
+
651
+ # return token_probs_list, sorted by probability, only top k
652
+ return nlargest(k, token_probs_list, key=lambda x:x[1])
653
+
654
+
655
+ def print_candidates(
656
+ self,
657
+ prompt: str,
658
+ k: int,
659
+ file: _SupportsWriteAndFlush = sys.stdout,
660
+ flush: bool = False
661
+ ) -> None:
662
+ """
663
+ Like `Model.candidates()`, but print the values instead
664
+ of returning them
665
+ """
666
+
667
+ for _tuple in self.candidates(prompt, k):
668
+ print(
669
+ f"token '{_tuple[0]}' has probability {_tuple[1]}",
670
+ file=file,
671
+ flush=flush
672
+ )
673
+
674
+ # if flush is False, then so far file is not flushed, but it should
675
+ # always be flushed at the end of printing
676
+ if not flush:
677
+ file.flush()
678
+
679
+
680
+ def assert_model_is_loaded(model: Model) -> None:
681
+ """
682
+ Ensure the Model is fully constructed, such that
683
+ `Model.llama._model.model is not None` is guaranteed to be `True`
684
+
685
+ Raise ModelUnloadedException otherwise
686
+ """
687
+ if not hasattr(model, 'llama'):
688
+ raise ModelUnloadedException(
689
+ "webscout.Local.Model instance has no attribute 'llama'"
690
+ )
691
+ if not hasattr(model.llama, '_model'):
692
+ raise ModelUnloadedException(
693
+ "llama_cpp.Llama instance has no attribute '_model'"
694
+ )
695
+ if not hasattr(model.llama._model, 'model'):
696
+ raise ModelUnloadedException(
697
+ "llama_cpp._internals._LlamaModel instance has no attribute 'model'"
698
+ )
699
+ if model.llama._model.model is None:
700
+ raise ModelUnloadedException(
701
+ "llama_cpp._internals._LlamaModel.model is None"
702
+ )