camel-ai 0.1.5.5__py3-none-any.whl → 0.1.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (97) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +3 -3
  3. camel/agents/critic_agent.py +1 -1
  4. camel/agents/deductive_reasoner_agent.py +4 -4
  5. camel/agents/embodied_agent.py +1 -1
  6. camel/agents/knowledge_graph_agent.py +13 -17
  7. camel/agents/role_assignment_agent.py +1 -1
  8. camel/agents/search_agent.py +4 -5
  9. camel/agents/task_agent.py +5 -6
  10. camel/configs/__init__.py +15 -0
  11. camel/configs/gemini_config.py +98 -0
  12. camel/configs/groq_config.py +119 -0
  13. camel/configs/litellm_config.py +1 -1
  14. camel/configs/mistral_config.py +81 -0
  15. camel/configs/ollama_config.py +1 -1
  16. camel/configs/openai_config.py +1 -1
  17. camel/configs/vllm_config.py +103 -0
  18. camel/configs/zhipuai_config.py +1 -1
  19. camel/embeddings/__init__.py +2 -0
  20. camel/embeddings/mistral_embedding.py +89 -0
  21. camel/interpreters/__init__.py +2 -0
  22. camel/interpreters/ipython_interpreter.py +167 -0
  23. camel/models/__init__.py +10 -0
  24. camel/models/anthropic_model.py +7 -2
  25. camel/models/azure_openai_model.py +152 -0
  26. camel/models/base_model.py +9 -2
  27. camel/models/gemini_model.py +215 -0
  28. camel/models/groq_model.py +131 -0
  29. camel/models/litellm_model.py +26 -4
  30. camel/models/mistral_model.py +169 -0
  31. camel/models/model_factory.py +33 -5
  32. camel/models/ollama_model.py +21 -2
  33. camel/models/open_source_model.py +11 -3
  34. camel/models/openai_model.py +7 -2
  35. camel/models/stub_model.py +4 -4
  36. camel/models/vllm_model.py +138 -0
  37. camel/models/zhipuai_model.py +7 -4
  38. camel/prompts/__init__.py +2 -2
  39. camel/prompts/task_prompt_template.py +4 -4
  40. camel/prompts/{descripte_video_prompt.py → video_description_prompt.py} +1 -1
  41. camel/retrievers/auto_retriever.py +2 -0
  42. camel/storages/graph_storages/neo4j_graph.py +5 -0
  43. camel/toolkits/__init__.py +36 -0
  44. camel/toolkits/base.py +1 -1
  45. camel/toolkits/code_execution.py +1 -1
  46. camel/toolkits/github_toolkit.py +3 -2
  47. camel/toolkits/google_maps_toolkit.py +367 -0
  48. camel/toolkits/math_toolkit.py +79 -0
  49. camel/toolkits/open_api_toolkit.py +548 -0
  50. camel/toolkits/retrieval_toolkit.py +76 -0
  51. camel/toolkits/search_toolkit.py +326 -0
  52. camel/toolkits/slack_toolkit.py +308 -0
  53. camel/toolkits/twitter_toolkit.py +522 -0
  54. camel/toolkits/weather_toolkit.py +173 -0
  55. camel/types/enums.py +163 -30
  56. camel/utils/__init__.py +4 -0
  57. camel/utils/async_func.py +1 -1
  58. camel/utils/token_counting.py +182 -40
  59. {camel_ai-0.1.5.5.dist-info → camel_ai-0.1.5.9.dist-info}/METADATA +43 -3
  60. camel_ai-0.1.5.9.dist-info/RECORD +165 -0
  61. camel/functions/__init__.py +0 -51
  62. camel/functions/google_maps_function.py +0 -335
  63. camel/functions/math_functions.py +0 -61
  64. camel/functions/open_api_function.py +0 -508
  65. camel/functions/retrieval_functions.py +0 -61
  66. camel/functions/search_functions.py +0 -298
  67. camel/functions/slack_functions.py +0 -286
  68. camel/functions/twitter_function.py +0 -479
  69. camel/functions/weather_functions.py +0 -144
  70. camel_ai-0.1.5.5.dist-info/RECORD +0 -155
  71. /camel/{functions → toolkits}/open_api_specs/biztoc/__init__.py +0 -0
  72. /camel/{functions → toolkits}/open_api_specs/biztoc/ai-plugin.json +0 -0
  73. /camel/{functions → toolkits}/open_api_specs/biztoc/openapi.yaml +0 -0
  74. /camel/{functions → toolkits}/open_api_specs/coursera/__init__.py +0 -0
  75. /camel/{functions → toolkits}/open_api_specs/coursera/openapi.yaml +0 -0
  76. /camel/{functions → toolkits}/open_api_specs/create_qr_code/__init__.py +0 -0
  77. /camel/{functions → toolkits}/open_api_specs/create_qr_code/openapi.yaml +0 -0
  78. /camel/{functions → toolkits}/open_api_specs/klarna/__init__.py +0 -0
  79. /camel/{functions → toolkits}/open_api_specs/klarna/openapi.yaml +0 -0
  80. /camel/{functions → toolkits}/open_api_specs/nasa_apod/__init__.py +0 -0
  81. /camel/{functions → toolkits}/open_api_specs/nasa_apod/openapi.yaml +0 -0
  82. /camel/{functions → toolkits}/open_api_specs/outschool/__init__.py +0 -0
  83. /camel/{functions → toolkits}/open_api_specs/outschool/ai-plugin.json +0 -0
  84. /camel/{functions → toolkits}/open_api_specs/outschool/openapi.yaml +0 -0
  85. /camel/{functions → toolkits}/open_api_specs/outschool/paths/__init__.py +0 -0
  86. /camel/{functions → toolkits}/open_api_specs/outschool/paths/get_classes.py +0 -0
  87. /camel/{functions → toolkits}/open_api_specs/outschool/paths/search_teachers.py +0 -0
  88. /camel/{functions → toolkits}/open_api_specs/security_config.py +0 -0
  89. /camel/{functions → toolkits}/open_api_specs/speak/__init__.py +0 -0
  90. /camel/{functions → toolkits}/open_api_specs/speak/openapi.yaml +0 -0
  91. /camel/{functions → toolkits}/open_api_specs/web_scraper/__init__.py +0 -0
  92. /camel/{functions → toolkits}/open_api_specs/web_scraper/ai-plugin.json +0 -0
  93. /camel/{functions → toolkits}/open_api_specs/web_scraper/openapi.yaml +0 -0
  94. /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/__init__.py +0 -0
  95. /camel/{functions → toolkits}/open_api_specs/web_scraper/paths/scraper.py +0 -0
  96. /camel/{functions → toolkits}/openai_function.py +0 -0
  97. {camel_ai-0.1.5.5.dist-info → camel_ai-0.1.5.9.dist-info}/WHEEL +0 -0
