aient 1.1.58__py3-none-any.whl → 1.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/models/groq.py DELETED
@@ -1,234 +0,0 @@
1
- import os
2
- import json
3
- import requests
4
-
5
- from .base import BaseLLM
6
-
7
- class groq(BaseLLM):
8
- def __init__(
9
- self,
10
- api_key: str = None,
11
- engine: str = os.environ.get("GPT_ENGINE") or "llama3-70b-8192",
12
- api_url: str = "https://api.groq.com/openai/v1/chat/completions",
13
- system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
14
- temperature: float = 0.5,
15
- top_p: float = 1,
16
- timeout: float = 20,
17
- ):
18
- super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p)
19
- self.api_url = api_url
20
-
21
- def add_to_conversation(
22
- self,
23
- message: str,
24
- role: str,
25
- convo_id: str = "default",
26
- pass_history: int = 9999,
27
- total_tokens: int = 0,
28
- ) -> None:
29
- """
30
- Add a message to the conversation
31
- """
32
- if convo_id not in self.conversation or pass_history <= 2:
33
- self.reset(convo_id=convo_id)
34
- self.conversation[convo_id].append({"role": role, "content": message})
35
-
36
- history_len = len(self.conversation[convo_id])
37
- history = pass_history
38
- if pass_history < 2:
39
- history = 2
40
- while history_len > history:
41
- self.conversation[convo_id].pop(1)
42
- history_len = history_len - 1
43
-
44
- if total_tokens:
45
- self.tokens_usage[convo_id] += total_tokens
46
-
47
- def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
48
- """
49
- Reset the conversation
50
- """
51
- self.conversation[convo_id] = list()
52
- self.system_prompt = system_prompt or self.system_prompt
53
-
54
- def ask_stream(
55
- self,
56
- prompt: str,
57
- role: str = "user",
58
- convo_id: str = "default",
59
- model: str = "",
60
- pass_history: int = 9999,
61
- model_max_tokens: int = 1024,
62
- system_prompt: str = None,
63
- **kwargs,
64
- ):
65
- self.system_prompt = system_prompt or self.system_prompt
66
- if convo_id not in self.conversation or pass_history <= 2:
67
- self.reset(convo_id=convo_id)
68
- self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
69
- # self.__truncate_conversation(convo_id=convo_id)
70
- # print(self.conversation[convo_id])
71
-
72
- url = self.api_url
73
- headers = {
74
- "Authorization": f"Bearer {kwargs.get('GROQ_API_KEY', self.api_key)}",
75
- "Content-Type": "application/json",
76
- }
77
-
78
- self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
79
- json_post = {
80
- "messages": self.conversation[convo_id] if pass_history else [{
81
- "role": "user",
82
- "content": prompt
83
- }],
84
- "model": model or self.engine,
85
- "temperature": kwargs.get("temperature", self.temperature),
86
- "max_tokens": model_max_tokens,
87
- "top_p": kwargs.get("top_p", self.top_p),
88
- "stop": None,
89
- "stream": True,
90
- }
91
- # print("json_post", json_post)
92
- # print(os.environ.get("GPT_ENGINE"), model, self.engine)
93
-
94
- try:
95
- response = self.session.post(
96
- url,
97
- headers=headers,
98
- json=json_post,
99
- timeout=kwargs.get("timeout", self.timeout),
100
- stream=True,
101
- )
102
- except ConnectionError:
103
- print("连接错误,请检查服务器状态或网络连接。")
104
- return
105
- except requests.exceptions.ReadTimeout:
106
- print("请求超时,请检查网络连接或增加超时时间。{e}")
107
- return
108
- except Exception as e:
109
- print(f"发生了未预料的错误: {e}")
110
- return
111
-
112
- if response.status_code != 200:
113
- print(response.text)
114
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
115
- response_role: str = "assistant"
116
- full_response: str = ""
117
- for line in response.iter_lines():
118
- if not line:
119
- continue
120
- # Remove "data: "
121
- # print(line.decode("utf-8"))
122
- if line.decode("utf-8")[:6] == "data: ":
123
- line = line.decode("utf-8")[6:]
124
- else:
125
- print(line.decode("utf-8"))
126
- full_response = json.loads(line.decode("utf-8"))["choices"][0]["message"]["content"]
127
- yield full_response
128
- break
129
- if line == "[DONE]":
130
- break
131
- resp: dict = json.loads(line)
132
- # print("resp", resp)
133
- choices = resp.get("choices")
134
- if not choices:
135
- continue
136
- delta = choices[0].get("delta")
137
- if not delta:
138
- continue
139
- if "role" in delta:
140
- response_role = delta["role"]
141
- if "content" in delta and delta["content"]:
142
- content = delta["content"]
143
- full_response += content
144
- yield content
145
- self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
146
-
147
- async def ask_stream_async(
148
- self,
149
- prompt: str,
150
- role: str = "user",
151
- convo_id: str = "default",
152
- model: str = "",
153
- pass_history: int = 9999,
154
- model_max_tokens: int = 1024,
155
- system_prompt: str = None,
156
- **kwargs,
157
- ):
158
- self.system_prompt = system_prompt or self.system_prompt
159
- if convo_id not in self.conversation or pass_history <= 2:
160
- self.reset(convo_id=convo_id)
161
- self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
162
- # self.__truncate_conversation(convo_id=convo_id)
163
- # print(self.conversation[convo_id])
164
-
165
- url = self.api_url
166
- headers = {
167
- "Authorization": f"Bearer {os.environ.get('GROQ_API_KEY', self.api_key) or kwargs.get('api_key')}",
168
- "Content-Type": "application/json",
169
- }
170
-
171
- self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
172
- json_post = {
173
- "messages": self.conversation[convo_id] if pass_history else [{
174
- "role": "user",
175
- "content": prompt
176
- }],
177
- "model": model or self.engine,
178
- "temperature": kwargs.get("temperature", self.temperature),
179
- "max_tokens": model_max_tokens,
180
- "top_p": kwargs.get("top_p", self.top_p),
181
- "stop": None,
182
- "stream": True,
183
- }
184
- # print("json_post", json_post)
185
- # print(os.environ.get("GPT_ENGINE"), model, self.engine)
186
-
187
- response_role: str = "assistant"
188
- full_response: str = ""
189
- try:
190
- async with self.aclient.stream(
191
- "post",
192
- url,
193
- headers=headers,
194
- json=json_post,
195
- timeout=kwargs.get("timeout", self.timeout),
196
- ) as response:
197
- if response.status_code != 200:
198
- await response.aread()
199
- print(response.text)
200
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
201
- async for line in response.aiter_lines():
202
- if not line:
203
- continue
204
- # Remove "data: "
205
- # print(line)
206
- if line[:6] == "data: ":
207
- line = line.lstrip("data: ")
208
- else:
209
- full_response = json.loads(line)["choices"][0]["message"]["content"]
210
- yield full_response
211
- break
212
- if line == "[DONE]":
213
- break
214
- resp: dict = json.loads(line)
215
- # print("resp", resp)
216
- choices = resp.get("choices")
217
- if not choices:
218
- continue
219
- delta = choices[0].get("delta")
220
- if not delta:
221
- continue
222
- if "role" in delta:
223
- response_role = delta["role"]
224
- if "content" in delta and delta["content"]:
225
- content = delta["content"]
226
- full_response += content
227
- yield content
228
- except Exception as e:
229
- print(f"发生了未预料的错误: {e}")
230
- import traceback
231
- traceback.print_exc()
232
- return
233
-
234
- self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
aient/models/vertex.py DELETED
@@ -1,420 +0,0 @@
1
- import os
2
- import re
3
- import json
4
- import requests
5
-
6
-
7
- from .base import BaseLLM
8
- from ..core.utils import BaseAPI
9
-
10
- import copy
11
- from ..plugins import PLUGINS, get_tools_result_async, function_call_list
12
- from ..utils.scripts import safe_get
13
-
14
- import time
15
- import httpx
16
- import base64
17
- from cryptography.hazmat.primitives import hashes
18
- from cryptography.hazmat.primitives.asymmetric import padding
19
- from cryptography.hazmat.primitives.serialization import load_pem_private_key
20
-
21
- def create_jwt(client_email, private_key):
22
- # JWT Header
23
- header = json.dumps({
24
- "alg": "RS256",
25
- "typ": "JWT"
26
- }).encode()
27
-
28
- # JWT Payload
29
- now = int(time.time())
30
- payload = json.dumps({
31
- "iss": client_email,
32
- "scope": "https://www.googleapis.com/auth/cloud-platform",
33
- "aud": "https://oauth2.googleapis.com/token",
34
- "exp": now + 3600,
35
- "iat": now
36
- }).encode()
37
-
38
- # Encode header and payload
39
- segments = [
40
- base64.urlsafe_b64encode(header).rstrip(b'='),
41
- base64.urlsafe_b64encode(payload).rstrip(b'=')
42
- ]
43
-
44
- # Create signature
45
- signing_input = b'.'.join(segments)
46
- private_key = load_pem_private_key(private_key.encode(), password=None)
47
- signature = private_key.sign(
48
- signing_input,
49
- padding.PKCS1v15(),
50
- hashes.SHA256()
51
- )
52
-
53
- segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
54
- return b'.'.join(segments).decode()
55
-
56
- def get_access_token(client_email, private_key):
57
- jwt = create_jwt(client_email, private_key)
58
-
59
- with httpx.Client() as client:
60
- response = client.post(
61
- "https://oauth2.googleapis.com/token",
62
- data={
63
- "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
64
- "assertion": jwt
65
- },
66
- headers={'Content-Type': "application/x-www-form-urlencoded"}
67
- )
68
- response.raise_for_status()
69
- return response.json()["access_token"]
70
-
71
- # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#python
72
- class vertex(BaseLLM):
73
- def __init__(
74
- self,
75
- api_key: str = None,
76
- engine: str = os.environ.get("GPT_ENGINE") or "gemini-1.5-pro-latest",
77
- api_url: str = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}",
78
- system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally",
79
- project_id: str = os.environ.get("VERTEX_PROJECT_ID", None),
80
- temperature: float = 0.5,
81
- top_p: float = 0.7,
82
- timeout: float = 20,
83
- use_plugins: bool = True,
84
- print_log: bool = False,
85
- ):
86
- url = api_url.format(PROJECT_ID=os.environ.get("VERTEX_PROJECT_ID", project_id), MODEL_ID=engine, stream="streamGenerateContent")
87
- super().__init__(api_key, engine, url, system_prompt=system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
88
- self.conversation: dict[str, list[dict]] = {
89
- "default": [],
90
- }
91
-
92
- def add_to_conversation(
93
- self,
94
- message: str,
95
- role: str,
96
- convo_id: str = "default",
97
- pass_history: int = 9999,
98
- total_tokens: int = 0,
99
- function_arguments: str = "",
100
- ) -> None:
101
- """
102
- Add a message to the conversation
103
- """
104
-
105
- if convo_id not in self.conversation or pass_history <= 2:
106
- self.reset(convo_id=convo_id)
107
- # print("message", message)
108
-
109
- if function_arguments:
110
- self.conversation[convo_id].append(
111
- {
112
- "role": "model",
113
- "parts": [function_arguments]
114
- }
115
- )
116
- function_call_name = function_arguments["functionCall"]["name"]
117
- self.conversation[convo_id].append(
118
- {
119
- "role": "function",
120
- "parts": [{
121
- "functionResponse": {
122
- "name": function_call_name,
123
- "response": {
124
- "name": function_call_name,
125
- "content": {
126
- "result": message,
127
- }
128
- }
129
- }
130
- }]
131
- }
132
- )
133
-
134
- else:
135
- if isinstance(message, str):
136
- message = [{"text": message}]
137
- self.conversation[convo_id].append({"role": role, "parts": message})
138
-
139
- history_len = len(self.conversation[convo_id])
140
- history = pass_history
141
- if pass_history < 2:
142
- history = 2
143
- while history_len > history:
144
- mess_body = self.conversation[convo_id].pop(1)
145
- history_len = history_len - 1
146
- if mess_body.get("role") == "user":
147
- mess_body = self.conversation[convo_id].pop(1)
148
- history_len = history_len - 1
149
- if safe_get(mess_body, "parts", 0, "functionCall"):
150
- self.conversation[convo_id].pop(1)
151
- history_len = history_len - 1
152
-
153
- if total_tokens:
154
- self.tokens_usage[convo_id] += total_tokens
155
-
156
- def reset(self, convo_id: str = "default", system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally") -> None:
157
- """
158
- Reset the conversation
159
- """
160
- self.system_prompt = system_prompt or self.system_prompt
161
- self.conversation[convo_id] = list()
162
-
163
- def ask_stream(
164
- self,
165
- prompt: str,
166
- role: str = "user",
167
- convo_id: str = "default",
168
- model: str = "",
169
- pass_history: int = 9999,
170
- model_max_tokens: int = 4096,
171
- systemprompt: str = None,
172
- **kwargs,
173
- ):
174
- self.system_prompt = systemprompt or self.system_prompt
175
- if convo_id not in self.conversation or pass_history <= 2:
176
- self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
177
- self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
178
- # print(self.conversation[convo_id])
179
-
180
- headers = {
181
- "Content-Type": "application/json",
182
- }
183
-
184
- json_post = {
185
- "contents": self.conversation[convo_id] if pass_history else [{
186
- "role": "user",
187
- "content": prompt
188
- }],
189
- "systemInstruction": {"parts": [{"text": self.system_prompt}]},
190
- "safetySettings": [
191
- {
192
- "category": "HARM_CATEGORY_HARASSMENT",
193
- "threshold": "BLOCK_NONE"
194
- },
195
- {
196
- "category": "HARM_CATEGORY_HATE_SPEECH",
197
- "threshold": "BLOCK_NONE"
198
- },
199
- {
200
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
201
- "threshold": "BLOCK_NONE"
202
- },
203
- {
204
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
205
- "threshold": "BLOCK_NONE"
206
- }
207
- ],
208
- }
209
- if self.print_log:
210
- replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
211
- print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
212
-
213
- url = self.api_url.format(model=model or self.engine, stream="streamGenerateContent", api_key=self.api_key)
214
-
215
- try:
216
- response = self.session.post(
217
- url,
218
- headers=headers,
219
- json=json_post,
220
- timeout=kwargs.get("timeout", self.timeout),
221
- stream=True,
222
- )
223
- except ConnectionError:
224
- print("连接错误,请检查服务器状态或网络连接。")
225
- return
226
- except requests.exceptions.ReadTimeout:
227
- print("请求超时,请检查网络连接或增加超时时间。{e}")
228
- return
229
- except Exception as e:
230
- print(f"发生了未预料的错误: {e}")
231
- return
232
-
233
- if response.status_code != 200:
234
- print(response.text)
235
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
236
- response_role: str = "model"
237
- full_response: str = ""
238
- try:
239
- for line in response.iter_lines():
240
- if not line:
241
- continue
242
- line = line.decode("utf-8")
243
- if line and '\"text\": \"' in line:
244
- content = line.split('\"text\": \"')[1][:-1]
245
- content = "\n".join(content.split("\\n"))
246
- full_response += content
247
- yield content
248
- except requests.exceptions.ChunkedEncodingError as e:
249
- print("Chunked Encoding Error occurred:", e)
250
- except Exception as e:
251
- print("An error occurred:", e)
252
-
253
- self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, pass_history=pass_history)
254
-
255
- async def ask_stream_async(
256
- self,
257
- prompt: str,
258
- role: str = "user",
259
- convo_id: str = "default",
260
- model: str = "",
261
- pass_history: int = 9999,
262
- systemprompt: str = None,
263
- language: str = "English",
264
- function_arguments: str = "",
265
- total_tokens: int = 0,
266
- **kwargs,
267
- ):
268
- self.system_prompt = systemprompt or self.system_prompt
269
- if convo_id not in self.conversation or pass_history <= 2:
270
- self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
271
- self.add_to_conversation(prompt, role, convo_id=convo_id, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history)
272
- # print(self.conversation[convo_id])
273
-
274
- client_email = os.environ.get("VERTEX_CLIENT_EMAIL")
275
- private_key = os.environ.get("VERTEX_PRIVATE_KEY")
276
- access_token = get_access_token(client_email, private_key)
277
- headers = {
278
- 'Authorization': f"Bearer {access_token}",
279
- "Content-Type": "application/json",
280
- }
281
-
282
- json_post = {
283
- "contents": self.conversation[convo_id] if pass_history else [{
284
- "role": "user",
285
- "content": prompt
286
- }],
287
- "system_instruction": {"parts": [{"text": self.system_prompt}]},
288
- # "safety_settings": [
289
- # {
290
- # "category": "HARM_CATEGORY_HARASSMENT",
291
- # "threshold": "BLOCK_NONE"
292
- # },
293
- # {
294
- # "category": "HARM_CATEGORY_HATE_SPEECH",
295
- # "threshold": "BLOCK_NONE"
296
- # },
297
- # {
298
- # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
299
- # "threshold": "BLOCK_NONE"
300
- # },
301
- # {
302
- # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
303
- # "threshold": "BLOCK_NONE"
304
- # }
305
- # ],
306
- "generationConfig": {
307
- "temperature": self.temperature,
308
- "max_output_tokens": 8192,
309
- "top_k": 40,
310
- "top_p": 0.95
311
- },
312
- }
313
-
314
- plugins = kwargs.get("plugins", PLUGINS)
315
- if all(value == False for value in plugins.values()) == False and self.use_plugins:
316
- tools = {
317
- "tools": [
318
- {
319
- "function_declarations": [
320
-
321
- ]
322
- }
323
- ],
324
- "tool_config": {
325
- "function_calling_config": {
326
- "mode": "AUTO",
327
- },
328
- },
329
- }
330
- json_post.update(copy.deepcopy(tools))
331
- for item in plugins.keys():
332
- try:
333
- if plugins[item]:
334
- json_post["tools"][0]["function_declarations"].append(function_call_list[item])
335
- except:
336
- pass
337
-
338
- if self.print_log:
339
- replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
340
- print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
341
-
342
- url = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}".format(PROJECT_ID=os.environ.get("VERTEX_PROJECT_ID"), MODEL_ID=model, stream="streamGenerateContent")
343
- self.api_url = BaseAPI(url)
344
- url = self.api_url.source_api_url
345
-
346
- response_role: str = "model"
347
- full_response: str = ""
348
- function_full_response: str = "{"
349
- need_function_call = False
350
- revicing_function_call = False
351
- total_tokens = 0
352
- try:
353
- async with self.aclient.stream(
354
- "post",
355
- url,
356
- headers=headers,
357
- json=json_post,
358
- timeout=kwargs.get("timeout", self.timeout),
359
- ) as response:
360
- if response.status_code != 200:
361
- error_content = await response.aread()
362
- error_message = error_content.decode('utf-8')
363
- raise BaseException(f"{response.status_code}: {error_message}")
364
- try:
365
- async for line in response.aiter_lines():
366
- if not line:
367
- continue
368
- # print(line)
369
- if line and '\"text\": \"' in line:
370
- content = line.split('\"text\": \"')[1][:-1]
371
- content = "\n".join(content.split("\\n"))
372
- full_response += content
373
- yield content
374
-
375
- if line and '\"totalTokenCount\": ' in line:
376
- content = int(line.split('\"totalTokenCount\": ')[1])
377
- total_tokens = content
378
-
379
- if line and ('\"functionCall\": {' in line or revicing_function_call):
380
- revicing_function_call = True
381
- need_function_call = True
382
- if ']' in line:
383
- revicing_function_call = False
384
- continue
385
-
386
- function_full_response += line
387
-
388
- except requests.exceptions.ChunkedEncodingError as e:
389
- print("Chunked Encoding Error occurred:", e)
390
- except Exception as e:
391
- print("An error occurred:", e)
392
-
393
- except Exception as e:
394
- print(f"发生了未预料的错误: {e}")
395
- return
396
-
397
- if response.status_code != 200:
398
- await response.aread()
399
- print(response.text)
400
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
401
- if self.print_log:
402
- print("\n\rtotal_tokens", total_tokens)
403
- if need_function_call:
404
- # print(function_full_response)
405
- function_call = json.loads(function_full_response)
406
- print(json.dumps(function_call, indent=4, ensure_ascii=False))
407
- function_call_name = function_call["functionCall"]["name"]
408
- function_full_response = json.dumps(function_call["functionCall"]["args"])
409
- function_call_max_tokens = 32000
410
- print("\033[32m function_call", function_call_name, "max token:", function_call_max_tokens, "\033[0m")
411
- async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, model or self.engine, vertex, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
412
- if "function_response:" in chunk:
413
- function_response = chunk.replace("function_response:", "")
414
- else:
415
- yield chunk
416
- response_role = "model"
417
- async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, function_arguments=function_call, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS)):
418
- yield chunk
419
- else:
420
- self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
File without changes