typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,293 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
"""SQLite storage provider implementation."""
|
5
|
+
|
6
|
+
import json
|
7
|
+
import sqlite3
|
8
|
+
|
9
|
+
from ...knowpro import interfaces
|
10
|
+
from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings
|
11
|
+
from .collections import SqliteMessageCollection, SqliteSemanticRefCollection
|
12
|
+
from .messageindex import SqliteMessageTextIndex
|
13
|
+
from .propindex import SqlitePropertyIndex
|
14
|
+
from .reltermsindex import SqliteRelatedTermsIndex
|
15
|
+
from .semrefindex import SqliteTermToSemanticRefIndex
|
16
|
+
from .timestampindex import SqliteTimestampToTextRangeIndex
|
17
|
+
from .schema import (
|
18
|
+
CONVERSATIONS_SCHEMA,
|
19
|
+
MESSAGE_TEXT_INDEX_SCHEMA,
|
20
|
+
MESSAGES_SCHEMA,
|
21
|
+
PROPERTY_INDEX_SCHEMA,
|
22
|
+
RELATED_TERMS_ALIASES_SCHEMA,
|
23
|
+
RELATED_TERMS_FUZZY_SCHEMA,
|
24
|
+
SEMANTIC_REF_INDEX_SCHEMA,
|
25
|
+
SEMANTIC_REFS_SCHEMA,
|
26
|
+
ConversationMetadata,
|
27
|
+
get_db_schema_version,
|
28
|
+
init_db_schema,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class SqliteStorageProvider[TMessage: interfaces.IMessage](
|
33
|
+
interfaces.IStorageProvider[TMessage]
|
34
|
+
):
|
35
|
+
"""SQLite-backed storage provider implementation."""
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
db_path: str = ":memory:",
|
40
|
+
conversation_id: str = "default",
|
41
|
+
message_type: type[TMessage] = None, # type: ignore
|
42
|
+
semantic_ref_type: type[interfaces.SemanticRef] = None, # type: ignore
|
43
|
+
conversation_index_settings=None,
|
44
|
+
message_text_index_settings: MessageTextIndexSettings | None = None,
|
45
|
+
related_term_index_settings: RelatedTermIndexSettings | None = None,
|
46
|
+
):
|
47
|
+
self.db_path = db_path
|
48
|
+
self.conversation_id = conversation_id
|
49
|
+
self.message_type = message_type
|
50
|
+
self.semantic_ref_type = semantic_ref_type
|
51
|
+
|
52
|
+
# Settings with defaults (require embedding settings)
|
53
|
+
self.conversation_index_settings = conversation_index_settings or {}
|
54
|
+
if message_text_index_settings is None:
|
55
|
+
# Create default embedding settings if not provided
|
56
|
+
from ...aitools.embeddings import AsyncEmbeddingModel
|
57
|
+
from ...aitools.vectorbase import TextEmbeddingIndexSettings
|
58
|
+
|
59
|
+
model = AsyncEmbeddingModel()
|
60
|
+
embedding_settings = TextEmbeddingIndexSettings(model)
|
61
|
+
message_text_index_settings = MessageTextIndexSettings(embedding_settings)
|
62
|
+
self.message_text_index_settings = message_text_index_settings
|
63
|
+
|
64
|
+
if related_term_index_settings is None:
|
65
|
+
# Use the same embedding settings
|
66
|
+
embedding_settings = message_text_index_settings.embedding_index_settings
|
67
|
+
related_term_index_settings = RelatedTermIndexSettings(embedding_settings)
|
68
|
+
self.related_term_index_settings = related_term_index_settings
|
69
|
+
|
70
|
+
# Initialize database connection
|
71
|
+
self.db = sqlite3.connect(db_path)
|
72
|
+
|
73
|
+
# Configure SQLite for optimal bulk insertion performance
|
74
|
+
self.db.execute("PRAGMA foreign_keys = ON")
|
75
|
+
# Improve write performance for bulk operations
|
76
|
+
self.db.execute("PRAGMA synchronous = NORMAL") # Faster than FULL, still safe
|
77
|
+
self.db.execute(
|
78
|
+
"PRAGMA journal_mode = WAL"
|
79
|
+
) # Write-Ahead Logging for better concurrency
|
80
|
+
self.db.execute("PRAGMA cache_size = -64000") # 64MB cache (negative = KB)
|
81
|
+
self.db.execute("PRAGMA temp_store = MEMORY") # Store temp tables in memory
|
82
|
+
self.db.execute("PRAGMA mmap_size = 268435456") # 256MB memory-mapped I/O
|
83
|
+
|
84
|
+
# Initialize schema
|
85
|
+
init_db_schema(self.db)
|
86
|
+
|
87
|
+
# Initialize collections
|
88
|
+
# Initialize message collection first
|
89
|
+
self._message_collection = SqliteMessageCollection(self.db, self.message_type)
|
90
|
+
self._semantic_ref_collection = SqliteSemanticRefCollection(self.db)
|
91
|
+
|
92
|
+
# Initialize indexes
|
93
|
+
self._term_to_semantic_ref_index = SqliteTermToSemanticRefIndex(self.db)
|
94
|
+
self._property_index = SqlitePropertyIndex(self.db)
|
95
|
+
self._timestamp_index = SqliteTimestampToTextRangeIndex(self.db)
|
96
|
+
self._message_text_index = SqliteMessageTextIndex(
|
97
|
+
self.db,
|
98
|
+
self.message_text_index_settings,
|
99
|
+
self._message_collection,
|
100
|
+
)
|
101
|
+
# Initialize related terms index
|
102
|
+
self._related_terms_index = SqliteRelatedTermsIndex(
|
103
|
+
self.db, self.related_term_index_settings.embedding_index_settings
|
104
|
+
)
|
105
|
+
|
106
|
+
# Connect message collection to message text index for automatic indexing
|
107
|
+
self._message_collection.set_message_text_index(self._message_text_index)
|
108
|
+
|
109
|
+
async def close(self) -> None:
|
110
|
+
"""Close the database connection. COMMITS."""
|
111
|
+
if hasattr(self, "db"):
|
112
|
+
self.db.commit()
|
113
|
+
self.db.close()
|
114
|
+
del self.db
|
115
|
+
|
116
|
+
def __del__(self) -> None:
|
117
|
+
"""Ensure database is closed when object is deleted. ROLLS BACK."""
|
118
|
+
# Can't use async in __del__, so close directly
|
119
|
+
if hasattr(self, "db"):
|
120
|
+
self.db.rollback()
|
121
|
+
self.db.close()
|
122
|
+
del self.db
|
123
|
+
|
124
|
+
@property
|
125
|
+
def messages(self) -> SqliteMessageCollection[TMessage]:
|
126
|
+
return self._message_collection
|
127
|
+
|
128
|
+
@property
|
129
|
+
def semantic_refs(self) -> SqliteSemanticRefCollection:
|
130
|
+
return self._semantic_ref_collection
|
131
|
+
|
132
|
+
@property
|
133
|
+
def term_to_semantic_ref_index(self) -> SqliteTermToSemanticRefIndex:
|
134
|
+
return self._term_to_semantic_ref_index
|
135
|
+
|
136
|
+
@property
|
137
|
+
def property_index(self) -> SqlitePropertyIndex:
|
138
|
+
return self._property_index
|
139
|
+
|
140
|
+
@property
|
141
|
+
def timestamp_index(self) -> SqliteTimestampToTextRangeIndex:
|
142
|
+
return self._timestamp_index
|
143
|
+
|
144
|
+
@property
|
145
|
+
def message_text_index(self) -> SqliteMessageTextIndex:
|
146
|
+
return self._message_text_index
|
147
|
+
|
148
|
+
@property
|
149
|
+
def related_terms_index(self) -> SqliteRelatedTermsIndex:
|
150
|
+
return self._related_terms_index
|
151
|
+
|
152
|
+
# Async getters required by base class
|
153
|
+
async def get_message_collection(
|
154
|
+
self, message_type: type[TMessage] | None = None
|
155
|
+
) -> interfaces.IMessageCollection[TMessage]:
|
156
|
+
"""Get the message collection."""
|
157
|
+
return self._message_collection
|
158
|
+
|
159
|
+
async def get_semantic_ref_collection(self) -> interfaces.ISemanticRefCollection:
|
160
|
+
"""Get the semantic reference collection."""
|
161
|
+
return self._semantic_ref_collection
|
162
|
+
|
163
|
+
async def get_semantic_ref_index(self) -> interfaces.ITermToSemanticRefIndex:
|
164
|
+
"""Get the semantic reference index."""
|
165
|
+
return self._term_to_semantic_ref_index
|
166
|
+
|
167
|
+
async def get_property_index(self) -> interfaces.IPropertyToSemanticRefIndex:
|
168
|
+
"""Get the property index."""
|
169
|
+
return self._property_index
|
170
|
+
|
171
|
+
async def get_timestamp_index(self) -> interfaces.ITimestampToTextRangeIndex:
|
172
|
+
"""Get the timestamp index."""
|
173
|
+
return self._timestamp_index
|
174
|
+
|
175
|
+
async def get_message_text_index(self) -> interfaces.IMessageTextIndex[TMessage]:
|
176
|
+
"""Get the message text index."""
|
177
|
+
return self._message_text_index
|
178
|
+
|
179
|
+
async def get_related_terms_index(self) -> interfaces.ITermToRelatedTermsIndex:
|
180
|
+
"""Get the related terms index."""
|
181
|
+
return self._related_terms_index
|
182
|
+
|
183
|
+
async def get_conversation_threads(self) -> interfaces.IConversationThreads:
|
184
|
+
"""Get the conversation threads."""
|
185
|
+
# For now, return a simple implementation
|
186
|
+
# In a full implementation, this would be stored/retrieved from SQLite
|
187
|
+
from ...storage.memory.convthreads import ConversationThreads
|
188
|
+
|
189
|
+
return ConversationThreads(
|
190
|
+
self.message_text_index_settings.embedding_index_settings
|
191
|
+
)
|
192
|
+
|
193
|
+
async def clear(self) -> None:
|
194
|
+
"""Clear all data from the storage provider."""
|
195
|
+
cursor = self.db.cursor()
|
196
|
+
# Clear in reverse dependency order
|
197
|
+
cursor.execute("DELETE FROM RelatedTermsFuzzy")
|
198
|
+
cursor.execute("DELETE FROM RelatedTermsAliases")
|
199
|
+
cursor.execute("DELETE FROM MessageTextIndex")
|
200
|
+
cursor.execute("DELETE FROM PropertyIndex")
|
201
|
+
cursor.execute("DELETE FROM SemanticRefIndex")
|
202
|
+
cursor.execute("DELETE FROM SemanticRefs")
|
203
|
+
cursor.execute("DELETE FROM Messages")
|
204
|
+
cursor.execute("DELETE FROM ConversationMetadata")
|
205
|
+
|
206
|
+
# Clear in-memory indexes
|
207
|
+
await self._message_text_index.clear()
|
208
|
+
|
209
|
+
def serialize(self) -> dict:
|
210
|
+
"""Serialize all storage provider data."""
|
211
|
+
return {
|
212
|
+
"termToSemanticRefIndexData": self._term_to_semantic_ref_index.serialize(),
|
213
|
+
"relatedTermsIndexData": self._related_terms_index.serialize(),
|
214
|
+
}
|
215
|
+
|
216
|
+
async def deserialize(self, data: dict) -> None:
|
217
|
+
"""Deserialize storage provider data."""
|
218
|
+
# Deserialize term to semantic ref index
|
219
|
+
if data.get("termToSemanticRefIndexData"):
|
220
|
+
await self._term_to_semantic_ref_index.deserialize(
|
221
|
+
data["termToSemanticRefIndexData"]
|
222
|
+
)
|
223
|
+
|
224
|
+
# Deserialize related terms index
|
225
|
+
if data.get("relatedTermsIndexData"):
|
226
|
+
await self._related_terms_index.deserialize(data["relatedTermsIndexData"])
|
227
|
+
|
228
|
+
# Deserialize message text index
|
229
|
+
if data.get("messageIndexData"):
|
230
|
+
await self._message_text_index.deserialize(data["messageIndexData"])
|
231
|
+
|
232
|
+
def get_conversation_metadata(self) -> ConversationMetadata | None:
|
233
|
+
"""Get conversation metadata."""
|
234
|
+
cursor = self.db.cursor()
|
235
|
+
cursor.execute(
|
236
|
+
"SELECT name_tag, schema_version, created_at, updated_at, tags, extra FROM ConversationMetadata LIMIT 1"
|
237
|
+
)
|
238
|
+
row = cursor.fetchone()
|
239
|
+
if row:
|
240
|
+
return ConversationMetadata(
|
241
|
+
name_tag=row[0],
|
242
|
+
schema_version=row[1],
|
243
|
+
created_at=row[2],
|
244
|
+
updated_at=row[3],
|
245
|
+
tags=json.loads(row[4]) if row[4] else [],
|
246
|
+
extra=json.loads(row[5]) if row[5] else {},
|
247
|
+
)
|
248
|
+
return None
|
249
|
+
|
250
|
+
def update_conversation_metadata(
|
251
|
+
self, created_at: str | None = None, updated_at: str | None = None
|
252
|
+
) -> None:
|
253
|
+
"""Update conversation metadata."""
|
254
|
+
cursor = self.db.cursor()
|
255
|
+
|
256
|
+
# Check if conversation metadata exists
|
257
|
+
cursor.execute("SELECT 1 FROM ConversationMetadata LIMIT 1")
|
258
|
+
|
259
|
+
if cursor.fetchone():
|
260
|
+
# Update existing
|
261
|
+
updates = []
|
262
|
+
params = []
|
263
|
+
if created_at is not None:
|
264
|
+
updates.append("created_at = ?")
|
265
|
+
params.append(created_at)
|
266
|
+
if updated_at is not None:
|
267
|
+
updates.append("updated_at = ?")
|
268
|
+
params.append(updated_at)
|
269
|
+
|
270
|
+
if updates:
|
271
|
+
cursor.execute(
|
272
|
+
f"UPDATE ConversationMetadata SET {', '.join(updates)}",
|
273
|
+
params,
|
274
|
+
)
|
275
|
+
else:
|
276
|
+
# Insert new with default values
|
277
|
+
name_tag = f"conversation_{self.conversation_id}"
|
278
|
+
schema_version = "1.0"
|
279
|
+
tags = json.dumps([])
|
280
|
+
extra = json.dumps({})
|
281
|
+
|
282
|
+
cursor.execute(
|
283
|
+
"INSERT INTO ConversationMetadata (name_tag, schema_version, created_at, updated_at, tags, extra) VALUES (?, ?, ?, ?, ?, ?)",
|
284
|
+
(name_tag, schema_version, created_at, updated_at, tags, extra),
|
285
|
+
)
|
286
|
+
|
287
|
+
def get_db_version(self) -> int:
|
288
|
+
"""Get the database schema version."""
|
289
|
+
version_str = get_db_schema_version(self.db)
|
290
|
+
try:
|
291
|
+
return int(version_str.split(".")[0]) # Get major version as int
|
292
|
+
except (ValueError, AttributeError):
|
293
|
+
return 1 # Default version
|
@@ -0,0 +1,328 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
"""SQLite-based related terms index implementations."""
|
5
|
+
|
6
|
+
import sqlite3
|
7
|
+
|
8
|
+
from ...aitools.embeddings import AsyncEmbeddingModel, NormalizedEmbeddings
|
9
|
+
from ...aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase
|
10
|
+
from ...knowpro import interfaces
|
11
|
+
|
12
|
+
from .schema import serialize_embedding, deserialize_embedding
|
13
|
+
|
14
|
+
|
15
|
+
class SqliteRelatedTermsAliases(interfaces.ITermToRelatedTerms):
|
16
|
+
"""SQLite-backed implementation of term to related terms aliases."""
|
17
|
+
|
18
|
+
def __init__(self, db: sqlite3.Connection):
|
19
|
+
self.db = db
|
20
|
+
|
21
|
+
async def lookup_term(self, text: str) -> list[interfaces.Term] | None:
|
22
|
+
cursor = self.db.cursor()
|
23
|
+
cursor.execute("SELECT alias FROM RelatedTermsAliases WHERE term = ?", (text,))
|
24
|
+
results = [interfaces.Term(row[0]) for row in cursor.fetchall()]
|
25
|
+
return results if results else None
|
26
|
+
|
27
|
+
async def add_related_term(
|
28
|
+
self, text: str, related_terms: interfaces.Term | list[interfaces.Term]
|
29
|
+
) -> None:
|
30
|
+
# Convert single Term to list
|
31
|
+
if isinstance(related_terms, interfaces.Term):
|
32
|
+
related_terms = [related_terms]
|
33
|
+
|
34
|
+
cursor = self.db.cursor()
|
35
|
+
# Add new aliases (use INSERT OR IGNORE to avoid duplicates)
|
36
|
+
for related_term in related_terms:
|
37
|
+
cursor.execute(
|
38
|
+
"INSERT OR IGNORE INTO RelatedTermsAliases (term, alias) VALUES (?, ?)",
|
39
|
+
(text, related_term.text),
|
40
|
+
)
|
41
|
+
|
42
|
+
async def remove_term(self, text: str) -> None:
|
43
|
+
cursor = self.db.cursor()
|
44
|
+
cursor.execute("DELETE FROM RelatedTermsAliases WHERE term = ?", (text,))
|
45
|
+
|
46
|
+
async def clear(self) -> None:
|
47
|
+
cursor = self.db.cursor()
|
48
|
+
cursor.execute("DELETE FROM RelatedTermsAliases")
|
49
|
+
|
50
|
+
async def set_related_terms(self, term: str, related_terms: list[str]) -> None:
|
51
|
+
cursor = self.db.cursor()
|
52
|
+
# Clear existing aliases for this term
|
53
|
+
cursor.execute("DELETE FROM RelatedTermsAliases WHERE term = ?", (term,))
|
54
|
+
# Add new aliases
|
55
|
+
for alias in related_terms:
|
56
|
+
cursor.execute(
|
57
|
+
"INSERT INTO RelatedTermsAliases (term, alias) VALUES (?, ?)",
|
58
|
+
(term, alias),
|
59
|
+
)
|
60
|
+
|
61
|
+
async def size(self) -> int:
|
62
|
+
cursor = self.db.cursor()
|
63
|
+
cursor.execute("SELECT COUNT(DISTINCT term) FROM RelatedTermsAliases")
|
64
|
+
return cursor.fetchone()[0]
|
65
|
+
|
66
|
+
async def get_terms(self) -> list[str]:
|
67
|
+
cursor = self.db.cursor()
|
68
|
+
cursor.execute("SELECT DISTINCT term FROM RelatedTermsAliases ORDER BY term")
|
69
|
+
return [row[0] for row in cursor.fetchall()]
|
70
|
+
|
71
|
+
async def is_empty(self) -> bool:
|
72
|
+
cursor = self.db.cursor()
|
73
|
+
cursor.execute("SELECT COUNT(*) FROM RelatedTermsAliases")
|
74
|
+
return cursor.fetchone()[0] == 0
|
75
|
+
|
76
|
+
async def serialize(self) -> interfaces.TermToRelatedTermsData:
|
77
|
+
"""Serialize the aliases data."""
|
78
|
+
cursor = self.db.cursor()
|
79
|
+
cursor.execute(
|
80
|
+
"SELECT term, alias FROM RelatedTermsAliases ORDER BY term, alias"
|
81
|
+
)
|
82
|
+
|
83
|
+
# Group by term
|
84
|
+
term_to_aliases: dict[str, list[str]] = {}
|
85
|
+
for term, alias in cursor.fetchall():
|
86
|
+
if term not in term_to_aliases:
|
87
|
+
term_to_aliases[term] = []
|
88
|
+
term_to_aliases[term].append(alias)
|
89
|
+
|
90
|
+
# Convert to the expected format
|
91
|
+
items = []
|
92
|
+
for term, aliases in term_to_aliases.items():
|
93
|
+
term_data_list = [interfaces.TermData(text=alias) for alias in aliases]
|
94
|
+
items.append(
|
95
|
+
interfaces.TermsToRelatedTermsDataItem(
|
96
|
+
termText=term, relatedTerms=term_data_list
|
97
|
+
)
|
98
|
+
)
|
99
|
+
|
100
|
+
return interfaces.TermToRelatedTermsData(relatedTerms=items)
|
101
|
+
|
102
|
+
async def deserialize(self, data: interfaces.TermToRelatedTermsData | None) -> None:
|
103
|
+
"""Deserialize alias data."""
|
104
|
+
cursor = self.db.cursor()
|
105
|
+
|
106
|
+
# Clear existing data
|
107
|
+
cursor.execute("DELETE FROM RelatedTermsAliases")
|
108
|
+
|
109
|
+
if data is None:
|
110
|
+
return
|
111
|
+
|
112
|
+
related_terms = data.get("relatedTerms", [])
|
113
|
+
|
114
|
+
if related_terms:
|
115
|
+
# Prepare all insertion data for bulk operation
|
116
|
+
insertion_data = []
|
117
|
+
for item in related_terms:
|
118
|
+
if item and item.get("termText") and item.get("relatedTerms"):
|
119
|
+
term = item["termText"]
|
120
|
+
for term_data in item["relatedTerms"]:
|
121
|
+
alias = term_data["text"]
|
122
|
+
insertion_data.append((term, alias))
|
123
|
+
|
124
|
+
# Bulk insert all the data
|
125
|
+
if insertion_data:
|
126
|
+
cursor.executemany(
|
127
|
+
"INSERT INTO RelatedTermsAliases (term, alias) VALUES (?, ?)",
|
128
|
+
insertion_data,
|
129
|
+
)
|
130
|
+
|
131
|
+
|
132
|
+
class SqliteRelatedTermsFuzzy(interfaces.ITermToRelatedTermsFuzzy):
|
133
|
+
"""SQLite-backed implementation of fuzzy term relationships with persistent embeddings."""
|
134
|
+
|
135
|
+
def __init__(self, db: sqlite3.Connection, settings: TextEmbeddingIndexSettings):
|
136
|
+
self.db = db
|
137
|
+
self._embedding_settings = settings
|
138
|
+
self._vector_base = VectorBase(self._embedding_settings)
|
139
|
+
# Maintain our own list of terms to map ordinals back to keys
|
140
|
+
self._terms_list: list[str] = [] # TODO: Use the database instead?
|
141
|
+
self._added_terms: set[str] = set() # TODO: Ditto?
|
142
|
+
# If items exist in the db, copy them into the VectorBase, terms list, and added terms
|
143
|
+
if self._size() > 0:
|
144
|
+
cursor = self.db.cursor()
|
145
|
+
cursor.execute(
|
146
|
+
"SELECT term, term_embedding FROM RelatedTermsFuzzy ORDER BY term"
|
147
|
+
)
|
148
|
+
rows = cursor.fetchall()
|
149
|
+
for term, blob in rows:
|
150
|
+
assert blob is not None, term
|
151
|
+
embedding: NormalizedEmbeddings = deserialize_embedding(blob)
|
152
|
+
# Add to VectorBase at the correct ordinal
|
153
|
+
self._vector_base.add_embedding(term, embedding)
|
154
|
+
self._terms_list.append(term)
|
155
|
+
self._added_terms.add(term)
|
156
|
+
|
157
|
+
async def lookup_term(
|
158
|
+
self,
|
159
|
+
text: str,
|
160
|
+
max_hits: int | None = None,
|
161
|
+
min_score: float | None = None,
|
162
|
+
) -> list[interfaces.Term]:
|
163
|
+
"""Look up similar terms using fuzzy matching."""
|
164
|
+
|
165
|
+
# Search for similar terms using VectorBase
|
166
|
+
similar_results = await self._vector_base.fuzzy_lookup(
|
167
|
+
text, max_hits=max_hits, min_score=min_score
|
168
|
+
)
|
169
|
+
|
170
|
+
# Convert VectorBase results to Term objects
|
171
|
+
results = []
|
172
|
+
for scored_int in similar_results:
|
173
|
+
# Get the term text from the list of terms # TODO: Use the database instead?
|
174
|
+
if scored_int.item < len(self._terms_list):
|
175
|
+
term_text = self._terms_list[scored_int.item]
|
176
|
+
results.append(interfaces.Term(term_text, scored_int.score))
|
177
|
+
|
178
|
+
return results
|
179
|
+
|
180
|
+
async def remove_term(self, term: str) -> None:
|
181
|
+
raise NotImplementedError(
|
182
|
+
"TODO: Removal from VectorBase, _terms_list, _terms_to_ordinal"
|
183
|
+
)
|
184
|
+
# cursor = self.db.cursor()
|
185
|
+
# cursor.execute("DELETE FROM RelatedTermsFuzzy WHERE term = ?", (term,))
|
186
|
+
|
187
|
+
# Clear VectorBase and local mappings - they will be rebuilt on next lookup
|
188
|
+
# NO THEY WON'T
|
189
|
+
# self._vector_base.clear()
|
190
|
+
# self._terms_list.clear()
|
191
|
+
# self._added_terms.clear()
|
192
|
+
|
193
|
+
async def clear(self) -> None:
|
194
|
+
cursor = self.db.cursor()
|
195
|
+
cursor.execute("DELETE FROM RelatedTermsFuzzy")
|
196
|
+
|
197
|
+
async def size(self) -> int:
|
198
|
+
return self._size()
|
199
|
+
|
200
|
+
def _size(self) -> int:
|
201
|
+
cursor = self.db.cursor()
|
202
|
+
cursor.execute("SELECT COUNT(term) FROM RelatedTermsFuzzy")
|
203
|
+
return cursor.fetchone()[0]
|
204
|
+
|
205
|
+
async def get_terms(self) -> list[str]:
|
206
|
+
cursor = self.db.cursor()
|
207
|
+
cursor.execute("SELECT term FROM RelatedTermsFuzzy ORDER BY term")
|
208
|
+
return [row[0] for row in cursor.fetchall()]
|
209
|
+
|
210
|
+
async def add_terms(self, texts: list[str]) -> None:
|
211
|
+
"""Add terms."""
|
212
|
+
cursor = self.db.cursor()
|
213
|
+
# TODO: Batch additions to database
|
214
|
+
for text in texts:
|
215
|
+
if text in self._added_terms:
|
216
|
+
continue
|
217
|
+
|
218
|
+
# Add to VectorBase for fuzzy lookup
|
219
|
+
await self._vector_base.add_key(text)
|
220
|
+
self._terms_list.append(text)
|
221
|
+
self._added_terms.add(text)
|
222
|
+
|
223
|
+
# Generate embedding for term and store in database
|
224
|
+
embedding = await self._vector_base.get_embedding(text) # Cached
|
225
|
+
serialized_embedding = serialize_embedding(embedding)
|
226
|
+
# Insert term and embedding
|
227
|
+
cursor.execute(
|
228
|
+
"""
|
229
|
+
INSERT OR REPLACE INTO RelatedTermsFuzzy
|
230
|
+
(term, term_embedding)
|
231
|
+
VALUES (?, ?)
|
232
|
+
""",
|
233
|
+
(text, serialized_embedding),
|
234
|
+
)
|
235
|
+
|
236
|
+
async def lookup_terms(
|
237
|
+
self,
|
238
|
+
texts: list[str],
|
239
|
+
max_hits: int | None = None,
|
240
|
+
min_score: float | None = None,
|
241
|
+
) -> list[list[interfaces.Term]]:
|
242
|
+
"""Look up multiple terms at once."""
|
243
|
+
# TODO: Some kind of batching?
|
244
|
+
results = []
|
245
|
+
for text in texts:
|
246
|
+
term_results = await self.lookup_term(text, max_hits, min_score)
|
247
|
+
results.append(term_results)
|
248
|
+
return results
|
249
|
+
|
250
|
+
async def deserialize(self, data: interfaces.TextEmbeddingIndexData) -> None:
|
251
|
+
"""Deserialize fuzzy index data from JSON into SQLite database."""
|
252
|
+
# Clear existing data
|
253
|
+
cursor = self.db.cursor()
|
254
|
+
cursor.execute("DELETE FROM RelatedTermsFuzzy")
|
255
|
+
|
256
|
+
# Clear local mappings
|
257
|
+
self._terms_list.clear()
|
258
|
+
self._added_terms.clear()
|
259
|
+
|
260
|
+
# Get text items and embeddings from the data
|
261
|
+
text_items = data.get("textItems")
|
262
|
+
embeddings_data = data.get("embeddings")
|
263
|
+
|
264
|
+
if not text_items or embeddings_data is None:
|
265
|
+
return
|
266
|
+
|
267
|
+
# Use persistent VectorBase to deserialize embeddings (preserves caching)
|
268
|
+
self._vector_base.deserialize(embeddings_data)
|
269
|
+
|
270
|
+
# Prepare all insertion data for bulk operation
|
271
|
+
from .schema import serialize_embedding
|
272
|
+
|
273
|
+
insertion_data = []
|
274
|
+
for i, text in enumerate(text_items):
|
275
|
+
if i < len(self._vector_base):
|
276
|
+
# Get embedding from persistent VectorBase
|
277
|
+
embedding = self._vector_base.get_embedding_at(i)
|
278
|
+
if embedding is not None:
|
279
|
+
serialized_embedding = serialize_embedding(embedding)
|
280
|
+
# Insert as self-referential entry with only term_embedding
|
281
|
+
insertion_data.append((text, serialized_embedding))
|
282
|
+
# Update local mappings
|
283
|
+
self._terms_list.append(text)
|
284
|
+
self._added_terms.add(text)
|
285
|
+
|
286
|
+
# Bulk insert all the data
|
287
|
+
if insertion_data:
|
288
|
+
cursor.executemany(
|
289
|
+
"""
|
290
|
+
INSERT OR REPLACE INTO RelatedTermsFuzzy
|
291
|
+
(term, term_embedding)
|
292
|
+
VALUES (?, ?)
|
293
|
+
""",
|
294
|
+
insertion_data,
|
295
|
+
)
|
296
|
+
|
297
|
+
|
298
|
+
class SqliteRelatedTermsIndex(interfaces.ITermToRelatedTermsIndex):
|
299
|
+
"""SQLite-backed implementation of ITermToRelatedTermsIndex combining aliases and fuzzy index."""
|
300
|
+
|
301
|
+
def __init__(self, db: sqlite3.Connection, settings: TextEmbeddingIndexSettings):
|
302
|
+
self.db = db
|
303
|
+
# Initialize alias and fuzzy related terms indexes
|
304
|
+
self._aliases = SqliteRelatedTermsAliases(db)
|
305
|
+
self._fuzzy_index = SqliteRelatedTermsFuzzy(db, settings)
|
306
|
+
|
307
|
+
@property
|
308
|
+
def aliases(self) -> interfaces.ITermToRelatedTerms:
|
309
|
+
return self._aliases
|
310
|
+
|
311
|
+
@property
|
312
|
+
def fuzzy_index(self) -> interfaces.ITermToRelatedTermsFuzzy | None:
|
313
|
+
return self._fuzzy_index
|
314
|
+
|
315
|
+
async def serialize(self) -> interfaces.TermsToRelatedTermsIndexData:
|
316
|
+
raise NotImplementedError("TODO")
|
317
|
+
|
318
|
+
async def deserialize(self, data: interfaces.TermsToRelatedTermsIndexData) -> None:
|
319
|
+
"""Deserialize related terms index data."""
|
320
|
+
# Deserialize alias data
|
321
|
+
alias_data = data.get("aliasData")
|
322
|
+
if alias_data is not None:
|
323
|
+
await self._aliases.deserialize(alias_data)
|
324
|
+
|
325
|
+
# Deserialize fuzzy index data
|
326
|
+
text_embedding_data = data.get("textEmbeddingData")
|
327
|
+
if text_embedding_data is not None:
|
328
|
+
await self._fuzzy_index.deserialize(text_embedding_data)
|