camel/types/enums.py CHANGED
@@ -29,11 +29,22 @@ class ModelType(Enum):
29
29
  GPT_4_32K = "gpt-4-32k"
30
30
  GPT_4_TURBO = "gpt-4-turbo"
31
31
  GPT_4O = "gpt-4o"
32
+ GPT_4O_MINI = "gpt-4o-mini"
33
+
32
34
  GLM_4 = "glm-4"
33
35
  GLM_4_OPEN_SOURCE = "glm-4-open-source"
34
36
  GLM_4V = 'glm-4v'
35
37
  GLM_3_TURBO = "glm-3-turbo"
36
38
 
39
+ GROQ_LLAMA_3_1_8B = "llama-3.1-8b-instant"
40
+ GROQ_LLAMA_3_1_70B = "llama-3.1-70b-versatile"
41
+ GROQ_LLAMA_3_1_405B = "llama-3.1-405b-reasoning"
42
+ GROQ_LLAMA_3_8B = "llama3-8b-8192"
43
+ GROQ_LLAMA_3_70B = "llama3-70b-8192"
44
+ GROQ_MIXTRAL_8_7B = "mixtral-8x7b-32768"
45
+ GROQ_GEMMA_7B_IT = "gemma-7b-it"
46
+ GROQ_GEMMA_2_9B_IT = "gemma2-9b-it"
47
+
37
48
  STUB = "stub"
38
49
 
39
50
  LLAMA_2 = "llama-2"
@@ -44,7 +55,7 @@ class ModelType(Enum):
44
55
  QWEN_2 = "qwen-2"
45
56
 
46
57
  # Legacy anthropic models
47
- # NOTE: anthropic lagecy models only Claude 2.1 has system prompt support
58
+ # NOTE: anthropic legacy models only Claude 2.1 has system prompt support
48
59
  CLAUDE_2_1 = "claude-2.1"
49
60
  CLAUDE_2_0 = "claude-2.0"
