powermem 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. powermem/__init__.py +103 -0
  2. powermem/agent/__init__.py +35 -0
  3. powermem/agent/abstract/__init__.py +22 -0
  4. powermem/agent/abstract/collaboration.py +259 -0
  5. powermem/agent/abstract/context.py +187 -0
  6. powermem/agent/abstract/manager.py +232 -0
  7. powermem/agent/abstract/permission.py +217 -0
  8. powermem/agent/abstract/privacy.py +267 -0
  9. powermem/agent/abstract/scope.py +199 -0
  10. powermem/agent/agent.py +791 -0
  11. powermem/agent/components/__init__.py +18 -0
  12. powermem/agent/components/collaboration_coordinator.py +645 -0
  13. powermem/agent/components/permission_controller.py +586 -0
  14. powermem/agent/components/privacy_protector.py +767 -0
  15. powermem/agent/components/scope_controller.py +685 -0
  16. powermem/agent/factories/__init__.py +16 -0
  17. powermem/agent/factories/agent_factory.py +266 -0
  18. powermem/agent/factories/config_factory.py +308 -0
  19. powermem/agent/factories/memory_factory.py +229 -0
  20. powermem/agent/implementations/__init__.py +16 -0
  21. powermem/agent/implementations/hybrid.py +728 -0
  22. powermem/agent/implementations/multi_agent.py +1040 -0
  23. powermem/agent/implementations/multi_user.py +1020 -0
  24. powermem/agent/types.py +53 -0
  25. powermem/agent/wrappers/__init__.py +14 -0
  26. powermem/agent/wrappers/agent_memory_wrapper.py +427 -0
  27. powermem/agent/wrappers/compatibility_wrapper.py +520 -0
  28. powermem/config_loader.py +318 -0
  29. powermem/configs.py +249 -0
  30. powermem/core/__init__.py +19 -0
  31. powermem/core/async_memory.py +1493 -0
  32. powermem/core/audit.py +258 -0
  33. powermem/core/base.py +165 -0
  34. powermem/core/memory.py +1567 -0
  35. powermem/core/setup.py +162 -0
  36. powermem/core/telemetry.py +215 -0
  37. powermem/integrations/__init__.py +17 -0
  38. powermem/integrations/embeddings/__init__.py +13 -0
  39. powermem/integrations/embeddings/aws_bedrock.py +100 -0
  40. powermem/integrations/embeddings/azure_openai.py +55 -0
  41. powermem/integrations/embeddings/base.py +31 -0
  42. powermem/integrations/embeddings/config/base.py +132 -0
  43. powermem/integrations/embeddings/configs.py +31 -0
  44. powermem/integrations/embeddings/factory.py +48 -0
  45. powermem/integrations/embeddings/gemini.py +39 -0
  46. powermem/integrations/embeddings/huggingface.py +41 -0
  47. powermem/integrations/embeddings/langchain.py +35 -0
  48. powermem/integrations/embeddings/lmstudio.py +29 -0
  49. powermem/integrations/embeddings/mock.py +11 -0
  50. powermem/integrations/embeddings/ollama.py +53 -0
  51. powermem/integrations/embeddings/openai.py +49 -0
  52. powermem/integrations/embeddings/qwen.py +102 -0
  53. powermem/integrations/embeddings/together.py +31 -0
  54. powermem/integrations/embeddings/vertexai.py +54 -0
  55. powermem/integrations/llm/__init__.py +18 -0
  56. powermem/integrations/llm/anthropic.py +87 -0
  57. powermem/integrations/llm/base.py +132 -0
  58. powermem/integrations/llm/config/anthropic.py +56 -0
  59. powermem/integrations/llm/config/azure.py +56 -0
  60. powermem/integrations/llm/config/base.py +62 -0
  61. powermem/integrations/llm/config/deepseek.py +56 -0
  62. powermem/integrations/llm/config/ollama.py +56 -0
  63. powermem/integrations/llm/config/openai.py +79 -0
  64. powermem/integrations/llm/config/qwen.py +68 -0
  65. powermem/integrations/llm/config/qwen_asr.py +46 -0
  66. powermem/integrations/llm/config/vllm.py +56 -0
  67. powermem/integrations/llm/configs.py +26 -0
  68. powermem/integrations/llm/deepseek.py +106 -0
  69. powermem/integrations/llm/factory.py +118 -0
  70. powermem/integrations/llm/gemini.py +201 -0
  71. powermem/integrations/llm/langchain.py +65 -0
  72. powermem/integrations/llm/ollama.py +106 -0
  73. powermem/integrations/llm/openai.py +166 -0
  74. powermem/integrations/llm/openai_structured.py +80 -0
  75. powermem/integrations/llm/qwen.py +207 -0
  76. powermem/integrations/llm/qwen_asr.py +171 -0
  77. powermem/integrations/llm/vllm.py +106 -0
  78. powermem/integrations/rerank/__init__.py +20 -0
  79. powermem/integrations/rerank/base.py +43 -0
  80. powermem/integrations/rerank/config/__init__.py +7 -0
  81. powermem/integrations/rerank/config/base.py +27 -0
  82. powermem/integrations/rerank/configs.py +23 -0
  83. powermem/integrations/rerank/factory.py +68 -0
  84. powermem/integrations/rerank/qwen.py +159 -0
  85. powermem/intelligence/__init__.py +17 -0
  86. powermem/intelligence/ebbinghaus_algorithm.py +354 -0
  87. powermem/intelligence/importance_evaluator.py +361 -0
  88. powermem/intelligence/intelligent_memory_manager.py +284 -0
  89. powermem/intelligence/manager.py +148 -0
  90. powermem/intelligence/plugin.py +229 -0
  91. powermem/prompts/__init__.py +29 -0
  92. powermem/prompts/graph/graph_prompts.py +217 -0
  93. powermem/prompts/graph/graph_tools_prompts.py +469 -0
  94. powermem/prompts/importance_evaluation.py +246 -0
  95. powermem/prompts/intelligent_memory_prompts.py +163 -0
  96. powermem/prompts/templates.py +193 -0
  97. powermem/storage/__init__.py +14 -0
  98. powermem/storage/adapter.py +896 -0
  99. powermem/storage/base.py +109 -0
  100. powermem/storage/config/base.py +13 -0
  101. powermem/storage/config/oceanbase.py +58 -0
  102. powermem/storage/config/pgvector.py +52 -0
  103. powermem/storage/config/sqlite.py +27 -0
  104. powermem/storage/configs.py +159 -0
  105. powermem/storage/factory.py +59 -0
  106. powermem/storage/migration_manager.py +438 -0
  107. powermem/storage/oceanbase/__init__.py +8 -0
  108. powermem/storage/oceanbase/constants.py +162 -0
  109. powermem/storage/oceanbase/oceanbase.py +1384 -0
  110. powermem/storage/oceanbase/oceanbase_graph.py +1441 -0
  111. powermem/storage/pgvector/__init__.py +7 -0
  112. powermem/storage/pgvector/pgvector.py +420 -0
  113. powermem/storage/sqlite/__init__.py +0 -0
  114. powermem/storage/sqlite/sqlite.py +218 -0
  115. powermem/storage/sqlite/sqlite_vector_store.py +311 -0
  116. powermem/utils/__init__.py +35 -0
  117. powermem/utils/utils.py +605 -0
  118. powermem/version.py +23 -0
  119. powermem-0.1.0.dist-info/METADATA +187 -0
  120. powermem-0.1.0.dist-info/RECORD +123 -0
  121. powermem-0.1.0.dist-info/WHEEL +5 -0
  122. powermem-0.1.0.dist-info/licenses/LICENSE +206 -0
  123. powermem-0.1.0.dist-info/top_level.txt +1 -0
