camel-ai 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (55) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +13 -1
  3. camel/benchmarks/__init__.py +18 -0
  4. camel/benchmarks/base.py +152 -0
  5. camel/benchmarks/gaia.py +478 -0
  6. camel/configs/__init__.py +3 -0
  7. camel/configs/ollama_config.py +4 -2
  8. camel/configs/sglang_config.py +71 -0
  9. camel/data_collector/__init__.py +19 -0
  10. camel/data_collector/alpaca_collector.py +127 -0
  11. camel/data_collector/base.py +211 -0
  12. camel/data_collector/sharegpt_collector.py +205 -0
  13. camel/datahubs/__init__.py +23 -0
  14. camel/datahubs/base.py +136 -0
  15. camel/datahubs/huggingface.py +433 -0
  16. camel/datahubs/models.py +22 -0
  17. camel/interpreters/__init__.py +2 -0
  18. camel/interpreters/e2b_interpreter.py +136 -0
  19. camel/loaders/__init__.py +3 -1
  20. camel/loaders/base_io.py +41 -41
  21. camel/messages/__init__.py +2 -0
  22. camel/models/__init__.py +2 -0
  23. camel/models/anthropic_model.py +14 -4
  24. camel/models/base_model.py +28 -0
  25. camel/models/groq_model.py +1 -1
  26. camel/models/model_factory.py +3 -0
  27. camel/models/ollama_model.py +12 -0
  28. camel/models/openai_model.py +0 -26
  29. camel/models/reward/__init__.py +22 -0
  30. camel/models/reward/base_reward_model.py +58 -0
  31. camel/models/reward/evaluator.py +63 -0
  32. camel/models/reward/nemotron_model.py +112 -0
  33. camel/models/sglang_model.py +225 -0
  34. camel/models/vllm_model.py +1 -1
  35. camel/personas/persona_hub.py +2 -2
  36. camel/schemas/openai_converter.py +2 -2
  37. camel/societies/workforce/role_playing_worker.py +2 -2
  38. camel/societies/workforce/single_agent_worker.py +2 -2
  39. camel/societies/workforce/workforce.py +3 -3
  40. camel/storages/object_storages/amazon_s3.py +2 -2
  41. camel/storages/object_storages/azure_blob.py +2 -2
  42. camel/storages/object_storages/google_cloud.py +2 -2
  43. camel/toolkits/__init__.py +2 -0
  44. camel/toolkits/code_execution.py +5 -1
  45. camel/toolkits/function_tool.py +41 -0
  46. camel/toolkits/math_toolkit.py +47 -16
  47. camel/toolkits/search_toolkit.py +154 -2
  48. camel/toolkits/stripe_toolkit.py +273 -0
  49. camel/types/__init__.py +2 -0
  50. camel/types/enums.py +27 -2
  51. camel/utils/token_counting.py +22 -10
  52. {camel_ai-0.2.11.dist-info → camel_ai-0.2.12.dist-info}/METADATA +13 -6
  53. {camel_ai-0.2.11.dist-info → camel_ai-0.2.12.dist-info}/RECORD +55 -36
  54. {camel_ai-0.2.11.dist-info → camel_ai-0.2.12.dist-info}/LICENSE +0 -0
  55. {camel_ai-0.2.11.dist-info → camel_ai-0.2.12.dist-info}/WHEEL +0 -0
@@ -13,7 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
  import os
15
15
  import xml.etree.ElementTree as ET
16
- from typing import Any, Dict, List, Union
16
+ from typing import Any, Dict, List, Optional, TypeAlias, Union
17
17
 
18
18
  import requests
19
19
 
@@ -26,7 +26,7 @@ class SearchToolkit(BaseToolkit):
26
26
  r"""A class representing a toolkit for web search.
27
27
 
28
28
  This class provides methods for searching information on the web using
29
- search engines like Google, DuckDuckGo, Wikipedia and Wolfram Alpha.
29
+ search engines like Google, DuckDuckGo, Wikipedia and Wolfram Alpha, Brave.
30
30
  """
