webscout 7.4__py3-none-any.whl → 7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of webscout might be problematic. Click here for more details.

Files changed (42) hide show
  1. webscout/Provider/C4ai.py +414 -0
  2. webscout/Provider/Cloudflare.py +18 -21
  3. webscout/Provider/DeepSeek.py +3 -32
  4. webscout/Provider/Deepinfra.py +30 -21
  5. webscout/Provider/GithubChat.py +362 -0
  6. webscout/Provider/HeckAI.py +20 -3
  7. webscout/Provider/HuggingFaceChat.py +462 -0
  8. webscout/Provider/Marcus.py +7 -50
  9. webscout/Provider/Netwrck.py +6 -53
  10. webscout/Provider/Phind.py +29 -3
  11. webscout/Provider/TTI/aiarta/__init__.py +2 -0
  12. webscout/Provider/TTI/aiarta/async_aiarta.py +482 -0
  13. webscout/Provider/TTI/aiarta/sync_aiarta.py +409 -0
  14. webscout/Provider/Venice.py +200 -200
  15. webscout/Provider/Youchat.py +1 -1
  16. webscout/Provider/__init__.py +13 -2
  17. webscout/Provider/akashgpt.py +8 -5
  18. webscout/Provider/copilot.py +416 -0
  19. webscout/Provider/flowith.py +181 -0
  20. webscout/Provider/granite.py +17 -53
  21. webscout/Provider/llamatutor.py +6 -46
  22. webscout/Provider/llmchat.py +7 -46
  23. webscout/Provider/multichat.py +29 -91
  24. webscout/exceptions.py +19 -9
  25. webscout/update_checker.py +55 -93
  26. webscout/version.py +1 -1
  27. webscout-7.5.dist-info/LICENSE.md +146 -0
  28. {webscout-7.4.dist-info → webscout-7.5.dist-info}/METADATA +5 -126
  29. {webscout-7.4.dist-info → webscout-7.5.dist-info}/RECORD +32 -33
  30. webscout/Local/__init__.py +0 -10
  31. webscout/Local/_version.py +0 -3
  32. webscout/Local/formats.py +0 -747
  33. webscout/Local/model.py +0 -1368
  34. webscout/Local/samplers.py +0 -125
  35. webscout/Local/thread.py +0 -539
  36. webscout/Local/ui.py +0 -401
  37. webscout/Local/utils.py +0 -388
  38. webscout/Provider/dgaf.py +0 -214
  39. webscout-7.4.dist-info/LICENSE.md +0 -211
  40. {webscout-7.4.dist-info → webscout-7.5.dist-info}/WHEEL +0 -0
  41. {webscout-7.4.dist-info → webscout-7.5.dist-info}/entry_points.txt +0 -0
  42. {webscout-7.4.dist-info → webscout-7.5.dist-info}/top_level.txt +0 -0
