mem0ai-azure-mysql 0.1.115.2__py3-none-any.whl → 0.1.116.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. mem0/client/main.py +20 -17
  2. mem0/configs/llms/anthropic.py +56 -0
  3. mem0/configs/llms/aws_bedrock.py +191 -0
  4. mem0/configs/llms/azure.py +57 -0
  5. mem0/configs/llms/base.py +29 -119
  6. mem0/configs/llms/deepseek.py +56 -0
  7. mem0/configs/llms/lmstudio.py +59 -0
  8. mem0/configs/llms/ollama.py +56 -0
  9. mem0/configs/llms/openai.py +76 -0
  10. mem0/configs/llms/vllm.py +56 -0
  11. mem0/configs/vector_stores/databricks.py +63 -0
  12. mem0/configs/vector_stores/elasticsearch.py +18 -0
  13. mem0/configs/vector_stores/milvus.py +1 -0
  14. mem0/configs/vector_stores/pgvector.py +15 -2
  15. mem0/configs/vector_stores/pinecone.py +1 -0
  16. mem0/embeddings/azure_openai.py +0 -3
  17. mem0/embeddings/ollama.py +1 -1
  18. mem0/graphs/configs.py +26 -2
  19. mem0/graphs/neptune/main.py +1 -0
  20. mem0/graphs/tools.py +6 -6
  21. mem0/llms/anthropic.py +33 -10
  22. mem0/llms/aws_bedrock.py +484 -154
  23. mem0/llms/azure_openai.py +30 -19
  24. mem0/llms/azure_openai_structured.py +19 -4
  25. mem0/llms/base.py +105 -6
  26. mem0/llms/deepseek.py +31 -9
  27. mem0/llms/lmstudio.py +75 -14
  28. mem0/llms/ollama.py +44 -32
  29. mem0/llms/openai.py +39 -22
  30. mem0/llms/vllm.py +32 -14
  31. mem0/memory/base.py +2 -2
  32. mem0/memory/graph_memory.py +166 -54
  33. mem0/memory/kuzu_memory.py +710 -0
  34. mem0/memory/main.py +59 -37
  35. mem0/memory/memgraph_memory.py +43 -35
  36. mem0/memory/utils.py +51 -0
  37. mem0/proxy/main.py +5 -10
  38. mem0/utils/factory.py +132 -25
  39. mem0/vector_stores/azure_ai_search.py +0 -3
  40. mem0/vector_stores/chroma.py +27 -2
  41. mem0/vector_stores/configs.py +1 -0
  42. mem0/vector_stores/databricks.py +759 -0
  43. mem0/vector_stores/elasticsearch.py +2 -0
  44. mem0/vector_stores/langchain.py +3 -2
  45. mem0/vector_stores/milvus.py +3 -1
  46. mem0/vector_stores/mongodb.py +20 -1
  47. mem0/vector_stores/pgvector.py +83 -9
  48. mem0/vector_stores/pinecone.py +17 -8
  49. mem0/vector_stores/qdrant.py +30 -0
  50. mem0ai_azure_mysql-0.1.116.2.data/data/README.md +24 -0
  51. mem0ai_azure_mysql-0.1.116.2.dist-info/METADATA +88 -0
  52. {mem0ai_azure_mysql-0.1.115.2.dist-info → mem0ai_azure_mysql-0.1.116.2.dist-info}/RECORD +53 -43
  53. mem0ai_azure_mysql-0.1.115.2.data/data/README.md +0 -169
  54. mem0ai_azure_mysql-0.1.115.2.dist-info/METADATA +0 -224
  55. mem0ai_azure_mysql-0.1.115.2.dist-info/licenses/LICENSE +0 -201
  56. {mem0ai_azure_mysql-0.1.115.2.dist-info → mem0ai_azure_mysql-0.1.116.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,56 @@
1
+ from typing import Optional
2
+
3
+ from mem0.configs.llms.base import BaseLlmConfig
4
+
5
+
6
+ class OllamaConfig(BaseLlmConfig):
7
+ """
8
+ Configuration class for Ollama-specific parameters.
9
+ Inherits from BaseLlmConfig and adds Ollama-specific settings.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ # Base parameters
15
+ model: Optional[str] = None,
16
+ temperature: float = 0.1,
17
+ api_key: Optional[str] = None,
18
+ max_tokens: int = 2000,
19
+ top_p: float = 0.1,
20
+ top_k: int = 1,
21
+ enable_vision: bool = False,
22
+ vision_details: Optional[str] = "auto",
23
+ http_client_proxies: Optional[dict] = None,
24
+ # Ollama-specific parameters
25
+ ollama_base_url: Optional[str] = None,
26
+ ):
27
+ """
28
+ Initialize Ollama configuration.
29
+
30
+ Args:
31
+ model: Ollama model to use, defaults to None
32
+ temperature: Controls randomness, defaults to 0.1
33
+ api_key: Ollama API key, defaults to None
34
+ max_tokens: Maximum tokens to generate, defaults to 2000
35
+ top_p: Nucleus sampling parameter, defaults to 0.1
36
+ top_k: Top-k sampling parameter, defaults to 1
37
+ enable_vision: Enable vision capabilities, defaults to False
38
+ vision_details: Vision detail level, defaults to "auto"
39
+ http_client_proxies: HTTP client proxy settings, defaults to None
40
+ ollama_base_url: Ollama base URL, defaults to None
41
+ """
42
+ # Initialize base parameters
43
+ super().__init__(
44
+ model=model,
45
+ temperature=temperature,
46
+ api_key=api_key,
47
+ max_tokens=max_tokens,
48
+ top_p=top_p,
49
+ top_k=top_k,
50
+ enable_vision=enable_vision,
51
+ vision_details=vision_details,
52
+ http_client_proxies=http_client_proxies,
53
+ )
54
+
55
+ # Ollama-specific parameters
56
+ self.ollama_base_url = ollama_base_url
@@ -0,0 +1,76 @@
1
+ from typing import Any, Callable, List, Optional
2
+
3
+ from mem0.configs.llms.base import BaseLlmConfig
4
+
5
+
6
+ class OpenAIConfig(BaseLlmConfig):
7
+ """
8
+ Configuration class for OpenAI and OpenRouter-specific parameters.
9
+ Inherits from BaseLlmConfig and adds OpenAI-specific settings.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ # Base parameters
15
+ model: Optional[str] = None,
16
+ temperature: float = 0.1,
17
+ api_key: Optional[str] = None,
18
+ max_tokens: int = 2000,
19
+ top_p: float = 0.1,
20
+ top_k: int = 1,
21
+ enable_vision: bool = False,
22
+ vision_details: Optional[str] = "auto",
23
+ http_client_proxies: Optional[dict] = None,
24
+ # OpenAI-specific parameters
25
+ openai_base_url: Optional[str] = None,
26
+ models: Optional[List[str]] = None,
27
+ route: Optional[str] = "fallback",
28
+ openrouter_base_url: Optional[str] = None,
29
+ site_url: Optional[str] = None,
30
+ app_name: Optional[str] = None,
31
+ # Response monitoring callback
32
+ response_callback: Optional[Callable[[Any, dict, dict], None]] = None,
33
+ ):
34
+ """
35
+ Initialize OpenAI configuration.
36
+
37
+ Args:
38
+ model: OpenAI model to use, defaults to None
39
+ temperature: Controls randomness, defaults to 0.1
40
+ api_key: OpenAI API key, defaults to None
41
+ max_tokens: Maximum tokens to generate, defaults to 2000
42
+ top_p: Nucleus sampling parameter, defaults to 0.1
43
+ top_k: Top-k sampling parameter, defaults to 1
44
+ enable_vision: Enable vision capabilities, defaults to False
45
+ vision_details: Vision detail level, defaults to "auto"
46
+ http_client_proxies: HTTP client proxy settings, defaults to None
47
+ openai_base_url: OpenAI API base URL, defaults to None
48
+ models: List of models for OpenRouter, defaults to None
49
+ route: OpenRouter route strategy, defaults to "fallback"
50
+ openrouter_base_url: OpenRouter base URL, defaults to None
51
+ site_url: Site URL for OpenRouter, defaults to None
52
+ app_name: Application name for OpenRouter, defaults to None
53
+ response_callback: Optional callback for monitoring LLM responses.
54
+ """
55
+ # Initialize base parameters
56
+ super().__init__(
57
+ model=model,
58
+ temperature=temperature,
59
+ api_key=api_key,
60
+ max_tokens=max_tokens,
61
+ top_p=top_p,
62
+ top_k=top_k,
63
+ enable_vision=enable_vision,
64
+ vision_details=vision_details,
65
+ http_client_proxies=http_client_proxies,
66
+ )
67
+
68
+ # OpenAI-specific parameters
69
+ self.openai_base_url = openai_base_url
70
+ self.models = models
71
+ self.route = route
72
+ self.openrouter_base_url = openrouter_base_url
73
+ self.site_url = site_url
74
+ self.app_name = app_name
75
+ # Response monitoring
76
+ self.response_callback = response_callback
@@ -0,0 +1,56 @@
1
+ from typing import Optional
2
+
3
+ from mem0.configs.llms.base import BaseLlmConfig
4
+
5
+
6
+ class VllmConfig(BaseLlmConfig):
7
+ """
8
+ Configuration class for vLLM-specific parameters.
9
+ Inherits from BaseLlmConfig and adds vLLM-specific settings.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ # Base parameters
15
+ model: Optional[str] = None,
16
+ temperature: float = 0.1,
17
+ api_key: Optional[str] = None,
18
+ max_tokens: int = 2000,
19
+ top_p: float = 0.1,
20
+ top_k: int = 1,
21
+ enable_vision: bool = False,
22
+ vision_details: Optional[str] = "auto",
23
+ http_client_proxies: Optional[dict] = None,
24
+ # vLLM-specific parameters
25
+ vllm_base_url: Optional[str] = None,
26
+ ):
27
+ """
28
+ Initialize vLLM configuration.
29
+
30
+ Args:
31
+ model: vLLM model to use, defaults to None
32
+ temperature: Controls randomness, defaults to 0.1
33
+ api_key: vLLM API key, defaults to None
34
+ max_tokens: Maximum tokens to generate, defaults to 2000
35
+ top_p: Nucleus sampling parameter, defaults to 0.1
36
+ top_k: Top-k sampling parameter, defaults to 1
37
+ enable_vision: Enable vision capabilities, defaults to False
38
+ vision_details: Vision detail level, defaults to "auto"
39
+ http_client_proxies: HTTP client proxy settings, defaults to None
40
+ vllm_base_url: vLLM base URL, defaults to None
41
+ """
42
+ # Initialize base parameters
43
+ super().__init__(
44
+ model=model,
45
+ temperature=temperature,
46
+ api_key=api_key,
47
+ max_tokens=max_tokens,
48
+ top_p=top_p,
49
+ top_k=top_k,
50
+ enable_vision=enable_vision,
51
+ vision_details=vision_details,
52
+ http_client_proxies=http_client_proxies,
53
+ )
54
+
55
+ # vLLM-specific parameters
56
+ self.vllm_base_url = vllm_base_url or "http://localhost:8000/v1"
@@ -0,0 +1,63 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel, Field, model_validator
4
+
5
+ from databricks.sdk.service.vectorsearch import EndpointType, VectorIndexType, PipelineType
6
+
7
+
8
+ class DatabricksConfig(BaseModel):
9
+ """Configuration for Databricks Vector Search vector store."""
10
+
11
+ workspace_url: str = Field(..., description="Databricks workspace URL")
12
+ access_token: Optional[str] = Field(None, description="Personal access token for authentication")
13
+ client_id: Optional[str] = Field(None, description="Databricks Service principal client ID")
14
+ client_secret: Optional[str] = Field(None, description="Databricks Service principal client secret")
15
+ azure_client_id: Optional[str] = Field(None, description="Azure AD application client ID (for Azure Databricks)")
16
+ azure_client_secret: Optional[str] = Field(
17
+ None, description="Azure AD application client secret (for Azure Databricks)"
18
+ )
19
+ endpoint_name: str = Field(..., description="Vector search endpoint name")
20
+ catalog: str = Field(..., description="The Unity Catalog catalog name")
21
+ schema: str = Field(..., description="The Unity Catalog schama name")
22
+ table_name: str = Field(..., description="Source Delta table name")
23
+ collection_name: str = Field("mem0", description="Vector search index name")
24
+ index_type: VectorIndexType = Field("DELTA_SYNC", description="Index type: DELTA_SYNC or DIRECT_ACCESS")
25
+ embedding_model_endpoint_name: Optional[str] = Field(
26
+ None, description="Embedding model endpoint for Databricks-computed embeddings"
27
+ )
28
+ embedding_dimension: int = Field(1536, description="Vector embedding dimensions")
29
+ endpoint_type: EndpointType = Field("STANDARD", description="Endpoint type: STANDARD or STORAGE_OPTIMIZED")
30
+ pipeline_type: PipelineType = Field("TRIGGERED", description="Sync pipeline type: TRIGGERED or CONTINUOUS")
31
+ warehouse_name: Optional[str] = Field(None, description="Databricks SQL warehouse Name")
32
+ query_type: str = Field("ANN", description="Query type: `ANN` and `HYBRID`")
33
+
34
+ @model_validator(mode="before")
35
+ @classmethod
36
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
37
+ allowed_fields = set(cls.model_fields.keys())
38
+ input_fields = set(values.keys())
39
+ extra_fields = input_fields - allowed_fields
40
+ if extra_fields:
41
+ raise ValueError(
42
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
43
+ )
44
+ return values
45
+
46
+ @model_validator(mode="after")
47
+ def validate_authentication(self):
48
+ """Validate that either access_token or service principal credentials are provided."""
49
+ has_token = self.access_token is not None
50
+ has_service_principal = (self.client_id is not None and self.client_secret is not None) or (
51
+ self.azure_client_id is not None and self.azure_client_secret is not None
52
+ )
53
+
54
+ if not has_token and not has_service_principal:
55
+ raise ValueError(
56
+ "Either access_token or both client_id/client_secret or azure_client_id/azure_client_secret must be provided"
57
+ )
58
+
59
+ return self
60
+
61
+ model_config = {
62
+ "arbitrary_types_allowed": True,
63
+ }
@@ -19,6 +19,7 @@ class ElasticsearchConfig(BaseModel):
19
19
  custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
20
20
  None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
21
21
  )
22
+ headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to include in requests")
22
23
 
23
24
  @model_validator(mode="before")
24
25
  @classmethod
@@ -33,6 +34,23 @@ class ElasticsearchConfig(BaseModel):
33
34
 
34
35
  return values
35
36
 
37
+ @model_validator(mode="before")
38
+ @classmethod
39
+ def validate_headers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
40
+ """Validate headers format and content"""
41
+ headers = values.get("headers")
42
+ if headers is not None:
43
+ # Check if headers is a dictionary
44
+ if not isinstance(headers, dict):
45
+ raise ValueError("headers must be a dictionary")
46
+
47
+ # Check if all keys and values are strings
48
+ for key, value in headers.items():
49
+ if not isinstance(key, str) or not isinstance(value, str):
50
+ raise ValueError("All header keys and values must be strings")
51
+
52
+ return values
53
+
36
54
  @model_validator(mode="before")
37
55
  @classmethod
38
56
  def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@@ -25,6 +25,7 @@ class MilvusDBConfig(BaseModel):
25
25
  collection_name: str = Field("mem0", description="Name of the collection")
26
26
  embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
27
27
  metric_type: str = Field("L2", description="Metric type for similarity search")
28
+ db_name: str = Field("", description="Name of the database")
28
29
 
29
30
  @model_validator(mode="before")
30
31
  @classmethod
@@ -13,15 +13,28 @@ class PGVectorConfig(BaseModel):
13
13
  port: Optional[int] = Field(None, description="Database port. Default is 1536")
14
14
  diskann: Optional[bool] = Field(True, description="Use diskann for approximate nearest neighbors search")
15
15
  hnsw: Optional[bool] = Field(False, description="Use hnsw for faster search")
16
+ # New SSL and connection options
17
+ sslmode: Optional[str] = Field(None, description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')")
18
+ connection_string: Optional[str] = Field(None, description="PostgreSQL connection string (overrides individual connection parameters)")
19
+ connection_pool: Optional[Any] = Field(None, description="psycopg2 connection pool object (overrides connection string and individual parameters)")
16
20
 
17
21
  @model_validator(mode="before")
18
22
  def check_auth_and_connection(cls, values):
23
+ # If connection_pool is provided, skip validation of individual connection parameters
24
+ if values.get("connection_pool") is not None:
25
+ return values
26
+
27
+ # If connection_string is provided, skip validation of individual connection parameters
28
+ if values.get("connection_string") is not None:
29
+ return values
30
+
31
+ # Otherwise, validate individual connection parameters
19
32
  user, password = values.get("user"), values.get("password")
20
33
  host, port = values.get("host"), values.get("port")
21
34
  if not user and not password:
22
- raise ValueError("Both 'user' and 'password' must be provided.")
35
+ raise ValueError("Both 'user' and 'password' must be provided when not using connection_string or connection_pool.")
23
36
  if not host and not port:
24
- raise ValueError("Both 'host' and 'port' must be provided.")
37
+ raise ValueError("Both 'host' and 'port' must be provided when not using connection_string or connection_pool.")
25
38
  return values
26
39
 
27
40
  @model_validator(mode="before")
@@ -18,6 +18,7 @@ class PineconeConfig(BaseModel):
18
18
  metric: str = Field("cosine", description="Distance metric for vector similarity")
19
19
  batch_size: int = Field(100, description="Batch size for operations")
20
20
  extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
21
+ namespace: Optional[str] = Field(None, description="Namespace for the collection")
21
22
 
22
23
  @model_validator(mode="before")
23
24
  @classmethod
@@ -24,13 +24,10 @@ class AzureOpenAIEmbedding(EmbeddingBase):
24
24
 
25
25
  auth_kwargs = {}
26
26
  if api_key:
27
- print("Using API key for Azure OpenAI Embedding authentication.")
28
27
  auth_kwargs["api_key"] = api_key
29
28
  elif azure_ad_token:
30
- print("Using Azure AD token for Azure OpenAI Embedding authentication.")
31
29
  auth_kwargs["azure_ad_token"] = azure_ad_token
32
30
  else:
33
- print("Using DefaultAzureCredential for Azure OpenAI Embedding authentication.")
34
31
  credential = DefaultAzureCredential()
35
32
  token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
36
33
  auth_kwargs["azure_ad_token_provider"] = token_provider
mem0/embeddings/ollama.py CHANGED
@@ -36,7 +36,7 @@ class OllamaEmbedding(EmbeddingBase):
36
36
  Ensure the specified model exists locally. If not, pull it from Ollama.
37
37
  """