50
61
  CLAUDE_INSTANT_1_2 = "claude-instant-1.2"
@@ -58,6 +69,19 @@ class ModelType(Enum):
58
69
  # Nvidia models
59
70
  NEMOTRON_4_REWARD = "nvidia/nemotron-4-340b-reward"
60
71
 
72
+ # Gemini models
73
+ GEMINI_1_5_FLASH = "gemini-1.5-flash"
74
+ GEMINI_1_5_PRO = "gemini-1.5-pro"
75
+
76
+ # Mistral AI Model
77
+ MISTRAL_LARGE = "mistral-large-latest"
78
+ MISTRAL_NEMO = "open-mistral-nemo"
79
+ MISTRAL_CODESTRAL = "codestral-latest"
80
+ MISTRAL_7B = "open-mistral-7b"
81
+ MISTRAL_MIXTRAL_8x7B = "open-mixtral-8x7b"
82
+ MISTRAL_MIXTRAL_8x22B = "open-mixtral-8x22b"
83
+ MISTRAL_CODESTRAL_MAMBA = "open-codestral-mamba"
84
+
61
85
  @property
62
86
  def value_for_tiktoken(self) -> str:
63
87
  return (
@@ -69,6 +93,20 @@ class ModelType(Enum):
69
93
  @property
70
94
  def is_openai(self) -> bool:
71
95
  r"""Returns whether this type of models is an OpenAI-released model."""
96
+ return self in {
97
+ ModelType.GPT_3_5_TURBO,
98
+ ModelType.GPT_4,
99
+ ModelType.GPT_4_32K,
100
+ ModelType.GPT_4_TURBO,
101
+ ModelType.GPT_4O,
102
+ ModelType.GPT_4O_MINI,
103
+ }
104
+
105
+ @property
106
+ def is_azure_openai(self) -> bool:
107
+ r"""Returns whether this type of models is an OpenAI-released model
108
+ from Azure.
109
+ """
72
110
  return self in {
73
111
  ModelType.GPT_3_5_TURBO,
74
112
  ModelType.GPT_4,
@@ -115,6 +153,33 @@ class ModelType(Enum):
115
153
  ModelType.CLAUDE_3_5_SONNET,
116
154
  }
117
155
 
156
+ @property
157
+ def is_groq(self) -> bool:
158
+ r"""Returns whether this type of models is served by Groq."""
159
+ return self in {
160
+ ModelType.GROQ_LLAMA_3_1_8B,
161
+ ModelType.GROQ_LLAMA_3_1_70B,
162
+ ModelType.GROQ_LLAMA_3_1_405B,
163
+ ModelType.GROQ_LLAMA_3_8B,
164
+ ModelType.GROQ_LLAMA_3_70B,
165
+ ModelType.GROQ_MIXTRAL_8_7B,
166
+ ModelType.GROQ_GEMMA_7B_IT,
167
+ ModelType.GROQ_GEMMA_2_9B_IT,
168
+ }
169
+
170
+ @property
171
+ def is_mistral(self) -> bool:
172
+ r"""Returns whether this type of models is served by Mistral."""
173
+ return self in {
174
+ ModelType.MISTRAL_LARGE,
175
+ ModelType.MISTRAL_NEMO,
176
+ ModelType.MISTRAL_CODESTRAL,
177
+ ModelType.MISTRAL_7B,
178
+ ModelType.MISTRAL_MIXTRAL_8x7B,
179
+ ModelType.MISTRAL_MIXTRAL_8x22B,
180
+ ModelType.MISTRAL_CODESTRAL_MAMBA,
181
+ }
182
+
118
183
  @property
119
184
  def is_nvidia(self) -> bool:
120
185
  r"""Returns whether this type of models is Nvidia-released model.
@@ -126,45 +191,70 @@ class ModelType(Enum):
126
191
  ModelType.NEMOTRON_4_REWARD,
127
192
  }
128
193
 
194
+ @property
195
+ def is_gemini(self) -> bool:
196
+ return self in {ModelType.GEMINI_1_5_FLASH, ModelType.GEMINI_1_5_PRO}
197
+
129
198
  @property
130
199
  def token_limit(self) -> int:
131
200
  r"""Returns the maximum token limit for a given model.
201
+
132
202
  Returns:
133
203
  int: The maximum token limit for the given model.
134
204
  """
135
- if self is ModelType.GPT_3_5_TURBO:
136
- return 16385
137
- elif self is ModelType.GPT_4:
138
- return 8192
139
- elif self is ModelType.GPT_4_32K:
140
- return 32768
141
- elif self is ModelType.GPT_4_TURBO:
142
- return 128000
143
- elif self is ModelType.GPT_4O:
144
- return 128000
145
- elif self == ModelType.GLM_4_OPEN_SOURCE:
146
- return 8192
147
- elif self == ModelType.GLM_3_TURBO:
148
- return 8192
149
- elif self == ModelType.GLM_4V:
205
+ if self is ModelType.GLM_4V:
150
206
  return 1024
151
- elif self is ModelType.STUB:
152
- return 4096
153
- elif self is ModelType.LLAMA_2:
154
- return 4096
155
- elif self is ModelType.LLAMA_3:
156
- return 8192
157
- elif self is ModelType.QWEN_2:
158
- return 128000
159
- elif self is ModelType.GLM_4:
160
- return 8192
161
207
  elif self is ModelType.VICUNA:
162
208
  # reference: https://lmsys.org/blog/2023-03-30-vicuna/
163
209
  return 2048
210
+ elif self in {
211
+ ModelType.GPT_3_5_TURBO,
212
+ ModelType.LLAMA_2,
213
+ ModelType.NEMOTRON_4_REWARD,
214
+ ModelType.STUB,
215
+ }:
216
+ return 4_096
217
+ elif self in {
218
+ ModelType.GPT_4,
219
+ ModelType.GROQ_LLAMA_3_8B,
220
+ ModelType.GROQ_LLAMA_3_70B,
221
+ ModelType.GROQ_GEMMA_7B_IT,
222
+ ModelType.GROQ_GEMMA_2_9B_IT,
223
+ ModelType.LLAMA_3,
224
+ ModelType.GLM_3_TURBO,
225
+ ModelType.GLM_4,
226
+ ModelType.GLM_4_OPEN_SOURCE,
227
+ }:
228
+ return 8_192
164
229
  elif self is ModelType.VICUNA_16K:
165
- return 16384
230
+ return 16_384
231
+ elif self in {
232
+ ModelType.GPT_4_32K,
233
+ ModelType.MISTRAL_CODESTRAL,
234
+ ModelType.MISTRAL_7B,
235
+ ModelType.MISTRAL_MIXTRAL_8x7B,
236
+ ModelType.GROQ_MIXTRAL_8_7B,
237
+ }:
238
+ return 32_768
239
+ elif self in {ModelType.MISTRAL_MIXTRAL_8x22B}:
240
+ return 64_000
166
241
  elif self in {ModelType.CLAUDE_2_0, ModelType.CLAUDE_INSTANT_1_2}:
167
242
  return 100_000
243
+ elif self in {
244
+ ModelType.GPT_4O,
245
+ ModelType.GPT_4O_MINI,
246
+ ModelType.GPT_4_TURBO,
247
+ ModelType.MISTRAL_LARGE,
248
+ ModelType.MISTRAL_NEMO,
249
+ ModelType.QWEN_2,
250
+ }:
251
+ return 128_000
252
+ elif self in {
253
+ ModelType.GROQ_LLAMA_3_1_8B,
254
+ ModelType.GROQ_LLAMA_3_1_70B,
255
+ ModelType.GROQ_LLAMA_3_1_405B,
256
+ }:
257
+ return 131_072
168
258
  elif self in {
169
259
  ModelType.CLAUDE_2_1,
170
260
  ModelType.CLAUDE_3_OPUS,
@@ -173,8 +263,12 @@ class ModelType(Enum):
173
263
  ModelType.CLAUDE_3_5_SONNET,
174
264
  }:
175
265
  return 200_000
176
- elif self is ModelType.NEMOTRON_4_REWARD:
177
- return 4096
266
+ elif self in {
267
+ ModelType.MISTRAL_CODESTRAL_MAMBA,
268
+ }:
269
+ return 256_000
270
+ elif self in {ModelType.GEMINI_1_5_FLASH, ModelType.GEMINI_1_5_PRO}:
271
+ return 1_048_576
178
272
  else:
179
273
  raise ValueError("Unknown model type")
180
274
 
@@ -183,8 +277,9 @@ class ModelType(Enum):
183
277
 
184
278
  Args:
185
279
  model_name (str): The name of the model, e.g. "vicuna-7b-v1.5".
280
+
186
281
  Returns:
187
- bool: Whether the model type mathches the model name.
282
+ bool: Whether the model type matches the model name.
188
283
  """
189
284
  if self is ModelType.VICUNA:
190
285
  pattern = r'^vicuna-\d+b-v\d+\.\d+$'
@@ -220,6 +315,8 @@ class EmbeddingModelType(Enum):
220
315
  TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small"
221
316
  TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large"
222
317
 
318
+ MISTRAL_EMBED = "mistral-embed"
319
+
223
320
  @property
224
321
  def is_openai(self) -> bool:
225
322
  r"""Returns whether this type of models is an OpenAI-released model."""
@@ -229,6 +326,15 @@ class EmbeddingModelType(Enum):
229
326
  EmbeddingModelType.TEXT_EMBEDDING_3_LARGE,
230
327
  }
