webscout 2.6__py3-none-any.whl → 2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of webscout might be problematic. Click here for more details.

webscout/Local/thread.py CHANGED
@@ -1,5 +1,3 @@
1
- # thread.py
2
- # https://github.com/ddh0/easy-llama/
3
1
  from ._version import __version__, __llama_cpp_version__
4
2
 
5
3
  """Submodule containing the Thread class, used for interaction with a Model"""
@@ -10,15 +8,51 @@ from .model import Model, assert_model_is_loaded, _SupportsWriteAndFlush
10
8
  from .utils import RESET_ALL, cls, print_verbose, truncate
11
9
  from .samplers import SamplerSettings, DefaultSampling
12
10
  from typing import Optional, Literal, Union
11
+ from .formats import AdvancedFormat
13
12
 
14
13
  from .formats import blank as formats_blank
15
14
 
16
15
 
16
+ class Message(dict):
17
+ """
18
+ A dictionary representing a single message within a Thread
19
+
20
+ Works just like a normal `dict`, but a new method:
21
+ - `.as_string` - Return the full message string
22
+
23
+ Generally, messages have these keys:
24
+ - `role` - The role of the speaker: 'system', 'user', or 'bot'
25
+ - `prefix` - The text that prefixes the message content
26
+ - `content` - The actual content of the message
27
+ - `suffix` - The text that suffixes the message content
28
+ """
29
+
30
+ def __repr__(self) -> str:
31
+ return \
32
+ f"Message([" \
33
+ f"('role', {repr(self['role'])}), " \
34
+ f"('prefix', {repr(self['prefix'])}), " \
35
+ f"('content', {repr(self['content'])}), " \
36
+ f"('suffix', {repr(self['suffix'])})])"
37
+
38
+ def as_string(self):
39
+ """Return the full message string"""
40
+ try:
41
+ return self['prefix'] + self['content'] + self['suffix']
42
+ except KeyError as e:
43
+ e.add_note(
44
+ "as_string: Message is missing one or more of the "
45
+ "required 'prefix', 'content', 'suffix' attributes - this is "
46
+ "unexpected"
47
+ )
48
+ raise e
49
+
50
+
17
51
  class Thread:
18
52
  """
19
53
  Provide functionality to facilitate easy interactions with a Model
20
54
 
21
- This is just a brief overview of webscout.Local.Thread.
55
+ This is just a brief overview of m.Thread.
22
56
  To see a full description of each method and its parameters,
23
57
  call help(Thread), or see the relevant docstring.
24
58
 
@@ -36,24 +70,26 @@ class Thread:
36
70
  The following attributes are available:
37
71
  - `.format` - The format being used for messages in this thread
38
72
  - `.messages` - The list of messages in this thread
39
- - `.model` - The `webscout.Local.Model` instance used by this thread
73
+ - `.model` - The `m.Model` instance used by this thread
40
74
  - `.sampler` - The SamplerSettings object used in this thread
41
75
  """
42
76
 
43
77
  def __init__(
44
78
  self,
45
79
  model: Model,
46
- format: dict[str, Union[str, list]],
47
- sampler: SamplerSettings = DefaultSampling
80
+ format: Union[dict, AdvancedFormat],
81
+ sampler: SamplerSettings = DefaultSampling,
82
+ messages: Optional[list[Message]] = None,
48
83
  ):
49
84
  """
50
85
  Given a Model and a format, construct a Thread instance.
51
86
 
52
87
  model: The Model to use for text generation
53
- format: The format specifying how messages should be structured (see webscout.Local.formats)
88
+ format: The format specifying how messages should be structured (see m.formats)
54
89
 
55
- The following parameter is optional:
90
+ The following parameters are optional:
56
91
  - sampler: The SamplerSettings object used to control text generation
92
+ - messages: A list of m.thread.Message objects to add to the Thread upon construction
57
93
  """
58
94
 
59
95
  assert isinstance(model, Model), \
@@ -62,8 +98,8 @@ class Thread:
62
98
 
63
99
  assert_model_is_loaded(model)
64
100
 
65
- assert isinstance(format, dict), \
66
- f"Thread: format should be dict, not {type(format)}"
101
+ assert isinstance(format, (dict, AdvancedFormat)), \
102
+ f"Thread: format should be dict or AdvancedFormat, not {type(format)}"
67
103
 
68
104
  if any(k not in format.keys() for k in formats_blank.keys()):
69
105
  raise KeyError(
@@ -87,12 +123,23 @@ class Thread:
87
123
  'top_k'
88
124
  ]