38
38
  local_models = self.client.list()["models"]
39
- if not any(model.get("name") == self.config.model for model in local_models):
39
+ if not any(model.get("name") == self.config.model or model.get("model") == self.config.model for model in local_models):
40
40
  self.client.pull(self.config.model)
41
41
 
42
42
  def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
mem0/graphs/configs.py CHANGED
@@ -5,12 +5,24 @@ from pydantic import BaseModel, Field, field_validator, model_validator
5
5
  from mem0.llms.configs import LlmConfig
6
6
 
7
7
 
8
+ class RerankConfig(BaseModel):
9
+ provider: str = Field(
10
+ description="Provider of the rerank model (e.g., 'openai', 'azure', 'cohere')",
11
+ default="cohere",
12
+ )
13
+ config: Optional[dict] = Field(
14
+ description="Configuration for the specific rerank model", default={}
15
+ )
16
+
8
17
  class Neo4jConfig(BaseModel):
9
18
  url: Optional[str] = Field(None, description="Host address for the graph database")
10
19
  username: Optional[str] = Field(None, description="Username for the graph database")
11
20
  password: Optional[str] = Field(None, description="Password for the graph database")
12
21
  database: Optional[str] = Field(None, description="Database for the graph database")
13
22
  base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
23
+ similarity_threshold: float = Field(0.7, description="Threshold for the similarity of nodes")
24
+ top_k: int = Field(5, description="Number of top scored results to return")
25
+ rerank: Optional[RerankConfig] = Field(None, description="Rerank configuration")
14
26
 