31
31
 
32
32
  @dependencies_required("wikipedia")
@@ -151,6 +151,152 @@ class SearchToolkit(BaseToolkit):
151
151
  # If no answer found, return an empty list
152
152
  return responses
153
153
 
154
+ @api_keys_required("BRAVE_API_KEY")
155
+ def search_brave(
156
+ self,
157
+ q: str,
158
+ country: str = "US",
159
+ search_lang: str = "en",
160
+ ui_lang: str = "en-US",
161
+ count: int = 20,
162
+ offset: int = 0,
163
+ safesearch: str = "moderate",
164
+ freshness: Optional[str] = None,
165
+ text_decorations: bool = True,
166
+ spellcheck: bool = True,
167
+ result_filter: Optional[str] = None,
168
+ goggles_id: Optional[str] = None,
169
+ units: Optional[str] = None,
170
+ extra_snippets: Optional[bool] = None,
171
+ summary: Optional[bool] = None,
172
+ ) -> Dict[str, Any]:
173
+ r"""This function queries the Brave search engine API and returns a
174
+ dictionary, representing a search result.
175
+ See https://api.search.brave.com/app/documentation/web-search/query
176
+ for more details.
177
+
178
+ Args:
179
+ q (str): The user's search query term. Query cannot be empty.
180
+ Maximum of 400 characters and 50 words in the query.
181
+ country (str): The search query country where results come from.
182
+ The country string is limited to 2 character country codes of
183
+ supported countries. For a list of supported values, see
184
+ Country Codes. (default::obj:`US `)
185
+ search_lang (str): The search language preference. The 2 or more
186
+ character language code for which search results are provided.
187
+ For a list of possible values, see Language Codes.
188
+ ui_lang (str): User interface language preferred in response.
189
+ Usually of the format '<language_code>-<country_code>'. For
190
+ more, see RFC 9110. For a list of supported values, see UI
191
+ Language Codes.
192
+ count (int): The number of search results returned in response.
193
+ The maximum is 20. The actual number delivered may be less than
194
+ requested. Combine this parameter with offset to paginate
195
+ search results.
196
+ offset (int): The zero based offset that indicates number of search
197
+ results per page (count) to skip before returning the result.
198
+ The maximum is 9. The actual number delivered may be less than
199
+ requested based on the query. In order to paginate results use
200
+ this parameter together with count. For example, if your user
201
+ interface displays 20 search results per page, set count to 20
202
+ and offset to 0 to show the first page of results. To get
203
+ subsequent pages, increment offset by 1 (e.g. 0, 1, 2). The
204
+ results may overlap across multiple pages.
205
+ safesearch (str): Filters search results for adult content.
206
+ The following values are supported:
207
+ - 'off': No filtering is done.
208
+ - 'moderate': Filters explicit content, like images and videos,
209
+ but allows adult domains in the search results.
210
+ - 'strict': Drops all adult content from search results.
211
+ freshness (Optional[str]): Filters search results by when they were
212
+ discovered:
213
+ - 'pd': Discovered within the last 24 hours.
214
+ - 'pw': Discovered within the last 7 Days.
215
+ - 'pm': Discovered within the last 31 Days.
216
+ - 'py': Discovered within the last 365 Days.
217
+ - 'YYYY-MM-DDtoYYYY-MM-DD': Timeframe is also supported by
218
+ specifying the date range e.g. '2022-04-01to2022-07-30'.
219
+ text_decorations (bool): Whether display strings (e.g. result
220
+ snippets) should include decoration markers (e.g. highlighting
221
+ characters).
222
+ spellcheck (bool): Whether to spellcheck provided query. If the
223
+ spellchecker is enabled, the modified query is always used for
224
+ search. The modified query can be found in altered key from the
225
+ query response model.
226
+ result_filter (Optional[str]): A comma delimited string of result
227
+ types to include in the search response. Not specifying this
228
+ parameter will return back all result types in search response
229
+ where data is available and a plan with the corresponding
230
+ option is subscribed. The response always includes query and
231
+ type to identify any query modifications and response type
232
+ respectively. Available result filter values are:
233
+ - 'discussions'
234
+ - 'faq'
235
+ - 'infobox'
236
+ - 'news'
237
+ - 'query'
238
+ - 'summarizer'
239
+ - 'videos'
240
+ - 'web'
241
+ - 'locations'
242
+ goggles_id (Optional[str]): Goggles act as a custom re-ranking on
243
+ top of Brave's search index. For more details, refer to the
244
+ Goggles repository.
245
+ units (Optional[str]): The measurement units. If not provided,
246
+ units are derived from search country. Possible values are:
247
+ - 'metric': The standardized measurement system
248
+ - 'imperial': The British Imperial system of units.
249
+ extra_snippets (Optional[bool]): A snippet is an excerpt from a
250
+ page you get as a result of the query, and extra_snippets
251
+ allow you to get up to 5 additional, alternative excerpts. Only
252
+ available under Free AI, Base AI, Pro AI, Base Data, Pro Data
253
+ and Custom plans.
254
+ summary (Optional[bool]): This parameter enables summary key
255
+ generation in web search results. This is required for
256
+ summarizer to be enabled.
257
+
258
+ Returns:
259
+ Dict[str, Any]: A dictionary representing a search result.
260
+ """
261
+
262
+ import requests
263
+
264
+ BRAVE_API_KEY = os.getenv("BRAVE_API_KEY")
265
+
266
+ url = "https://api.search.brave.com/res/v1/web/search"
267
+ headers = {
268
+ "Content-Type": "application/json",
269
+ "X-BCP-APIV": "1.0",
270
+ "X-Subscription-Token": BRAVE_API_KEY,
271
+ }
272
+
273
+ ParamsType: TypeAlias = Dict[
274
+ str,
275
+ Union[str, int, float, List[Union[str, int, float]], None],
276
+ ]
277
+
278
+ params: ParamsType = {
279
+ "q": q,
280
+ "country": country,
281
+ "search_lang": search_lang,
282
+ "ui_lang": ui_lang,
283
+ "count": count,
284
+ "offset": offset,
285
+ "safesearch": safesearch,
286
+ "freshness": freshness,
287
+ "text_decorations": text_decorations,
288
+ "spellcheck": spellcheck,
289
+ "result_filter": result_filter,
290
+ "goggles_id": goggles_id,
291
+ "units": units,
292
+ "extra_snippets": extra_snippets,
293
+ "summary": summary,
294
+ }
295
+
296
+ response = requests.get(url, headers=headers, params=params)
297
+ data = response.json()["web"]
298
+ return data
299
+
154
300
  @api_keys_required("GOOGLE_API_KEY", "SEARCH_ENGINE_ID")
