webscout 2.6__py3-none-any.whl → 2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- webscout/Local/_version.py +1 -1
- webscout/Local/formats.py +154 -88
- webscout/Local/model.py +4 -4
- webscout/Local/thread.py +166 -156
- webscout/Provider/BasedGPT.py +226 -0
- webscout/Provider/__init__.py +1 -0
- webscout/__init__.py +2 -2
- webscout/cli.py +39 -3
- webscout/version.py +1 -1
- webscout/webscout_search.py +1018 -40
- webscout/webscout_search_async.py +151 -839
- {webscout-2.6.dist-info → webscout-2.7.dist-info}/METADATA +35 -21
- {webscout-2.6.dist-info → webscout-2.7.dist-info}/RECORD +17 -16
- {webscout-2.6.dist-info → webscout-2.7.dist-info}/LICENSE.md +0 -0
- {webscout-2.6.dist-info → webscout-2.7.dist-info}/WHEEL +0 -0
- {webscout-2.6.dist-info → webscout-2.7.dist-info}/entry_points.txt +0 -0
- {webscout-2.6.dist-info → webscout-2.7.dist-info}/top_level.txt +0 -0
webscout/Local/thread.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
# thread.py
|
|
2
|
-
# https://github.com/ddh0/easy-llama/
|
|
3
1
|
from ._version import __version__, __llama_cpp_version__
|
|
4
2
|
|
|
5
3
|
"""Submodule containing the Thread class, used for interaction with a Model"""
|
|
@@ -10,15 +8,51 @@ from .model import Model, assert_model_is_loaded, _SupportsWriteAndFlush
|
|
|
10
8
|
from .utils import RESET_ALL, cls, print_verbose, truncate
|
|
11
9
|
from .samplers import SamplerSettings, DefaultSampling
|
|
12
10
|
from typing import Optional, Literal, Union
|
|
11
|
+
from .formats import AdvancedFormat
|
|
13
12
|
|
|
14
13
|
from .formats import blank as formats_blank
|
|
15
14
|
|
|
16
15
|
|
|
16
|
+
class Message(dict):
|
|
17
|
+
"""
|
|
18
|
+
A dictionary representing a single message within a Thread
|
|
19
|
+
|
|
20
|
+
Works just like a normal `dict`, but a new method:
|
|
21
|
+
- `.as_string` - Return the full message string
|
|
22
|
+
|
|
23
|
+
Generally, messages have these keys:
|
|
24
|
+
- `role` - The role of the speaker: 'system', 'user', or 'bot'
|
|
25
|
+
- `prefix` - The text that prefixes the message content
|
|
26
|
+
- `content` - The actual content of the message
|
|
27
|
+
- `suffix` - The text that suffixes the message content
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __repr__(self) -> str:
|
|
31
|
+
return \
|
|
32
|
+
f"Message([" \
|
|
33
|
+
f"('role', {repr(self['role'])}), " \
|
|
34
|
+
f"('prefix', {repr(self['prefix'])}), " \
|
|
35
|
+
f"('content', {repr(self['content'])}), " \
|
|
36
|
+
f"('suffix', {repr(self['suffix'])})])"
|
|
37
|
+
|
|
38
|
+
def as_string(self):
|
|
39
|
+
"""Return the full message string"""
|
|
40
|
+
try:
|
|
41
|
+
return self['prefix'] + self['content'] + self['suffix']
|
|
42
|
+
except KeyError as e:
|
|
43
|
+
e.add_note(
|
|
44
|
+
"as_string: Message is missing one or more of the "
|
|
45
|
+
"required 'prefix', 'content', 'suffix' attributes - this is "
|
|
46
|
+
"unexpected"
|
|
47
|
+
)
|
|
48
|
+
raise e
|
|
49
|
+
|
|
50
|
+
|
|
17
51
|
class Thread:
|
|
18
52
|
"""
|
|
19
53
|
Provide functionality to facilitate easy interactions with a Model
|
|
20
54
|
|
|
21
|
-
This is just a brief overview of
|
|
55
|
+
This is just a brief overview of m.Thread.
|
|
22
56
|
To see a full description of each method and its parameters,
|
|
23
57
|
call help(Thread), or see the relevant docstring.
|
|
24
58
|
|
|
@@ -36,24 +70,26 @@ class Thread:
|
|
|
36
70
|
The following attributes are available:
|
|
37
71
|
- `.format` - The format being used for messages in this thread
|
|
38
72
|
- `.messages` - The list of messages in this thread
|
|
39
|
-
- `.model` - The `
|
|
73
|
+
- `.model` - The `m.Model` instance used by this thread
|
|
40
74
|
- `.sampler` - The SamplerSettings object used in this thread
|
|
41
75
|
"""
|
|
42
76
|
|
|
43
77
|
def __init__(
|
|
44
78
|
self,
|
|
45
79
|
model: Model,
|
|
46
|
-
format:
|
|
47
|
-
sampler: SamplerSettings = DefaultSampling
|
|
80
|
+
format: Union[dict, AdvancedFormat],
|
|
81
|
+
sampler: SamplerSettings = DefaultSampling,
|
|
82
|
+
messages: Optional[list[Message]] = None,
|
|
48
83
|
):
|
|
49
84
|
"""
|
|
50
85
|
Given a Model and a format, construct a Thread instance.
|
|
51
86
|
|
|
52
87
|
model: The Model to use for text generation
|
|
53
|
-
format: The format specifying how messages should be structured (see
|
|
88
|
+
format: The format specifying how messages should be structured (see m.formats)
|
|
54
89
|
|
|
55
|
-
The following
|
|
90
|
+
The following parameters are optional:
|
|
56
91
|
- sampler: The SamplerSettings object used to control text generation
|
|
92
|
+
- messages: A list of m.thread.Message objects to add to the Thread upon construction
|
|
57
93
|
"""
|
|
58
94
|
|
|
59
95
|
assert isinstance(model, Model), \
|
|
@@ -62,8 +98,8 @@ class Thread:
|
|
|
62
98
|
|
|
63
99
|
assert_model_is_loaded(model)
|
|
64
100
|
|
|
65
|
-
assert isinstance(format, dict), \
|
|
66
|
-
f"Thread: format should be dict, not {type(format)}"
|
|
101
|
+
assert isinstance(format, (dict, AdvancedFormat)), \
|
|
102
|
+
f"Thread: format should be dict or AdvancedFormat, not {type(format)}"
|
|
67
103
|
|
|
68
104
|
if any(k not in format.keys() for k in formats_blank.keys()):
|
|
69
105
|
raise KeyError(
|
|
@@ -87,12 +123,23 @@ class Thread:
|
|
|
87
123
|
'top_k'
|
|
88
124
|
]
|
|
89
125
|
), 'Thread: sampler is missing one or more required attributes'
|
|
126
|
+
|
|
127
|
+
self._messages: Optional[list[Message]] = messages
|
|
128
|
+
if self._messages is not None:
|
|
129
|
+
if not all(isinstance(msg, Message) for msg in self._messages):
|
|
130
|
+
raise TypeError(
|
|
131
|
+
"Thread: one or more messages provided to __init__() is "
|
|
132
|
+
"not an instance of m.thread.Message"
|
|
133
|
+
)
|
|
90
134
|
|
|
135
|
+
# Thread.messages is never empty, unless `messages` param is explicity
|
|
136
|
+
# set to `[]` during construction
|
|
137
|
+
|
|
91
138
|
self.model: Model = model
|
|
92
|
-
self.format:
|
|
93
|
-
self.messages: list[
|
|
139
|
+
self.format: Union[dict, AdvancedFormat] = format
|
|
140
|
+
self.messages: list[Message] = [
|
|
94
141
|
self.create_message("system", self.format['system_content'])
|
|
95
|
-
]
|
|
142
|
+
] if self._messages is None else self._messages
|
|
96
143
|
self.sampler: SamplerSettings = sampler
|
|
97
144
|
|
|
98
145
|
if self.model.verbose:
|
|
@@ -100,13 +147,13 @@ class Thread:
|
|
|
100
147
|
print_verbose(f"model == {self.model}")
|
|
101
148
|
print_verbose(f"format['system_prefix'] == {truncate(repr(self.format['system_prefix']))}")
|
|
102
149
|
print_verbose(f"format['system_content'] == {truncate(repr(self.format['system_content']))}")
|
|
103
|
-
print_verbose(f"format['
|
|
150
|
+
print_verbose(f"format['system_suffix'] == {truncate(repr(self.format['system_suffix']))}")
|
|
104
151
|
print_verbose(f"format['user_prefix'] == {truncate(repr(self.format['user_prefix']))}")
|
|
105
152
|
print_verbose(f"format['user_content'] == {truncate(repr(self.format['user_content']))}")
|
|
106
|
-
print_verbose(f"format['
|
|
153
|
+
print_verbose(f"format['user_suffix'] == {truncate(repr(self.format['user_suffix']))}")
|
|
107
154
|
print_verbose(f"format['bot_prefix'] == {truncate(repr(self.format['bot_prefix']))}")
|
|
108
155
|
print_verbose(f"format['bot_content'] == {truncate(repr(self.format['bot_content']))}")
|
|
109
|
-
print_verbose(f"format['
|
|
156
|
+
print_verbose(f"format['bot_suffix'] == {truncate(repr(self.format['bot_suffix']))}")
|
|
110
157
|
print_verbose(f"format['stops'] == {truncate(repr(self.format['stops']))}")
|
|
111
158
|
print_verbose(f"sampler.temp == {self.sampler.temp}")
|
|
112
159
|
print_verbose(f"sampler.top_p == {self.sampler.top_p}")
|
|
@@ -118,29 +165,10 @@ class Thread:
|
|
|
118
165
|
|
|
119
166
|
|
|
120
167
|
def __repr__(self) -> str:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
return repr_str
|
|
126
|
-
else:
|
|
127
|
-
for msg in self.messages:
|
|
128
|
-
if msg['role'] == 'user':
|
|
129
|
-
repr_str += "\nThread.add_message('user', " + repr(msg['content']) + ')'
|
|
130
|
-
elif msg['role'] == 'bot':
|
|
131
|
-
repr_str += "\nThread.add_message('bot', " + repr(msg['content']) + ')'
|
|
132
|
-
return repr_str
|
|
133
|
-
def save_conversation(self, filepath: str) -> None:
|
|
134
|
-
"""
|
|
135
|
-
Saves the conversation history to a JSON file.
|
|
136
|
-
|
|
137
|
-
filepath: The path to the file where the conversation should be saved.
|
|
138
|
-
"""
|
|
139
|
-
import json
|
|
140
|
-
|
|
141
|
-
data = [{'role': msg['role'], 'content': msg['content']} for msg in self.messages]
|
|
142
|
-
with open(filepath, 'w') as f:
|
|
143
|
-
json.dump(data, f, indent=4)
|
|
168
|
+
return \
|
|
169
|
+
f"Thread({repr(self.model)}, {repr(self.format)}, " + \
|
|
170
|
+
f"{repr(self.sampler)}, {repr(self.messages)})"
|
|
171
|
+
|
|
144
172
|
def __str__(self) -> str:
|
|
145
173
|
return self.as_string()
|
|
146
174
|
|
|
@@ -156,9 +184,9 @@ class Thread:
|
|
|
156
184
|
self,
|
|
157
185
|
role: Literal['system', 'user', 'bot'],
|
|
158
186
|
content: str
|
|
159
|
-
) ->
|
|
187
|
+
) -> Message:
|
|
160
188
|
"""
|
|
161
|
-
|
|
189
|
+
Construct a message using the format of this Thread
|
|
162
190
|
"""
|
|
163
191
|
|
|
164
192
|
assert role.lower() in ['system', 'user', 'bot'], \
|
|
@@ -168,34 +196,40 @@ class Thread:
|
|
|
168
196
|
f"create_message: content should be str, not {type(content)}"
|
|
169
197
|
|
|
170
198
|
if role.lower() == 'system':
|
|
171
|
-
return
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
199
|
+
return Message(
|
|
200
|
+
[
|
|
201
|
+
('role', 'system'),
|
|
202
|
+
('prefix', self.format['system_prefix']),
|
|
203
|
+
('content', content),
|
|
204
|
+
('suffix', self.format['system_suffix'])
|
|
205
|
+
]
|
|
206
|
+
)
|
|
177
207
|
|
|
178
208
|
elif role.lower() == 'user':
|
|
179
|
-
return
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
209
|
+
return Message(
|
|
210
|
+
[
|
|
211
|
+
('role', 'user'),
|
|
212
|
+
('prefix', self.format['user_prefix']),
|
|
213
|
+
('content', content),
|
|
214
|
+
('suffix', self.format['user_suffix'])
|
|
215
|
+
]
|
|
216
|
+
)
|
|
185
217
|
|
|
186
218
|
elif role.lower() == 'bot':
|
|
187
|
-
return
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
219
|
+
return Message(
|
|
220
|
+
[
|
|
221
|
+
('role', 'bot'),
|
|
222
|
+
('prefix', self.format['bot_prefix']),
|
|
223
|
+
('content', content),
|
|
224
|
+
('suffix', self.format['bot_suffix'])
|
|
225
|
+
]
|
|
226
|
+
)
|
|
193
227
|
|
|
194
228
|
def len_messages(self) -> int:
|
|
195
229
|
"""
|
|
196
230
|
Return the total length of all messages in this thread, in tokens.
|
|
197
231
|
|
|
198
|
-
|
|
232
|
+
Can also use `len(Thread)`."""
|
|
199
233
|
|
|
200
234
|
return self.model.get_length(self.as_string())
|
|
201
235
|
|
|
@@ -223,38 +257,35 @@ class Thread:
|
|
|
223
257
|
respecting the format and context length of this thread.
|
|
224
258
|
"""
|
|
225
259
|
|
|
226
|
-
messages = self.messages
|
|
227
|
-
|
|
228
|
-
context_len_budget = self.model.context_length
|
|
229
|
-
if len(messages) > 0:
|
|
230
|
-
sys_msg = messages[0]
|
|
231
|
-
sys_msg_str = (
|
|
232
|
-
sys_msg['prefix'] + sys_msg['content'] + sys_msg['postfix']
|
|
233
|
-
)
|
|
234
|
-
context_len_budget -= self.model.get_length(sys_msg_str)
|
|
235
|
-
else:
|
|
236
|
-
sys_msg_str = ''
|
|
237
|
-
|
|
238
260
|
inf_str = ''
|
|
261
|
+
sys_msg_str = ''
|
|
262
|
+
# whether to treat the first message as necessary to keep
|
|
263
|
+
sys_msg_flag = False
|
|
264
|
+
context_len_budget = self.model.context_length
|
|
239
265
|
|
|
240
|
-
#
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
266
|
+
# if at least 1 message is history
|
|
267
|
+
if len(self.messages) >= 1:
|
|
268
|
+
# if first message has system role
|
|
269
|
+
if self.messages[0]['role'] == 'system':
|
|
270
|
+
sys_msg_flag = True
|
|
271
|
+
sys_msg = self.messages[0]
|
|
272
|
+
sys_msg_str = sys_msg.as_string()
|
|
273
|
+
context_len_budget -= self.model.get_length(sys_msg_str)
|
|
274
|
+
|
|
275
|
+
if sys_msg_flag:
|
|
276
|
+
iterator = reversed(self.messages[1:])
|
|
277
|
+
else:
|
|
278
|
+
iterator = reversed(self.messages)
|
|
279
|
+
|
|
280
|
+
for message in iterator:
|
|
281
|
+
msg_str = message.as_string()
|
|
282
|
+
context_len_budget -= self.model.get_length(msg_str)
|
|
248
283
|
if context_len_budget <= 0:
|
|
249
284
|
break
|
|
250
|
-
|
|
251
|
-
msg_str = (
|
|
252
|
-
message['prefix'] + message['content'] + message['postfix']
|
|
253
|
-
)
|
|
254
|
-
|
|
255
285
|
inf_str = msg_str + inf_str
|
|
256
286
|
|
|
257
|
-
|
|
287
|
+
if sys_msg_flag:
|
|
288
|
+
inf_str = sys_msg_str + inf_str
|
|
258
289
|
inf_str += self.format['bot_prefix']
|
|
259
290
|
|
|
260
291
|
return inf_str
|
|
@@ -277,20 +308,7 @@ class Thread:
|
|
|
277
308
|
self.add_message("bot", output)
|
|
278
309
|
|
|
279
310
|
return output
|
|
280
|
-
|
|
281
|
-
"""
|
|
282
|
-
Loads a conversation history from a JSON file.
|
|
283
|
-
|
|
284
|
-
filepath: The path to the file containing the conversation history.
|
|
285
|
-
"""
|
|
286
|
-
import json
|
|
287
|
-
|
|
288
|
-
with open(filepath, 'r') as f:
|
|
289
|
-
data = json.load(f)
|
|
290
|
-
|
|
291
|
-
self.messages = []
|
|
292
|
-
for item in data:
|
|
293
|
-
self.messages.append(self.create_message(item['role'], item['content']))
|
|
311
|
+
|
|
294
312
|
|
|
295
313
|
def _interactive_update_sampler(self) -> None:
|
|
296
314
|
"""Interactively update the sampler settings used in this Thread"""
|
|
@@ -373,7 +391,8 @@ class Thread:
|
|
|
373
391
|
prompt: str,
|
|
374
392
|
_dim_style: str,
|
|
375
393
|
_user_style: str,
|
|
376
|
-
_bot_style: str
|
|
394
|
+
_bot_style: str,
|
|
395
|
+
_special_style: str
|
|
377
396
|
) -> tuple:
|
|
378
397
|
"""
|
|
379
398
|
Recive input from the user, while handling multi-line input
|
|
@@ -420,15 +439,7 @@ class Thread:
|
|
|
420
439
|
|
|
421
440
|
elif command.lower() in ['str', 'string', 'as_string']:
|
|
422
441
|
print(f"\n{self.as_string()}\n")
|
|
423
|
-
|
|
424
|
-
elif command.lower() in ['save']:
|
|
425
|
-
print()
|
|
426
|
-
try:
|
|
427
|
-
filepath = input(f'Enter filepath to save conversation: ')
|
|
428
|
-
self.save_conversation(filepath)
|
|
429
|
-
print(f'[conversation saved to {filepath}]\n')
|
|
430
|
-
except Exception as e:
|
|
431
|
-
print(f'[error saving conversation: {e}]\n')
|
|
442
|
+
|
|
432
443
|
elif command.lower() in ['repr', 'save', 'backup']:
|
|
433
444
|
print(f"\n{repr(self)}\n")
|
|
434
445
|
|
|
@@ -438,14 +449,7 @@ class Thread:
|
|
|
438
449
|
del self.messages[-1]
|
|
439
450
|
assert len(self.messages) == (old_len - 1)
|
|
440
451
|
print('[removed last message]\n')
|
|
441
|
-
|
|
442
|
-
print()
|
|
443
|
-
try:
|
|
444
|
-
filepath = input(f'Enter filepath to load conversation: ')
|
|
445
|
-
self.load_conversation(filepath)
|
|
446
|
-
print(f'[conversation loaded from {filepath}]\n')
|
|
447
|
-
except Exception as e:
|
|
448
|
-
print(f'[error loading conversation: {e}]\n')
|
|
452
|
+
|
|
449
453
|
elif command.lower() in ['last', 'repeat']:
|
|
450
454
|
last_msg = self.messages[-1]
|
|
451
455
|
if last_msg['role'] == 'user':
|
|
@@ -468,33 +472,35 @@ class Thread:
|
|
|
468
472
|
|
|
469
473
|
elif command.lower() in ['help', '/?', '?']:
|
|
470
474
|
print()
|
|
471
|
-
print('reset
|
|
472
|
-
print('clear
|
|
473
|
-
print('context
|
|
474
|
-
print('print_stats
|
|
475
|
-
print('sampler
|
|
476
|
-
print('string
|
|
477
|
-
print('repr
|
|
478
|
-
print('remove
|
|
479
|
-
print('last
|
|
480
|
-
print('inference
|
|
481
|
-
print('reroll
|
|
482
|
-
print('exit
|
|
483
|
-
print('help
|
|
475
|
+
print('reset | restart -- Reset the thread to its original state')
|
|
476
|
+
print('clear | cls -- Clear the terminal')
|
|
477
|
+
print('context | ctx -- Get the context usage in tokens')
|
|
478
|
+
print('print_stats | stats -- Get the context usage stats')
|
|
479
|
+
print('sampler | settings -- Update the sampler settings')
|
|
480
|
+
print('string | str -- Print the message history as a string')
|
|
481
|
+
print('repr | save -- Print the representation of the thread')
|
|
482
|
+
print('remove | delete -- Remove the last message')
|
|
483
|
+
print('last | repeat -- Repeat the last message')
|
|
484
|
+
print('inference | inf -- Print the inference string')
|
|
485
|
+
print('reroll | swipe -- Regenerate the last message')
|
|
486
|
+
print('exit | quit -- Exit the interactive chat (can also use ^C)')
|
|
487
|
+
print('help | ? -- Show this screen')
|
|
484
488
|
print()
|
|
485
489
|
print("TIP: type < at the prompt and press ENTER to prefix the bot's next message.")
|
|
486
490
|
print(' for example, type "Sure!" to bypass refusals')
|
|
487
491
|
print()
|
|
488
|
-
print(
|
|
489
|
-
print(
|
|
492
|
+
print("TIP: type !! at the prompt and press ENTER to insert a system message")
|
|
493
|
+
print()
|
|
494
|
+
|
|
490
495
|
else:
|
|
491
496
|
print(f'\n[unknown command]\n')
|
|
492
497
|
|
|
493
|
-
|
|
498
|
+
# prefix the bot's next message
|
|
499
|
+
elif user_input == '<':
|
|
494
500
|
|
|
495
501
|
print()
|
|
496
502
|
try:
|
|
497
|
-
next_message_start = input(f'{
|
|
503
|
+
next_message_start = input(f'{RESET_ALL} < {_dim_style}')
|
|
498
504
|
|
|
499
505
|
except KeyboardInterrupt:
|
|
500
506
|
print(f'{RESET_ALL}\n')
|
|
@@ -503,25 +509,23 @@ class Thread:
|
|
|
503
509
|
else:
|
|
504
510
|
print()
|
|
505
511
|
return '', next_message_start
|
|
506
|
-
|
|
507
|
-
elif user_input.endswith('<'):
|
|
508
512
|
|
|
513
|
+
# insert a system message
|
|
514
|
+
elif user_input == '!!':
|
|
509
515
|
print()
|
|
510
516
|
|
|
511
|
-
msg = user_input.removesuffix('<')
|
|
512
|
-
self.add_message("user", msg)
|
|
513
|
-
|
|
514
517
|
try:
|
|
515
|
-
|
|
516
|
-
|
|
518
|
+
next_sys_msg = input(f'{RESET_ALL} !! {_special_style}')
|
|
519
|
+
|
|
517
520
|
except KeyboardInterrupt:
|
|
518
521
|
print(f'{RESET_ALL}\n')
|
|
519
522
|
continue
|
|
520
|
-
|
|
523
|
+
|
|
521
524
|
else:
|
|
522
525
|
print()
|
|
523
|
-
return
|
|
526
|
+
return next_sys_msg, -1
|
|
524
527
|
|
|
528
|
+
# concatenate multi-line input
|
|
525
529
|
else:
|
|
526
530
|
full_user_input += user_input
|
|
527
531
|
return full_user_input, None
|
|
@@ -548,6 +552,8 @@ class Thread:
|
|
|
548
552
|
Type `<` and press `ENTER` to prefix the bot's next message, for
|
|
549
553
|
example with `Sure!`.
|
|
550
554
|
|
|
555
|
+
Type `!!` at the prompt and press `ENTER` to insert a system message.
|
|
556
|
+
|
|
551
557
|
The following parameters are optional:
|
|
552
558
|
- color: Whether to use colored text to differentiate user / bot
|
|
553
559
|
- header: Header text to print at the start of the interaction
|
|
@@ -556,28 +562,29 @@ class Thread:
|
|
|
556
562
|
print()
|
|
557
563
|
|
|
558
564
|
# fresh import of color codes in case `color` param has changed
|
|
559
|
-
from .utils import USER_STYLE, BOT_STYLE, DIM_STYLE
|
|
565
|
+
from .utils import SPECIAL_STYLE, USER_STYLE, BOT_STYLE, DIM_STYLE
|
|
560
566
|
|
|
561
567
|
# disable color codes if explicitly disabled by `color` param
|
|
562
568
|
if not color:
|
|
569
|
+
SPECIAL_STYLE = ''
|
|
563
570
|
USER_STYLE = ''
|
|
564
571
|
BOT_STYLE = ''
|
|
565
572
|
DIM_STYLE = ''
|
|
566
|
-
SPECIAL_STYLE = ''
|
|
567
573
|
|
|
568
574
|
if header is not None:
|
|
569
575
|
print(f"{SPECIAL_STYLE}{header}{RESET_ALL}\n")
|
|
570
576
|
|
|
571
577
|
while True:
|
|
572
578
|
|
|
573
|
-
prompt = f"{RESET_ALL}
|
|
579
|
+
prompt = f"{RESET_ALL} > {USER_STYLE}"
|
|
574
580
|
|
|
575
581
|
try:
|
|
576
582
|
user_prompt, next_message_start = self._interactive_input(
|
|
577
583
|
prompt,
|
|
578
584
|
DIM_STYLE,
|
|
579
585
|
USER_STYLE,
|
|
580
|
-
BOT_STYLE
|
|
586
|
+
BOT_STYLE,
|
|
587
|
+
SPECIAL_STYLE
|
|
581
588
|
)
|
|
582
589
|
except KeyboardInterrupt:
|
|
583
590
|
print(f"{RESET_ALL}\n")
|
|
@@ -587,6 +594,11 @@ class Thread:
|
|
|
587
594
|
if user_prompt is None and next_message_start is None:
|
|
588
595
|
break
|
|
589
596
|
|
|
597
|
+
# insert a system message via `!!` prompt
|
|
598
|
+
if next_message_start == -1:
|
|
599
|
+
self.add_message('system', user_prompt)
|
|
600
|
+
continue
|
|
601
|
+
|
|
590
602
|
if next_message_start is not None:
|
|
591
603
|
try:
|
|
592
604
|
if stream:
|
|
@@ -648,19 +660,17 @@ class Thread:
|
|
|
648
660
|
Clear the list of messages, which resets the thread to its original
|
|
649
661
|
state
|
|
650
662
|
"""
|
|
651
|
-
self.messages: list[
|
|
663
|
+
self.messages: list[Message] = [
|
|
652
664
|
self.create_message("system", self.format['system_content'])
|
|
653
|
-
]
|
|
665
|
+
] if self._messages is None else self._messages
|
|
654
666
|
|
|
655
667
|
|
|
656
668
|
def as_string(self) -> str:
|
|
657
669
|
"""Return this thread's message history as a string"""
|
|
658
|
-
|
|
670
|
+
thread_string = ''
|
|
659
671
|
for msg in self.messages:
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
ret += msg['postfix']
|
|
663
|
-
return ret
|
|
672
|
+
thread_string += msg.as_string()
|
|
673
|
+
return thread_string
|
|
664
674
|
|
|
665
675
|
|
|
666
676
|
def print_stats(
|
|
@@ -677,4 +687,4 @@ class Thread:
|
|
|
677
687
|
print(f"{context_used_percentage}% of context used", file=file, flush=flush)
|
|
678
688
|
print(f"{len(self.messages)} messages", end=end, file=file, flush=flush)
|
|
679
689
|
if not flush:
|
|
680
|
-
file.flush()
|
|
690
|
+
file.flush()
|