powermem/core/setup.py ADDED
@@ -0,0 +1,162 @@
1
+ """
2
+ Setup utilities for powermem
3
+
4
+ This module provides setup functions compatible with initialization,
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import uuid
10
+ from typing import Optional, Dict, Any
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Set up the directory path
16
+ VECTOR_ID = str(uuid.uuid4())
17
+ home_dir = os.path.expanduser("~")
18
+ powermem_dir = os.environ.get("POWERMEM_DIR") or os.path.join(home_dir, ".powermem")
19
+ os.makedirs(powermem_dir, exist_ok=True)
20
+
21
+
22
+ def setup_config():
23
+ """Setup configuration file."""
24
+ config_path = os.path.join(powermem_dir, "config.json")
25
+ if not os.path.exists(config_path):
26
+ user_id = str(uuid.uuid4())
27
+ config = {"user_id": user_id}
28
+ with open(config_path, "w") as config_file:
29
+ json.dump(config, config_file, indent=4)
30
+
31
+
32
+ def get_user_id() -> str:
33
+ """Get or create user ID."""
34
+ config_path = os.path.join(powermem_dir, "config.json")
35
+ if not os.path.exists(config_path):
36
+ setup_config()
37
+
38
+ try:
39
+ with open(config_path, "r") as config_file:
40
+ config = json.load(config_file)
41
+ user_id = config.get("user_id")
42
+ return user_id
43
+ except Exception:
44
+ return "anonymous_user"
45
+
46
+
47
+ def from_config(config: Optional[Dict[str, Any]] = None, **kwargs):
48
+ """
49
+ Create Memory instance from configuration.
50
+
51
+ powermem now uses field names natively: 'embedder' and 'vector_store'.
52
+
53
+ Args:
54
+ config: Configuration dictionary
55
+ - llm: LLM provider configuration
56
+ - embedder: Embedder configuration
57
+ - vector_store: Vector store config (uses OceanBase)
58
+ **kwargs: Additional parameters
59
+
60
+ Returns:
61
+ Memory instance
62
+
63
+ Example:
64
+ ```python
65
+ from powermem import from_config
66
+
67
+ memory = from_config({
68
+ "llm": {"provider": "openai", "config": {"api_key": "..."}},
69
+ "embedder": {"provider": "openai", "config": {"api_key": "..."}},
70
+ "vector_store": {"provider": "oceanbase", "config": {...}},
71
+ })
72
+ ```
73
+ """
74
+ from ..core.memory import Memory
75
+
76
+ if config is None:
77
+ # Use auto config from environment
78
+ from ..config_loader import auto_config
79
+ config = auto_config()
80
+
81
+ converted_config = _convert_legacy_to_mem_config(config)
82
+
83
+ return Memory(config=converted_config, **kwargs)
84
+
85
+
86
+ def _convert_legacy_to_mem_config(config: Dict[str, Any]) -> Dict[str, Any]:
87
+ """
88
+ Convert legacy powermem config format.
89
+
90
+ Now powermem uses field names natively, so we only convert legacy format.
91
+
92
+ Args:
93
+ config: Legacy powermem configuration dictionary
94
+
95
+ Returns:
96
+ configuration dictionary
97
+ """
98
+ if "embedder" in config or "vector_store" in config:
99
+ return config
100
+
101
+ converted = {}
102
+
103
+ # LLM stays the same
104
+ if "llm" in config:
105
+ converted["llm"] = config["llm"]
106
+
107
+ # Convert embedding to embedder
108
+ if "embedding" in config:
109
+ converted["embedder"] = config["embedding"]
110
+
111
+ # Convert database to vector_store
112
+ if "database" in config:
113
+ db_config = config["database"]
114
+ converted["vector_store"] = {
115
+ "provider": db_config.get("provider", "oceanbase"),
116
+ "config": db_config.get("config", {})
117
+ }
118
+ else:
119
+ converted["vector_store"] = {
120
+ "provider": "oceanbase",
121
+ "config": {}
122
+ }
123
+
124
+ return converted
125
+
126
+
127
+ def get_or_create_user_id(vector_store=None) -> str:
128
+ """
129
+ Store user_id in vector store and return it.
130
+
131
+ Args:
132
+ vector_store: Optional vector store instance
133
+
134
+ Returns:
135
+ User ID
136
+ """
137
+ user_id = get_user_id()
138
+
139
+ if vector_store is None:
140
+ return user_id
141
+
142
+ # Try to get existing user_id from vector store
143
+ try:
144
+ existing = vector_store.get(vector_id=user_id)
145
+ if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
146
+ return existing.payload["user_id"]
147
+ except Exception:
148
+ pass
149
+
150
+ # If we get here, we need to insert the user_id
151
+ try:
152
+ dims = getattr(vector_store, "embedding_model_dims", 1536)
153
+ vector_store.insert(
154
+ vectors=[[0.1] * dims],
155
+ payloads=[{"user_id": user_id, "type": "user_identity"}],
156
+ ids=[user_id]
157
+ )
158
+ except Exception:
159
+ pass
160
+
161
+ return user_id
162
+
@@ -0,0 +1,215 @@
1
+ """
2
+ Telemetry management for memory operations
3
+
4
+ This module handles telemetry data collection and reporting.
5
+ """
6
+
7
+ import logging
8
+ import json
9
+ import time
10
+ from typing import Any, Dict, Optional
11
+ from datetime import datetime
12
+ import httpx
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TelemetryManager:
18
+ """
19
+ Manages telemetry data collection and reporting.
20
+ """
21
+
22
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
23
+ """
24
+ Initialize telemetry manager.
25
+
26
+ Args:
27
+ config: Configuration dictionary
28
+ """
29
+ self.config = config or {}
30
+ self.enabled = self.config.get("enable_telemetry", False)
31
+ self.endpoint = self.config.get("telemetry_endpoint", "https://telemetry.powermem.ai")
32
+ self.api_key = self.config.get("telemetry_api_key")
33
+ self.batch_size = self.config.get("telemetry_batch_size", 100)
34
+ self.flush_interval = self.config.get("telemetry_flush_interval", 30)
35
+
36
+ self.events = []
37
+ self.last_flush = time.time()
38
+
39
+ logger.info(f"TelemetryManager initialized - enabled: {self.enabled}")
40
+
41
+ def capture_event(
42
+ self,
43
+ event_name: str,
44
+ properties: Optional[Dict[str, Any]] = None,
45
+ user_id: Optional[str] = None,
46
+ agent_id: Optional[str] = None,
47
+ ) -> None:
48
+ """
49
+ Capture a telemetry event.
50
+
51
+ Args:
52
+ event_name: Name of the event
53
+ properties: Event properties
54
+ user_id: User ID associated with the event
55
+ agent_id: Agent ID associated with the event
56
+ """
57
+ if not self.enabled:
58
+ return
59
+
60
+ try:
61
+ event = {
62
+ "event_name": event_name,
63
+ "properties": properties or {},
64
+ "user_id": user_id,
65
+ "agent_id": agent_id,
66
+ "timestamp": datetime.utcnow().isoformat(),
67
+ "version": "0.1.0",
68
+ }
69
+
70
+ self.events.append(event)
71
+
72
+ # Flush if batch size reached
73
+ if len(self.events) >= self.batch_size:
74
+ self._flush_events()
75
+
76
+ except Exception as e:
77
+ logger.error(f"Failed to capture telemetry event: {e}")
78
+
79
+ def _flush_events(self) -> None:
80
+ """Flush events to the telemetry endpoint."""
81
+ if not self.events or not self.enabled:
82
+ return
83
+
84
+ try:
85
+ if self.api_key:
86
+ headers = {
87
+ "Authorization": f"Bearer {self.api_key}",
88
+ "Content-Type": "application/json",
89
+ }
90
+ else:
91
+ headers = {"Content-Type": "application/json"}
92
+
93
+ payload = {
94
+ "events": self.events.copy(),
95
+ "timestamp": datetime.utcnow().isoformat(),
96
+ }
97
+
98
+ # Send events asynchronously to avoid blocking
99
+ self._send_events_async(payload, headers)
100
+
101
+ # Clear events after sending
102
+ self.events.clear()
103
+ self.last_flush = time.time()
104
+
105
+ except Exception as e:
106
+ logger.error(f"Failed to flush telemetry events: {e}")
107
+
108
+ def _send_events_async(self, payload: Dict[str, Any], headers: Dict[str, str]) -> None:
109
+ """Send events asynchronously."""
110
+ try:
111
+ # Try to get current event loop
112
+ import asyncio
113
+ try:
114
+ loop = asyncio.get_running_loop()
115
+ # If we're in an async context, schedule the task
116
+ asyncio.create_task(self._send_request(payload, headers))
117
+ except RuntimeError:
118
+ # No running event loop, use httpx in sync mode for now
119
+ import httpx
120
+ try:
121
+ with httpx.Client(timeout=10.0) as client:
122
+ response = client.post(
123
+ f"{self.endpoint}/events",
124
+ json=payload,
125
+ headers=headers,
126
+ timeout=10.0
127
+ )
128
+ response.raise_for_status()
129
+ except Exception as sync_e:
130
+ logger.debug(f"Failed to send telemetry events synchronously: {sync_e}")
131
+
132
+ except Exception as e:
133
+ logger.debug(f"Failed to send telemetry events: {e}")
134
+
135
+ async def _send_request(self, payload: Dict[str, Any], headers: Dict[str, str]) -> None:
136
+ """Helper method to send HTTP request asynchronously."""
137
+ async with httpx.AsyncClient() as client:
138
+ response = await client.post(
139
+ f"{self.endpoint}/events",
140
+ json=payload,
141
+ headers=headers,
142
+ timeout=10.0
143
+ )
144
+ response.raise_for_status()
145
+
146
+ def flush(self) -> None:
147
+ """Manually flush all pending events."""
148
+ self._flush_events()
149
+
150
+ def set_user_properties(self, user_id: str, properties: Dict[str, Any]) -> None:
151
+ """
152
+ Set user properties for telemetry.
153
+
154
+ Args:
155
+ user_id: User ID
156
+ properties: User properties
157
+ """
158
+ if not self.enabled:
159
+ return
160
+
161
+ try:
162
+ event = {
163
+ "event_name": "user_properties",
164
+ "properties": properties,
165
+ "user_id": user_id,
166
+ "timestamp": datetime.utcnow().isoformat(),
167
+ "version": "0.1.0",
168
+ }
169
+
170
+ self.events.append(event)
171
+
172
+ except Exception as e:
173
+ logger.error(f"Failed to set user properties: {e}")
174
+
175
+ def track_performance(self, operation: str, duration: float, metadata: Optional[Dict[str, Any]] = None) -> None:
176
+ """
177
+ Track performance metrics.
178
+
179
+ Args:
180
+ operation: Operation name
181
+ duration: Duration in seconds
182
+ metadata: Additional metadata
183
+ """
184
+ if not self.enabled:
185
+ return
186
+
187
+ self.capture_event(
188
+ "performance_metric",
189
+ {
190
+ "operation": operation,
191
+ "duration": duration,
192
+ "metadata": metadata or {},
193
+ }
194
+ )
195
+
196
+ def track_error(self, error_type: str, error_message: str, context: Optional[Dict[str, Any]] = None) -> None:
197
+ """
198
+ Track error events.
199
+
200
+ Args:
201
+ error_type: Type of error
202
+ error_message: Error message
203
+ context: Additional context
204
+ """
205
+ if not self.enabled:
206
+ return
207
+
208
+ self.capture_event(
209
+ "error",
210
+ {
211
+ "error_type": error_type,
212
+ "error_message": error_message,
213
+ "context": context or {},
214
+ }
215
+ )
@@ -0,0 +1,17 @@
1
+ """
2
+ Integration layer for external services
3
+
4
+ This module provides integrations with LLMs, embeddings, rerank, and other services.
5
+ """
6
+
7
+ from .llm.factory import LLMFactory
8
+ from .embeddings.factory import EmbedderFactory
9
+ from .rerank.factory import RerankFactory
10
+ from .rerank.configs import RerankConfig
11
+
12
+ __all__ = [
13
+ "LLMFactory",
14
+ "EmbedderFactory",
15
+ "RerankFactory",
16
+ "RerankConfig",
17
+ ]
@@ -0,0 +1,13 @@
1
+ """
2
+ Embeddings integration module
3
+
4
+ This module provides embeddings integrations and factory.
5
+ """
6
+
7
+ from .factory import EmbedderFactory
8
+
9
+ EmbedderFactory = EmbedderFactory
10
+
11
+ __all__ = [
12
+ "EmbedderFactory",
13
+ ]
@@ -0,0 +1,100 @@
1
+ import json
2
+ import os
3
+ from typing import Literal, Optional
4
+
5
+ from powermem.integrations.embeddings.base import EmbeddingBase
6
+ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig
7
+
8
+ try:
9
+ import boto3
10
+ except ImportError:
11
+ raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
12
+
13
+ import numpy as np
14
+
15
+
16
+ class AWSBedrockEmbedding(EmbeddingBase):
17
+ """AWS Bedrock embedding implementation.
18
+
19
+ This class uses AWS Bedrock's embedding models.
20
+ """
21
+
22
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
23
+ super().__init__(config)
24
+
25
+ self.config.model = self.config.model or "amazon.titan-embed-text-v1"
26
+
27
+ # Get AWS config from environment variables or use defaults
28
+ aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
29
+ aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
30
+ aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
31
+
32
+ # Check if AWS config is provided in the config
33
+ if hasattr(self.config, "aws_access_key_id"):
34
+ aws_access_key = self.config.aws_access_key_id
35
+ if hasattr(self.config, "aws_secret_access_key"):
36
+ aws_secret_key = self.config.aws_secret_access_key
37
+
38
+ # AWS region is always set in config - see BaseEmbedderConfig
39
+ aws_region = self.config.aws_region or "us-west-2"
40
+
41
+ self.client = boto3.client(
42
+ "bedrock-runtime",
43
+ region_name=aws_region,
44
+ aws_access_key_id=aws_access_key if aws_access_key else None,
45
+ aws_secret_access_key=aws_secret_key if aws_secret_key else None,
46
+ aws_session_token=aws_session_token if aws_session_token else None,
47
+ )
48
+
49
+ def _normalize_vector(self, embeddings):
50
+ """Normalize the embedding to a unit vector."""
51
+ emb = np.array(embeddings)
52
+ norm_emb = emb / np.linalg.norm(emb)
53
+ return norm_emb.tolist()
54
+
55
+ def _get_embedding(self, text):
56
+ """Call out to Bedrock embedding endpoint."""
57
+
58
+ # Format input body based on the provider
59
+ provider = self.config.model.split(".")[0]
60
+ input_body = {}
61
+
62
+ if provider == "cohere":
63
+ input_body["input_type"] = "search_document"
64
+ input_body["texts"] = [text]
65
+ else:
66
+ # Amazon and other providers
67
+ input_body["inputText"] = text
68
+
69
+ body = json.dumps(input_body)
70
+
71
+ try:
72
+ response = self.client.invoke_model(
73
+ body=body,
74
+ modelId=self.config.model,
75
+ accept="application/json",
76
+ contentType="application/json",
77
+ )
78
+
79
+ response_body = json.loads(response.get("body").read())
80
+
81
+ if provider == "cohere":
82
+ embeddings = response_body.get("embeddings")[0]
83
+ else:
84
+ embeddings = response_body.get("embedding")
85
+
86
+ return embeddings
87
+ except Exception as e:
88
+ raise ValueError(f"Error getting embedding from AWS Bedrock: {e}")
89
+
90
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
91
+ """
92
+ Get the embedding for the given text using AWS Bedrock.
93
+
94
+ Args:
95
+ text (str): The text to embed.
96
+ memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
97
+ Returns:
98
+ list: The embedding vector.
99
+ """
100
+ return self._get_embedding(text)
@@ -0,0 +1,55 @@
1
+ import os
2
+ from typing import Literal, Optional
3
+
4
+ from azure.identity import DefaultAzureCredential, get_bearer_token_provider
5
+ from openai import AzureOpenAI
6
+
7
+ from powermem.integrations.embeddings.base import EmbeddingBase
8
+ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig
9
+
10
+ SCOPE = "https://cognitiveservices.azure.com/.default"
11
+
12
+
13
+ class AzureOpenAIEmbedding(EmbeddingBase):
14
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
15
+ super().__init__(config)
16
+
17
+ api_key = self.config.azure_kwargs.api_key or os.getenv("EMBEDDING_AZURE_OPENAI_API_KEY")
18
+ azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
19
+ azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
20
+ api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
21
+ default_headers = self.config.azure_kwargs.default_headers
22
+
23
+ # If the API key is not provided or is a placeholder, use DefaultAzureCredential.
24
+ if api_key is None or api_key == "" or api_key == "your-api-key":
25
+ self.credential = DefaultAzureCredential()
26
+ azure_ad_token_provider = get_bearer_token_provider(
27
+ self.credential,
28
+ SCOPE,
29
+ )
30
+ api_key = None
31
+ else:
32
+ azure_ad_token_provider = None
33
+
34
+ self.client = AzureOpenAI(
35
+ azure_deployment=azure_deployment,
36
+ azure_endpoint=azure_endpoint,
37
+ azure_ad_token_provider=azure_ad_token_provider,
38
+ api_version=api_version,
39
+ api_key=api_key,
40
+ http_client=self.config.http_client,
41
+ default_headers=default_headers,
42
+ )
43
+
44
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
45
+ """
46
+ Get the embedding for the given text using OpenAI.
47
+
48
+ Args:
49
+ text (str): The text to embed.
50
+ memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
51
+ Returns:
52
+ list: The embedding vector.
53
+ """
54
+ text = text.replace("\n", " ")
55
+ return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
@@ -0,0 +1,31 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Literal, Optional
3
+
4
+ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig
5
+
6
+
7
+ class EmbeddingBase(ABC):
8
+ """Initialized a base embedding class
9
+
10
+ :param config: Embedding configuration option class, defaults to None
11
+ :type config: Optional[BaseEmbedderConfig], optional
12
+ """
13
+
14
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
15
+ if config is None:
16
+ self.config = BaseEmbedderConfig()
17
+ else:
18
+ self.config = config
19
+
20
+ @abstractmethod
21
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
22
+ """
23
+ Get the embedding for the given text.
24
+
25
+ Args:
26
+ text (str): The text to embed.
27
+ memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
28
+ Returns:
29
+ list: The embedding vector.
30
+ """
31
+ pass