155
301
  def search_google(
156
302
  self, query: str, num_result_pages: int = 5
@@ -219,6 +365,11 @@ class SearchToolkit(BaseToolkit):
219
365
 
220
366
  # Iterate over 10 results found
221
367
  for i, search_item in enumerate(search_items, start=1):
368
+ # Check metatags are present
369
+ if "pagemap" not in search_item:
370
+ continue
371
+ if "metatags" not in search_item["pagemap"]:
372
+ continue
222
373
  if (
223
374
  "og:description"
224
375
  in search_item["pagemap"]["metatags"][0]
@@ -471,4 +622,5 @@ class SearchToolkit(BaseToolkit):
471
622
  FunctionTool(self.search_duckduckgo),
472
623
  FunctionTool(self.query_wolfram_alpha),
473
624
  FunctionTool(self.tavily_search),
625
+ FunctionTool(self.search_brave),
474
626
  ]
@@ -0,0 +1,273 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ from typing import List
19
+
20
+ from camel.toolkits import FunctionTool
21
+ from camel.toolkits.base import BaseToolkit
22
+ from camel.utils import api_keys_required
23
+
24
+
25
+ class StripeToolkit(BaseToolkit):
26
+ r"""A class representing a toolkit for Stripe operations.
27
+
28
+ This toolkit provides methods to interact with the Stripe API,
29
+ allowing users to operate stripe core resources, including Customer,
30
+ Balance, BalanceTransaction, Payment, Refund
31
+
32
+ Use the Developers Dashboard https://dashboard.stripe.com/test/apikeys to
33
+ create an API keys as STRIPE_API_KEY.
34
+
35
+ Attributes:
36
+ logger (Logger): a logger to write logs.
37
+ """
38
+
39
+ @api_keys_required("STRIPE_API_KEY")
40
+ def __init__(self, retries: int = 3):
41
+ r"""Initializes the StripeToolkit with the specified number of
42
+ retries.
43
+
44
+ Args:
45
+ retries (int,optional): Number of times to retry the request in
46
+ case of failure. (default: :obj:`3`)
47
+ """
48
+ import stripe
49
+
50
+ stripe.max_network_retries = retries
51
+ stripe.log = 'info'
52
+ self.logger = logging.getLogger(__name__)
53
+ self.logger.setLevel(logging.INFO)
54
+ handler = logging.StreamHandler()
55
+ formatter = logging.Formatter(
56
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
57
+ )
58
+ handler.setFormatter(formatter)
59
+ if not self.logger.handlers:
60
+ self.logger.addHandler(handler)
61
+ stripe.api_key = os.environ.get("STRIPE_API_KEY")
62
+
63
+ def customer_get(self, customer_id: str) -> str:
64
+ r"""Retrieve a customer by ID.
65
+
66
+ Args:
67
+ customer_id (str): The ID of the customer to retrieve.
68
+
69
+ Returns:
70
+ str: The customer data as a str.
71
+ """
72
+ import stripe
73
+
74
+ try:
75
+ self.logger.info(f"Retrieving customer with ID: {customer_id}")
76
+ customer = stripe.Customer.retrieve(customer_id)
77
+ self.logger.info(f"Retrieved customer: {customer.id}")
78
+ json_string = json.dumps(customer)
79
+ return json_string
80
+ except Exception as e:
81
+ return self.handle_exception("customer_get", e)
82
+
83
+ def customer_list(self, limit: int = 100) -> str:
84
+ r"""List customers.
85
+
86
+ Args:
87
+ limit (int, optional): Number of customers to retrieve. (default:
88
+ :obj:`100`)
89
+
90
+ Returns:
91
+ str: An output str if successful, or an error message string if
92
+ failed.
93
+ """
94
+ import stripe
95
+
96
+ try:
97
+ self.logger.info(f"Listing customers with limit={limit}")
98
+ customers = stripe.Customer.list(limit=limit).data
99
+ self.logger.info(
100
+ f"Successfully retrieved {len(customers)} customers."
101
+ )
102
+ return json.dumps([customer for customer in customers])
103
+ except Exception as e:
104
+ return self.handle_exception("customer_list", e)
105
+
106
+ def balance_get(self) -> str:
107
+ r"""Retrieve your account balance.
108
+
109
+ Returns:
110
+ str: A str containing the account balance if successful, or an
111
+ error message string if failed.
112
+ """
113
+ import stripe
114
+
115
+ try:
116
+ self.logger.info("Retrieving account balance.")
117
+ balance = stripe.Balance.retrieve()
118
+ self.logger.info(
119
+ f"Successfully retrieved account balance: {balance}."
120
+ )
121
+ return json.dumps(balance)
122
+ except Exception as e:
123
+ return self.handle_exception("balance_get", e)
124
+
125
+ def balance_transaction_list(self, limit: int = 100) -> str:
126
+ r"""List your balance transactions.
127
+
128
+ Args:
129
+ limit (int, optional): Number of balance transactions to retrieve.
130
+ (default::obj:`100`)
131
+
132
+ Returns:
133
+ str: A list of balance transaction data if successful, or an error
134
+ message string if failed.
135
+ """
136
+ import stripe
137
+
138
+ try:
139
+ self.logger.info(
140
+ f"Listing balance transactions with limit={limit}"
141
+ )
142
+ transactions = stripe.BalanceTransaction.list(limit=limit).data
143
+ self.logger.info(
144
+ f"Successfully retrieved {len(transactions)} "
145
+ "balance transactions."
146
+ )
147
+ return json.dumps([transaction for transaction in transactions])
148
+ except Exception as e:
149
+ return self.handle_exception("balance_transaction_list", e)
150
+
151
+ def payment_get(self, payment_id: str) -> str:
152
+ r"""Retrieve a payment by ID.
153
+
154
+ Args:
155
+ payment_id (str): The ID of the payment to retrieve.
156
+
157
+ Returns:
158
+ str:The payment data as a str if successful, or an error message
159
+ string if failed.
160
+ """
161
+ import stripe
162
+
163
+ try:
164
+ self.logger.info(f"Retrieving payment with ID: {payment_id}")
165
+ payment = stripe.PaymentIntent.retrieve(payment_id)
166
+ self.logger.info(f"Retrieved payment: {payment.id}")
167
+ return json.dumps(payment)
168
+ except Exception as e:
169
+ return self.handle_exception("payment_get", e)
170
+
171
+ def payment_list(self, limit: int = 100) -> str:
172
+ r"""List payments.
173
+
174
+ Args:
175
+ limit (int, optional): Number of payments to retrieve.
176
+ (default::obj:`100`)
177
+
178
+ Returns:
179
+ str: A list of payment data if successful, or an error message
180
+ string if failed.
181
+ """
182
+ import stripe
183
+
184
+ try:
185
+ self.logger.info(f"Listing payments with limit={limit}")
186
+ payments = stripe.PaymentIntent.list(limit=limit).data
187
+ self.logger.info(
188
+ f"Successfully retrieved {len(payments)} payments."
189
+ )
190
+ return json.dumps([payment for payment in payments])
191
+ except Exception as e:
192
+ return self.handle_exception("payment_list", e)
193
+
194
+ def refund_get(self, refund_id: str) -> str:
195
+ r"""Retrieve a refund by ID.
196
+
197
+ Args:
198
+ refund_id (str): The ID of the refund to retrieve.
199
+
200
+ Returns:
201
+ str: The refund data as a str if successful, or an error message
202
+ string if failed.
203
+ """
204
+ import stripe
205
+
206
+ try:
207
+ self.logger.info(f"Retrieving refund with ID: {refund_id}")
208
+ refund = stripe.Refund.retrieve(refund_id)
209
+ self.logger.info(f"Retrieved refund: {refund.id}")
210
+ return json.dumps(refund)
211
+ except Exception as e:
212
+ return self.handle_exception("refund_get", e)
213
+
214
+ def refund_list(self, limit: int = 100) -> str:
215
+ r"""List refunds.
216
+
217
+ Args:
218
+ limit (int, optional): Number of refunds to retrieve.
219
+ (default::obj:`100`)
220
+
221
+ Returns:
222
+ str: A list of refund data as a str if successful, or an error
223
+ message string if failed.
224
+ """
225
+ import stripe
226
+
227
+ try:
228
+ self.logger.info(f"Listing refunds with limit={limit}")
229
+ refunds = stripe.Refund.list(limit=limit).data
230
+ self.logger.info(f"Successfully retrieved {len(refunds)} refunds.")
231
+ return json.dumps([refund for refund in refunds])
232
+ except Exception as e:
233
+ return self.handle_exception("refund_list", e)
234
+
235
+ def handle_exception(self, func_name: str, error: Exception) -> str:
236
+ r"""Handle exceptions by logging and returning an error message.
237
+
238
+ Args:
239
+ func_name (str): The name of the function where the exception
240
+ occurred.
241
+ error (Exception): The exception instance.
242
+
243
+ Returns:
244
+ str: An error message string.
245
+ """
246
+ from stripe import StripeError
247
+
248
+ if isinstance(error, StripeError):
249
+ message = error.user_message or str(error)
250
+ self.logger.error(f"Stripe error in {func_name}: {message}")
251
+ return f"Stripe error in {func_name}: {message}"
252
+ else:
253
+ self.logger.error(f"Unexpected error in {func_name}: {error!s}")
254
+ return f"Unexpected error in {func_name}: {error!s}"
255
+
256
+ def get_tools(self) -> List[FunctionTool]:
257
+ r"""Returns a list of FunctionTool objects representing the
258
+ functions in the toolkit.
259
+
260
+ Returns:
261
+ List[FunctionTool]: A list of FunctionTool objects for the
262
+ toolkit methods.
263
+ """
264
+ return [
265
+ FunctionTool(self.customer_get),
266
+ FunctionTool(self.customer_list),
267
+ FunctionTool(self.balance_get),
268
+ FunctionTool(self.balance_transaction_list),
269
+ FunctionTool(self.payment_get),
270
+ FunctionTool(self.payment_list),
271
+ FunctionTool(self.refund_get),
272
+ FunctionTool(self.refund_list),
273
+ ]
camel/types/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
  from .enums import (
15
15
  AudioModelType,
16
16
  EmbeddingModelType,
17
+ HuggingFaceRepoType,
17
18
  ModelPlatformType,
18
19
  ModelType,
19
20
  OpenAIBackendRole,
@@ -73,4 +74,5 @@ __all__ = [
73
74
  'NOT_GIVEN',
74
75
  'NotGiven',
75
76
  'ParsedChatCompletion',
77
+ 'HuggingFaceRepoType',
76
78
  ]
camel/types/enums.py CHANGED
@@ -41,9 +41,12 @@ class ModelType(UnifiedModelType, Enum):
41
41
  GLM_4V = 'glm-4v'
42
42
  GLM_3_TURBO = "glm-3-turbo"
43
43
 
44
+ # Groq platform models
44
45
  GROQ_LLAMA_3_1_8B = "llama-3.1-8b-instant"
45
46
  GROQ_LLAMA_3_1_70B = "llama-3.1-70b-versatile"
46
47
  GROQ_LLAMA_3_1_405B = "llama-3.1-405b-reasoning"
48
+ GROQ_LLAMA_3_3_70B = "llama-3.3-70b-versatile"
49
+ GROQ_LLAMA_3_3_70B_PREVIEW = "llama-3.3-70b-specdec"
47
50
  GROQ_LLAMA_3_8B = "llama3-8b-8192"
48
51
  GROQ_LLAMA_3_70B = "llama3-70b-8192"
49
52
  GROQ_MIXTRAL_8_7B = "mixtral-8x7b-32768"
@@ -59,10 +62,11 @@ class ModelType(UnifiedModelType, Enum):
59
62
  CLAUDE_INSTANT_1_2 = "claude-instant-1.2"
60
63
 
61
64
  # Claude3 models
62
- CLAUDE_3_OPUS = "claude-3-opus-20240229"
65
+ CLAUDE_3_OPUS = "claude-3-opus-latest"
63
66
  CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
64
67
  CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
65
- CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
68
+ CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
69
+ CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
66
70
 
67
71
  # Nvidia models
68
72
  NVIDIA_NEMOTRON_340B_INSTRUCT = "nvidia/nemotron-4-340b-instruct"
@@ -76,6 +80,7 @@ class ModelType(UnifiedModelType, Enum):
76
80
  NVIDIA_LLAMA3_1_405B_INSTRUCT = "meta/llama-3.1-405b-instruct"
77
81
  NVIDIA_LLAMA3_2_1B_INSTRUCT = "meta/llama-3.2-1b-instruct"
78
82
  NVIDIA_LLAMA3_2_3B_INSTRUCT = "meta/llama-3.2-3b-instruct"
83
+ NVIDIA_LLAMA3_3_70B_INSTRUCT = "meta/llama-3.3-70b-instruct"
79
84
 
80
85
  # Gemini models
81
86
  GEMINI_1_5_FLASH = "gemini-1.5-flash"
@@ -202,6 +207,7 @@ class ModelType(UnifiedModelType, Enum):
202
207
  ModelType.CLAUDE_3_SONNET,
203
208
  ModelType.CLAUDE_3_HAIKU,
204
209
  ModelType.CLAUDE_3_5_SONNET,
210
+ ModelType.CLAUDE_3_5_HAIKU,
205
211
  }
206
212
 
207
213
  @property
@@ -211,6 +217,8 @@ class ModelType(UnifiedModelType, Enum):
211
217
  ModelType.GROQ_LLAMA_3_1_8B,
212
218
  ModelType.GROQ_LLAMA_3_1_70B,
213
219
  ModelType.GROQ_LLAMA_3_1_405B,
220
+ ModelType.GROQ_LLAMA_3_3_70B,
221
+ ModelType.GROQ_LLAMA_3_3_70B_PREVIEW,
214
222
  ModelType.GROQ_LLAMA_3_8B,
215
223
  ModelType.GROQ_LLAMA_3_70B,
216
224
  ModelType.GROQ_MIXTRAL_8_7B,
@@ -249,6 +257,7 @@ class ModelType(UnifiedModelType, Enum):
249
257
  ModelType.NVIDIA_LLAMA3_1_405B_INSTRUCT,
250
258
  ModelType.NVIDIA_LLAMA3_2_1B_INSTRUCT,
251
259
  ModelType.NVIDIA_LLAMA3_2_3B_INSTRUCT,
260
+ ModelType.NVIDIA_LLAMA3_3_70B_INSTRUCT,
252
261
  }
