aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/models/claude.py ADDED
@@ -0,0 +1,640 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import copy
5
+ import tiktoken
6
+ import requests
7
+
8
+ from .base import BaseLLM
9
+ from ..plugins import PLUGINS, get_tools_result_async, claude_tools_list
10
+ from ..utils.scripts import check_json, safe_get, async_generator_to_sync
11
+
12
+ class claudeConversation(dict):
13
+ def Conversation(self, index):
14
+ conversation_list = super().__getitem__(index)
15
+ return "\n\n" + "\n\n".join([f"{item['role']}:{item['content']}" for item in conversation_list]) + "\n\nAssistant:"
16
+
17
+ class claude(BaseLLM):
18
+ def __init__(
19
+ self,
20
+ api_key: str,
21
+ engine: str = os.environ.get("GPT_ENGINE") or "claude-2.1",
22
+ api_url: str = "https://api.anthropic.com/v1/complete",
23
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
24
+ temperature: float = 0.5,
25
+ top_p: float = 0.7,
26
+ timeout: float = 20,
27
+ use_plugins: bool = True,
28
+ print_log: bool = False,
29
+ ):
30
+ super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
31
+ # self.api_url = api_url
32
+ self.conversation = claudeConversation()
33
+
34
+ def add_to_conversation(
35
+ self,
36
+ message: str,
37
+ role: str,
38
+ convo_id: str = "default",
39
+ pass_history: int = 9999,
40
+ total_tokens: int = 0,
41
+ ) -> None:
42
+ """
43
+ Add a message to the conversation
44
+ """
45
+
46
+ if convo_id not in self.conversation or pass_history <= 2:
47
+ self.reset(convo_id=convo_id)
48
+ self.conversation[convo_id].append({"role": role, "content": message})
49
+
50
+ history_len = len(self.conversation[convo_id])
51
+ history = pass_history
52
+ if pass_history < 2:
53
+ history = 2
54
+ while history_len > history:
55
+ self.conversation[convo_id].pop(1)
56
+ history_len = history_len - 1
57
+
58
+ if total_tokens:
59
+ self.tokens_usage[convo_id] += total_tokens
60
+
61
+ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
62
+ """
63
+ Reset the conversation
64
+ """
65
+ self.conversation[convo_id] = claudeConversation()
66
+ self.system_prompt = system_prompt or self.system_prompt
67
+
68
+ def __truncate_conversation(self, convo_id: str = "default") -> None:
69
+ """
70
+ Truncate the conversation
71
+ """
72
+ while True:
73
+ if (
74
+ self.get_token_count(convo_id) > self.truncate_limit
75
+ and len(self.conversation[convo_id]) > 1
76
+ ):
77
+ # Don't remove the first message
78
+ self.conversation[convo_id].pop(1)
79
+ else:
80
+ break
81
+
82
+ def get_token_count(self, convo_id: str = "default") -> int:
83
+ """
84
+ Get token count
85
+ """
86
+ tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
87
+ encoding = tiktoken.encoding_for_model(self.engine)
88
+
89
+ num_tokens = 0
90
+ for message in self.conversation[convo_id]:
91
+ # every message follows <im_start>{role/name}\n{content}<im_end>\n
92
+ num_tokens += 5
93
+ for key, value in message.items():
94
+ if value:
95
+ num_tokens += len(encoding.encode(value))
96
+ if key == "name": # if there's a name, the role is omitted
97
+ num_tokens += 5 # role is always required and always 1 token
98
+ num_tokens += 5 # every reply is primed with <im_start>assistant
99
+ return num_tokens
100
+
101
+ def ask_stream(
102
+ self,
103
+ prompt: str,
104
+ role: str = "Human",
105
+ convo_id: str = "default",
106
+ model: str = "",
107
+ pass_history: int = 9999,
108
+ model_max_tokens: int = 4096,
109
+ **kwargs,
110
+ ):
111
+ if convo_id not in self.conversation or pass_history <= 2:
112
+ self.reset(convo_id=convo_id)
113
+ self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
114
+ # self.__truncate_conversation(convo_id=convo_id)
115
+ # print(self.conversation[convo_id])
116
+
117
+ url = self.api_url
118
+ headers = {
119
+ "accept": "application/json",
120
+ "anthropic-version": "2023-06-01",
121
+ "content-type": "application/json",
122
+ "x-api-key": f"{kwargs.get('api_key', self.api_key)}",
123
+ }
124
+
125
+ json_post = {
126
+ "model": model or self.engine,
127
+ "prompt": self.conversation.Conversation(convo_id) if pass_history else f"\n\nHuman:{prompt}\n\nAssistant:",
128
+ "stream": True,
129
+ "temperature": kwargs.get("temperature", self.temperature),
130
+ "top_p": kwargs.get("top_p", self.top_p),
131
+ "max_tokens_to_sample": model_max_tokens,
132
+ }
133
+
134
+ try:
135
+ response = self.session.post(
136
+ url,
137
+ headers=headers,
138
+ json=json_post,
139
+ timeout=kwargs.get("timeout", self.timeout),
140
+ stream=True,
141
+ )
142
+ except ConnectionError:
143
+ print("连接错误,请检查服务器状态或网络连接。")
144
+ return
145
+ except requests.exceptions.ReadTimeout:
146
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
147
+ return
148
+ except Exception as e:
149
+ print(f"发生了未预料的错误: {e}")
150
+ return
151
+
152
+ if response.status_code != 200:
153
+ print(response.text)
154
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
155
+ response_role: str = "Assistant"
156
+ full_response: str = ""
157
+ for line in response.iter_lines():
158
+ if not line or line.decode("utf-8") == "event: completion" or line.decode("utf-8") == "event: ping" or line.decode("utf-8") == "data: {}":
159
+ continue
160
+ line = line.decode("utf-8")[6:]
161
+ # print(line)
162
+ resp: dict = json.loads(line)
163
+ content = resp.get("completion")
164
+ if content:
165
+ full_response += content
166
+ yield content
167
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
168
+
169
+ class claude3(BaseLLM):
170
+ def __init__(
171
+ self,
172
+ api_key: str = None,
173
+ engine: str = os.environ.get("GPT_ENGINE") or "claude-3-5-sonnet-20241022",
174
+ api_url: str = (os.environ.get("CLAUDE_API_URL") or "https://api.anthropic.com/v1/messages"),
175
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
176
+ temperature: float = 0.5,
177
+ timeout: float = 20,
178
+ top_p: float = 0.7,
179
+ use_plugins: bool = True,
180
+ print_log: bool = False,
181
+ ):
182
+ super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
183
+ self.conversation: dict[str, list[dict]] = {
184
+ "default": [],
185
+ }
186
+
187
+ def add_to_conversation(
188
+ self,
189
+ message: str,
190
+ role: str,
191
+ convo_id: str = "default",
192
+ pass_history: int = 9999,
193
+ total_tokens: int = 0,
194
+ tools_id= "",
195
+ function_name: str = "",
196
+ function_full_response: str = "",
197
+ ) -> None:
198
+ """
199
+ Add a message to the conversation
200
+ """
201
+
202
+ if convo_id not in self.conversation or pass_history <= 2:
203
+ self.reset(convo_id=convo_id)
204
+ if role == "user" or (role == "assistant" and function_full_response == ""):
205
+ if type(message) == list:
206
+ self.conversation[convo_id].append({
207
+ "role": role,
208
+ "content": message
209
+ })
210
+ if type(message) == str:
211
+ self.conversation[convo_id].append({
212
+ "role": role,
213
+ "content": [{
214
+ "type": "text",
215
+ "text": message
216
+ }]
217
+ })
218
+ elif role == "assistant" and function_full_response:
219
+ print("function_full_response", function_full_response)
220
+ function_dict = {
221
+ "type": "tool_use",
222
+ "id": f"{tools_id}",
223
+ "name": f"{function_name}",
224
+ "input": json.loads(function_full_response)
225
+ # "input": json.dumps(function_full_response, ensure_ascii=False)
226
+ }
227
+ self.conversation[convo_id].append({"role": role, "content": [function_dict]})
228
+ function_dict = {
229
+ "type": "tool_result",
230
+ "tool_use_id": f"{tools_id}",
231
+ "content": f"{message}",
232
+ # "is_error": true
233
+ }
234
+ self.conversation[convo_id].append({"role": "user", "content": [function_dict]})
235
+
236
+ conversation_len = len(self.conversation[convo_id]) - 1
237
+ message_index = 0
238
+ while message_index < conversation_len:
239
+ if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
240
+ self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
241
+ self.conversation[convo_id].pop(message_index + 1)
242
+ conversation_len = conversation_len - 1
243
+ else:
244
+ message_index = message_index + 1
245
+
246
+ history_len = len(self.conversation[convo_id])
247
+ history = pass_history
248
+ if pass_history < 2:
249
+ history = 2
250
+ while history_len > history:
251
+ mess_body = self.conversation[convo_id].pop(1)
252
+ history_len = history_len - 1
253
+ if mess_body.get("role") == "user":
254
+ mess_body = self.conversation[convo_id].pop(1)
255
+ history_len = history_len - 1
256
+ if safe_get(mess_body, "content", 0, "type") == "tool_use":
257
+ self.conversation[convo_id].pop(1)
258
+ history_len = history_len - 1
259
+
260
+ if total_tokens:
261
+ self.tokens_usage[convo_id] += total_tokens
262
+
263
+ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
264
+ """
265
+ Reset the conversation
266
+ """
267
+ self.conversation[convo_id] = list()
268
+ self.system_prompt = system_prompt or self.system_prompt
269
+
270
+ def __truncate_conversation(self, convo_id: str = "default") -> None:
271
+ """
272
+ Truncate the conversation
273
+ """
274
+ while True:
275
+ if (
276
+ self.get_token_count(convo_id) > self.truncate_limit
277
+ and len(self.conversation[convo_id]) > 1
278
+ ):
279
+ # Don't remove the first message
280
+ self.conversation[convo_id].pop(1)
281
+ else:
282
+ break
283
+
284
+ def get_token_count(self, convo_id: str = "default") -> int:
285
+ """
286
+ Get token count
287
+ """
288
+ tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
289
+ encoding = tiktoken.encoding_for_model(self.engine)
290
+
291
+ num_tokens = 0
292
+ for message in self.conversation[convo_id]:
293
+ # every message follows <im_start>{role/name}\n{content}<im_end>\n
294
+ num_tokens += 5
295
+ for key, value in message.items():
296
+ if value:
297
+ num_tokens += len(encoding.encode(value))
298
+ if key == "name": # if there's a name, the role is omitted
299
+ num_tokens += 5 # role is always required and always 1 token
300
+ num_tokens += 5 # every reply is primed with <im_start>assistant
301
+ return num_tokens
302
+
303
+ def ask_stream(
304
+ self,
305
+ prompt: str,
306
+ role: str = "user",
307
+ convo_id: str = "default",
308
+ model: str = "",
309
+ pass_history: int = 9999,
310
+ model_max_tokens: int = 4096,
311
+ tools_id: str = "",
312
+ total_tokens: int = 0,
313
+ function_name: str = "",
314
+ function_full_response: str = "",
315
+ language: str = "English",
316
+ system_prompt: str = None,
317
+ **kwargs,
318
+ ):
319
+ self.add_to_conversation(prompt, role, convo_id=convo_id, tools_id=tools_id, total_tokens=total_tokens, function_name=function_name, function_full_response=function_full_response, pass_history=pass_history)
320
+ # self.__truncate_conversation(convo_id=convo_id)
321
+ # print(self.conversation[convo_id])
322
+
323
+ url = self.api_url.source_api_url
324
+ now_model = model or self.engine
325
+ headers = {
326
+ "content-type": "application/json",
327
+ "x-api-key": f"{kwargs.get('api_key', self.api_key)}",
328
+ "anthropic-version": "2023-06-01",
329
+ "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in now_model else "tools-2024-05-16",
330
+ }
331
+
332
+ json_post = {
333
+ "model": now_model,
334
+ "messages": self.conversation[convo_id] if pass_history else [{
335
+ "role": "user",
336
+ "content": prompt
337
+ }],
338
+ "temperature": kwargs.get("temperature", self.temperature),
339
+ "top_p": kwargs.get("top_p", self.top_p),
340
+ "max_tokens": 8192 if "claude-3-5-sonnet" in now_model else model_max_tokens,
341
+ "stream": True,
342
+ }
343
+ json_post["system"] = system_prompt or self.system_prompt
344
+ plugins = kwargs.get("plugins", PLUGINS)
345
+ if all(value == False for value in plugins.values()) == False and self.use_plugins:
346
+ json_post.update(copy.deepcopy(claude_tools_list["base"]))
347
+ for item in plugins.keys():
348
+ try:
349
+ if plugins[item]:
350
+ json_post["tools"].append(claude_tools_list[item])
351
+ except:
352
+ pass
353
+
354
+ if self.print_log:
355
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
356
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
357
+
358
+ try:
359
+ response = self.session.post(
360
+ url,
361
+ headers=headers,
362
+ json=json_post,
363
+ timeout=kwargs.get("timeout", self.timeout),
364
+ stream=True,
365
+ )
366
+ except ConnectionError:
367
+ print("连接错误,请检查服务器状态或网络连接。")
368
+ return
369
+ except requests.exceptions.ReadTimeout:
370
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
371
+ return
372
+ except Exception as e:
373
+ print(f"发生了未预料的错误: {e}")
374
+ return
375
+
376
+ if response.status_code != 200:
377
+ print(response.text)
378
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
379
+ response_role: str = "assistant"
380
+ full_response: str = ""
381
+ need_function_call: bool = False
382
+ function_call_name: str = ""
383
+ function_full_response: str = ""
384
+ total_tokens = 0
385
+ tools_id = ""
386
+ for line in response.iter_lines():
387
+ if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
388
+ continue
389
+ # print(line.decode("utf-8"))
390
+ # if "tool_use" in line.decode("utf-8"):
391
+ # tool_input = json.loads(line.decode("utf-8")["content"][1]["input"])
392
+ # else:
393
+ # line = line.decode("utf-8")[6:]
394
+ line = line.decode("utf-8")
395
+ line = line.lstrip("data: ")
396
+ # print(line)
397
+ resp: dict = json.loads(line)
398
+ if resp.get("error"):
399
+ print("error:", resp["error"])
400
+ raise BaseException(f"{resp['error']}")
401
+
402
+ message = resp.get("message")
403
+ if message:
404
+ usage = message.get("usage")
405
+ input_tokens = usage.get("input_tokens", 0)
406
+ # output_tokens = usage.get("output_tokens", 0)
407
+ output_tokens = 0
408
+ total_tokens = total_tokens + input_tokens + output_tokens
409
+
410
+ usage = resp.get("usage")
411
+ if usage:
412
+ input_tokens = usage.get("input_tokens", 0)
413
+ output_tokens = usage.get("output_tokens", 0)
414
+ total_tokens = total_tokens + input_tokens + output_tokens
415
+
416
+ # print("\n\rtotal_tokens", total_tokens)
417
+
418
+ tool_use = resp.get("content_block")
419
+ if tool_use and "tool_use" == tool_use['type']:
420
+ # print("tool_use", tool_use)
421
+ tools_id = tool_use["id"]
422
+ need_function_call = True
423
+ if "name" in tool_use:
424
+ function_call_name = tool_use["name"]
425
+ delta = resp.get("delta")
426
+ # print("delta", delta)
427
+ if not delta:
428
+ continue
429
+ if "text" in delta:
430
+ content = delta["text"]
431
+ full_response += content
432
+ yield content
433
+ if "partial_json" in delta:
434
+ function_call_content = delta["partial_json"]
435
+ function_full_response += function_call_content
436
+
437
+ # print("function_full_response", function_full_response)
438
+ # print("function_call_name", function_call_name)
439
+ # print("need_function_call", need_function_call)
440
+ if self.print_log:
441
+ print("\n\rtotal_tokens", total_tokens)
442
+ if need_function_call:
443
+ function_full_response = check_json(function_full_response)
444
+ print("function_full_response", function_full_response)
445
+ function_response = ""
446
+ function_call_max_tokens = int(self.truncate_limit / 2)
447
+
448
+ # function_response = yield from get_tools_result(function_call_name, function_full_response, function_call_max_tokens, self.engine, claude3, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language)
449
+
450
+ async def run_async():
451
+ nonlocal function_response
452
+ async for chunk in get_tools_result_async(
453
+ function_call_name, function_full_response, function_call_max_tokens,
454
+ model or self.engine, claude3, kwargs.get('api_key', self.api_key),
455
+ self.api_url, use_plugins=False, model=model or self.engine,
456
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
457
+ ):
458
+ if "function_response:" in chunk:
459
+ function_response = chunk.replace("function_response:", "")
460
+ else:
461
+ yield chunk
462
+
463
+ # 使用封装后的函数
464
+ for chunk in async_generator_to_sync(run_async()):
465
+ yield chunk
466
+
467
+ response_role = "assistant"
468
+ if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
469
+ mess = self.conversation[convo_id].pop(-1)
470
+ yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, tools_id=tools_id, function_full_response=function_full_response, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt)
471
+ else:
472
+ if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
473
+ mess = self.conversation[convo_id].pop(-1)
474
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
475
+ self.function_calls_counter = {}
476
+ if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 and ("You are a translation engine" in self.conversation[convo_id][-2]["content"] or (type(self.conversation[convo_id][-2]["content"]) == list and "You are a translation engine" in self.conversation[convo_id][-2]["content"][0]["text"])):
477
+ self.conversation[convo_id].pop(-1)
478
+ self.conversation[convo_id].pop(-1)
479
+
480
+ async def ask_stream_async(
481
+ self,
482
+ prompt: str,
483
+ role: str = "user",
484
+ convo_id: str = "default",
485
+ model: str = "",
486
+ pass_history: int = 9999,
487
+ model_max_tokens: int = 4096,
488
+ tools_id: str = "",
489
+ total_tokens: int = 0,
490
+ function_name: str = "",
491
+ function_full_response: str = "",
492
+ language: str = "English",
493
+ system_prompt: str = None,
494
+ **kwargs,
495
+ ):
496
+ self.add_to_conversation(prompt, role, convo_id=convo_id, tools_id=tools_id, total_tokens=total_tokens, function_name=function_name, function_full_response=function_full_response, pass_history=pass_history)
497
+ # self.__truncate_conversation(convo_id=convo_id)
498
+ # print(self.conversation[convo_id])
499
+
500
+ url = self.api_url.source_api_url
501
+ now_model = model or self.engine
502
+ headers = {
503
+ "content-type": "application/json",
504
+ "x-api-key": f"{kwargs.get('api_key', self.api_key)}",
505
+ "anthropic-version": "2023-06-01",
506
+ "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in now_model else "tools-2024-05-16",
507
+ }
508
+
509
+ json_post = {
510
+ "model": now_model,
511
+ "messages": self.conversation[convo_id] if pass_history else [{
512
+ "role": "user",
513
+ "content": prompt
514
+ }],
515
+ "temperature": kwargs.get("temperature", self.temperature),
516
+ "top_p": kwargs.get("top_p", self.top_p),
517
+ "max_tokens": 8192 if "claude-3-5-sonnet" in now_model else model_max_tokens,
518
+ "stream": True,
519
+ }
520
+ json_post["system"] = system_prompt or self.system_prompt
521
+ plugins = kwargs.get("plugins", PLUGINS)
522
+ if all(value == False for value in plugins.values()) == False and self.use_plugins:
523
+ json_post.update(copy.deepcopy(claude_tools_list["base"]))
524
+ for item in plugins.keys():
525
+ try:
526
+ if plugins[item]:
527
+ json_post["tools"].append(claude_tools_list[item])
528
+ except:
529
+ pass
530
+
531
+ if self.print_log:
532
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
533
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
534
+
535
+ try:
536
+ response = self.session.post(
537
+ url,
538
+ headers=headers,
539
+ json=json_post,
540
+ timeout=kwargs.get("timeout", self.timeout),
541
+ stream=True,
542
+ )
543
+ except ConnectionError:
544
+ print("连接错误,请检查服务器状态或网络连接。")
545
+ return
546
+ except requests.exceptions.ReadTimeout:
547
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
548
+ return
549
+ except Exception as e:
550
+ print(f"发生了未预料的错误: {e}")
551
+ return
552
+
553
+ if response.status_code != 200:
554
+ print(response.text)
555
+ raise BaseException(f"{response.status_code} {response.reason} {response.text}")
556
+ response_role: str = "assistant"
557
+ full_response: str = ""
558
+ need_function_call: bool = False
559
+ function_call_name: str = ""
560
+ function_full_response: str = ""
561
+ total_tokens = 0
562
+ tools_id = ""
563
+ for line in response.iter_lines():
564
+ if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
565
+ continue
566
+ # print(line.decode("utf-8"))
567
+ # if "tool_use" in line.decode("utf-8"):
568
+ # tool_input = json.loads(line.decode("utf-8")["content"][1]["input"])
569
+ # else:
570
+ # line = line.decode("utf-8")[6:]
571
+ line = line.decode("utf-8")[5:]
572
+ if line.startswith(" "):
573
+ line = line[1:]
574
+ # print(line)
575
+ resp: dict = json.loads(line)
576
+ if resp.get("error"):
577
+ print("error:", resp["error"])
578
+ raise BaseException(f"{resp['error']}")
579
+
580
+ message = resp.get("message")
581
+ if message:
582
+ usage = message.get("usage")
583
+ input_tokens = usage.get("input_tokens", 0)
584
+ # output_tokens = usage.get("output_tokens", 0)
585
+ output_tokens = 0
586
+ total_tokens = total_tokens + input_tokens + output_tokens
587
+
588
+ usage = resp.get("usage")
589
+ if usage:
590
+ input_tokens = usage.get("input_tokens", 0)
591
+ output_tokens = usage.get("output_tokens", 0)
592
+ total_tokens = total_tokens + input_tokens + output_tokens
593
+ if self.print_log:
594
+ print("\n\rtotal_tokens", total_tokens)
595
+
596
+ tool_use = resp.get("content_block")
597
+ if tool_use and "tool_use" == tool_use['type']:
598
+ # print("tool_use", tool_use)
599
+ tools_id = tool_use["id"]
600
+ need_function_call = True
601
+ if "name" in tool_use:
602
+ function_call_name = tool_use["name"]
603
+ delta = resp.get("delta")
604
+ # print("delta", delta)
605
+ if not delta:
606
+ continue
607
+ if "text" in delta:
608
+ content = delta["text"]
609
+ full_response += content
610
+ yield content
611
+ if "partial_json" in delta:
612
+ function_call_content = delta["partial_json"]
613
+ function_full_response += function_call_content
614
+ # print("function_full_response", function_full_response)
615
+ # print("function_call_name", function_call_name)
616
+ # print("need_function_call", need_function_call)
617
+ if need_function_call:
618
+ function_full_response = check_json(function_full_response)
619
+ print("function_full_response", function_full_response)
620
+ function_response = ""
621
+ function_call_max_tokens = int(self.truncate_limit / 2)
622
+ async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, self.engine, claude3, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
623
+ if "function_response:" in chunk:
624
+ function_response = chunk.replace("function_response:", "")
625
+ else:
626
+ yield chunk
627
+ response_role = "assistant"
628
+ if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
629
+ mess = self.conversation[convo_id].pop(-1)
630
+ async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, tools_id=tools_id, function_full_response=function_full_response, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt):
631
+ yield chunk
632
+ # yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, tools_id=tools_id, function_full_response=function_full_response)
633
+ else:
634
+ if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
635
+ mess = self.conversation[convo_id].pop(-1)
636
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
637
+ self.function_calls_counter = {}
638
+ if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 and ("You are a translation engine" in self.conversation[convo_id][-2]["content"] or (type(self.conversation[convo_id][-2]["content"]) == list and "You are a translation engine" in self.conversation[convo_id][-2]["content"][0]["text"])):
639
+ self.conversation[convo_id].pop(-1)
640
+ self.conversation[convo_id].pop(-1)