noesium 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- noesium/core/__init__.py +4 -0
- noesium/core/agent/__init__.py +14 -0
- noesium/core/agent/base.py +227 -0
- noesium/core/consts.py +6 -0
- noesium/core/goalith/conflict/conflict.py +104 -0
- noesium/core/goalith/conflict/detector.py +53 -0
- noesium/core/goalith/decomposer/__init__.py +6 -0
- noesium/core/goalith/decomposer/base.py +46 -0
- noesium/core/goalith/decomposer/callable_decomposer.py +65 -0
- noesium/core/goalith/decomposer/llm_decomposer.py +326 -0
- noesium/core/goalith/decomposer/prompts.py +140 -0
- noesium/core/goalith/decomposer/simple_decomposer.py +61 -0
- noesium/core/goalith/errors.py +22 -0
- noesium/core/goalith/goalgraph/graph.py +526 -0
- noesium/core/goalith/goalgraph/node.py +179 -0
- noesium/core/goalith/replanner/base.py +31 -0
- noesium/core/goalith/replanner/replanner.py +36 -0
- noesium/core/goalith/service.py +26 -0
- noesium/core/llm/__init__.py +154 -0
- noesium/core/llm/base.py +152 -0
- noesium/core/llm/litellm.py +528 -0
- noesium/core/llm/llamacpp.py +487 -0
- noesium/core/llm/message.py +184 -0
- noesium/core/llm/ollama.py +459 -0
- noesium/core/llm/openai.py +520 -0
- noesium/core/llm/openrouter.py +89 -0
- noesium/core/llm/prompt.py +551 -0
- noesium/core/memory/__init__.py +11 -0
- noesium/core/memory/base.py +464 -0
- noesium/core/memory/memu/__init__.py +24 -0
- noesium/core/memory/memu/config/__init__.py +26 -0
- noesium/core/memory/memu/config/activity/config.py +46 -0
- noesium/core/memory/memu/config/event/config.py +46 -0
- noesium/core/memory/memu/config/markdown_config.py +241 -0
- noesium/core/memory/memu/config/profile/config.py +48 -0
- noesium/core/memory/memu/llm_adapter.py +129 -0
- noesium/core/memory/memu/memory/__init__.py +31 -0
- noesium/core/memory/memu/memory/actions/__init__.py +40 -0
- noesium/core/memory/memu/memory/actions/add_activity_memory.py +299 -0
- noesium/core/memory/memu/memory/actions/base_action.py +342 -0
- noesium/core/memory/memu/memory/actions/cluster_memories.py +262 -0
- noesium/core/memory/memu/memory/actions/generate_suggestions.py +198 -0
- noesium/core/memory/memu/memory/actions/get_available_categories.py +66 -0
- noesium/core/memory/memu/memory/actions/link_related_memories.py +515 -0
- noesium/core/memory/memu/memory/actions/run_theory_of_mind.py +254 -0
- noesium/core/memory/memu/memory/actions/update_memory_with_suggestions.py +514 -0
- noesium/core/memory/memu/memory/embeddings.py +130 -0
- noesium/core/memory/memu/memory/file_manager.py +306 -0
- noesium/core/memory/memu/memory/memory_agent.py +578 -0
- noesium/core/memory/memu/memory/recall_agent.py +376 -0
- noesium/core/memory/memu/memory_store.py +628 -0
- noesium/core/memory/models.py +149 -0
- noesium/core/msgbus/__init__.py +12 -0
- noesium/core/msgbus/base.py +395 -0
- noesium/core/orchestrix/__init__.py +0 -0
- noesium/core/py.typed +0 -0
- noesium/core/routing/__init__.py +20 -0
- noesium/core/routing/base.py +66 -0
- noesium/core/routing/router.py +241 -0
- noesium/core/routing/strategies/__init__.py +9 -0
- noesium/core/routing/strategies/dynamic_complexity.py +361 -0
- noesium/core/routing/strategies/self_assessment.py +147 -0
- noesium/core/routing/types.py +38 -0
- noesium/core/toolify/__init__.py +39 -0
- noesium/core/toolify/base.py +360 -0
- noesium/core/toolify/config.py +138 -0
- noesium/core/toolify/mcp_integration.py +275 -0
- noesium/core/toolify/registry.py +214 -0
- noesium/core/toolify/toolkits/__init__.py +1 -0
- noesium/core/tracing/__init__.py +37 -0
- noesium/core/tracing/langgraph_hooks.py +308 -0
- noesium/core/tracing/opik_tracing.py +144 -0
- noesium/core/tracing/token_tracker.py +166 -0
- noesium/core/utils/__init__.py +10 -0
- noesium/core/utils/logging.py +172 -0
- noesium/core/utils/statistics.py +12 -0
- noesium/core/utils/typing.py +17 -0
- noesium/core/vector_store/__init__.py +79 -0
- noesium/core/vector_store/base.py +94 -0
- noesium/core/vector_store/pgvector.py +304 -0
- noesium/core/vector_store/weaviate.py +383 -0
- noesium-0.1.0.dist-info/METADATA +525 -0
- noesium-0.1.0.dist-info/RECORD +86 -0
- noesium-0.1.0.dist-info/WHEEL +5 -0
- noesium-0.1.0.dist-info/licenses/LICENSE +21 -0
- noesium-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simplified logging configuration for noesium with color support.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
from noesium.core.utils.logging import setup_logging, get_logger
|
|
6
|
+
|
|
7
|
+
# Basic usage - automatically uses environment variables if available
|
|
8
|
+
setup_logging()
|
|
9
|
+
logger = get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
# Explicit configuration
|
|
12
|
+
setup_logging(level="DEBUG", enable_colors=True)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Dict, List, Optional
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import colorlog
|
|
23
|
+
except ImportError:
|
|
24
|
+
colorlog = None # fallback if colorlog is not installed
|
|
25
|
+
|
|
26
|
+
# Configure logging
|
|
27
|
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
28
|
+
logging.basicConfig(level=getattr(logging, log_level, logging.INFO))
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def setup_logging(
|
|
33
|
+
level: Optional[str] = None,
|
|
34
|
+
log_file: Optional[str] = None,
|
|
35
|
+
log_file_level: Optional[str] = None,
|
|
36
|
+
enable_colors: Optional[bool] = None,
|
|
37
|
+
log_format: Optional[str] = None,
|
|
38
|
+
custom_colors: Optional[Dict[str, str]] = None,
|
|
39
|
+
third_party_level: Optional[str] = None,
|
|
40
|
+
clear_existing: bool = True,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Initialize logging for noesium and third-party libs.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
level: Log level (INFO, DEBUG, etc.)
|
|
46
|
+
log_file: Path to log file
|
|
47
|
+
log_file_level: Log level for file handler
|
|
48
|
+
enable_colors: Whether to enable colored output
|
|
49
|
+
log_format: Custom log format string
|
|
50
|
+
custom_colors: Custom color mapping
|
|
51
|
+
third_party_level: Log level for third-party libraries
|
|
52
|
+
clear_existing: Whether to clear existing handlers
|
|
53
|
+
"""
|
|
54
|
+
# Set defaults for any remaining None values
|
|
55
|
+
level = level or "INFO"
|
|
56
|
+
log_file_level = log_file_level or "DEBUG"
|
|
57
|
+
enable_colors = enable_colors if enable_colors is not None else True
|
|
58
|
+
third_party_level = third_party_level or "WARNING"
|
|
59
|
+
|
|
60
|
+
root_logger = logging.getLogger()
|
|
61
|
+
if clear_existing:
|
|
62
|
+
root_logger.handlers.clear()
|
|
63
|
+
|
|
64
|
+
log_level = getattr(logging, level.upper(), logging.INFO)
|
|
65
|
+
root_logger.setLevel(log_level)
|
|
66
|
+
|
|
67
|
+
fmt = log_format or "%(log_color)s%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
|
68
|
+
datefmt = "%H:%M:%S"
|
|
69
|
+
|
|
70
|
+
if enable_colors and colorlog:
|
|
71
|
+
formatter = colorlog.ColoredFormatter(
|
|
72
|
+
fmt=fmt,
|
|
73
|
+
datefmt=datefmt,
|
|
74
|
+
log_colors=custom_colors
|
|
75
|
+
or {
|
|
76
|
+
"DEBUG": "cyan",
|
|
77
|
+
"INFO": "green",
|
|
78
|
+
"WARNING": "yellow",
|
|
79
|
+
"ERROR": "red",
|
|
80
|
+
"CRITICAL": "bold_red",
|
|
81
|
+
},
|
|
82
|
+
style="%",
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
formatter = logging.Formatter(fmt.replace("%(log_color)s", ""), datefmt=datefmt)
|
|
86
|
+
|
|
87
|
+
# Console handler
|
|
88
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
89
|
+
console_handler.setFormatter(formatter)
|
|
90
|
+
console_handler.setLevel(log_level)
|
|
91
|
+
root_logger.addHandler(console_handler)
|
|
92
|
+
|
|
93
|
+
# Optional file handler
|
|
94
|
+
if log_file:
|
|
95
|
+
Path(log_file).parent.mkdir(parents=True, exist_ok=True)
|
|
96
|
+
file_handler = logging.FileHandler(log_file)
|
|
97
|
+
file_handler.setFormatter(formatter)
|
|
98
|
+
file_handler.setLevel(getattr(logging, log_file_level.upper(), logging.DEBUG))
|
|
99
|
+
root_logger.addHandler(file_handler)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_logger(name: str) -> logging.Logger:
|
|
103
|
+
return logging.getLogger(name)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# --- Per-message color utilities -------------------------------------------------
|
|
107
|
+
|
|
108
|
+
_ANSI_COLORS: Dict[str, str] = {
|
|
109
|
+
"black": "\033[30m",
|
|
110
|
+
"red": "\033[31m",
|
|
111
|
+
"green": "\033[32m",
|
|
112
|
+
"yellow": "\033[33m",
|
|
113
|
+
"blue": "\033[34m",
|
|
114
|
+
"magenta": "\033[35m",
|
|
115
|
+
"cyan": "\033[36m",
|
|
116
|
+
"white": "\033[37m",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
_ANSI_ATTRS: Dict[str, str] = {
|
|
120
|
+
"bold": "\033[1m",
|
|
121
|
+
"dim": "\033[2m",
|
|
122
|
+
"underline": "\033[4m",
|
|
123
|
+
"blink": "\033[5m",
|
|
124
|
+
"reverse": "\033[7m",
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
_ANSI_RESET = "\033[0m"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def color_text(message: str, color: Optional[str] = None, attrs: Optional[List[str]] = None) -> str:
|
|
131
|
+
"""Wrap message with ANSI color/attribute codes.
|
|
132
|
+
|
|
133
|
+
Works regardless of whether colorlog is installed. Use to color a specific
|
|
134
|
+
log invocation, e.g. `logger.info(color_text("done", "cyan"))`.
|
|
135
|
+
"""
|
|
136
|
+
if not color and not attrs:
|
|
137
|
+
return message
|
|
138
|
+
|
|
139
|
+
parts: List[str] = []
|
|
140
|
+
if attrs:
|
|
141
|
+
for attr in attrs:
|
|
142
|
+
code = _ANSI_ATTRS.get(attr.lower())
|
|
143
|
+
if code:
|
|
144
|
+
parts.append(code)
|
|
145
|
+
if color:
|
|
146
|
+
parts.append(_ANSI_COLORS.get(color.lower(), ""))
|
|
147
|
+
|
|
148
|
+
return f"{''.join(parts)}{message}{_ANSI_RESET}"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def info_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
|
|
152
|
+
"""Log INFO with optional per-message color and bold attribute."""
|
|
153
|
+
attrs = ["bold"] if bold else None
|
|
154
|
+
logger_obj.info(color_text(message, color=color, attrs=attrs))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def debug_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
|
|
158
|
+
"""Log DEBUG with optional per-message color and bold attribute."""
|
|
159
|
+
attrs = ["bold"] if bold else None
|
|
160
|
+
logger_obj.debug(color_text(message, color=color, attrs=attrs))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def warning_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
|
|
164
|
+
"""Log WARNING with optional per-message color and bold attribute."""
|
|
165
|
+
attrs = ["bold"] if bold else None
|
|
166
|
+
logger_obj.warning(color_text(message, color=color, attrs=attrs))
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def error_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
|
|
170
|
+
"""Log ERROR with optional per-message color and bold attribute."""
|
|
171
|
+
attrs = ["bold"] if bold else None
|
|
172
|
+
logger_obj.error(color_text(message, color=color, attrs=attrs))
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
6
|
+
"""Calculate cosine similarity between two vectors"""
|
|
7
|
+
dot_product = sum(x * y for x, y in zip(a, b))
|
|
8
|
+
magnitude_a = math.sqrt(sum(x * x for x in a))
|
|
9
|
+
magnitude_b = math.sqrt(sum(x * x for x in b))
|
|
10
|
+
if magnitude_a == 0 or magnitude_b == 0:
|
|
11
|
+
return 0
|
|
12
|
+
return dot_product / (magnitude_a * magnitude_b)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compatibility module for typing features not available in older Python versions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from typing import override
|
|
10
|
+
else:
|
|
11
|
+
if sys.version_info >= (3, 12):
|
|
12
|
+
from typing import override
|
|
13
|
+
else:
|
|
14
|
+
# Fallback for Python < 3.12
|
|
15
|
+
def override(func):
|
|
16
|
+
"""Fallback override decorator for Python < 3.12."""
|
|
17
|
+
return func
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from .base import BaseVectorStore, OutputData
|
|
4
|
+
|
|
5
|
+
# Optional imports - vector store providers might not be available
|
|
6
|
+
try:
|
|
7
|
+
from .weaviate import WeaviateVectorStore
|
|
8
|
+
|
|
9
|
+
WEAVIATE_AVAILABLE = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
WeaviateVectorStore = None
|
|
12
|
+
WEAVIATE_AVAILABLE = False
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from .pgvector import PGVectorStore
|
|
16
|
+
|
|
17
|
+
PGVECTOR_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
PGVectorStore = None
|
|
20
|
+
PGVECTOR_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"BaseVectorStore",
|
|
24
|
+
"OutputData",
|
|
25
|
+
"WeaviateVectorStore",
|
|
26
|
+
"PGVectorStore",
|
|
27
|
+
"get_vector_store",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
#############################
|
|
31
|
+
# Common Vector Store helper functions
|
|
32
|
+
#############################
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_vector_store(
|
|
36
|
+
provider: str = os.getenv("COGENTS_VECTOR_STORE_PROVIDER", "weaviate"),
|
|
37
|
+
collection_name: str = "default_collection",
|
|
38
|
+
embedding_model_dims: int = int(os.getenv("COGENTS_EMBEDDING_DIMS", "768")),
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Get a vector store instance based on the specified provider.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
provider: Vector store provider to use ("weaviate", "pgvector")
|
|
46
|
+
collection_name: Name of the collection/table to use
|
|
47
|
+
embedding_model_dims: Dimensions of the embedding model
|
|
48
|
+
**kwargs: Additional provider-specific arguments:
|
|
49
|
+
- weaviate: cluster_url, auth_client_secret, additional_headers
|
|
50
|
+
- pgvector: dbname, user, password, host, port, diskann, hnsw
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
BaseVectorStore instance for the specified provider
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If provider is not supported or not available
|
|
57
|
+
"""
|
|
58
|
+
if provider == "weaviate":
|
|
59
|
+
if not WEAVIATE_AVAILABLE:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"weaviate provider is not available. Please install the required dependencies: pip install weaviate-client"
|
|
62
|
+
)
|
|
63
|
+
return WeaviateVectorStore(
|
|
64
|
+
collection_name=collection_name,
|
|
65
|
+
embedding_model_dims=embedding_model_dims,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
68
|
+
elif provider == "pgvector":
|
|
69
|
+
if not PGVECTOR_AVAILABLE:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"pgvector provider is not available. Please install the required dependencies: pip install psycopg2"
|
|
72
|
+
)
|
|
73
|
+
return PGVectorStore(
|
|
74
|
+
collection_name=collection_name,
|
|
75
|
+
embedding_model_dims=embedding_model_dims,
|
|
76
|
+
**kwargs,
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(f"Unsupported provider: {provider}. Supported providers: weaviate, pgvector")
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OutputData(BaseModel):
|
|
8
|
+
"""Standard output data structure for vector store operations."""
|
|
9
|
+
|
|
10
|
+
id: str
|
|
11
|
+
score: Optional[float] = None
|
|
12
|
+
payload: Dict[str, Any]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseVectorStore(ABC):
|
|
16
|
+
def __init__(self, embedding_model_dims: int):
|
|
17
|
+
"""
|
|
18
|
+
Initialize the vector store.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
embedding_model_dims: Expected dimensions for embedding vectors
|
|
22
|
+
"""
|
|
23
|
+
self.embedding_model_dims = embedding_model_dims
|
|
24
|
+
|
|
25
|
+
def _validate_vector_dimensions(self, vectors: List[List[float]]) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Validate that all vectors have the expected dimensions.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
vectors: List of vectors to validate
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If any vector has incorrect dimensions
|
|
34
|
+
"""
|
|
35
|
+
for i, vector in enumerate(vectors):
|
|
36
|
+
if len(vector) != self.embedding_model_dims:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Vector at index {i} has {len(vector)} dimensions, "
|
|
39
|
+
f"expected {self.embedding_model_dims}. "
|
|
40
|
+
f"Check that your embedding model matches COGENTS_EMBEDDING_DIMS={self.embedding_model_dims}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def create_collection(self, vector_size: int, distance: str = "cosine") -> None:
|
|
45
|
+
"""Create a new collection."""
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def insert(
|
|
49
|
+
self,
|
|
50
|
+
vectors: List[List[float]],
|
|
51
|
+
payloads: Optional[List[Dict[str, Any]]] = None,
|
|
52
|
+
ids: Optional[List[str]] = None,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Insert vectors into a collection."""
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def search(
|
|
58
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict[str, Any]] = None
|
|
59
|
+
) -> List[OutputData]:
|
|
60
|
+
"""Search for similar vectors."""
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def delete(self, vector_id: str) -> None:
|
|
64
|
+
"""Delete a vector by ID."""
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def update(
|
|
68
|
+
self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Update a vector and its payload."""
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
74
|
+
"""Retrieve a vector by ID."""
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def list_collections(self) -> List[str]:
|
|
78
|
+
"""List all collections."""
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def delete_collection(self) -> None:
|
|
82
|
+
"""Delete a collection."""
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def collection_info(self) -> Dict[str, Any]:
|
|
86
|
+
"""Get information about a collection."""
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> List[OutputData]:
|
|
90
|
+
"""List all memories."""
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def reset(self) -> None:
|
|
94
|
+
"""Reset by delete the collection and recreate it."""
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import psycopg2
|
|
7
|
+
from psycopg2.extras import execute_values
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError("The 'psycopg2' library is required. Please install it using 'pip install psycopg2'.")
|
|
10
|
+
|
|
11
|
+
from .base import BaseVectorStore, OutputData
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PGVectorStore(BaseVectorStore):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
dbname,
|
|
20
|
+
collection_name,
|
|
21
|
+
embedding_model_dims,
|
|
22
|
+
user,
|
|
23
|
+
password,
|
|
24
|
+
host,
|
|
25
|
+
port,
|
|
26
|
+
diskann,
|
|
27
|
+
hnsw,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the PGVector database.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
dbname (str): Database name
|
|
34
|
+
collection_name (str): Collection name
|
|
35
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
36
|
+
user (str): Database user
|
|
37
|
+
password (str): Database password
|
|
38
|
+
host (str, optional): Database host
|
|
39
|
+
port (int, optional): Database port
|
|
40
|
+
diskann (bool, optional): Use DiskANN for faster search
|
|
41
|
+
hnsw (bool, optional): Use HNSW for faster search
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(embedding_model_dims)
|
|
44
|
+
|
|
45
|
+
self.collection_name = collection_name
|
|
46
|
+
self.use_diskann = diskann
|
|
47
|
+
self.use_hnsw = hnsw
|
|
48
|
+
|
|
49
|
+
self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
|
|
50
|
+
self.cur = self.conn.cursor()
|
|
51
|
+
|
|
52
|
+
collections = self.list_collections()
|
|
53
|
+
if collection_name not in collections:
|
|
54
|
+
self.create_collection(embedding_model_dims)
|
|
55
|
+
|
|
56
|
+
def create_collection(self, vector_size: int, distance: str = "cosine") -> None:
|
|
57
|
+
"""
|
|
58
|
+
Create a new collection (table in PostgreSQL).
|
|
59
|
+
Will also initialize vector search index if specified.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
vector_size (int): Dimension of the embedding vector.
|
|
63
|
+
distance (str): Distance metric (not used in PGVector, for compatibility).
|
|
64
|
+
"""
|
|
65
|
+
self.cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
66
|
+
self.cur.execute(
|
|
67
|
+
f"""
|
|
68
|
+
CREATE TABLE IF NOT EXISTS {self.collection_name} (
|
|
69
|
+
id UUID PRIMARY KEY,
|
|
70
|
+
vector vector({vector_size}),
|
|
71
|
+
payload JSONB
|
|
72
|
+
);
|
|
73
|
+
"""
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if self.use_diskann and vector_size < 2000:
|
|
77
|
+
# Check if vectorscale extension is installed
|
|
78
|
+
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
|
79
|
+
if self.cur.fetchone():
|
|
80
|
+
# Create DiskANN index if extension is installed for faster search
|
|
81
|
+
self.cur.execute(
|
|
82
|
+
f"""
|
|
83
|
+
CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
|
|
84
|
+
ON {self.collection_name}
|
|
85
|
+
USING diskann (vector);
|
|
86
|
+
"""
|
|
87
|
+
)
|
|
88
|
+
elif self.use_hnsw:
|
|
89
|
+
self.cur.execute(
|
|
90
|
+
f"""
|
|
91
|
+
CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
|
|
92
|
+
ON {self.collection_name}
|
|
93
|
+
USING hnsw (vector vector_cosine_ops)
|
|
94
|
+
"""
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self.conn.commit()
|
|
98
|
+
|
|
99
|
+
def insert(
|
|
100
|
+
self,
|
|
101
|
+
vectors: List[List[float]],
|
|
102
|
+
payloads: Optional[List[Dict[str, Any]]] = None,
|
|
103
|
+
ids: Optional[List[str]] = None,
|
|
104
|
+
) -> None:
|
|
105
|
+
"""
|
|
106
|
+
Insert vectors into a collection.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
110
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
111
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
112
|
+
"""
|
|
113
|
+
# Validate vector dimensions
|
|
114
|
+
self._validate_vector_dimensions(vectors)
|
|
115
|
+
|
|
116
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
117
|
+
json_payloads = [json.dumps(payload) for payload in payloads]
|
|
118
|
+
|
|
119
|
+
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
|
|
120
|
+
execute_values(
|
|
121
|
+
self.cur,
|
|
122
|
+
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
|
|
123
|
+
data,
|
|
124
|
+
)
|
|
125
|
+
self.conn.commit()
|
|
126
|
+
|
|
127
|
+
def search(
|
|
128
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict[str, Any]] = None
|
|
129
|
+
) -> List[OutputData]:
|
|
130
|
+
"""
|
|
131
|
+
Search for similar vectors.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
query (str): Query.
|
|
135
|
+
vectors (List[float]): Query vector.
|
|
136
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
137
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
list: Search results.
|
|
141
|
+
"""
|
|
142
|
+
filter_conditions = []
|
|
143
|
+
filter_params = []
|
|
144
|
+
|
|
145
|
+
if filters:
|
|
146
|
+
for k, v in filters.items():
|
|
147
|
+
filter_conditions.append("payload->>%s = %s")
|
|
148
|
+
filter_params.extend([k, str(v)])
|
|
149
|
+
|
|
150
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
151
|
+
|
|
152
|
+
self.cur.execute(
|
|
153
|
+
f"""
|
|
154
|
+
SELECT id, vector <=> %s::vector AS distance, payload
|
|
155
|
+
FROM {self.collection_name}
|
|
156
|
+
{filter_clause}
|
|
157
|
+
ORDER BY distance
|
|
158
|
+
LIMIT %s
|
|
159
|
+
""",
|
|
160
|
+
(vectors, *filter_params, limit),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
results = self.cur.fetchall()
|
|
164
|
+
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
|
|
165
|
+
|
|
166
|
+
def delete(self, vector_id: str) -> None:
|
|
167
|
+
"""
|
|
168
|
+
Delete a vector by ID.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
vector_id (str): ID of the vector to delete.
|
|
172
|
+
"""
|
|
173
|
+
self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
|
|
174
|
+
self.conn.commit()
|
|
175
|
+
|
|
176
|
+
def update(
|
|
177
|
+
self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None
|
|
178
|
+
) -> None:
|
|
179
|
+
"""
|
|
180
|
+
Update a vector and its payload.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
vector_id (str): ID of the vector to update.
|
|
184
|
+
vector (List[float], optional): Updated vector.
|
|
185
|
+
payload (Dict, optional): Updated payload.
|
|
186
|
+
"""
|
|
187
|
+
if vector:
|
|
188
|
+
self.cur.execute(
|
|
189
|
+
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
|
|
190
|
+
(vector, vector_id),
|
|
191
|
+
)
|
|
192
|
+
if payload:
|
|
193
|
+
self.cur.execute(
|
|
194
|
+
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
|
195
|
+
(psycopg2.extras.Json(payload), vector_id),
|
|
196
|
+
)
|
|
197
|
+
self.conn.commit()
|
|
198
|
+
|
|
199
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
200
|
+
"""
|
|
201
|
+
Retrieve a vector by ID.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
vector_id (str): ID of the vector to retrieve.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
OutputData: Retrieved vector.
|
|
208
|
+
"""
|
|
209
|
+
self.cur.execute(
|
|
210
|
+
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
|
|
211
|
+
(vector_id,),
|
|
212
|
+
)
|
|
213
|
+
result = self.cur.fetchone()
|
|
214
|
+
if not result:
|
|
215
|
+
return None
|
|
216
|
+
return OutputData(id=str(result[0]), score=None, payload=result[2])
|
|
217
|
+
|
|
218
|
+
def list_collections(self) -> List[str]:
|
|
219
|
+
"""
|
|
220
|
+
List all collections.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
List[str]: List of collection names.
|
|
224
|
+
"""
|
|
225
|
+
self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
|
|
226
|
+
return [row[0] for row in self.cur.fetchall()]
|
|
227
|
+
|
|
228
|
+
def delete_collection(self) -> None:
|
|
229
|
+
"""Delete a collection."""
|
|
230
|
+
self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
|
|
231
|
+
self.conn.commit()
|
|
232
|
+
|
|
233
|
+
def collection_info(self) -> Dict[str, Any]:
|
|
234
|
+
"""
|
|
235
|
+
Get information about a collection.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Dict[str, Any]: Collection information.
|
|
239
|
+
"""
|
|
240
|
+
self.cur.execute(
|
|
241
|
+
f"""
|
|
242
|
+
SELECT
|
|
243
|
+
table_name,
|
|
244
|
+
(SELECT COUNT(*) FROM {self.collection_name}) as row_count,
|
|
245
|
+
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
|
|
246
|
+
FROM information_schema.tables
|
|
247
|
+
WHERE table_schema = 'public' AND table_name = %s
|
|
248
|
+
""",
|
|
249
|
+
(self.collection_name,),
|
|
250
|
+
)
|
|
251
|
+
result = self.cur.fetchone()
|
|
252
|
+
return {"name": result[0], "count": result[1], "size": result[2]}
|
|
253
|
+
|
|
254
|
+
def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> List[OutputData]:
|
|
255
|
+
"""
|
|
256
|
+
List all vectors in a collection.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
260
|
+
limit (int, optional): Number of vectors to return. Defaults to 10000 if None.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
List[OutputData]: List of vectors.
|
|
264
|
+
"""
|
|
265
|
+
filter_conditions = []
|
|
266
|
+
filter_params = []
|
|
267
|
+
|
|
268
|
+
if filters:
|
|
269
|
+
for k, v in filters.items():
|
|
270
|
+
filter_conditions.append("payload->>%s = %s")
|
|
271
|
+
filter_params.extend([k, str(v)])
|
|
272
|
+
|
|
273
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
274
|
+
|
|
275
|
+
# Handle None limit by using a large default value
|
|
276
|
+
if limit is None:
|
|
277
|
+
limit = 10000 # Default large limit
|
|
278
|
+
|
|
279
|
+
query = f"""
|
|
280
|
+
SELECT id, vector, payload
|
|
281
|
+
FROM {self.collection_name}
|
|
282
|
+
{filter_clause}
|
|
283
|
+
LIMIT %s
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
self.cur.execute(query, (*filter_params, limit))
|
|
287
|
+
|
|
288
|
+
results = self.cur.fetchall()
|
|
289
|
+
return [OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]
|
|
290
|
+
|
|
291
|
+
def __del__(self) -> None:
|
|
292
|
+
"""
|
|
293
|
+
Close the database connection when the object is deleted.
|
|
294
|
+
"""
|
|
295
|
+
if hasattr(self, "cur"):
|
|
296
|
+
self.cur.close()
|
|
297
|
+
if hasattr(self, "conn"):
|
|
298
|
+
self.conn.close()
|
|
299
|
+
|
|
300
|
+
def reset(self) -> None:
|
|
301
|
+
"""Reset the index by deleting and recreating it."""
|
|
302
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
303
|
+
self.delete_collection()
|
|
304
|
+
self.create_collection(self.embedding_model_dims)
|