89
125
  ), 'Thread: sampler is missing one or more required attributes'
126
+
127
+ self._messages: Optional[list[Message]] = messages
128
+ if self._messages is not None:
129
+ if not all(isinstance(msg, Message) for msg in self._messages):
130
+ raise TypeError(
131
+ "Thread: one or more messages provided to __init__() is "
132
+ "not an instance of m.thread.Message"
133
+ )
90
134
 
135
+ # Thread.messages is never empty, unless `messages` param is explicity
136
+ # set to `[]` during construction
137
+
91
138
  self.model: Model = model
92
- self.format: dict[str, Union[str, list]] = format
93
- self.messages: list[dict[str, str]] = [
139
+ self.format: Union[dict, AdvancedFormat] = format
140
+ self.messages: list[Message] = [
94
141
  self.create_message("system", self.format['system_content'])
95
- ]
142
+ ] if self._messages is None else self._messages
96
143
  self.sampler: SamplerSettings = sampler
97
144
 
98
145
  if self.model.verbose:
@@ -100,13 +147,13 @@ class Thread:
100
147
  print_verbose(f"model == {self.model}")
101
148
  print_verbose(f"format['system_prefix'] == {truncate(repr(self.format['system_prefix']))}")
102
149
  print_verbose(f"format['system_content'] == {truncate(repr(self.format['system_content']))}")
103
- print_verbose(f"format['system_postfix'] == {truncate(repr(self.format['system_postfix']))}")
150
+ print_verbose(f"format['system_suffix'] == {truncate(repr(self.format['system_suffix']))}")
104
151
  print_verbose(f"format['user_prefix'] == {truncate(repr(self.format['user_prefix']))}")
105
152
  print_verbose(f"format['user_content'] == {truncate(repr(self.format['user_content']))}")
106
- print_verbose(f"format['user_postfix'] == {truncate(repr(self.format['user_postfix']))}")
153
+ print_verbose(f"format['user_suffix'] == {truncate(repr(self.format['user_suffix']))}")
107
154
  print_verbose(f"format['bot_prefix'] == {truncate(repr(self.format['bot_prefix']))}")
108
155
  print_verbose(f"format['bot_content'] == {truncate(repr(self.format['bot_content']))}")
109
- print_verbose(f"format['bot_postfix'] == {truncate(repr(self.format['bot_postfix']))}")
156
+ print_verbose(f"format['bot_suffix'] == {truncate(repr(self.format['bot_suffix']))}")
110
157
  print_verbose(f"format['stops'] == {truncate(repr(self.format['stops']))}")
111
158
  print_verbose(f"sampler.temp == {self.sampler.temp}")
112
159
  print_verbose(f"sampler.top_p == {self.sampler.top_p}")
@@ -118,29 +165,10 @@ class Thread:
118
165
 
119
166
 
120
167
  def __repr__(self) -> str:
121
- repr_str = f"Thread({repr(self.model)}, {repr(self.format)}, "
122
- repr_str += f"{repr(self.sampler)})"
123
- # system message is created from format, so not represented
124
- if len(self.messages) <= 1:
125
- return repr_str
126
- else:
127
- for msg in self.messages:
128
- if msg['role'] == 'user':
129
- repr_str += "\nThread.add_message('user', " + repr(msg['content']) + ')'
130
- elif msg['role'] == 'bot':
131
- repr_str += "\nThread.add_message('bot', " + repr(msg['content']) + ')'
132
- return repr_str
133
- def save_conversation(self, filepath: str) -> None:
134
- """
135
- Saves the conversation history to a JSON file.
136
-
137
- filepath: The path to the file where the conversation should be saved.
138
- """
139
- import json
140
-
141
- data = [{'role': msg['role'], 'content': msg['content']} for msg in self.messages]
142
- with open(filepath, 'w') as f:
143
- json.dump(data, f, indent=4)
168
+ return \
169
+ f"Thread({repr(self.model)}, {repr(self.format)}, " + \
170
+ f"{repr(self.sampler)}, {repr(self.messages)})"
171
+
144
172
  def __str__(self) -> str:
145
173
  return self.as_string()
146
174
 
@@ -156,9 +184,9 @@ class Thread:
156
184
  self,
157
185
  role: Literal['system', 'user', 'bot'],
158
186
  content: str
159
- ) -> dict[str, str]:
187
+ ) -> Message:
160
188
  """
161
- Create a message using the format of this Thread
189
+ Construct a message using the format of this Thread
162
190
  """
163
191
 
164
192
  assert role.lower() in ['system', 'user', 'bot'], \
