webscout 2.2b0__py3-none-any.whl → 2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of webscout might be problematic. Click here for more details.

@@ -0,0 +1,641 @@
1
+ # thread.py
2
+ # https://github.com/ddh0/easy-llama/
3
+ from ._version import __version__, __llama_cpp_version__
4
+
5
+ """Submodule containing the Thread class, used for interaction with a Model"""
6
+
7
+ import sys
8
+
9
+ from .model import Model, assert_model_is_loaded, _SupportsWriteAndFlush
10
+ from .utils import RESET_ALL, cls, print_verbose, truncate
11
+ from .samplers import SamplerSettings, DefaultSampling
12
+ from typing import Optional, Literal, Union
13
+
14
+ from .formats import blank as formats_blank
15
+
16
+
17
+ class Thread:
18
+ """
19
+ Provide functionality to facilitate easy interactions with a Model
20
+
21
+ This is just a brief overview of webscout.Local.Thread.
22
+ To see a full description of each method and its parameters,
23
+ call help(Thread), or see the relevant docstring.
24
+
25
+ The following methods are available:
26
+ - `.add_message()` - Add a message to `Thread.messages`
27
+ - `.as_string()` - Return this thread's complete message history as a string
28
+ - `.create_message()` - Create a message using the format of this thread
29
+ - `.inference_str_from_messages()` - Using the list of messages, return a string suitable for inference
30
+ - `.interact()` - Start an interactive, terminal-based chat session
31
+ - `.len_messages()` - Get the total length of all messages in tokens
32
+ - `.print_stats()` - Print stats about the context usage in this thread
33
+ - `.reset()` - Clear the list of messages
34
+ - `.send()` - Send a message in this thread
35
+
36
+ The following attributes are available:
37
+ - `.format` - The format being used for messages in this thread
38
+ - `.messages` - The list of messages in this thread
39
+ - `.model` - The `webscout.Local.Model` instance used by this thread
40
+ - `.sampler` - The SamplerSettings object used in this thread
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ model: Model,
46
+ format: dict[str, Union[str, list]],
47
+ sampler: SamplerSettings = DefaultSampling
48
+ ):
49
+ """
50
+ Given a Model and a format, construct a Thread instance.
51
+
52
+ model: The Model to use for text generation
53
+ format: The format specifying how messages should be structured (see webscout.Local.formats)
54
+
55
+ The following parameter is optional:
56
+ - sampler: The SamplerSettings object used to control text generation
57
+ """
58
+
59
+ assert isinstance(model, Model), \
60
+ "Thread: model should be an " + \
61
+ f"instance of webscout.Local.Model, not {type(model)}"
62
+
63
+ assert_model_is_loaded(model)
64
+
65
+ assert isinstance(format, dict), \
66
+ f"Thread: format should be dict, not {type(format)}"
67
+
68
+ if any(k not in format.keys() for k in formats_blank.keys()):
69
+ raise KeyError(
70
+ "Thread: format is missing one or more required keys, see " + \
71
+ "webscout.Local.formats.blank for an example"
72
+ )
73
+
74
+ assert isinstance(format['stops'], list), \
75
+ "Thread: format['stops'] should be list, not " + \
76
+ f"{type(format['stops'])}"
77
+
78
+ assert all(
79
+ hasattr(sampler, attr) for attr in [
80
+ 'max_len_tokens',
81
+ 'temp',
82
+ 'top_p',
83
+ 'min_p',
84
+ 'frequency_penalty',
85
+ 'presence_penalty',
86
+ 'repeat_penalty',
87
+ 'top_k'
88
+ ]
89
+ ), 'Thread: sampler is missing one or more required attributes'
90
+
91
+ self.model: Model = model
92
+ self.format: dict[str, Union[str, list]] = format
93
+ self.messages: list[dict[str, str]] = [
94
+ self.create_message("system", self.format['system_content'])
95
+ ]
96
+ self.sampler: SamplerSettings = sampler
97
+
98
+ if self.model.verbose:
99
+ print_verbose("new Thread instance with the following attributes:")
100
+ print_verbose(f"model == {self.model}")
101
+ print_verbose(f"format['system_prefix'] == {truncate(repr(self.format['system_prefix']))}")
102
+ print_verbose(f"format['system_content'] == {truncate(repr(self.format['system_content']))}")
103
+ print_verbose(f"format['system_postfix'] == {truncate(repr(self.format['system_postfix']))}")
104
+ print_verbose(f"format['user_prefix'] == {truncate(repr(self.format['user_prefix']))}")
105
+ print_verbose(f"format['user_content'] == {truncate(repr(self.format['user_content']))}")
106
+ print_verbose(f"format['user_postfix'] == {truncate(repr(self.format['user_postfix']))}")
107
+ print_verbose(f"format['bot_prefix'] == {truncate(repr(self.format['bot_prefix']))}")
108
+ print_verbose(f"format['bot_content'] == {truncate(repr(self.format['bot_content']))}")
109
+ print_verbose(f"format['bot_postfix'] == {truncate(repr(self.format['bot_postfix']))}")
110
+ print_verbose(f"format['stops'] == {truncate(repr(self.format['stops']))}")
111
+ print_verbose(f"sampler.temp == {self.sampler.temp}")
112
+ print_verbose(f"sampler.top_p == {self.sampler.top_p}")
113
+ print_verbose(f"sampler.min_p == {self.sampler.min_p}")
114
+ print_verbose(f"sampler.frequency_penalty == {self.sampler.frequency_penalty}")
115
+ print_verbose(f"sampler.presence_penalty == {self.sampler.presence_penalty}")
116
+ print_verbose(f"sampler.repeat_penalty == {self.sampler.repeat_penalty}")
117
+ print_verbose(f"sampler.top_k == {self.sampler.top_k}")
118
+
119
+
120
+ def __repr__(self) -> str:
121
+ repr_str = f"Thread({repr(self.model)}, {repr(self.format)}, "
122
+ repr_str += f"{repr(self.sampler)})"
123
+ # system message is created from format, so not represented
124
+ if len(self.messages) <= 1:
125
+ return repr_str
126
+ else:
127
+ for msg in self.messages:
128
+ if msg['role'] == 'user':
129
+ repr_str += "\nThread.add_message('user', " + repr(msg['content']) + ')'
130
+ elif msg['role'] == 'bot':
131
+ repr_str += "\nThread.add_message('bot', " + repr(msg['content']) + ')'
132
+ return repr_str
133
+
134
+ def __str__(self) -> str:
135
+ return self.as_string()
136
+
137
+ def __len__(self) -> int:
138
+ """
139
+ `len(Thread)` returns the length of the Thread in tokens
140
+
141
+ To get the number of messages in the Thread, use `len(Thread.messages)`
142
+ """
143
+ return self.len_messages()
144
+
145
+ def create_message(
146
+ self,
147
+ role: Literal['system', 'user', 'bot'],
148
+ content: str
149
+ ) -> dict[str, str]:
150
+ """
151
+ Create a message using the format of this Thread
152
+ """
153
+
154
+ assert role.lower() in ['system', 'user', 'bot'], \
155
+ f"create_message: role should be 'system', 'user', or 'bot', not '{role.lower()}'"
156
+
157
+ assert isinstance(content, str), \
158
+ f"create_message: content should be str, not {type(content)}"
159
+
160
+ if role.lower() == 'system':
161
+ return {
162
+ "role": "system",
163
+ "prefix": self.format['system_prefix'],
164
+ "content": content,
165
+ "postfix": self.format['system_postfix']
166
+ }
167
+
168
+ elif role.lower() == 'user':
169
+ return {
170
+ "role": "user",
171
+ "prefix": self.format['user_prefix'],
172
+ "content": content,
173
+ "postfix": self.format['user_postfix']
174
+ }
175
+
176
+ elif role.lower() == 'bot':
177
+ return {
178
+ "role": "bot",
179
+ "prefix": self.format['bot_prefix'],
180
+ "content": content,
181
+ "postfix": self.format['bot_postfix']
182
+ }
183
+
184
+ def len_messages(self) -> int:
185
+ """
186
+ Return the total length of all messages in this thread, in tokens.
187
+
188
+ Equivalent to `len(Thread)`."""
189
+
190
+ return self.model.get_length(self.as_string())
191
+
192
+ def add_message(
193
+ self,
194
+ role: Literal['system', 'user', 'bot'],
195
+ content: str
196
+ ) -> None:
197
+ """
198
+ Create a message and append it to `Thread.messages`.
199
+
200
+ `Thread.add_message(...)` is a shorthand for
201
+ `Thread.messages.append(Thread.create_message(...))`
202
+ """
203
+ self.messages.append(
204
+ self.create_message(
205
+ role=role,
206
+ content=content
207
+ )
208
+ )
209
+
210
+ def inference_str_from_messages(self) -> str:
211
+ """
212
+ Using the list of messages, construct a string suitable for inference,
213
+ respecting the format and context length of this thread.
214
+ """
215
+
216
+ messages = self.messages
217
+
218
+ context_len_budget = self.model.context_length
219
+ if len(messages) > 0:
220
+ sys_msg = messages[0]
221
+ sys_msg_str = (
222
+ sys_msg['prefix'] + sys_msg['content'] + sys_msg['postfix']
223
+ )
224
+ context_len_budget -= self.model.get_length(sys_msg_str)
225
+ else:
226
+ sys_msg_str = ''
227
+
228
+ inf_str = ''
229
+
230
+ # Start at most recent message and work backwards up the history
231
+ # excluding system message. Once we exceed thread
232
+ # max_context_length, break without including that message
233
+ for message in reversed(messages[1:]):
234
+ context_len_budget -= self.model.get_length(
235
+ message['prefix'] + message['content'] + message['postfix']
236
+ )
237
+
238
+ if context_len_budget <= 0:
239
+ break
240
+
241
+ msg_str = (
242
+ message['prefix'] + message['content'] + message['postfix']
243
+ )
244
+
245
+ inf_str = msg_str + inf_str
246
+
247
+ inf_str = sys_msg_str + inf_str
248
+ inf_str += self.format['bot_prefix']
249
+
250
+ return inf_str
251
+
252
+
253
+ def send(self, prompt: str) -> str:
254
+ """
255
+ Send a message in this thread. This adds your message and the bot's
256
+ response to the list of messages.
257
+
258
+ Returns a string containing the response to your message.
259
+ """
260
+
261
+ self.add_message("user", prompt)
262
+ output = self.model.generate(
263
+ self.inference_str_from_messages(),
264
+ stops=self.format['stops'],
265
+ sampler=self.sampler
266
+ )
267
+ self.add_message("bot", output)
268
+
269
+ return output
270
+
271
+
272
+ def _interactive_update_sampler(self) -> None:
273
+ """Interactively update the sampler settings used in this Thread"""
274
+ print()
275
+ try:
276
+ new_max_len_tokens = input(f'max_len_tokens: {self.sampler.max_len_tokens} -> ')
277
+ new_temp = input(f'temp: {self.sampler.temp} -> ')
278
+ new_top_p = input(f'top_p: {self.sampler.top_p} -> ')
279
+ new_min_p = input(f'min_p: {self.sampler.min_p} -> ')
280
+ new_frequency_penalty = input(f'frequency_penalty: {self.sampler.frequency_penalty} -> ')
281
+ new_presence_penalty = input(f'presence_penalty: {self.sampler.presence_penalty} -> ')
282
+ new_repeat_penalty = input(f'repeat_penalty: {self.sampler.repeat_penalty} -> ')
283
+ new_top_k = input(f'top_k: {self.sampler.top_k} -> ')
284
+
285
+ except KeyboardInterrupt:
286
+ print('\nwebscout.Local: sampler settings not updated\n')
287
+ return
288
+ print()
289
+
290
+ try:
291
+ self.sampler.max_len_tokens = int(new_max_len_tokens)
292
+ except ValueError:
293
+ pass
294
+ else:
295
+ print('webscout.Local: max_len_tokens updated')
296
+
297
+ try:
298
+ self.sampler.temp = float(new_temp)
299
+ except ValueError:
300
+ pass
301
+ else:
302
+ print('webscout.Local: temp updated')
303
+
304
+ try:
305
+ self.sampler.top_p = float(new_top_p)
306
+ except ValueError:
307
+ pass
308
+ else:
309
+ print('webscout.Local: top_p updated')
310
+
311
+ try:
312
+ self.sampler.min_p = float(new_min_p)
313
+ except ValueError:
314
+ pass
315
+ else:
316
+ print('webscout.Local: min_p updated')
317
+
318
+ try:
319
+ self.sampler.frequency_penalty = float(new_frequency_penalty)
320
+ except ValueError:
321
+ pass
322
+ else:
323
+ print('webscout.Local: frequency_penalty updated')
324
+
325
+ try:
326
+ self.sampler.presence_penalty = float(new_presence_penalty)
327
+ except ValueError:
328
+ pass
329
+ else:
330
+ print('webscout.Local: presence_penalty updated')
331
+
332
+ try:
333
+ self.sampler.repeat_penalty = float(new_repeat_penalty)
334
+ except ValueError:
335
+ pass
336
+ else:
337
+ print('webscout.Local: repeat_penalty updated')
338
+
339
+ try:
340
+ self.sampler.top_k = int(new_top_k)
341
+ except ValueError:
342
+ pass
343
+ else:
344
+ print('webscout.Local: top_k updated')
345
+ print()
346
+
347
+
348
+ def _interactive_input(
349
+ self,
350
+ prompt: str,
351
+ _dim_style: str,
352
+ _user_style: str,
353
+ _bot_style: str
354
+ ) -> tuple:
355
+ """
356
+ Recive input from the user, while handling multi-line input
357
+ and commands
358
+ """
359
+ full_user_input = '' # may become multiline
360
+
361
+ while True:
362
+ user_input = input(prompt)
363
+
364
+ if user_input.endswith('\\'):
365
+ full_user_input += user_input[:-1] + '\n'
366
+
367
+ elif user_input == '!':
368
+
369
+ print()
370
+ try:
371
+ command = input(f'{RESET_ALL} ! {_dim_style}')
372
+ except KeyboardInterrupt:
373
+ print('\n')
374
+ continue
375
+
376
+ if command == '':
377
+ print(f'\n[no command]\n')
378
+
379
+ elif command.lower() in ['reset', 'restart']:
380
+ self.reset()
381
+ print(f'\n[thread reset]\n')
382
+
383
+ elif command.lower() in ['cls', 'clear']:
384
+ cls()
385
+ print()
386
+
387
+ elif command.lower() in ['ctx', 'context']:
388
+ print(f"\n{self.len_messages()}\n")
389
+
390
+ elif command.lower() in ['stats', 'print_stats']:
391
+ print()
392
+ self.print_stats()
393
+ print()
394
+
395
+ elif command.lower() in ['sampler', 'samplers', 'settings']:
396
+ self._interactive_update_sampler()
397
+
398
+ elif command.lower() in ['str', 'string', 'as_string']:
399
+ print(f"\n{self.as_string()}\n")
400
+
401
+ elif command.lower() in ['repr', 'save', 'backup']:
402
+ print(f"\n{repr(self)}\n")
403
+
404
+ elif command.lower() in ['remove', 'rem', 'delete', 'del']:
405
+ print()
406
+ old_len = len(self.messages)
407
+ del self.messages[-1]
408
+ assert len(self.messages) == (old_len - 1)
409
+ print('[removed last message]\n')
410
+
411
+ elif command.lower() in ['last', 'repeat']:
412
+ last_msg = self.messages[-1]
413
+ if last_msg['role'] == 'user':
414
+ print(f"\n{_user_style}{last_msg['content']}{RESET_ALL}\n")
415
+ elif last_msg['role'] == 'bot':
416
+ print(f"\n{_bot_style}{last_msg['content']}{RESET_ALL}\n")
417
+
418
+ elif command.lower() in ['inf', 'inference', 'inf_str']:
419
+ print(f'\n"""{self.inference_str_from_messages()}"""\n')
420
+
421
+ elif command.lower() in ['reroll', 're-roll', 're', 'swipe']:
422
+ old_len = len(self.messages)
423
+ del self.messages[-1]
424
+ assert len(self.messages) == (old_len - 1)
425
+ return '', None
426
+
427
+ elif command.lower() in ['exit', 'quit']:
428
+ print(RESET_ALL)
429
+ return None, None
430
+
431
+ elif command.lower() in ['help', '/?', '?']:
432
+ print()
433
+ print('reset / restart -- Reset the thread to its original state')
434
+ print('clear / cls -- Clear the terminal')
435
+ print('context / ctx -- Get the context usage in tokens')
436
+ print('print_stats / stats -- Get the context usage stats')
437
+ print('sampler / settings -- Update the sampler settings')
438
+ print('string / str -- Print the message history as a string')
439
+ print('repr / save -- Print the representation of the thread')
440
+ print('remove / delete -- Remove the last message')
441
+ print('last / repeat -- Repeat the last message')
442
+ print('inference / inf -- Print the inference string')
443
+ print('reroll / swipe -- Regenerate the last message')
444
+ print('exit / quit -- Exit the interactive chat (can also use ^C)')
445
+ print('help / ? -- Show this screen')
446
+ print()
447
+ print("TIP: type < at the prompt and press ENTER to prefix the bot's next message.")
448
+ print(' for example, type "Sure!" to bypass refusals')
449
+ print()
450
+
451
+ else:
452
+ print(f'\n[unknown command]\n')
453
+
454
+ elif user_input == '<': # the next bot message will start with...
455
+
456
+ print()
457
+ try:
458
+ next_message_start = input(f'{_dim_style} < ')
459
+
460
+ except KeyboardInterrupt:
461
+ print(f'{RESET_ALL}\n')
462
+ continue
463
+
464
+ else:
465
+ print()
466
+ return '', next_message_start
467
+
468
+ elif user_input.endswith('<'):
469
+
470
+ print()
471
+
472
+ msg = user_input.removesuffix('<')
473
+ self.add_message("user", msg)
474
+
475
+ try:
476
+ next_message_start = input(f'{_dim_style} < ')
477
+
478
+ except KeyboardInterrupt:
479
+ print(f'{RESET_ALL}\n')
480
+ continue
481
+
482
+ else:
483
+ print()
484
+ return '', next_message_start
485
+
486
+ else:
487
+ full_user_input += user_input
488
+ return full_user_input, None
489
+
490
+
491
+ def interact(
492
+ self,
493
+ color: bool = True,
494
+ header: Optional[str] = None,
495
+ stream: bool = True
496
+ ) -> None:
497
+ """
498
+ Start an interactive chat session using this Thread.
499
+
500
+ While text is being generated, press `^C` to interrupt the bot.
501
+ Then you have the option to press `ENTER` to re-roll, or to simply type
502
+ another message.
503
+
504
+ At the prompt, press `^C` to end the chat session.
505
+
506
+ Type `!` and press `ENTER` to enter a basic command prompt. For a list
507
+ of commands, type `help` at this prompt.
508
+
509
+ Type `<` and press `ENTER` to prefix the bot's next message, for
510
+ example with `Sure!`.
511
+
512
+ The following parameters are optional:
513
+ - color: Whether to use colored text to differentiate user / bot
514
+ - header: Header text to print at the start of the interaction
515
+ - stream: Whether to stream text as it is generated
516
+ """
517
+ print()
518
+
519
+ # fresh import of color codes in case `color` param has changed
520
+ from .utils import USER_STYLE, BOT_STYLE, DIM_STYLE, SPECIAL_STYLE
521
+
522
+ # disable color codes if explicitly disabled by `color` param
523
+ if not color:
524
+ USER_STYLE = ''
525
+ BOT_STYLE = ''
526
+ DIM_STYLE = ''
527
+ SPECIAL_STYLE = ''
528
+
529
+ if header is not None:
530
+ print(f"{SPECIAL_STYLE}{header}{RESET_ALL}\n")
531
+
532
+ while True:
533
+
534
+ prompt = f"{RESET_ALL} > {USER_STYLE}"
535
+
536
+ try:
537
+ user_prompt, next_message_start = self._interactive_input(
538
+ prompt,
539
+ DIM_STYLE,
540
+ USER_STYLE,
541
+ BOT_STYLE
542
+ )
543
+ except KeyboardInterrupt:
544
+ print(f"{RESET_ALL}\n")
545
+ return
546
+
547
+ # got 'exit' or 'quit' command
548
+ if user_prompt is None and next_message_start is None:
549
+ break
550
+
551
+ if next_message_start is not None:
552
+ try:
553
+ if stream:
554
+ print(f"{BOT_STYLE}{next_message_start}", end='', flush=True)
555
+ output = next_message_start + self.model.stream_print(
556
+ self.inference_str_from_messages() + next_message_start,
557
+ stops=self.format['stops'],
558
+ sampler=self.sampler,
559
+ end=''
560
+ )
561
+ else:
562
+ print(f"{BOT_STYLE}", end='', flush=True)
563
+ output = next_message_start + self.model.generate(
564
+ self.inference_str_from_messages() + next_message_start,
565
+ stops=self.format['stops'],
566
+ sampler=self.sampler
567
+ )
568
+ print(output, end='', flush=True)
569
+ except KeyboardInterrupt:
570
+ print(f"{DIM_STYLE} [message not added to history; press ENTER to re-roll]\n")
571
+ continue
572
+ else:
573
+ self.add_message("bot", output)
574
+ else:
575
+ print(BOT_STYLE)
576
+ if user_prompt != "":
577
+ self.add_message("user", user_prompt)
578
+ try:
579
+ if stream:
580
+ output = self.model.stream_print(
581
+ self.inference_str_from_messages(),
582
+ stops=self.format['stops'],
583
+ sampler=self.sampler,
584
+ end=''
585
+ )
586
+ else:
587
+ output = self.model.generate(
588
+ self.inference_str_from_messages(),
589
+ stops=self.format['stops'],
590
+ sampler=self.sampler
591
+ )
592
+ print(output, end='', flush=True)
593
+ except KeyboardInterrupt:
594
+ print(f"{DIM_STYLE} [message not added to history; press ENTER to re-roll]\n")
595
+ continue
596
+ else:
597
+ self.add_message("bot", output)
598
+
599
+ if output.endswith("\n\n"):
600
+ print(RESET_ALL, end = '', flush=True)
601
+ elif output.endswith("\n"):
602
+ print(RESET_ALL)
603
+ else:
604
+ print(f"{RESET_ALL}\n")
605
+
606
+
607
+ def reset(self) -> None:
608
+ """
609
+ Clear the list of messages, which resets the thread to its original
610
+ state
611
+ """
612
+ self.messages: list[dict[str, str]] = [
613
+ self.create_message("system", self.format['system_content'])
614
+ ]
615
+
616
+
617
+ def as_string(self) -> str:
618
+ """Return this thread's message history as a string"""
619
+ ret = ''
620
+ for msg in self.messages:
621
+ ret += msg['prefix']
622
+ ret += msg['content']
623
+ ret += msg['postfix']
624
+ return ret
625
+
626
+
627
+ def print_stats(
628
+ self,
629
+ end: str = '\n',
630
+ file: _SupportsWriteAndFlush = sys.stdout,
631
+ flush: bool = True
632
+ ) -> None:
633
+ """Print stats about the context usage in this thread"""
634
+ thread_len_tokens = self.len_messages()
635
+ max_ctx_len = self.model.context_length
636
+ context_used_percentage = round((thread_len_tokens/max_ctx_len)*100)
637
+ print(f"{thread_len_tokens} / {max_ctx_len} tokens", file=file, flush=flush)
638
+ print(f"{context_used_percentage}% of context used", file=file, flush=flush)
639
+ print(f"{len(self.messages)} messages", end=end, file=file, flush=flush)
640
+ if not flush:
641
+ file.flush()