aient 1.1.58__py3-none-any.whl → 1.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/models/claude.py DELETED
@@ -1,573 +0,0 @@
1
- import os
2
- import re
3
- import json
4
- import copy
5
- import requests
6
-
7
- from .base import BaseLLM
8
- from ..plugins import PLUGINS, get_tools_result_async, claude_tools_list
9
- from ..utils.scripts import check_json, safe_get, async_generator_to_sync
10
-
11
- class claudeConversation(dict):
12
- def Conversation(self, index):
13
- conversation_list = super().__getitem__(index)
14
- return "\n\n" + "\n\n".join([f"{item['role']}:{item['content']}" for item in conversation_list]) + "\n\nAssistant:"
15
-
16
- class claude(BaseLLM):
17
- def __init__(
18
- self,
19
- api_key: str,
20
- engine: str = os.environ.get("GPT_ENGINE") or "claude-2.1",
21
- api_url: str = "https://api.anthropic.com/v1/complete",
22
- system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
23
- temperature: float = 0.5,
24
- top_p: float = 0.7,
25
- timeout: float = 20,
26
- use_plugins: bool = True,
27
- print_log: bool = False,
28
- ):
29
- super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
30
- # self.api_url = api_url
31
- self.conversation = claudeConversation()
32
-
33
- def add_to_conversation(
34
- self,
35
- message: str,
36
- role: str,
37
- convo_id: str = "default",
38
- pass_history: int = 9999,
39
- total_tokens: int = 0,
40
- ) -> None:
41
- """
42
- Add a message to the conversation
43
- """
44
-
45
- if convo_id not in self.conversation or pass_history <= 2:
46
- self.reset(convo_id=convo_id)
47
- self.conversation[convo_id].append({"role": role, "content": message})
48
-
49
- history_len = len(self.conversation[convo_id])
50
- history = pass_history
51
- if pass_history < 2:
52
- history = 2
53
- while history_len > history:
54
- self.conversation[convo_id].pop(1)
55
- history_len = history_len - 1
56
-
57
- if total_tokens:
58
- self.tokens_usage[convo_id] += total_tokens
59
-
60
- def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
61
- """
62
- Reset the conversation
63
- """
64
- self.conversation[convo_id] = claudeConversation()
65
- self.system_prompt = system_prompt or self.system_prompt
66
-
67
- def ask_stream(
68
- self,
69
- prompt: str,
70
- role: str = "Human",
71
- convo_id: str = "default",
72
- model: str = "",
73
- pass_history: int = 9999,
74
- model_max_tokens: int = 4096,
75
- **kwargs,
76
- ):
77
- if convo_id not in self.conversation or pass_history <= 2:
78
- self.reset(convo_id=convo_id)
79
- self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
80
- # self.__truncate_conversation(convo_id=convo_id)
81
- # print(self.conversation[convo_id])
82
-
83
- url = self.api_url
84
- headers = {
85
- "accept": "application/json",
86
- "anthropic-version": "2023-06-01",
87
- "content-type": "application/json",
88
- "x-api-key": f"{kwargs.get('api_key', self.api_key)}",
89
- }
90
-
91
- json_post = {
92
- "model": model or self.engine,
93
- "prompt": self.conversation.Conversation(convo_id) if pass_history else f"\n\nHuman:{prompt}\n\nAssistant:",
94
- "stream": True,
95
- "temperature": kwargs.get("temperature", self.temperature),
96
- "top_p": kwargs.get("top_p", self.top_p),
97
- "max_tokens_to_sample": model_max_tokens,
98
- }
99
-
100
- try:
101
- response = self.session.post(
102
- url,
103
- headers=headers,
104
- json=json_post,
105
- timeout=kwargs.get("timeout", self.timeout),
106
- stream=True,
107
- )
108
- except ConnectionError:
109
- print("连接错误,请检查服务器状态或网络连接。")
110
- return
111
- except requests.exceptions.ReadTimeout:
112
- print("请求超时,请检查网络连接或增加超时时间。{e}")
113
- return
114
- except Exception as e:
115
- print(f"发生了未预料的错误: {e}")
116
- return
117
-
118
- if response.status_code != 200:
119
- print(response.text)
120
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
121
- response_role: str = "Assistant"
122
- full_response: str = ""
123
- for line in response.iter_lines():
124
- if not line or line.decode("utf-8") == "event: completion" or line.decode("utf-8") == "event: ping" or line.decode("utf-8") == "data: {}":
125
- continue
126
- line = line.decode("utf-8")[6:]
127
- # print(line)
128
- resp: dict = json.loads(line)
129
- content = resp.get("completion")
130
- if content:
131
- full_response += content
132
- yield content
133
- self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
134
-
135
- class claude3(BaseLLM):
136
- def __init__(
137
- self,
138
- api_key: str = None,
139
- engine: str = os.environ.get("GPT_ENGINE") or "claude-3-5-sonnet-20241022",
140
- api_url: str = (os.environ.get("CLAUDE_API_URL") or "https://api.anthropic.com/v1/messages"),
141
- system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
142
- temperature: float = 0.5,
143
- timeout: float = 20,
144
- top_p: float = 0.7,
145
- use_plugins: bool = True,
146
- print_log: bool = False,
147
- ):
148
- super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
149
- self.conversation: dict[str, list[dict]] = {
150
- "default": [],
151
- }
152
-
153
- def add_to_conversation(
154
- self,
155
- message: str,
156
- role: str,
157
- convo_id: str = "default",
158
- pass_history: int = 9999,
159
- total_tokens: int = 0,
160
- tools_id= "",
161
- function_name: str = "",
162
- function_full_response: str = "",
163
- ) -> None:
164
- """
165
- Add a message to the conversation
166
- """
167
-
168
- if convo_id not in self.conversation or pass_history <= 2:
169
- self.reset(convo_id=convo_id)
170
- if role == "user" or (role == "assistant" and function_full_response == ""):
171
- if type(message) == list:
172
- self.conversation[convo_id].append({
173
- "role": role,
174
- "content": message
175
- })
176
- if type(message) == str:
177
- self.conversation[convo_id].append({
178
- "role": role,
179
- "content": [{
180
- "type": "text",
181
- "text": message
182
- }]
183
- })
184
- elif role == "assistant" and function_full_response:
185
- print("function_full_response", function_full_response)
186
- function_dict = {
187
- "type": "tool_use",
188
- "id": f"{tools_id}",
189
- "name": f"{function_name}",
190
- "input": json.loads(function_full_response)
191
- # "input": json.dumps(function_full_response, ensure_ascii=False)
192
- }
193
- self.conversation[convo_id].append({"role": role, "content": [function_dict]})
194
- function_dict = {
195
- "type": "tool_result",
196
- "tool_use_id": f"{tools_id}",
197
- "content": f"{message}",
198
- # "is_error": true
199
- }
200
- self.conversation[convo_id].append({"role": "user", "content": [function_dict]})
201
-
202
- conversation_len = len(self.conversation[convo_id]) - 1
203
- message_index = 0
204
- while message_index < conversation_len:
205
- if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
206
- self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
207
- self.conversation[convo_id].pop(message_index + 1)
208
- conversation_len = conversation_len - 1
209
- else:
210
- message_index = message_index + 1
211
-
212
- history_len = len(self.conversation[convo_id])
213
- history = pass_history
214
- if pass_history < 2:
215
- history = 2
216
- while history_len > history:
217
- mess_body = self.conversation[convo_id].pop(1)
218
- history_len = history_len - 1
219
- if mess_body.get("role") == "user":
220
- mess_body = self.conversation[convo_id].pop(1)
221
- history_len = history_len - 1
222
- if safe_get(mess_body, "content", 0, "type") == "tool_use":
223
- self.conversation[convo_id].pop(1)
224
- history_len = history_len - 1
225
-
226
- if total_tokens:
227
- self.tokens_usage[convo_id] += total_tokens
228
-
229
- def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
230
- """
231
- Reset the conversation
232
- """
233
- self.conversation[convo_id] = list()
234
- self.system_prompt = system_prompt or self.system_prompt
235
-
236
- def ask_stream(
237
- self,
238
- prompt: str,
239
- role: str = "user",
240
- convo_id: str = "default",
241
- model: str = "",
242
- pass_history: int = 9999,
243
- model_max_tokens: int = 4096,
244
- tools_id: str = "",
245
- total_tokens: int = 0,
246
- function_name: str = "",
247
- function_full_response: str = "",
248
- language: str = "English",
249
- system_prompt: str = None,
250
- **kwargs,
251
- ):
252
- self.add_to_conversation(prompt, role, convo_id=convo_id, tools_id=tools_id, total_tokens=total_tokens, function_name=function_name, function_full_response=function_full_response, pass_history=pass_history)
253
- # self.__truncate_conversation(convo_id=convo_id)
254
- # print(self.conversation[convo_id])
255
-
256
- url = self.api_url.source_api_url
257
- now_model = model or self.engine
258
- headers = {
259
- "content-type": "application/json",
260
- "x-api-key": f"{kwargs.get('api_key', self.api_key)}",
261
- "anthropic-version": "2023-06-01",
262
- "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in now_model else "tools-2024-05-16",
263
- }
264
-
265
- json_post = {
266
- "model": now_model,
267
- "messages": self.conversation[convo_id] if pass_history else [{
268
- "role": "user",
269
- "content": prompt
270
- }],
271
- "temperature": kwargs.get("temperature", self.temperature),
272
- "top_p": kwargs.get("top_p", self.top_p),
273
- "max_tokens": 8192 if "claude-3-5-sonnet" in now_model else model_max_tokens,
274
- "stream": True,
275
- }
276
- json_post["system"] = system_prompt or self.system_prompt
277
- plugins = kwargs.get("plugins", PLUGINS)
278
- if all(value == False for value in plugins.values()) == False and self.use_plugins:
279
- json_post.update(copy.deepcopy(claude_tools_list["base"]))
280
- for item in plugins.keys():
281
- try:
282
- if plugins[item]:
283
- json_post["tools"].append(claude_tools_list[item])
284
- except:
285
- pass
286
-
287
- if self.print_log:
288
- replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
289
- print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
290
-
291
- try:
292
- response = self.session.post(
293
- url,
294
- headers=headers,
295
- json=json_post,
296
- timeout=kwargs.get("timeout", self.timeout),
297
- stream=True,
298
- )
299
- except ConnectionError:
300
- print("连接错误,请检查服务器状态或网络连接。")
301
- return
302
- except requests.exceptions.ReadTimeout:
303
- print("请求超时,请检查网络连接或增加超时时间。{e}")
304
- return
305
- except Exception as e:
306
- print(f"发生了未预料的错误: {e}")
307
- return
308
-
309
- if response.status_code != 200:
310
- print(response.text)
311
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
312
- response_role: str = "assistant"
313
- full_response: str = ""
314
- need_function_call: bool = False
315
- function_call_name: str = ""
316
- function_full_response: str = ""
317
- total_tokens = 0
318
- tools_id = ""
319
- for line in response.iter_lines():
320
- if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
321
- continue
322
- # print(line.decode("utf-8"))
323
- # if "tool_use" in line.decode("utf-8"):
324
- # tool_input = json.loads(line.decode("utf-8")["content"][1]["input"])
325
- # else:
326
- # line = line.decode("utf-8")[6:]
327
- line = line.decode("utf-8")
328
- line = line.lstrip("data: ")
329
- # print(line)
330
- resp: dict = json.loads(line)
331
- if resp.get("error"):
332
- print("error:", resp["error"])
333
- raise BaseException(f"{resp['error']}")
334
-
335
- message = resp.get("message")
336
- if message:
337
- usage = message.get("usage")
338
- input_tokens = usage.get("input_tokens", 0)
339
- # output_tokens = usage.get("output_tokens", 0)
340
- output_tokens = 0
341
- total_tokens = total_tokens + input_tokens + output_tokens
342
-
343
- usage = resp.get("usage")
344
- if usage:
345
- input_tokens = usage.get("input_tokens", 0)
346
- output_tokens = usage.get("output_tokens", 0)
347
- total_tokens = total_tokens + input_tokens + output_tokens
348
-
349
- # print("\n\rtotal_tokens", total_tokens)
350
-
351
- tool_use = resp.get("content_block")
352
- if tool_use and "tool_use" == tool_use['type']:
353
- # print("tool_use", tool_use)
354
- tools_id = tool_use["id"]
355
- need_function_call = True
356
- if "name" in tool_use:
357
- function_call_name = tool_use["name"]
358
- delta = resp.get("delta")
359
- # print("delta", delta)
360
- if not delta:
361
- continue
362
- if "text" in delta:
363
- content = delta["text"]
364
- full_response += content
365
- yield content
366
- if "partial_json" in delta:
367
- function_call_content = delta["partial_json"]
368
- function_full_response += function_call_content
369
-
370
- # print("function_full_response", function_full_response)
371
- # print("function_call_name", function_call_name)
372
- # print("need_function_call", need_function_call)
373
- if self.print_log:
374
- print("\n\rtotal_tokens", total_tokens)
375
- if need_function_call:
376
- function_full_response = check_json(function_full_response)
377
- print("function_full_response", function_full_response)
378
- function_response = ""
379
- function_call_max_tokens = int(self.truncate_limit / 2)
380
-
381
- # function_response = yield from get_tools_result(function_call_name, function_full_response, function_call_max_tokens, self.engine, claude3, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language)
382
-
383
- async def run_async():
384
- nonlocal function_response
385
- async for chunk in get_tools_result_async(
386
- function_call_name, function_full_response, function_call_max_tokens,
387
- model or self.engine, claude3, kwargs.get('api_key', self.api_key),
388
- self.api_url, use_plugins=False, model=model or self.engine,
389
- add_message=self.add_to_conversation, convo_id=convo_id, language=language
390
- ):
391
- if "function_response:" in chunk:
392
- function_response = chunk.replace("function_response:", "")
393
- else:
394
- yield chunk
395
-
396
- # 使用封装后的函数
397
- for chunk in async_generator_to_sync(run_async()):
398
- yield chunk
399
-
400
- response_role = "assistant"
401
- if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
402
- mess = self.conversation[convo_id].pop(-1)
403
- yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, tools_id=tools_id, function_full_response=function_full_response, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt)
404
- else:
405
- if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
406
- mess = self.conversation[convo_id].pop(-1)
407
- self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
408
- self.function_calls_counter = {}
409
- if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 and ("You are a translation engine" in self.conversation[convo_id][-2]["content"] or (type(self.conversation[convo_id][-2]["content"]) == list and "You are a translation engine" in self.conversation[convo_id][-2]["content"][0]["text"])):
410
- self.conversation[convo_id].pop(-1)
411
- self.conversation[convo_id].pop(-1)
412
-
413
- async def ask_stream_async(
414
- self,
415
- prompt: str,
416
- role: str = "user",
417
- convo_id: str = "default",
418
- model: str = "",
419
- pass_history: int = 9999,
420
- model_max_tokens: int = 4096,
421
- tools_id: str = "",
422
- total_tokens: int = 0,
423
- function_name: str = "",
424
- function_full_response: str = "",
425
- language: str = "English",
426
- system_prompt: str = None,
427
- **kwargs,
428
- ):
429
- self.add_to_conversation(prompt, role, convo_id=convo_id, tools_id=tools_id, total_tokens=total_tokens, function_name=function_name, function_full_response=function_full_response, pass_history=pass_history)
430
- # self.__truncate_conversation(convo_id=convo_id)
431
- # print(self.conversation[convo_id])
432
-
433
- url = self.api_url.source_api_url
434
- now_model = model or self.engine
435
- headers = {
436
- "content-type": "application/json",
437
- "x-api-key": f"{kwargs.get('api_key', self.api_key)}",
438
- "anthropic-version": "2023-06-01",
439
- "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in now_model else "tools-2024-05-16",
440
- }
441
-
442
- json_post = {
443
- "model": now_model,
444
- "messages": self.conversation[convo_id] if pass_history else [{
445
- "role": "user",
446
- "content": prompt
447
- }],
448
- "temperature": kwargs.get("temperature", self.temperature),
449
- "top_p": kwargs.get("top_p", self.top_p),
450
- "max_tokens": 8192 if "claude-3-5-sonnet" in now_model else model_max_tokens,
451
- "stream": True,
452
- }
453
- json_post["system"] = system_prompt or self.system_prompt
454
- plugins = kwargs.get("plugins", PLUGINS)
455
- if all(value == False for value in plugins.values()) == False and self.use_plugins:
456
- json_post.update(copy.deepcopy(claude_tools_list["base"]))
457
- for item in plugins.keys():
458
- try:
459
- if plugins[item]:
460
- json_post["tools"].append(claude_tools_list[item])
461
- except:
462
- pass
463
-
464
- if self.print_log:
465
- replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
466
- print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
467
-
468
- try:
469
- response = self.session.post(
470
- url,
471
- headers=headers,
472
- json=json_post,
473
- timeout=kwargs.get("timeout", self.timeout),
474
- stream=True,
475
- )
476
- except ConnectionError:
477
- print("连接错误,请检查服务器状态或网络连接。")
478
- return
479
- except requests.exceptions.ReadTimeout:
480
- print("请求超时,请检查网络连接或增加超时时间。{e}")
481
- return
482
- except Exception as e:
483
- print(f"发生了未预料的错误: {e}")
484
- return
485
-
486
- if response.status_code != 200:
487
- print(response.text)
488
- raise BaseException(f"{response.status_code} {response.reason} {response.text}")
489
- response_role: str = "assistant"
490
- full_response: str = ""
491
- need_function_call: bool = False
492
- function_call_name: str = ""
493
- function_full_response: str = ""
494
- total_tokens = 0
495
- tools_id = ""
496
- for line in response.iter_lines():
497
- if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
498
- continue
499
- # print(line.decode("utf-8"))
500
- # if "tool_use" in line.decode("utf-8"):
501
- # tool_input = json.loads(line.decode("utf-8")["content"][1]["input"])
502
- # else:
503
- # line = line.decode("utf-8")[6:]
504
- line = line.decode("utf-8")[5:]
505
- if line.startswith(" "):
506
- line = line[1:]
507
- # print(line)
508
- resp: dict = json.loads(line)
509
- if resp.get("error"):
510
- print("error:", resp["error"])
511
- raise BaseException(f"{resp['error']}")
512
-
513
- message = resp.get("message")
514
- if message:
515
- usage = message.get("usage")
516
- input_tokens = usage.get("input_tokens", 0)
517
- # output_tokens = usage.get("output_tokens", 0)
518
- output_tokens = 0
519
- total_tokens = total_tokens + input_tokens + output_tokens
520
-
521
- usage = resp.get("usage")
522
- if usage:
523
- input_tokens = usage.get("input_tokens", 0)
524
- output_tokens = usage.get("output_tokens", 0)
525
- total_tokens = total_tokens + input_tokens + output_tokens
526
- if self.print_log:
527
- print("\n\rtotal_tokens", total_tokens)
528
-
529
- tool_use = resp.get("content_block")
530
- if tool_use and "tool_use" == tool_use['type']:
531
- # print("tool_use", tool_use)
532
- tools_id = tool_use["id"]
533
- need_function_call = True
534
- if "name" in tool_use:
535
- function_call_name = tool_use["name"]
536
- delta = resp.get("delta")
537
- # print("delta", delta)
538
- if not delta:
539
- continue
540
- if "text" in delta:
541
- content = delta["text"]
542
- full_response += content
543
- yield content
544
- if "partial_json" in delta:
545
- function_call_content = delta["partial_json"]
546
- function_full_response += function_call_content
547
- # print("function_full_response", function_full_response)
548
- # print("function_call_name", function_call_name)
549
- # print("need_function_call", need_function_call)
550
- if need_function_call:
551
- function_full_response = check_json(function_full_response)
552
- print("function_full_response", function_full_response)
553
- function_response = ""
554
- function_call_max_tokens = int(self.truncate_limit / 2)
555
- async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, self.engine, claude3, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
556
- if "function_response:" in chunk:
557
- function_response = chunk.replace("function_response:", "")
558
- else:
559
- yield chunk
560
- response_role = "assistant"
561
- if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
562
- mess = self.conversation[convo_id].pop(-1)
563
- async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, tools_id=tools_id, function_full_response=function_full_response, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt):
564
- yield chunk
565
- # yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, tools_id=tools_id, function_full_response=function_full_response)
566
- else:
567
- if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
568
- mess = self.conversation[convo_id].pop(-1)
569
- self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
570
- self.function_calls_counter = {}
571
- if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 and ("You are a translation engine" in self.conversation[convo_id][-2]["content"] or (type(self.conversation[convo_id][-2]["content"]) == list and "You are a translation engine" in self.conversation[convo_id][-2]["content"][0]["text"])):
572
- self.conversation[convo_id].pop(-1)
573
- self.conversation[convo_id].pop(-1)