camel-ai 0.2.20a1__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (42) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +2 -3
  3. camel/agents/knowledge_graph_agent.py +1 -5
  4. camel/benchmarks/apibench.py +1 -5
  5. camel/benchmarks/nexus.py +1 -5
  6. camel/benchmarks/ragbench.py +2 -2
  7. camel/bots/telegram_bot.py +1 -5
  8. camel/configs/__init__.py +3 -0
  9. camel/configs/aiml_config.py +80 -0
  10. camel/datagen/__init__.py +3 -1
  11. camel/datagen/self_improving_cot.py +821 -0
  12. camel/datagen/self_instruct/self_instruct.py +1 -1
  13. camel/embeddings/openai_embedding.py +10 -1
  14. camel/interpreters/docker/Dockerfile +12 -0
  15. camel/interpreters/docker_interpreter.py +19 -1
  16. camel/interpreters/subprocess_interpreter.py +97 -6
  17. camel/loaders/__init__.py +2 -0
  18. camel/loaders/mineru_extractor.py +250 -0
  19. camel/models/__init__.py +2 -0
  20. camel/models/aiml_model.py +147 -0
  21. camel/models/base_model.py +54 -1
  22. camel/models/deepseek_model.py +0 -18
  23. camel/models/model_factory.py +3 -0
  24. camel/models/siliconflow_model.py +1 -1
  25. camel/societies/workforce/role_playing_worker.py +2 -4
  26. camel/societies/workforce/single_agent_worker.py +1 -6
  27. camel/societies/workforce/workforce.py +3 -9
  28. camel/toolkits/__init__.py +5 -0
  29. camel/toolkits/mineru_toolkit.py +178 -0
  30. camel/toolkits/reddit_toolkit.py +8 -38
  31. camel/toolkits/sympy_toolkit.py +816 -0
  32. camel/toolkits/whatsapp_toolkit.py +11 -32
  33. camel/types/enums.py +25 -1
  34. camel/utils/__init__.py +7 -2
  35. camel/utils/commons.py +198 -21
  36. camel/utils/deduplication.py +232 -0
  37. camel/utils/token_counting.py +0 -38
  38. {camel_ai-0.2.20a1.dist-info → camel_ai-0.2.22.dist-info}/METADATA +10 -13
  39. {camel_ai-0.2.20a1.dist-info → camel_ai-0.2.22.dist-info}/RECORD +42 -34
  40. {camel_ai-0.2.20a1.dist-info → camel_ai-0.2.22.dist-info}/WHEEL +1 -1
  41. /camel/datagen/{cotdatagen.py → cot_datagen.py} +0 -0
  42. {camel_ai-0.2.20a1.dist-info → camel_ai-0.2.22.dist-info}/LICENSE +0 -0
@@ -361,7 +361,7 @@ class SelfInstructPipeline:
361
361
  in JSON format.
362
362
  """
363
363
  with open(self.data_output_path, 'w') as f:
364
- json.dump(self.machine_tasks, f, indent=4)
364
+ json.dump(self.machine_tasks, f, indent=4, ensure_ascii=False)
365
365
 
366
366
  def generate(self):
367
367
  r"""Execute the entire pipeline to generate machine instructions
@@ -30,6 +30,8 @@ class OpenAIEmbedding(BaseEmbedding[str]):
30
30
  model_type (EmbeddingModelType, optional): The model type to be
31
31
  used for text embeddings.
32
32
  (default: :obj:`TEXT_EMBEDDING_3_SMALL`)
33
+ url (Optional[str], optional): The url to the OpenAI service.
34
+ (default: :obj:`None`)
33
35
  api_key (str, optional): The API key for authenticating with the
34
36
  OpenAI service. (default: :obj:`None`)
35
37
  dimensions (int, optional): The text embedding output dimensions.
@@ -49,6 +51,7 @@ class OpenAIEmbedding(BaseEmbedding[str]):
49
51
  model_type: EmbeddingModelType = (
50
52
  EmbeddingModelType.TEXT_EMBEDDING_3_SMALL
51
53
  ),
54
+ url: str | None = None,
52
55
  api_key: str | None = None,
53
56
  dimensions: int | NotGiven = NOT_GIVEN,
54
57
  ) -> None:
@@ -61,7 +64,13 @@ class OpenAIEmbedding(BaseEmbedding[str]):
61
64
  assert isinstance(dimensions, int)
62
65
  self.output_dim = dimensions
63
66
  self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