@@ -168,34 +196,40 @@ class Thread:
168
196
  f"create_message: content should be str, not {type(content)}"
169
197
 
170
198
  if role.lower() == 'system':
171
- return {
172
- "role": "system",
173
- "prefix": self.format['system_prefix'],
174
- "content": content,
175
- "postfix": self.format['system_postfix']
176
- }
199
+ return Message(
200
+ [
201
+ ('role', 'system'),
202
+ ('prefix', self.format['system_prefix']),
203
+ ('content', content),
204
+ ('suffix', self.format['system_suffix'])
205
+ ]
206
+ )
177
207
 
178
208
  elif role.lower() == 'user':
179
- return {
180
- "role": "user",
181
- "prefix": self.format['user_prefix'],
182
- "content": content,
183
- "postfix": self.format['user_postfix']
184
- }
209
+ return Message(
210
+ [
211
+ ('role', 'user'),
212
+ ('prefix', self.format['user_prefix']),
213
+ ('content', content),
214
+ ('suffix', self.format['user_suffix'])
215
+ ]
216
+ )
185
217
 
186
218
  elif role.lower() == 'bot':
187
- return {
188
- "role": "bot",
189
- "prefix": self.format['bot_prefix'],
190
- "content": content,
191
- "postfix": self.format['bot_postfix']
192
- }
219
+ return Message(
220
+ [
221
+ ('role', 'bot'),
222
+ ('prefix', self.format['bot_prefix']),
223
+ ('content', content),
224
+ ('suffix', self.format['bot_suffix'])
225
+ ]
226
+ )
193
227
 
194
228
  def len_messages(self) -> int:
195
229
  """
196
230
  Return the total length of all messages in this thread, in tokens.
197
231
 
198
- Equivalent to `len(Thread)`."""
232
+ Can also use `len(Thread)`."""
199
233
 
200
234
  return self.model.get_length(self.as_string())
201
235
 
@@ -223,38 +257,35 @@ class Thread:
223
257
  respecting the format and context length of this thread.
224
258
  """
225
259
 
226
- messages = self.messages
227
-
228
- context_len_budget = self.model.context_length
229
- if len(messages) > 0:
230
- sys_msg = messages[0]
231
- sys_msg_str = (
232
- sys_msg['prefix'] + sys_msg['content'] + sys_msg['postfix']
233
- )
234
- context_len_budget -= self.model.get_length(sys_msg_str)
235
- else:
236
- sys_msg_str = ''
237
-
238
260
  inf_str = ''
261
+ sys_msg_str = ''
262
+ # whether to treat the first message as necessary to keep
263
+ sys_msg_flag = False
264
+ context_len_budget = self.model.context_length
239
265
 
240
- # Start at most recent message and work backwards up the history
241
- # excluding system message. Once we exceed thread
242
- # max_context_length, break without including that message
243
- for message in reversed(messages[1:]):
244
- context_len_budget -= self.model.get_length(
245
- message['prefix'] + message['content'] + message['postfix']
246
- )
247
-
266
+ # if at least 1 message is history
267
+ if len(self.messages) >= 1:
268
+ # if first message has system role
269
+ if self.messages[0]['role'] == 'system':
270
+ sys_msg_flag = True
271
+ sys_msg = self.messages[0]
272
+ sys_msg_str = sys_msg.as_string()
273
+ context_len_budget -= self.model.get_length(sys_msg_str)
274
+
275
+ if sys_msg_flag:
276
+ iterator = reversed(self.messages[1:])
277
+ else:
278
+ iterator = reversed(self.messages)
279
+
280
+ for message in iterator:
281
+ msg_str = message.as_string()
282
+ context_len_budget -= self.model.get_length(msg_str)
248
283
  if context_len_budget <= 0:
249
284
  break
250
-
251
- msg_str = (
252
- message['prefix'] + message['content'] + message['postfix']
253
- )
254
-
255
285
  inf_str = msg_str + inf_str
256
286
 
257
- inf_str = sys_msg_str + inf_str
287
+ if sys_msg_flag:
288
+ inf_str = sys_msg_str + inf_str
258
289
  inf_str += self.format['bot_prefix']
259
290
 
260
291
  return inf_str
@@ -277,20 +308,7 @@ class Thread:
277
308
  self.add_message("bot", output)
278
309
 
279
310
  return output
280
- def load_conversation(self, filepath: str) -> None:
281
- """
282
- Loads a conversation history from a JSON file.
283
-
284
- filepath: The path to the file containing the conversation history.
285
- """
286
- import json
287
-
288
- with open(filepath, 'r') as f:
289
- data = json.load(f)
290
-
291
- self.messages = []
292
- for item in data:
293
- self.messages.append(self.create_message(item['role'], item['content']))
311
+
294
312
 
