camel-ai 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -11
- camel/agents/__init__.py +7 -5
- camel/agents/chat_agent.py +134 -86
- camel/agents/critic_agent.py +28 -17
- camel/agents/deductive_reasoner_agent.py +235 -0
- camel/agents/embodied_agent.py +92 -40
- camel/agents/knowledge_graph_agent.py +221 -0
- camel/agents/role_assignment_agent.py +27 -17
- camel/agents/task_agent.py +60 -34
- camel/agents/tool_agents/base.py +0 -1
- camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
- camel/configs/__init__.py +29 -0
- camel/configs/anthropic_config.py +73 -0
- camel/configs/base_config.py +22 -0
- camel/{configs.py → configs/openai_config.py} +37 -64
- camel/embeddings/__init__.py +2 -0
- camel/embeddings/base.py +3 -2
- camel/embeddings/openai_embedding.py +10 -5
- camel/embeddings/sentence_transformers_embeddings.py +65 -0
- camel/functions/__init__.py +18 -3
- camel/functions/google_maps_function.py +335 -0
- camel/functions/math_functions.py +7 -7
- camel/functions/open_api_function.py +380 -0
- camel/functions/open_api_specs/coursera/__init__.py +13 -0
- camel/functions/open_api_specs/coursera/openapi.yaml +82 -0
- camel/functions/open_api_specs/klarna/__init__.py +13 -0
- camel/functions/open_api_specs/klarna/openapi.yaml +87 -0
- camel/functions/open_api_specs/speak/__init__.py +13 -0
- camel/functions/open_api_specs/speak/openapi.yaml +151 -0
- camel/functions/openai_function.py +346 -42
- camel/functions/retrieval_functions.py +61 -0
- camel/functions/search_functions.py +100 -35
- camel/functions/slack_functions.py +275 -0
- camel/functions/twitter_function.py +484 -0
- camel/functions/weather_functions.py +36 -23
- camel/generators.py +65 -46
- camel/human.py +17 -11
- camel/interpreters/__init__.py +25 -0
- camel/interpreters/base.py +49 -0
- camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
- camel/interpreters/interpreter_error.py +19 -0
- camel/interpreters/subprocess_interpreter.py +190 -0
- camel/loaders/__init__.py +22 -0
- camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
- camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
- camel/memories/__init__.py +17 -7
- camel/memories/agent_memories.py +156 -0
- camel/memories/base.py +97 -32
- camel/memories/blocks/__init__.py +21 -0
- camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
- camel/memories/blocks/vectordb_block.py +101 -0
- camel/memories/context_creators/__init__.py +3 -2
- camel/memories/context_creators/score_based.py +32 -20
- camel/memories/records.py +6 -5
- camel/messages/__init__.py +2 -2
- camel/messages/base.py +99 -16
- camel/messages/func_message.py +7 -4
- camel/models/__init__.py +6 -2
- camel/models/anthropic_model.py +146 -0
- camel/models/base_model.py +10 -3
- camel/models/model_factory.py +17 -11
- camel/models/open_source_model.py +25 -13
- camel/models/openai_audio_models.py +251 -0
- camel/models/openai_model.py +20 -13
- camel/models/stub_model.py +10 -5
- camel/prompts/__init__.py +7 -5
- camel/prompts/ai_society.py +21 -14
- camel/prompts/base.py +54 -47
- camel/prompts/code.py +22 -14
- camel/prompts/evaluation.py +8 -5
- camel/prompts/misalignment.py +26 -19
- camel/prompts/object_recognition.py +35 -0
- camel/prompts/prompt_templates.py +14 -8
- camel/prompts/role_description_prompt_template.py +16 -10
- camel/prompts/solution_extraction.py +9 -5
- camel/prompts/task_prompt_template.py +24 -21
- camel/prompts/translation.py +9 -5
- camel/responses/agent_responses.py +5 -2
- camel/retrievers/__init__.py +26 -0
- camel/retrievers/auto_retriever.py +330 -0
- camel/retrievers/base.py +69 -0
- camel/retrievers/bm25_retriever.py +140 -0
- camel/retrievers/cohere_rerank_retriever.py +108 -0
- camel/retrievers/vector_retriever.py +183 -0
- camel/societies/__init__.py +1 -1
- camel/societies/babyagi_playing.py +56 -32
- camel/societies/role_playing.py +188 -133
- camel/storages/__init__.py +18 -0
- camel/storages/graph_storages/__init__.py +23 -0
- camel/storages/graph_storages/base.py +82 -0
- camel/storages/graph_storages/graph_element.py +74 -0
- camel/storages/graph_storages/neo4j_graph.py +582 -0
- camel/storages/key_value_storages/base.py +1 -2
- camel/storages/key_value_storages/in_memory.py +1 -2
- camel/storages/key_value_storages/json.py +8 -13
- camel/storages/vectordb_storages/__init__.py +33 -0
- camel/storages/vectordb_storages/base.py +202 -0
- camel/storages/vectordb_storages/milvus.py +396 -0
- camel/storages/vectordb_storages/qdrant.py +373 -0
- camel/terminators/__init__.py +1 -1
- camel/terminators/base.py +2 -3
- camel/terminators/response_terminator.py +21 -12
- camel/terminators/token_limit_terminator.py +5 -3
- camel/toolkits/__init__.py +21 -0
- camel/toolkits/base.py +22 -0
- camel/toolkits/github_toolkit.py +245 -0
- camel/types/__init__.py +18 -6
- camel/types/enums.py +129 -15
- camel/types/openai_types.py +10 -5
- camel/utils/__init__.py +20 -13
- camel/utils/commons.py +170 -85
- camel/utils/token_counting.py +135 -15
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/METADATA +123 -75
- camel_ai-0.1.4.dist-info/RECORD +119 -0
- {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/WHEEL +1 -1
- camel/memories/context_creators/base.py +0 -72
- camel_ai-0.1.1.dist-info/RECORD +0 -75
camel/utils/commons.py
CHANGED
|
@@ -11,26 +11,18 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
-
import
|
|
14
|
+
import importlib
|
|
15
15
|
import os
|
|
16
|
+
import platform
|
|
16
17
|
import re
|
|
17
18
|
import socket
|
|
18
19
|
import time
|
|
19
20
|
import zipfile
|
|
20
21
|
from functools import wraps
|
|
21
|
-
from typing import
|
|
22
|
-
Any,
|
|
23
|
-
Callable,
|
|
24
|
-
Dict,
|
|
25
|
-
List,
|
|
26
|
-
Optional,
|
|
27
|
-
Set,
|
|
28
|
-
Tuple,
|
|
29
|
-
TypeVar,
|
|
30
|
-
cast,
|
|
31
|
-
)
|
|
22
|
+
from typing import Any, Callable, List, Optional, Set, TypeVar, cast
|
|
32
23
|
from urllib.parse import urlparse
|
|
33
24
|
|
|
25
|
+
import pydantic
|
|
34
26
|
import requests
|
|
35
27
|
|
|
36
28
|
from camel.types import TaskType
|
|
@@ -38,9 +30,8 @@ from camel.types import TaskType
|
|
|
38
30
|
F = TypeVar('F', bound=Callable[..., Any])
|
|
39
31
|
|
|
40
32
|
|
|
41
|
-
def
|
|
42
|
-
r"""Decorator that checks if the
|
|
43
|
-
environment variables.
|
|
33
|
+
def api_key_required(func: F) -> F:
|
|
34
|
+
r"""Decorator that checks if the API key is available either as an environment variable or passed directly.
|
|
44
35
|
|
|
45
36
|
Args:
|
|
46
37
|
func (callable): The function to be wrapped.
|
|
@@ -49,16 +40,22 @@ def openai_api_key_required(func: F) -> F:
|
|
|
49
40
|
callable: The decorated function.
|
|
50
41
|
|
|
51
42
|
Raises:
|
|
52
|
-
ValueError: If the
|
|
53
|
-
|
|
43
|
+
ValueError: If the API key is not found, either as an environment
|
|
44
|
+
variable or directly passed.
|
|
54
45
|
"""
|
|
55
46
|
|
|
56
47
|
@wraps(func)
|
|
57
48
|
def wrapper(self, *args, **kwargs):
|
|
58
|
-
if
|
|
49
|
+
if self.model_type.is_openai:
|
|
50
|
+
if not self._api_key and 'OPENAI_API_KEY' not in os.environ:
|
|
51
|
+
raise ValueError('OpenAI API key not found.')
|
|
52
|
+
return func(self, *args, **kwargs)
|
|
53
|
+
elif self.model_type.is_anthropic:
|
|
54
|
+
if 'ANTHROPIC_API_KEY' not in os.environ:
|
|
55
|
+
raise ValueError('Anthropic API key not found.')
|
|
59
56
|
return func(self, *args, **kwargs)
|
|
60
57
|
else:
|
|
61
|
-
raise ValueError('
|
|
58
|
+
raise ValueError('Unsupported model type.')
|
|
62
59
|
|
|
63
60
|
return cast(F, wrapper)
|
|
64
61
|
|
|
@@ -116,12 +113,26 @@ def get_first_int(string: str) -> Optional[int]:
|
|
|
116
113
|
|
|
117
114
|
|
|
118
115
|
def download_tasks(task: TaskType, folder_path: str) -> None:
|
|
116
|
+
r"""Downloads task-related files from a specified URL and extracts them.
|
|
117
|
+
|
|
118
|
+
This function downloads a zip file containing tasks based on the specified
|
|
119
|
+
`task` type from a predefined URL, saves it to `folder_path`, and then
|
|
120
|
+
extracts the contents of the zip file into the same folder. After
|
|
121
|
+
extraction, the zip file is deleted.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
task (TaskType): An enum representing the type of task to download.
|
|
125
|
+
folder_path (str): The path of the folder where the zip file will be
|
|
126
|
+
downloaded and extracted.
|
|
127
|
+
"""
|
|
119
128
|
# Define the path to save the zip file
|
|
120
129
|
zip_file_path = os.path.join(folder_path, "tasks.zip")
|
|
121
130
|
|
|
122
131
|
# Download the zip file from the Google Drive link
|
|
123
|
-
response = requests.get(
|
|
124
|
-
|
|
132
|
+
response = requests.get(
|
|
133
|
+
"https://huggingface.co/datasets/camel-ai/"
|
|
134
|
+
f"metadata/resolve/main/{task.value}_tasks.zip"
|
|
135
|
+
)
|
|
125
136
|
|
|
126
137
|
# Save the zip file
|
|
127
138
|
with open(zip_file_path, "wb") as f:
|
|
@@ -134,70 +145,6 @@ def download_tasks(task: TaskType, folder_path: str) -> None:
|
|
|
134
145
|
os.remove(zip_file_path)
|
|
135
146
|
|
|
136
147
|
|
|
137
|
-
def parse_doc(func: Callable) -> Dict[str, Any]:
|
|
138
|
-
r"""Parse the docstrings of a function to extract the function name,
|
|
139
|
-
description and parameters.
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
func (Callable): The function to be parsed.
|
|
143
|
-
Returns:
|
|
144
|
-
Dict[str, Any]: A dictionary with the function's name,
|
|
145
|
-
description, and parameters.
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
doc = inspect.getdoc(func)
|
|
149
|
-
if not doc:
|
|
150
|
-
raise ValueError(
|
|
151
|
-
f"Invalid function {func.__name__}: no docstring provided.")
|
|
152
|
-
|
|
153
|
-
properties = {}
|
|
154
|
-
required = []
|
|
155
|
-
|
|
156
|
-
parts = re.split(r'\n\s*\n', doc)
|
|
157
|
-
func_desc = parts[0].strip()
|
|
158
|
-
|
|
159
|
-
args_section = next((p for p in parts if 'Args:' in p), None)
|
|
160
|
-
if args_section:
|
|
161
|
-
args_descs: List[Tuple[str, str, str, ]] = re.findall(
|
|
162
|
-
r'(\w+)\s*\((\w+)\):\s*(.*)', args_section)
|
|
163
|
-
properties = {
|
|
164
|
-
name.strip(): {
|
|
165
|
-
'type': type,
|
|
166
|
-
'description': desc
|
|
167
|
-
}
|
|
168
|
-
for name, type, desc in args_descs
|
|
169
|
-
}
|
|
170
|
-
for name in properties:
|
|
171
|
-
required.append(name)
|
|
172
|
-
|
|
173
|
-
# Parameters from the function signature
|
|
174
|
-
sign_params = list(inspect.signature(func).parameters.keys())
|
|
175
|
-
if len(sign_params) != len(required):
|
|
176
|
-
raise ValueError(
|
|
177
|
-
f"Number of parameters in function signature ({len(sign_params)})"
|
|
178
|
-
f" does not match that in docstring ({len(required)}).")
|
|
179
|
-
|
|
180
|
-
for param in sign_params:
|
|
181
|
-
if param not in required:
|
|
182
|
-
raise ValueError(f"Parameter '{param}' in function signature"
|
|
183
|
-
" is missing in the docstring.")
|
|
184
|
-
|
|
185
|
-
parameters = {
|
|
186
|
-
"type": "object",
|
|
187
|
-
"properties": properties,
|
|
188
|
-
"required": required,
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
# Construct the function dictionary
|
|
192
|
-
function_dict = {
|
|
193
|
-
"name": func.__name__,
|
|
194
|
-
"description": func_desc,
|
|
195
|
-
"parameters": parameters,
|
|
196
|
-
}
|
|
197
|
-
|
|
198
|
-
return function_dict
|
|
199
|
-
|
|
200
|
-
|
|
201
148
|
def get_task_list(task_response: str) -> List[str]:
|
|
202
149
|
r"""Parse the response of the Agent and return task list.
|
|
203
150
|
|
|
@@ -241,3 +188,141 @@ def check_server_running(server_url: str) -> bool:
|
|
|
241
188
|
|
|
242
189
|
# if the port is open, the result should be 0.
|
|
243
190
|
return result == 0
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def dependencies_required(*required_modules: str) -> Callable[[F], F]:
|
|
194
|
+
r"""A decorator to ensure that specified Python modules
|
|
195
|
+
are available before a function executes.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
required_modules (str): The required modules to be checked for
|
|
199
|
+
availability.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Callable[[F], F]: The original function with the added check for
|
|
203
|
+
required module dependencies.
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ImportError: If any of the required modules are not available.
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
::
|
|
210
|
+
|
|
211
|
+
@dependencies_required('numpy', 'pandas')
|
|
212
|
+
def data_processing_function():
|
|
213
|
+
# Function implementation...
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def decorator(func: F) -> F:
|
|
217
|
+
@wraps(func)
|
|
218
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
219
|
+
missing_modules = [
|
|
220
|
+
m for m in required_modules if not is_module_available(m)
|
|
221
|
+
]
|
|
222
|
+
if missing_modules:
|
|
223
|
+
raise ImportError(
|
|
224
|
+
f"Missing required modules: {', '.join(missing_modules)}"
|
|
225
|
+
)
|
|
226
|
+
return func(*args, **kwargs)
|
|
227
|
+
|
|
228
|
+
return cast(F, wrapper)
|
|
229
|
+
|
|
230
|
+
return decorator
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def is_module_available(module_name: str) -> bool:
|
|
234
|
+
r"""Check if a module is available for import.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
module_name (str): The name of the module to check for availability.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
bool: True if the module can be imported, False otherwise.
|
|
241
|
+
"""
|
|
242
|
+
try:
|
|
243
|
+
importlib.import_module(module_name)
|
|
244
|
+
return True
|
|
245
|
+
except ImportError:
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def api_keys_required(*required_keys: str) -> Callable[[F], F]:
|
|
250
|
+
r"""A decorator to check if the required API keys are
|
|
251
|
+
present in the environment variables.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
required_keys (str): The required API keys to be checked.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Callable[[F], F]: The original function with the added check
|
|
258
|
+
for required API keys.
|
|
259
|
+
|
|
260
|
+
Raises:
|
|
261
|
+
ValueError: If any of the required API keys are missing in the
|
|
262
|
+
environment variables.
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
::
|
|
266
|
+
|
|
267
|
+
@api_keys_required('API_KEY_1', 'API_KEY_2')
|
|
268
|
+
def some_api_function():
|
|
269
|
+
# Function implementation...
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def decorator(func: F) -> F:
|
|
273
|
+
@wraps(func)
|
|
274
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
275
|
+
missing_keys = [k for k in required_keys if k not in os.environ]
|
|
276
|
+
if missing_keys:
|
|
277
|
+
raise ValueError(f"Missing API keys: {', '.join(missing_keys)}")
|
|
278
|
+
return func(*args, **kwargs)
|
|
279
|
+
|
|
280
|
+
return cast(F, wrapper)
|
|
281
|
+
|
|
282
|
+
return decorator
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def get_system_information():
|
|
286
|
+
r"""Gathers information about the operating system.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
dict: A dictionary containing various pieces of OS information.
|
|
290
|
+
"""
|
|
291
|
+
sys_info = {
|
|
292
|
+
"OS Name": os.name,
|
|
293
|
+
"System": platform.system(),
|
|
294
|
+
"Release": platform.release(),
|
|
295
|
+
"Version": platform.version(),
|
|
296
|
+
"Machine": platform.machine(),
|
|
297
|
+
"Processor": platform.processor(),
|
|
298
|
+
"Platform": platform.platform(),
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
return sys_info
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def to_pascal(snake: str) -> str:
|
|
305
|
+
"""Convert a snake_case string to PascalCase.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
snake (str): The snake_case string to be converted.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
str: The converted PascalCase string.
|
|
312
|
+
"""
|
|
313
|
+
# Check if the string is already in PascalCase
|
|
314
|
+
if re.match(r'^[A-Z][a-zA-Z0-9]*([A-Z][a-zA-Z0-9]*)*$', snake):
|
|
315
|
+
return snake
|
|
316
|
+
# Remove leading and trailing underscores
|
|
317
|
+
snake = snake.strip('_')
|
|
318
|
+
# Replace multiple underscores with a single one
|
|
319
|
+
snake = re.sub('_+', '_', snake)
|
|
320
|
+
# Convert to PascalCase
|
|
321
|
+
return re.sub(
|
|
322
|
+
'_([0-9A-Za-z])',
|
|
323
|
+
lambda m: m.group(1).upper(),
|
|
324
|
+
snake.title(),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
camel/utils/token_counting.py
CHANGED
|
@@ -11,11 +11,29 @@
|
|
|
11
11
|
# See the License for the specific language governing permissions and
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import base64
|
|
14
18
|
from abc import ABC, abstractmethod
|
|
15
|
-
from
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from math import ceil
|
|
21
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
22
|
+
|
|
23
|
+
from anthropic import Anthropic
|
|
24
|
+
from PIL import Image
|
|
16
25
|
|
|
17
|
-
from camel.
|
|
18
|
-
|
|
26
|
+
from camel.types import ModelType, OpenAIImageDetailType, OpenAIImageType
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from camel.messages import OpenAIMessage
|
|
30
|
+
|
|
31
|
+
LOW_DETAIL_TOKENS = 85
|
|
32
|
+
FIT_SQUARE_PIXELS = 2048
|
|
33
|
+
SHORTEST_SIDE_PIXELS = 768
|
|
34
|
+
SQUARE_PIXELS = 512
|
|
35
|
+
SQUARE_TOKENS = 170
|
|
36
|
+
EXTRA_TOKENS = 85
|
|
19
37
|
|
|
20
38
|
|
|
21
39
|
def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
@@ -45,8 +63,10 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
|
45
63
|
content = msg["content"]
|
|
46
64
|
if content:
|
|
47
65
|
if not isinstance(content, str):
|
|
48
|
-
raise ValueError(
|
|
49
|
-
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Currently multimodal context is not "
|
|
68
|
+
"supported by the token counter."
|
|
69
|
+
)
|
|
50
70
|
if i == 0:
|
|
51
71
|
ret += system_prompt + content
|
|
52
72
|
else:
|
|
@@ -64,8 +84,10 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
|
|
|
64
84
|
role = role_map[msg["role"]]
|
|
65
85
|
content = msg["content"]
|
|
66
86
|
if not isinstance(content, str):
|
|
67
|
-
raise ValueError(
|
|
68
|
-
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"Currently multimodal context is not "
|
|
89
|
+
"supported by the token counter."
|
|
90
|
+
)
|
|
69
91
|
if content:
|
|
70
92
|
ret += role + ": " + content + seps[i % 2]
|
|
71
93
|
else:
|
|
@@ -85,6 +107,7 @@ def get_model_encoding(value_for_tiktoken: str):
|
|
|
85
107
|
tiktoken.Encoding: Model encoding.
|
|
86
108
|
"""
|
|
87
109
|
import tiktoken
|
|
110
|
+
|
|
88
111
|
try:
|
|
89
112
|
encoding = tiktoken.encoding_for_model(value_for_tiktoken)
|
|
90
113
|
except KeyError:
|
|
@@ -111,7 +134,6 @@ class BaseTokenCounter(ABC):
|
|
|
111
134
|
|
|
112
135
|
|
|
113
136
|
class OpenSourceTokenCounter(BaseTokenCounter):
|
|
114
|
-
|
|
115
137
|
def __init__(self, model_type: ModelType, model_path: str):
|
|
116
138
|
r"""Constructor for the token counter for open-source models.
|
|
117
139
|
|
|
@@ -126,6 +148,7 @@ class OpenSourceTokenCounter(BaseTokenCounter):
|
|
|
126
148
|
# If a fast tokenizer is not available for a given model,
|
|
127
149
|
# a normal Python-based tokenizer is returned instead.
|
|
128
150
|
from transformers import AutoTokenizer
|
|
151
|
+
|
|
129
152
|
try:
|
|
130
153
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
131
154
|
model_path,
|
|
@@ -136,10 +159,11 @@ class OpenSourceTokenCounter(BaseTokenCounter):
|
|
|
136
159
|
model_path,
|
|
137
160
|
use_fast=False,
|
|
138
161
|
)
|
|
139
|
-
except:
|
|
162
|
+
except Exception:
|
|
140
163
|
raise ValueError(
|
|
141
164
|
f"Invalid `model_path` ({model_path}) is provided. "
|
|
142
|
-
"Tokenizer loading failed."
|
|
165
|
+
"Tokenizer loading failed."
|
|
166
|
+
)
|
|
143
167
|
|
|
144
168
|
self.tokenizer = tokenizer
|
|
145
169
|
self.model_type = model_type
|
|
@@ -162,13 +186,11 @@ class OpenSourceTokenCounter(BaseTokenCounter):
|
|
|
162
186
|
|
|
163
187
|
|
|
164
188
|
class OpenAITokenCounter(BaseTokenCounter):
|
|
165
|
-
|
|
166
189
|
def __init__(self, model: ModelType):
|
|
167
190
|
r"""Constructor for the token counter for OpenAI models.
|
|
168
191
|
|
|
169
192
|
Args:
|
|
170
|
-
|
|
171
|
-
counted.
|
|
193
|
+
model (ModelType): Model type for which tokens will be counted.
|
|
172
194
|
"""
|
|
173
195
|
self.model: str = model.value_for_tiktoken
|
|
174
196
|
|
|
@@ -192,7 +214,8 @@ class OpenAITokenCounter(BaseTokenCounter):
|
|
|
192
214
|
"for information on how messages are converted to tokens. "
|
|
193
215
|
"See https://platform.openai.com/docs/models/gpt-4"
|
|
194
216
|
"or https://platform.openai.com/docs/models/gpt-3-5"
|
|
195
|
-
"for information about openai chat models."
|
|
217
|
+
"for information about openai chat models."
|
|
218
|
+
)
|
|
196
219
|
|
|
197
220
|
self.encoding = get_model_encoding(self.model)
|
|
198
221
|
|
|
@@ -211,10 +234,107 @@ class OpenAITokenCounter(BaseTokenCounter):
|
|
|
211
234
|
for message in messages:
|
|
212
235
|
num_tokens += self.tokens_per_message
|
|
213
236
|
for key, value in message.items():
|
|
214
|
-
|
|
237
|
+
if not isinstance(value, list):
|
|
238
|
+
num_tokens += len(self.encoding.encode(str(value)))
|
|
239
|
+
else:
|
|
240
|
+
for item in value:
|
|
241
|
+
if item["type"] == "text":
|
|
242
|
+
num_tokens += len(
|
|
243
|
+
self.encoding.encode(str(item["text"]))
|
|
244
|
+
)
|
|
245
|
+
elif item["type"] == "image_url":
|
|
246
|
+
image_str: str = item["image_url"]["url"]
|
|
247
|
+
detail = item["image_url"]["detail"]
|
|
248
|
+
image_prefix_format = "data:image/{};base64,"
|
|
249
|
+
image_prefix: Optional[str] = None
|
|
250
|
+
for image_type in list(OpenAIImageType):
|
|
251
|
+
# Find the correct image format
|
|
252
|
+
image_prefix = image_prefix_format.format(
|
|
253
|
+
image_type.value
|
|
254
|
+
)
|
|
255
|
+
if image_prefix in image_str:
|
|
256
|
+
break
|
|
257
|
+
assert isinstance(image_prefix, str)
|
|
258
|
+
encoded_image = image_str.split(image_prefix)[1]
|
|
259
|
+
image_bytes = BytesIO(
|
|
260
|
+
base64.b64decode(encoded_image)
|
|
261
|
+
)
|
|
262
|
+
image = Image.open(image_bytes)
|
|
263
|
+
num_tokens += count_tokens_from_image(
|
|
264
|
+
image, OpenAIImageDetailType(detail)
|
|
265
|
+
)
|
|
215
266
|
if key == "name":
|
|
216
267
|
num_tokens += self.tokens_per_name
|
|
217
268
|
|
|
218
269
|
# every reply is primed with <|start|>assistant<|message|>
|
|
219
270
|
num_tokens += 3
|
|
220
271
|
return num_tokens
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class AnthropicTokenCounter(BaseTokenCounter):
|
|
275
|
+
def __init__(self, model_type: ModelType):
|
|
276
|
+
r"""Constructor for the token counter for Anthropic models.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
model_type (ModelType): Model type for which tokens will be
|
|
280
|
+
counted.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
self.model_type = model_type
|
|
284
|
+
self.client = Anthropic()
|
|
285
|
+
self.tokenizer = self.client.get_tokenizer()
|
|
286
|
+
|
|
287
|
+
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
|
|
288
|
+
r"""Count number of tokens in the provided message list using
|
|
289
|
+
loaded tokenizer specific for this type of model.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
messages (List[OpenAIMessage]): Message list with the chat history
|
|
293
|
+
in OpenAI API format.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
int: Number of tokens in the messages.
|
|
297
|
+
"""
|
|
298
|
+
prompt = messages_to_prompt(messages, self.model_type)
|
|
299
|
+
|
|
300
|
+
return self.client.count_tokens(prompt)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def count_tokens_from_image(
|
|
304
|
+
image: Image.Image, detail: OpenAIImageDetailType
|
|
305
|
+
) -> int:
|
|
306
|
+
r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
|
|
307
|
+
resolution model will be treated as :obj:`"high"`. All images with
|
|
308
|
+
:obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
|
|
309
|
+
are first scaled to fit within a 2048 x 2048 square, maintaining their
|
|
310
|
+
aspect ratio. Then, they are scaled such that the shortest side of the
|
|
311
|
+
image is 768px long. Finally, we count how many 512px squares the image
|
|
312
|
+
consists of. Each of those squares costs 170 tokens. Another 85 tokens are
|
|
313
|
+
always added to the final total. For more details please refer to `OpenAI
|
|
314
|
+
vision docs <https://platform.openai.com/docs/guides/vision>`_
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
image (PIL.Image.Image): Image to count number of tokens.
|
|
318
|
+
detail (OpenAIImageDetailType): Image detail type to count
|
|
319
|
+
number of tokens.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
int: Number of tokens for the image given a detail type.
|
|
323
|
+
"""
|
|
324
|
+
if detail == OpenAIImageDetailType.LOW:
|
|
325
|
+
return LOW_DETAIL_TOKENS
|
|
326
|
+
|
|
327
|
+
width, height = image.size
|
|
328
|
+
if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
|
|
329
|
+
scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
|
|
330
|
+
width = int(width / scaling_factor)
|
|
331
|
+
height = int(height / scaling_factor)
|
|
332
|
+
|
|
333
|
+
scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
|
|
334
|
+
scaled_width = int(width / scaling_factor)
|
|
335
|
+
scaled_height = int(height / scaling_factor)
|
|
336
|
+
|
|
337
|
+
h = ceil(scaled_height / SQUARE_PIXELS)
|
|
338
|
+
w = ceil(scaled_width / SQUARE_PIXELS)
|
|
339
|
+
total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
|
|
340
|
+
return total
|