253
262
 
254
263
  @property
@@ -362,6 +371,7 @@ class ModelType(UnifiedModelType, Enum):
362
371
  ModelType.GPT_4,
363
372
  ModelType.GROQ_LLAMA_3_8B,
364
373
  ModelType.GROQ_LLAMA_3_70B,
374
+ ModelType.GROQ_LLAMA_3_3_70B_PREVIEW,
365
375
  ModelType.GROQ_GEMMA_7B_IT,
366
376
  ModelType.GROQ_GEMMA_2_9B_IT,
367
377
  ModelType.GLM_3_TURBO,
@@ -428,6 +438,8 @@ class ModelType(UnifiedModelType, Enum):
428
438
  ModelType.NVIDIA_LLAMA3_1_405B_INSTRUCT,
429
439
  ModelType.NVIDIA_LLAMA3_2_1B_INSTRUCT,
430
440
  ModelType.NVIDIA_LLAMA3_2_3B_INSTRUCT,
441
+ ModelType.NVIDIA_LLAMA3_3_70B_INSTRUCT,
442
+ ModelType.GROQ_LLAMA_3_3_70B,
431
443
  }:
432
444
  return 128_000
433
445
  elif self in {
@@ -445,6 +457,7 @@ class ModelType(UnifiedModelType, Enum):
445
457
  ModelType.CLAUDE_3_SONNET,
446
458
  ModelType.CLAUDE_3_HAIKU,
447
459
  ModelType.CLAUDE_3_5_SONNET,
460
+ ModelType.CLAUDE_3_5_HAIKU,
448
461
  ModelType.YI_MEDIUM_200K,
449
462
  }:
