camel-ai 0.2.20a1__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -19,7 +19,7 @@ import requests
19
19
 
20
20
  from camel.toolkits import FunctionTool
21
21
  from camel.toolkits.base import BaseToolkit
22
- from camel.utils.commons import retry_request
22
+ from camel.utils import retry_on_error
23
23
 
24
24
 
25
25
  class WhatsAppToolkit(BaseToolkit):
@@ -36,18 +36,8 @@ class WhatsAppToolkit(BaseToolkit):
36
36
  version (str): API version.
37
37
  """
38
38
 
39
- def __init__(self, retries: int = 3, delay: int = 1):
40
- r"""Initializes the WhatsAppToolkit with the specified number of
41
- retries and delay.
42
-
43
- Args:
44
- retries (int): Number of times to retry the request in case of
45
- failure. (default: :obj:`3`)
46
- delay (int): Time in seconds to wait between retries.
47
- (default: :obj:`1`)
48
- """
49
- self.retries = retries
50
- self.delay = delay
39
+ def __init__(self):
40
+ r"""Initializes the WhatsAppToolkit."""
51
41
  self.base_url = "https://graph.facebook.com"
52
42
  self.version = "v17.0"
53
43
 
@@ -61,6 +51,7 @@ class WhatsAppToolkit(BaseToolkit):
61
51
  "WHATSAPP_PHONE_NUMBER_ID environment variables."
62
52
  )
63
53
 
54
+ @retry_on_error()
64
55
  def send_message(
65
56
  self, to: str, message: str
66
57
  ) -> Union[Dict[str, Any], str]:
@@ -88,19 +79,15 @@ class WhatsAppToolkit(BaseToolkit):
88
79
  }
89
80
 
90
81
  try:
91
- response = retry_request(
92
- requests.post,
93
- retries=self.retries,
94
- delay=self.delay,
95
- url=url,
96
- headers=headers,
97
- json=data,
98
- )
82
+ response = requests.post(url=url, headers=headers, json=data)
99
83
  response.raise_for_status()
100
84
  return response.json()
85
+ except requests.exceptions.RequestException as e:
86
+ raise e
101
87
  except Exception as e:
102
88
  return f"Failed to send message: {e!s}"
103
89
 
90
+ @retry_on_error()
104
91
  def get_message_templates(self) -> Union[List[Dict[str, Any]], str]:
105
92
  r"""Retrieves all message templates for the WhatsApp Business account.
106
93
 
@@ -116,18 +103,13 @@ class WhatsAppToolkit(BaseToolkit):
116
103
  headers = {"Authorization": f"Bearer {self.access_token}"}
117
104
 
118
105
  try:
119
- response = retry_request(
120
- requests.get,
121
- retries=self.retries,
122
- delay=self.delay,
123
- url=url,
124
- headers=headers,
125
- )
106
+ response = requests.get(url=url, headers=headers)
126
107
  response.raise_for_status()
127
108
  return response.json().get("data", [])
128
109
  except Exception as e:
129
110
  return f"Failed to retrieve message templates: {e!s}"
130
111
 
112
+ @retry_on_error()
131
113
  def get_business_profile(self) -> Union[Dict[str, Any], str]:
132
114
  r"""Retrieves the WhatsApp Business profile information.
133
115
 
@@ -149,10 +131,7 @@ class WhatsAppToolkit(BaseToolkit):
149
131
  }
150
132
 
151
133
  try:
152
- response = retry_request(
153
- requests.get,
154
- retries=self.retries,
155
- delay=self.delay,
134
+ response = requests.get(
156
135
  url=url,
157
136
  headers=headers,
158
137
  params=params,
camel/types/enums.py CHANGED
@@ -204,6 +204,10 @@ class ModelType(UnifiedModelType, Enum):
204
204
  SILICONFLOW_THUDM_GLM_4_9B_CHAT = "THUDM/glm-4-9b-chat"
205
205
  SILICONFLOW_PRO_THUDM_GLM_4_9B_CHAT = "Pro/THUDM/glm-4-9b-chat"
206
206
 
207
+ # AIML models support tool calling
208
+ AIML_MIXTRAL_8X7B = "mistralai/Mixtral-8x7B-Instruct-v0.1"
209
+ AIML_MISTRAL_7B_INSTRUCT = "mistralai/Mistral-7B-Instruct-v0.1"
210
+
207
211
  def __str__(self):
208
212
  return self.value
209
213
 
@@ -218,7 +222,11 @@ class ModelType(UnifiedModelType, Enum):
218
222
 
219
223
  @property
220
224
  def support_native_structured_output(self) -> bool:
221
- return self.is_openai
225
+ return any(
226
+ [
227
+ self.is_openai,
228
+ ]
229
+ )
222
230
 
223
231
  @property
224
232
  def support_native_tool_calling(self) -> bool:
@@ -238,6 +246,7 @@ class ModelType(UnifiedModelType, Enum):
238
246
  self.is_moonshot,
239
247
  self.is_siliconflow,
240
248
  self.is_zhipuai,
249
+ self.is_aiml,
241
250
  ]
242
251
  )
243
252
 
@@ -513,6 +522,13 @@ class ModelType(UnifiedModelType, Enum):
513
522
  ModelType.SILICONFLOW_PRO_THUDM_GLM_4_9B_CHAT,
514
523
  }
515
524
 
525
+ @property
526
+ def is_aiml(self) -> bool:
527
+ return self in {
528
+ ModelType.AIML_MIXTRAL_8X7B,
529
+ ModelType.AIML_MISTRAL_7B_INSTRUCT,
530
+ }
531
+
516
532
  @property
517
533
  def token_limit(self) -> int:
518
534
  r"""Returns the maximum token limit for a given model.
@@ -586,6 +602,8 @@ class ModelType(UnifiedModelType, Enum):
586
602
  ModelType.TOGETHER_MIXTRAL_8_7B,
587
603
  ModelType.SGLANG_MISTRAL_7B,
588
604
  ModelType.MOONSHOT_V1_32K,
605
+ ModelType.AIML_MIXTRAL_8X7B,
606
+ ModelType.AIML_MISTRAL_7B_INSTRUCT,
589
607
  }:
590
608
  return 32_768
