noesium 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. noesium/core/__init__.py +4 -0
  2. noesium/core/agent/__init__.py +14 -0
  3. noesium/core/agent/base.py +227 -0
  4. noesium/core/consts.py +6 -0
  5. noesium/core/goalith/conflict/conflict.py +104 -0
  6. noesium/core/goalith/conflict/detector.py +53 -0
  7. noesium/core/goalith/decomposer/__init__.py +6 -0
  8. noesium/core/goalith/decomposer/base.py +46 -0
  9. noesium/core/goalith/decomposer/callable_decomposer.py +65 -0
  10. noesium/core/goalith/decomposer/llm_decomposer.py +326 -0
  11. noesium/core/goalith/decomposer/prompts.py +140 -0
  12. noesium/core/goalith/decomposer/simple_decomposer.py +61 -0
  13. noesium/core/goalith/errors.py +22 -0
  14. noesium/core/goalith/goalgraph/graph.py +526 -0
  15. noesium/core/goalith/goalgraph/node.py +179 -0
  16. noesium/core/goalith/replanner/base.py +31 -0
  17. noesium/core/goalith/replanner/replanner.py +36 -0
  18. noesium/core/goalith/service.py +26 -0
  19. noesium/core/llm/__init__.py +154 -0
  20. noesium/core/llm/base.py +152 -0
  21. noesium/core/llm/litellm.py +528 -0
  22. noesium/core/llm/llamacpp.py +487 -0
  23. noesium/core/llm/message.py +184 -0
  24. noesium/core/llm/ollama.py +459 -0
  25. noesium/core/llm/openai.py +520 -0
  26. noesium/core/llm/openrouter.py +89 -0
  27. noesium/core/llm/prompt.py +551 -0
  28. noesium/core/memory/__init__.py +11 -0
  29. noesium/core/memory/base.py +464 -0
  30. noesium/core/memory/memu/__init__.py +24 -0
  31. noesium/core/memory/memu/config/__init__.py +26 -0
  32. noesium/core/memory/memu/config/activity/config.py +46 -0
  33. noesium/core/memory/memu/config/event/config.py +46 -0
  34. noesium/core/memory/memu/config/markdown_config.py +241 -0
  35. noesium/core/memory/memu/config/profile/config.py +48 -0
  36. noesium/core/memory/memu/llm_adapter.py +129 -0
  37. noesium/core/memory/memu/memory/__init__.py +31 -0
  38. noesium/core/memory/memu/memory/actions/__init__.py +40 -0
  39. noesium/core/memory/memu/memory/actions/add_activity_memory.py +299 -0
  40. noesium/core/memory/memu/memory/actions/base_action.py +342 -0
  41. noesium/core/memory/memu/memory/actions/cluster_memories.py +262 -0
  42. noesium/core/memory/memu/memory/actions/generate_suggestions.py +198 -0
  43. noesium/core/memory/memu/memory/actions/get_available_categories.py +66 -0
  44. noesium/core/memory/memu/memory/actions/link_related_memories.py +515 -0
  45. noesium/core/memory/memu/memory/actions/run_theory_of_mind.py +254 -0
  46. noesium/core/memory/memu/memory/actions/update_memory_with_suggestions.py +514 -0
  47. noesium/core/memory/memu/memory/embeddings.py +130 -0
  48. noesium/core/memory/memu/memory/file_manager.py +306 -0
  49. noesium/core/memory/memu/memory/memory_agent.py +578 -0
  50. noesium/core/memory/memu/memory/recall_agent.py +376 -0
  51. noesium/core/memory/memu/memory_store.py +628 -0
  52. noesium/core/memory/models.py +149 -0
  53. noesium/core/msgbus/__init__.py +12 -0
  54. noesium/core/msgbus/base.py +395 -0
  55. noesium/core/orchestrix/__init__.py +0 -0
  56. noesium/core/py.typed +0 -0
  57. noesium/core/routing/__init__.py +20 -0
  58. noesium/core/routing/base.py +66 -0
  59. noesium/core/routing/router.py +241 -0
  60. noesium/core/routing/strategies/__init__.py +9 -0
  61. noesium/core/routing/strategies/dynamic_complexity.py +361 -0
  62. noesium/core/routing/strategies/self_assessment.py +147 -0
  63. noesium/core/routing/types.py +38 -0
  64. noesium/core/toolify/__init__.py +39 -0
  65. noesium/core/toolify/base.py +360 -0
  66. noesium/core/toolify/config.py +138 -0
  67. noesium/core/toolify/mcp_integration.py +275 -0
  68. noesium/core/toolify/registry.py +214 -0
  69. noesium/core/toolify/toolkits/__init__.py +1 -0
  70. noesium/core/tracing/__init__.py +37 -0
  71. noesium/core/tracing/langgraph_hooks.py +308 -0
  72. noesium/core/tracing/opik_tracing.py +144 -0
  73. noesium/core/tracing/token_tracker.py +166 -0
  74. noesium/core/utils/__init__.py +10 -0
  75. noesium/core/utils/logging.py +172 -0
  76. noesium/core/utils/statistics.py +12 -0
  77. noesium/core/utils/typing.py +17 -0
  78. noesium/core/vector_store/__init__.py +79 -0
  79. noesium/core/vector_store/base.py +94 -0
  80. noesium/core/vector_store/pgvector.py +304 -0
  81. noesium/core/vector_store/weaviate.py +383 -0
  82. noesium-0.1.0.dist-info/METADATA +525 -0
  83. noesium-0.1.0.dist-info/RECORD +86 -0
  84. noesium-0.1.0.dist-info/WHEEL +5 -0
  85. noesium-0.1.0.dist-info/licenses/LICENSE +21 -0
  86. noesium-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,172 @@
