typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,390 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from dataclasses import is_dataclass, MISSING
|
5
|
+
from datetime import datetime
|
6
|
+
import functools
|
7
|
+
import json
|
8
|
+
import types
|
9
|
+
from typing import (
|
10
|
+
Annotated,
|
11
|
+
Any,
|
12
|
+
cast,
|
13
|
+
get_args,
|
14
|
+
get_origin,
|
15
|
+
Literal,
|
16
|
+
NotRequired,
|
17
|
+
overload,
|
18
|
+
TypeAliasType,
|
19
|
+
TypedDict,
|
20
|
+
Union,
|
21
|
+
)
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
from pydantic.alias_generators import to_camel
|
25
|
+
|
26
|
+
from ..aitools.embeddings import NormalizedEmbeddings
|
27
|
+
|
28
|
+
from .interfaces import (
|
29
|
+
ConversationDataWithIndexes,
|
30
|
+
SearchTermGroupTypes,
|
31
|
+
Tag,
|
32
|
+
Topic,
|
33
|
+
)
|
34
|
+
from . import kplib
|
35
|
+
|
36
|
+
|
37
|
+
# -------------------
|
38
|
+
# Shared definitions
|
39
|
+
# -------------------
|
40
|
+
|
41
|
+
|
42
|
+
DATA_FILE_SUFFIX = "_data.json"
|
43
|
+
EMBEDDING_FILE_SUFFIX = "_embeddings.bin"
|
44
|
+
|
45
|
+
|
46
|
+
class FileHeader(TypedDict):
|
47
|
+
version: str
|
48
|
+
|
49
|
+
|
50
|
+
# Needed to create a TypedDict.
|
51
|
+
def create_file_header() -> FileHeader:
|
52
|
+
return FileHeader(version="0.1")
|
53
|
+
|
54
|
+
|
55
|
+
class EmbeddingFileHeader(TypedDict):
|
56
|
+
relatedCount: NotRequired[int | None]
|
57
|
+
messageCount: NotRequired[int | None]
|
58
|
+
|
59
|
+
|
60
|
+
class EmbeddingData(TypedDict):
|
61
|
+
embeddings: NormalizedEmbeddings | None
|
62
|
+
|
63
|
+
|
64
|
+
class ConversationJsonData[TMessageData](ConversationDataWithIndexes[TMessageData]):
|
65
|
+
fileHeader: NotRequired[FileHeader | None]
|
66
|
+
embeddingFileHeader: NotRequired[EmbeddingFileHeader | None]
|
67
|
+
|
68
|
+
|
69
|
+
class ConversationBinaryData(TypedDict):
|
70
|
+
embeddingsList: NotRequired[list[NormalizedEmbeddings] | None]
|
71
|
+
|
72
|
+
|
73
|
+
class ConversationFileData[TMessageData](TypedDict):
|
74
|
+
# This data goes into a JSON text file
|
75
|
+
jsonData: ConversationJsonData[TMessageData]
|
76
|
+
# This goes into a single binary file
|
77
|
+
binaryData: ConversationBinaryData
|
78
|
+
|
79
|
+
|
80
|
+
# --------------
|
81
|
+
# Serialization
|
82
|
+
# ---------------
|
83
|
+
|
84
|
+
|
85
|
+
def write_conversation_data_to_file[TMessageData](
|
86
|
+
conversation_data: ConversationDataWithIndexes[TMessageData],
|
87
|
+
filename: str,
|
88
|
+
) -> None:
|
89
|
+
file_data = to_conversation_file_data(conversation_data)
|
90
|
+
binary_data = file_data["binaryData"]
|
91
|
+
if binary_data:
|
92
|
+
embeddings_list = binary_data.get("embeddingsList")
|
93
|
+
if embeddings_list:
|
94
|
+
with open(filename + EMBEDDING_FILE_SUFFIX, "wb") as f:
|
95
|
+
for embeddings in embeddings_list:
|
96
|
+
embeddings.tofile(f)
|
97
|
+
with open(filename + DATA_FILE_SUFFIX, "w", encoding="utf-8") as f:
|
98
|
+
# f.write(repr(file_data["jsonData"]))
|
99
|
+
json.dump(file_data["jsonData"], f)
|
100
|
+
|
101
|
+
|
102
|
+
def serialize_embeddings(embeddings: NormalizedEmbeddings) -> NormalizedEmbeddings:
|
103
|
+
return np.concatenate(embeddings)
|
104
|
+
|
105
|
+
|
106
|
+
def to_conversation_file_data[TMessageData](
|
107
|
+
conversation_data: ConversationDataWithIndexes[TMessageData],
|
108
|
+
) -> ConversationFileData[TMessageData]:
|
109
|
+
file_header = create_file_header()
|
110
|
+
embedding_file_header = EmbeddingFileHeader()
|
111
|
+
|
112
|
+
embeddings_list: list[NormalizedEmbeddings] = []
|
113
|
+
|
114
|
+
related_terms_index_data = conversation_data.get("relatedTermsIndexData")
|
115
|
+
if related_terms_index_data is not None:
|
116
|
+
text_embedding_data = related_terms_index_data.get("textEmbeddingData")
|
117
|
+
if text_embedding_data is not None:
|
118
|
+
embeddings = text_embedding_data.get("embeddings")
|
119
|
+
if embeddings is not None:
|
120
|
+
embeddings_list.append(embeddings)
|
121
|
+
text_embedding_data["embeddings"] = None
|
122
|
+
embedding_file_header["relatedCount"] = len(embeddings)
|
123
|
+
|
124
|
+
message_index_data = conversation_data.get("messageIndexData")
|
125
|
+
if message_index_data is not None:
|
126
|
+
text_embedding_data = message_index_data.get("indexData")
|
127
|
+
if text_embedding_data is not None:
|
128
|
+
embeddings = text_embedding_data.get("embeddings")
|
129
|
+
if embeddings is not None:
|
130
|
+
embeddings_list.append(embeddings)
|
131
|
+
text_embedding_data["embeddings"] = None
|
132
|
+
embedding_file_header["messageCount"] = len(embeddings)
|
133
|
+
|
134
|
+
binary_data = ConversationBinaryData(embeddingsList=embeddings_list)
|
135
|
+
json_data = ConversationJsonData(
|
136
|
+
**conversation_data,
|
137
|
+
fileHeader=file_header,
|
138
|
+
embeddingFileHeader=embedding_file_header,
|
139
|
+
)
|
140
|
+
file_data = ConversationFileData(
|
141
|
+
jsonData=json_data,
|
142
|
+
binaryData=binary_data,
|
143
|
+
)
|
144
|
+
|
145
|
+
return file_data
|
146
|
+
|
147
|
+
|
148
|
+
# This converts any Pydantic dataclass instance to a dict, recursively,
|
149
|
+
# with field names converted to camelCase.
|
150
|
+
@overload
|
151
|
+
def serialize_object(arg: None) -> None: ...
|
152
|
+
@overload
|
153
|
+
def serialize_object(arg: object) -> Any:
|
154
|
+
# NOTE: This now works specifically with Pydantic dataclasses.
|
155
|
+
...
|
156
|
+
|
157
|
+
|
158
|
+
def serialize_object(arg: Any) -> Any | None:
|
159
|
+
if arg is None:
|
160
|
+
return None
|
161
|
+
|
162
|
+
# Require Pydantic dataclass
|
163
|
+
if not hasattr(arg, "__pydantic_serializer__"):
|
164
|
+
raise TypeError(f"Object must be a Pydantic dataclass, got {type(arg)}")
|
165
|
+
|
166
|
+
# Use Pydantic's serialization with aliases
|
167
|
+
return arg.__pydantic_serializer__.to_python(arg, by_alias=True) # type: ignore
|
168
|
+
|
169
|
+
|
170
|
+
# ----------------
|
171
|
+
# Deserialization
|
172
|
+
# -----------------
|
173
|
+
|
174
|
+
|
175
|
+
def from_conversation_file_data(
|
176
|
+
file_data: ConversationFileData[Any],
|
177
|
+
) -> ConversationDataWithIndexes[Any]:
|
178
|
+
json_data = file_data["jsonData"]
|
179
|
+
file_header = json_data.get("fileHeader")
|
180
|
+
if file_header is None:
|
181
|
+
raise DeserializationError("Missing file header")
|
182
|
+
if file_header["version"] != "0.1":
|
183
|
+
raise DeserializationError(f"Unsupported file version {file_header['version']}")
|
184
|
+
embedding_file_header = json_data.get("embeddingFileHeader")
|
185
|
+
if embedding_file_header is None:
|
186
|
+
raise DeserializationError("Missing embedding file header")
|
187
|
+
|
188
|
+
binary_data = file_data["binaryData"]
|
189
|
+
if binary_data:
|
190
|
+
embeddings_list = binary_data.get("embeddingsList")
|
191
|
+
if embeddings_list is None:
|
192
|
+
raise DeserializationError("Missing embeddings list")
|
193
|
+
if len(embeddings_list) != 1:
|
194
|
+
raise ValueError(
|
195
|
+
f"Expected embeddings list of lengt 1, got {len(embeddings_list)}"
|
196
|
+
)
|
197
|
+
embeddings = embeddings_list[0]
|
198
|
+
pos = 0
|
199
|
+
pos += get_embeddings_from_binary_data(
|
200
|
+
embeddings,
|
201
|
+
json_data,
|
202
|
+
("relatedTermsIndexData", "textEmbeddingData"),
|
203
|
+
pos,
|
204
|
+
embedding_file_header.get("relatedCount"),
|
205
|
+
)
|
206
|
+
pos += get_embeddings_from_binary_data(
|
207
|
+
embeddings,
|
208
|
+
json_data,
|
209
|
+
("messageIndexData", "indexData"),
|
210
|
+
pos,
|
211
|
+
embedding_file_header.get("messageCount"),
|
212
|
+
)
|
213
|
+
return json_data
|
214
|
+
|
215
|
+
|
216
|
+
def get_embeddings_from_binary_data(
|
217
|
+
embeddings: NormalizedEmbeddings,
|
218
|
+
json_data: ConversationJsonData[Any],
|
219
|
+
keys: tuple[str, ...],
|
220
|
+
offset: int,
|
221
|
+
count: int | None,
|
222
|
+
) -> int:
|
223
|
+
if count is None or count <= 0:
|
224
|
+
return 0
|
225
|
+
embeddings = embeddings[offset : offset + count] # Simple np slice creates a view.
|
226
|
+
if len(embeddings) != count:
|
227
|
+
raise DeserializationError(
|
228
|
+
f"Expected {count} embeddings, got {len(embeddings)}"
|
229
|
+
)
|
230
|
+
data: dict[str, object] = cast(
|
231
|
+
dict[str, object], json_data
|
232
|
+
) # We know it's a dict, but pyright doesn't.
|
233
|
+
# Traverse the keys to get to the embeddings.
|
234
|
+
for key in keys:
|
235
|
+
new_data = data.get(key)
|
236
|
+
if new_data is None or type(new_data) is not dict:
|
237
|
+
return 0
|
238
|
+
data = new_data
|
239
|
+
if "embeddings" in data:
|
240
|
+
data["embeddings"] = embeddings
|
241
|
+
return count
|
242
|
+
|
243
|
+
|
244
|
+
TYPE_MAP = {
|
245
|
+
"entity": kplib.ConcreteEntity,
|
246
|
+
"action": kplib.Action,
|
247
|
+
"topic": Topic,
|
248
|
+
"tag": Tag,
|
249
|
+
}
|
250
|
+
|
251
|
+
|
252
|
+
# Looks like this only works for knowledge...
|
253
|
+
def deserialize_knowledge(knowledge_type: str, obj: Any) -> Any:
|
254
|
+
typ = TYPE_MAP[knowledge_type]
|
255
|
+
return deserialize_object(typ, obj)
|
256
|
+
|
257
|
+
|
258
|
+
class DeserializationError(Exception):
|
259
|
+
pass
|
260
|
+
|
261
|
+
|
262
|
+
@functools.cache
|
263
|
+
def is_primitive(typ: type) -> bool:
|
264
|
+
return typ in (int, float, bool, str, type(None))
|
265
|
+
|
266
|
+
|
267
|
+
# TODO: Use type(obj) is X instead of isinstance(obj, X). It's faster.
|
268
|
+
# TODO: Design a consistent reporting format.
|
269
|
+
# TODO: Doesn't Pydantic have this functionality?
|
270
|
+
def deserialize_object(typ: Any, obj: Any) -> Any:
|
271
|
+
if isinstance(typ, str):
|
272
|
+
# A forward reference; special-case those that matter.
|
273
|
+
match typ:
|
274
|
+
case "SearchTermGroupTypes":
|
275
|
+
typ = SearchTermGroupTypes
|
276
|
+
case _:
|
277
|
+
raise DeserializationError(f"Unknown forward type reference {typ!r}")
|
278
|
+
elif typ is datetime:
|
279
|
+
# Special case for datetime, which is serialized as a string.
|
280
|
+
if isinstance(obj, str):
|
281
|
+
try:
|
282
|
+
return datetime.fromisoformat(obj)
|
283
|
+
except ValueError as e:
|
284
|
+
raise DeserializationError(f"Invalid datetime string {obj!r}") from e
|
285
|
+
else:
|
286
|
+
raise DeserializationError(
|
287
|
+
f"Expected datetime string, got {type(obj)}: {obj!r}"
|
288
|
+
)
|
289
|
+
if typ.__class__ is TypeAliasType:
|
290
|
+
typ = typ.__value__
|
291
|
+
origin = get_origin(typ)
|
292
|
+
|
293
|
+
# Handle Annotated by substituting its first argument for typ.
|
294
|
+
if origin is Annotated:
|
295
|
+
typ = get_args(typ)[0]
|
296
|
+
if typ.__class__ is TypeAliasType:
|
297
|
+
typ = typ.__value__
|
298
|
+
origin = get_origin(typ) # Get the first type argument.
|
299
|
+
|
300
|
+
# Non-generic: primitives and dataclasses.
|
301
|
+
if origin is None:
|
302
|
+
if is_primitive(typ):
|
303
|
+
if typ is int and type(obj) is float:
|
304
|
+
return int(obj)
|
305
|
+
if typ is float and type(obj) is int:
|
306
|
+
return float(obj)
|
307
|
+
if not isinstance(obj, typ):
|
308
|
+
raise DeserializationError(f"Expected {typ} but got {type(obj)}")
|
309
|
+
return obj
|
310
|
+
elif isinstance(typ, type) and is_dataclass(typ):
|
311
|
+
if not isinstance(obj, dict):
|
312
|
+
raise DeserializationError(f"Expected dict for {typ}, got {type(obj)}")
|
313
|
+
|
314
|
+
# Require Pydantic dataclass
|
315
|
+
if not hasattr(typ, "__pydantic_validator__"):
|
316
|
+
raise TypeError(f"Type must be a Pydantic dataclass, got {typ}")
|
317
|
+
|
318
|
+
try:
|
319
|
+
# Use Pydantic's validator with aliases
|
320
|
+
return typ.__pydantic_validator__.validate_python(obj) # type: ignore
|
321
|
+
except Exception as e:
|
322
|
+
raise DeserializationError(
|
323
|
+
f"Pydantic validation failed for {typ.__name__}: {e}"
|
324
|
+
) from e
|
325
|
+
else:
|
326
|
+
# Could be a class that's not a dataclass -- we don't know the signature.
|
327
|
+
raise TypeError(f"Unsupported origin-less type {typ}")
|
328
|
+
|
329
|
+
# Handle Literal.
|
330
|
+
if origin is Literal:
|
331
|
+
if type(obj) is str and obj in get_args(typ):
|
332
|
+
return obj
|
333
|
+
raise DeserializationError(
|
334
|
+
f"Expected one of {get_args(typ)} for Literal, but got {obj!r} of type {type(obj)}"
|
335
|
+
)
|
336
|
+
|
337
|
+
# Handle list[T] / List[T].
|
338
|
+
if origin is list:
|
339
|
+
if not isinstance(obj, list):
|
340
|
+
raise DeserializationError(f"Expected list for list, got {type(obj)}")
|
341
|
+
(elem_type,) = get_args(typ)
|
342
|
+
return [deserialize_object(elem_type, item) for item in obj]
|
343
|
+
|
344
|
+
# Handle tuple[T1, T2, etc.] / Tuple[T1, T2, etc.].
|
345
|
+
if origin is tuple:
|
346
|
+
if not isinstance(obj, list):
|
347
|
+
raise DeserializationError(f"Expected list for tuple, got {type(obj)}")
|
348
|
+
args = get_args(typ)
|
349
|
+
if len(args) != len(obj):
|
350
|
+
raise DeserializationError(
|
351
|
+
f"Tuple length mismatch: expected {len(args)}, got {len(obj)}"
|
352
|
+
)
|
353
|
+
return tuple(deserialize_object(t, item) for t, item in zip(args, obj))
|
354
|
+
|
355
|
+
# Handle Union[X, Y], Optional[X], and X | Y.
|
356
|
+
if origin in (Union, types.UnionType):
|
357
|
+
candidates = get_args(typ)
|
358
|
+
# Disambiguate among dataclasses if possible.
|
359
|
+
dataclass_candidates = [
|
360
|
+
c for c in candidates if isinstance(c, type) and is_dataclass(c)
|
361
|
+
]
|
362
|
+
if dataclass_candidates and isinstance(obj, dict):
|
363
|
+
matching = []
|
364
|
+
for candidate in dataclass_candidates:
|
365
|
+
mandatory = {
|
366
|
+
to_camel(name)
|
367
|
+
for name, field in candidate.__dataclass_fields__.items()
|
368
|
+
if field.default is MISSING and field.default_factory is MISSING
|
369
|
+
}
|
370
|
+
if mandatory.issubset(obj.keys()):
|
371
|
+
matching.append(candidate)
|
372
|
+
if len(matching) == 1:
|
373
|
+
return deserialize_object(matching[0], obj)
|
374
|
+
elif len(matching) > 1:
|
375
|
+
raise TypeError(
|
376
|
+
f"Ambiguous union {typ}: multiple dataclass candidates match: "
|
377
|
+
+ str([c.__name__ for c in matching])
|
378
|
+
)
|
379
|
+
# Try each candidate until one succeeds.
|
380
|
+
all_excs = []
|
381
|
+
for candidate in candidates:
|
382
|
+
try:
|
383
|
+
return deserialize_object(candidate, obj)
|
384
|
+
except DeserializationError as e:
|
385
|
+
all_excs.append(e)
|
386
|
+
raise DeserializationError(
|
387
|
+
f"No candidate from union {typ} succeeded -- errors: {all_excs}"
|
388
|
+
)
|
389
|
+
|
390
|
+
raise TypeError(f"Unsupported type {typ}, object {obj!r} of type {type(obj)}")
|
@@ -0,0 +1,179 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from collections.abc import Callable
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Protocol
|
7
|
+
|
8
|
+
from ..aitools.embeddings import NormalizedEmbedding
|
9
|
+
from ..aitools.vectorbase import TextEmbeddingIndexSettings
|
10
|
+
|
11
|
+
from .fuzzyindex import ScoredInt, EmbeddingIndex
|
12
|
+
from .interfaces import (
|
13
|
+
TextToTextLocationIndexData,
|
14
|
+
TextLocation,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class ScoredTextLocation:
|
20
|
+
text_location: TextLocation
|
21
|
+
score: float
|
22
|
+
|
23
|
+
|
24
|
+
class ITextToTextLocationIndex(Protocol):
|
25
|
+
async def add_text_location(
|
26
|
+
self, text: str, text_location: TextLocation
|
27
|
+
) -> None: ...
|
28
|
+
|
29
|
+
async def add_text_locations(
|
30
|
+
self,
|
31
|
+
text_and_locations: list[tuple[str, TextLocation]],
|
32
|
+
) -> None: ...
|
33
|
+
|
34
|
+
async def lookup_text(
|
35
|
+
self,
|
36
|
+
text: str,
|
37
|
+
max_matches: int | None = None,
|
38
|
+
threshold_score: float | None = None,
|
39
|
+
) -> list[ScoredTextLocation]: ...
|
40
|
+
|
41
|
+
async def size(self) -> int: ...
|
42
|
+
|
43
|
+
async def is_empty(self) -> bool: ...
|
44
|
+
|
45
|
+
def serialize(self) -> TextToTextLocationIndexData: ...
|
46
|
+
|
47
|
+
def deserialize(self, data: TextToTextLocationIndexData) -> None: ...
|
48
|
+
|
49
|
+
|
50
|
+
class TextToTextLocationIndex(ITextToTextLocationIndex):
|
51
|
+
def __init__(self, settings: TextEmbeddingIndexSettings):
|
52
|
+
self._text_locations: list[TextLocation] = []
|
53
|
+
self._embedding_index: EmbeddingIndex = EmbeddingIndex(settings=settings)
|
54
|
+
self._settings = settings
|
55
|
+
|
56
|
+
async def size(self) -> int:
|
57
|
+
return await self._embedding_index.size()
|
58
|
+
|
59
|
+
async def is_empty(self) -> bool:
|
60
|
+
return await self._embedding_index.is_empty()
|
61
|
+
|
62
|
+
def get(self, pos: int, default: TextLocation | None = None) -> TextLocation | None:
|
63
|
+
size = len(self._text_locations)
|
64
|
+
if 0 <= pos < size:
|
65
|
+
return self._text_locations[pos]
|
66
|
+
return default
|
67
|
+
|
68
|
+
async def add_text_location(self, text: str, text_location: TextLocation) -> None:
|
69
|
+
await self.add_text_locations([(text, text_location)])
|
70
|
+
|
71
|
+
async def add_text_locations(
|
72
|
+
self,
|
73
|
+
text_and_locations: list[tuple[str, TextLocation]],
|
74
|
+
) -> None:
|
75
|
+
await self._embedding_index.add_texts([text for text, _ in text_and_locations])
|
76
|
+
self._text_locations.extend([loc for _, loc in text_and_locations])
|
77
|
+
|
78
|
+
async def lookup_text(
|
79
|
+
self,
|
80
|
+
text: str,
|
81
|
+
max_matches: int | None = None,
|
82
|
+
threshold_score: float | None = None,
|
83
|
+
) -> list[ScoredTextLocation]:
|
84
|
+
embedding = await self.generate_embedding(text)
|
85
|
+
matches = self._embedding_index.get_indexes_of_nearest(
|
86
|
+
embedding,
|
87
|
+
max_matches=max_matches,
|
88
|
+
min_score=threshold_score if threshold_score is not None else 0.85,
|
89
|
+
)
|
90
|
+
return [
|
91
|
+
ScoredTextLocation(self._text_locations[match.item], match.score)
|
92
|
+
for match in matches
|
93
|
+
]
|
94
|
+
|
95
|
+
async def lookup_text_in_subset(
|
96
|
+
self,
|
97
|
+
text: str,
|
98
|
+
ordinals_to_search: list[int],
|
99
|
+
max_matches: int | None = None,
|
100
|
+
threshold_score: float | None = None,
|
101
|
+
) -> list[ScoredTextLocation]:
|
102
|
+
embedding = await self.generate_embedding(text)
|
103
|
+
matches = self._embedding_index.get_indexes_of_nearest_in_subset(
|
104
|
+
embedding,
|
105
|
+
ordinals_to_search,
|
106
|
+
max_matches,
|
107
|
+
threshold_score,
|
108
|
+
)
|
109
|
+
return [
|
110
|
+
ScoredTextLocation(self._text_locations[match.item], match.score)
|
111
|
+
for match in matches
|
112
|
+
]
|
113
|
+
|
114
|
+
async def generate_embedding(
|
115
|
+
self, text: str, cache: bool = True
|
116
|
+
) -> NormalizedEmbedding:
|
117
|
+
return await self._embedding_index.get_embedding(text, cache)
|
118
|
+
|
119
|
+
def lookup_by_embedding(
|
120
|
+
self,
|
121
|
+
text_embedding: NormalizedEmbedding,
|
122
|
+
max_matches: int | None = None,
|
123
|
+
threshold_score: float | None = None,
|
124
|
+
predicate: Callable[[int], bool] | None = None,
|
125
|
+
) -> list[ScoredTextLocation]:
|
126
|
+
matches = self._embedding_index.get_indexes_of_nearest(
|
127
|
+
text_embedding,
|
128
|
+
max_matches,
|
129
|
+
threshold_score,
|
130
|
+
predicate,
|
131
|
+
)
|
132
|
+
return self.to_scored_locations(matches)
|
133
|
+
|
134
|
+
def lookup_in_subset_by_embedding(
|
135
|
+
self,
|
136
|
+
text_embedding: NormalizedEmbedding,
|
137
|
+
ordinals_to_match: list[int],
|
138
|
+
max_matches: int | None = None,
|
139
|
+
threshold_score: float | None = None,
|
140
|
+
) -> list[ScoredTextLocation]:
|
141
|
+
matches = self._embedding_index.get_indexes_of_nearest_in_subset(
|
142
|
+
text_embedding,
|
143
|
+
ordinals_to_match,
|
144
|
+
max_matches,
|
145
|
+
threshold_score,
|
146
|
+
)
|
147
|
+
return self.to_scored_locations(matches)
|
148
|
+
|
149
|
+
def to_scored_locations(self, matches: list[ScoredInt]) -> list[ScoredTextLocation]:
|
150
|
+
return [
|
151
|
+
ScoredTextLocation(self._text_locations[match.item], match.score)
|
152
|
+
for match in matches
|
153
|
+
]
|
154
|
+
|
155
|
+
def clear(self) -> None:
|
156
|
+
self._text_locations = []
|
157
|
+
self._embedding_index.clear()
|
158
|
+
|
159
|
+
def serialize(self) -> TextToTextLocationIndexData:
|
160
|
+
return TextToTextLocationIndexData(
|
161
|
+
textLocations=[loc.serialize() for loc in self._text_locations],
|
162
|
+
embeddings=self._embedding_index.serialize(),
|
163
|
+
)
|
164
|
+
|
165
|
+
def deserialize(self, data: TextToTextLocationIndexData) -> None:
|
166
|
+
self._text_locations.clear()
|
167
|
+
self._embedding_index.clear()
|
168
|
+
text_locations = data["textLocations"]
|
169
|
+
embeddings = data["embeddings"]
|
170
|
+
|
171
|
+
if embeddings is None:
|
172
|
+
raise ValueError("No embeddings found")
|
173
|
+
if len(text_locations) != len(embeddings):
|
174
|
+
raise ValueError(
|
175
|
+
f"TextToTextLocationIndexData corrupt. textLocation.length {len(text_locations)} != {len(embeddings)}"
|
176
|
+
)
|
177
|
+
|
178
|
+
self._text_locations = [TextLocation.deserialize(loc) for loc in text_locations]
|
179
|
+
self._embedding_index.deserialize(embeddings)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
"""Utility functions for the knowpro package."""
|
5
|
+
|
6
|
+
from .interfaces import MessageOrdinal, TextLocation, TextRange
|
7
|
+
|
8
|
+
|
9
|
+
def text_range_from_message_chunk(
|
10
|
+
message_ordinal: MessageOrdinal,
|
11
|
+
chunk_ordinal: int = 0,
|
12
|
+
) -> TextRange:
|
13
|
+
"""Create a TextRange from message and chunk ordinals."""
|
14
|
+
return TextRange(
|
15
|
+
start=TextLocation(message_ordinal, chunk_ordinal),
|
16
|
+
end=None,
|
17
|
+
)
|