aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/models/groq.py ADDED
@@ -0,0 +1,268 @@
1
+ import os
2
+ import json
3
+ import requests
4
+ import tiktoken
5
+
6
+ from .base import BaseLLM
7
+
8
+ class groq(BaseLLM):
9
+ def __init__(
10
+ self,
11
+ api_key: str = None,
12
+ engine: str = os.environ.get("GPT_ENGINE") or "llama3-70b-8192",
13
+ api_url: str = "https://api.groq.com/openai/v1/chat/completions",
14
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
15
+ temperature: float = 0.5,
16
+ top_p: float = 1,
17
+ timeout: float = 20,
18
+ ):
19
+ super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p)
20
+ self.api_url = api_url
21
+
22
+ def add_to_conversation(
23
+ self,
24
+ message: str,
25
+ role: str,
26
+ convo_id: str = "default",
27
+ pass_history: int = 9999,
28
+ total_tokens: int = 0,
29
+ ) -> None:
30
+ """
31
+ Add a message to the conversation
32
+ """
33
+ if convo_id not in self.conversation or pass_history <= 2:
34
+ self.reset(convo_id=convo_id)
35
+ self.conversation[convo_id].append({"role": role, "content": message})
36
+
37
+ history_len = len(self.conversation[convo_id])
38
+ history = pass_history
39
+ if pass_history < 2:
40
+ history = 2
41
+ while history_len > history:
42
+ self.conversation[convo_id].pop(1)
43
+ history_len = history_len - 1
44
+
45
+ if total_tokens:
46
+ self.tokens_usage[convo_id] += total_tokens
47
+
48
+ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
49
+ """
50
+ Reset the conversation
51
+ """
52
+ self.conversation[convo_id] = list()
53
+ self.system_prompt = system_prompt or self.system_prompt
54
+
55
+ def __truncate_conversation(self, convo_id: str = "default") -> None:
56
+ """
57
+ Truncate the conversation
58
+ """
59
+ while True:
60
+ if (
61
+ self.get_token_count(convo_id) > self.truncate_limit
62
+ and len(self.conversation[convo_id]) > 1
63
+ ):
64
+ # Don't remove the first message
65
+ self.conversation[convo_id].pop(1)
66
+ else:
67
+ break
68
+
69
+ def get_token_count(self, convo_id: str = "default") -> int:
70
+ """
71
+ Get token count
72
+ """
73
+ # tiktoken.model.MODEL_TO_ENCODING["mixtral-8x7b-32768"] = "cl100k_base"
74
+ encoding = tiktoken.get_encoding("cl100k_base")
75
+
76
+ num_tokens = 0
77
+ for message in self.conversation[convo_id]:
78
+ # every message follows <im_start>{role/name}\n{content}<im_end>\n
79
+ num_tokens += 5
80
+ for key, value in message.items():
81
+ if value:
82
+ num_tokens += len(encoding.encode(value))
83
+ if key == "name": # if there's a name, the role is omitted
84
+ num_tokens += 5 # role is always required and always 1 token
85
+ num_tokens += 5 # every reply is primed with <im_start>assistant
86
+ return num_tokens
87
+
88
+ def ask_stream(
89
+ self,
90
+ prompt: str,
91
+ role: str = "user",
92
+ convo_id: str = "default",
93
+ model: str = "",
94
+ pass_history: int = 9999,
95
+ model_max_tokens: int = 1024,
96
+ system_prompt: str = None,
97
+ **kwargs,
98
+ ):
99
+ self.system_prompt = system_prompt or self.system_prompt
100
+ if convo_id not in self.conversation or pass_history <= 2:
101
+ self.reset(convo_id=convo_id)
102
+ self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
103
+ # self.__truncate_conversation(convo_id=convo_id)
104
+ # print(self.conversation[convo_id])
105
+
106
+ url = self.api_url
107
+ headers = {
108
+ "Authorization": f"Bearer {kwargs.get('GROQ_API_KEY', self.api_key)}",
109
+ "Content-Type": "application/json",
110
+ }
111
+
112
+ self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
113
+ json_post = {
114
+ "messages": self.conversation[convo_id] if pass_history else [{
115
+ "role": "user",
116
+ "content": prompt
117
+ }],
118
+ "model": model or self.engine,
119
+ "temperature": kwargs.get("temperature", self.temperature),
120
+ "max_tokens": model_max_tokens,
121
+ "top_p": kwargs.get("top_p", self.top_p),
122
+ "stop": None,
123
+ "stream": True,
124
+ }
125
+ # print("json_post", json_post)
126
+ # print(os.environ.get("GPT_ENGINE"), model, self.engine)
127
+
128
+ try:
129
+ response = self.session.post(
130
+ url,
131
+ headers=headers,
132
+ json=json_post,
133
+ timeout=kwargs.get("timeout", self.timeout),
134
+ stream=True,
135
+ )
136
+ except ConnectionError:
137
+ print("连接错误,请检查服务器状态或网络连接。")
138
+ return
139
+ except requests.exceptions.ReadTimeout:
140
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
141
+ return
142
+ except Exception as e:
143
+ print(f"发生了未预料的错误: {e}")
144
+ return
145
+
146
+ if response.status_code != 200:
147
+ print(response.text)
148
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
149
+ response_role: str = "assistant"
150
+ full_response: str = ""
151
+ for line in response.iter_lines():
152
+ if not line:
153
+ continue
154
+ # Remove "data: "
155
+ # print(line.decode("utf-8"))
156
+ if line.decode("utf-8")[:6] == "data: ":
157
+ line = line.decode("utf-8")[6:]
158
+ else:
159
+ print(line.decode("utf-8"))
160
+ full_response = json.loads(line.decode("utf-8"))["choices"][0]["message"]["content"]
161
+ yield full_response
162
+ break
163
+ if line == "[DONE]":
164
+ break
165
+ resp: dict = json.loads(line)
166
+ # print("resp", resp)
167
+ choices = resp.get("choices")
168
+ if not choices:
169
+ continue
170
+ delta = choices[0].get("delta")
171
+ if not delta:
172
+ continue
173
+ if "role" in delta:
174
+ response_role = delta["role"]
175
+ if "content" in delta and delta["content"]:
176
+ content = delta["content"]
177
+ full_response += content
178
+ yield content
179
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
180
+
181
+ async def ask_stream_async(
182
+ self,
183
+ prompt: str,
184
+ role: str = "user",
185
+ convo_id: str = "default",
186
+ model: str = "",
187
+ pass_history: int = 9999,
188
+ model_max_tokens: int = 1024,
189
+ system_prompt: str = None,
190
+ **kwargs,
191
+ ):
192
+ self.system_prompt = system_prompt or self.system_prompt
193
+ if convo_id not in self.conversation or pass_history <= 2:
194
+ self.reset(convo_id=convo_id)
195
+ self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
196
+ # self.__truncate_conversation(convo_id=convo_id)
197
+ # print(self.conversation[convo_id])
198
+
199
+ url = self.api_url
200
+ headers = {
201
+ "Authorization": f"Bearer {os.environ.get('GROQ_API_KEY', self.api_key) or kwargs.get('api_key')}",
202
+ "Content-Type": "application/json",
203
+ }
204
+
205
+ self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
206
+ json_post = {
207
+ "messages": self.conversation[convo_id] if pass_history else [{
208
+ "role": "user",
209
+ "content": prompt
210
+ }],
211
+ "model": model or self.engine,
212
+ "temperature": kwargs.get("temperature", self.temperature),
213
+ "max_tokens": model_max_tokens,
214
+ "top_p": kwargs.get("top_p", self.top_p),
215
+ "stop": None,
216
+ "stream": True,
217
+ }
218
+ # print("json_post", json_post)
219
+ # print(os.environ.get("GPT_ENGINE"), model, self.engine)
220
+
221
+ response_role: str = "assistant"
222
+ full_response: str = ""
223
+ try:
224
+ async with self.aclient.stream(
225
+ "post",
226
+ url,
227
+ headers=headers,
228
+ json=json_post,
229
+ timeout=kwargs.get("timeout", self.timeout),
230
+ ) as response:
231
+ if response.status_code != 200:
232
+ await response.aread()
233
+ print(response.text)
234
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
235
+ async for line in response.aiter_lines():
236
+ if not line:
237
+ continue
238
+ # Remove "data: "
239
+ # print(line)
240
+ if line[:6] == "data: ":
241
+ line = line.lstrip("data: ")
242
+ else:
243
+ full_response = json.loads(line)["choices"][0]["message"]["content"]
244
+ yield full_response
245
+ break
246
+ if line == "[DONE]":
247
+ break
248
+ resp: dict = json.loads(line)
249
+ # print("resp", resp)
250
+ choices = resp.get("choices")
251
+ if not choices:
252
+ continue
253
+ delta = choices[0].get("delta")
254
+ if not delta:
255
+ continue
256
+ if "role" in delta:
257
+ response_role = delta["role"]
258
+ if "content" in delta and delta["content"]:
259
+ content = delta["content"]
260
+ full_response += content
261
+ yield content
262
+ except Exception as e:
263
+ print(f"发生了未预料的错误: {e}")
264
+ import traceback
265
+ traceback.print_exc()
266
+ return
267
+
268
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
aient/models/vertex.py ADDED
@@ -0,0 +1,420 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import requests
5
+
6
+
7
+ from .base import BaseLLM
8
+ from ..core.utils import BaseAPI
9
+
10
+ import copy
11
+ from ..plugins import PLUGINS, get_tools_result_async, function_call_list
12
+ from ..utils.scripts import safe_get
13
+
14
+ import time
15
+ import httpx
16
+ import base64
17
+ from cryptography.hazmat.primitives import hashes
18
+ from cryptography.hazmat.primitives.asymmetric import padding
19
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
20
+
21
+ def create_jwt(client_email, private_key):
22
+ # JWT Header
23
+ header = json.dumps({
24
+ "alg": "RS256",
25
+ "typ": "JWT"
26
+ }).encode()
27
+
28
+ # JWT Payload
29
+ now = int(time.time())
30
+ payload = json.dumps({
31
+ "iss": client_email,
32
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
33
+ "aud": "https://oauth2.googleapis.com/token",
34
+ "exp": now + 3600,
35
+ "iat": now
36
+ }).encode()
37
+
38
+ # Encode header and payload
39
+ segments = [
40
+ base64.urlsafe_b64encode(header).rstrip(b'='),
41
+ base64.urlsafe_b64encode(payload).rstrip(b'=')
42
+ ]
43
+
44
+ # Create signature
45
+ signing_input = b'.'.join(segments)
46
+ private_key = load_pem_private_key(private_key.encode(), password=None)
47
+ signature = private_key.sign(
48
+ signing_input,
49
+ padding.PKCS1v15(),
50
+ hashes.SHA256()
51
+ )
52
+
53
+ segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
54
+ return b'.'.join(segments).decode()
55
+
56
+ def get_access_token(client_email, private_key):
57
+ jwt = create_jwt(client_email, private_key)
58
+
59
+ with httpx.Client() as client:
60
+ response = client.post(
61
+ "https://oauth2.googleapis.com/token",
62
+ data={
63
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
64
+ "assertion": jwt
65
+ },
66
+ headers={'Content-Type': "application/x-www-form-urlencoded"}
67
+ )
68
+ response.raise_for_status()
69
+ return response.json()["access_token"]
70
+
71
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#python
72
+ class vertex(BaseLLM):
73
+ def __init__(
74
+ self,
75
+ api_key: str = None,
76
+ engine: str = os.environ.get("GPT_ENGINE") or "gemini-1.5-pro-latest",
77
+ api_url: str = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}",
78
+ system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally",
79
+ project_id: str = os.environ.get("VERTEX_PROJECT_ID", None),
80
+ temperature: float = 0.5,
81
+ top_p: float = 0.7,
82
+ timeout: float = 20,
83
+ use_plugins: bool = True,
84
+ print_log: bool = False,
85
+ ):
86
+ url = api_url.format(PROJECT_ID=os.environ.get("VERTEX_PROJECT_ID", project_id), MODEL_ID=engine, stream="streamGenerateContent")
87
+ super().__init__(api_key, engine, url, system_prompt=system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
88
+ self.conversation: dict[str, list[dict]] = {
89
+ "default": [],
90
+ }
91
+
92
+ def add_to_conversation(
93
+ self,
94
+ message: str,
95
+ role: str,
96
+ convo_id: str = "default",
97
+ pass_history: int = 9999,
98
+ total_tokens: int = 0,
99
+ function_arguments: str = "",
100
+ ) -> None:
101
+ """
102
+ Add a message to the conversation
103
+ """
104
+
105
+ if convo_id not in self.conversation or pass_history <= 2:
106
+ self.reset(convo_id=convo_id)
107
+ # print("message", message)
108
+
109
+ if function_arguments:
110
+ self.conversation[convo_id].append(
111
+ {
112
+ "role": "model",
113
+ "parts": [function_arguments]
114
+ }
115
+ )
116
+ function_call_name = function_arguments["functionCall"]["name"]
117
+ self.conversation[convo_id].append(
118
+ {
119
+ "role": "function",
120
+ "parts": [{
121
+ "functionResponse": {
122
+ "name": function_call_name,
123
+ "response": {
124
+ "name": function_call_name,
125
+ "content": {
126
+ "result": message,
127
+ }
128
+ }
129
+ }
130
+ }]
131
+ }
132
+ )
133
+
134
+ else:
135
+ if isinstance(message, str):
136
+ message = [{"text": message}]
137
+ self.conversation[convo_id].append({"role": role, "parts": message})
138
+
139
+ history_len = len(self.conversation[convo_id])
140
+ history = pass_history
141
+ if pass_history < 2:
142
+ history = 2
143
+ while history_len > history:
144
+ mess_body = self.conversation[convo_id].pop(1)
145
+ history_len = history_len - 1
146
+ if mess_body.get("role") == "user":
147
+ mess_body = self.conversation[convo_id].pop(1)
148
+ history_len = history_len - 1
149
+ if safe_get(mess_body, "parts", 0, "functionCall"):
150
+ self.conversation[convo_id].pop(1)
151
+ history_len = history_len - 1
152
+
153
+ if total_tokens:
154
+ self.tokens_usage[convo_id] += total_tokens
155
+
156
+ def reset(self, convo_id: str = "default", system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally") -> None:
157
+ """
158
+ Reset the conversation
159
+ """
160
+ self.system_prompt = system_prompt or self.system_prompt
161
+ self.conversation[convo_id] = list()
162
+
163
+ def ask_stream(
164
+ self,
165
+ prompt: str,
166
+ role: str = "user",
167
+ convo_id: str = "default",
168
+ model: str = "",
169
+ pass_history: int = 9999,
170
+ model_max_tokens: int = 4096,
171
+ systemprompt: str = None,
172
+ **kwargs,
173
+ ):
174
+ self.system_prompt = systemprompt or self.system_prompt
175
+ if convo_id not in self.conversation or pass_history <= 2:
176
+ self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
177
+ self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
178
+ # print(self.conversation[convo_id])
179
+
180
+ headers = {
181
+ "Content-Type": "application/json",
182
+ }
183
+
184
+ json_post = {
185
+ "contents": self.conversation[convo_id] if pass_history else [{
186
+ "role": "user",
187
+ "content": prompt
188
+ }],
189
+ "systemInstruction": {"parts": [{"text": self.system_prompt}]},
190
+ "safetySettings": [
191
+ {
192
+ "category": "HARM_CATEGORY_HARASSMENT",
193
+ "threshold": "BLOCK_NONE"
194
+ },
195
+ {
196
+ "category": "HARM_CATEGORY_HATE_SPEECH",
197
+ "threshold": "BLOCK_NONE"
198
+ },
199
+ {
200
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
201
+ "threshold": "BLOCK_NONE"
202
+ },
203
+ {
204
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
205
+ "threshold": "BLOCK_NONE"
206
+ }
207
+ ],
208
+ }
209
+ if self.print_log:
210
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
211
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
212
+
213
+ url = self.api_url.format(model=model or self.engine, stream="streamGenerateContent", api_key=self.api_key)
214
+
215
+ try:
216
+ response = self.session.post(
217
+ url,
218
+ headers=headers,
219
+ json=json_post,
220
+ timeout=kwargs.get("timeout", self.timeout),
221
+ stream=True,
222
+ )
223
+ except ConnectionError:
224
+ print("连接错误,请检查服务器状态或网络连接。")
225
+ return
226
+ except requests.exceptions.ReadTimeout:
227
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
228
+ return
229
+ except Exception as e:
230
+ print(f"发生了未预料的错误: {e}")
231
+ return
232
+
233
+ if response.status_code != 200:
234
+ print(response.text)
235
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
236
+ response_role: str = "model"
237
+ full_response: str = ""
238
+ try:
239
+ for line in response.iter_lines():
240
+ if not line:
241
+ continue
242
+ line = line.decode("utf-8")
243
+ if line and '\"text\": \"' in line:
244
+ content = line.split('\"text\": \"')[1][:-1]
245
+ content = "\n".join(content.split("\\n"))
246
+ full_response += content
247
+ yield content
248
+ except requests.exceptions.ChunkedEncodingError as e:
249
+ print("Chunked Encoding Error occurred:", e)
250
+ except Exception as e:
251
+ print("An error occurred:", e)
252
+
253
+ self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, pass_history=pass_history)
254
+
255
+ async def ask_stream_async(
256
+ self,
257
+ prompt: str,
258
+ role: str = "user",
259
+ convo_id: str = "default",
260
+ model: str = "",
261
+ pass_history: int = 9999,
262
+ systemprompt: str = None,
263
+ language: str = "English",
264
+ function_arguments: str = "",
265
+ total_tokens: int = 0,
266
+ **kwargs,
267
+ ):
268
+ self.system_prompt = systemprompt or self.system_prompt
269
+ if convo_id not in self.conversation or pass_history <= 2:
270
+ self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
271
+ self.add_to_conversation(prompt, role, convo_id=convo_id, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history)
272
+ # print(self.conversation[convo_id])
273
+
274
+ client_email = os.environ.get("VERTEX_CLIENT_EMAIL")
275
+ private_key = os.environ.get("VERTEX_PRIVATE_KEY")
276
+ access_token = get_access_token(client_email, private_key)
277
+ headers = {
278
+ 'Authorization': f"Bearer {access_token}",
279
+ "Content-Type": "application/json",
280
+ }
281
+
282
+ json_post = {
283
+ "contents": self.conversation[convo_id] if pass_history else [{
284
+ "role": "user",
285
+ "content": prompt
286
+ }],
287
+ "system_instruction": {"parts": [{"text": self.system_prompt}]},
288
+ # "safety_settings": [
289
+ # {
290
+ # "category": "HARM_CATEGORY_HARASSMENT",
291
+ # "threshold": "BLOCK_NONE"
292
+ # },
293
+ # {
294
+ # "category": "HARM_CATEGORY_HATE_SPEECH",
295
+ # "threshold": "BLOCK_NONE"
296
+ # },
297
+ # {
298
+ # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
299
+ # "threshold": "BLOCK_NONE"
300
+ # },
301
+ # {
302
+ # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
303
+ # "threshold": "BLOCK_NONE"
304
+ # }
305
+ # ],
306
+ "generationConfig": {
307
+ "temperature": self.temperature,
308
+ "max_output_tokens": 8192,
309
+ "top_k": 40,
310
+ "top_p": 0.95
311
+ },
312
+ }
313
+
314
+ plugins = kwargs.get("plugins", PLUGINS)
315
+ if all(value == False for value in plugins.values()) == False and self.use_plugins:
316
+ tools = {
317
+ "tools": [
318
+ {
319
+ "function_declarations": [
320
+
321
+ ]
322
+ }
323
+ ],
324
+ "tool_config": {
325
+ "function_calling_config": {
326
+ "mode": "AUTO",
327
+ },
328
+ },
329
+ }
330
+ json_post.update(copy.deepcopy(tools))
331
+ for item in plugins.keys():
332
+ try:
333
+ if plugins[item]:
334
+ json_post["tools"][0]["function_declarations"].append(function_call_list[item])
335
+ except:
336
+ pass
337
+
338
+ if self.print_log:
339
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
340
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
341
+
342
+ url = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}".format(PROJECT_ID=os.environ.get("VERTEX_PROJECT_ID"), MODEL_ID=model, stream="streamGenerateContent")
343
+ self.api_url = BaseAPI(url)
344
+ url = self.api_url.source_api_url
345
+
346
+ response_role: str = "model"
347
+ full_response: str = ""
348
+ function_full_response: str = "{"
349
+ need_function_call = False
350
+ revicing_function_call = False
351
+ total_tokens = 0
352
+ try:
353
+ async with self.aclient.stream(
354
+ "post",
355
+ url,
356
+ headers=headers,
357
+ json=json_post,
358
+ timeout=kwargs.get("timeout", self.timeout),
359
+ ) as response:
360
+ if response.status_code != 200:
361
+ error_content = await response.aread()
362
+ error_message = error_content.decode('utf-8')
363
+ raise BaseException(f"{response.status_code}: {error_message}")
364
+ try:
365
+ async for line in response.aiter_lines():
366
+ if not line:
367
+ continue
368
+ # print(line)
369
+ if line and '\"text\": \"' in line:
370
+ content = line.split('\"text\": \"')[1][:-1]
371
+ content = "\n".join(content.split("\\n"))
372
+ full_response += content
373
+ yield content
374
+
375
+ if line and '\"totalTokenCount\": ' in line:
376
+ content = int(line.split('\"totalTokenCount\": ')[1])
377
+ total_tokens = content
378
+
379
+ if line and ('\"functionCall\": {' in line or revicing_function_call):
380
+ revicing_function_call = True
381
+ need_function_call = True
382
+ if ']' in line:
383
+ revicing_function_call = False
384
+ continue
385
+
386
+ function_full_response += line
387
+
388
+ except requests.exceptions.ChunkedEncodingError as e:
389
+ print("Chunked Encoding Error occurred:", e)
390
+ except Exception as e:
391
+ print("An error occurred:", e)
392
+
393
+ except Exception as e:
394
+ print(f"发生了未预料的错误: {e}")
395
+ return
396
+
397
+ if response.status_code != 200:
398
+ await response.aread()
399
+ print(response.text)
400
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
401
+ if self.print_log:
402
+ print("\n\rtotal_tokens", total_tokens)
403
+ if need_function_call:
404
+ # print(function_full_response)
405
+ function_call = json.loads(function_full_response)
406
+ print(json.dumps(function_call, indent=4, ensure_ascii=False))
407
+ function_call_name = function_call["functionCall"]["name"]
408
+ function_full_response = json.dumps(function_call["functionCall"]["args"])
409
+ function_call_max_tokens = 32000
410
+ print("\033[32m function_call", function_call_name, "max token:", function_call_max_tokens, "\033[0m")
411
+ async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, model or self.engine, vertex, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
412
+ if "function_response:" in chunk:
413
+ function_response = chunk.replace("function_response:", "")
414
+ else:
415
+ yield chunk
416
+ response_role = "model"
417
+ async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, function_arguments=function_call, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS)):
418
+ yield chunk
419
+ else:
420
+ self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)