webscout 2.2__py3-none-any.whl → 2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of webscout might be problematic. Click here for more details.
- webscout/AIauto.py +54 -13
- webscout/Local/__init__.py +10 -0
- webscout/Local/_version.py +3 -0
- webscout/Local/formats.py +482 -0
- webscout/Local/model.py +702 -0
- webscout/Local/samplers.py +161 -0
- webscout/Local/thread.py +641 -0
- webscout/Local/utils.py +171 -0
- webscout/__init__.py +96 -94
- webscout/async_providers.py +9 -9
- webscout/g4f.py +2 -2
- {webscout-2.2.dist-info → webscout-2.3.dist-info}/METADATA +6 -3
- {webscout-2.2.dist-info → webscout-2.3.dist-info}/RECORD +17 -10
- webscout-2.3.dist-info/entry_points.txt +3 -0
- webscout-2.2.dist-info/entry_points.txt +0 -6
- {webscout-2.2.dist-info → webscout-2.3.dist-info}/LICENSE.md +0 -0
- {webscout-2.2.dist-info → webscout-2.3.dist-info}/WHEEL +0 -0
- {webscout-2.2.dist-info → webscout-2.3.dist-info}/top_level.txt +0 -0
webscout/Local/thread.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
1
|
+
# thread.py
|
|
2
|
+
# https://github.com/ddh0/easy-llama/
|
|
3
|
+
from ._version import __version__, __llama_cpp_version__
|
|
4
|
+
|
|
5
|
+
"""Submodule containing the Thread class, used for interaction with a Model"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
from .model import Model, assert_model_is_loaded, _SupportsWriteAndFlush
|
|
10
|
+
from .utils import RESET_ALL, cls, print_verbose, truncate
|
|
11
|
+
from .samplers import SamplerSettings, DefaultSampling
|
|
12
|
+
from typing import Optional, Literal, Union
|
|
13
|
+
|
|
14
|
+
from .formats import blank as formats_blank
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Thread:
|
|
18
|
+
"""
|
|
19
|
+
Provide functionality to facilitate easy interactions with a Model
|
|
20
|
+
|
|
21
|
+
This is just a brief overview of webscout.Local.Thread.
|
|
22
|
+
To see a full description of each method and its parameters,
|
|
23
|
+
call help(Thread), or see the relevant docstring.
|
|
24
|
+
|
|
25
|
+
The following methods are available:
|
|
26
|
+
- `.add_message()` - Add a message to `Thread.messages`
|
|
27
|
+
- `.as_string()` - Return this thread's complete message history as a string
|
|
28
|
+
- `.create_message()` - Create a message using the format of this thread
|
|
29
|
+
- `.inference_str_from_messages()` - Using the list of messages, return a string suitable for inference
|
|
30
|
+
- `.interact()` - Start an interactive, terminal-based chat session
|
|
31
|
+
- `.len_messages()` - Get the total length of all messages in tokens
|
|
32
|
+
- `.print_stats()` - Print stats about the context usage in this thread
|
|
33
|
+
- `.reset()` - Clear the list of messages
|
|
34
|
+
- `.send()` - Send a message in this thread
|
|
35
|
+
|
|
36
|
+
The following attributes are available:
|
|
37
|
+
- `.format` - The format being used for messages in this thread
|
|
38
|
+
- `.messages` - The list of messages in this thread
|
|
39
|
+
- `.model` - The `webscout.Local.Model` instance used by this thread
|
|
40
|
+
- `.sampler` - The SamplerSettings object used in this thread
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
model: Model,
|
|
46
|
+
format: dict[str, Union[str, list]],
|
|
47
|
+
sampler: SamplerSettings = DefaultSampling
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
Given a Model and a format, construct a Thread instance.
|
|
51
|
+
|
|
52
|
+
model: The Model to use for text generation
|
|
53
|
+
format: The format specifying how messages should be structured (see webscout.Local.formats)
|
|
54
|
+
|
|
55
|
+
The following parameter is optional:
|
|
56
|
+
- sampler: The SamplerSettings object used to control text generation
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
assert isinstance(model, Model), \
|
|
60
|
+
"Thread: model should be an " + \
|
|
61
|
+
f"instance of webscout.Local.Model, not {type(model)}"
|
|
62
|
+
|
|
63
|
+
assert_model_is_loaded(model)
|
|
64
|
+
|
|
65
|
+
assert isinstance(format, dict), \
|
|
66
|
+
f"Thread: format should be dict, not {type(format)}"
|
|
67
|
+
|
|
68
|
+
if any(k not in format.keys() for k in formats_blank.keys()):
|
|
69
|
+
raise KeyError(
|
|
70
|
+
"Thread: format is missing one or more required keys, see " + \
|
|
71
|
+
"webscout.Local.formats.blank for an example"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
assert isinstance(format['stops'], list), \
|
|
75
|
+
"Thread: format['stops'] should be list, not " + \
|
|
76
|
+
f"{type(format['stops'])}"
|
|
77
|
+
|
|
78
|
+
assert all(
|
|
79
|
+
hasattr(sampler, attr) for attr in [
|
|
80
|
+
'max_len_tokens',
|
|
81
|
+
'temp',
|
|
82
|
+
'top_p',
|
|
83
|
+
'min_p',
|
|
84
|
+
'frequency_penalty',
|
|
85
|
+
'presence_penalty',
|
|
86
|
+
'repeat_penalty',
|
|
87
|
+
'top_k'
|
|
88
|
+
]
|
|
89
|
+
), 'Thread: sampler is missing one or more required attributes'
|
|
90
|
+
|
|
91
|
+
self.model: Model = model
|
|
92
|
+
self.format: dict[str, Union[str, list]] = format
|
|
93
|
+
self.messages: list[dict[str, str]] = [
|
|
94
|
+
self.create_message("system", self.format['system_content'])
|
|
95
|
+
]
|
|
96
|
+
self.sampler: SamplerSettings = sampler
|
|
97
|
+
|
|
98
|
+
if self.model.verbose:
|
|
99
|
+
print_verbose("new Thread instance with the following attributes:")
|
|
100
|
+
print_verbose(f"model == {self.model}")
|
|
101
|
+
print_verbose(f"format['system_prefix'] == {truncate(repr(self.format['system_prefix']))}")
|
|
102
|
+
print_verbose(f"format['system_content'] == {truncate(repr(self.format['system_content']))}")
|
|
103
|
+
print_verbose(f"format['system_postfix'] == {truncate(repr(self.format['system_postfix']))}")
|
|
104
|
+
print_verbose(f"format['user_prefix'] == {truncate(repr(self.format['user_prefix']))}")
|
|
105
|
+
print_verbose(f"format['user_content'] == {truncate(repr(self.format['user_content']))}")
|
|
106
|
+
print_verbose(f"format['user_postfix'] == {truncate(repr(self.format['user_postfix']))}")
|
|
107
|
+
print_verbose(f"format['bot_prefix'] == {truncate(repr(self.format['bot_prefix']))}")
|
|
108
|
+
print_verbose(f"format['bot_content'] == {truncate(repr(self.format['bot_content']))}")
|
|
109
|
+
print_verbose(f"format['bot_postfix'] == {truncate(repr(self.format['bot_postfix']))}")
|
|
110
|
+
print_verbose(f"format['stops'] == {truncate(repr(self.format['stops']))}")
|
|
111
|
+
print_verbose(f"sampler.temp == {self.sampler.temp}")
|
|
112
|
+
print_verbose(f"sampler.top_p == {self.sampler.top_p}")
|
|
113
|
+
print_verbose(f"sampler.min_p == {self.sampler.min_p}")
|
|
114
|
+
print_verbose(f"sampler.frequency_penalty == {self.sampler.frequency_penalty}")
|
|
115
|
+
print_verbose(f"sampler.presence_penalty == {self.sampler.presence_penalty}")
|
|
116
|
+
print_verbose(f"sampler.repeat_penalty == {self.sampler.repeat_penalty}")
|
|
117
|
+
print_verbose(f"sampler.top_k == {self.sampler.top_k}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def __repr__(self) -> str:
|
|
121
|
+
repr_str = f"Thread({repr(self.model)}, {repr(self.format)}, "
|
|
122
|
+
repr_str += f"{repr(self.sampler)})"
|
|
123
|
+
# system message is created from format, so not represented
|
|
124
|
+
if len(self.messages) <= 1:
|
|
125
|
+
return repr_str
|
|
126
|
+
else:
|
|
127
|
+
for msg in self.messages:
|
|
128
|
+
if msg['role'] == 'user':
|
|
129
|
+
repr_str += "\nThread.add_message('user', " + repr(msg['content']) + ')'
|
|
130
|
+
elif msg['role'] == 'bot':
|
|
131
|
+
repr_str += "\nThread.add_message('bot', " + repr(msg['content']) + ')'
|
|
132
|
+
return repr_str
|
|
133
|
+
|
|
134
|
+
def __str__(self) -> str:
|
|
135
|
+
return self.as_string()
|
|
136
|
+
|
|
137
|
+
def __len__(self) -> int:
|
|
138
|
+
"""
|
|
139
|
+
`len(Thread)` returns the length of the Thread in tokens
|
|
140
|
+
|
|
141
|
+
To get the number of messages in the Thread, use `len(Thread.messages)`
|
|
142
|
+
"""
|
|
143
|
+
return self.len_messages()
|
|
144
|
+
|
|
145
|
+
def create_message(
|
|
146
|
+
self,
|
|
147
|
+
role: Literal['system', 'user', 'bot'],
|
|
148
|
+
content: str
|
|
149
|
+
) -> dict[str, str]:
|
|
150
|
+
"""
|
|
151
|
+
Create a message using the format of this Thread
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
assert role.lower() in ['system', 'user', 'bot'], \
|
|
155
|
+
f"create_message: role should be 'system', 'user', or 'bot', not '{role.lower()}'"
|
|
156
|
+
|
|
157
|
+
assert isinstance(content, str), \
|
|
158
|
+
f"create_message: content should be str, not {type(content)}"
|
|
159
|
+
|
|
160
|
+
if role.lower() == 'system':
|
|
161
|
+
return {
|
|
162
|
+
"role": "system",
|
|
163
|
+
"prefix": self.format['system_prefix'],
|
|
164
|
+
"content": content,
|
|
165
|
+
"postfix": self.format['system_postfix']
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
elif role.lower() == 'user':
|
|
169
|
+
return {
|
|
170
|
+
"role": "user",
|
|
171
|
+
"prefix": self.format['user_prefix'],
|
|
172
|
+
"content": content,
|
|
173
|
+
"postfix": self.format['user_postfix']
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
elif role.lower() == 'bot':
|
|
177
|
+
return {
|
|
178
|
+
"role": "bot",
|
|
179
|
+
"prefix": self.format['bot_prefix'],
|
|
180
|
+
"content": content,
|
|
181
|
+
"postfix": self.format['bot_postfix']
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def len_messages(self) -> int:
|
|
185
|
+
"""
|
|
186
|
+
Return the total length of all messages in this thread, in tokens.
|
|
187
|
+
|
|
188
|
+
Equivalent to `len(Thread)`."""
|
|
189
|
+
|
|
190
|
+
return self.model.get_length(self.as_string())
|
|
191
|
+
|
|
192
|
+
def add_message(
|
|
193
|
+
self,
|
|
194
|
+
role: Literal['system', 'user', 'bot'],
|
|
195
|
+
content: str
|
|
196
|
+
) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Create a message and append it to `Thread.messages`.
|
|
199
|
+
|
|
200
|
+
`Thread.add_message(...)` is a shorthand for
|
|
201
|
+
`Thread.messages.append(Thread.create_message(...))`
|
|
202
|
+
"""
|
|
203
|
+
self.messages.append(
|
|
204
|
+
self.create_message(
|
|
205
|
+
role=role,
|
|
206
|
+
content=content
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def inference_str_from_messages(self) -> str:
|
|
211
|
+
"""
|
|
212
|
+
Using the list of messages, construct a string suitable for inference,
|
|
213
|
+
respecting the format and context length of this thread.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
messages = self.messages
|
|
217
|
+
|
|
218
|
+
context_len_budget = self.model.context_length
|
|
219
|
+
if len(messages) > 0:
|
|
220
|
+
sys_msg = messages[0]
|
|
221
|
+
sys_msg_str = (
|
|
222
|
+
sys_msg['prefix'] + sys_msg['content'] + sys_msg['postfix']
|
|
223
|
+
)
|
|
224
|
+
context_len_budget -= self.model.get_length(sys_msg_str)
|
|
225
|
+
else:
|
|
226
|
+
sys_msg_str = ''
|
|
227
|
+
|
|
228
|
+
inf_str = ''
|
|
229
|
+
|
|
230
|
+
# Start at most recent message and work backwards up the history
|
|
231
|
+
# excluding system message. Once we exceed thread
|
|
232
|
+
# max_context_length, break without including that message
|
|
233
|
+
for message in reversed(messages[1:]):
|
|
234
|
+
context_len_budget -= self.model.get_length(
|
|
235
|
+
message['prefix'] + message['content'] + message['postfix']
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if context_len_budget <= 0:
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
msg_str = (
|
|
242
|
+
message['prefix'] + message['content'] + message['postfix']
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
inf_str = msg_str + inf_str
|
|
246
|
+
|
|
247
|
+
inf_str = sys_msg_str + inf_str
|
|
248
|
+
inf_str += self.format['bot_prefix']
|
|
249
|
+
|
|
250
|
+
return inf_str
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def send(self, prompt: str) -> str:
|
|
254
|
+
"""
|
|
255
|
+
Send a message in this thread. This adds your message and the bot's
|
|
256
|
+
response to the list of messages.
|
|
257
|
+
|
|
258
|
+
Returns a string containing the response to your message.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
self.add_message("user", prompt)
|
|
262
|
+
output = self.model.generate(
|
|
263
|
+
self.inference_str_from_messages(),
|
|
264
|
+
stops=self.format['stops'],
|
|
265
|
+
sampler=self.sampler
|
|
266
|
+
)
|
|
267
|
+
self.add_message("bot", output)
|
|
268
|
+
|
|
269
|
+
return output
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _interactive_update_sampler(self) -> None:
|
|
273
|
+
"""Interactively update the sampler settings used in this Thread"""
|
|
274
|
+
print()
|
|
275
|
+
try:
|
|
276
|
+
new_max_len_tokens = input(f'max_len_tokens: {self.sampler.max_len_tokens} -> ')
|
|
277
|
+
new_temp = input(f'temp: {self.sampler.temp} -> ')
|
|
278
|
+
new_top_p = input(f'top_p: {self.sampler.top_p} -> ')
|
|
279
|
+
new_min_p = input(f'min_p: {self.sampler.min_p} -> ')
|
|
280
|
+
new_frequency_penalty = input(f'frequency_penalty: {self.sampler.frequency_penalty} -> ')
|
|
281
|
+
new_presence_penalty = input(f'presence_penalty: {self.sampler.presence_penalty} -> ')
|
|
282
|
+
new_repeat_penalty = input(f'repeat_penalty: {self.sampler.repeat_penalty} -> ')
|
|
283
|
+
new_top_k = input(f'top_k: {self.sampler.top_k} -> ')
|
|
284
|
+
|
|
285
|
+
except KeyboardInterrupt:
|
|
286
|
+
print('\nwebscout.Local: sampler settings not updated\n')
|
|
287
|
+
return
|
|
288
|
+
print()
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
self.sampler.max_len_tokens = int(new_max_len_tokens)
|
|
292
|
+
except ValueError:
|
|
293
|
+
pass
|
|
294
|
+
else:
|
|
295
|
+
print('webscout.Local: max_len_tokens updated')
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
self.sampler.temp = float(new_temp)
|
|
299
|
+
except ValueError:
|
|
300
|
+
pass
|
|
301
|
+
else:
|
|
302
|
+
print('webscout.Local: temp updated')
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
self.sampler.top_p = float(new_top_p)
|
|
306
|
+
except ValueError:
|
|
307
|
+
pass
|
|
308
|
+
else:
|
|
309
|
+
print('webscout.Local: top_p updated')
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
self.sampler.min_p = float(new_min_p)
|
|
313
|
+
except ValueError:
|
|
314
|
+
pass
|
|
315
|
+
else:
|
|
316
|
+
print('webscout.Local: min_p updated')
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
self.sampler.frequency_penalty = float(new_frequency_penalty)
|
|
320
|
+
except ValueError:
|
|
321
|
+
pass
|
|
322
|
+
else:
|
|
323
|
+
print('webscout.Local: frequency_penalty updated')
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
self.sampler.presence_penalty = float(new_presence_penalty)
|
|
327
|
+
except ValueError:
|
|
328
|
+
pass
|
|
329
|
+
else:
|
|
330
|
+
print('webscout.Local: presence_penalty updated')
|
|
331
|
+
|
|
332
|
+
try:
|
|
333
|
+
self.sampler.repeat_penalty = float(new_repeat_penalty)
|
|
334
|
+
except ValueError:
|
|
335
|
+
pass
|
|
336
|
+
else:
|
|
337
|
+
print('webscout.Local: repeat_penalty updated')
|
|
338
|
+
|
|
339
|
+
try:
|
|
340
|
+
self.sampler.top_k = int(new_top_k)
|
|
341
|
+
except ValueError:
|
|
342
|
+
pass
|
|
343
|
+
else:
|
|
344
|
+
print('webscout.Local: top_k updated')
|
|
345
|
+
print()
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _interactive_input(
|
|
349
|
+
self,
|
|
350
|
+
prompt: str,
|
|
351
|
+
_dim_style: str,
|
|
352
|
+
_user_style: str,
|
|
353
|
+
_bot_style: str
|
|
354
|
+
) -> tuple:
|
|
355
|
+
"""
|
|
356
|
+
Recive input from the user, while handling multi-line input
|
|
357
|
+
and commands
|
|
358
|
+
"""
|
|
359
|
+
full_user_input = '' # may become multiline
|
|
360
|
+
|
|
361
|
+
while True:
|
|
362
|
+
user_input = input(prompt)
|
|
363
|
+
|
|
364
|
+
if user_input.endswith('\\'):
|
|
365
|
+
full_user_input += user_input[:-1] + '\n'
|
|
366
|
+
|
|
367
|
+
elif user_input == '!':
|
|
368
|
+
|
|
369
|
+
print()
|
|
370
|
+
try:
|
|
371
|
+
command = input(f'{RESET_ALL} ! {_dim_style}')
|
|
372
|
+
except KeyboardInterrupt:
|
|
373
|
+
print('\n')
|
|
374
|
+
continue
|
|
375
|
+
|
|
376
|
+
if command == '':
|
|
377
|
+
print(f'\n[no command]\n')
|
|
378
|
+
|
|
379
|
+
elif command.lower() in ['reset', 'restart']:
|
|
380
|
+
self.reset()
|
|
381
|
+
print(f'\n[thread reset]\n')
|
|
382
|
+
|
|
383
|
+
elif command.lower() in ['cls', 'clear']:
|
|
384
|
+
cls()
|
|
385
|
+
print()
|
|
386
|
+
|
|
387
|
+
elif command.lower() in ['ctx', 'context']:
|
|
388
|
+
print(f"\n{self.len_messages()}\n")
|
|
389
|
+
|
|
390
|
+
elif command.lower() in ['stats', 'print_stats']:
|
|
391
|
+
print()
|
|
392
|
+
self.print_stats()
|
|
393
|
+
print()
|
|
394
|
+
|
|
395
|
+
elif command.lower() in ['sampler', 'samplers', 'settings']:
|
|
396
|
+
self._interactive_update_sampler()
|
|
397
|
+
|
|
398
|
+
elif command.lower() in ['str', 'string', 'as_string']:
|
|
399
|
+
print(f"\n{self.as_string()}\n")
|
|
400
|
+
|
|
401
|
+
elif command.lower() in ['repr', 'save', 'backup']:
|
|
402
|
+
print(f"\n{repr(self)}\n")
|
|
403
|
+
|
|
404
|
+
elif command.lower() in ['remove', 'rem', 'delete', 'del']:
|
|
405
|
+
print()
|
|
406
|
+
old_len = len(self.messages)
|
|
407
|
+
del self.messages[-1]
|
|
408
|
+
assert len(self.messages) == (old_len - 1)
|
|
409
|
+
print('[removed last message]\n')
|
|
410
|
+
|
|
411
|
+
elif command.lower() in ['last', 'repeat']:
|
|
412
|
+
last_msg = self.messages[-1]
|
|
413
|
+
if last_msg['role'] == 'user':
|
|
414
|
+
print(f"\n{_user_style}{last_msg['content']}{RESET_ALL}\n")
|
|
415
|
+
elif last_msg['role'] == 'bot':
|
|
416
|
+
print(f"\n{_bot_style}{last_msg['content']}{RESET_ALL}\n")
|
|
417
|
+
|
|
418
|
+
elif command.lower() in ['inf', 'inference', 'inf_str']:
|
|
419
|
+
print(f'\n"""{self.inference_str_from_messages()}"""\n')
|
|
420
|
+
|
|
421
|
+
elif command.lower() in ['reroll', 're-roll', 're', 'swipe']:
|
|
422
|
+
old_len = len(self.messages)
|
|
423
|
+
del self.messages[-1]
|
|
424
|
+
assert len(self.messages) == (old_len - 1)
|
|
425
|
+
return '', None
|
|
426
|
+
|
|
427
|
+
elif command.lower() in ['exit', 'quit']:
|
|
428
|
+
print(RESET_ALL)
|
|
429
|
+
return None, None
|
|
430
|
+
|
|
431
|
+
elif command.lower() in ['help', '/?', '?']:
|
|
432
|
+
print()
|
|
433
|
+
print('reset / restart -- Reset the thread to its original state')
|
|
434
|
+
print('clear / cls -- Clear the terminal')
|
|
435
|
+
print('context / ctx -- Get the context usage in tokens')
|
|
436
|
+
print('print_stats / stats -- Get the context usage stats')
|
|
437
|
+
print('sampler / settings -- Update the sampler settings')
|
|
438
|
+
print('string / str -- Print the message history as a string')
|
|
439
|
+
print('repr / save -- Print the representation of the thread')
|
|
440
|
+
print('remove / delete -- Remove the last message')
|
|
441
|
+
print('last / repeat -- Repeat the last message')
|
|
442
|
+
print('inference / inf -- Print the inference string')
|
|
443
|
+
print('reroll / swipe -- Regenerate the last message')
|
|
444
|
+
print('exit / quit -- Exit the interactive chat (can also use ^C)')
|
|
445
|
+
print('help / ? -- Show this screen')
|
|
446
|
+
print()
|
|
447
|
+
print("TIP: type < at the prompt and press ENTER to prefix the bot's next message.")
|
|
448
|
+
print(' for example, type "Sure!" to bypass refusals')
|
|
449
|
+
print()
|
|
450
|
+
|
|
451
|
+
else:
|
|
452
|
+
print(f'\n[unknown command]\n')
|
|
453
|
+
|
|
454
|
+
elif user_input == '<': # the next bot message will start with...
|
|
455
|
+
|
|
456
|
+
print()
|
|
457
|
+
try:
|
|
458
|
+
next_message_start = input(f'{_dim_style} < ')
|
|
459
|
+
|
|
460
|
+
except KeyboardInterrupt:
|
|
461
|
+
print(f'{RESET_ALL}\n')
|
|
462
|
+
continue
|
|
463
|
+
|
|
464
|
+
else:
|
|
465
|
+
print()
|
|
466
|
+
return '', next_message_start
|
|
467
|
+
|
|
468
|
+
elif user_input.endswith('<'):
|
|
469
|
+
|
|
470
|
+
print()
|
|
471
|
+
|
|
472
|
+
msg = user_input.removesuffix('<')
|
|
473
|
+
self.add_message("user", msg)
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
next_message_start = input(f'{_dim_style} < ')
|
|
477
|
+
|
|
478
|
+
except KeyboardInterrupt:
|
|
479
|
+
print(f'{RESET_ALL}\n')
|
|
480
|
+
continue
|
|
481
|
+
|
|
482
|
+
else:
|
|
483
|
+
print()
|
|
484
|
+
return '', next_message_start
|
|
485
|
+
|
|
486
|
+
else:
|
|
487
|
+
full_user_input += user_input
|
|
488
|
+
return full_user_input, None
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def interact(
|
|
492
|
+
self,
|
|
493
|
+
color: bool = True,
|
|
494
|
+
header: Optional[str] = None,
|
|
495
|
+
stream: bool = True
|
|
496
|
+
) -> None:
|
|
497
|
+
"""
|
|
498
|
+
Start an interactive chat session using this Thread.
|
|
499
|
+
|
|
500
|
+
While text is being generated, press `^C` to interrupt the bot.
|
|
501
|
+
Then you have the option to press `ENTER` to re-roll, or to simply type
|
|
502
|
+
another message.
|
|
503
|
+
|
|
504
|
+
At the prompt, press `^C` to end the chat session.
|
|
505
|
+
|
|
506
|
+
Type `!` and press `ENTER` to enter a basic command prompt. For a list
|
|
507
|
+
of commands, type `help` at this prompt.
|
|
508
|
+
|
|
509
|
+
Type `<` and press `ENTER` to prefix the bot's next message, for
|
|
510
|
+
example with `Sure!`.
|
|
511
|
+
|
|
512
|
+
The following parameters are optional:
|
|
513
|
+
- color: Whether to use colored text to differentiate user / bot
|
|
514
|
+
- header: Header text to print at the start of the interaction
|
|
515
|
+
- stream: Whether to stream text as it is generated
|
|
516
|
+
"""
|
|
517
|
+
print()
|
|
518
|
+
|
|
519
|
+
# fresh import of color codes in case `color` param has changed
|
|
520
|
+
from .utils import USER_STYLE, BOT_STYLE, DIM_STYLE, SPECIAL_STYLE
|
|
521
|
+
|
|
522
|
+
# disable color codes if explicitly disabled by `color` param
|
|
523
|
+
if not color:
|
|
524
|
+
USER_STYLE = ''
|
|
525
|
+
BOT_STYLE = ''
|
|
526
|
+
DIM_STYLE = ''
|
|
527
|
+
SPECIAL_STYLE = ''
|
|
528
|
+
|
|
529
|
+
if header is not None:
|
|
530
|
+
print(f"{SPECIAL_STYLE}{header}{RESET_ALL}\n")
|
|
531
|
+
|
|
532
|
+
while True:
|
|
533
|
+
|
|
534
|
+
prompt = f"{RESET_ALL} > {USER_STYLE}"
|
|
535
|
+
|
|
536
|
+
try:
|
|
537
|
+
user_prompt, next_message_start = self._interactive_input(
|
|
538
|
+
prompt,
|
|
539
|
+
DIM_STYLE,
|
|
540
|
+
USER_STYLE,
|
|
541
|
+
BOT_STYLE
|
|
542
|
+
)
|
|
543
|
+
except KeyboardInterrupt:
|
|
544
|
+
print(f"{RESET_ALL}\n")
|
|
545
|
+
return
|
|
546
|
+
|
|
547
|
+
# got 'exit' or 'quit' command
|
|
548
|
+
if user_prompt is None and next_message_start is None:
|
|
549
|
+
break
|
|
550
|
+
|
|
551
|
+
if next_message_start is not None:
|
|
552
|
+
try:
|
|
553
|
+
if stream:
|
|
554
|
+
print(f"{BOT_STYLE}{next_message_start}", end='', flush=True)
|
|
555
|
+
output = next_message_start + self.model.stream_print(
|
|
556
|
+
self.inference_str_from_messages() + next_message_start,
|
|
557
|
+
stops=self.format['stops'],
|
|
558
|
+
sampler=self.sampler,
|
|
559
|
+
end=''
|
|
560
|
+
)
|
|
561
|
+
else:
|
|
562
|
+
print(f"{BOT_STYLE}", end='', flush=True)
|
|
563
|
+
output = next_message_start + self.model.generate(
|
|
564
|
+
self.inference_str_from_messages() + next_message_start,
|
|
565
|
+
stops=self.format['stops'],
|
|
566
|
+
sampler=self.sampler
|
|
567
|
+
)
|
|
568
|
+
print(output, end='', flush=True)
|
|
569
|
+
except KeyboardInterrupt:
|
|
570
|
+
print(f"{DIM_STYLE} [message not added to history; press ENTER to re-roll]\n")
|
|
571
|
+
continue
|
|
572
|
+
else:
|
|
573
|
+
self.add_message("bot", output)
|
|
574
|
+
else:
|
|
575
|
+
print(BOT_STYLE)
|
|
576
|
+
if user_prompt != "":
|
|
577
|
+
self.add_message("user", user_prompt)
|
|
578
|
+
try:
|
|
579
|
+
if stream:
|
|
580
|
+
output = self.model.stream_print(
|
|
581
|
+
self.inference_str_from_messages(),
|
|
582
|
+
stops=self.format['stops'],
|
|
583
|
+
sampler=self.sampler,
|
|
584
|
+
end=''
|
|
585
|
+
)
|
|
586
|
+
else:
|
|
587
|
+
output = self.model.generate(
|
|
588
|
+
self.inference_str_from_messages(),
|
|
589
|
+
stops=self.format['stops'],
|
|
590
|
+
sampler=self.sampler
|
|
591
|
+
)
|
|
592
|
+
print(output, end='', flush=True)
|
|
593
|
+
except KeyboardInterrupt:
|
|
594
|
+
print(f"{DIM_STYLE} [message not added to history; press ENTER to re-roll]\n")
|
|
595
|
+
continue
|
|
596
|
+
else:
|
|
597
|
+
self.add_message("bot", output)
|
|
598
|
+
|
|
599
|
+
if output.endswith("\n\n"):
|
|
600
|
+
print(RESET_ALL, end = '', flush=True)
|
|
601
|
+
elif output.endswith("\n"):
|
|
602
|
+
print(RESET_ALL)
|
|
603
|
+
else:
|
|
604
|
+
print(f"{RESET_ALL}\n")
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def reset(self) -> None:
|
|
608
|
+
"""
|
|
609
|
+
Clear the list of messages, which resets the thread to its original
|
|
610
|
+
state
|
|
611
|
+
"""
|
|
612
|
+
self.messages: list[dict[str, str]] = [
|
|
613
|
+
self.create_message("system", self.format['system_content'])
|
|
614
|
+
]
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def as_string(self) -> str:
|
|
618
|
+
"""Return this thread's message history as a string"""
|
|
619
|
+
ret = ''
|
|
620
|
+
for msg in self.messages:
|
|
621
|
+
ret += msg['prefix']
|
|
622
|
+
ret += msg['content']
|
|
623
|
+
ret += msg['postfix']
|
|
624
|
+
return ret
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def print_stats(
|
|
628
|
+
self,
|
|
629
|
+
end: str = '\n',
|
|
630
|
+
file: _SupportsWriteAndFlush = sys.stdout,
|
|
631
|
+
flush: bool = True
|
|
632
|
+
) -> None:
|
|
633
|
+
"""Print stats about the context usage in this thread"""
|
|
634
|
+
thread_len_tokens = self.len_messages()
|
|
635
|
+
max_ctx_len = self.model.context_length
|
|
636
|
+
context_used_percentage = round((thread_len_tokens/max_ctx_len)*100)
|
|
637
|
+
print(f"{thread_len_tokens} / {max_ctx_len} tokens", file=file, flush=flush)
|
|
638
|
+
print(f"{context_used_percentage}% of context used", file=file, flush=flush)
|
|
639
|
+
print(f"{len(self.messages)} messages", end=end, file=file, flush=flush)
|
|
640
|
+
if not flush:
|
|
641
|
+
file.flush()
|