591
609
  elif self in {
@@ -864,6 +882,7 @@ class ModelPlatformType(Enum):
864
882
  INTERNLM = "internlm"
865
883
  MOONSHOT = "moonshot"
866
884
  SILICONFLOW = "siliconflow"
885
+ AIML = "aiml"
867
886
 
868
887
  @property
869
888
  def is_openai(self) -> bool:
@@ -981,6 +1000,11 @@ class ModelPlatformType(Enum):
981
1000
  r"""Returns whether this platform is SiliconFlow."""
982
1001
  return self is ModelPlatformType.SILICONFLOW
983
1002
 
1003
+ @property
1004
+ def is_aiml(self) -> bool:
1005
+ r"""Returns whether this platform is AIML."""
1006
+ return self is ModelPlatformType.AIML
1007
+
984
1008
 
985
1009
  class AudioModelType(Enum):
986
1010
  TTS_1 = "tts-1"
camel/utils/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  from .commons import (
16
16
  AgentOpsMeta,
17
+ BatchProcessor,
17
18
  agentops_decorator,
18
19
  api_keys_required,
19
20
  check_server_running,
@@ -33,16 +34,17 @@ from .commons import (
33
34
  is_docker_running,
34
35
  json_to_function_code,
35
36
  print_text_animated,
37
+ retry_on_error,
36
38
  text_extract_from_web,
37
39
  to_pascal,
38
40
  track_agent,
39
41
  )
40
42
  from .constants import Constants
43
+ from .deduplication import DeduplicationResult, deduplicate_internally
41
44
  from .response_format import get_pydantic_model
42
45
  from .token_counting import (
43
46
  AnthropicTokenCounter,
44
47
  BaseTokenCounter,
45
- GeminiTokenCounter,
46
48
  LiteLLMTokenCounter,
47
49
  MistralTokenCounter,
48
50
  OpenAITokenCounter,
@@ -69,7 +71,6 @@ __all__ = [
69
71
  "dependencies_required",
70
72
  "api_keys_required",
71
73
  "is_docker_running",
72
- "GeminiTokenCounter",
73
74
  "MistralTokenCounter",
74
75
  "get_pydantic_major_version",
75
76
  "get_pydantic_object_schema",
@@ -82,4 +83,8 @@ __all__ = [
82
83
  "get_pydantic_model",
83
84
  "download_github_subdirectory",
84
85
  "generate_prompt_for_structured_output",
86
+ "deduplicate_internally",
87
+ "DeduplicationResult",
88
+ "retry_on_error",
89
+ "BatchProcessor",
85
90
  ]
camel/utils/commons.py CHANGED
@@ -11,7 +11,9 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import functools
14
15
  import importlib
16
+ import logging
15
17
  import os
16
18
  import platform
17
19
  import re
@@ -47,6 +49,8 @@ from .constants import Constants
47
49
 
48
50
  F = TypeVar('F', bound=Callable[..., Any])
49
51
 
52
+ logger = logging.getLogger(__name__)
53
+
50
54
 
51
55
  def print_text_animated(text, delay: float = 0.02, end: str = ""):
52
56
  r"""Prints the given text with an animated effect.
@@ -620,33 +624,206 @@ def handle_http_error(response: requests.Response) -> str:
620
624
  return "HTTP Error"
621
625
 
622
626
 
623
- def retry_request(
624
- func: Callable, retries: int = 3, delay: int = 1, *args: Any, **kwargs: Any
625
- ) -> Any:
626
- r"""Retries a function in case of any errors.
627
+ def retry_on_error(
628
+ max_retries: int = 3, initial_delay: float = 1.0
629
+ ) -> Callable:
630
+ r"""Decorator to retry function calls on exception with exponential
631
+ backoff.
627
632
 
628
633
  Args:
629
- func (Callable): The function to be retried.
630
- retries (int): Number of retry attempts. (default: :obj:`3`)
631
- delay (int): Delay between retries in seconds. (default: :obj:`1`)
632
- *args: Arguments to pass to the function.
633
- **kwargs: Keyword arguments to pass to the function.
634
+ max_retries (int): Maximum number of retry attempts
635
+ initial_delay (float): Initial delay between retries in seconds
634
636
 
635
637
  Returns:
636
- Any: The result of the function call if successful.
638
+ Callable: Decorated function with retry logic
639
+ """
637
640
 
638
- Raises:
639
- Exception: If all retry attempts fail.
641
+ def decorator(func: Callable) -> Callable:
642
+ @functools.wraps(func)
643
+ def wrapper(*args, **kwargs):
644
+ delay = initial_delay
645
+ last_exception = None
646
+
647
+ for attempt in range(max_retries + 1):
648
+ try:
649
+ return func(*args, **kwargs)
650
+ except Exception as e:
651
+ last_exception = e
652
+ if attempt == max_retries:
653
+ logger.error(
654
+ f"Failed after {max_retries} retries: {e!s}"
655
+ )
656
+ raise
657
+
658
+ logger.warning(
659
+ f"Attempt {attempt + 1} failed: {e!s}. "
660
+ f"Retrying in {delay:.1f}s..."
661
+ )
662
+ time.sleep(delay)
663
+ delay *= 2 # Exponential backoff
664
+
665
+ raise last_exception
666
+
667
+ return wrapper
668
+
669
+ return decorator
670
+
671
+
672
+ class BatchProcessor:
673
+ r"""Handles batch processing with dynamic sizing and error handling based
674
+ on system load.
640
675
  """
641
- for attempt in range(retries):
642
- try:
643
- return func(*args, **kwargs)
644
- except Exception as e:
645
- print(f"Attempt {attempt + 1}/{retries} failed: {e}")
646
- if attempt < retries - 1:
647
- time.sleep(delay)
648
- else:
649
- raise
676
+
677
+ def __init__(
678
+ self,
679
+ max_workers: Optional[int] = None,
680
+ initial_batch_size: Optional[int] = None,
681
+ monitoring_interval: float = 5.0,
682
+ cpu_threshold: float = 80.0,
683
+ memory_threshold: float = 85.0,
684
+ ):
685
+ r"""Initialize the BatchProcessor with dynamic worker allocation.
686
+
687
+ Args:
688
+ max_workers: Maximum number of workers. If None, will be
689
+ determined dynamically based on system resources.
690
+ (default: :obj:`None`)
691
+ initial_batch_size: Initial size of each batch. If `None`,
692
+ defaults to `10`. (default: :obj:`None`)
693
+ monitoring_interval: Interval in seconds between resource checks.
694
+ (default: :obj:`5.0`)
695
+ cpu_threshold: CPU usage percentage threshold for scaling down.
696
+ (default: :obj:`80.0`)
697
+ memory_threshold: Memory usage percentage threshold for scaling
698
+ down. (default: :obj:`85.0`)
699
+ """
700
+ import psutil
701
+
702
+ self.monitoring_interval = monitoring_interval
703
+ self.cpu_threshold = cpu_threshold
704
+ self.memory_threshold = memory_threshold
705
+ self.last_check_time = time.time()
706
+ self.psutil = psutil
707
+
708
+ # Initialize performance metrics
709
+ self.total_processed = 0
710
+ self.total_errors = 0
711
+ self.processing_times: List = []
712
+
713
+ if max_workers is None:
714
+ self.max_workers = self._calculate_optimal_workers()
715
+ else:
716
+ self.max_workers = max_workers
717
+
718
+ self.batch_size = (
719
+ 10 if initial_batch_size is None else initial_batch_size
720
+ )
721
+ self.min_batch_size = 1
722
+ self.max_batch_size = 20
723
+ self.backoff_factor = 0.8
724
+ self.success_factor = 1.2
725
+
726
+ # Initial resource check
727
+ self._update_resource_metrics()
728
+
729
+ def _calculate_optimal_workers(self) -> int:
730
+ r"""Calculate optimal number of workers based on system resources."""
731
+ cpu_count = self.psutil.cpu_count()
732
+ cpu_percent = self.psutil.cpu_percent(interval=1)
733
+ memory = self.psutil.virtual_memory()
734
+
735
+ # Base number of workers on CPU count and current load
736
+ if cpu_percent > self.cpu_threshold:
737
+ workers = max(1, cpu_count // 4)
738
+ elif cpu_percent > 60:
739
+ workers = max(1, cpu_count // 2)
740
+ else:
741
+ workers = max(1, cpu_count - 1)
742
+
743
+ # Further reduce if memory is constrained
744
+ if memory.percent > self.memory_threshold:
745
+ workers = max(1, workers // 2)
746
+
747
+ return workers
748
+
749
+ def _update_resource_metrics(self) -> None:
750
+ r"""Update current resource usage metrics."""
751
+ self.current_cpu = self.psutil.cpu_percent()
752
+ self.current_memory = self.psutil.virtual_memory().percent
753
+ self.last_check_time = time.time()
754
+
755
+ def _should_check_resources(self) -> bool:
756
+ r"""Determine if it's time to check resource usage again."""
757
+ return time.time() - self.last_check_time >= self.monitoring_interval
758
+
759
+ def adjust_batch_size(
760
+ self, success: bool, processing_time: Optional[float] = None
761
+ ) -> None:
762
+ r"""Adjust batch size based on success/failure and system resources.
763
+
764
+ Args:
765
+ success (bool): Whether the last batch completed successfully
766
+ processing_time (Optional[float]): Time taken to process the last
767
+ batch. (default: :obj:`None`)
768
+ """
769
+ # Update metrics
770
+ self.total_processed += 1
771
+ if not success:
772
+ self.total_errors += 1
773
+ if processing_time is not None:
774
+ self.processing_times.append(processing_time)
775
+
776
+ # Check system resources if interval has elapsed
777
+ if self._should_check_resources():
778
+ self._update_resource_metrics()
779
+
780
+ # Adjust based on resource usage
781
+ if (
782
+ self.current_cpu > self.cpu_threshold
783
+ or self.current_memory > self.memory_threshold
784
+ ):
785
+ self.batch_size = max(
786
+ int(self.batch_size * self.backoff_factor),
787
+ self.min_batch_size,
788
+ )
789
+ self.max_workers = max(1, self.max_workers - 1)
790
+ return
791
+
792
+ # Adjust based on success/failure
793
+ if success:
794
+ self.batch_size = min(
795
+ int(self.batch_size * self.success_factor), self.max_batch_size
796
+ )
797
+ else:
798
+ self.batch_size = max(
799
+ int(self.batch_size * self.backoff_factor), self.min_batch_size
800
+ )
801
+
802
+ def get_performance_metrics(self) -> Dict[str, Any]:
803
+ r"""Get current performance metrics.
804
+
805
+ Returns:
806
+ Dict containing performance metrics including:
807
+ - total_processed: Total number of batches processed
808
+ - error_rate: Percentage of failed batches
809
+ - avg_processing_time: Average time per batch
810
+ - current_batch_size: Current batch size
811
+ - current_workers: Current number of workers
812
+ - current_cpu: Current CPU usage percentage
813
+ - current_memory: Current memory usage percentage
814
+ """
815
+ metrics = {
816
+ "total_processed": self.total_processed,
817
+ "error_rate": (self.total_errors / max(1, self.total_processed))
818
+ * 100,
819
+ "avg_processing_time": sum(self.processing_times)
820
+ / max(1, len(self.processing_times)),
821
+ "current_batch_size": self.batch_size,
822
+ "current_workers": self.max_workers,
823
+ "current_cpu": self.current_cpu,
824
+ "current_memory": self.current_memory,
825
+ }
826
+ return metrics
650
827
 
651
828
 
652
829
  def download_github_subdirectory(
@@ -0,0 +1,232 @@
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+
16
+ from typing import Dict, List, Literal, Optional
17
+
18
+ from pydantic import BaseModel
19
+
20
+ from camel.embeddings.base import BaseEmbedding
21
+
22
+
23
+ class DeduplicationResult(BaseModel):
24
+ r"""The result of deduplication.
25
+
26
+ Attributes:
27
+ original_texts (List[str]): The original texts.
28
+ unique_ids (List[int]): A list of ids that are unique (not duplicates).
29
+ unique_embeddings_dict (Dict[int, List[float]]): A mapping from the
30
+ index of each unique text to its embedding.
31
+ duplicate_to_target_map (Dict[int, int]): A mapping from the index of
32
+ the duplicate text to the index of the text it is considered a
33
+ duplicate of.
34
+ """
35
+
36
+ original_texts: List[str]
37
+ unique_ids: List[int]
38
+ unique_embeddings_dict: Dict[int, List[float]]
39
+ duplicate_to_target_map: Dict[int, int]
40
+
41
+
42
+ def deduplicate_internally(
43
+ texts: List[str],
44
+ threshold: float = 0.65,
45
+ embedding_instance: Optional[BaseEmbedding[str]] = None,
46
+ embeddings: Optional[List[List[float]]] = None,
47
+ strategy: Literal["top1", "llm-supervise"] = "top1",
48
+ batch_size: int = 1000,
49
+ ) -> DeduplicationResult:
50
+ r"""Deduplicate a list of strings based on their cosine similarity.
51
+
52
+ You can either:
53
+ 1) Provide a CAMEL `BaseEmbedding` instance via `embedding_instance` to let
54
+ this function handle the embedding internally, OR
55
+ 2) Directly pass a list of pre-computed embeddings to `embeddings`.
56
+
57
+ If both `embedding_instance` and `embeddings` are provided, the function
58
+ will raise a ValueError to avoid ambiguous usage.
59
+
60
+ strategy is used to specify different strategies, where 'top1' selects the
61
+ one with highest similarity, and 'llm-supervise' uses LLM to determine if
62
+ texts are duplicates (not yet implemented).
63
+
64
+ Args:
65
+ texts (List[str]): The list of texts to be deduplicated.
66
+ threshold (float, optional): The similarity threshold for considering
67
+ two texts as duplicates. (default: :obj:`0.65`)
68
+ embedding_instance (Optional[BaseEmbedding[str]], optional):
69
+ A CAMEL embedding instance for automatic embedding. (default:
70
+ :obj:`None`)
71
+ embeddings (Optional[List[List[float]]], optional):
72
+ Pre-computed embeddings of `texts`. Each element in the list
73
+ corresponds to the embedding of the text in the same index of
74
+ `texts`. (default: :obj:`None`)
75
+ strategy (Literal["top1", "llm-supervise"], optional):
76
+ The strategy to use for deduplication. (default: :obj:`"top1"`)
77
+ batch_size (int, optional): The size of the batch to use for
78
+ calculating cosine similarities. (default: :obj:`1000`)
79
+
80
+ Returns:
81
+ DeduplicationResult: An object that contains:
82
+ - `original_texts`: The original texts.
83
+ - `unique_ids`: The unique ids after deduplication.
84
+ - `unique_embeddings_dict`: A dict mapping from (unique) text id
85
+ to its embedding.
86
+ - `duplicate_to_target_map`: A dict mapping from the id of a
87
+ duplicate text to the id of the text it is considered a duplicate
88
+ of.
89
+
90
+ Raises:
91
+ NotImplementedError: If the strategy is not "top1".
92
+ ValueError: If neither embeddings nor embedding_instance is provided,
93
+ or if both are provided at the same time.
94
+ ValueError: If the length of `embeddings` does not match the length of
95
+ `texts`.
96
+
97
+ Example:
98
+ >>> from camel.embeddings.openai_embedding import OpenAIEmbedding
99
+ >>> # Suppose we have 5 texts, some of which may be duplicates
100
+ >>> texts = [
101
+ ... "What is AI?",
102
+ ... "Artificial Intelligence is about machines",
103
+ ... "What is AI?",
104
+ ... "Deep Learning is a subset of AI",
105
+ ... "What is artificial intelligence?"
106
+ ... ]
107
+ >>> # or any other BaseEmbedding instance
108
+ >>> embedding_model = OpenAIEmbedding()
109
+ >>> result = deduplicate_internally(
110
+ ... texts=texts,
111
+ ... threshold=0.7,
112
+ ... embedding_instance=embedding_model
113
+ ... )
114
+ >>> print("Unique ids:")
115
+ >>> for uid in result.unique_ids:
116
+ ... print(texts[uid])
117
+ Unique ids:
118
+ What is AI?
119
+ Artificial Intelligence is about machines
120
+ Deep Learning is a subset of AI
121
+ What is artificial intelligence?
122
+
123
+ >>> print("Duplicate map:")
124
+ >>> print(result.duplicate_to_target_map)
125
+ {2: 0}
126
+ # This indicates the text at index 2 is considered
127
+ # a duplicate of index 0.
128
+ """
129
+ import numpy as np
130
+ from sklearn.metrics.pairwise import cosine_similarity
131
+
132
+ if len(texts) == 0:
133
+ return DeduplicationResult(
134
+ original_texts=[],
135
+ unique_ids=[],
136
+ unique_embeddings_dict={},
137
+ duplicate_to_target_map={},
138
+ )
139
+
140
+ if len(texts) == 1:
141
+ return DeduplicationResult(
142
+ original_texts=texts,
143
+ unique_ids=[0],
144
+ unique_embeddings_dict={
145
+ 0: embeddings[0]
146
+ if embeddings
147
+ else embedding_instance.embed_list(texts)[0] # type: ignore[union-attr]
148
+ },
149
+ duplicate_to_target_map={},
150
+ )
151
+
152
+ if strategy == "llm-supervise":
153
+ # TODO: Implement LLM-supervise deduplication.
154
+ raise NotImplementedError(
155
+ "LLM-supervise deduplication is not yet implemented."
156
+ )
157
+
158
+ # Check if the parameters are valid.
159
+ if not 0 <= threshold <= 1:
160
+ raise ValueError("Threshold must be between 0 and 1")
161
+
162
+ if embedding_instance is None and embeddings is None:
163
+ raise ValueError(
164
+ "Either 'embedding_instance' or 'embeddings' must be provided."
165
+ )
166
+ if embedding_instance is not None and embeddings is not None:
167
+ raise ValueError(
168
+ "Cannot provide both 'embedding_instance' and 'embeddings'. "
169
+ "Please choose only one way to supply embeddings."
170
+ )
171
+
172
+ if embedding_instance is not None:
173
+ # Use Camel's embedding_instance to vectorize.
174
+ embeddings = embedding_instance.embed_list(texts)
175
+ else:
176
+ # Use pre-supplied embeddings.
177
+ if embeddings and len(embeddings) != len(texts):
178
+ raise ValueError(
179
+ "The length of 'embeddings' does not match the length "
180
+ "of 'texts'."
181
+ )
182
+
183
+ # Convert embeddings to numpy array for efficient computation
184
+ embeddings_array = np.array(embeddings)
185
+ n = len(texts)
186
+ duplicate_to_target_map: Dict[int, int] = {}
187
+
188
+ # Process in batches to reduce memory usage
189
+ for i in range(0, n, batch_size):
190
+ batch_end = min(i + batch_size, n)
191
+ # Calculate cosine similarity for current batch
192
+ batch_similarities = cosine_similarity(
193
+ embeddings_array[i:batch_end], embeddings_array[:batch_end]
194
+ )
195
+
196
+ # Create mask for lower triangle (avoid self-comparison and redundant
197
+ # checks)
198
+ tril_mask = np.tril(np.ones_like(batch_similarities), k=-1)
199
+ batch_similarities = batch_similarities * tril_mask
200
+
201
+ # Find duplicates in current batch
202
+ masked_similarities = np.where(
203
+ batch_similarities > threshold, batch_similarities, -1
204
+ )
205
+ max_indices = masked_similarities.argmax(axis=1)
206
+ above_threshold = (
207
+ batch_similarities[np.arange(batch_end - i), max_indices]
208
+ > threshold
209
+ )
210
+
211
+ # Update duplicate map
212
+ for j, is_duplicate in enumerate(above_threshold):
213
+ if is_duplicate:
214
+ duplicate_to_target_map[i + j] = max_indices[j]
215
+
216
+ # Get the actual unique ids and embeddings.
217
+ unique_ids = []
218
+ unique_embeddings_dict = {}
219
+
220
+ assert embeddings, "embeddings must be valid"
221
+
222
+ for i, (_, emb) in enumerate(zip(texts, embeddings)):
223
+ if i not in duplicate_to_target_map:
224
+ unique_ids.append(i)
225
+ unique_embeddings_dict[i] = emb
226
+
227
+ return DeduplicationResult(
228
+ original_texts=texts,
229
+ unique_ids=unique_ids,
230
+ unique_embeddings_dict=unique_embeddings_dict,
231
+ duplicate_to_target_map=duplicate_to_target_map,
232
+ )