QuLab 2.10.10__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qulab/__init__.py +33 -0
- qulab/__main__.py +4 -0
- qulab/cli/__init__.py +0 -0
- qulab/cli/commands.py +30 -0
- qulab/cli/config.py +170 -0
- qulab/cli/decorators.py +28 -0
- qulab/dicttree.py +523 -0
- qulab/executor/__init__.py +5 -0
- qulab/executor/analyze.py +188 -0
- qulab/executor/cli.py +434 -0
- qulab/executor/load.py +563 -0
- qulab/executor/registry.py +185 -0
- qulab/executor/schedule.py +543 -0
- qulab/executor/storage.py +615 -0
- qulab/executor/template.py +259 -0
- qulab/executor/utils.py +194 -0
- qulab/expression.py +827 -0
- qulab/fun.cp313-win_amd64.pyd +0 -0
- qulab/monitor/__init__.py +1 -0
- qulab/monitor/__main__.py +8 -0
- qulab/monitor/config.py +41 -0
- qulab/monitor/dataset.py +77 -0
- qulab/monitor/event_queue.py +54 -0
- qulab/monitor/mainwindow.py +234 -0
- qulab/monitor/monitor.py +115 -0
- qulab/monitor/ploter.py +123 -0
- qulab/monitor/qt_compat.py +16 -0
- qulab/monitor/toolbar.py +265 -0
- qulab/scan/__init__.py +2 -0
- qulab/scan/curd.py +221 -0
- qulab/scan/models.py +554 -0
- qulab/scan/optimize.py +76 -0
- qulab/scan/query.py +387 -0
- qulab/scan/record.py +603 -0
- qulab/scan/scan.py +1166 -0
- qulab/scan/server.py +450 -0
- qulab/scan/space.py +213 -0
- qulab/scan/utils.py +234 -0
- qulab/storage/__init__.py +0 -0
- qulab/storage/__main__.py +51 -0
- qulab/storage/backend/__init__.py +0 -0
- qulab/storage/backend/redis.py +204 -0
- qulab/storage/base_dataset.py +352 -0
- qulab/storage/chunk.py +60 -0
- qulab/storage/dataset.py +127 -0
- qulab/storage/file.py +273 -0
- qulab/storage/models/__init__.py +22 -0
- qulab/storage/models/base.py +4 -0
- qulab/storage/models/config.py +28 -0
- qulab/storage/models/file.py +89 -0
- qulab/storage/models/ipy.py +58 -0
- qulab/storage/models/models.py +88 -0
- qulab/storage/models/record.py +161 -0
- qulab/storage/models/report.py +22 -0
- qulab/storage/models/tag.py +93 -0
- qulab/storage/storage.py +95 -0
- qulab/sys/__init__.py +2 -0
- qulab/sys/chat.py +688 -0
- qulab/sys/device/__init__.py +3 -0
- qulab/sys/device/basedevice.py +255 -0
- qulab/sys/device/loader.py +86 -0
- qulab/sys/device/utils.py +79 -0
- qulab/sys/drivers/FakeInstrument.py +68 -0
- qulab/sys/drivers/__init__.py +0 -0
- qulab/sys/ipy_events.py +125 -0
- qulab/sys/net/__init__.py +0 -0
- qulab/sys/net/bencoder.py +205 -0
- qulab/sys/net/cli.py +169 -0
- qulab/sys/net/dhcp.py +543 -0
- qulab/sys/net/dhcpd.py +176 -0
- qulab/sys/net/kad.py +1142 -0
- qulab/sys/net/kcp.py +192 -0
- qulab/sys/net/nginx.py +194 -0
- qulab/sys/progress.py +190 -0
- qulab/sys/rpc/__init__.py +0 -0
- qulab/sys/rpc/client.py +0 -0
- qulab/sys/rpc/exceptions.py +96 -0
- qulab/sys/rpc/msgpack.py +1052 -0
- qulab/sys/rpc/msgpack.pyi +41 -0
- qulab/sys/rpc/router.py +35 -0
- qulab/sys/rpc/rpc.py +412 -0
- qulab/sys/rpc/serialize.py +139 -0
- qulab/sys/rpc/server.py +29 -0
- qulab/sys/rpc/socket.py +29 -0
- qulab/sys/rpc/utils.py +25 -0
- qulab/sys/rpc/worker.py +0 -0
- qulab/sys/rpc/zmq_socket.py +227 -0
- qulab/tools/__init__.py +0 -0
- qulab/tools/connection_helper.py +39 -0
- qulab/typing.py +2 -0
- qulab/utils.py +95 -0
- qulab/version.py +1 -0
- qulab/visualization/__init__.py +188 -0
- qulab/visualization/__main__.py +71 -0
- qulab/visualization/_autoplot.py +464 -0
- qulab/visualization/plot_circ.py +319 -0
- qulab/visualization/plot_layout.py +408 -0
- qulab/visualization/plot_seq.py +242 -0
- qulab/visualization/qdat.py +152 -0
- qulab/visualization/rot3d.py +23 -0
- qulab/visualization/widgets.py +86 -0
- qulab-2.10.10.dist-info/METADATA +110 -0
- qulab-2.10.10.dist-info/RECORD +107 -0
- qulab-2.10.10.dist-info/WHEEL +5 -0
- qulab-2.10.10.dist-info/entry_points.txt +2 -0
- qulab-2.10.10.dist-info/licenses/LICENSE +21 -0
- qulab-2.10.10.dist-info/top_level.txt +1 -0
qulab/sys/chat.py
ADDED
@@ -0,0 +1,688 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import logging
|
3
|
+
import pickle
|
4
|
+
import re
|
5
|
+
import time
|
6
|
+
from concurrent.futures import ThreadPoolExecutor
|
7
|
+
from datetime import datetime
|
8
|
+
from pathlib import Path
|
9
|
+
from random import shuffle
|
10
|
+
from typing import Any, List, Optional, TypedDict
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
import openai
|
14
|
+
import tenacity
|
15
|
+
import tiktoken
|
16
|
+
from IPython import get_ipython
|
17
|
+
from IPython.display import Markdown, display
|
18
|
+
from openai.error import (APIConnectionError, APIError, RateLimitError,
|
19
|
+
ServiceUnavailableError, Timeout)
|
20
|
+
from scipy import spatial
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class Message(TypedDict):
|
26
|
+
"""OpenAI Message object containing a role and the message content"""
|
27
|
+
|
28
|
+
role: str
|
29
|
+
content: str
|
30
|
+
|
31
|
+
|
32
|
+
DEFAULT_SYSTEM_PROMPT = 'You are a helpful assistant. Respond using markdown.'
|
33
|
+
DEFAULT_GPT_MODEL = "gpt-3.5-turbo"
|
34
|
+
EMBEDDING_MODEL = "text-embedding-ada-002"
|
35
|
+
EMBED_DIM = 1536
|
36
|
+
|
37
|
+
|
38
|
+
def token_limits(model: str = DEFAULT_GPT_MODEL) -> int:
|
39
|
+
"""Return the maximum number of tokens for a model."""
|
40
|
+
return {
|
41
|
+
"gpt-3.5-turbo": 4096,
|
42
|
+
"gpt-4": 8192,
|
43
|
+
"gpt-4-32k": 32768,
|
44
|
+
"text-embedding-ada-002": 8191,
|
45
|
+
}[model]
|
46
|
+
|
47
|
+
|
48
|
+
def count_message_tokens(messages: List[Message],
|
49
|
+
model: str = "gpt-3.5-turbo-0301") -> int:
|
50
|
+
"""
|
51
|
+
Returns the number of tokens used by a list of messages.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
messages (list): A list of messages, each of which is a dictionary
|
55
|
+
containing the role and content of the message.
|
56
|
+
model (str): The name of the model to use for tokenization.
|
57
|
+
Defaults to "gpt-3.5-turbo-0301".
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
int: The number of tokens used by the list of messages.
|
61
|
+
"""
|
62
|
+
try:
|
63
|
+
encoding = tiktoken.encoding_for_model(model)
|
64
|
+
except KeyError:
|
65
|
+
logger.warn("Warning: model not found. Using cl100k_base encoding.")
|
66
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
67
|
+
if model == "gpt-3.5-turbo":
|
68
|
+
# !Note: gpt-3.5-turbo may change over time.
|
69
|
+
# Returning num tokens assuming gpt-3.5-turbo-0301.")
|
70
|
+
return count_message_tokens(messages, model="gpt-3.5-turbo-0301")
|
71
|
+
elif model == "gpt-4":
|
72
|
+
# !Note: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
|
73
|
+
return count_message_tokens(messages, model="gpt-4-0314")
|
74
|
+
elif model == "gpt-3.5-turbo-0301":
|
75
|
+
tokens_per_message = (
|
76
|
+
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
77
|
+
)
|
78
|
+
tokens_per_name = -1 # if there's a name, the role is omitted
|
79
|
+
elif model == "gpt-4-0314":
|
80
|
+
tokens_per_message = 3
|
81
|
+
tokens_per_name = 1
|
82
|
+
else:
|
83
|
+
raise NotImplementedError(
|
84
|
+
f"num_tokens_from_messages() is not implemented for model {model}.\n"
|
85
|
+
" See https://github.com/openai/openai-python/blob/main/chatml.md for"
|
86
|
+
" information on how messages are converted to tokens.")
|
87
|
+
num_tokens = 0
|
88
|
+
for message in messages:
|
89
|
+
num_tokens += tokens_per_message
|
90
|
+
for key, value in message.items():
|
91
|
+
num_tokens += len(encoding.encode(value))
|
92
|
+
if key == "name":
|
93
|
+
num_tokens += tokens_per_name
|
94
|
+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
95
|
+
return num_tokens
|
96
|
+
|
97
|
+
|
98
|
+
def count_string_tokens(string: str, model_name: str) -> int:
|
99
|
+
"""
|
100
|
+
Returns the number of tokens in a text string.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
string (str): The text string.
|
104
|
+
model_name (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
int: The number of tokens in the text string.
|
108
|
+
"""
|
109
|
+
encoding = tiktoken.encoding_for_model(model_name)
|
110
|
+
return len(encoding.encode(string))
|
111
|
+
|
112
|
+
|
113
|
+
def create_chat_message(role, content) -> Message:
|
114
|
+
"""
|
115
|
+
Create a chat message with the given role and content.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
role (str): The role of the message sender, e.g., "system", "user", or "assistant".
|
119
|
+
content (str): The content of the message.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
dict: A dictionary containing the role and content of the message.
|
123
|
+
"""
|
124
|
+
return {"role": role, "content": content}
|
125
|
+
|
126
|
+
|
127
|
+
def generate_context(prompt,
|
128
|
+
relevant_memory,
|
129
|
+
full_message_history,
|
130
|
+
model,
|
131
|
+
summary=None):
|
132
|
+
current_context = [
|
133
|
+
create_chat_message("system", prompt),
|
134
|
+
create_chat_message(
|
135
|
+
"system", f"The current time and date is {time.strftime('%c')}"),
|
136
|
+
create_chat_message(
|
137
|
+
"system",
|
138
|
+
f"This reminds you of these events from your past:\n{relevant_memory}\n\n",
|
139
|
+
),
|
140
|
+
]
|
141
|
+
if summary is not None:
|
142
|
+
current_context.append(
|
143
|
+
create_chat_message(
|
144
|
+
"system",
|
145
|
+
f"This is a summary of the conversation so far:\n{summary}\n\n"
|
146
|
+
))
|
147
|
+
|
148
|
+
# Add messages from the full message history until we reach the token limit
|
149
|
+
next_message_to_add_index = len(full_message_history) - 1
|
150
|
+
insertion_index = len(current_context)
|
151
|
+
# Count the currently used tokens
|
152
|
+
current_tokens_used = count_message_tokens(current_context, model)
|
153
|
+
return (
|
154
|
+
next_message_to_add_index,
|
155
|
+
current_tokens_used,
|
156
|
+
insertion_index,
|
157
|
+
current_context,
|
158
|
+
)
|
159
|
+
|
160
|
+
|
161
|
+
@tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
|
162
|
+
stop=tenacity.stop_after_attempt(5),
|
163
|
+
retry=tenacity.retry_if_exception_type(
|
164
|
+
(RateLimitError, APIError, Timeout,
|
165
|
+
ServiceUnavailableError, APIConnectionError)))
|
166
|
+
def create_chat_completion(
|
167
|
+
messages: List[Message], # type: ignore
|
168
|
+
model: Optional[str] = None,
|
169
|
+
temperature: float = 0.9,
|
170
|
+
max_tokens: Optional[int] = None,
|
171
|
+
) -> str:
|
172
|
+
"""Create a chat completion using the OpenAI API
|
173
|
+
|
174
|
+
Args:
|
175
|
+
messages (List[Message]): The messages to send to the chat completion
|
176
|
+
model (str, optional): The model to use. Defaults to None.
|
177
|
+
temperature (float, optional): The temperature to use. Defaults to 0.9.
|
178
|
+
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
str: The response from the chat completion
|
182
|
+
"""
|
183
|
+
response = openai.ChatCompletion.create(
|
184
|
+
model=model,
|
185
|
+
messages=messages,
|
186
|
+
temperature=temperature,
|
187
|
+
max_tokens=max_tokens,
|
188
|
+
)
|
189
|
+
try:
|
190
|
+
resp = response.choices[0].message["content"]
|
191
|
+
except:
|
192
|
+
try:
|
193
|
+
return response.error.message
|
194
|
+
except:
|
195
|
+
logger.error(f"Error in create_chat_completion: {response}")
|
196
|
+
raise
|
197
|
+
return resp
|
198
|
+
|
199
|
+
|
200
|
+
@tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
|
201
|
+
stop=tenacity.stop_after_attempt(5),
|
202
|
+
retry=tenacity.retry_if_exception_type(
|
203
|
+
(RateLimitError, APIError, Timeout,
|
204
|
+
ServiceUnavailableError, APIConnectionError)))
|
205
|
+
def get_embedding(
|
206
|
+
text: str,
|
207
|
+
*_,
|
208
|
+
model: str = EMBEDDING_MODEL,
|
209
|
+
**kwargs,
|
210
|
+
) -> List[float]:
|
211
|
+
"""Create an embedding using the OpenAI API
|
212
|
+
|
213
|
+
Args:
|
214
|
+
text (str): The text to embed.
|
215
|
+
kwargs: Other arguments to pass to the OpenAI API embedding creation call.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
List[float]: The embedding.
|
219
|
+
"""
|
220
|
+
return openai.Embedding.create(
|
221
|
+
model=model,
|
222
|
+
input=[text],
|
223
|
+
**kwargs,
|
224
|
+
)["data"][0]["embedding"]
|
225
|
+
|
226
|
+
|
227
|
+
def chat_with_ai(prompt,
|
228
|
+
user_input,
|
229
|
+
full_message_history,
|
230
|
+
permanent_memory,
|
231
|
+
summary=None,
|
232
|
+
model=DEFAULT_GPT_MODEL,
|
233
|
+
token_limit=None):
|
234
|
+
"""Interact with the OpenAI API, sending the prompt, user input, message history,
|
235
|
+
and permanent memory.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
prompt (str): The prompt explaining the rules to the AI.
|
239
|
+
user_input (str): The input from the user.
|
240
|
+
full_message_history (list): The list of all messages sent between the
|
241
|
+
user and the AI.
|
242
|
+
permanent_memory (Obj): The memory object containing the permanent
|
243
|
+
memory.
|
244
|
+
summary (str): The summary of the conversation so far.
|
245
|
+
model (str): The name of the model to use for tokenization.
|
246
|
+
token_limit (int): The maximum number of tokens allowed in the API call.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
str: The AI's response.
|
250
|
+
"""
|
251
|
+
|
252
|
+
# Reserve 1000 tokens for the response
|
253
|
+
|
254
|
+
if token_limit is None:
|
255
|
+
token_limit = token_limits(model)
|
256
|
+
|
257
|
+
logger.debug(f"Token limit: {token_limit}")
|
258
|
+
send_token_limit = token_limit - 1000
|
259
|
+
if len(full_message_history) == 0:
|
260
|
+
relevant_memory = ""
|
261
|
+
else:
|
262
|
+
recent_history = full_message_history[-5:]
|
263
|
+
shuffle(recent_history)
|
264
|
+
relevant_memories = permanent_memory.get_relevant(
|
265
|
+
str(recent_history), 5)
|
266
|
+
if relevant_memories:
|
267
|
+
shuffle(relevant_memories)
|
268
|
+
relevant_memory = str(relevant_memories)
|
269
|
+
|
270
|
+
logger.debug(f"Memory Stats: {permanent_memory.get_stats()}")
|
271
|
+
|
272
|
+
(
|
273
|
+
next_message_to_add_index,
|
274
|
+
current_tokens_used,
|
275
|
+
insertion_index,
|
276
|
+
current_context,
|
277
|
+
) = generate_context(prompt, relevant_memory, full_message_history, model,
|
278
|
+
summary)
|
279
|
+
|
280
|
+
while current_tokens_used > 2500:
|
281
|
+
# remove memories until we are under 2500 tokens
|
282
|
+
relevant_memory = relevant_memory[:-1]
|
283
|
+
(
|
284
|
+
next_message_to_add_index,
|
285
|
+
current_tokens_used,
|
286
|
+
insertion_index,
|
287
|
+
current_context,
|
288
|
+
) = generate_context(prompt, relevant_memory, full_message_history,
|
289
|
+
model, summary)
|
290
|
+
|
291
|
+
current_tokens_used += count_message_tokens(
|
292
|
+
[create_chat_message("user", user_input)],
|
293
|
+
model) # Account for user input (appended later)
|
294
|
+
|
295
|
+
while next_message_to_add_index >= 0:
|
296
|
+
# print (f"CURRENT TOKENS USED: {current_tokens_used}")
|
297
|
+
message_to_add = full_message_history[next_message_to_add_index]
|
298
|
+
|
299
|
+
tokens_to_add = count_message_tokens([message_to_add], model)
|
300
|
+
if current_tokens_used + tokens_to_add > send_token_limit:
|
301
|
+
break
|
302
|
+
|
303
|
+
# Add the most recent message to the start of the current context,
|
304
|
+
# after the two system prompts.
|
305
|
+
current_context.insert(insertion_index,
|
306
|
+
full_message_history[next_message_to_add_index])
|
307
|
+
|
308
|
+
# Count the currently used tokens
|
309
|
+
current_tokens_used += tokens_to_add
|
310
|
+
|
311
|
+
# Move to the next most recent message in the full message history
|
312
|
+
next_message_to_add_index -= 1
|
313
|
+
|
314
|
+
# Append user input, the length of this is accounted for above
|
315
|
+
current_context.extend([create_chat_message("user", user_input)])
|
316
|
+
|
317
|
+
# Calculate remaining tokens
|
318
|
+
tokens_remaining = token_limit - current_tokens_used
|
319
|
+
# assert tokens_remaining >= 0, "Tokens remaining is negative.
|
320
|
+
|
321
|
+
# TODO: use a model defined elsewhere, so that model can contain
|
322
|
+
# temperature and other settings we care about
|
323
|
+
assistant_reply = create_chat_completion(
|
324
|
+
model=model,
|
325
|
+
messages=current_context,
|
326
|
+
max_tokens=tokens_remaining,
|
327
|
+
)
|
328
|
+
|
329
|
+
# Update full message history
|
330
|
+
full_message_history.append(create_chat_message("user", user_input))
|
331
|
+
full_message_history.append(
|
332
|
+
create_chat_message("assistant", assistant_reply))
|
333
|
+
|
334
|
+
return assistant_reply
|
335
|
+
|
336
|
+
|
337
|
+
def create_default_embeddings():
|
338
|
+
return np.zeros((0, EMBED_DIM)).astype(np.float32)
|
339
|
+
|
340
|
+
|
341
|
+
@dataclasses.dataclass
|
342
|
+
class CacheContent:
|
343
|
+
texts: List[str] = dataclasses.field(default_factory=list)
|
344
|
+
embeddings: np.ndarray = dataclasses.field(
|
345
|
+
default_factory=create_default_embeddings)
|
346
|
+
|
347
|
+
|
348
|
+
class LocalCache():
|
349
|
+
"""A class that stores the memory in a local file"""
|
350
|
+
|
351
|
+
def __init__(self, path=None) -> None:
|
352
|
+
"""Initialize a class instance
|
353
|
+
|
354
|
+
Args:
|
355
|
+
path: str
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
None
|
359
|
+
"""
|
360
|
+
if path is None:
|
361
|
+
self.filename = None
|
362
|
+
else:
|
363
|
+
self.filename = Path(path)
|
364
|
+
self.filename.touch(exist_ok=True)
|
365
|
+
try:
|
366
|
+
with open(self.filename, 'rb') as f:
|
367
|
+
self.data = pickle.load(f)
|
368
|
+
except:
|
369
|
+
self.data = CacheContent()
|
370
|
+
|
371
|
+
def add(self, text: str):
|
372
|
+
"""
|
373
|
+
Add text to our list of texts, add embedding as row to our
|
374
|
+
embeddings-matrix
|
375
|
+
|
376
|
+
Args:
|
377
|
+
text: str
|
378
|
+
|
379
|
+
Returns: None
|
380
|
+
"""
|
381
|
+
if "Command Error:" in text:
|
382
|
+
return ""
|
383
|
+
self.data.texts.append(text)
|
384
|
+
|
385
|
+
embedding = get_embedding(text)
|
386
|
+
|
387
|
+
vector = np.array(embedding).astype(np.float32)
|
388
|
+
vector = vector[np.newaxis, :]
|
389
|
+
self.data.embeddings = np.concatenate(
|
390
|
+
[
|
391
|
+
self.data.embeddings,
|
392
|
+
vector,
|
393
|
+
],
|
394
|
+
axis=0,
|
395
|
+
)
|
396
|
+
|
397
|
+
if self.filename is not None:
|
398
|
+
with open(self.filename, "wb") as f:
|
399
|
+
pickle.dump(self.data, f)
|
400
|
+
return text
|
401
|
+
|
402
|
+
def clear(self) -> str:
|
403
|
+
"""
|
404
|
+
Clears the data in memory.
|
405
|
+
|
406
|
+
Returns: A message indicating that the memory has been cleared.
|
407
|
+
"""
|
408
|
+
self.data = CacheContent()
|
409
|
+
return "Obliviated"
|
410
|
+
|
411
|
+
def get(self, data: str) -> list[Any] | None:
|
412
|
+
"""
|
413
|
+
Gets the data from the memory that is most relevant to the given data.
|
414
|
+
|
415
|
+
Args:
|
416
|
+
data: The data to compare to.
|
417
|
+
|
418
|
+
Returns: The most relevant data.
|
419
|
+
"""
|
420
|
+
return self.get_relevant(data, 1)
|
421
|
+
|
422
|
+
def get_relevant(self, text: str, k: int) -> list[Any]:
|
423
|
+
""" "
|
424
|
+
matrix-vector mult to find score-for-each-row-of-matrix
|
425
|
+
get indices for top-k winning scores
|
426
|
+
return texts for those indices
|
427
|
+
Args:
|
428
|
+
text: str
|
429
|
+
k: int
|
430
|
+
|
431
|
+
Returns: List[str]
|
432
|
+
"""
|
433
|
+
if self.data.embeddings.shape[0] == 0:
|
434
|
+
return []
|
435
|
+
embedding = get_embedding(text)
|
436
|
+
|
437
|
+
scores = np.dot(self.data.embeddings, embedding)
|
438
|
+
|
439
|
+
top_k_indices = np.argsort(scores)[-k:][::-1]
|
440
|
+
|
441
|
+
return [self.data.texts[i] for i in top_k_indices]
|
442
|
+
|
443
|
+
def get_stats(self) -> tuple[int, tuple[int, ...]]:
|
444
|
+
"""
|
445
|
+
Returns: The stats of the local cache.
|
446
|
+
"""
|
447
|
+
return len(self.data.texts), self.data.embeddings.shape
|
448
|
+
|
449
|
+
|
450
|
+
class Completion():
|
451
|
+
|
452
|
+
def __init__(self,
|
453
|
+
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
454
|
+
model=DEFAULT_GPT_MODEL):
|
455
|
+
self.messages = [{"role": "system", "content": system_prompt}]
|
456
|
+
self.title = 'untitled'
|
457
|
+
self.last_time = datetime.now()
|
458
|
+
self.completion = None
|
459
|
+
self.total_tokens = 0
|
460
|
+
self.prompt_tokens = 0
|
461
|
+
self.completion_tokens = 0
|
462
|
+
self.model = model
|
463
|
+
|
464
|
+
def make_title(self):
|
465
|
+
|
466
|
+
text = [
|
467
|
+
f'{d["role"]} :\n"""\n{d["content"]}\n"""'
|
468
|
+
for d in self.messages[1:]
|
469
|
+
]
|
470
|
+
|
471
|
+
messages = [{
|
472
|
+
"role": "system",
|
473
|
+
"content": 'You are a helpful assistant.'
|
474
|
+
}, {
|
475
|
+
'role':
|
476
|
+
"user",
|
477
|
+
'content': ("总结以下对话的内容并为其取个标题以概括对话的内容,标题长度不超过100个字符。"
|
478
|
+
"不得包含`?:*,<>\\/` 等不能用于文件路径的字符。"
|
479
|
+
"返回的结果除了标题本身,不要包含额外的内容,省略结尾的句号。\n" + '\n\n'.join(text))
|
480
|
+
}]
|
481
|
+
completion = openai.ChatCompletion.create(model=self.model,
|
482
|
+
messages=messages)
|
483
|
+
content = completion.choices[0].message['content']
|
484
|
+
return f"{time.strftime('%Y%m%d%H%M')} {content}"
|
485
|
+
|
486
|
+
def say(self, msg):
|
487
|
+
self.last_time = datetime.now()
|
488
|
+
self.messages.append({"role": "user", "content": msg})
|
489
|
+
self.completion = openai.ChatCompletion.create(model=self.model,
|
490
|
+
messages=self.messages)
|
491
|
+
self.total_tokens += self.completion.usage.total_tokens
|
492
|
+
self.completion_tokens += self.completion.usage.completion_tokens
|
493
|
+
self.prompt_tokens += self.completion.usage.prompt_tokens
|
494
|
+
message = self.completion.choices[0].message
|
495
|
+
self.messages.append({
|
496
|
+
"role": message['role'],
|
497
|
+
"content": message['content']
|
498
|
+
})
|
499
|
+
return message['content']
|
500
|
+
|
501
|
+
def save(self):
|
502
|
+
if self.title == 'untitled':
|
503
|
+
self.title = self.make_title()
|
504
|
+
|
505
|
+
filepath = Path.home() / 'chatGPT' / f"{self.title}.completion"
|
506
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
507
|
+
with open(filepath, 'wb') as f:
|
508
|
+
pickle.dump(self, f)
|
509
|
+
|
510
|
+
|
511
|
+
class Conversation():
|
512
|
+
|
513
|
+
def __init__(self,
|
514
|
+
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
515
|
+
model=DEFAULT_GPT_MODEL):
|
516
|
+
self.system_prompt = system_prompt
|
517
|
+
self.summary = None
|
518
|
+
self.history = []
|
519
|
+
self.memory = LocalCache()
|
520
|
+
self.title = None
|
521
|
+
self.last_time = datetime.now()
|
522
|
+
self.model = model
|
523
|
+
self._pool = ThreadPoolExecutor()
|
524
|
+
self._save_future = None
|
525
|
+
|
526
|
+
def __del__(self):
|
527
|
+
if self._save_future is not None:
|
528
|
+
self._save_future.result()
|
529
|
+
self._pool.shutdown()
|
530
|
+
|
531
|
+
def _validate_title(self, title: str) -> str:
|
532
|
+
title.replace('\\/:.*?%&#\"\'<>{}|\n\r\t_', ' ')
|
533
|
+
title = title.strip()
|
534
|
+
title = '_'.join(title.split())
|
535
|
+
if len(title) > 70:
|
536
|
+
title = title[:70]
|
537
|
+
while title[-1] in ' .。,,-_':
|
538
|
+
title = title[:-1]
|
539
|
+
return title
|
540
|
+
|
541
|
+
def make_title(self):
|
542
|
+
messages = [{
|
543
|
+
"role": "system",
|
544
|
+
"content": 'You are a helpful assistant.'
|
545
|
+
}]
|
546
|
+
|
547
|
+
tokens = count_string_tokens(messages[0]['content'], self.model)
|
548
|
+
|
549
|
+
query = ("请根据以下对话内容,总结出中文标题,长度不超过100个字符。"
|
550
|
+
"请注意,标题必须是合法的文件名,省略结尾的句号。返回结果不得包含额外的解释和格式。\n"
|
551
|
+
"对话内容:\n")
|
552
|
+
|
553
|
+
text = []
|
554
|
+
for msg in self.history:
|
555
|
+
text.append(f'{msg["role"]} :\n<quote>{msg["content"]}</quote>')
|
556
|
+
tokens += count_string_tokens(query + '\n'.join(text), self.model)
|
557
|
+
if tokens > token_limits(self.model) - 500:
|
558
|
+
text.pop()
|
559
|
+
break
|
560
|
+
messages.append({"role": "user", "content": query + '\n'.join(text)})
|
561
|
+
|
562
|
+
try:
|
563
|
+
self.last_time = datetime.now()
|
564
|
+
content = create_chat_completion(messages, self.model)
|
565
|
+
title = self._validate_title(content)
|
566
|
+
return f"{time.strftime('%Y%m%d%H%M%S')} {title}"
|
567
|
+
except:
|
568
|
+
return f"{time.strftime('%Y%m%d%H%M%S')} untitled"
|
569
|
+
|
570
|
+
def ask(self, query):
|
571
|
+
self.last_time = datetime.now()
|
572
|
+
|
573
|
+
reply = chat_with_ai(self.system_prompt, query, self.history,
|
574
|
+
self.memory, self.summary, self.model,
|
575
|
+
token_limits(self.model))
|
576
|
+
return reply
|
577
|
+
|
578
|
+
def _save(self):
|
579
|
+
if len(self.history) == 0:
|
580
|
+
return
|
581
|
+
|
582
|
+
if self.title is None:
|
583
|
+
self.title = self.make_title()
|
584
|
+
|
585
|
+
filepath = Path.home() / 'chatGPT' / f"{self.title}.conversation"
|
586
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
587
|
+
with open(filepath, 'wb') as f:
|
588
|
+
pickle.dump(self, f)
|
589
|
+
|
590
|
+
def save(self):
|
591
|
+
self._save_future = self._pool.submit(self._save)
|
592
|
+
return self._save_future
|
593
|
+
|
594
|
+
def __getstate__(self):
|
595
|
+
state = self.__dict__.copy()
|
596
|
+
del state['_pool']
|
597
|
+
del state['_save_future']
|
598
|
+
return state
|
599
|
+
|
600
|
+
def __setstate__(self, state):
|
601
|
+
self.__dict__.update(state)
|
602
|
+
self._pool = ThreadPoolExecutor(max_workers=1)
|
603
|
+
self._save_future = None
|
604
|
+
|
605
|
+
|
606
|
+
ipy = get_ipython()
|
607
|
+
|
608
|
+
current_completion = Conversation()
|
609
|
+
|
610
|
+
|
611
|
+
def chat(line, cell):
|
612
|
+
global current_completion
|
613
|
+
if line:
|
614
|
+
args = line.split()
|
615
|
+
current_completion.save()
|
616
|
+
model = DEFAULT_GPT_MODEL
|
617
|
+
if args[0] in ['gpt-4', 'gpt-3.5', 'gpt-3.5-turbo']:
|
618
|
+
model = args[0]
|
619
|
+
if model == 'gpt-3.5':
|
620
|
+
model = 'gpt-3.5-turbo'
|
621
|
+
if len(args) > 1:
|
622
|
+
prompt = ' '.join(args[1:])
|
623
|
+
else:
|
624
|
+
prompt = DEFAULT_SYSTEM_PROMPT
|
625
|
+
current_completion = Conversation(system_prompt=prompt, model=model)
|
626
|
+
if args[0] in ['end', 'save', 'bye']:
|
627
|
+
return
|
628
|
+
content = current_completion.ask(cell)
|
629
|
+
display(Markdown(content))
|
630
|
+
ipy.set_next_input('%%chat\n')
|
631
|
+
|
632
|
+
|
633
|
+
def autosave_completion():
|
634
|
+
global current_completion
|
635
|
+
if (datetime.now() - current_completion.last_time).seconds > 300 and len(
|
636
|
+
current_completion.history) >= 3:
|
637
|
+
current_completion.save()
|
638
|
+
elif len(current_completion.history) > 7:
|
639
|
+
current_completion.save()
|
640
|
+
|
641
|
+
|
642
|
+
def load_chat(index):
|
643
|
+
global current_completion
|
644
|
+
|
645
|
+
filepath = Path.home() / 'chatGPT'
|
646
|
+
if not filepath.exists():
|
647
|
+
return
|
648
|
+
for i, f in enumerate(
|
649
|
+
sorted(filepath.glob('*.conversation'),
|
650
|
+
key=lambda f: f.stat().st_mtime,
|
651
|
+
reverse=True)):
|
652
|
+
if i == index:
|
653
|
+
if current_completion is not None:
|
654
|
+
current_completion.save().result()
|
655
|
+
with open(f, 'rb') as f:
|
656
|
+
current_completion = pickle.load(f)
|
657
|
+
break
|
658
|
+
|
659
|
+
|
660
|
+
def show_chat(index=None):
|
661
|
+
if index is not None:
|
662
|
+
load_chat(index)
|
663
|
+
messages = current_completion.history
|
664
|
+
for msg in messages:
|
665
|
+
display(Markdown(f"**{msg['role']}**\n\n{msg['content']}"))
|
666
|
+
ipy.set_next_input('%%chat\n')
|
667
|
+
|
668
|
+
|
669
|
+
def list_chat():
|
670
|
+
filepath = Path.home() / 'chatGPT'
|
671
|
+
if not filepath.exists():
|
672
|
+
return
|
673
|
+
rows = ["|index|title|length|time|", "|:---:|:---:|:---:|:---:|"]
|
674
|
+
for i, f in enumerate(
|
675
|
+
sorted(filepath.glob('*.conversation'),
|
676
|
+
key=lambda f: f.stat().st_mtime,
|
677
|
+
reverse=True)):
|
678
|
+
with open(f, 'rb') as f:
|
679
|
+
completion = pickle.load(f)
|
680
|
+
rows.append(
|
681
|
+
f"|{i}|{completion.title}|{len(completion.history)}|{completion.last_time}|"
|
682
|
+
)
|
683
|
+
display(Markdown('\n'.join(rows)))
|
684
|
+
|
685
|
+
|
686
|
+
if ipy is not None:
|
687
|
+
ipy.register_magic_function(chat, 'cell', magic_name='chat')
|
688
|
+
ipy.events.register('post_run_cell', autosave_completion)
|