15
27
  @model_validator(mode="before")
16
28
  def check_host_port_or_path(cls, values):
@@ -21,6 +33,12 @@ class Neo4jConfig(BaseModel):
21
33
  )
22
34
  if not url or not username or not password:
23
35
  raise ValueError("Please provide 'url', 'username' and 'password'.")
36
+
37
+ if values.get("rerank") is not None:
38
+ values["rerank"] = RerankConfig(**values.get("rerank"))
39
+ if values["rerank"].provider not in ("cohere"):
40
+ raise ValueError("Invalid rerank provider. Supported providers are: cohere")
41
+
24
42
  return values
25
43
 
26
44
 
@@ -70,12 +88,16 @@ class NeptuneConfig(BaseModel):
70
88
  )
71
89
 
72
90
 
91
+ class KuzuConfig(BaseModel):
92
+ db: Optional[str] = Field(":memory:", description="Path to a Kuzu database file")
93
+
94
+
73
95
  class GraphStoreConfig(BaseModel):
74
96
  provider: str = Field(
75
- description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune')",
97
+ description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune', 'kuzu')",
76
98
  default="neo4j",
77
99
  )
78
- config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig] = Field(
100
+ config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig, KuzuConfig] = Field(
79
101
  description="Configuration for the specific data store", default=None
80
102
  )