450
463
  return 200_000
@@ -611,6 +624,7 @@ class ModelPlatformType(Enum):
611
624
  QWEN = "tongyi-qianwen"
612
625
  NVIDIA = "nvidia"
613
626
  DEEPSEEK = "deepseek"
627
+ SGLANG = "sglang"
614
628
 
615
629
  @property
616
630
  def is_openai(self) -> bool:
@@ -642,6 +656,11 @@ class ModelPlatformType(Enum):
642
656
  r"""Returns whether this platform is vllm."""
643
657
  return self is ModelPlatformType.VLLM
644
658
 
659
+ @property
660
+ def is_sglang(self) -> bool:
661
+ r"""Returns whether this platform is sglang."""
662
+ return self is ModelPlatformType.SGLANG
663
+
645
664
  @property
646
665
  def is_together(self) -> bool:
647
666
  r"""Returns whether this platform is together."""
@@ -749,3 +768,9 @@ class JinaReturnFormat(Enum):
749
768
  MARKDOWN = "markdown"
750
769
  HTML = "html"
751
770
  TEXT = "text"
771
+
772
+
773
+ class HuggingFaceRepoType(str, Enum):
774
+ DATASET = "dataset"
775
+ MODEL = "model"
776
+ SPACE = "space"
@@ -222,13 +222,18 @@ class OpenAITokenCounter(BaseTokenCounter):
222
222
 