295
313
  def _interactive_update_sampler(self) -> None:
296
314
  """Interactively update the sampler settings used in this Thread"""
@@ -373,7 +391,8 @@ class Thread:
373
391
  prompt: str,
374
392
  _dim_style: str,
375
393
  _user_style: str,
376
- _bot_style: str
394
+ _bot_style: str,
395
+ _special_style: str
377
396
  ) -> tuple:
378
397
  """
379
398
  Recive input from the user, while handling multi-line input
@@ -420,15 +439,7 @@ class Thread:
420
439
 
421
440
  elif command.lower() in ['str', 'string', 'as_string']:
422
441
  print(f"\n{self.as_string()}\n")
423
-
424
- elif command.lower() in ['save']:
425
- print()
426
- try:
427
- filepath = input(f'Enter filepath to save conversation: ')
428
- self.save_conversation(filepath)
429
- print(f'[conversation saved to {filepath}]\n')
430
- except Exception as e:
431
- print(f'[error saving conversation: {e}]\n')
442
+
432
443
  elif command.lower() in ['repr', 'save', 'backup']:
433
444
  print(f"\n{repr(self)}\n")
434
445
 
@@ -438,14 +449,7 @@ class Thread:
438
449
  del self.messages[-1]
439
450
  assert len(self.messages) == (old_len - 1)
440
451
  print('[removed last message]\n')
441
- elif command.lower() in ['load']:
442
- print()
443
- try:
444
- filepath = input(f'Enter filepath to load conversation: ')
445
- self.load_conversation(filepath)
446
- print(f'[conversation loaded from {filepath}]\n')
447
- except Exception as e:
448
- print(f'[error loading conversation: {e}]\n')
452
+
449
453
  elif command.lower() in ['last', 'repeat']:
450
454
  last_msg = self.messages[-1]
451
455
  if last_msg['role'] == 'user':
@@ -468,33 +472,35 @@ class Thread:
468
472
 
469
473
  elif command.lower() in ['help', '/?', '?']:
470
474
  print()
471
- print('reset / restart -- Reset the thread to its original state')
472
- print('clear / cls -- Clear the terminal')
473
- print('context / ctx -- Get the context usage in tokens')
474
- print('print_stats / stats -- Get the context usage stats')
475
- print('sampler / settings -- Update the sampler settings')
476
- print('string / str -- Print the message history as a string')
477
- print('repr / save -- Print the representation of the thread')
478
- print('remove / delete -- Remove the last message')
479
- print('last / repeat -- Repeat the last message')
480
- print('inference / inf -- Print the inference string')
481
- print('reroll / swipe -- Regenerate the last message')
482
- print('exit / quit -- Exit the interactive chat (can also use ^C)')
483
- print('help / ? -- Show this screen')
475
+ print('reset | restart -- Reset the thread to its original state')
476
+ print('clear | cls -- Clear the terminal')
477
+ print('context | ctx -- Get the context usage in tokens')
478
+ print('print_stats | stats -- Get the context usage stats')
479
+ print('sampler | settings -- Update the sampler settings')
480
+ print('string | str -- Print the message history as a string')
481
+ print('repr | save -- Print the representation of the thread')
482
+ print('remove | delete -- Remove the last message')
483
+ print('last | repeat -- Repeat the last message')
484
+ print('inference | inf -- Print the inference string')
485
+ print('reroll | swipe -- Regenerate the last message')
486
+ print('exit | quit -- Exit the interactive chat (can also use ^C)')
487
+ print('help | ? -- Show this screen')
484
488
  print()
485
489
  print("TIP: type < at the prompt and press ENTER to prefix the bot's next message.")
486
490
  print(' for example, type "Sure!" to bypass refusals')
487
491
  print()
488
- print('save -- Save the conversation to a JSON file')
489
- print('load -- Load a conversation from a JSON file')
492
+ print("TIP: type !! at the prompt and press ENTER to insert a system message")
493
+ print()
494
+
490
495
  else:
491
496
  print(f'\n[unknown command]\n')
492
497
 
493
- elif user_input == '<': # the next bot message will start with...
498
+ # prefix the bot's next message
499
+ elif user_input == '<':
494
500
 
495
501
  print()
496
502
  try:
497
- next_message_start = input(f'{_dim_style} < ')
503
+ next_message_start = input(f'{RESET_ALL} < {_dim_style}')
498
504
 
499
505
  except KeyboardInterrupt:
500
506
  print(f'{RESET_ALL}\n')
@@ -503,25 +509,23 @@ class Thread:
503
509
  else:
504
510
  print()
505
511
  return '', next_message_start
506
-
507
- elif user_input.endswith('<'):
508
512
 
513
+ # insert a system message
514
+ elif user_input == '!!':
509
515
  print()
510
516
 
511
- msg = user_input.removesuffix('<')
512
- self.add_message("user", msg)
513
-
514
517
  try:
515
- next_message_start = input(f'{_dim_style} < ')
516
-
518
+ next_sys_msg = input(f'{RESET_ALL} !! {_special_style}')
519
+
517
520
  except KeyboardInterrupt:
518
521
  print(f'{RESET_ALL}\n')
519
522
  continue
520
-
523
+
521
524
  else:
522
525
  print()
523
- return '', next_message_start
526
+ return next_sys_msg, -1
524
527
 
528
+ # concatenate multi-line input
525
529
  else:
526
530
  full_user_input += user_input
527
531
  return full_user_input, None
@@ -548,6 +552,8 @@ class Thread:
548
552
  Type `<` and press `ENTER` to prefix the bot's next message, for
549
553
  example with `Sure!`.
550
554
 
555
+ Type `!!` at the prompt and press `ENTER` to insert a system message.
556
+
551
557
  The following parameters are optional:
552
558
  - color: Whether to use colored text to differentiate user / bot
553
559
  - header: Header text to print at the start of the interaction
@@ -556,28 +562,29 @@ class Thread:
556
562
  print()
557
563
 
558
564
  # fresh import of color codes in case `color` param has changed
559
- from .utils import USER_STYLE, BOT_STYLE, DIM_STYLE, SPECIAL_STYLE
565
+ from .utils import SPECIAL_STYLE, USER_STYLE, BOT_STYLE, DIM_STYLE
560
566
 
561
567
  # disable color codes if explicitly disabled by `color` param
562
568
  if not color:
569
+ SPECIAL_STYLE = ''
563
570
  USER_STYLE = ''
564
571
  BOT_STYLE = ''
565
572
  DIM_STYLE = ''
566
- SPECIAL_STYLE = ''
567
573
 
568
574
  if header is not None:
569
575
  print(f"{SPECIAL_STYLE}{header}{RESET_ALL}\n")
570
576
 
571
577
  while True:
572
578
 
573
- prompt = f"{RESET_ALL} >>> {USER_STYLE}"
579
+ prompt = f"{RESET_ALL} > {USER_STYLE}"
574
580
 
575
581
  try:
576
582
  user_prompt, next_message_start = self._interactive_input(
577
583
  prompt,
578
584
  DIM_STYLE,
579
585
  USER_STYLE,
580
- BOT_STYLE
586
+ BOT_STYLE,
587
+ SPECIAL_STYLE
581
588
  )
582
589
  except KeyboardInterrupt:
583
590
  print(f"{RESET_ALL}\n")
@@ -587,6 +594,11 @@ class Thread:
587
594
  if user_prompt is None and next_message_start is None:
588
595
  break
589
596
 
597
+ # insert a system message via `!!` prompt
598
+ if next_message_start == -1:
599
+ self.add_message('system', user_prompt)
600
+ continue
601
+
590
602
  if next_message_start is not None:
591
603
  try:
592
604
  if stream:
@@ -648,19 +660,17 @@ class Thread:
648
660
  Clear the list of messages, which resets the thread to its original
649
661
  state
650
662
  """
651
- self.messages: list[dict[str, str]] = [
663
+ self.messages: list[Message] = [
652
664
  self.create_message("system", self.format['system_content'])
653
- ]
665
+ ] if self._messages is None else self._messages
654
666
 
655
667
 
656
668
  def as_string(self) -> str:
657
669
  """Return this thread's message history as a string"""
658
- ret = ''
670
+ thread_string = ''
659
671
  for msg in self.messages:
660
- ret += msg['prefix']
661
- ret += msg['content']
662
- ret += msg['postfix']
663
- return ret
672
+ thread_string += msg.as_string()
673
+ return thread_string
664
674
 
665
675
 
666
676
  def print_stats(
@@ -677,4 +687,4 @@ class Thread:
677
687
  print(f"{context_used_percentage}% of context used", file=file, flush=flush)
678
688
  print(f"{len(self.messages)} messages", end=end, file=file, flush=flush)
679
689
  if not flush:
680
- file.flush()
690
+ file.flush()