81
103
  llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
@@ -92,5 +114,7 @@ class GraphStoreConfig(BaseModel):
92
114
  return MemgraphConfig(**v.model_dump())
93
115
  elif provider == "neptune":
94
116
  return NeptuneConfig(**v.model_dump())
117
+ elif provider == "kuzu":
118
+ return KuzuConfig(**v.model_dump())
95
119
  else:
96
120
  raise ValueError(f"Unsupported graph store provider: {provider}")
@@ -1,4 +1,5 @@
1
1
  import logging
2
+
2
3
  from .base import NeptuneBase
3
4
 
4
5
  try:
mem0/graphs/tools.py CHANGED
@@ -249,23 +249,23 @@ RELATIONS_STRUCT_TOOL = {
249
249
  "items": {
250
250
  "type": "object",
251
251
  "properties": {
252
- "source_entity": {
252
+ "source": {
253
253
  "type": "string",
254
254
  "description": "The source entity of the relationship.",
255
255
  },
256
- "relatationship": {
256
+ "relationship": {
257
257
  "type": "string",
258
258
  "description": "The relationship between the source and destination entities.",
259
259
  },
260
- "destination_entity": {
260
+ "destination": {
261
261
  "type": "string",
262
262
  "description": "The destination entity of the relationship.",
263
263
  },
264
264
  },
265
265
  "required": [
266
- "source_entity",
267
- "relatationship",
268
- "destination_entity",
266
+ "source",
267
+ "relationship",
268
+ "destination",
269
269
  ],
270
270
  "additionalProperties": False,
271
271
  },
mem0/llms/anthropic.py CHANGED
@@ -1,17 +1,37 @@
1
1
  import os
2
- from typing import Dict, List, Optional
2
+ from typing import Dict, List, Optional, Union
3
3
 
4
4
  try:
5
5
  import anthropic
6
6
  except ImportError:
7
7
  raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
8
8
 
9
+ from mem0.configs.llms.anthropic import AnthropicConfig
9
10
  from mem0.configs.llms.base import BaseLlmConfig
10
11
  from mem0.llms.base import LLMBase
11
12
 
12
13
 
13
14
  class AnthropicLLM(LLMBase):
14
- def __init__(self, config: Optional[BaseLlmConfig] = None):
15
+ def __init__(self, config: Optional[Union[BaseLlmConfig, AnthropicConfig, Dict]] = None):
16
+ # Convert to AnthropicConfig if needed
17
+ if config is None:
18
+ config = AnthropicConfig()
19
+ elif isinstance(config, dict):
20
+ config = AnthropicConfig(**config)
21
+ elif isinstance(config, BaseLlmConfig) and not isinstance(config, AnthropicConfig):
22
+ # Convert BaseLlmConfig to AnthropicConfig
23
+ config = AnthropicConfig(
24
+ model=config.model,
25
+ temperature=config.temperature,
26
+ api_key=config.api_key,
27
+ max_tokens=config.max_tokens,
28
+ top_p=config.top_p,
29
+ top_k=config.top_k,
30
+ enable_vision=config.enable_vision,
31
+ vision_details=config.vision_details,
32
+ http_client_proxies=config.http_client,
33
+ )
34
+
15
35
  super().__init__(config)
16
36
 
17
37
  if not self.config.model:
@@ -26,6 +46,7 @@ class AnthropicLLM(LLMBase):
26
46
  response_format=None,
27
47
  tools: Optional[List[Dict]] = None,
28
48
  tool_choice: str = "auto",
49
+ **kwargs,
29
50
  ):
30
51
  """
31
52
  Generate a response based on the given messages using Anthropic.
@@ -35,6 +56,7 @@ class AnthropicLLM(LLMBase):
35
56
  response_format (str or object, optional): Format of the response. Defaults to "text".
36
57
  tools (list, optional): List of tools that the model can call. Defaults to None.
37
58
  tool_choice (str, optional): Tool choice method. Defaults to "auto".
59
+ **kwargs: Additional Anthropic-specific parameters.
38
60
 
39
61
  Returns:
40
62
  str: The generated response.
@@ -48,14 +70,15 @@ class AnthropicLLM(LLMBase):
48
70
  else:
49
71
  filtered_messages.append(message)
50
72
 
51
- params = {
52
- "model": self.config.model,
53
- "messages": filtered_messages,
54
- "system": system_message,
55
- "temperature": self.config.temperature,
56
- "max_tokens": self.config.max_tokens,
57
- "top_p": self.config.top_p,
58
- }
73
+ params = self._get_supported_params(messages=messages, **kwargs)
74
+ params.update(
75
+ {
76
+ "model": self.config.model,
77
+ "messages": filtered_messages,
78
+ "system": system_message,
79
+ }
80
+ )
81
+
59
82
  if tools: # TODO: Remove tools if no issues found with new memory addition logic
60
83
  params["tools"] = tools
61
84
  params["tool_choice"] = tool_choice