camel-ai 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. camel/__init__.py +1 -11
  2. camel/agents/__init__.py +7 -5
  3. camel/agents/chat_agent.py +134 -86
  4. camel/agents/critic_agent.py +28 -17
  5. camel/agents/deductive_reasoner_agent.py +235 -0
  6. camel/agents/embodied_agent.py +92 -40
  7. camel/agents/knowledge_graph_agent.py +221 -0
  8. camel/agents/role_assignment_agent.py +27 -17
  9. camel/agents/task_agent.py +60 -34
  10. camel/agents/tool_agents/base.py +0 -1
  11. camel/agents/tool_agents/hugging_face_tool_agent.py +7 -4
  12. camel/configs/__init__.py +29 -0
  13. camel/configs/anthropic_config.py +73 -0
  14. camel/configs/base_config.py +22 -0
  15. camel/{configs.py → configs/openai_config.py} +37 -64
  16. camel/embeddings/__init__.py +2 -0
  17. camel/embeddings/base.py +3 -2
  18. camel/embeddings/openai_embedding.py +10 -5
  19. camel/embeddings/sentence_transformers_embeddings.py +65 -0
  20. camel/functions/__init__.py +18 -3
  21. camel/functions/google_maps_function.py +335 -0
  22. camel/functions/math_functions.py +7 -7
  23. camel/functions/open_api_function.py +380 -0
  24. camel/functions/open_api_specs/coursera/__init__.py +13 -0
  25. camel/functions/open_api_specs/coursera/openapi.yaml +82 -0
  26. camel/functions/open_api_specs/klarna/__init__.py +13 -0
  27. camel/functions/open_api_specs/klarna/openapi.yaml +87 -0
  28. camel/functions/open_api_specs/speak/__init__.py +13 -0
  29. camel/functions/open_api_specs/speak/openapi.yaml +151 -0
  30. camel/functions/openai_function.py +346 -42
  31. camel/functions/retrieval_functions.py +61 -0
  32. camel/functions/search_functions.py +100 -35
  33. camel/functions/slack_functions.py +275 -0
  34. camel/functions/twitter_function.py +484 -0
  35. camel/functions/weather_functions.py +36 -23
  36. camel/generators.py +65 -46
  37. camel/human.py +17 -11
  38. camel/interpreters/__init__.py +25 -0
  39. camel/interpreters/base.py +49 -0
  40. camel/{utils/python_interpreter.py → interpreters/internal_python_interpreter.py} +129 -48
  41. camel/interpreters/interpreter_error.py +19 -0
  42. camel/interpreters/subprocess_interpreter.py +190 -0
  43. camel/loaders/__init__.py +22 -0
  44. camel/{functions/base_io_functions.py → loaders/base_io.py} +38 -35
  45. camel/{functions/unstructured_io_fuctions.py → loaders/unstructured_io.py} +199 -110
  46. camel/memories/__init__.py +17 -7
  47. camel/memories/agent_memories.py +156 -0
  48. camel/memories/base.py +97 -32
  49. camel/memories/blocks/__init__.py +21 -0
  50. camel/memories/{chat_history_memory.py → blocks/chat_history_block.py} +34 -34
  51. camel/memories/blocks/vectordb_block.py +101 -0
  52. camel/memories/context_creators/__init__.py +3 -2
  53. camel/memories/context_creators/score_based.py +32 -20
  54. camel/memories/records.py +6 -5
  55. camel/messages/__init__.py +2 -2
  56. camel/messages/base.py +99 -16
  57. camel/messages/func_message.py +7 -4
  58. camel/models/__init__.py +6 -2
  59. camel/models/anthropic_model.py +146 -0
  60. camel/models/base_model.py +10 -3
  61. camel/models/model_factory.py +17 -11
  62. camel/models/open_source_model.py +25 -13
  63. camel/models/openai_audio_models.py +251 -0
  64. camel/models/openai_model.py +20 -13
  65. camel/models/stub_model.py +10 -5
  66. camel/prompts/__init__.py +7 -5
  67. camel/prompts/ai_society.py +21 -14
  68. camel/prompts/base.py +54 -47
  69. camel/prompts/code.py +22 -14
  70. camel/prompts/evaluation.py +8 -5
  71. camel/prompts/misalignment.py +26 -19
  72. camel/prompts/object_recognition.py +35 -0
  73. camel/prompts/prompt_templates.py +14 -8
  74. camel/prompts/role_description_prompt_template.py +16 -10
  75. camel/prompts/solution_extraction.py +9 -5
  76. camel/prompts/task_prompt_template.py +24 -21
  77. camel/prompts/translation.py +9 -5
  78. camel/responses/agent_responses.py +5 -2
  79. camel/retrievers/__init__.py +26 -0
  80. camel/retrievers/auto_retriever.py +330 -0
  81. camel/retrievers/base.py +69 -0
  82. camel/retrievers/bm25_retriever.py +140 -0
  83. camel/retrievers/cohere_rerank_retriever.py +108 -0
  84. camel/retrievers/vector_retriever.py +183 -0
  85. camel/societies/__init__.py +1 -1
  86. camel/societies/babyagi_playing.py +56 -32
  87. camel/societies/role_playing.py +188 -133
  88. camel/storages/__init__.py +18 -0
  89. camel/storages/graph_storages/__init__.py +23 -0
  90. camel/storages/graph_storages/base.py +82 -0
  91. camel/storages/graph_storages/graph_element.py +74 -0
  92. camel/storages/graph_storages/neo4j_graph.py +582 -0
  93. camel/storages/key_value_storages/base.py +1 -2
  94. camel/storages/key_value_storages/in_memory.py +1 -2
  95. camel/storages/key_value_storages/json.py +8 -13
  96. camel/storages/vectordb_storages/__init__.py +33 -0
  97. camel/storages/vectordb_storages/base.py +202 -0
  98. camel/storages/vectordb_storages/milvus.py +396 -0
  99. camel/storages/vectordb_storages/qdrant.py +373 -0
  100. camel/terminators/__init__.py +1 -1
  101. camel/terminators/base.py +2 -3
  102. camel/terminators/response_terminator.py +21 -12
  103. camel/terminators/token_limit_terminator.py +5 -3
  104. camel/toolkits/__init__.py +21 -0
  105. camel/toolkits/base.py +22 -0
  106. camel/toolkits/github_toolkit.py +245 -0
  107. camel/types/__init__.py +18 -6
  108. camel/types/enums.py +129 -15
  109. camel/types/openai_types.py +10 -5
  110. camel/utils/__init__.py +20 -13
  111. camel/utils/commons.py +170 -85
  112. camel/utils/token_counting.py +135 -15
  113. {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/METADATA +123 -75
  114. camel_ai-0.1.4.dist-info/RECORD +119 -0
  115. {camel_ai-0.1.1.dist-info → camel_ai-0.1.4.dist-info}/WHEEL +1 -1
  116. camel/memories/context_creators/base.py +0 -72
  117. camel_ai-0.1.1.dist-info/RECORD +0 -75
camel/utils/commons.py CHANGED
@@ -11,26 +11,18 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
- import inspect
14
+ import importlib
15
15
  import os
16
+ import platform
16
17
  import re
17
18
  import socket
18
19
  import time
19
20
  import zipfile
20
21
  from functools import wraps
21
- from typing import (
22
- Any,
23
- Callable,
24
- Dict,
25
- List,
26
- Optional,
27
- Set,
28
- Tuple,
29
- TypeVar,
30
- cast,
31
- )
22
+ from typing import Any, Callable, List, Optional, Set, TypeVar, cast
32
23
  from urllib.parse import urlparse
33
24
 
25
+ import pydantic
34
26
  import requests
35
27
 
36
28
  from camel.types import TaskType
@@ -38,9 +30,8 @@ from camel.types import TaskType
38
30
  F = TypeVar('F', bound=Callable[..., Any])
39
31
 
40
32
 
41
- def openai_api_key_required(func: F) -> F:
42
- r"""Decorator that checks if the OpenAI API key is available in the
43
- environment variables.
33
+ def api_key_required(func: F) -> F:
34
+ r"""Decorator that checks if the API key is available either as an environment variable or passed directly.
44
35
 
45
36
  Args:
46
37
  func (callable): The function to be wrapped.
@@ -49,16 +40,22 @@ def openai_api_key_required(func: F) -> F:
49
40
  callable: The decorated function.
50
41
 
51
42
  Raises:
52
- ValueError: If the OpenAI API key is not found in the environment
53
- variables.
43
+ ValueError: If the API key is not found, either as an environment
44
+ variable or directly passed.
54
45
  """
55
46
 
56
47
  @wraps(func)
57
48
  def wrapper(self, *args, **kwargs):
58
- if 'OPENAI_API_KEY' in os.environ:
49
+ if self.model_type.is_openai:
50
+ if not self._api_key and 'OPENAI_API_KEY' not in os.environ:
51
+ raise ValueError('OpenAI API key not found.')
52
+ return func(self, *args, **kwargs)
53
+ elif self.model_type.is_anthropic:
54
+ if 'ANTHROPIC_API_KEY' not in os.environ:
55
+ raise ValueError('Anthropic API key not found.')
59
56
  return func(self, *args, **kwargs)
60
57
  else:
61
- raise ValueError('OpenAI API key not found.')
58
+ raise ValueError('Unsupported model type.')
62
59
 
63
60
  return cast(F, wrapper)
64
61
 
@@ -116,12 +113,26 @@ def get_first_int(string: str) -> Optional[int]:
116
113
 
117
114
 
118
115
  def download_tasks(task: TaskType, folder_path: str) -> None:
116
+ r"""Downloads task-related files from a specified URL and extracts them.
117
+
118
+ This function downloads a zip file containing tasks based on the specified
119
+ `task` type from a predefined URL, saves it to `folder_path`, and then
120
+ extracts the contents of the zip file into the same folder. After
121
+ extraction, the zip file is deleted.
122
+
123
+ Args:
124
+ task (TaskType): An enum representing the type of task to download.
125
+ folder_path (str): The path of the folder where the zip file will be
126
+ downloaded and extracted.
127
+ """
119
128
  # Define the path to save the zip file
120
129
  zip_file_path = os.path.join(folder_path, "tasks.zip")
121
130
 
122
131
  # Download the zip file from the Google Drive link
123
- response = requests.get("https://huggingface.co/datasets/camel-ai/"
124
- f"metadata/resolve/main/{task.value}_tasks.zip")
132
+ response = requests.get(
133
+ "https://huggingface.co/datasets/camel-ai/"
134
+ f"metadata/resolve/main/{task.value}_tasks.zip"
135
+ )
125
136
 
126
137
  # Save the zip file
127
138
  with open(zip_file_path, "wb") as f:
@@ -134,70 +145,6 @@ def download_tasks(task: TaskType, folder_path: str) -> None:
134
145
  os.remove(zip_file_path)
135
146
 
136
147
 
137
- def parse_doc(func: Callable) -> Dict[str, Any]:
138
- r"""Parse the docstrings of a function to extract the function name,
139
- description and parameters.
140
-
141
- Args:
142
- func (Callable): The function to be parsed.
143
- Returns:
144
- Dict[str, Any]: A dictionary with the function's name,
145
- description, and parameters.
146
- """
147
-
148
- doc = inspect.getdoc(func)
149
- if not doc:
150
- raise ValueError(
151
- f"Invalid function {func.__name__}: no docstring provided.")
152
-
153
- properties = {}
154
- required = []
155
-
156
- parts = re.split(r'\n\s*\n', doc)
157
- func_desc = parts[0].strip()
158
-
159
- args_section = next((p for p in parts if 'Args:' in p), None)
160
- if args_section:
161
- args_descs: List[Tuple[str, str, str, ]] = re.findall(
162
- r'(\w+)\s*\((\w+)\):\s*(.*)', args_section)
163
- properties = {
164
- name.strip(): {
165
- 'type': type,
166
- 'description': desc
167
- }
168
- for name, type, desc in args_descs
169
- }
170
- for name in properties:
171
- required.append(name)
172
-
173
- # Parameters from the function signature
174
- sign_params = list(inspect.signature(func).parameters.keys())
175
- if len(sign_params) != len(required):
176
- raise ValueError(
177
- f"Number of parameters in function signature ({len(sign_params)})"
178
- f" does not match that in docstring ({len(required)}).")
179
-
180
- for param in sign_params:
181
- if param not in required:
182
- raise ValueError(f"Parameter '{param}' in function signature"
183
- " is missing in the docstring.")
184
-
185
- parameters = {
186
- "type": "object",
187
- "properties": properties,
188
- "required": required,
189
- }
190
-
191
- # Construct the function dictionary
192
- function_dict = {
193
- "name": func.__name__,
194
- "description": func_desc,
195
- "parameters": parameters,
196
- }
197
-
198
- return function_dict
199
-
200
-
201
148
  def get_task_list(task_response: str) -> List[str]:
202
149
  r"""Parse the response of the Agent and return task list.
203
150
 
@@ -241,3 +188,141 @@ def check_server_running(server_url: str) -> bool:
241
188
 
242
189
  # if the port is open, the result should be 0.
243
190
  return result == 0
191
+
192
+
193
+ def dependencies_required(*required_modules: str) -> Callable[[F], F]:
194
+ r"""A decorator to ensure that specified Python modules
195
+ are available before a function executes.
196
+
197
+ Args:
198
+ required_modules (str): The required modules to be checked for
199
+ availability.
200
+
201
+ Returns:
202
+ Callable[[F], F]: The original function with the added check for
203
+ required module dependencies.
204
+
205
+ Raises:
206
+ ImportError: If any of the required modules are not available.
207
+
208
+ Example:
209
+ ::
210
+
211
+ @dependencies_required('numpy', 'pandas')
212
+ def data_processing_function():
213
+ # Function implementation...
214
+ """
215
+
216
+ def decorator(func: F) -> F:
217
+ @wraps(func)
218
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
219
+ missing_modules = [
220
+ m for m in required_modules if not is_module_available(m)
221
+ ]
222
+ if missing_modules:
223
+ raise ImportError(
224
+ f"Missing required modules: {', '.join(missing_modules)}"
225
+ )
226
+ return func(*args, **kwargs)
227
+
228
+ return cast(F, wrapper)
229
+
230
+ return decorator
231
+
232
+
233
+ def is_module_available(module_name: str) -> bool:
234
+ r"""Check if a module is available for import.
235
+
236
+ Args:
237
+ module_name (str): The name of the module to check for availability.
238
+
239
+ Returns:
240
+ bool: True if the module can be imported, False otherwise.
241
+ """
242
+ try:
243
+ importlib.import_module(module_name)
244
+ return True
245
+ except ImportError:
246
+ return False
247
+
248
+
249
+ def api_keys_required(*required_keys: str) -> Callable[[F], F]:
250
+ r"""A decorator to check if the required API keys are
251
+ present in the environment variables.
252
+
253
+ Args:
254
+ required_keys (str): The required API keys to be checked.
255
+
256
+ Returns:
257
+ Callable[[F], F]: The original function with the added check
258
+ for required API keys.
259
+
260
+ Raises:
261
+ ValueError: If any of the required API keys are missing in the
262
+ environment variables.
263
+
264
+ Example:
265
+ ::
266
+
267
+ @api_keys_required('API_KEY_1', 'API_KEY_2')
268
+ def some_api_function():
269
+ # Function implementation...
270
+ """
271
+
272
+ def decorator(func: F) -> F:
273
+ @wraps(func)
274
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
275
+ missing_keys = [k for k in required_keys if k not in os.environ]
276
+ if missing_keys:
277
+ raise ValueError(f"Missing API keys: {', '.join(missing_keys)}")
278
+ return func(*args, **kwargs)
279
+
280
+ return cast(F, wrapper)
281
+
282
+ return decorator
283
+
284
+
285
+ def get_system_information():
286
+ r"""Gathers information about the operating system.
287
+
288
+ Returns:
289
+ dict: A dictionary containing various pieces of OS information.
290
+ """
291
+ sys_info = {
292
+ "OS Name": os.name,
293
+ "System": platform.system(),
294
+ "Release": platform.release(),
295
+ "Version": platform.version(),
296
+ "Machine": platform.machine(),
297
+ "Processor": platform.processor(),
298
+ "Platform": platform.platform(),
299
+ }
300
+
301
+ return sys_info
302
+
303
+
304
+ def to_pascal(snake: str) -> str:
305
+ """Convert a snake_case string to PascalCase.
306
+
307
+ Args:
308
+ snake (str): The snake_case string to be converted.
309
+
310
+ Returns:
311
+ str: The converted PascalCase string.
312
+ """
313
+ # Check if the string is already in PascalCase
314
+ if re.match(r'^[A-Z][a-zA-Z0-9]*([A-Z][a-zA-Z0-9]*)*$', snake):
315
+ return snake
316
+ # Remove leading and trailing underscores
317
+ snake = snake.strip('_')
318
+ # Replace multiple underscores with a single one
319
+ snake = re.sub('_+', '_', snake)
320
+ # Convert to PascalCase
321
+ return re.sub(
322
+ '_([0-9A-Za-z])',
323
+ lambda m: m.group(1).upper(),
324
+ snake.title(),
325
+ )
326
+
327
+
328
+ PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
@@ -11,11 +11,29 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14
+
15
+ from __future__ import annotations
16
+
17
+ import base64
14
18
  from abc import ABC, abstractmethod
15
- from typing import List
19
+ from io import BytesIO
20
+ from math import ceil
21
+ from typing import TYPE_CHECKING, List, Optional
22
+
23
+ from anthropic import Anthropic
24
+ from PIL import Image
16
25
 
17
- from camel.messages import OpenAIMessage
18
- from camel.types import ModelType
26
+ from camel.types import ModelType, OpenAIImageDetailType, OpenAIImageType
27
+
28
+ if TYPE_CHECKING:
29
+ from camel.messages import OpenAIMessage
30
+
31
+ LOW_DETAIL_TOKENS = 85
32
+ FIT_SQUARE_PIXELS = 2048
33
+ SHORTEST_SIDE_PIXELS = 768
34
+ SQUARE_PIXELS = 512
35
+ SQUARE_TOKENS = 170
36
+ EXTRA_TOKENS = 85
19
37
 
20
38
 
21
39
  def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
@@ -45,8 +63,10 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
45
63
  content = msg["content"]
46
64
  if content:
47
65
  if not isinstance(content, str):
48
- raise ValueError("Currently multimodal context is not "
49
- "supported by the token counter.")
66
+ raise ValueError(
67
+ "Currently multimodal context is not "
68
+ "supported by the token counter."
69
+ )
50
70
  if i == 0:
51
71
  ret += system_prompt + content
52
72
  else:
@@ -64,8 +84,10 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
64
84
  role = role_map[msg["role"]]
65
85
  content = msg["content"]
66
86
  if not isinstance(content, str):
67
- raise ValueError("Currently multimodal context is not "
68
- "supported by the token counter.")
87
+ raise ValueError(
88
+ "Currently multimodal context is not "
89
+ "supported by the token counter."
90
+ )
69
91
  if content:
70
92
  ret += role + ": " + content + seps[i % 2]
71
93
  else:
@@ -85,6 +107,7 @@ def get_model_encoding(value_for_tiktoken: str):
85
107
  tiktoken.Encoding: Model encoding.
86
108
  """
87
109
  import tiktoken
110
+
88
111
  try:
89
112
  encoding = tiktoken.encoding_for_model(value_for_tiktoken)
90
113
  except KeyError:
@@ -111,7 +134,6 @@ class BaseTokenCounter(ABC):
111
134
 
112
135
 
113
136
  class OpenSourceTokenCounter(BaseTokenCounter):
114
-
115
137
  def __init__(self, model_type: ModelType, model_path: str):
116
138
  r"""Constructor for the token counter for open-source models.
117
139
 
@@ -126,6 +148,7 @@ class OpenSourceTokenCounter(BaseTokenCounter):
126
148
  # If a fast tokenizer is not available for a given model,
127
149
  # a normal Python-based tokenizer is returned instead.
128
150
  from transformers import AutoTokenizer
151
+
129
152
  try:
130
153
  tokenizer = AutoTokenizer.from_pretrained(
131
154
  model_path,
@@ -136,10 +159,11 @@ class OpenSourceTokenCounter(BaseTokenCounter):
136
159
  model_path,
137
160
  use_fast=False,
138
161
  )
139
- except:
162
+ except Exception:
140
163
  raise ValueError(
141
164
  f"Invalid `model_path` ({model_path}) is provided. "
142
- "Tokenizer loading failed.")
165
+ "Tokenizer loading failed."
166
+ )
143
167
 