64
- self.client = OpenAI(timeout=180, max_retries=3, api_key=self._api_key)
67
+ self._url = url or os.environ.get("OPENAI_API_BASE_URL")
68
+ self.client = OpenAI(
69
+ timeout=180,
70
+ max_retries=3,
71
+ base_url=self._url,
72
+ api_key=self._api_key,
73
+ )
65
74
 
66
75
  def embed_list(
67
76
  self,
@@ -0,0 +1,12 @@
1
+ FROM python:3.9-slim
2
+
3
+ # Install R and required dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ r-base \
6
+ && rm -rf /var/lib/apt/lists/*
7
+
8
+ # Set working directory
9
+ WORKDIR /workspace
10
+
11
+ # Keep container running
12
+ CMD ["tail", "-f", "/dev/null"]
@@ -52,11 +52,13 @@ class DockerInterpreter(BaseInterpreter):
52
52
  _CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
53
53
  "python": "python {file_name}",
54
54
  "bash": "bash {file_name}",
55
+ "r": "Rscript {file_name}",
55
56
  }
56
57
 
57
58
  _CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
58
59
  "python": "py",
59
60
  "bash": "sh",
61
+ "r": "R",
60
62
  }
61
63
 
62
64
  _CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
@@ -67,6 +69,8 @@ class DockerInterpreter(BaseInterpreter):
67
69
  "shell": "bash",
68
70
  "bash": "bash",
69
71
  "sh": "bash",
72
+ "r": "r",
73
+ "R": "r",
70
74
  }
71
75
 
72
76
  def __init__(
@@ -104,8 +108,22 @@ class DockerInterpreter(BaseInterpreter):
104
108
  import docker
105
109
 
106
110
  client = docker.from_env()
111
+
112
+ # Build custom image with Python and R
113
+ dockerfile_path = Path(__file__).parent / "docker"
114
+ image_tag = "camel-interpreter:latest"
115
+ try:
116
+ client.images.get(image_tag)
117
+ except docker.errors.ImageNotFound:
118
+ logger.info("Building custom interpreter image...")
119
+ client.images.build(
120
+ path=str(dockerfile_path),
121
+ tag=image_tag,
122
+ rm=True,
123
+ )
124
+
107
125
  self._container = client.containers.run(
108
- "python:3.10",
126
+ image_tag,
109
127
  detach=True,
110
128
  name=f"camel-interpreter-{uuid.uuid4()}",
111
129
  command="tail -f /dev/null",
@@ -48,11 +48,13 @@ class SubprocessInterpreter(BaseInterpreter):
48
48
  _CODE_EXECUTE_CMD_MAPPING: ClassVar[Dict[str, str]] = {
49
49
  "python": "python {file_name}",
50
50
  "bash": "bash {file_name}",
51
+ "r": "Rscript {file_name}",
51
52
  }
52
53
 
53
54
  _CODE_EXTENSION_MAPPING: ClassVar[Dict[str, str]] = {
54
55
  "python": "py",
55
56
  "bash": "sh",
57
+ "r": "R",
56
58
  }
57
59
 
58
60
  _CODE_TYPE_MAPPING: ClassVar[Dict[str, str]] = {
@@ -63,6 +65,8 @@ class SubprocessInterpreter(BaseInterpreter):
63
65
  "shell": "bash",
64
66
  "bash": "bash",
65
67
  "sh": "bash",
68
+ "r": "r",
69
+ "R": "r",
66
70
  }
67
71
 
68
72
  def __init__(
@@ -98,15 +102,91 @@ class SubprocessInterpreter(BaseInterpreter):
98
102
  if not file.is_file():
99
103
  raise RuntimeError(f"{file} is not a file.")
100
104
  code_type = self._check_code_type(code_type)
101
- cmd = shlex.split(
102
- self._CODE_EXECUTE_CMD_MAPPING[code_type].format(
103
- file_name=str(file)
105
+ if self._CODE_TYPE_MAPPING[code_type] == "python":
106
+ # For Python code, use ast to analyze and modify the code
107
+ import ast
108
+
109
+ import astor
110
+
111
+ with open(file, 'r') as f:
112
+ source = f.read()
113
+
114
+ # Parse the source code
115
+ try:
116
+ tree = ast.parse(source)
117
+ # Get the last node
118
+ if tree.body:
119
+ last_node = tree.body[-1]
120
+ # Handle expressions that would normally not produce output
121
+ # For example: In a REPL, typing '1 + 2' should show '3'
122
+
123
+ if isinstance(last_node, ast.Expr):
124
+ # Only wrap in print(repr()) if it's not already a
125
+ # print call
126
+ if not (
127
+ isinstance(last_node.value, ast.Call)
128
+ and isinstance(last_node.value.func, ast.Name)
129
+ and last_node.value.func.id == 'print'
130
+ ):
131
+ # Transform the AST to wrap the expression in print
132
+ # (repr())
133
+ # Example transformation:
134
+ # Before: x + y
135
+ # After: print(repr(x + y))
136
+ tree.body[-1] = ast.Expr(
137
+ value=ast.Call(
138
+ # Create print() function call
139
+ func=ast.Name(id='print', ctx=ast.Load()),
140
+ args=[
141
+ ast.Call(
142
+ # Create repr() function call
143
+ func=ast.Name(
144
+ id='repr', ctx=ast.Load()
145
+ ),
146
+ # Pass the original expression as
147
+ # argument to repr()
148
+ args=[last_node.value],
149
+ keywords=[],
150
+ )
151
+ ],
152
+ keywords=[],
153
+ )
154
+ )
155
+ # Fix missing source locations
156
+ ast.fix_missing_locations(tree)
157
+ # Convert back to source
158
+ modified_source = astor.to_source(tree)
159
+ # Create a temporary file with the modified source
160
+ temp_file = self._create_temp_file(modified_source, "py")
161
+ cmd = shlex.split(f"python {temp_file!s}")
162
+ except SyntaxError:
163
+ # If parsing fails, run the original file
164
+ cmd = shlex.split(
165
+ self._CODE_EXECUTE_CMD_MAPPING[code_type].format(
166
+ file_name=str(file)
167
+ )
168
+ )
169
+ else:
170
+ # For non-Python code, use standard execution
171
+ cmd = shlex.split(
172
+ self._CODE_EXECUTE_CMD_MAPPING[code_type].format(
173
+ file_name=str(file)
174
+ )
104
175
  )
105
- )
176
+
106
177
  proc = subprocess.Popen(
107
178
  cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
108
179
  )
109
180
  stdout, stderr = proc.communicate()
181
+ return_code = proc.returncode
182
+
183
+ # Clean up temporary file if it was created
184
+ if (
185
+ self._CODE_TYPE_MAPPING[code_type] == "python"
186
+ and 'temp_file' in locals()
187
+ ):
188
+ temp_file.unlink()
189
+
110
190
  if self.print_stdout and stdout:
111
191
  print("======stdout======")
112
192
  print(Fore.GREEN + stdout + Fore.RESET)
@@ -115,8 +195,19 @@ class SubprocessInterpreter(BaseInterpreter):
115
195
  print("======stderr======")
116
196
  print(Fore.RED + stderr + Fore.RESET)
117
197
  print("==================")
118
- exec_result = f"{stdout}"
119
- exec_result += f"(stderr: {stderr})" if stderr else ""
198
+
199
+ # Build the execution result
200
+ exec_result = ""
201
+ if stdout:
202
+ exec_result += stdout
203
+ if stderr:
204
+ exec_result += f"(stderr: {stderr})"
205
+ if return_code != 0:
206
+ error_msg = f"(Execution failed with return code {return_code})"
207
+ if not stderr:
208
+ exec_result += error_msg
209
+ elif error_msg not in stderr:
210
+ exec_result += error_msg
120
211
  return exec_result
121
212
 
122
213
  def run(
camel/loaders/__init__.py CHANGED
@@ -17,6 +17,7 @@ from .base_io import File, create_file, create_file_from_raw_bytes
17
17
  from .chunkr_reader import ChunkrReader
18
18
  from .firecrawl_reader import Firecrawl
19
19
  from .jina_url_reader import JinaURLReader
20
+ from .mineru_extractor import MinerU
20
21
  from .panda_reader import PandaReader
21
22
  from .unstructured_io import UnstructuredIO
22
23
 
@@ -30,4 +31,5 @@ __all__ = [
30
31
  'Apify',
31
32
  'ChunkrReader',
32
33
  'PandaReader',
34
+ 'MinerU',
33
35
  ]
@@ -0,0 +1,250 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import os
16
+ import time
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import requests
20
+
21
+ from camel.utils import api_keys_required
22
+
23
+
24
+ class MinerU:
25
+ r"""Document extraction service supporting OCR, formula recognition
26
+ and tables.
27
+
28
+ Args:
29
+ api_key (str, optional): Authentication key for MinerU API service.
30
+ If not provided, will use MINERU_API_KEY environment variable.
31
+ (default: :obj:`None`)
32
+ api_url (str, optional): Base URL endpoint for the MinerU API service.
33
+ (default: :obj:`"https://mineru.net/api/v4"`)
34
+
35
+ Note:
36
+ - Single file size limit: 200MB
37
+ - Page limit per file: 600 pages
38
+ - Daily high-priority parsing quota: 2000 pages
39
+ - Some URLs (GitHub, AWS) may timeout due to network restrictions
40
+ """
41
+
42
+ @api_keys_required(
43
+ [
44
+ ("api_key", "MINERU_API_KEY"),
45
+ ]
46
+ )
47
+ def __init__(
48
+ self,
49
+ api_key: Optional[str] = None,
50
+ api_url: Optional[str] = "https://mineru.net/api/v4",
51
+ is_ocr: bool = False,
52
+ enable_formula: bool = False,
53
+ enable_table: bool = True,
54
+ layout_model: str = "doclayout_yolo",
55
+ language: str = "en",
56
+ ) -> None:
57
+ r"""Initialize MinerU extractor.
58
+
59
+ Args:
60
+ api_key (str, optional): Authentication key for MinerU API service.
61
+ If not provided, will use MINERU_API_KEY environment variable.
62
+ api_url (str, optional): Base URL endpoint for MinerU API service.
63
+ (default: "https://mineru.net/api/v4")
64
+ is_ocr (bool, optional): Enable optical character recognition.
65
+ (default: :obj:`False`)
66
+ enable_formula (bool, optional): Enable formula recognition.
67
+ (default: :obj:`False`)
68
+ enable_table (bool, optional): Enable table detection, extraction.
69
+ (default: :obj:`True`)
70
+ layout_model (str, optional): Model for document layout detection.
71
+ Options are 'doclayout_yolo' or 'layoutlmv3'.
72
+ (default: :obj:`"doclayout_yolo"`)
73
+ language (str, optional): Primary language of the document.
74
+ (default: :obj:`"en"`)
75
+ """
76
+ self._api_key = api_key or os.environ.get("MINERU_API_KEY")
77
+ self._api_url = api_url
78
+ self._headers = {
79
+ "Authorization": f"Bearer {self._api_key}",
80
+ "Content-Type": "application/json",
81
+ "Accept": "*/*",
82
+ }
83
+ self.is_ocr = is_ocr
84
+ self.enable_formula = enable_formula
85
+ self.enable_table = enable_table
86
+ self.layout_model = layout_model
87
+ self.language = language
88
+
89
+ def extract_url(self, url: str) -> Dict:
90
+ r"""Extract content from a URL document.
91
+
92
+ Args:
93
+ url (str): Document URL to extract content from.
94
+
95
+ Returns:
96
+ Dict: Task identifier for tracking extraction progress.
97
+ """
98
+ endpoint = f"{self._api_url}/extract/task"
99
+ payload = {"url": url}
100
+
101
+ try:
102
+ response = requests.post(
103
+ endpoint,
104
+ headers=self._headers,
105
+ json=payload,
106
+ )
107
+ response.raise_for_status()
108
+ return response.json()["data"]
109
+ except Exception as e:
110
+ raise RuntimeError(f"Failed to extract URL: {e}")
111
+
112
+ def batch_extract_urls(
113
+ self,
114
+ files: List[Dict[str, Union[str, bool]]],
115
+ ) -> str:
116
+ r"""Extract content from multiple document URLs in batch.
117
+
118
+ Args:
119
+ files (List[Dict[str, Union[str, bool]]]): List of document
120
+ configurations. Each document requires 'url' and optionally
121
+ 'is_ocr' and 'data_id' parameters.
122
+
123
+ Returns:
124
+ str: Batch identifier for tracking extraction progress.
125
+ """
126
+ endpoint = f"{self._api_url}/extract/task/batch"
127
+ payload = {"files": files}
128
+
129
+ try:
130
+ response = requests.post(
131
+ endpoint,
132
+ headers=self._headers,
133
+ json=payload,
134
+ )
135
+ response.raise_for_status()
136
+ return response.json()["data"]["batch_id"]
137
+ except Exception as e:
138
+ raise RuntimeError(f"Failed to batch extract URLs: {e}")
139
+
140
+ def get_task_status(self, task_id: str) -> Dict:
141
+ r"""Retrieve status of a single extraction task.
142
+
143
+ Args:
144
+ task_id (str): Unique identifier of the extraction task.
145
+
146
+ Returns:
147
+ Dict: Current task status and results if completed.
148
+ """
149
+ endpoint = f"{self._api_url}/extract/task/{task_id}"
150
+
151
+ try:
152
+ response = requests.get(endpoint, headers=self._headers)
153
+ response.raise_for_status()
154
+ return response.json()["data"]
155
+ except Exception as e:
156
+ raise RuntimeError(f"Failed to get task status: {e}")
157
+
158
+ def get_batch_status(self, batch_id: str) -> Dict:
159
+ r"""Retrieve status of a batch extraction task.
160
+
161
+ Args:
162
+ batch_id (str): Unique identifier of the batch extraction task.
163
+
164
+ Returns:
165
+ Dict: Current status and results for all documents in the batch.
166
+ """
167
+ endpoint = f"{self._api_url}/extract-results/batch/{batch_id}"
168
+
169
+ try:
170
+ response = requests.get(endpoint, headers=self._headers)
171
+ response.raise_for_status()
172
+ return response.json()["data"]
173
+ except Exception as e:
174
+ raise RuntimeError(f"Failed to get batch status: {e}")
175
+
176
+ def wait_for_completion(
177
+ self,
178
+ task_id: str,
179
+ is_batch: bool = False,
180
+ timeout: float = 100,
181
+ check_interval: float = 5,
182
+ ) -> Dict:
183
+ r"""Monitor task until completion or timeout.
184
+
185
+ Args:
186
+ task_id (str): Unique identifier of the task or batch.
187
+ is_batch (bool, optional): Indicates if task is a batch operation.
188
+ (default: :obj:`False`)
189
+ timeout (float, optional): Maximum wait time in seconds.
190
+ (default: :obj:`100`)
191
+ check_interval (float, optional): Time between status checks in
192
+ seconds. (default: :obj:`5`)
193
+
194
+ Returns:
195
+ Dict: Final task status and extraction results.
196
+
197
+ Raises:
198
+ TimeoutError: If task exceeds specified timeout duration.
199
+ RuntimeError: If task fails or encounters processing error.
200
+ """
201
+ start_time = time.time()
202
+ while True:
203
+ if time.time() - start_time > timeout:
204
+ raise TimeoutError(
205
+ f"Task {task_id} timed out after {timeout}s"
206
+ )
207
+
208
+ try:
209
+ status = (
210
+ self.get_batch_status(task_id)
211
+ if is_batch
212
+ else self.get_task_status(task_id)
213
+ )
214
+
215
+ if is_batch:
216
+ # Check batch status
217
+ all_done = True
218
+ failed_tasks = []
219
+ for result in status.get('extract_result', []):
220
+ if result.get('state') == 'failed':
221
+ failed_tasks.append(
222
+ f"{result.get('data_id')}:"
223
+ f" {result.get('err_msg')}"
224
+ )
225
+ elif result.get('state') != 'done':
226
+ all_done = False
227
+ break
228
+
229
+ if failed_tasks:
230
+ raise RuntimeError(
231
+ f"Batch tasks failed: {'; '.join(failed_tasks)}"
232
+ )
233
+ if all_done:
234
+ return status
235
+ else:
236
+ # Check single task status
237
+ state = status.get('state')
238
+ if state == 'failed':
239
+ raise RuntimeError(
240
+ f"Task failed: {status.get('err_msg')}"
241
+ )
242
+ elif state == 'done':
243
+ return status
244
+
245
+ except Exception as e:
246
+ if not isinstance(e, RuntimeError):
247
+ raise RuntimeError(f"Error checking status: {e}")
248
+ raise
249
+
250
+ time.sleep(check_interval)
camel/models/__init__.py CHANGED
@@ -11,6 +11,7 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .aiml_model import AIMLModel
14
15
  from .anthropic_model import AnthropicModel