231
328
 
329
+ @property
330
+ def is_mistral(self) -> bool:
331
+ r"""Returns whether this type of models is an Mistral-released
332
+ model.
333
+ """
334
+ return self in {
335
+ EmbeddingModelType.MISTRAL_EMBED,
336
+ }
337
+
232
338
  @property
233
339
  def output_dim(self) -> int:
234
340
  if self is EmbeddingModelType.TEXT_EMBEDDING_ADA_2:
@@ -237,6 +343,8 @@ class EmbeddingModelType(Enum):
237
343
  return 1536
238
344
  elif self is EmbeddingModelType.TEXT_EMBEDDING_3_LARGE:
239
345
  return 3072
346
+ elif self is EmbeddingModelType.MISTRAL_EMBED:
347
+ return 1024
240
348
  else:
241
349
  raise ValueError(f"Unknown model type {self}.")
242
350
 
@@ -273,6 +381,7 @@ class OpenAIBackendRole(Enum):
273
381
  SYSTEM = "system"
274
382
  USER = "user"
275
383
  FUNCTION = "function"
384
+ TOOL = "tool"
276
385
 
277
386
 
278
387
  class TerminationMode(Enum):
@@ -326,11 +435,15 @@ class ModelPlatformType(Enum):
326
435
  OPENAI = "openai"