144
168
  self.tokenizer = tokenizer
145
169
  self.model_type = model_type
@@ -162,13 +186,11 @@ class OpenSourceTokenCounter(BaseTokenCounter):
162
186
 
163
187
 
164
188
  class OpenAITokenCounter(BaseTokenCounter):
165
-
166
189
  def __init__(self, model: ModelType):
167
190
  r"""Constructor for the token counter for OpenAI models.
168
191
 
169
192
  Args:
170
- model_type (ModelType): Model type for which tokens will be
171
- counted.
193
+ model (ModelType): Model type for which tokens will be counted.
172
194
  """
173
195
  self.model: str = model.value_for_tiktoken
174
196
 
@@ -192,7 +214,8 @@ class OpenAITokenCounter(BaseTokenCounter):
192
214
  "for information on how messages are converted to tokens. "
193
215
  "See https://platform.openai.com/docs/models/gpt-4"
194
216
  "or https://platform.openai.com/docs/models/gpt-3-5"
195
- "for information about openai chat models.")
217
+ "for information about openai chat models."
218
+ )
196
219
 
197
220
  self.encoding = get_model_encoding(self.model)
198
221
 
@@ -211,10 +234,107 @@ class OpenAITokenCounter(BaseTokenCounter):
211
234
  for message in messages:
