echo-vector 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,182 @@
1
+ """Faiss-based index implementation."""
2
+
3
+ import os
4
+ from typing import Any, cast
5
+
6
+ import faiss
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ from .base import BaseIndex
11
+ from .store import SQLiteStore
12
+
13
+
14
+ class FaissIndex(BaseIndex):
15
+ """Faiss-based index using IndexFlatIP (Inner Product) for vector search."""
16
+
17
+ def __init__(self, dimension: int, db_path: str = ":memory:") -> None:
18
+ """Initialize the Faiss index and the metadata store.
19
+
20
+ Args:
21
+ dimension: Dimensionality of the embeddings.
22
+ db_path: Path to the SQLite store database.
23
+ """
24
+ self.dimension = dimension
25
+ self.index: faiss.IndexIDMap2 = faiss.IndexIDMap2(faiss.IndexFlatIP(dimension))
26
+ self.store = SQLiteStore(db_path)
27
+
28
+ def add(
29
+ self,
30
+ embeddings: npt.NDArray[np.float32],
31
+ ids: list[str],
32
+ metadata: list[dict[str, Any]] | None = None,
33
+ ) -> None:
34
+ """Add embeddings, their string IDs, and metadata to the index.
35
+
36
+ Supports batched and incremental indexing.
37
+
38
+ Args:
39
+ embeddings: A 2D numpy array of embeddings (np.float32).
40
+ ids: A list of string IDs corresponding to the embeddings.
41
+ metadata: An optional list of metadata dictionaries.
42
+
43
+ Raises:
44
+ ValueError: If input dimensions or lengths are invalid.
45
+ """
46
+ if embeddings.ndim != 2 or embeddings.shape[1] != self.dimension:
47
+ raise ValueError(f"Embeddings must be a 2D array with dimension {self.dimension}")
48
+
49
+ num_vectors = embeddings.shape[0]
50
+ if len(ids) != num_vectors:
51
+ raise ValueError("Number of IDs must match the number of embeddings.")
52
+
53
+ if metadata is None:
54
+ metadata = [{} for _ in range(num_vectors)]
55
+ elif len(metadata) != num_vectors:
56
+ raise ValueError("Length of metadata must match the number of embeddings.")
57
+
58
+ # Ensure embeddings are contiguous and float32
59
+ embeddings_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
60
+
61
+ # Use max stored ID + 1 so IDs remain unique after deletions.
62
+ start_id = self.store.get_max_int_id() + 1
63
+ int_ids = list(range(start_id, start_id + num_vectors))
64
+
65
+ # Add to Faiss index with explicit IDs
66
+ self.index.add_with_ids(embeddings_f32, np.array(int_ids, dtype=np.int64))
67
+
68
+ # Add to SQLite store
69
+ self.store.add(int_ids, ids, metadata)
70
+
71
+ def search(
72
+ self, query_embeddings: npt.NDArray[np.float32], k: int = 10
73
+ ) -> tuple[
74
+ npt.NDArray[np.float32],
75
+ list[list[str | None]],
76
+ list[list[dict[str, Any] | None]],
77
+ ]:
78
+ """Search for the k nearest neighbors.
79
+
80
+ Args:
81
+ query_embeddings: 2D numpy array of query vectors.
82
+ k: Number of nearest neighbors to retrieve.
83
+
84
+ Returns:
85
+ A tuple of (distances, string_ids, metadata).
86
+
87
+ Raises:
88
+ ValueError: If query embeddings dimensions are invalid.
89
+ """
90
+ if query_embeddings.ndim != 2 or query_embeddings.shape[1] != self.dimension:
91
+ raise ValueError(
92
+ f"Query embeddings must be a 2D array with dimension {self.dimension}"
93
+ )
94
+
95
+ if self.index.ntotal == 0:
96
+ return (
97
+ np.array([], dtype=np.float32).reshape(query_embeddings.shape[0], 0),
98
+ [[] for _ in range(query_embeddings.shape[0])],
99
+ [[] for _ in range(query_embeddings.shape[0])],
100
+ )
101
+
102
+ query_f32 = np.ascontiguousarray(query_embeddings, dtype=np.float32)
103
+ # Ensure k is not larger than index size
104
+ actual_k = min(k, self.index.ntotal)
105
+
106
+ distances, int_indices = self.index.search(query_f32, actual_k)
107
+
108
+ all_string_ids: list[list[str | None]] = []
109
+ all_metadata: list[list[dict[str, Any] | None]] = []
110
+
111
+ for indices in int_indices:
112
+ # -1 is returned by Faiss if not enough results are found
113
+ valid_indices = [int(idx) for idx in indices if idx != -1]
114
+ if not valid_indices:
115
+ all_string_ids.append([])
116
+ all_metadata.append([])
117
+ continue
118
+
119
+ str_ids, meta = self.store.get_by_int_ids(valid_indices)
120
+
121
+ # Reconstruct list to handle -1
122
+ row_str_ids: list[str | None] = []
123
+ row_meta: list[dict[str, Any] | None] = []
124
+ valid_idx_ptr = 0
125
+ for idx in indices:
126
+ if idx == -1:
127
+ row_str_ids.append(None)
128
+ row_meta.append(None)
129
+ else:
130
+ row_str_ids.append(str_ids[valid_idx_ptr])
131
+ row_meta.append(meta[valid_idx_ptr])
132
+ valid_idx_ptr += 1
133
+
134
+ all_string_ids.append(row_str_ids)
135
+ all_metadata.append(row_meta)
136
+
137
+ return distances, all_string_ids, all_metadata
138
+
139
+ def remove_int_ids(self, int_ids: list[int]) -> None:
140
+ """Remove vectors by their integer IDs from the FAISS index and metadata store.
141
+
142
+ Args:
143
+ int_ids: List of integer IDs to remove.
144
+ """
145
+ if not int_ids:
146
+ return
147
+ ids_array = np.array(int_ids, dtype=np.int64)
148
+ self.index.remove_ids(faiss.IDSelectorBatch(ids_array))
149
+ placeholders = ",".join("?" for _ in int_ids)
150
+ delete_query = f"DELETE FROM metadata WHERE int_id IN ({placeholders})" # noqa: S608
151
+ self.store._conn.execute(delete_query, int_ids)
152
+ self.store._conn.commit()
153
+
154
+ def save(self, index_path: str) -> None:
155
+ """Save the Faiss index to disk.
156
+
157
+ Args:
158
+ index_path: The file path to save the index to.
159
+ """
160
+ parent = os.path.dirname(index_path)
161
+ if parent:
162
+ os.makedirs(parent, exist_ok=True)
163
+ faiss.write_index(self.index, index_path)
164
+
165
+ def load(self, index_path: str) -> None:
166
+ """Load the Faiss index from disk.
167
+
168
+ Args:
169
+ index_path: The file path to load the index from.
170
+
171
+ Raises:
172
+ FileNotFoundError: If the index file does not exist.
173
+ ValueError: If the loaded index dimension does not match.
174
+ """
175
+ if not os.path.exists(index_path):
176
+ raise FileNotFoundError(f"Index file {index_path} not found.")
177
+ self.index = cast("faiss.IndexIDMap2", faiss.read_index(index_path))
178
+ if self.index.d != self.dimension:
179
+ raise ValueError(
180
+ f"Loaded index dimension ({self.index.d}) does not match "
181
+ f"expected dimension ({self.dimension})"
182
+ )
@@ -0,0 +1,165 @@
1
+ """SQLite-based store for metadata persistence."""
2
+
3
+ import contextlib
4
+ import json
5
+ import sqlite3
6
+ from typing import Any
7
+
8
+ from .base import BaseStore
9
+
10
+
11
+ class SQLiteStore(BaseStore):
12
+ """SQLite-based store for metadata and string ID persistence."""
13
+
14
+ def __init__(self, db_path: str = ":memory:") -> None:
15
+ """Initialize the SQLite store.
16
+
17
+ Args:
18
+ db_path: Path to the SQLite database file.
19
+ """
20
+ self.db_path = db_path
21
+ self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
22
+ self.initialize()
23
+
24
+ def initialize(self) -> None:
25
+ """Initialize the database schema."""
26
+ cursor = self._conn.cursor()
27
+ cursor.execute(
28
+ """
29
+ CREATE TABLE IF NOT EXISTS metadata (
30
+ int_id INTEGER PRIMARY KEY,
31
+ string_id TEXT UNIQUE NOT NULL,
32
+ metadata_json TEXT
33
+ )
34
+ """
35
+ )
36
+ self._conn.commit()
37
+
38
+ def add(
39
+ self, int_ids: list[int], string_ids: list[str], metadata_list: list[dict[str, Any]]
40
+ ) -> None:
41
+ """Add metadata and ID mappings to the store.
42
+
43
+ Args:
44
+ int_ids: List of integer IDs assigned by the index.
45
+ string_ids: List of original string IDs.
46
+ metadata_list: List of metadata dictionaries.
47
+
48
+ Raises:
49
+ ValueError: If lengths of input lists do not match.
50
+ """
51
+ if not (len(int_ids) == len(string_ids) == len(metadata_list)):
52
+ raise ValueError("Mismatched lengths for IDs and metadata.")
53
+
54
+ cursor = self._conn.cursor()
55
+ data = [
56
+ (i_id, s_id, json.dumps(meta))
57
+ for i_id, s_id, meta in zip(int_ids, string_ids, metadata_list, strict=True)
58
+ ]
59
+
60
+ cursor.executemany(
61
+ """
62
+ INSERT OR REPLACE INTO metadata (int_id, string_id, metadata_json)
63
+ VALUES (?, ?, ?)
64
+ """,
65
+ data,
66
+ )
67
+ self._conn.commit()
68
+
69
+ def get_by_int_ids(
70
+ self, int_ids: list[int]
71
+ ) -> tuple[list[str | None], list[dict[str, Any] | None]]:
72
+ """Retrieve string IDs and metadata for a list of integer IDs.
73
+
74
+ Args:
75
+ int_ids: List of integer IDs to query.
76
+
77
+ Returns:
78
+ A tuple containing a list of string IDs and a list of metadata dictionaries.
79
+ """
80
+ if not int_ids:
81
+ return [], []
82
+
83
+ cursor = self._conn.cursor()
84
+ placeholders = ",".join("?" for _ in int_ids)
85
+ cols = "int_id, string_id, metadata_json"
86
+ query = f"SELECT {cols} FROM metadata WHERE int_id IN ({placeholders})" # noqa: S608
87
+ cursor.execute(query, int_ids)
88
+
89
+ rows = cursor.fetchall()
90
+ row_dict = {row[0]: (row[1], json.loads(row[2])) for row in rows}
91
+
92
+ string_ids: list[str | None] = []
93
+ metadata: list[dict[str, Any] | None] = []
94
+ for i_id in int_ids:
95
+ if i_id in row_dict:
96
+ string_ids.append(row_dict[i_id][0])
97
+ metadata.append(row_dict[i_id][1])
98
+ else:
99
+ string_ids.append(None)
100
+ metadata.append(None)
101
+
102
+ return string_ids, metadata
103
+
104
+ def get_max_int_id(self) -> int:
105
+ """Get the maximum integer ID currently in the store.
106
+
107
+ Returns:
108
+ The maximum integer ID, or -1 if empty.
109
+ """
110
+ cursor = self._conn.cursor()
111
+ cursor.execute("SELECT MAX(int_id) FROM metadata")
112
+ row = cursor.fetchone()
113
+ return int(row[0]) if row[0] is not None else -1
114
+
115
+ def has_filepath(self, filepath: str) -> bool:
116
+ """Return True if any chunk from filepath is already stored.
117
+
118
+ Args:
119
+ filepath: Absolute or relative path of the source audio file.
120
+ """
121
+ cursor = self._conn.cursor()
122
+ cursor.execute(
123
+ "SELECT 1 FROM metadata WHERE string_id LIKE ? LIMIT 1",
124
+ (filepath + "#%",),
125
+ )
126
+ return cursor.fetchone() is not None
127
+
128
+ def get_int_ids_for_filepath(self, filepath: str) -> list[int]:
129
+ """Return all integer IDs belonging to chunks of filepath.
130
+
131
+ Args:
132
+ filepath: Source audio file path.
133
+ """
134
+ cursor = self._conn.cursor()
135
+ cursor.execute(
136
+ "SELECT int_id FROM metadata WHERE string_id LIKE ?",
137
+ (filepath + "#%",),
138
+ )
139
+ return [row[0] for row in cursor.fetchall()]
140
+
141
+ def delete_by_filepath(self, filepath: str) -> list[int]:
142
+ """Delete all chunks for filepath and return their integer IDs.
143
+
144
+ Args:
145
+ filepath: Source audio file path.
146
+
147
+ Returns:
148
+ List of integer IDs that were removed.
149
+ """
150
+ int_ids = self.get_int_ids_for_filepath(filepath)
151
+ if int_ids:
152
+ placeholders = ",".join("?" for _ in int_ids)
153
+ delete_query = f"DELETE FROM metadata WHERE int_id IN ({placeholders})" # noqa: S608
154
+ self._conn.execute(delete_query, int_ids)
155
+ self._conn.commit()
156
+ return int_ids
157
+
158
+ def close(self) -> None:
159
+ """Close the database connection."""
160
+ self._conn.close()
161
+
162
+ def __del__(self) -> None:
163
+ """Ensure connection is closed on GC to avoid Python 3.13 sqlite3 finalizer bug."""
164
+ with contextlib.suppress(Exception):
165
+ self._conn.close()
@@ -0,0 +1,14 @@
1
+ """Search module for EchoVector."""
2
+
3
+ from echovector.search.engine import Embedder, SearchEngine, VectorIndex
4
+ from echovector.search.filters import SearchFilter
5
+ from echovector.search.results import SearchResult, TimestampRange
6
+
7
+ __all__ = [
8
+ "Embedder",
9
+ "SearchEngine",
10
+ "SearchFilter",
11
+ "SearchResult",
12
+ "TimestampRange",
13
+ "VectorIndex",
14
+ ]
@@ -0,0 +1,82 @@
1
+ """Search engine implementation."""
2
+
3
+ from typing import Any, Protocol
4
+
5
+ from echovector.search.filters import SearchFilter
6
+ from echovector.search.results import SearchResult, TimestampRange
7
+
8
+
9
+ class Embedder(Protocol):
10
+ """Protocol for text embedders."""
11
+
12
+ def embed_text(self, text: str) -> list[float]:
13
+ """Embed a text query into a vector."""
14
+ ...
15
+
16
+
17
+ class VectorIndex(Protocol):
18
+ """Protocol for vector indices."""
19
+
20
+ def search(self, vector: list[float], top_k: int) -> list[dict[str, Any]]:
21
+ """Search the index.
22
+
23
+ Expected to return a list of dictionaries, each containing:
24
+ - 'filepath': str
25
+ - 'start': float
26
+ - 'end': float
27
+ - 'score': float
28
+ - 'metadata': dict (optional)
29
+ """
30
+ ...
31
+
32
+
33
+ class SearchEngine:
34
+ """Engine for executing searches against an index."""
35
+
36
+ def __init__(self, index: VectorIndex, embedder: Embedder) -> None:
37
+ """Initialize the search engine.
38
+
39
+ Args:
40
+ index: The vector index to search against.
41
+ embedder: The embedder to use for queries.
42
+ """
43
+ self._index = index
44
+ self._embedder = embedder
45
+
46
+ def search(
47
+ self, query: str, top_k: int = 10, filters: SearchFilter | None = None
48
+ ) -> list[SearchResult]:
49
+ """Search the index for a given query.
50
+
51
+ Args:
52
+ query: The text query.
53
+ top_k: Number of results to return.
54
+ filters: Optional filters to apply.
55
+
56
+ Returns:
57
+ A list of hydrated SearchResult objects.
58
+ """
59
+ vector = self._embedder.embed_text(query)
60
+
61
+ # If filtering, fetch more to ensure we have enough post-filter
62
+ fetch_k = top_k * 5 if filters else top_k
63
+
64
+ raw_results = self._index.search(vector, fetch_k)
65
+
66
+ results = []
67
+ for raw in raw_results:
68
+ # Safely get metadata
69
+ metadata = raw.get("metadata", {})
70
+
71
+ result = SearchResult(
72
+ filepath=raw["filepath"],
73
+ timestamp_range=TimestampRange(start=raw["start"], end=raw["end"]),
74
+ score=raw["score"],
75
+ metadata=metadata,
76
+ )
77
+ results.append(result)
78
+
79
+ if filters:
80
+ results = filters.apply(results)
81
+
82
+ return results[:top_k]
@@ -0,0 +1,55 @@
1
+ """Filtering logic for search results."""
2
+
3
+ from typing import Any
4
+
5
+ from echovector.search.results import SearchResult
6
+
7
+
8
+ class SearchFilter:
9
+ """Filter parameters for search queries.
10
+
11
+ Attributes:
12
+ filepaths: Optional list of allowed file paths.
13
+ min_score: Optional minimum score threshold.
14
+ metadata_filters: Optional exact match metadata filters.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ filepaths: list[str] | None = None,
20
+ min_score: float | None = None,
21
+ metadata_filters: dict[str, Any] | None = None,
22
+ ) -> None:
23
+ """Initialize the search filter.
24
+
25
+ Args:
26
+ filepaths: List of valid file paths.
27
+ min_score: Minimum required score.
28
+ metadata_filters: Key-value pairs that must match exactly.
29
+ """
30
+ self.filepaths = filepaths
31
+ self.min_score = min_score
32
+ self.metadata_filters = metadata_filters or {}
33
+
34
+ def apply(self, results: list[SearchResult]) -> list[SearchResult]:
35
+ """Apply filters to a list of results.
36
+
37
+ Args:
38
+ results: List of SearchResult objects.
39
+
40
+ Returns:
41
+ Filtered list of SearchResult objects.
42
+ """
43
+ filtered = results
44
+ if self.min_score is not None:
45
+ filtered = [r for r in filtered if r.score >= self.min_score]
46
+
47
+ if self.filepaths is not None:
48
+ valid_paths = set(self.filepaths)
49
+ filtered = [r for r in filtered if r.filepath in valid_paths]
50
+
51
+ if self.metadata_filters:
52
+ for key, val in self.metadata_filters.items():
53
+ filtered = [r for r in filtered if r.metadata and r.metadata.get(key) == val]
54
+
55
+ return filtered
@@ -0,0 +1,41 @@
1
+ """Models for search results."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class TimestampRange:
9
+ """Represents a time range in an audio file.
10
+
11
+ Attributes:
12
+ start: Start time in seconds.
13
+ end: End time in seconds.
14
+ """
15
+
16
+ start: float
17
+ end: float
18
+
19
+ def __post_init__(self) -> None:
20
+ """Validate timestamp range."""
21
+ if self.start < 0:
22
+ raise ValueError("Start time cannot be negative.")
23
+ if self.end < self.start:
24
+ raise ValueError("End time cannot be less than start time.")
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class SearchResult:
29
+ """A hydrated search result from the engine.
30
+
31
+ Attributes:
32
+ filepath: Path to the audio file.
33
+ timestamp_range: The time range within the audio file.
34
+ score: The search score (e.g., cosine similarity).
35
+ metadata: Optional metadata dictionary.
36
+ """
37
+
38
+ filepath: str
39
+ timestamp_range: TimestampRange
40
+ score: float
41
+ metadata: dict[str, Any] = field(default_factory=dict)
@@ -0,0 +1,6 @@
1
+ """Utility modules for EchoVector."""
2
+
3
+ from echovector.utils.config import Config
4
+ from echovector.utils.logging import logger, setup_logger
5
+
6
+ __all__ = ["Config", "logger", "setup_logger"]
@@ -0,0 +1,69 @@
1
+ """Configuration management for EchoVector."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+
8
+ class Config:
9
+ """Configuration class to manage settings for EchoVector."""
10
+
11
+ def __init__(self, config_dict: dict[str, Any] | None = None) -> None:
12
+ """Initialize the Config object.
13
+
14
+ Args:
15
+ config_dict: Dictionary containing configuration parameters.
16
+ """
17
+ self._config = config_dict or {}
18
+
19
+ @classmethod
20
+ def from_json(cls, file_path: str | Path) -> "Config":
21
+ """Load configuration from a JSON file.
22
+
23
+ Args:
24
+ file_path: Path to the JSON configuration file.
25
+
26
+ Returns:
27
+ Config instance populated with data from the JSON file.
28
+ """
29
+ with open(file_path, encoding="utf-8") as f:
30
+ data = json.load(f)
31
+ return cls(data)
32
+
33
+ def to_json(self, file_path: str | Path) -> None:
34
+ """Save current configuration to a JSON file.
35
+
36
+ Args:
37
+ file_path: Path where the JSON configuration will be saved.
38
+ """
39
+ with open(file_path, "w", encoding="utf-8") as f:
40
+ json.dump(self._config, f, indent=4)
41
+
42
+ def get(self, key: str, default: Any = None) -> Any:
43
+ """Get a configuration value.
44
+
45
+ Args:
46
+ key: Configuration key.
47
+ default: Default value if key is not found.
48
+
49
+ Returns:
50
+ The value for the specified key or the default value.
51
+ """
52
+ return self._config.get(key, default)
53
+
54
+ def set(self, key: str, value: Any) -> None:
55
+ """Set a configuration value.
56
+
57
+ Args:
58
+ key: Configuration key.
59
+ value: Configuration value.
60
+ """
61
+ self._config[key] = value
62
+
63
+ def update(self, other_config: dict[str, Any]) -> None:
64
+ """Update configuration with another dictionary.
65
+
66
+ Args:
67
+ other_config: Dictionary to update current configuration with.
68
+ """
69
+ self._config.update(other_config)
@@ -0,0 +1,31 @@
1
+ """Logging configuration for EchoVector."""
2
+
3
+ import logging
4
+ import sys
5
+
6
+
7
+ def setup_logger(name: str = "echovector", level: int = logging.INFO) -> logging.Logger:
8
+ """Set up and return a logger with the specified name and level.
9
+
10
+ Args:
11
+ name: Name of the logger.
12
+ level: Logging level.
13
+
14
+ Returns:
15
+ Configured logger instance.
16
+ """
17
+ logger = logging.getLogger(name)
18
+ logger.setLevel(level)
19
+
20
+ if not logger.handlers:
21
+ formatter = logging.Formatter(
22
+ fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
23
+ )
24
+ handler = logging.StreamHandler(sys.stdout)
25
+ handler.setFormatter(formatter)
26
+ logger.addHandler(handler)
27
+
28
+ return logger
29
+
30
+
31
+ logger = setup_logger()