327
436
  AZURE = "azure"
328
437
  ANTHROPIC = "anthropic"
438
+ GROQ = "groq"
329
439
  OPENSOURCE = "opensource"
330
440
  OLLAMA = "ollama"
331
441
  LITELLM = "litellm"
332
442
  ZHIPU = "zhipuai"
333
443
  DEFAULT = "default"
444
+ GEMINI = "gemini"
445
+ VLLM = "vllm"
446
+ MISTRAL = "mistral"
334
447
 
335
448
  @property
336
449
  def is_openai(self) -> bool:
@@ -347,11 +460,21 @@ class ModelPlatformType(Enum):
347
460
  r"""Returns whether this platform is anthropic."""
348
461
  return self is ModelPlatformType.ANTHROPIC
349
462
 
463
+ @property
464
+ def is_groq(self) -> bool:
465
+ r"""Returns whether this platform is groq."""
466
+ return self is ModelPlatformType.GROQ
467
+
350
468
  @property
351
469
  def is_ollama(self) -> bool:
352
470
  r"""Returns whether this platform is ollama."""
353
471
  return self is ModelPlatformType.OLLAMA
354
472
 
473
+ @property
474
+ def is_vllm(self) -> bool:
475
+ r"""Returns whether this platform is vllm."""
476
+ return self is ModelPlatformType.VLLM
477
+
355
478
  @property
356
479
  def is_litellm(self) -> bool:
357
480
  r"""Returns whether this platform is litellm."""
@@ -362,11 +485,21 @@ class ModelPlatformType(Enum):
362
485
  r"""Returns whether this platform is zhipu."""
363
486
  return self is ModelPlatformType.ZHIPU
364
487
 
488
+ @property
489
+ def is_mistral(self) -> bool:
490
+ r"""Returns whether this platform is mistral."""
491
+ return self is ModelPlatformType.MISTRAL
492
+
365
493
  @property
366
494
  def is_open_source(self) -> bool:
367
495
  r"""Returns whether this platform is opensource."""
368
496
  return self is ModelPlatformType.OPENSOURCE
369
497
 
498
+ @property
499
+ def is_gemini(self) -> bool:
500
+ r"""Returns whether this platform is Gemini."""
501
+ return self is ModelPlatformType.GEMINI
502
+
370
503
 
371
504
  class AudioModelType(Enum):
372
505
  TTS_1 = "tts-1"