15
16
  from .azure_openai_model import AzureOpenAIModel
16
17
  from .base_model import BaseModelBackend
@@ -72,4 +73,5 @@ __all__ = [
72
73
  'FishAudioModel',
73
74
  'InternLMModel',
74
75
  'MoonshotModel',
76
+ 'AIMLModel',
75
77
  ]
@@ -0,0 +1,147 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import os
15
+ from typing import Any, Dict, List, Optional, Union
16
+
17
+ from openai import OpenAI, Stream
18
+
19
+ from camel.configs import AIML_API_PARAMS, AIMLConfig
20
+ from camel.messages import OpenAIMessage
21
+ from camel.models.base_model import BaseModelBackend
22
+ from camel.types import (
23
+ ChatCompletion,
24
+ ChatCompletionChunk,
25
+ ModelType,
26
+ )
27
+ from camel.utils import (
28
+ BaseTokenCounter,
29
+ OpenAITokenCounter,
30
+ api_keys_required,
31
+ )
32
+
33
+
34
+ class AIMLModel(BaseModelBackend):
35
+ r"""AIML API in a unified BaseModelBackend interface.
36
+
37
+ Args:
38
+ model_type (Union[ModelType, str]): Model for which a backend is
39
+ created.
40
+ model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
41
+ that will be fed into OpenAI client. If :obj:`None`,
42
+ :obj:`AIMLConfig().as_dict()` will be used.
43
+ (default: :obj:`None`)
44
+ api_key (Optional[str], optional): The API key for authenticating with
45
+ the AIML service. (default: :obj:`None`)
46
+ url (Optional[str], optional): The URL to the AIML service. If
47
+ not provided, :obj:`https://api.aimlapi.com/v1` will be used.
48
+ (default: :obj:`None`)
49
+ token_counter (Optional[BaseTokenCounter], optional): Token counter to
50
+ use for the model. If not provided, :obj:`OpenAITokenCounter(
51
+ ModelType.GPT_4O_MINI)` will be used.
52
+ (default: :obj:`None`)
53
+ """
54
+
55
+ @api_keys_required(
56
+ [
57
+ ("api_key", 'AIML_API_KEY'),
58
+ ]
59
+ )
60
+ def __init__(
61
+ self,
62
+ model_type: Union[ModelType, str],
63
+ model_config_dict: Optional[Dict[str, Any]] = None,
64
+ api_key: Optional[str] = None,
65
+ url: Optional[str] = None,
66
+ token_counter: Optional[BaseTokenCounter] = None,
67
+ ) -> None:
68
+ if model_config_dict is None:
69
+ model_config_dict = AIMLConfig().as_dict()
70
+ api_key = api_key or os.environ.get("AIML_API_KEY")
71
+ url = url or os.environ.get(
72
+ "AIML_API_BASE_URL",
73
+ "https://api.aimlapi.com/v1",
74
+ )
75
+ super().__init__(
76
+ model_type, model_config_dict, api_key, url, token_counter
77
+ )
78
+ self._client = OpenAI(
79
+ timeout=180,
80
+ max_retries=3,
81
+ api_key=self._api_key,
82
+ base_url=self._url,
83
+ )
84
+
85
+ def run(
86
+ self,
87
+ messages: List[OpenAIMessage],
88
+ ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
89
+ r"""Runs inference of OpenAI chat completion.
90
+
91
+ Args:
92
+ messages (List[OpenAIMessage]): Message list with the chat history
93
+ in OpenAI API format.
94
+
95
+ Returns:
96
+ Union[ChatCompletion, Stream[ChatCompletionChunk]]:
97
+ `ChatCompletion` in the non-stream mode, or
98
+ `Stream[ChatCompletionChunk]` in the stream mode.
99
+ """
100
+ # Process model configuration parameters
101
+ model_config = self.model_config_dict.copy()
102
+
103
+ # Handle special case for tools parameter
104
+ if model_config.get('tools') is None:
105
+ model_config['tools'] = []
106
+
107
+ response = self._client.chat.completions.create(
108
+ messages=messages, model=self.model_type, **model_config
109
+ )
110
+ return response
111
+
112
+ @property
113
+ def token_counter(self) -> BaseTokenCounter:
114
+ r"""Initialize the token counter for the model backend.
115
+
116
+ Returns:
117
+ BaseTokenCounter: The token counter following the model's
118
+ tokenization style.
119
+ """
120
+ if not self._token_counter:
121
+ self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
122
+ return self._token_counter
123
+
124
+ def check_model_config(self):
125
+ r"""Check whether the model configuration contains any
126
+ unexpected arguments to AIML API.
127
+
128
+ Raises:
129
+ ValueError: If the model configuration dictionary contains any
130
+ unexpected arguments to AIML API.
131
+ """
132
+ for param in self.model_config_dict:
133
+ if param not in AIML_API_PARAMS:
134
+ raise ValueError(
135
+ f"Unexpected argument `{param}` is "
136
+ "input into AIML model backend."
137
+ )
138
+
139
+ @property
140
+ def stream(self) -> bool:
141
+ """Returns whether the model is in stream mode, which sends partial
142
+ results each time.
143
+
144
+ Returns:
145
+ bool: Whether the model is in stream mode.
146
+ """
147
+ return self.model_config_dict.get('stream', False)