223
223
  class AnthropicTokenCounter(BaseTokenCounter):
224
224
  @dependencies_required('anthropic')
225
- def __init__(self):
226
- r"""Constructor for the token counter for Anthropic models."""
225
+ def __init__(self, model: str):
226
+ r"""Constructor for the token counter for Anthropic models.
227
+
228
+ Args:
229
+ model (str): The name of the Anthropic model being used.
230
+ """
227
231
  from anthropic import Anthropic
228
232
 
229
233
  self.client = Anthropic()
230
- self.tokenizer = self.client.get_tokenizer()
234
+ self.model = model
231
235
 
236
+ @dependencies_required('anthropic')
232
237
  def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
233
238
  r"""Count number of tokens in the provided message list using
234
239
  loaded tokenizer specific for this type of model.
@@ -240,11 +245,18 @@ class AnthropicTokenCounter(BaseTokenCounter):
240
245
  Returns:
241
246
  int: Number of tokens in the messages.
242
247
  """
243
- num_tokens = 0
244
- for message in messages:
245
- content = str(message["content"])
246
- num_tokens += self.client.count_tokens(content)
247
- return num_tokens
248
+ from anthropic.types.beta import BetaMessageParam
249
+
250
+ return self.client.beta.messages.count_tokens(
251
+ messages=[
252
+ BetaMessageParam(
253
+ content=str(msg["content"]),
254
+ role="user" if msg["role"] == "user" else "assistant",
255
+ )
256
+ for msg in messages
257
+ ],
258
+ model=self.model,
259
+ ).input_tokens
248
260
 
249
261
 
250
262
  class GeminiTokenCounter(BaseTokenCounter):
@@ -360,7 +372,7 @@ class MistralTokenCounter(BaseTokenCounter):
360
372
  ModelType.MISTRAL_CODESTRAL,
361
373
  ModelType.MISTRAL_CODESTRAL_MAMBA,
362
374
  }
363
- else self.model_type.value
375
+ else self.model_type
364
376
  )
365
377
 
366
378
  self.tokenizer = MistralTokenizer.from_model(model_name)
@@ -403,7 +415,7 @@ class MistralTokenCounter(BaseTokenCounter):
403
415
  )
404
416
 
405
417
  mistral_request = ChatCompletionRequest( # type: ignore[type-var]
406
- model=self.model_type.value,
418
+ model=self.model_type,
407
419
  messages=[openai_msg],
408
420
  )
409
421