camel/utils/__init__.py CHANGED
@@ -32,7 +32,9 @@ from .constants import Constants
32
32
  from .token_counting import (
33
33
  AnthropicTokenCounter,
34
34
  BaseTokenCounter,
35
+ GeminiTokenCounter,
35
36
  LiteLLMTokenCounter,
37
+ MistralTokenCounter,
36
38
  OpenAITokenCounter,
37
39
  OpenSourceTokenCounter,
38
40
  get_model_encoding,
@@ -60,4 +62,6 @@ __all__ = [
60
62
  'dependencies_required',
61
63
  'api_keys_required',
62
64
  'is_docker_running',
65
+ 'GeminiTokenCounter',
66
+ 'MistralTokenCounter',
63
67
  ]
camel/utils/async_func.py CHANGED
@@ -14,7 +14,7 @@
14
14
  import asyncio
15
15
  from copy import deepcopy
16
16
 
17
- from camel.functions.openai_function import OpenAIFunction
17
+ from camel.toolkits import OpenAIFunction
18
18
 
19
19
 
20
20
  def sync_funcs_to_async(funcs: list[OpenAIFunction]) -> list[OpenAIFunction]:
@@ -26,6 +26,8 @@ from PIL import Image
26
26
  from camel.types import ModelType, OpenAIImageType, OpenAIVisionDetailType
27
27
 
28
28
  if TYPE_CHECKING:
29
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
30
+
29
31
  from camel.messages import OpenAIMessage
30
32
 
31
33
  LOW_DETAIL_TOKENS = 85
@@ -37,7 +39,7 @@ EXTRA_TOKENS = 85
37
39
 
38
40
 
39
41
  def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
40
- r"""Parse the message list into a single prompt following model-specifc
42
+ r"""Parse the message list into a single prompt following model-specific
41
43
  formats.
42
44
 
43
45
  Args:
@@ -51,7 +53,12 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
51
53
  system_message = messages[0]["content"]
52
54
 
53
55
  ret: str
54
- if model == ModelType.LLAMA_2 or model == ModelType.LLAMA_3:
56
+ if model in [
57
+ ModelType.LLAMA_2,
58
+ ModelType.LLAMA_3,
59
+ ModelType.GROQ_LLAMA_3_8B,
60
+ ModelType.GROQ_LLAMA_3_70B,
61
+ ]:
55
62
  # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
56
63
  seps = [" ", " </s><s>"]
57
64
  role_map = {"user": "[INST]", "assistant": "[/INST]"}
@@ -74,7 +81,7 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
74
81
  else:
75
82
  ret += role
76
83
  return ret
77
- elif model == ModelType.VICUNA or model == ModelType.VICUNA_16K:
84
+ elif model in [ModelType.VICUNA, ModelType.VICUNA_16K]:
78
85
  seps = [" ", "</s>"]
79
86
  role_map = {"user": "USER", "assistant": "ASSISTANT"}
80
87
 
@@ -132,6 +139,40 @@ def messages_to_prompt(messages: List[OpenAIMessage], model: ModelType) -> str:
132
139
  else:
133
140
  ret += '<|im_start|>' + role + '\n'
134
141
  return ret
142
+ elif model == ModelType.GROQ_MIXTRAL_8_7B:
143
+ # Mistral/Mixtral format
144
+ system_prompt = f"<s>[INST] {system_message} [/INST]\n"
145
+ ret = system_prompt
146
+
147
+ for msg in messages[1:]:
148
+ if msg["role"] == "user":
149
+ ret += f"[INST] {msg['content']} [/INST]\n"
150
+ elif msg["role"] == "assistant":
151
+ ret += f"{msg['content']}</s>\n"
152
+
153
+ if not isinstance(msg['content'], str):
154
+ raise ValueError(
155
+ "Currently multimodal context is not "
156
+ "supported by the token counter."
157
+ )
158
+
159
+ return ret.strip()
160
+ elif model in [ModelType.GROQ_GEMMA_7B_IT, ModelType.GROQ_GEMMA_2_9B_IT]:
161
+ # Gemma format
162
+ ret = f"<bos>{system_message}\n"
163
+ for msg in messages:
164
+ if msg["role"] == "user":
165
+ ret += f"Human: {msg['content']}\n"
166
+ elif msg["role"] == "assistant":
167
+ ret += f"Assistant: {msg['content']}\n"
168
+
169
+ if not isinstance(msg['content'], str):
170
+ raise ValueError(
171
+ "Currently multimodal context is not supported by the token counter."
172
+ )
173
+
174
+ ret += "<eos>"
175
+ return ret
135
176
  else:
136
177
  raise ValueError(f"Invalid model type: {model}")
137
178
 
@@ -232,6 +273,7 @@ class OpenAITokenCounter(BaseTokenCounter):
232
273
  model (ModelType): Model type for which tokens will be counted.
233
274
  """
234
275
  self.model: str = model.value_for_tiktoken
276
+ self.model_type = model
235
277
 
236
278
  self.tokens_per_message: int
237
279
  self.tokens_per_name: int
@@ -300,7 +342,7 @@ class OpenAITokenCounter(BaseTokenCounter):
300
342
  base64.b64decode(encoded_image)
301
343
  )
302
344
  image = Image.open(image_bytes)
303
- num_tokens += count_tokens_from_image(
345
+ num_tokens += self._count_tokens_from_image(
304
346
  image, OpenAIVisionDetailType(detail)
305
347
  )
306
348
  if key == "name":
@@ -310,6 +352,45 @@ class OpenAITokenCounter(BaseTokenCounter):
310
352
  num_tokens += 3
311
353
  return num_tokens
312
354
 
355
+ def _count_tokens_from_image(
356
+ self, image: Image.Image, detail: OpenAIVisionDetailType
357
+ ) -> int:
358
+ r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
359
+ resolution model will be treated as :obj:`"high"`. All images with
360
+ :obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
361
+ are first scaled to fit within a 2048 x 2048 square, maintaining their
362
+ aspect ratio. Then, they are scaled such that the shortest side of the
363
+ image is 768px long. Finally, we count how many 512px squares the image
364
+ consists of. Each of those squares costs 170 tokens. Another 85 tokens are
365
+ always added to the final total. For more details please refer to `OpenAI
366
+ vision docs <https://platform.openai.com/docs/guides/vision>`_
367
+
368
+ Args:
369
+ image (PIL.Image.Image): Image to count number of tokens.
370
+ detail (OpenAIVisionDetailType): Image detail type to count
371
+ number of tokens.
372
+
373
+ Returns:
374
+ int: Number of tokens for the image given a detail type.
375
+ """
376
+ if detail == OpenAIVisionDetailType.LOW:
377
+ return LOW_DETAIL_TOKENS
378
+
379
+ width, height = image.size
380
+ if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
381
+ scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
382
+ width = int(width / scaling_factor)
383
+ height = int(height / scaling_factor)
384
+
385
+ scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
386
+ scaled_width = int(width / scaling_factor)
387
+ scaled_height = int(height / scaling_factor)
388
+
389
+ h = ceil(scaled_height / SQUARE_PIXELS)
390
+ w = ceil(scaled_width / SQUARE_PIXELS)
391
+ total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
392
+ return total
393
+
313
394
 
314
395
  class AnthropicTokenCounter(BaseTokenCounter):
315
396
  def __init__(self, model_type: ModelType):
@@ -342,6 +423,40 @@ class AnthropicTokenCounter(BaseTokenCounter):
342
423
  return num_tokens
343
424
 
344
425
 
426
+ class GeminiTokenCounter(BaseTokenCounter):
427
+ def __init__(self, model_type: ModelType):
428
+ r"""Constructor for the token counter for Gemini models."""
429
+ import google.generativeai as genai
430
+
431
+ self.model_type = model_type
432
+ self._client = genai.GenerativeModel(self.model_type.value)
433
+
434
+ def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
435
+ r"""Count number of tokens in the provided message list using
436
+ loaded tokenizer specific for this type of model.
437
+
438
+ Args:
439
+ messages (List[OpenAIMessage]): Message list with the chat history
440
+ in OpenAI API format.
441
+
442
+ Returns:
443
+ int: Number of tokens in the messages.
444
+ """
445
+ converted_messages = []
446
+ for message in messages:
447
+ role = message.get('role')
448
+ if role == 'assistant':
449
+ role_to_gemini = 'model'
450
+ else:
451
+ role_to_gemini = 'user'
452
+ converted_message = {
453
+ "role": role_to_gemini,
454
+ "parts": message.get("content"),
455
+ }
456
+ converted_messages.append(converted_message)
457
+ return self._client.count_tokens(converted_messages).total_tokens
458
+
459
+
345
460
  class LiteLLMTokenCounter:
346
461
  def __init__(self, model_type: str):
347
462
  r"""Constructor for the token counter for LiteLLM models.
@@ -394,41 +509,68 @@ class LiteLLMTokenCounter:
394
509
  return self.completion_cost(completion_response=response)
395
510
 
396
511
 
397
- def count_tokens_from_image(
398
- image: Image.Image, detail: OpenAIVisionDetailType
399
- ) -> int:
400
- r"""Count image tokens for OpenAI vision model. An :obj:`"auto"`
401
- resolution model will be treated as :obj:`"high"`. All images with
402
- :obj:`"low"` detail cost 85 tokens each. Images with :obj:`"high"` detail
403
- are first scaled to fit within a 2048 x 2048 square, maintaining their
404
- aspect ratio. Then, they are scaled such that the shortest side of the
405
- image is 768px long. Finally, we count how many 512px squares the image
406
- consists of. Each of those squares costs 170 tokens. Another 85 tokens are
407
- always added to the final total. For more details please refer to `OpenAI
408
- vision docs <https://platform.openai.com/docs/guides/vision>`_
512
+ class MistralTokenCounter(BaseTokenCounter):
513
+ def __init__(self, model_type: ModelType):
514
+ r"""Constructor for the token counter for Mistral models.
409
515
 
410
- Args:
411
- image (PIL.Image.Image): Image to count number of tokens.
412
- detail (OpenAIVisionDetailType): Image detail type to count
413
- number of tokens.
516
+ Args:
517
+ model_type (ModelType): Model type for which tokens will be
518
+ counted.
519
+ """
520
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
414
521
 
415
- Returns:
416
- int: Number of tokens for the image given a detail type.
417
- """
418
- if detail == OpenAIVisionDetailType.LOW:
419
- return LOW_DETAIL_TOKENS
420
-
421
- width, height = image.size
422
- if width > FIT_SQUARE_PIXELS or height > FIT_SQUARE_PIXELS:
423
- scaling_factor = max(width, height) / FIT_SQUARE_PIXELS
424
- width = int(width / scaling_factor)
425
- height = int(height / scaling_factor)
426
-
427
- scaling_factor = min(width, height) / SHORTEST_SIDE_PIXELS
428
- scaled_width = int(width / scaling_factor)
429
- scaled_height = int(height / scaling_factor)
430
-
431
- h = ceil(scaled_height / SQUARE_PIXELS)
432
- w = ceil(scaled_width / SQUARE_PIXELS)
433
- total = EXTRA_TOKENS + SQUARE_TOKENS * h * w
434
- return total
522
+ self.model_type = model_type
523
+
524
+ # Determine the model type and set the tokenizer accordingly
525
+ model_name = (
526
+ "codestral-22b"
527
+ if self.model_type
528
+ in {ModelType.MISTRAL_CODESTRAL, ModelType.MISTRAL_CODESTRAL_MAMBA}
529
+ else self.model_type.value
530
+ )
531
+
532
+ self.tokenizer = MistralTokenizer.from_model(model_name)
533
+
534
+ def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
535
+ r"""Count number of tokens in the provided message list using
536
+ loaded tokenizer specific for this type of model.
537
+
538
+ Args:
539
+ messages (List[OpenAIMessage]): Message list with the chat history
540
+ in OpenAI API format.
541
+
542
+ Returns:
543
+ int: Total number of tokens in the messages.
544
+ """
545
+ total_tokens = 0
546
+ for msg in messages:
547
+ tokens = self.tokenizer.encode_chat_completion(
548
+ self._convert_response_from_openai_to_mistral(msg)
549
+ ).tokens
550
+ total_tokens += len(tokens)
551
+ return total_tokens
552
+
553
+ def _convert_response_from_openai_to_mistral(
554
+ self, openai_msg: OpenAIMessage
555
+ ) -> ChatCompletionRequest:
556
+ r"""Convert an OpenAI message to a Mistral ChatCompletionRequest.
557
+
558
+ Args:
559
+ openai_msg (OpenAIMessage): An individual message with OpenAI
560
+ format.
561
+
562
+ Returns:
563
+ ChatCompletionRequest: The converted message in Mistral's request
564
+ format.
565
+ """
566
+
567
+ from mistral_common.protocol.instruct.request import (
568
+ ChatCompletionRequest,
569
+ )
570
+
571
+ mistral_request = ChatCompletionRequest( # type: ignore[type-var]
572
+ model=self.model_type.value,
573
+ messages=[openai_msg],
574
+ )
575
+
576
+ return mistral_request