1
+ """
2
+ Simplified logging configuration for noesium with color support.
3
+
4
+ Usage:
5
+ from noesium.core.utils.logging import setup_logging, get_logger
6
+
7
+ # Basic usage - automatically uses environment variables if available
8
+ setup_logging()
9
+ logger = get_logger(__name__)
10
+
11
+ # Explicit configuration
12
+ setup_logging(level="DEBUG", enable_colors=True)
13
+ """
14
+
15
+ import logging
16
+ import os
17
+ import sys
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional
20
+
21
+ try:
22
+ import colorlog
23
+ except ImportError:
24
+ colorlog = None # fallback if colorlog is not installed
25
+
26
+ # Configure logging
27
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
28
+ logging.basicConfig(level=getattr(logging, log_level, logging.INFO))
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def setup_logging(
33
+ level: Optional[str] = None,
34
+ log_file: Optional[str] = None,
35
+ log_file_level: Optional[str] = None,
36
+ enable_colors: Optional[bool] = None,
37
+ log_format: Optional[str] = None,
38
+ custom_colors: Optional[Dict[str, str]] = None,
39
+ third_party_level: Optional[str] = None,
40
+ clear_existing: bool = True,
41
+ ) -> None:
42
+ """Initialize logging for noesium and third-party libs.
43
+
44
+ Args:
45
+ level: Log level (INFO, DEBUG, etc.)
46
+ log_file: Path to log file
47
+ log_file_level: Log level for file handler
48
+ enable_colors: Whether to enable colored output
49
+ log_format: Custom log format string
50
+ custom_colors: Custom color mapping
51
+ third_party_level: Log level for third-party libraries
52
+ clear_existing: Whether to clear existing handlers
53
+ """
54
+ # Set defaults for any remaining None values
55
+ level = level or "INFO"
56
+ log_file_level = log_file_level or "DEBUG"
57
+ enable_colors = enable_colors if enable_colors is not None else True
58
+ third_party_level = third_party_level or "WARNING"
59
+
60
+ root_logger = logging.getLogger()
61
+ if clear_existing:
62
+ root_logger.handlers.clear()
63
+
64
+ log_level = getattr(logging, level.upper(), logging.INFO)
65
+ root_logger.setLevel(log_level)
66
+
67
+ fmt = log_format or "%(log_color)s%(asctime)s [%(levelname)s] %(name)s: %(message)s"
68
+ datefmt = "%H:%M:%S"
69
+
70
+ if enable_colors and colorlog:
71
+ formatter = colorlog.ColoredFormatter(
72
+ fmt=fmt,
73
+ datefmt=datefmt,
74
+ log_colors=custom_colors
75
+ or {
76
+ "DEBUG": "cyan",
77
+ "INFO": "green",
78
+ "WARNING": "yellow",
79
+ "ERROR": "red",
80
+ "CRITICAL": "bold_red",
81
+ },
82
+ style="%",
83
+ )
84
+ else:
85
+ formatter = logging.Formatter(fmt.replace("%(log_color)s", ""), datefmt=datefmt)
86
+
87
+ # Console handler
88
+ console_handler = logging.StreamHandler(sys.stdout)
89
+ console_handler.setFormatter(formatter)
90
+ console_handler.setLevel(log_level)
91
+ root_logger.addHandler(console_handler)
92
+
93
+ # Optional file handler
94
+ if log_file:
95
+ Path(log_file).parent.mkdir(parents=True, exist_ok=True)
96
+ file_handler = logging.FileHandler(log_file)
97
+ file_handler.setFormatter(formatter)
98
+ file_handler.setLevel(getattr(logging, log_file_level.upper(), logging.DEBUG))
99
+ root_logger.addHandler(file_handler)
100
+
101
+
102
+ def get_logger(name: str) -> logging.Logger:
103
+ return logging.getLogger(name)
104
+
105
+
106
+ # --- Per-message color utilities -------------------------------------------------
107
+
108
+ _ANSI_COLORS: Dict[str, str] = {
109
+ "black": "\033[30m",
110
+ "red": "\033[31m",
111
+ "green": "\033[32m",
112
+ "yellow": "\033[33m",
113
+ "blue": "\033[34m",
114
+ "magenta": "\033[35m",
115
+ "cyan": "\033[36m",
116
+ "white": "\033[37m",
117
+ }
118
+
119
+ _ANSI_ATTRS: Dict[str, str] = {
120
+ "bold": "\033[1m",
121
+ "dim": "\033[2m",
122
+ "underline": "\033[4m",
123
+ "blink": "\033[5m",
124
+ "reverse": "\033[7m",
125
+ }
126
+
127
+ _ANSI_RESET = "\033[0m"
128
+
129
+
130
+ def color_text(message: str, color: Optional[str] = None, attrs: Optional[List[str]] = None) -> str:
131
+ """Wrap message with ANSI color/attribute codes.
132
+
133
+ Works regardless of whether colorlog is installed. Use to color a specific
134
+ log invocation, e.g. `logger.info(color_text("done", "cyan"))`.
135
+ """
136
+ if not color and not attrs:
137
+ return message
138
+
139
+ parts: List[str] = []
140
+ if attrs:
141
+ for attr in attrs:
142
+ code = _ANSI_ATTRS.get(attr.lower())
143
+ if code:
144
+ parts.append(code)
145
+ if color:
146
+ parts.append(_ANSI_COLORS.get(color.lower(), ""))
147
+
148
+ return f"{''.join(parts)}{message}{_ANSI_RESET}"
149
+
150
+
151
+ def info_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
152
+ """Log INFO with optional per-message color and bold attribute."""
153
+ attrs = ["bold"] if bold else None
154
+ logger_obj.info(color_text(message, color=color, attrs=attrs))
155
+
156
+
157
+ def debug_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
158
+ """Log DEBUG with optional per-message color and bold attribute."""
159
+ attrs = ["bold"] if bold else None
160
+ logger_obj.debug(color_text(message, color=color, attrs=attrs))
161
+
162
+
163
+ def warning_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
164
+ """Log WARNING with optional per-message color and bold attribute."""
165
+ attrs = ["bold"] if bold else None
166
+ logger_obj.warning(color_text(message, color=color, attrs=attrs))
167
+
168
+
169
+ def error_color(logger_obj: logging.Logger, message: str, color: Optional[str] = None, *, bold: bool = False) -> None:
170
+ """Log ERROR with optional per-message color and bold attribute."""
171
+ attrs = ["bold"] if bold else None
172
+ logger_obj.error(color_text(message, color=color, attrs=attrs))
@@ -0,0 +1,12 @@
1
+ import math
2
+ from typing import List
3
+
4
+
5
+ def cosine_similarity(a: List[float], b: List[float]) -> float:
6
+ """Calculate cosine similarity between two vectors"""
7
+ dot_product = sum(x * y for x, y in zip(a, b))
8
+ magnitude_a = math.sqrt(sum(x * x for x in a))
9
+ magnitude_b = math.sqrt(sum(x * x for x in b))
10
+ if magnitude_a == 0 or magnitude_b == 0:
11
+ return 0
12
+ return dot_product / (magnitude_a * magnitude_b)
@@ -0,0 +1,17 @@
1
+ """
2
+ Compatibility module for typing features not available in older Python versions.
3
+ """
4
+
5
+ import sys
6
+ from typing import TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from typing import override
10
+ else:
11
+ if sys.version_info >= (3, 12):
12
+ from typing import override
13
+ else:
14
+ # Fallback for Python < 3.12
15
+ def override(func):
16
+ """Fallback override decorator for Python < 3.12."""
17
+ return func
@@ -0,0 +1,79 @@
1
+ import os
2
+
3
+ from .base import BaseVectorStore, OutputData
4
+
5
+ # Optional imports - vector store providers might not be available
6
+ try:
7
+ from .weaviate import WeaviateVectorStore
8
+
9
+ WEAVIATE_AVAILABLE = True
10
+ except ImportError:
11
+ WeaviateVectorStore = None
12
+ WEAVIATE_AVAILABLE = False
13
+
14
+ try:
15
+ from .pgvector import PGVectorStore
16
+
17
+ PGVECTOR_AVAILABLE = True
18
+ except ImportError:
19
+ PGVectorStore = None
20
+ PGVECTOR_AVAILABLE = False
21
+
22
+ __all__ = [
23
+ "BaseVectorStore",
24
+ "OutputData",
25
+ "WeaviateVectorStore",
26
+ "PGVectorStore",
27
+ "get_vector_store",
28
+ ]
29
+
30
+ #############################
31
+ # Common Vector Store helper functions
32
+ #############################
33
+
34
+
35
+ def get_vector_store(
36
+ provider: str = os.getenv("COGENTS_VECTOR_STORE_PROVIDER", "weaviate"),
37
+ collection_name: str = "default_collection",
38
+ embedding_model_dims: int = int(os.getenv("COGENTS_EMBEDDING_DIMS", "768")),
39
+ **kwargs,
40
+ ):
41
+ """
42
+ Get a vector store instance based on the specified provider.
43
+
44
+ Args:
45
+ provider: Vector store provider to use ("weaviate", "pgvector")
46
+ collection_name: Name of the collection/table to use
47
+ embedding_model_dims: Dimensions of the embedding model
48
+ **kwargs: Additional provider-specific arguments:
49
+ - weaviate: cluster_url, auth_client_secret, additional_headers
50
+ - pgvector: dbname, user, password, host, port, diskann, hnsw
51
+
52
+ Returns:
53
+ BaseVectorStore instance for the specified provider
54
+
55
+ Raises:
56
+ ValueError: If provider is not supported or not available
57
+ """
58
+ if provider == "weaviate":
59
+ if not WEAVIATE_AVAILABLE:
60
+ raise ValueError(
61
+ "weaviate provider is not available. Please install the required dependencies: pip install weaviate-client"
62
+ )
63
+ return WeaviateVectorStore(
64
+ collection_name=collection_name,
65
+ embedding_model_dims=embedding_model_dims,
66
+ **kwargs,
67
+ )
68
+ elif provider == "pgvector":
69
+ if not PGVECTOR_AVAILABLE:
70
+ raise ValueError(
71
+ "pgvector provider is not available. Please install the required dependencies: pip install psycopg2"
72
+ )
73
+ return PGVectorStore(
74
+ collection_name=collection_name,
75
+ embedding_model_dims=embedding_model_dims,
76
+ **kwargs,
77
+ )
78
+ else:
79
+ raise ValueError(f"Unsupported provider: {provider}. Supported providers: weaviate, pgvector")
@@ -0,0 +1,94 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class OutputData(BaseModel):
8
+ """Standard output data structure for vector store operations."""
9
+
10
+ id: str
11
+ score: Optional[float] = None
12
+ payload: Dict[str, Any]
13
+
14
+
15
+ class BaseVectorStore(ABC):
16
+ def __init__(self, embedding_model_dims: int):
17
+ """
18
+ Initialize the vector store.
19
+
20
+ Args:
21
+ embedding_model_dims: Expected dimensions for embedding vectors
22
+ """
23
+ self.embedding_model_dims = embedding_model_dims
24
+
25
+ def _validate_vector_dimensions(self, vectors: List[List[float]]) -> None:
26
+ """
27
+ Validate that all vectors have the expected dimensions.
28
+
29
+ Args:
30
+ vectors: List of vectors to validate
31
+
32
+ Raises:
33
+ ValueError: If any vector has incorrect dimensions
34
+ """
35
+ for i, vector in enumerate(vectors):
36
+ if len(vector) != self.embedding_model_dims:
37
+ raise ValueError(
38
+ f"Vector at index {i} has {len(vector)} dimensions, "
39
+ f"expected {self.embedding_model_dims}. "
40
+ f"Check that your embedding model matches COGENTS_EMBEDDING_DIMS={self.embedding_model_dims}"
41
+ )
42
+
43
+ @abstractmethod
44
+ def create_collection(self, vector_size: int, distance: str = "cosine") -> None:
45
+ """Create a new collection."""
46
+
47
+ @abstractmethod
48
+ def insert(
49
+ self,
50
+ vectors: List[List[float]],
51
+ payloads: Optional[List[Dict[str, Any]]] = None,
52
+ ids: Optional[List[str]] = None,
53
+ ) -> None:
54
+ """Insert vectors into a collection."""
55
+
56
+ @abstractmethod
57
+ def search(
58
+ self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict[str, Any]] = None
59
+ ) -> List[OutputData]:
60
+ """Search for similar vectors."""
61
+
62
+ @abstractmethod
63
+ def delete(self, vector_id: str) -> None:
64
+ """Delete a vector by ID."""
65
+
66
+ @abstractmethod
67
+ def update(
68
+ self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None
69
+ ) -> None:
70
+ """Update a vector and its payload."""
71
+
72
+ @abstractmethod
73
+ def get(self, vector_id: str) -> Optional[OutputData]:
74
+ """Retrieve a vector by ID."""
75
+
76
+ @abstractmethod
77
+ def list_collections(self) -> List[str]:
78
+ """List all collections."""
79
+
80
+ @abstractmethod
81
+ def delete_collection(self) -> None:
82
+ """Delete a collection."""
83
+
84
+ @abstractmethod
85
+ def collection_info(self) -> Dict[str, Any]:
86
+ """Get information about a collection."""
87
+
88
+ @abstractmethod
89
+ def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> List[OutputData]:
90
+ """List all memories."""
91
+
92
+ @abstractmethod
93
+ def reset(self) -> None:
94
+ """Reset by delete the collection and recreate it."""
@@ -0,0 +1,304 @@
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ try:
6
+ import psycopg2
7
+ from psycopg2.extras import execute_values
8
+ except ImportError:
9
+ raise ImportError("The 'psycopg2' library is required. Please install it using 'pip install psycopg2'.")
10
+
11
+ from .base import BaseVectorStore, OutputData
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class PGVectorStore(BaseVectorStore):
17
+ def __init__(
18
+ self,
19
+ dbname,
20
+ collection_name,
21
+ embedding_model_dims,
22
+ user,
23
+ password,
24
+ host,
25
+ port,
26
+ diskann,
27
+ hnsw,
28
+ ):
29
+ """
30
+ Initialize the PGVector database.
31
+
32
+ Args:
33
+ dbname (str): Database name
34
+ collection_name (str): Collection name
35
+ embedding_model_dims (int): Dimension of the embedding vector
36
+ user (str): Database user
37
+ password (str): Database password
38
+ host (str, optional): Database host
39
+ port (int, optional): Database port
40
+ diskann (bool, optional): Use DiskANN for faster search
41
+ hnsw (bool, optional): Use HNSW for faster search
42
+ """
43
+ super().__init__(embedding_model_dims)
44
+
45
+ self.collection_name = collection_name
46
+ self.use_diskann = diskann
47
+ self.use_hnsw = hnsw
48
+
49
+ self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
50
+ self.cur = self.conn.cursor()
51
+
52
+ collections = self.list_collections()
53
+ if collection_name not in collections:
54
+ self.create_collection(embedding_model_dims)
55
+
56
+ def create_collection(self, vector_size: int, distance: str = "cosine") -> None:
57
+ """
58
+ Create a new collection (table in PostgreSQL).
59
+ Will also initialize vector search index if specified.
60
+
61
+ Args:
62
+ vector_size (int): Dimension of the embedding vector.
63
+ distance (str): Distance metric (not used in PGVector, for compatibility).
64
+ """
65
+ self.cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
66
+ self.cur.execute(
67
+ f"""
68
+ CREATE TABLE IF NOT EXISTS {self.collection_name} (
69
+ id UUID PRIMARY KEY,
70
+ vector vector({vector_size}),
71
+ payload JSONB
72
+ );
73
+ """
74
+ )
75
+
76
+ if self.use_diskann and vector_size < 2000:
77
+ # Check if vectorscale extension is installed
78
+ self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
79
+ if self.cur.fetchone():
80
+ # Create DiskANN index if extension is installed for faster search
81
+ self.cur.execute(
82
+ f"""
83
+ CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
84
+ ON {self.collection_name}
85
+ USING diskann (vector);
86
+ """
87
+ )
88
+ elif self.use_hnsw:
89
+ self.cur.execute(
90
+ f"""
91
+ CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
92
+ ON {self.collection_name}
93
+ USING hnsw (vector vector_cosine_ops)
94
+ """
95
+ )
96
+
97
+ self.conn.commit()
98
+
99
+ def insert(
100
+ self,
101
+ vectors: List[List[float]],
102
+ payloads: Optional[List[Dict[str, Any]]] = None,
103
+ ids: Optional[List[str]] = None,
104
+ ) -> None:
105
+ """
106
+ Insert vectors into a collection.
107
+
108
+ Args:
109
+ vectors (List[List[float]]): List of vectors to insert.
110
+ payloads (List[Dict], optional): List of payloads corresponding to vectors.
111
+ ids (List[str], optional): List of IDs corresponding to vectors.
112
+ """
113
+ # Validate vector dimensions
114
+ self._validate_vector_dimensions(vectors)
115
+
116
+ logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
117
+ json_payloads = [json.dumps(payload) for payload in payloads]
118
+
119
+ data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
120
+ execute_values(
121
+ self.cur,
122
+ f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
123
+ data,
124
+ )
125
+ self.conn.commit()
126
+
127
+ def search(
128
+ self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict[str, Any]] = None
129
+ ) -> List[OutputData]:
130
+ """
131
+ Search for similar vectors.
132
+
133
+ Args:
134
+ query (str): Query.
135
+ vectors (List[float]): Query vector.
136
+ limit (int, optional): Number of results to return. Defaults to 5.
137
+ filters (Dict, optional): Filters to apply to the search. Defaults to None.
138
+
139
+ Returns:
140
+ list: Search results.
141
+ """
142
+ filter_conditions = []
143
+ filter_params = []
144
+
145
+ if filters:
146
+ for k, v in filters.items():
147
+ filter_conditions.append("payload->>%s = %s")
148
+ filter_params.extend([k, str(v)])
149
+
150
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
151
+
152
+ self.cur.execute(
153
+ f"""
154
+ SELECT id, vector <=> %s::vector AS distance, payload
155
+ FROM {self.collection_name}
156
+ {filter_clause}
157
+ ORDER BY distance
158
+ LIMIT %s
159
+ """,
160
+ (vectors, *filter_params, limit),
161
+ )
162
+
163
+ results = self.cur.fetchall()
164
+ return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
165
+
166
+ def delete(self, vector_id: str) -> None:
167
+ """
168
+ Delete a vector by ID.
169
+
170
+ Args:
171
+ vector_id (str): ID of the vector to delete.
172
+ """
173
+ self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
174
+ self.conn.commit()
175
+
176
+ def update(
177
+ self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict[str, Any]] = None
178
+ ) -> None:
179
+ """
180
+ Update a vector and its payload.
181
+
182
+ Args:
183
+ vector_id (str): ID of the vector to update.
184
+ vector (List[float], optional): Updated vector.
185
+ payload (Dict, optional): Updated payload.
186
+ """
187
+ if vector:
188
+ self.cur.execute(
189
+ f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
190
+ (vector, vector_id),
191
+ )
192
+ if payload:
193
+ self.cur.execute(
194
+ f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
195
+ (psycopg2.extras.Json(payload), vector_id),
196
+ )
197
+ self.conn.commit()
198
+
199
+ def get(self, vector_id: str) -> Optional[OutputData]:
200
+ """
201
+ Retrieve a vector by ID.
202
+
203
+ Args:
204
+ vector_id (str): ID of the vector to retrieve.
205
+
206
+ Returns:
207
+ OutputData: Retrieved vector.
208
+ """
209
+ self.cur.execute(
210
+ f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
211
+ (vector_id,),
212
+ )
213
+ result = self.cur.fetchone()
214
+ if not result:
215
+ return None
216
+ return OutputData(id=str(result[0]), score=None, payload=result[2])
217
+
218
+ def list_collections(self) -> List[str]:
219
+ """
220
+ List all collections.
221
+
222
+ Returns:
223
+ List[str]: List of collection names.
224
+ """
225
+ self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
226
+ return [row[0] for row in self.cur.fetchall()]
227
+
228
+ def delete_collection(self) -> None:
229
+ """Delete a collection."""
230
+ self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
231
+ self.conn.commit()
232
+
233
+ def collection_info(self) -> Dict[str, Any]:
234
+ """
235
+ Get information about a collection.
236
+
237
+ Returns:
238
+ Dict[str, Any]: Collection information.
239
+ """
240
+ self.cur.execute(
241
+ f"""
242
+ SELECT
243
+ table_name,
244
+ (SELECT COUNT(*) FROM {self.collection_name}) as row_count,
245
+ (SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
246
+ FROM information_schema.tables
247
+ WHERE table_schema = 'public' AND table_name = %s
248
+ """,
249
+ (self.collection_name,),
250
+ )
251
+ result = self.cur.fetchone()
252
+ return {"name": result[0], "count": result[1], "size": result[2]}
253
+
254
+ def list(self, filters: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> List[OutputData]:
255
+ """
256
+ List all vectors in a collection.
257
+
258
+ Args:
259
+ filters (Dict, optional): Filters to apply to the list.
260
+ limit (int, optional): Number of vectors to return. Defaults to 10000 if None.
261
+
262
+ Returns:
263
+ List[OutputData]: List of vectors.
264
+ """
265
+ filter_conditions = []
266
+ filter_params = []
267
+
268
+ if filters:
269
+ for k, v in filters.items():
270
+ filter_conditions.append("payload->>%s = %s")
271
+ filter_params.extend([k, str(v)])
272
+
273
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
274
+
275
+ # Handle None limit by using a large default value
276
+ if limit is None:
277
+ limit = 10000 # Default large limit
278
+
279
+ query = f"""
280
+ SELECT id, vector, payload
281
+ FROM {self.collection_name}
282
+ {filter_clause}
283
+ LIMIT %s
284
+ """
285
+
286
+ self.cur.execute(query, (*filter_params, limit))
287
+
288
+ results = self.cur.fetchall()
289
+ return [OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]
290
+
291
+ def __del__(self) -> None:
292
+ """
293
+ Close the database connection when the object is deleted.
294
+ """
295
+ if hasattr(self, "cur"):
296
+ self.cur.close()
297
+ if hasattr(self, "conn"):
298
+ self.conn.close()
299
+
300
+ def reset(self) -> None:
301
+ """Reset the index by deleting and recreating it."""
302
+ logger.warning(f"Resetting index {self.collection_name}...")
303
+ self.delete_collection()
304
+ self.create_collection(self.embedding_model_dims)