aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ from types import TracebackType
2
+ from collections import defaultdict
3
+
4
+ import json
5
+ import httpx
6
+ from fake_useragent import UserAgent
7
+
8
+ class DuckChatException(httpx.HTTPError):
9
+ """Base exception class for duck_chat."""
10
+
11
+
12
+ class RatelimitException(DuckChatException):
13
+ """Raised for rate limit exceeded errors during API requests."""
14
+
15
+
16
+ class ConversationLimitException(DuckChatException):
17
+ """Raised for conversation limit during API requests to AI endpoint."""
18
+
19
+
20
+ from enum import Enum
21
+ class ModelType(Enum):
22
+ claude = "claude-3-haiku-20240307"
23
+ llama = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
24
+ gpt4omini = "gpt-4o-mini"
25
+ mixtral = "mistralai/Mixtral-8x7B-Instruct-v0.1"
26
+
27
+ @classmethod
28
+ def _missing_(cls, value):
29
+ if isinstance(value, str):
30
+ # 对于完全匹配的情况
31
+ for member in cls:
32
+ if member.value == value:
33
+ return member
34
+
35
+ # 对于部分匹配的情况
36
+ for member in cls:
37
+ if value in member.value:
38
+ return member
39
+
40
+ return None
41
+
42
+ def __new__(cls, *args):
43
+ obj = object.__new__(cls)
44
+ obj._value_ = args[0]
45
+ return obj
46
+
47
+ def __str__(self):
48
+ return self.value
49
+
50
+ def __repr__(self):
51
+ return f"ModelType({self.value!r})"
52
+
53
+ class Role(Enum):
54
+ user = "user"
55
+ assistant = "assistant"
56
+
57
+ import msgspec
58
+ class Message(msgspec.Struct):
59
+ role: Role
60
+ content: str
61
+
62
+ def get(self, key, default=None):
63
+ try:
64
+ return getattr(self, key)
65
+ except AttributeError:
66
+ return default
67
+
68
+ def __getitem__(self, key):
69
+ return getattr(self, key)
70
+
71
+ class History(msgspec.Struct):
72
+ model: ModelType
73
+ messages: list[Message]
74
+
75
+ def add_to_conversation(self, role: Role, message: str) -> None:
76
+ self.messages.append(Message(role, message))
77
+
78
+ def set_model(self, model_name: str) -> None:
79
+ self.model = ModelType(model_name)
80
+
81
+ def __getitem__(self, index: int) -> list[Message]:
82
+ return self.messages[index]
83
+
84
+ def __len__(self) -> int:
85
+ return len(self.messages)
86
+
87
+ class UserHistory(msgspec.Struct):
88
+ user_history: dict[str, History] = msgspec.field(default_factory=dict)
89
+
90
+ def add_to_conversation(self, role: Role, message: str, convo_id: str = "default") -> None:
91
+ if convo_id not in self.user_history:
92
+ self.user_history[convo_id] = History(model=ModelType.claude, messages=[])
93
+ self.user_history[convo_id].add_to_conversation(role, message)
94
+
95
+ def get_history(self, convo_id: str = "default") -> History:
96
+ if convo_id not in self.user_history:
97
+ self.user_history[convo_id] = History(model=ModelType.claude, messages=[])
98
+ return self.user_history[convo_id]
99
+
100
+ def set_model(self, model_name: str, convo_id: str = "default") -> None:
101
+ self.get_history(convo_id).set_model(model_name)
102
+
103
+ def reset(self, convo_id: str = "default") -> None:
104
+ self.user_history[convo_id] = History(model=ModelType.claude, messages=[])
105
+
106
+ def get_all_convo_ids(self) -> list[str]:
107
+ return list(self.user_history.keys())
108
+
109
+ # 新增方法
110
+ def __getitem__(self, convo_id: str) -> History:
111
+ return self.get_history(convo_id)
112
+
113
+ class DuckChat:
114
+ def __init__(
115
+ self,
116
+ model: ModelType = ModelType.claude,
117
+ client: httpx.AsyncClient | None = None,
118
+ user_agent: UserAgent | str = UserAgent(min_version=120.0),
119
+ ) -> None:
120
+ if isinstance(user_agent, str):
121
+ self.user_agent = user_agent
122
+ else:
123
+ self.user_agent = user_agent.random
124
+
125
+ self._client = client or httpx.AsyncClient(
126
+ headers={
127
+ "Host": "duckduckgo.com",
128
+ "Accept": "text/event-stream",
129
+ "Accept-Language": "en-US,en;q=0.5",
130
+ "Accept-Encoding": "gzip, deflate, br",
131
+ "Referer": "https://duckduckgo.com/",
132
+ "User-Agent": self.user_agent,
133
+ "DNT": "1",
134
+ "Sec-GPC": "1",
135
+ "Connection": "keep-alive",
136
+ "Sec-Fetch-Dest": "empty",
137
+ "Sec-Fetch-Mode": "cors",
138
+ "Sec-Fetch-Site": "same-origin",
139
+ "TE": "trailers",
140
+ }
141
+ )
142
+ self.vqd: list[str | None] = []
143
+ self.history = History(model, [])
144
+ self.conversation = UserHistory({"default": self.history})
145
+ self.__encoder = msgspec.json.Encoder()
146
+ self.__decoder = msgspec.json.Decoder()
147
+
148
+ self.tokens_usage = defaultdict(int)
149
+
150
+ async def __aenter__(self):
151
+ return self
152
+
153
+ async def __aexit__(
154
+ self,
155
+ exc_type: type[BaseException] | None = None,
156
+ exc_value: BaseException | None = None,
157
+ traceback: TracebackType | None = None,
158
+ ) -> None:
159
+ await self._client.aclose()
160
+
161
+ async def add_to_conversation(self, role: Role, message: Message, convo_id: str = "default") -> None:
162
+ self.conversation.add_to_conversation(role, message, convo_id)
163
+
164
+ async def get_vqd(self) -> None:
165
+ """Get new x-vqd-4 token"""
166
+ response = await self._client.get(
167
+ "https://duckduckgo.com/duckchat/v1/status", headers={"x-vqd-accept": "1"}
168
+ )
169
+ if response.status_code == 429:
170
+ try:
171
+ err_message = self.__decoder.decode(response.content).get("type", "")
172
+ except Exception:
173
+ raise DuckChatException(response.text)
174
+ else:
175
+ raise RatelimitException(err_message)
176
+ self.vqd.append(response.headers.get("x-vqd-4"))
177
+ if not self.vqd:
178
+ raise DuckChatException("No x-vqd-4")
179
+
180
+ async def process_sse_stream(self, convo_id: str = "default"):
181
+ # print("self.conversation[convo_id]", self.conversation[convo_id])
182
+ async with self._client.stream(
183
+ "POST",
184
+ "https://duckduckgo.com/duckchat/v1/chat",
185
+ headers={
186
+ "Content-Type": "application/json",
187
+ "x-vqd-4": self.vqd[-1],
188
+ },
189
+ content=self.__encoder.encode(self.conversation[convo_id]),
190
+ ) as response:
191
+ if response.status_code == 400:
192
+ content = await response.aread()
193
+ print("response.status_code", response.status_code, content)
194
+ if response.status_code == 429:
195
+ raise RatelimitException("Rate limit exceeded")
196
+
197
+ async for line in response.aiter_lines():
198
+ if line.startswith('data: '):
199
+ yield line
200
+
201
+ async def ask_stream_async(self, query, convo_id, model, **kwargs):
202
+ """Get answer from chat AI"""
203
+ if not self.vqd:
204
+ await self.get_vqd()
205
+ await self.add_to_conversation(Role.user, query, convo_id)
206
+ self.conversation.set_model(model, convo_id)
207
+ full_response = ""
208
+ async for sse in self.process_sse_stream(convo_id):
209
+ data = sse.lstrip("data: ")
210
+ if data == "[DONE]":
211
+ break
212
+ resp: dict = json.loads(data)
213
+ mess = resp.get("message")
214
+ if mess:
215
+ yield mess
216
+ full_response += mess
217
+ # await self.add_to_conversation(Role.assistant, full_response, convo_id)
218
+
219
+ async def reset(self, convo_id: str = "default") -> None:
220
+ self.conversation.reset(convo_id)
221
+
222
+ # async def reask_question(self, num: int) -> str:
223
+ # """Get answer from chat AI"""
224
+
225
+ # if num >= len(self.vqd):
226
+ # num = len(self.vqd) - 1
227
+ # self.vqd = self.vqd[:num]
228
+
229
+ # if not self.history.messages:
230
+ # return ""
231
+
232
+ # if not self.vqd:
233
+ # await self.get_vqd()
234
+ # self.history.messages = [self.history.messages[0]]
235
+ # else:
236
+ # num = min(num, len(self.vqd))
237
+ # self.history.messages = self.history.messages[: (num * 2 - 1)]
238
+ # message = await self.get_answer()
239
+ # self.add_to_conversation(Role.assistant, message)
240
+
241
+ # return message
aient/models/gemini.py ADDED
@@ -0,0 +1,357 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import requests
5
+
6
+ from .base import BaseLLM
7
+ from ..core.utils import BaseAPI
8
+
9
+ import copy
10
+ from ..plugins import PLUGINS, get_tools_result_async, function_call_list
11
+ from ..utils.scripts import safe_get
12
+
13
+
14
+ class gemini(BaseLLM):
15
+ def __init__(
16
+ self,
17
+ api_key: str = None,
18
+ engine: str = os.environ.get("GPT_ENGINE") or "gemini-1.5-pro-latest",
19
+ api_url: str = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}",
20
+ system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally",
21
+ temperature: float = 0.5,
22
+ top_p: float = 0.7,
23
+ timeout: float = 20,
24
+ use_plugins: bool = True,
25
+ print_log: bool = False,
26
+ ):
27
+ url = api_url.format(model=engine, stream="streamGenerateContent", api_key=os.environ.get("GOOGLE_AI_API_KEY", api_key))
28
+ super().__init__(api_key, engine, url, system_prompt=system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
29
+ self.conversation: dict[str, list[dict]] = {
30
+ "default": [],
31
+ }
32
+
33
+ def add_to_conversation(
34
+ self,
35
+ message: str,
36
+ role: str,
37
+ convo_id: str = "default",
38
+ pass_history: int = 9999,
39
+ total_tokens: int = 0,
40
+ function_arguments: str = "",
41
+ ) -> None:
42
+ """
43
+ Add a message to the conversation
44
+ """
45
+
46
+ if convo_id not in self.conversation:
47
+ self.reset(convo_id=convo_id)
48
+ # print("message", message)
49
+
50
+ if function_arguments:
51
+ self.conversation[convo_id].append(
52
+ {
53
+ "role": "model",
54
+ "parts": [function_arguments]
55
+ }
56
+ )
57
+ function_call_name = function_arguments["functionCall"]["name"]
58
+ self.conversation[convo_id].append(
59
+ {
60
+ "role": "function",
61
+ "parts": [{
62
+ "functionResponse": {
63
+ "name": function_call_name,
64
+ "response": {
65
+ "name": function_call_name,
66
+ "content": {
67
+ "result": message,
68
+ }
69
+ }
70
+ }
71
+ }]
72
+ }
73
+ )
74
+
75
+ else:
76
+ if isinstance(message, str):
77
+ message = [{"text": message}]
78
+ self.conversation[convo_id].append({"role": role, "parts": message})
79
+
80
+ history_len = len(self.conversation[convo_id])
81
+ history = pass_history
82
+ if pass_history < 2:
83
+ history = 2
84
+ while history_len > history:
85
+ mess_body = self.conversation[convo_id].pop(1)
86
+ history_len = history_len - 1
87
+ if mess_body.get("role") == "user":
88
+ mess_body = self.conversation[convo_id].pop(1)
89
+ history_len = history_len - 1
90
+ if safe_get(mess_body, "parts", 0, "functionCall"):
91
+ self.conversation[convo_id].pop(1)
92
+ history_len = history_len - 1
93
+
94
+ if total_tokens:
95
+ self.tokens_usage[convo_id] += total_tokens
96
+
97
+ def reset(self, convo_id: str = "default", system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally") -> None:
98
+ """
99
+ Reset the conversation
100
+ """
101
+ self.system_prompt = system_prompt or self.system_prompt
102
+ self.conversation[convo_id] = list()
103
+
104
+ def ask_stream(
105
+ self,
106
+ prompt: str,
107
+ role: str = "user",
108
+ convo_id: str = "default",
109
+ model: str = "",
110
+ pass_history: int = 9999,
111
+ model_max_tokens: int = 4096,
112
+ system_prompt: str = None,
113
+ **kwargs,
114
+ ):
115
+ self.system_prompt = system_prompt or self.system_prompt
116
+ if convo_id not in self.conversation or pass_history <= 2:
117
+ self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
118
+ self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
119
+ # print(self.conversation[convo_id])
120
+
121
+ headers = {
122
+ "Content-Type": "application/json",
123
+ }
124
+
125
+ json_post = {
126
+ "contents": self.conversation[convo_id] if pass_history else [{
127
+ "role": "user",
128
+ "content": prompt
129
+ }],
130
+ "systemInstruction": {"parts": [{"text": self.system_prompt}]},
131
+ "safetySettings": [
132
+ {
133
+ "category": "HARM_CATEGORY_HARASSMENT",
134
+ "threshold": "BLOCK_NONE"
135
+ },
136
+ {
137
+ "category": "HARM_CATEGORY_HATE_SPEECH",
138
+ "threshold": "BLOCK_NONE"
139
+ },
140
+ {
141
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
142
+ "threshold": "BLOCK_NONE"
143
+ },
144
+ {
145
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
146
+ "threshold": "BLOCK_NONE"
147
+ }
148
+ ],
149
+ }
150
+
151
+ url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model or self.engine, stream="streamGenerateContent", api_key=os.environ.get("GOOGLE_AI_API_KEY", self.api_key) or kwargs.get("api_key"))
152
+ self.api_url = BaseAPI(url)
153
+ url = self.api_url.source_api_url
154
+
155
+ if self.print_log:
156
+ print("url", url)
157
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
158
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
159
+
160
+ try:
161
+ response = self.session.post(
162
+ url,
163
+ headers=headers,
164
+ json=json_post,
165
+ timeout=kwargs.get("timeout", self.timeout),
166
+ stream=True,
167
+ )
168
+ except ConnectionError:
169
+ print("连接错误,请检查服务器状态或网络连接。")
170
+ return
171
+ except requests.exceptions.ReadTimeout:
172
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
173
+ return
174
+ except Exception as e:
175
+ print(f"发生了未预料的错误: {e}")
176
+ return
177
+
178
+ if response.status_code != 200:
179
+ print(response.text)
180
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
181
+ response_role: str = "model"
182
+ full_response: str = ""
183
+ try:
184
+ for line in response.iter_lines():
185
+ if not line:
186
+ continue
187
+ line = line.decode("utf-8")
188
+ if line and '\"text\": \"' in line:
189
+ content = line.split('\"text\": \"')[1][:-1]
190
+ content = "\n".join(content.split("\\n"))
191
+ content = content.encode('utf-8').decode('unicode-escape')
192
+ full_response += content
193
+ yield content
194
+ except requests.exceptions.ChunkedEncodingError as e:
195
+ print("Chunked Encoding Error occurred:", e)
196
+ except Exception as e:
197
+ print("An error occurred:", e)
198
+
199
+ self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, pass_history=pass_history)
200
+
201
+ async def ask_stream_async(
202
+ self,
203
+ prompt: str,
204
+ role: str = "user",
205
+ convo_id: str = "default",
206
+ model: str = "",
207
+ pass_history: int = 9999,
208
+ system_prompt: str = None,
209
+ language: str = "English",
210
+ function_arguments: str = "",
211
+ total_tokens: int = 0,
212
+ **kwargs,
213
+ ):
214
+ self.system_prompt = system_prompt or self.system_prompt
215
+ if convo_id not in self.conversation or pass_history <= 2:
216
+ self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
217
+ self.add_to_conversation(prompt, role, convo_id=convo_id, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history)
218
+ # print(self.conversation[convo_id])
219
+
220
+ headers = {
221
+ "Content-Type": "application/json",
222
+ }
223
+
224
+ json_post = {
225
+ "contents": self.conversation[convo_id] if pass_history else [{
226
+ "role": "user",
227
+ "content": prompt
228
+ }],
229
+ "systemInstruction": {"parts": [{"text": self.system_prompt}]},
230
+ "safetySettings": [
231
+ {
232
+ "category": "HARM_CATEGORY_HARASSMENT",
233
+ "threshold": "BLOCK_NONE"
234
+ },
235
+ {
236
+ "category": "HARM_CATEGORY_HATE_SPEECH",
237
+ "threshold": "BLOCK_NONE"
238
+ },
239
+ {
240
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
241
+ "threshold": "BLOCK_NONE"
242
+ },
243
+ {
244
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
245
+ "threshold": "BLOCK_NONE"
246
+ }
247
+ ],
248
+ }
249
+
250
+ plugins = kwargs.get("plugins", PLUGINS)
251
+ if all(value == False for value in plugins.values()) == False and self.use_plugins:
252
+ tools = {
253
+ "tools": [
254
+ {
255
+ "function_declarations": [
256
+
257
+ ]
258
+ }
259
+ ],
260
+ "tool_config": {
261
+ "function_calling_config": {
262
+ "mode": "AUTO",
263
+ },
264
+ },
265
+ }
266
+ json_post.update(copy.deepcopy(tools))
267
+ for item in plugins.keys():
268
+ try:
269
+ if plugins[item]:
270
+ json_post["tools"][0]["function_declarations"].append(function_call_list[item])
271
+ except:
272
+ pass
273
+
274
+ url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model or self.engine, stream="streamGenerateContent", api_key=os.environ.get("GOOGLE_AI_API_KEY", self.api_key) or kwargs.get("api_key"))
275
+ self.api_url = BaseAPI(url)
276
+ url = self.api_url.source_api_url
277
+
278
+ if self.print_log:
279
+ print("url", url)
280
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
281
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
282
+
283
+ response_role: str = "model"
284
+ full_response: str = ""
285
+ function_full_response: str = "{"
286
+ need_function_call = False
287
+ revicing_function_call = False
288
+ total_tokens = 0
289
+ try:
290
+ async with self.aclient.stream(
291
+ "post",
292
+ url,
293
+ headers=headers,
294
+ json=json_post,
295
+ timeout=kwargs.get("timeout", self.timeout),
296
+ ) as response:
297
+ if response.status_code != 200:
298
+ error_content = await response.aread()
299
+ error_message = error_content.decode('utf-8')
300
+ raise BaseException(f"{response.status_code}: {error_message}")
301
+ try:
302
+ async for line in response.aiter_lines():
303
+ if not line:
304
+ continue
305
+ # print(line)
306
+ if line and '\"text\": \"' in line:
307
+ content = line.split('\"text\": \"')[1][:-1]
308
+ content = "\n".join(content.split("\\n"))
309
+ full_response += content
310
+ yield content
311
+
312
+ if line and '\"totalTokenCount\": ' in line:
313
+ content = int(line.split('\"totalTokenCount\": ')[1])
314
+ total_tokens = content
315
+
316
+ if line and ('\"functionCall\": {' in line or revicing_function_call):
317
+ revicing_function_call = True
318
+ need_function_call = True
319
+ if ']' in line:
320
+ revicing_function_call = False
321
+ continue
322
+
323
+ function_full_response += line
324
+
325
+ except requests.exceptions.ChunkedEncodingError as e:
326
+ print("Chunked Encoding Error occurred:", e)
327
+ except Exception as e:
328
+ print("An error occurred:", e)
329
+
330
+ except Exception as e:
331
+ print(f"发生了未预料的错误: {e}")
332
+ return
333
+
334
+ if response.status_code != 200:
335
+ await response.aread()
336
+ print(response.text)
337
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
338
+ if self.print_log:
339
+ print("\n\rtotal_tokens", total_tokens)
340
+ if need_function_call:
341
+ # print(function_full_response)
342
+ function_call = json.loads(function_full_response)
343
+ print(json.dumps(function_call, indent=4, ensure_ascii=False))
344
+ function_call_name = function_call["functionCall"]["name"]
345
+ function_full_response = json.dumps(function_call["functionCall"]["args"])
346
+ function_call_max_tokens = 32000
347
+ print("\033[32m function_call", function_call_name, "max token:", function_call_max_tokens, "\033[0m")
348
+ async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, model or self.engine, gemini, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
349
+ if "function_response:" in chunk:
350
+ function_response = chunk.replace("function_response:", "")
351
+ else:
352
+ yield chunk
353
+ response_role = "model"
354
+ async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, function_arguments=function_call, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt):
355
+ yield chunk
356
+ else:
357
+ self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)