@@ -1,125 +0,0 @@
1
- from typing import Optional
2
- from sys import maxsize
3
-
4
- from .utils import assert_type, NoneType
5
-
6
- MAX_TEMP = float(maxsize)
7
-
8
- class SamplerSettings:
9
- """
10
- Specifies sampling parameters for controlling text generation.
11
-
12
- This class allows you to fine-tune the behavior of text generation
13
- models by adjusting various sampling parameters. These settings are
14
- passed as an optional parameter to functions like `Thread.__init__()`,
15
- `Model.generate()`, `Model.stream()`, and `Model.stream_print()`.
16
-
17
- If a parameter is unspecified, the default value from llama.cpp is used.
18
- If all parameters are unspecified, the behavior is equivalent to
19
- `DefaultSampling`.
20
-
21
- Setting a parameter explicitly to `None` disables it. When all samplers
22
- are disabled, it's equivalent to `NoSampling` (unmodified probability
23
- distribution).
24
-
25
- Attributes:
26
- max_len_tokens (Optional[int]): Maximum number of tokens to generate.
27
- Defaults to -1 (no limit).
28
- top_k (Optional[int]): Number of highest probability tokens to consider.
29
- Defaults to 40. Set to `None` to disable.
30
- top_p (Optional[float]): Nucleus sampling threshold (0.0 - 1.0).
31
- Defaults to 0.95. Set to `None` to disable.
32
- min_p (Optional[float]): Minimum probability threshold (0.0 - 1.0).
33
- Defaults to 0.05. Set to `None` to disable.
34
- temp (Optional[float]): Temperature for sampling (0.0 - inf).
35
- Defaults to 0.8. Set to `None` to disable.
36
- frequency_penalty (Optional[float]): Penalty for repeating tokens.
37
- Defaults to 0.0.
38
- presence_penalty (Optional[float]): Penalty for generating new tokens.
39
- Defaults to 0.0.
40
- repeat_penalty (Optional[float]): Penalty for repeating token sequences.
41
- Defaults to 1.0.
42
-
43
- Presets:
44
- - `GreedyDecoding`: Always chooses the most likely token.
45
- - `DefaultSampling`: Uses default parameters from llama.cpp.
46
- - `NoSampling`: Unmodified probability distribution (all parameters disabled).
47
- - `ClassicSampling`: Reflects old llama.cpp defaults.
48
- - `SemiSampling`: Halfway between DefaultSampling and SimpleSampling.
49
- - `TikTokenSampling`: For models with large vocabularies.
50
- - `LowMinPSampling`, `MinPSampling`, `StrictMinPSampling`: Use `min_p` as the only active sampler.
51
- - `ContrastiveSearch`, `WarmContrastiveSearch`: Implement contrastive search.
52
- - `RandomSampling`: Outputs completely random tokens (useless).
53
- - `LowTempSampling`: Default sampling with reduced temperature.
54
- - `HighTempSampling`: Default sampling with increased temperature.
55
- - `LowTopPSampling`, `TopPSampling`, `StrictTopPSampling`: Use `top_p` as the primary sampler.
56
- - `MidnightMiqu`: For sophosympatheia/Midnight-Miqu-70B-v1.5 model.
57
- - `Llama3`: For meta-llama/Meta-Llama-3.1-8B-Instruct model.
58
- - `Nemo`: For mistralai/Mistral-Nemo-Instruct-2407 model.
59
- """
60
-
61
- param_types: dict[str, tuple[type]] = {
62
- 'max_len_tokens': (int, NoneType),
63
- 'top_k': (int, NoneType),
64
- 'top_p': (float, NoneType),
65
- 'min_p': (float, NoneType),
66
- 'temp': (float, NoneType),
67
- 'frequency_penalty': (float, NoneType),
68
- 'presence_penalty': (float, NoneType),
69
- 'repeat_penalty': (float, NoneType)
70
- }
71
-
72
- def __init__(
73
- self,
74
- max_len_tokens: Optional[int] = -1,
75
- top_k: Optional[int] = 40,
76
- top_p: Optional[float] = 0.95,
77
- min_p: Optional[float] = 0.05,
78
- temp: Optional[float] = 0.8,
79
- frequency_penalty: Optional[float] = 0.0,
80
- presence_penalty: Optional[float] = 0.0,
81
- repeat_penalty: Optional[float] = 1.0
82
- ):
83
- self.max_len_tokens = max_len_tokens if max_len_tokens is not None else -1
84
- self.top_k = top_k if top_k is not None else -1
85
- self.top_p = top_p if top_p is not None else 1.0
86
- self.min_p = min_p if min_p is not None else 0.0
87
- self.temp = temp if temp is not None else 1.0
88
- self.frequency_penalty = frequency_penalty if frequency_penalty is not None else 0.0
89
- self.presence_penalty = presence_penalty if presence_penalty is not None else 0.0
90
- self.repeat_penalty = repeat_penalty if repeat_penalty is not None else 1.0
91
-
92
- # Validate parameters using param_types dictionary
93
- for param_name, param_value in self.__dict__.items():
94
- assert_type(param_value, self.param_types[param_name],
95
- f'{param_name} parameter', 'SamplerSettings')
96
-
97
- def __repr__(self) -> str:
98
- params = ', '.join(
99
- f'{name}={value}' for name, value in self.__dict__.items()
100
- )
101
- return f'SamplerSettings({params})'
102
-
103
- # Predefined sampler settings
104
- GreedyDecoding = SamplerSettings(temp=0.0)
105
- DefaultSampling = SamplerSettings()
106
- NoSampling = SimpleSampling = SamplerSettings(top_k=None, top_p=None, min_p=None, temp=None)
107
- ClassicSampling = SamplerSettings(min_p=None, repeat_penalty=1.1)
108
- SemiSampling = SamplerSettings(top_k=80, top_p=0.975, min_p=0.025, temp=0.9)
109
- TikTokenSampling = SamplerSettings(temp=0.65)
110
- LowMinPSampling = SamplerSettings(top_k=None, top_p=None, min_p=0.01, temp=None)
111
- MinPSampling = SamplerSettings(top_k=None, top_p=None, min_p=0.075, temp=None)
112
- StrictMinPSampling = SamplerSettings(top_k=None, top_p=None, min_p=0.2, temp=None)
113
- ContrastiveSearch = SamplerSettings(top_k=None, top_p=None, min_p=None, temp=0.0, presence_penalty=0.6)
114
- WarmContrastiveSearch = SamplerSettings(top_k=None, top_p=None, min_p=None, temp=0.0, presence_penalty=1.0)
115
- RandomSampling = SamplerSettings(top_k=None, top_p=None, min_p=None, temp=MAX_TEMP)
116
- LowTempSampling = SamplerSettings(temp=0.4)
117
- HighTempSampling = SamplerSettings(temp=1.1)
118
- LowTopPSampling = SamplerSettings(top_k=None, top_p=0.98, min_p=None, temp=None)
119
- TopPSampling = SamplerSettings(top_k=None, top_p=0.9, min_p=None, temp=None)
120
- StrictTopPSampling = SamplerSettings(top_k=None, top_p=0.7, min_p=None, temp=None)
121
-
122
- # Model-specific samplers
123
- MidnightMiqu = SamplerSettings(top_k=None, top_p=None, min_p=0.12, temp=1.0, repeat_penalty=1.05) # sophosympatheia/Midnight-Miqu-70B-v1.5
124
- Llama3 = SamplerSettings(top_k=None, top_p=0.9, min_p=None, temp=0.6) # meta-llama/Meta-Llama-3.1-8B-Instruct
125
- Nemo = MistralNemo = MistralSmall = SamplerSettings(top_k=None, top_p=0.85, min_p=None, temp=0.7) # mistralai/Mistral-Nemo-Instruct-2407
webscout/Local/thread.py DELETED
@@ -1,539 +0,0 @@
1
- import sys
2
- import time
3
- from typing import Optional, Literal, Union, Generator, Tuple, TextIO
4
- import uuid
5
-
6
- from .model import Model, assert_model_is_loaded, _SupportsWriteAndFlush
7
- from .utils import RESET_ALL, cls, print_verbose, truncate
8
- from .samplers import SamplerSettings, DefaultSampling
9
- from .formats import AdvancedFormat, blank as formats_blank
10
-
11
-
12
- class Message(dict):
13
- """
14
- Represents a single message within a Thread.
15
-
16
- Inherits from `dict` and provides additional functionality:
17
-
18
- - `as_string()`: Returns the full message string.
19
-
20
- Typical message keys:
21
- - `role`: The speaker's role ('system', 'user', 'bot').
22
- - `prefix`: Text prefixing the content.
23
- - `content`: The message content.
24
- - `suffix`: Text suffixing the content.
25
- """
26
-
27
- def __repr__(self) -> str:
28
- return (
29
- f"Message(["
30
- f"('role', {repr(self['role'])}), "
31
- f"('prefix', {repr(self['prefix'])}), "
32
- f"('content', {repr(self['content'])}), "
33
- f"('suffix', {repr(self['suffix'])})])"
34
- )
35
-
36
- def as_string(self) -> str:
37
- """Returns the full message string."""
38
- try:
39
- return self['prefix'] + self['content'] + self['suffix']
40
- except KeyError as e:
41
- e.add_note(
42
- "Message.as_string(): Missing 'prefix', 'content', or 'suffix' "
43
- "attribute. This is unexpected."
44
- )
45
- raise e
46
-
47
-
48
- class Thread:
49
- """
50
- Facilitates easy interactions with a Model.
51
-
52
- Methods:
53
- - `add_message()`: Appends a message to the thread's messages.
54
- - `as_string()`: Returns the complete message history as a string.
55
- - `create_message()`: Creates a message using the thread's format.
56
- - `inference_str_from_messages()`: Generates an inference-ready string from messages.
57
- - `interact()`: Starts an interactive chat session.
58
- - `len_messages()`: Gets the total token length of all messages.
59
- - `print_stats()`: Prints context usage statistics.
60
- - `reset()`: Clears the message history.
61
- - `send()`: Sends a message and receives a response.
62
- - `warmup()`: Warms up the model by running a simple generation.
63
-
64
- Attributes:
65
- - `format`: The message format (see `webscout.AIutel.formats`).
66
- - `messages`: The list of messages in the thread.
67
- - `model`: The associated `webscout.AIutel.model.Model` instance.
68
- - `sampler`: The `webscout.AIutel.samplers.SamplerSettings` for text generation.
69
- - `tools`: A list of tools available for function calling.
70
- - `uuid`: A unique identifier for the thread (UUID object).
71
- """
72
-
73
- def __init__(
74
- self,
75
- model: Model,
76
- format: Union[dict, AdvancedFormat],
77
- sampler: SamplerSettings = DefaultSampling,
78
- messages: Optional[list[Message]] = None,
79
- ):
80
- """
81
- Initializes a Thread instance.
82
-
83
- Args:
84
- model: The Model instance for text generation.
85
- format: The message format (see `webscout.AIutel.formats`).
86
- sampler: Sampler settings for controlling generation.
87
- messages: Initial list of messages (optional).
88
- """
89
- assert isinstance(model, Model), \
90
- f"Thread: model should be a webscout.AIutel.model.Model, not {type(model)}"
91
- assert_model_is_loaded(model)
92
-
93
- assert isinstance(format, (dict, AdvancedFormat)), \
94
- f"Thread: format should be dict or AdvancedFormat, not {type(format)}"
95
-
96
- if any(k not in format.keys() for k in formats_blank.keys()):
97
- raise KeyError(
98
- "Thread: format is missing one or more required keys, see "
99
- "webscout.AIutel.formats.blank for an example"
100
- )
101
-
102
- assert isinstance(format['stops'], list), \
103
- f"Thread: format['stops'] should be list, not {type(format['stops'])}"
104
-
105
- assert all(
106
- hasattr(sampler, attr) for attr in [
107
- 'max_len_tokens', 'temp', 'top_p', 'min_p',
108
- 'frequency_penalty', 'presence_penalty', 'repeat_penalty',
109
- 'top_k'
110
- ]
111
- ), 'Thread: sampler is missing one or more required attributes'
112
-
113
- self._messages: Optional[list[Message]] = messages
114
- if self._messages is not None:
115
- if not all(isinstance(msg, Message) for msg in self._messages):
116
- raise TypeError(
117
- "Thread: one or more messages provided to __init__() is "
118
- "not an instance of webscout.AIutel.thread.Message"
119
- )
120
-
121
- self.model = model
122
- self.format = format
123
- self.messages: list[Message] = [
124
- self.create_message("system", self.format['system_prompt'])
125
- ] if self._messages is None else self._messages
126
- self.sampler = sampler
127
- self.tools = []
128
- self.uuid = uuid.uuid4() # Generate a UUID for the thread
129
-
130
- if self.model.verbose:
131
- print_verbose("New Thread instance with attributes:")
132
- print_verbose(f"model == {self.model}")
133
- print_verbose(f"format['system_prefix'] == {truncate(repr(self.format['system_prefix']))}")
134
- print_verbose(f"format['system_prompt'] == {truncate(repr(self.format['system_prompt']))}")
135
- print_verbose(f"format['system_suffix'] == {truncate(repr(self.format['system_suffix']))}")
136
- print_verbose(f"format['user_prefix'] == {truncate(repr(self.format['user_prefix']))}")
137
- # print_verbose(f"format['user_content'] == {truncate(repr(self.format['user_content']))}")
138
- print_verbose(f"format['user_suffix'] == {truncate(repr(self.format['user_suffix']))}")
139
- print_verbose(f"format['bot_prefix'] == {truncate(repr(self.format['bot_prefix']))}")
140
- # print_verbose(f"format['bot_content'] == {truncate(repr(self.format['bot_content']))}")
141
- print_verbose(f"format['bot_suffix'] == {truncate(repr(self.format['bot_suffix']))}")
142
- print_verbose(f"format['stops'] == {truncate(repr(self.format['stops']))}")
143
- print_verbose(f"sampler.temp == {self.sampler.temp}")
144
- print_verbose(f"sampler.top_p == {self.sampler.top_p}")
145
- print_verbose(f"sampler.min_p == {self.sampler.min_p}")
146
- print_verbose(f"sampler.frequency_penalty == {self.sampler.frequency_penalty}")
147
- print_verbose(f"sampler.presence_penalty == {self.sampler.presence_penalty}")
148
- print_verbose(f"sampler.repeat_penalty == {self.sampler.repeat_penalty}")
149
- print_verbose(f"sampler.top_k == {self.sampler.top_k}")
150
-
151
- def add_tool(self, tool: dict) -> None:
152
- """
153
- Adds a tool to the Thread for function calling.
154
-
155
- Args:
156
- tool (dict): A dictionary describing the tool, containing
157
- 'function' with 'name', 'description', and 'execute' keys.
158
- """
159
- self.tools.append(tool)
160
- self.model.register_tool(tool['function']['name'], tool['function']['execute'])
161
- self.messages[0]['content'] += f"\nYou have access to the following tool:\n{tool['function']['description']}"
162
-
163
- def __repr__(self) -> str:
164
- return (
165
- f"Thread({repr(self.model)}, {repr(self.format)}, "
166
- f"{repr(self.sampler)}, {repr(self.messages)})"
167
- )
168
-
169
- def __str__(self) -> str:
170
- return self.as_string()
171
-
172
- def __len__(self) -> int:
173
- """Returns the total token length of all messages."""
174
- return self.len_messages()
175
-
176
- def create_message(self, role: Literal['system', 'user', 'bot'], content: str) -> Message:
177
- """Constructs a message using the thread's format."""
178
- assert role.lower() in ['system', 'user', 'bot'], \
179
- f"Thread.create_message(): role should be 'system', 'user', or 'bot', not '{role.lower()}'"
180
- assert isinstance(content, str), \
181
- f"Thread.create_message(): content should be str, not {type(content)}"
182
-
183
- message_data = {
184
- 'system': {
185
- 'role': 'system',
186
- 'prefix': self.format['system_prefix'],
187
- 'content': content,
188
- 'suffix': self.format['system_suffix']
189
- },
190
- 'user': {
191
- 'role': 'user',
192
- 'prefix': self.format['user_prefix'],
193
- 'content': content,
194
- 'suffix': self.format['user_suffix']
195
- },
196
- 'bot': {
197
- 'role': 'bot',
198
- 'prefix': self.format['bot_prefix'],
199
- 'content': content,
200
- 'suffix': self.format['bot_suffix']
201
- }
202
- }
203
-
204
- return Message(message_data[role.lower()])
205
-
206
- def len_messages(self) -> int:
207
- """Returns the total length of all messages in tokens."""
208
- return self.model.get_length(self.as_string())
209
-
210
- def add_message(self, role: Literal['system', 'user', 'bot'], content: str) -> None:
211
- """Appends a message to the thread's messages."""
212
- self.messages.append(self.create_message(role, content))
213
-
214
- def inference_str_from_messages(self) -> str:
215
- """Constructs an inference-ready string from messages."""
216
- inf_str = ''
217
- sys_msg_str = ''
218
- sys_msg_flag = False
219
- context_len_budget = self.model.context_length
220
-
221
- if len(self.messages) >= 1 and self.messages[0]['role'] == 'system':
222
- sys_msg_flag = True
223
- sys_msg = self.messages[0]
224
- sys_msg_str = sys_msg.as_string()
225
- context_len_budget -= self.model.get_length(sys_msg_str)
226
-
227
- iterator = reversed(self.messages[1:]) if sys_msg_flag else reversed(self.messages)
228
-
229
- for message in iterator:
230
- msg_str = message.as_string()
231
- context_len_budget -= self.model.get_length(msg_str)
232
- if context_len_budget <= 0:
233
- break
234
- inf_str = msg_str + inf_str
235
-
236
- inf_str = sys_msg_str + inf_str if sys_msg_flag else inf_str
237
- inf_str += self.format['bot_prefix']
238
-
239
- return inf_str
240
-
241
- def send(self, prompt: str) -> str:
242
- """Sends a message and receives a response."""
243
- self.add_message("user", prompt)
244
- output = self.model.generate(
245
- self.inference_str_from_messages(),
246
- stops=self.format['stops'],
247
- sampler=self.sampler
248
- )
249
- self.add_message("bot", output)
250
- return output
251
-
252
- def _interactive_update_sampler(self) -> None:
253
- """Interactively updates sampler settings."""
254
- print()
255
- try:
256
- for param_name in SamplerSettings.param_types:
257
- current_value = getattr(self.sampler, param_name)
258
- new_value = input(f'{param_name}: {current_value} -> ')
259
- try:
260
- if new_value.lower() == 'none':
261
- setattr(self.sampler, param_name, None)
262
- elif param_name in ('top_k', 'max_len_tokens'):
263
- setattr(self.sampler, param_name, int(new_value))
264
- else:
265
- setattr(self.sampler, param_name, float(new_value))
266
- print(f'webscout.AIutel: {param_name} updated')
267
- except ValueError:
268
- print(f'webscout.AIutel: {param_name} not updated (invalid input)')
269
- print()
270
- except KeyboardInterrupt:
271
- print('\nwebscout.AIutel: Sampler settings not updated\n')
272
-
273
- def _interactive_input(
274
- self,
275
- prompt: str,
276
- _dim_style: str,
277
- _user_style: str,
278
- _bot_style: str,
279
- _special_style: str
280
- ) -> Tuple[Optional[str], Optional[str]]:
281
- """Receives input from the user, handling multi-line input and commands."""
282
- full_user_input = ''
283
-
284
- while True:
285
- try:
286
- user_input = input(prompt)
287
- except KeyboardInterrupt:
288
- print(f"{RESET_ALL}\n")
289
- return None, None
290
-
291
- if user_input.endswith('\\'):
292
- full_user_input += user_input[:-1] + '\n'
293
- elif user_input == '!':
294
- print()
295
- try:
296
- command = input(f'{RESET_ALL} ! {_dim_style}')
297
- except KeyboardInterrupt:
298
- print('\n')
299
- continue
300
-
301
- if command == '':
302
- print('\n[No command]\n')
303
- elif command.lower() in ['reset', 'restart']:
304
- self.reset()
305
- print('\n[Thread reset]\n')
306
- elif command.lower() in ['cls', 'clear']:
307
- cls()
308
- print()
309
- elif command.lower() in ['ctx', 'context']:
310
- print(f"\n{self.len_messages()}\n")
311
- elif command.lower() in ['stats', 'print_stats']:
312
- print()
313
- self.print_stats()
314
- print()
315
- elif command.lower() in ['sampler', 'samplers', 'settings']:
316
- self._interactive_update_sampler()
317
- elif command.lower() in ['str', 'string', 'as_string']:
318
- print(f"\n{self.as_string()}\n")
319
- elif command.lower() in ['repr', 'save', 'backup']:
320
- print(f"\n{repr(self)}\n")
321
- elif command.lower() in ['remove', 'rem', 'delete', 'del']:
322
- print()
323
- if len(self.messages) > 1: # Prevent deleting the system message
324
- old_len = len(self.messages)
325
- del self.messages[-1]
326
- assert len(self.messages) == (old_len - 1)
327
- print('[Removed last message]\n')
328
- else:
329
- print('[Cannot remove system message]\n')
330
- elif command.lower() in ['last', 'repeat']:
331
- if len(self.messages) > 1:
332
- last_msg = self.messages[-1]
333
- if last_msg['role'] == 'user':
334
- print(f"\n{_user_style}{last_msg['content']}{RESET_ALL}\n")
335
- elif last_msg['role'] == 'bot':
336
- print(f"\n{_bot_style}{last_msg['content']}{RESET_ALL}\n")
337
- else:
338
- print("\n[No previous message]\n")
339
- elif command.lower() in ['inf', 'inference', 'inf_str']:
340
- print(f'\n"""{self.inference_str_from_messages()}"""\n')
341
- elif command.lower() in ['reroll', 're-roll', 're', 'swipe']:
342
- if len(self.messages) > 1:
343
- old_len = len(self.messages)
344
- del self.messages[-1]
345
- assert len(self.messages) == (old_len - 1)
346
- return '', None
347
- else:
348
- print("\n[Cannot reroll system message]\n")
349
- elif command.lower() in ['exit', 'quit']:
350
- print(RESET_ALL)
351
- return None, None
352
- elif command.lower() in ['help', '/?', '?']:
353
- print(
354
- "\n"
355
- "reset | restart -- Reset the thread to its original state\n"
356
- "clear | cls -- Clear the terminal\n"
357
- "context | ctx -- Get the context usage in tokens\n"
358
- "print_stats | stats -- Get the context usage stats\n"
359
- "sampler | settings -- Update the sampler settings\n"
360
- "string | str -- Print the message history as a string\n"
361
- "repr | save -- Print the representation of the thread\n"
362
- "remove | delete -- Remove the last message\n"
363
- "last | repeat -- Repeat the last message\n"
364
- "inference | inf -- Print the inference string\n"
365
- "reroll | swipe -- Regenerate the last message\n"
366
- "exit | quit -- Exit the interactive chat (can also use ^C)\n"
367
- "help | ? -- Show this screen\n"
368
- "\n"
369
- "TIP: Type '<' at the prompt and press ENTER to prefix the bot's next message.\n"
370
- " For example, type 'Sure!' to bypass refusals\n"
371
- "\n"
372
- "TIP: Type '!!' at the prompt and press ENTER to insert a system message\n"
373
- "\n"
374
- )
375
- else:
376
- print('\n[Unknown command]\n')
377
- elif user_input == '<':
378
- print()
379
- try:
380
- next_message_start = input(f'{RESET_ALL} < {_dim_style}')
381
- except KeyboardInterrupt:
382
- print(f'{RESET_ALL}\n')
383
- continue
384
- else:
385
- print()
386
- return '', next_message_start
387
- elif user_input == '!!':
388
- print()
389
- try:
390
- next_sys_msg = input(f'{RESET_ALL} !! {_special_style}')
391
- except KeyboardInterrupt:
392
- print(f'{RESET_ALL}\n')
393
- continue
394
- else:
395
- print()
396
- return next_sys_msg, '-1'
397
- else:
398
- full_user_input += user_input
399
- return full_user_input, None
400
-
401
- def interact(
402
- self,
403
- color: bool = True,
404
- header: Optional[str] = None,
405
- stream: bool = True
406
- ) -> None:
407
- """
408
- Starts an interactive chat session.
409
-
410
- Allows for real-time interaction with the model, including
411
- interrupting generation, regenerating responses, and using
412
- commands.
413
-
414
- Args:
415
- color (bool, optional): Whether to use colored output. Defaults to True.
416
- header (Optional[str], optional): Header text to display. Defaults to None.
417
- stream (bool, optional): Whether to stream the response. Defaults to True.
418
- """
419
- print()
420
- from .utils import SPECIAL_STYLE, USER_STYLE, BOT_STYLE, DIM_STYLE
421
- if not color:
422
- SPECIAL_STYLE = USER_STYLE = BOT_STYLE = DIM_STYLE = ''
423
-
424
- if header is not None:
425
- print(f"{SPECIAL_STYLE}{header}{RESET_ALL}\n")
426
-
427
- while True:
428
- prompt = f"{RESET_ALL} > {USER_STYLE}"
429
- try:
430
- user_prompt, next_message_start = self._interactive_input(
431
- prompt, DIM_STYLE, USER_STYLE, BOT_STYLE, SPECIAL_STYLE
432
- )
433
- except KeyboardInterrupt:
434
- print(f"{RESET_ALL}\n")
435
- return
436
-
437
- if user_prompt is None and next_message_start is None:
438
- break
439
-
440
- if next_message_start == '-1':
441
- self.add_message('system', user_prompt)
442
- continue
443
-
444
- if next_message_start is not None:
445
- try:
446
- print(f"{BOT_STYLE}{next_message_start}", end='', flush=True)
447
- if stream:
448
- output = next_message_start + self.model.stream_print(
449
- self.inference_str_from_messages() + next_message_start,
450
- stops=self.format['stops'],
451
- sampler=self.sampler,
452
- end=''
453
- )
454
- else:
455
- output = next_message_start + self.model.generate(
456
- self.inference_str_from_messages() + next_message_start,
457
- stops=self.format['stops'],
458
- sampler=self.sampler
459
- )
460
- print(output, end='', flush=True)
461
- except KeyboardInterrupt:
462
- print(
463
- f"{DIM_STYLE} [Message not added to history; "
464
- "press ENTER to re-roll]\n"
465
- )
466
- continue
467
- else:
468
- self.add_message("bot", output)
469
- else:
470
- print(BOT_STYLE, end='')
471
- if user_prompt != "":
472
- self.add_message("user", user_prompt)
473
- try:
474
- if stream:
475
- output = self.model.stream_print(
476
- self.inference_str_from_messages(),
477
- stops=self.format['stops'],
478
- sampler=self.sampler,
479
- end=''
480
- )
481
- else:
482
- output = self.model.generate(
483
- self.inference_str_from_messages(),
484
- stops=self.format['stops'],
485
- sampler=self.sampler
486
- )
487
- print(output, end='', flush=True)
488
- except KeyboardInterrupt:
489
- print(
490
- f"{DIM_STYLE} [Message not added to history; "
491
- "press ENTER to re-roll]\n"
492
- )
493
- continue
494
- else:
495
- self.add_message("bot", output)
496
-
497
- if output.endswith("\n\n"):
498
- print(RESET_ALL, end='', flush=True)
499
- elif output.endswith("\n"):
500
- print(RESET_ALL)
501
- else:
502
- print(f"{RESET_ALL}\n")
503
-
504
- def reset(self) -> None:
505
- """Clears the message history, resetting the thread to its initial state."""
506
- self.messages: list[Message] = [
507
- self.create_message("system", self.format['system_prompt'])
508
- ] if self._messages is None else self._messages
509
-
510
- def as_string(self) -> str:
511
- """Returns the thread's message history as a string."""
512
- return ''.join(msg.as_string() for msg in self.messages)
513
-
514
- def print_stats(
515
- self,
516
- end: str = '\n',
517
- file: TextIO = sys.stdout,
518
- flush: bool = True
519
- ) -> None:
520
- """Prints context usage statistics."""
521
- thread_len_tokens = self.len_messages()
522
- max_ctx_len = self.model.context_length
523
- context_used_percentage = round((thread_len_tokens / max_ctx_len) * 100)
524
- print(
525
- f"{thread_len_tokens} / {max_ctx_len} tokens "
526
- f"({context_used_percentage}% of context used), "
527
- f"{len(self.messages)} messages",
528
- end=end, file=file, flush=flush
529
- )
530
- if not flush:
531
- file.flush()
532
-
533
- def warmup(self):
534
- """
535
- Warms up the model by running a simple generation.
536
- """
537
- if self.model.verbose:
538
- print_verbose("Warming up the model...")
539
- self.model.generate("This is a warm-up prompt.")