212
235
  num_tokens += self.tokens_per_message
213
236
  for key, value in message.items():
214
- num_tokens += len(self.encoding.encode(str(value)))
237
+ if not isinstance(value, list):
238
+ num_tokens += len(self.encoding.encode(str(value)))
239
+ else:
240
+ for item in value:
241
+ if item["type"] == "text":
242
+ num_tokens += len(
243
+ self.encoding.encode(str(item["text"]))
244
+ )
245
+ elif item["type"] == "image_url":
246
+ image_str: str = item["image_url"]["url"]
247
+ detail = item["image_url"]["detail"]
248
+ image_prefix_format = "data:image/{};base64,"
249
+ image_prefix: Optional[str] = None
250
+ for image_type in list(OpenAIImageType):
251
+ # Find the correct image format
252
+ image_prefix = image_prefix_format.format(
253
+ image_type.value
254
+ )
255
+ if image_prefix in image_str:
256
+ break
257
+ assert isinstance(image_prefix, str)
258
+ encoded_image = image_str.split(image_prefix)[1]
259
+ image_bytes = BytesIO(
260
+ base64.b64decode(encoded_image)
261
+ )
262
+ image = Image.open(image_bytes)
263
+ num_tokens += count_tokens_from_image(
264
+ image, OpenAIImageDetailType(detail)
265
+ )
215
266
  if key == "name":
216
267
  num_tokens += self.tokens_per_name
217
268
 
218
269
  # every reply is primed with <|start|>assistant<|message|>
219
270
  num_tokens += 3
220
271
  return num_tokens
272
+
273
+
274
+ class AnthropicTokenCounter(BaseTokenCounter):
275
+ def __init__(self, model_type: ModelType):
276
+ r"""Constructor for the token counter for Anthropic models.
277
+
278
+ Args:
279
+ model_type (ModelType): Model type for which tokens will be
280
+ counted.
281
+ """
282
+
283
+ self.model_type = model_type
284
+ self.client = Anthropic()
285
+ self.tokenizer = self.client.get_tokenizer()
286
+
287
+ def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
288
+ r"""Count number of tokens in the provided message list using
289
+ loaded tokenizer specific for this type of model.
290
+
291
+ Args:
292
+ messages (List[OpenAIMessage]): Message list with the chat history
293
+ in OpenAI API format.
294
+
295
+ Returns:
296
+ int: Number of tokens in the messages.
297
+ """
298
+ prompt = messages_to_prompt(messages, self.model_type)
299
+
300
+ return self.client.count_tokens(prompt)
301
+
302
+
303
+ def count_tokens_from_image(
304
+ image: Image.Image, detail: OpenAIImageDetailType
305
+ ) -> int:
306
+ r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
307
+ resolution model will be treated as :obj:`"high"`. All images with
308
+ :obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
309
+ are first scaled to fit within a 2048 x 2048 square, maintaining their
310
+ aspect ratio. Then, they are scaled such that the shortest side of the
311
+ image is 768px long. Finally, we count how many 512px squares the image
312
+ consists of. Each of those squares costs 170 tokens. Another 85 tokens are
313
+ always added to the final total. For more details please refer to `OpenAI
314
+ vision docs <https://platform.openai.com/docs/guides/vision>`_
315
+
316
+ Args:
317
+ image (PIL.Image.Image): Image to count number of tokens.
318
+ detail (OpenAIImageDetailType): Image detail type to count
319
+ number of tokens.
320
+
321
+ Returns:
322
+ int: Number of tokens for the image given a detail type.
323
+ """
324
+ if detail == OpenAIImageDetailType.LOW:
325
+ return LOW_DETAIL_TOKENS
326
+
327
+ width, height = image.size
328
+ if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
329
+ scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
330
+ width = int(width / scaling_factor)
331
+ height = int(height / scaling_factor)
332
+
333
+ scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
334
+ scaled_width = int(width / scaling_factor)
335
+ scaled_height = int(height / scaling_factor)
336
+
337
+ h = ceil(scaled_height / SQUARE_PIXELS)
338
+ w = ceil(scaled_width / SQUARE_PIXELS)
339
+ total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
340
+ return total