typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1128 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from collections.abc import Callable
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
from re import search
|
8
|
+
from typing import Literal, Protocol, cast
|
9
|
+
|
10
|
+
from ..aitools.embeddings import NormalizedEmbedding
|
11
|
+
|
12
|
+
from .collections import (
|
13
|
+
Match,
|
14
|
+
MatchAccumulator,
|
15
|
+
MessageAccumulator,
|
16
|
+
PropertyTermSet,
|
17
|
+
SemanticRefAccumulator,
|
18
|
+
TermSet,
|
19
|
+
TextRangeCollection,
|
20
|
+
TextRangesInScope,
|
21
|
+
)
|
22
|
+
from .common import is_search_term_wildcard
|
23
|
+
from .interfaces import (
|
24
|
+
Datetime,
|
25
|
+
DateRange,
|
26
|
+
IConversation,
|
27
|
+
IMessage,
|
28
|
+
IMessageCollection,
|
29
|
+
IPropertyToSemanticRefIndex,
|
30
|
+
ISemanticRefCollection,
|
31
|
+
ITermToSemanticRefIndex,
|
32
|
+
ITimestampToTextRangeIndex,
|
33
|
+
KnowledgeType,
|
34
|
+
MessageOrdinal,
|
35
|
+
PropertySearchTerm,
|
36
|
+
ScoredMessageOrdinal,
|
37
|
+
ScoredSemanticRefOrdinal,
|
38
|
+
SearchTerm,
|
39
|
+
SearchTermGroup,
|
40
|
+
SemanticRef,
|
41
|
+
SemanticRefOrdinal,
|
42
|
+
SemanticRefSearchResult,
|
43
|
+
Term,
|
44
|
+
TextLocation,
|
45
|
+
TextRange,
|
46
|
+
Thread,
|
47
|
+
)
|
48
|
+
from .kplib import ConcreteEntity
|
49
|
+
from ..storage.memory.messageindex import IMessageTextEmbeddingIndex
|
50
|
+
from ..storage.memory.propindex import PropertyNames, lookup_property_in_property_index
|
51
|
+
from .searchlib import create_property_search_term, create_tag_search_term_group
|
52
|
+
|
53
|
+
|
54
|
+
# TODO: Move to compilelib.py
|
55
|
+
type BooleanOp = Literal["and", "or", "or_max"]
|
56
|
+
|
57
|
+
|
58
|
+
# TODO: Move to compilelib.py
|
59
|
+
@dataclass
|
60
|
+
class CompiledSearchTerm(SearchTerm):
|
61
|
+
related_terms_required: bool = False
|
62
|
+
|
63
|
+
|
64
|
+
# TODO: Move to compilelib.py
|
65
|
+
def to_required_search_term(term: SearchTerm) -> CompiledSearchTerm:
|
66
|
+
# NOTE: We must cast since the output must alias the input.
|
67
|
+
# If not, assignments to related_terms will be lost.
|
68
|
+
cst = cast(CompiledSearchTerm, term)
|
69
|
+
cst.related_terms_required = True
|
70
|
+
return cst
|
71
|
+
|
72
|
+
|
73
|
+
# TODO: Move to compilelib.py
|
74
|
+
def to_non_required_search_term(term: SearchTerm) -> CompiledSearchTerm:
|
75
|
+
# NOTE: We must cast since the output must alias the input.
|
76
|
+
# If not, assignments to related_terms will be lost.
|
77
|
+
cst = cast(CompiledSearchTerm, term)
|
78
|
+
cst.related_terms_required = False
|
79
|
+
return cst
|
80
|
+
|
81
|
+
|
82
|
+
# TODO: Move to compilelib.py
|
83
|
+
@dataclass
|
84
|
+
class CompiledTermGroup:
|
85
|
+
boolean_op: BooleanOp
|
86
|
+
terms: list[CompiledSearchTerm]
|
87
|
+
|
88
|
+
|
89
|
+
def is_conversation_searchable(conversation: IConversation) -> bool:
|
90
|
+
"""Determine if a conversation is searchable.
|
91
|
+
|
92
|
+
A conversation is searchable if it has a semantic reference index
|
93
|
+
and semantic references initialized.
|
94
|
+
"""
|
95
|
+
# TODO: also require secondary indices, once we have removed non-index based retrieval to test.
|
96
|
+
return (
|
97
|
+
conversation.semantic_ref_index is not None
|
98
|
+
and conversation.semantic_refs is not None
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
async def get_text_range_for_date_range(
|
103
|
+
conversation: IConversation,
|
104
|
+
date_range: DateRange,
|
105
|
+
) -> TextRange | None:
|
106
|
+
messages = conversation.messages
|
107
|
+
range_start_ordinal: MessageOrdinal = -1
|
108
|
+
range_end_ordinal = range_start_ordinal
|
109
|
+
async for message in messages:
|
110
|
+
if Datetime.fromisoformat(message.timestamp) in date_range:
|
111
|
+
if range_start_ordinal < 0:
|
112
|
+
range_start_ordinal = message.ordinal
|
113
|
+
range_end_ordinal = message.ordinal
|
114
|
+
else:
|
115
|
+
if range_start_ordinal >= 0:
|
116
|
+
# We have a range, so break.
|
117
|
+
break
|
118
|
+
if range_start_ordinal >= 0:
|
119
|
+
return TextRange(
|
120
|
+
start=TextLocation(range_start_ordinal),
|
121
|
+
end=TextLocation(range_end_ordinal + 1),
|
122
|
+
)
|
123
|
+
return None
|
124
|
+
|
125
|
+
|
126
|
+
def get_matching_term_for_text(search_term: SearchTerm, text: str) -> Term | None:
|
127
|
+
# Do case-INSENSITIVE comparisons, since stored entities may have different case.
|
128
|
+
if text.lower() == search_term.term.text.lower():
|
129
|
+
return search_term.term
|
130
|
+
if search_term.related_terms:
|
131
|
+
for related_term in search_term.related_terms:
|
132
|
+
if text.lower() == related_term.text.lower():
|
133
|
+
return related_term
|
134
|
+
return None
|
135
|
+
|
136
|
+
|
137
|
+
def match_search_term_to_text(search_term: SearchTerm, text: str | None) -> bool:
|
138
|
+
if text:
|
139
|
+
return get_matching_term_for_text(search_term, text) is not None
|
140
|
+
return False
|
141
|
+
|
142
|
+
|
143
|
+
def match_search_term_to_one_of_text(
|
144
|
+
search_term: SearchTerm, texts: list[str] | None
|
145
|
+
) -> bool:
|
146
|
+
if texts:
|
147
|
+
for text in texts:
|
148
|
+
if match_search_term_to_text(search_term, text):
|
149
|
+
return True
|
150
|
+
return False
|
151
|
+
|
152
|
+
|
153
|
+
# TODO: match_search_term_to_entity
|
154
|
+
# TODO: match_property_search_term_to_entity
|
155
|
+
# TODO: match_concrete_entity
|
156
|
+
|
157
|
+
|
158
|
+
def match_entity_name_or_type(
|
159
|
+
property_value: SearchTerm,
|
160
|
+
entity: ConcreteEntity,
|
161
|
+
) -> bool:
|
162
|
+
return match_search_term_to_text(
|
163
|
+
property_value, entity.name
|
164
|
+
) or match_search_term_to_one_of_text(property_value, entity.type)
|
165
|
+
|
166
|
+
|
167
|
+
# TODO: match_property_name_to_facet_name
|
168
|
+
# TODO: match_property_name_to_facet_value
|
169
|
+
# TODO: match_property_search_term_to_action
|
170
|
+
# TODO: match_property_search_term_to_tag
|
171
|
+
# TODO: match_property_search_term_to_semantic_ref
|
172
|
+
|
173
|
+
|
174
|
+
async def lookup_term_filtered(
|
175
|
+
semantic_ref_index: ITermToSemanticRefIndex,
|
176
|
+
term: Term,
|
177
|
+
semantic_refs: ISemanticRefCollection,
|
178
|
+
filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool],
|
179
|
+
) -> list[ScoredSemanticRefOrdinal] | None:
|
180
|
+
"""Look up a term in the semantic reference index and filter the results."""
|
181
|
+
scored_refs = await semantic_ref_index.lookup_term(term.text)
|
182
|
+
if scored_refs:
|
183
|
+
filtered = []
|
184
|
+
for sr in scored_refs:
|
185
|
+
semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal)
|
186
|
+
if filter(semantic_ref, sr):
|
187
|
+
filtered.append(sr)
|
188
|
+
return filtered
|
189
|
+
return None
|
190
|
+
|
191
|
+
|
192
|
+
async def lookup_term(
|
193
|
+
semantic_ref_index: ITermToSemanticRefIndex,
|
194
|
+
term: Term,
|
195
|
+
semantic_refs: ISemanticRefCollection,
|
196
|
+
ranges_in_scope: TextRangesInScope | None = None,
|
197
|
+
knowledge_type: KnowledgeType | None = None,
|
198
|
+
) -> list[ScoredSemanticRefOrdinal] | None:
|
199
|
+
"""Look up a term in the semantic reference index, optionally filtering by ranges in scope."""
|
200
|
+
if ranges_in_scope is not None:
|
201
|
+
# If ranges_in_scope has no actual text ranges, lookups can't possibly match.
|
202
|
+
return await lookup_term_filtered(
|
203
|
+
semantic_ref_index,
|
204
|
+
term,
|
205
|
+
semantic_refs,
|
206
|
+
lambda sr, _: (
|
207
|
+
not knowledge_type or sr.knowledge.knowledge_type == knowledge_type
|
208
|
+
)
|
209
|
+
and ranges_in_scope.is_range_in_scope(sr.range),
|
210
|
+
)
|
211
|
+
return await semantic_ref_index.lookup_term(term.text)
|
212
|
+
|
213
|
+
|
214
|
+
# TODO: lookup_property
|
215
|
+
# TODO: lookup_knowledge_type
|
216
|
+
|
217
|
+
|
218
|
+
@dataclass
|
219
|
+
class QueryEvalContext[TMessage: IMessage, TIndex: ITermToSemanticRefIndex]:
|
220
|
+
"""Context for evaluating a query within a conversation.
|
221
|
+
|
222
|
+
This class provides the necessary context for query evaluation, including
|
223
|
+
the conversation being queried, optional indexes for properties and timestamps,
|
224
|
+
and structures for tracking matched terms and text ranges in scope.
|
225
|
+
"""
|
226
|
+
|
227
|
+
# TODO: Make property and timestamp indexes NON-OPTIONAL
|
228
|
+
# TODO: Move non-index based code to test
|
229
|
+
|
230
|
+
conversation: IConversation[TMessage, TIndex]
|
231
|
+
# If a property secondary index is available, the query processor will use it
|
232
|
+
property_index: IPropertyToSemanticRefIndex | None = None
|
233
|
+
# If a timestamp secondary index is available, the query processor will use it
|
234
|
+
timestamp_index: ITimestampToTextRangeIndex | None = None
|
235
|
+
matched_terms: TermSet = field(init=False, default_factory=TermSet)
|
236
|
+
matched_property_terms: PropertyTermSet = field(
|
237
|
+
init=False, default_factory=PropertyTermSet
|
238
|
+
)
|
239
|
+
text_ranges_in_scope: TextRangesInScope | None = field(
|
240
|
+
init=False, default_factory=TextRangesInScope
|
241
|
+
)
|
242
|
+
|
243
|
+
def __post_init__(self):
|
244
|
+
if not is_conversation_searchable(self.conversation):
|
245
|
+
raise ValueError(
|
246
|
+
f"{self.conversation.name_tag} "
|
247
|
+
+ "is not initialized and cannot be searched."
|
248
|
+
)
|
249
|
+
|
250
|
+
@property
|
251
|
+
def semantic_ref_index(self) -> ITermToSemanticRefIndex:
|
252
|
+
assert self.conversation.semantic_ref_index is not None
|
253
|
+
return self.conversation.semantic_ref_index
|
254
|
+
|
255
|
+
@property
|
256
|
+
def semantic_refs(self) -> ISemanticRefCollection:
|
257
|
+
assert self.conversation.semantic_refs is not None
|
258
|
+
return self.conversation.semantic_refs
|
259
|
+
|
260
|
+
@property
|
261
|
+
def messages(self) -> IMessageCollection:
|
262
|
+
return self.conversation.messages
|
263
|
+
|
264
|
+
async def get_semantic_ref(
|
265
|
+
self, semantic_ref_ordinal: SemanticRefOrdinal
|
266
|
+
) -> SemanticRef:
|
267
|
+
"""Retrieve a semantic reference by its ordinal."""
|
268
|
+
assert self.conversation.semantic_refs is not None
|
269
|
+
return await self.conversation.semantic_refs.get_item(semantic_ref_ordinal)
|
270
|
+
|
271
|
+
async def get_message_for_ref(self, semantic_ref: SemanticRef) -> TMessage:
|
272
|
+
"""Retrieve the message associated with a semantic reference."""
|
273
|
+
message_ordinal = semantic_ref.range.start.message_ordinal
|
274
|
+
return await self.conversation.messages.get_item(message_ordinal)
|
275
|
+
|
276
|
+
async def get_message(self, message_ordinal: MessageOrdinal) -> TMessage:
|
277
|
+
"""Retrieve a message by its ordinal."""
|
278
|
+
return await self.messages.get_item(message_ordinal)
|
279
|
+
|
280
|
+
def clear_matched_terms(self) -> None:
|
281
|
+
"""Clear all matched terms and property terms."""
|
282
|
+
self.matched_terms.clear()
|
283
|
+
self.matched_property_terms.clear()
|
284
|
+
|
285
|
+
|
286
|
+
async def lookup_knowledge_type(
|
287
|
+
semantic_refs: ISemanticRefCollection, knowledge_type: KnowledgeType
|
288
|
+
) -> list[ScoredSemanticRefOrdinal]:
|
289
|
+
return [
|
290
|
+
ScoredSemanticRefOrdinal(sr.semantic_ref_ordinal, 1.0)
|
291
|
+
async for sr in semantic_refs
|
292
|
+
if sr.knowledge.knowledge_type == knowledge_type
|
293
|
+
]
|
294
|
+
|
295
|
+
|
296
|
+
class IQueryOpExpr[T](Protocol):
|
297
|
+
"""Protocol for query operation expressions that can be evaluated in a context."""
|
298
|
+
|
299
|
+
async def eval(self, context: QueryEvalContext) -> T: ...
|
300
|
+
|
301
|
+
|
302
|
+
class QueryOpExpr[T](IQueryOpExpr[T]):
|
303
|
+
"""Base class for query operation expressions."""
|
304
|
+
|
305
|
+
|
306
|
+
@dataclass
|
307
|
+
class SelectTopNExpr[T: MatchAccumulator](QueryOpExpr[T]):
|
308
|
+
"""Expression for selecting the top N matches from a query."""
|
309
|
+
|
310
|
+
source_expr: IQueryOpExpr[T]
|
311
|
+
max_matches: int | None = None
|
312
|
+
min_hit_count: int | None = None
|
313
|
+
|
314
|
+
async def eval(self, context: QueryEvalContext) -> T:
|
315
|
+
"""Evaluate the expression and return the top N matches."""
|
316
|
+
matches = await self.source_expr.eval(context)
|
317
|
+
matches.select_top_n_scoring(self.max_matches, self.min_hit_count)
|
318
|
+
return matches
|
319
|
+
|
320
|
+
|
321
|
+
# Abstract base class.
|
322
|
+
class MatchTermsBooleanExpr(QueryOpExpr[SemanticRefAccumulator]):
|
323
|
+
"""Expression for matching terms in a boolean query.
|
324
|
+
|
325
|
+
Subclasses implement 'OR', 'OR MAX' and 'AND' logic.
|
326
|
+
"""
|
327
|
+
|
328
|
+
get_scope_expr: "GetScopeExpr | None" = None
|
329
|
+
|
330
|
+
async def begin_match(self, context: QueryEvalContext) -> None:
|
331
|
+
"""Prepare for matching terms in the context by resetting some things."""
|
332
|
+
if self.get_scope_expr is not None:
|
333
|
+
context.text_ranges_in_scope = await self.get_scope_expr.eval(context)
|
334
|
+
context.clear_matched_terms()
|
335
|
+
|
336
|
+
|
337
|
+
@dataclass
|
338
|
+
class MatchTermsOrExpr(MatchTermsBooleanExpr):
|
339
|
+
"""Expression for matching terms with an OR condition."""
|
340
|
+
|
341
|
+
term_expressions: list[IQueryOpExpr[SemanticRefAccumulator | None]] = field(
|
342
|
+
default_factory=list
|
343
|
+
)
|
344
|
+
get_scope_expr: "GetScopeExpr | None" = None
|
345
|
+
|
346
|
+
async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
|
347
|
+
await self.begin_match(context)
|
348
|
+
all_matches: SemanticRefAccumulator | None = None
|
349
|
+
for match_expr in self.term_expressions:
|
350
|
+
term_matches = await match_expr.eval(context)
|
351
|
+
if term_matches:
|
352
|
+
if all_matches is None:
|
353
|
+
all_matches = term_matches
|
354
|
+
else:
|
355
|
+
all_matches.add_union(term_matches)
|
356
|
+
if all_matches is not None:
|
357
|
+
all_matches.calculate_total_score()
|
358
|
+
return all_matches or SemanticRefAccumulator()
|
359
|
+
|
360
|
+
|
361
|
+
@dataclass
|
362
|
+
class MatchTermsOrMaxExpr(MatchTermsOrExpr):
|
363
|
+
"""OR-MAX returns the union if there are no common matches, else the maximum scoring match."""
|
364
|
+
|
365
|
+
term_expressions: list[IQueryOpExpr[SemanticRefAccumulator | None]] = field(
|
366
|
+
default_factory=list
|
367
|
+
)
|
368
|
+
get_scope_expr: "GetScopeExpr | None" = None
|
369
|
+
|
370
|
+
async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
|
371
|
+
matches = await super().eval(context)
|
372
|
+
max_hit_count = matches.get_max_hit_count()
|
373
|
+
if max_hit_count > 1:
|
374
|
+
matches.select_with_hit_count(max_hit_count)
|
375
|
+
return matches
|
376
|
+
|
377
|
+
|
378
|
+
@dataclass
|
379
|
+
class MatchTermsAndExpr(MatchTermsBooleanExpr):
|
380
|
+
term_expressions: list[IQueryOpExpr[SemanticRefAccumulator | None]] = field(
|
381
|
+
default_factory=list
|
382
|
+
)
|
383
|
+
get_scope_expr: "GetScopeExpr | None" = None
|
384
|
+
|
385
|
+
async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
|
386
|
+
await self.begin_match(context)
|
387
|
+
all_matches: SemanticRefAccumulator | None = None
|
388
|
+
for match_expr in self.term_expressions:
|
389
|
+
term_matches = await match_expr.eval(context)
|
390
|
+
if not term_matches:
|
391
|
+
if all_matches is not None:
|
392
|
+
all_matches.clear_matches()
|
393
|
+
break
|
394
|
+
if all_matches is None:
|
395
|
+
all_matches = term_matches
|
396
|
+
else:
|
397
|
+
all_matches = all_matches.intersect(term_matches)
|
398
|
+
if all_matches is not None:
|
399
|
+
all_matches.calculate_total_score()
|
400
|
+
all_matches.select_with_hit_count(len(self.term_expressions))
|
401
|
+
else:
|
402
|
+
all_matches = SemanticRefAccumulator()
|
403
|
+
return all_matches
|
404
|
+
|
405
|
+
|
406
|
+
class MatchTermExpr(QueryOpExpr[SemanticRefAccumulator | None], ABC):
|
407
|
+
"""Expression for matching terms in a query.
|
408
|
+
|
409
|
+
Subclasses need to define accumulate_matches(), which must add
|
410
|
+
matches to its SemanticRefAccumulator argument.
|
411
|
+
"""
|
412
|
+
|
413
|
+
async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator | None:
|
414
|
+
matches = SemanticRefAccumulator()
|
415
|
+
await self.accumulate_matches(context, matches)
|
416
|
+
if len(matches) > 0:
|
417
|
+
return matches
|
418
|
+
return None
|
419
|
+
|
420
|
+
@abstractmethod
|
421
|
+
async def accumulate_matches(
|
422
|
+
self, context: QueryEvalContext, matches: SemanticRefAccumulator
|
423
|
+
) -> None: ...
|
424
|
+
|
425
|
+
|
426
|
+
type ScoreBoosterType = Callable[
|
427
|
+
[SearchTerm, SemanticRef, ScoredSemanticRefOrdinal],
|
428
|
+
ScoredSemanticRefOrdinal,
|
429
|
+
]
|
430
|
+
|
431
|
+
|
432
|
+
@dataclass
|
433
|
+
class MatchSearchTermExpr(MatchTermExpr):
|
434
|
+
search_term: SearchTerm
|
435
|
+
score_booster: ScoreBoosterType | None = None
|
436
|
+
|
437
|
+
async def accumulate_matches(
|
438
|
+
self, context: QueryEvalContext, matches: SemanticRefAccumulator
|
439
|
+
) -> None:
|
440
|
+
"""Accumulate matches for the search term and its related terms."""
|
441
|
+
# Match the search term
|
442
|
+
await self.accumulate_matches_for_term(context, matches, self.search_term.term)
|
443
|
+
|
444
|
+
# And any related terms
|
445
|
+
if self.search_term.related_terms is not None:
|
446
|
+
for related_term in self.search_term.related_terms:
|
447
|
+
await self.accumulate_matches_for_term(
|
448
|
+
context, matches, self.search_term.term, related_term
|
449
|
+
)
|
450
|
+
|
451
|
+
async def lookup_term(
|
452
|
+
self, context: QueryEvalContext, term: Term
|
453
|
+
) -> list[ScoredSemanticRefOrdinal] | None:
|
454
|
+
"""Look up a term in the semantic reference index."""
|
455
|
+
matches = await lookup_term(
|
456
|
+
context.semantic_ref_index,
|
457
|
+
term,
|
458
|
+
context.semantic_refs,
|
459
|
+
context.text_ranges_in_scope,
|
460
|
+
)
|
461
|
+
if matches and self.score_booster:
|
462
|
+
for i in range(len(matches)):
|
463
|
+
matches[i] = self.score_booster(
|
464
|
+
self.search_term,
|
465
|
+
await context.get_semantic_ref(matches[i].semantic_ref_ordinal),
|
466
|
+
matches[i],
|
467
|
+
)
|
468
|
+
return matches
|
469
|
+
|
470
|
+
async def accumulate_matches_for_term(
|
471
|
+
self,
|
472
|
+
context: QueryEvalContext,
|
473
|
+
matches: SemanticRefAccumulator,
|
474
|
+
term: Term,
|
475
|
+
related_term: Term | None = None,
|
476
|
+
) -> None:
|
477
|
+
"""Accumulate matches for a term or a related term."""
|
478
|
+
if related_term is None:
|
479
|
+
if term not in context.matched_terms:
|
480
|
+
semantic_refs = await self.lookup_term(context, term)
|
481
|
+
matches.add_term_matches(term, semantic_refs, True)
|
482
|
+
context.matched_terms.add(term)
|
483
|
+
else:
|
484
|
+
if related_term not in context.matched_terms:
|
485
|
+
# If this related term had not already matched as a related term for some other term
|
486
|
+
# Minimize over counting
|
487
|
+
semantic_refs = await self.lookup_term(context, related_term)
|
488
|
+
# This will only consider semantic refs that have not already matched this expression.
|
489
|
+
# In other words, if a semantic ref already matched due to the term 'novel',
|
490
|
+
# don't also match it because it matched the related term 'book'
|
491
|
+
matches.add_term_matches_if_new(
|
492
|
+
term, semantic_refs, False, related_term.weight
|
493
|
+
)
|
494
|
+
context.matched_terms.add(related_term)
|
495
|
+
|
496
|
+
|
497
|
+
@dataclass
|
498
|
+
class MatchPropertySearchTermExpr(MatchTermExpr):
|
499
|
+
property_search_term: PropertySearchTerm
|
500
|
+
|
501
|
+
async def accumulate_matches(
|
502
|
+
self, context: QueryEvalContext, matches: SemanticRefAccumulator
|
503
|
+
) -> None:
|
504
|
+
if isinstance(self.property_search_term.property_name, str):
|
505
|
+
await self.accumulate_matches_for_property(
|
506
|
+
context,
|
507
|
+
self.property_search_term.property_name,
|
508
|
+
self.property_search_term.property_value,
|
509
|
+
matches,
|
510
|
+
)
|
511
|
+
else:
|
512
|
+
await self.accumulate_matches_for_facets(
|
513
|
+
context,
|
514
|
+
self.property_search_term.property_name,
|
515
|
+
self.property_search_term.property_value,
|
516
|
+
matches,
|
517
|
+
)
|
518
|
+
|
519
|
+
async def accumulate_matches_for_facets(
|
520
|
+
self,
|
521
|
+
context: QueryEvalContext,
|
522
|
+
property_name: SearchTerm,
|
523
|
+
property_value: SearchTerm,
|
524
|
+
matches: SemanticRefAccumulator,
|
525
|
+
):
|
526
|
+
await self.accumulate_matches_for_property(
|
527
|
+
context,
|
528
|
+
PropertyNames.FacetName.value,
|
529
|
+
property_name,
|
530
|
+
matches,
|
531
|
+
)
|
532
|
+
if not is_search_term_wildcard(property_value):
|
533
|
+
await self.accumulate_matches_for_property(
|
534
|
+
context,
|
535
|
+
PropertyNames.FacetValue.value,
|
536
|
+
property_value,
|
537
|
+
matches,
|
538
|
+
)
|
539
|
+
|
540
|
+
async def accumulate_matches_for_property(
|
541
|
+
self,
|
542
|
+
context: QueryEvalContext,
|
543
|
+
property_name: str,
|
544
|
+
property_value: SearchTerm,
|
545
|
+
matches: SemanticRefAccumulator,
|
546
|
+
):
|
547
|
+
await self.accumulate_matches_for_property_value(
|
548
|
+
context,
|
549
|
+
matches,
|
550
|
+
property_name,
|
551
|
+
property_value.term,
|
552
|
+
)
|
553
|
+
if property_value.related_terms:
|
554
|
+
for related_property_value in property_value.related_terms:
|
555
|
+
await self.accumulate_matches_for_property_value(
|
556
|
+
context,
|
557
|
+
matches,
|
558
|
+
property_name,
|
559
|
+
property_value.term,
|
560
|
+
related_property_value,
|
561
|
+
)
|
562
|
+
|
563
|
+
async def accumulate_matches_for_property_value(
|
564
|
+
self,
|
565
|
+
context: QueryEvalContext,
|
566
|
+
matches: SemanticRefAccumulator,
|
567
|
+
property_name: str,
|
568
|
+
property_value: Term,
|
569
|
+
related_prop_val: Term | None = None,
|
570
|
+
) -> None:
|
571
|
+
if related_prop_val is None:
|
572
|
+
if not context.matched_property_terms.has(property_name, property_value):
|
573
|
+
semantic_refs = await self.lookup_property(
|
574
|
+
context,
|
575
|
+
property_name,
|
576
|
+
property_value.text,
|
577
|
+
)
|
578
|
+
if semantic_refs:
|
579
|
+
matches.add_term_matches(property_value, semantic_refs, True)
|
580
|
+
context.matched_property_terms.add(property_name, property_value)
|
581
|
+
else:
|
582
|
+
# To prevent over-counting, ensure this related_prop_val was not already used to match terms earlier
|
583
|
+
if not context.matched_property_terms.has(property_name, related_prop_val):
|
584
|
+
semantic_refs = await self.lookup_property(
|
585
|
+
context,
|
586
|
+
property_name,
|
587
|
+
related_prop_val.text,
|
588
|
+
)
|
589
|
+
if semantic_refs:
|
590
|
+
# This will only consider semantic refs that were not already matched by this expression.
|
591
|
+
# In other words, if a semantic ref already matched due to the term 'novel',
|
592
|
+
# don't also match it because it matched the related term 'book'
|
593
|
+
matches.add_term_matches_if_new(
|
594
|
+
property_value,
|
595
|
+
semantic_refs,
|
596
|
+
False,
|
597
|
+
related_prop_val.weight,
|
598
|
+
)
|
599
|
+
context.matched_property_terms.add(property_name, related_prop_val)
|
600
|
+
|
601
|
+
async def lookup_property(
|
602
|
+
self,
|
603
|
+
context: QueryEvalContext,
|
604
|
+
property_name: str,
|
605
|
+
property_value: str,
|
606
|
+
) -> list[ScoredSemanticRefOrdinal] | None:
|
607
|
+
if context.property_index is not None:
|
608
|
+
return await lookup_property_in_property_index(
|
609
|
+
context.property_index,
|
610
|
+
property_name,
|
611
|
+
property_value,
|
612
|
+
context.semantic_refs,
|
613
|
+
context.text_ranges_in_scope,
|
614
|
+
)
|
615
|
+
|
616
|
+
|
617
|
+
class MatchTagExpr(MatchSearchTermExpr):
|
618
|
+
def __init__(self, tag_term: SearchTerm):
|
619
|
+
self.tag_term = tag_term
|
620
|
+
super().__init__(tag_term)
|
621
|
+
|
622
|
+
async def lookup_term(
|
623
|
+
self, context: QueryEvalContext, term: Term
|
624
|
+
) -> list[ScoredSemanticRefOrdinal] | None:
|
625
|
+
if self.tag_term.term.text == "*":
|
626
|
+
return await lookup_knowledge_type(context.semantic_refs, "tag")
|
627
|
+
else:
|
628
|
+
return await lookup_term(
|
629
|
+
context.semantic_ref_index,
|
630
|
+
term,
|
631
|
+
context.semantic_refs,
|
632
|
+
context.text_ranges_in_scope,
|
633
|
+
"tag",
|
634
|
+
)
|
635
|
+
|
636
|
+
|
637
|
+
class MatchTopicExpr(MatchSearchTermExpr):
|
638
|
+
def __init__(self, topic: SearchTerm):
|
639
|
+
self.topic = topic
|
640
|
+
super().__init__(topic)
|
641
|
+
|
642
|
+
async def lookup_term(
|
643
|
+
self, context: QueryEvalContext, term: Term
|
644
|
+
) -> list[ScoredSemanticRefOrdinal] | None:
|
645
|
+
if self.topic.term.text == "*":
|
646
|
+
return await lookup_knowledge_type(context.semantic_refs, "topic")
|
647
|
+
else:
|
648
|
+
return await lookup_term(
|
649
|
+
context.semantic_ref_index,
|
650
|
+
term,
|
651
|
+
context.semantic_refs,
|
652
|
+
context.text_ranges_in_scope,
|
653
|
+
"topic",
|
654
|
+
)
|
655
|
+
|
656
|
+
|
657
|
+
@dataclass
|
658
|
+
class GroupByKnowledgeTypeExpr(
|
659
|
+
QueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
|
660
|
+
):
|
661
|
+
matches: IQueryOpExpr[SemanticRefAccumulator]
|
662
|
+
|
663
|
+
async def eval(
|
664
|
+
self, context: QueryEvalContext
|
665
|
+
) -> dict[KnowledgeType, SemanticRefAccumulator]:
|
666
|
+
semantic_ref_matches = await self.matches.eval(context)
|
667
|
+
return await semantic_ref_matches.group_matches_by_type(context.semantic_refs)
|
668
|
+
|
669
|
+
|
670
|
+
@dataclass
|
671
|
+
class SelectTopNKnowledgeGroupExpr(
|
672
|
+
QueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
|
673
|
+
):
|
674
|
+
source_expr: IQueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
|
675
|
+
max_matches: int | None = None
|
676
|
+
min_hit_count: int | None = None
|
677
|
+
|
678
|
+
async def eval(
|
679
|
+
self, context: QueryEvalContext
|
680
|
+
) -> dict[KnowledgeType, SemanticRefAccumulator]:
|
681
|
+
groups_accumulators = await self.source_expr.eval(context)
|
682
|
+
for accumulator in groups_accumulators.values():
|
683
|
+
accumulator.select_top_n_scoring(self.max_matches, self.min_hit_count)
|
684
|
+
return groups_accumulators
|
685
|
+
|
686
|
+
|
687
|
+
@dataclass
|
688
|
+
class GroupSearchResultsExpr(QueryOpExpr[dict[KnowledgeType, SemanticRefSearchResult]]):
|
689
|
+
src_expr: IQueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
|
690
|
+
|
691
|
+
async def eval(
|
692
|
+
self, context: QueryEvalContext
|
693
|
+
) -> dict[KnowledgeType, SemanticRefSearchResult]:
|
694
|
+
return to_grouped_search_results(await self.src_expr.eval(context))
|
695
|
+
|
696
|
+
|
697
|
+
@dataclass
|
698
|
+
class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]):
|
699
|
+
source_expr: IQueryOpExpr[SemanticRefAccumulator]
|
700
|
+
predicates: list["IQuerySemanticRefPredicate"]
|
701
|
+
|
702
|
+
async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
|
703
|
+
accumulator = await self.source_expr.eval(context)
|
704
|
+
filtered = SemanticRefAccumulator(accumulator.search_term_matches)
|
705
|
+
|
706
|
+
# Filter matches asynchronously
|
707
|
+
filtered_matches = []
|
708
|
+
for match in accumulator.get_matches():
|
709
|
+
if await self._eval_predicates(context, self.predicates, match):
|
710
|
+
filtered_matches.append(match)
|
711
|
+
|
712
|
+
filtered.set_matches(filtered_matches)
|
713
|
+
return filtered
|
714
|
+
|
715
|
+
async def _eval_predicates(
|
716
|
+
self,
|
717
|
+
context: QueryEvalContext,
|
718
|
+
predicates: list["IQuerySemanticRefPredicate"],
|
719
|
+
match: Match[SemanticRefOrdinal],
|
720
|
+
) -> bool:
|
721
|
+
for predicate in predicates:
|
722
|
+
semantic_ref = await context.get_semantic_ref(match.value)
|
723
|
+
if not await predicate.eval(context, semantic_ref):
|
724
|
+
return False
|
725
|
+
|
726
|
+
return True
|
727
|
+
|
728
|
+
|
729
|
+
class IQuerySemanticRefPredicate(Protocol):
|
730
|
+
async def eval(
|
731
|
+
self, context: QueryEvalContext, semantic_ref: SemanticRef
|
732
|
+
) -> bool: ...
|
733
|
+
|
734
|
+
|
735
|
+
# TODO: match_predicates
|
736
|
+
# TODO: knowledge_type_predicate
|
737
|
+
# TODO: property_match_predicate
|
738
|
+
|
739
|
+
|
740
|
+
# NOTE: GetScopeExpr is moved after TextRangeSelector to avoid circular references.
|
741
|
+
|
742
|
+
|
743
|
+
class IQueryTextRangeSelector(Protocol):
|
744
|
+
"""Protocol for a selector that can evaluate to a text range."""
|
745
|
+
|
746
|
+
async def eval(
|
747
|
+
self,
|
748
|
+
context: QueryEvalContext,
|
749
|
+
semantic_refs: SemanticRefAccumulator | None = None,
|
750
|
+
) -> TextRangeCollection | None:
|
751
|
+
"""Evaluate the selector and return the text range."""
|
752
|
+
...
|
753
|
+
|
754
|
+
|
755
|
+
class TextRangeSelector(IQueryTextRangeSelector):
|
756
|
+
"""A selector that evaluates to a pre-computed TextRangeCollection."""
|
757
|
+
|
758
|
+
text_ranges_in_scope: TextRangeCollection
|
759
|
+
|
760
|
+
def __init__(self, ranges_in_scope: list[TextRange]) -> None:
|
761
|
+
self.text_ranges_in_scope = TextRangeCollection(ranges_in_scope, True)
|
762
|
+
|
763
|
+
async def eval(
|
764
|
+
self,
|
765
|
+
context: QueryEvalContext,
|
766
|
+
semantic_refs: SemanticRefAccumulator | None = None,
|
767
|
+
) -> TextRangeCollection | None:
|
768
|
+
return self.text_ranges_in_scope
|
769
|
+
|
770
|
+
|
771
|
+
@dataclass
|
772
|
+
class GetScopeExpr(QueryOpExpr[TextRangesInScope]):
|
773
|
+
"""Expression for getting the scope of a query."""
|
774
|
+
|
775
|
+
range_selectors: list[IQueryTextRangeSelector]
|
776
|
+
|
777
|
+
async def eval(self, context: QueryEvalContext) -> TextRangesInScope:
|
778
|
+
"""Evaluate the expression and return the text ranges in scope."""
|
779
|
+
ranges_in_scope = TextRangesInScope()
|
780
|
+
for selector in self.range_selectors:
|
781
|
+
range_collection = await selector.eval(context)
|
782
|
+
if range_collection is not None:
|
783
|
+
ranges_in_scope.add_text_ranges(range_collection)
|
784
|
+
return ranges_in_scope
|
785
|
+
|
786
|
+
|
787
|
+
# TODO: SelectInScopeExpr
|
788
|
+
|
789
|
+
|
790
|
+
@dataclass
|
791
|
+
class TextRangesInDateRangeSelector(IQueryTextRangeSelector):
|
792
|
+
date_range_in_scope: DateRange
|
793
|
+
|
794
|
+
async def eval(
|
795
|
+
self,
|
796
|
+
context: QueryEvalContext,
|
797
|
+
semantic_refs: SemanticRefAccumulator | None = None,
|
798
|
+
) -> TextRangeCollection | None:
|
799
|
+
"""Evaluate the selector and return text ranges in the specified date range."""
|
800
|
+
text_ranges_in_scope = TextRangeCollection()
|
801
|
+
|
802
|
+
if context.timestamp_index is not None:
|
803
|
+
text_ranges = await context.timestamp_index.lookup_range(
|
804
|
+
self.date_range_in_scope,
|
805
|
+
)
|
806
|
+
for time_range in text_ranges:
|
807
|
+
text_ranges_in_scope.add_range(time_range.range)
|
808
|
+
else:
|
809
|
+
text_range = await get_text_range_for_date_range(
|
810
|
+
context.conversation,
|
811
|
+
self.date_range_in_scope,
|
812
|
+
)
|
813
|
+
if text_range is not None:
|
814
|
+
text_ranges_in_scope.add_range(text_range)
|
815
|
+
|
816
|
+
return text_ranges_in_scope
|
817
|
+
|
818
|
+
|
819
|
+
# TODO: TextRangesPredicateSelector
|
820
|
+
# TODO: TextRangesWithTagSelector
|
821
|
+
# TODO: TextRangesFromSemanticRefsSelector
|
822
|
+
|
823
|
+
|
824
|
+
@dataclass
|
825
|
+
class TextRangesFromMessagesSelector(IQueryTextRangeSelector):
|
826
|
+
source_expr: IQueryOpExpr[MessageAccumulator]
|
827
|
+
|
828
|
+
async def eval(
|
829
|
+
self,
|
830
|
+
context: QueryEvalContext,
|
831
|
+
semantic_refs: SemanticRefAccumulator | None = None,
|
832
|
+
) -> TextRangeCollection | None:
|
833
|
+
matches = await self.source_expr.eval(context)
|
834
|
+
ranges_in_scope: list[TextRange] | None = None
|
835
|
+
if matches:
|
836
|
+
all_ordinals = sorted(matches.get_matched_values())
|
837
|
+
ranges_in_scope = text_ranges_from_message_ordinals(all_ordinals)
|
838
|
+
return TextRangeCollection(ranges_in_scope)
|
839
|
+
|
840
|
+
|
841
|
+
# TODO: Move to messagelib.py
|
842
|
+
def text_ranges_from_message_ordinals(
|
843
|
+
message_ordinals: list[MessageOrdinal],
|
844
|
+
) -> list[TextRange]:
|
845
|
+
return [text_range_from_message(ordinal) for ordinal in message_ordinals]
|
846
|
+
|
847
|
+
|
848
|
+
# TODO: Move to messagelib.py
|
849
|
+
def text_range_from_message(message_ordinal: MessageOrdinal) -> TextRange:
|
850
|
+
return TextRange(start=TextLocation(message_ordinal))
|
851
|
+
|
852
|
+
|
853
|
+
# TODO: ThreadSelector
|
854
|
+
# TODO: to_grouped_search_results
|
855
|
+
|
856
|
+
|
857
|
+
def to_grouped_search_results(
|
858
|
+
eval_results: dict[KnowledgeType, SemanticRefAccumulator],
|
859
|
+
) -> dict[KnowledgeType, SemanticRefSearchResult]:
|
860
|
+
semantic_ref_matches: dict[KnowledgeType, SemanticRefSearchResult] = {}
|
861
|
+
for typ, accumulator in eval_results.items():
|
862
|
+
if len(accumulator) > 0:
|
863
|
+
semantic_ref_matches[typ] = SemanticRefSearchResult(
|
864
|
+
term_matches=accumulator.search_term_matches,
|
865
|
+
semantic_ref_matches=accumulator.to_scored_semantic_refs(),
|
866
|
+
)
|
867
|
+
return semantic_ref_matches
|
868
|
+
|
869
|
+
|
870
|
+
@dataclass
|
871
|
+
class MessagesFromKnowledgeExpr(QueryOpExpr[MessageAccumulator]):
|
872
|
+
src_expr: (
|
873
|
+
IQueryOpExpr[dict[KnowledgeType, SemanticRefSearchResult]]
|
874
|
+
| dict[KnowledgeType, SemanticRefSearchResult]
|
875
|
+
)
|
876
|
+
|
877
|
+
async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
|
878
|
+
knowledge = (
|
879
|
+
self.src_expr
|
880
|
+
if isinstance(self.src_expr, dict)
|
881
|
+
else await self.src_expr.eval(context)
|
882
|
+
)
|
883
|
+
return await message_matches_from_knowledge_matches(
|
884
|
+
context.semantic_refs, knowledge
|
885
|
+
)
|
886
|
+
|
887
|
+
|
888
|
+
# TODO: SelectMessagesInCharBudget
|
889
|
+
|
890
|
+
|
891
|
+
@dataclass
|
892
|
+
class RankMessagesBySimilarityExpr(QueryOpExpr[MessageAccumulator]):
|
893
|
+
src_expr: IQueryOpExpr[MessageAccumulator]
|
894
|
+
embedding: NormalizedEmbedding
|
895
|
+
max_messages: int | None = None
|
896
|
+
threshold_score: float | None = None
|
897
|
+
|
898
|
+
async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
|
899
|
+
matches = await self.src_expr.eval(context)
|
900
|
+
if self.max_messages is not None and len(matches) <= self.max_messages:
|
901
|
+
return matches
|
902
|
+
|
903
|
+
# Try to use the message embedding index for re-ranking if available.
|
904
|
+
message_index = (
|
905
|
+
None
|
906
|
+
if context.conversation.secondary_indexes is None
|
907
|
+
else context.conversation.secondary_indexes.message_index
|
908
|
+
)
|
909
|
+
if isinstance(message_index, IMessageTextEmbeddingIndex):
|
910
|
+
message_ordinals = await self._get_message_ordinals_in_index(
|
911
|
+
message_index, matches
|
912
|
+
)
|
913
|
+
if len(message_ordinals) == len(matches):
|
914
|
+
matches.clear_matches()
|
915
|
+
ranked_messages = message_index.lookup_in_subset_by_embedding(
|
916
|
+
self.embedding,
|
917
|
+
message_ordinals,
|
918
|
+
self.max_messages,
|
919
|
+
self.threshold_score,
|
920
|
+
)
|
921
|
+
for match in ranked_messages:
|
922
|
+
matches.add(match.message_ordinal, match.score)
|
923
|
+
return matches
|
924
|
+
|
925
|
+
if self.max_messages is not None:
|
926
|
+
# Can't re rank, so just take the top K from what we already have.
|
927
|
+
matches.select_top_n_scoring(self.max_messages)
|
928
|
+
return matches
|
929
|
+
|
930
|
+
async def _get_message_ordinals_in_index(
|
931
|
+
self, message_index, matches: MessageAccumulator
|
932
|
+
):
|
933
|
+
message_ordinals: list[MessageOrdinal] = []
|
934
|
+
index_size = await message_index.size()
|
935
|
+
for message_ordinal in matches.get_matched_values():
|
936
|
+
if message_ordinal >= index_size:
|
937
|
+
break
|
938
|
+
message_ordinals.append(message_ordinal)
|
939
|
+
return message_ordinals
|
940
|
+
|
941
|
+
|
942
|
+
@dataclass
|
943
|
+
class GetScoredMessagesExpr(QueryOpExpr[list[ScoredMessageOrdinal]]):
|
944
|
+
src_expr: IQueryOpExpr[MessageAccumulator]
|
945
|
+
|
946
|
+
async def eval(self, context: QueryEvalContext) -> list[ScoredMessageOrdinal]:
|
947
|
+
matches = await self.src_expr.eval(context)
|
948
|
+
return matches.to_scored_message_ordinals()
|
949
|
+
|
950
|
+
|
951
|
+
@dataclass
|
952
|
+
class MatchMessagesBooleanExpr(IQueryOpExpr[MessageAccumulator]):
|
953
|
+
term_expressions: list[
|
954
|
+
IQueryOpExpr[SemanticRefAccumulator | MessageAccumulator | None]
|
955
|
+
]
|
956
|
+
|
957
|
+
def _begin_match(self, context: QueryEvalContext) -> None:
|
958
|
+
context.clear_matched_terms()
|
959
|
+
|
960
|
+
async def _accumulate_messages(
|
961
|
+
self,
|
962
|
+
context: QueryEvalContext,
|
963
|
+
semantic_ref_matches: SemanticRefAccumulator,
|
964
|
+
) -> MessageAccumulator:
|
965
|
+
message_matches = MessageAccumulator()
|
966
|
+
for semantic_ref_match in semantic_ref_matches:
|
967
|
+
semantic_ref = await context.get_semantic_ref(semantic_ref_match.value)
|
968
|
+
message_matches.add_messages_for_semantic_ref(
|
969
|
+
semantic_ref,
|
970
|
+
semantic_ref_match.score,
|
971
|
+
)
|
972
|
+
return message_matches
|
973
|
+
|
974
|
+
|
975
|
+
@dataclass
|
976
|
+
class MatchMessagesOrExpr(MatchMessagesBooleanExpr):
|
977
|
+
|
978
|
+
async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
|
979
|
+
self._begin_match(context)
|
980
|
+
|
981
|
+
all_matches: MessageAccumulator | None = None
|
982
|
+
for match_expr in self.term_expressions:
|
983
|
+
matches = await match_expr.eval(context)
|
984
|
+
if not matches:
|
985
|
+
continue
|
986
|
+
if isinstance(matches, SemanticRefAccumulator):
|
987
|
+
message_matches = await self._accumulate_messages(context, matches)
|
988
|
+
else:
|
989
|
+
message_matches = matches
|
990
|
+
if all_matches is not None:
|
991
|
+
all_matches.add_union(message_matches)
|
992
|
+
else:
|
993
|
+
all_matches = message_matches
|
994
|
+
if all_matches is not None:
|
995
|
+
all_matches.calculate_total_score()
|
996
|
+
else:
|
997
|
+
all_matches = MessageAccumulator()
|
998
|
+
return all_matches
|
999
|
+
|
1000
|
+
|
1001
|
+
@dataclass
|
1002
|
+
class MatchMessagesAndExpr(MatchMessagesBooleanExpr):
|
1003
|
+
|
1004
|
+
async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
|
1005
|
+
self._begin_match(context)
|
1006
|
+
|
1007
|
+
all_matches: MessageAccumulator | None = None
|
1008
|
+
all_done = False
|
1009
|
+
for match_expr in self.term_expressions:
|
1010
|
+
matches = await match_expr.eval(context)
|
1011
|
+
if not matches:
|
1012
|
+
# If any expr does not match, the AND fails.
|
1013
|
+
break
|
1014
|
+
if isinstance(matches, SemanticRefAccumulator):
|
1015
|
+
message_matches = await self._accumulate_messages(context, matches)
|
1016
|
+
else:
|
1017
|
+
message_matches = matches
|
1018
|
+
if all_matches is None:
|
1019
|
+
all_matches = message_matches
|
1020
|
+
else:
|
1021
|
+
# Intersect the message matches
|
1022
|
+
all_matches = all_matches.intersect(message_matches)
|
1023
|
+
if not all_matches:
|
1024
|
+
# If the intersection is empty, we can stop early.
|
1025
|
+
break
|
1026
|
+
else:
|
1027
|
+
# If we did not break, all terms matched.
|
1028
|
+
all_done = True
|
1029
|
+
|
1030
|
+
if all_matches is not None:
|
1031
|
+
if all_done:
|
1032
|
+
all_matches.calculate_total_score()
|
1033
|
+
all_matches.select_with_hit_count(len(self.term_expressions))
|
1034
|
+
else:
|
1035
|
+
all_matches.clear_matches()
|
1036
|
+
else:
|
1037
|
+
all_matches = MessageAccumulator()
|
1038
|
+
return all_matches
|
1039
|
+
|
1040
|
+
|
1041
|
+
@dataclass
|
1042
|
+
class MatchMessagesOrMaxExpr(MatchMessagesOrExpr):
|
1043
|
+
|
1044
|
+
async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
|
1045
|
+
matches = await super().eval(context)
|
1046
|
+
max_hit_count = matches.get_max_hit_count()
|
1047
|
+
if max_hit_count > 1:
|
1048
|
+
matches.select_with_hit_count(max_hit_count)
|
1049
|
+
return matches
|
1050
|
+
|
1051
|
+
|
1052
|
+
# TODO: class MatchMessagesBySimilarityExpr(QueryOpExpr[list[ScoredMessageOrdinal]]):
|
1053
|
+
|
1054
|
+
|
1055
|
+
class NoOpExpr[T](QueryOpExpr[T]):
|
1056
|
+
def __init__(self, src_expr: IQueryOpExpr[T]) -> None:
|
1057
|
+
self.src_expr = src_expr
|
1058
|
+
super().__init__()
|
1059
|
+
|
1060
|
+
async def eval(self, context: QueryEvalContext) -> T:
|
1061
|
+
return await self.src_expr.eval(context)
|
1062
|
+
|
1063
|
+
|
1064
|
+
async def message_matches_from_knowledge_matches(
|
1065
|
+
semantic_refs: ISemanticRefCollection,
|
1066
|
+
knowledge_matches: dict[KnowledgeType, SemanticRefSearchResult],
|
1067
|
+
intersect_across_knowledge_types: bool = True,
|
1068
|
+
) -> MessageAccumulator:
|
1069
|
+
message_matches = MessageAccumulator()
|
1070
|
+
knowledge_type_hit_count = 0 # How many types of knowledge matched?
|
1071
|
+
for knowledge_type, matches_by_type in knowledge_matches.items():
|
1072
|
+
if matches_by_type and matches_by_type.semantic_ref_matches:
|
1073
|
+
knowledge_type_hit_count += 1
|
1074
|
+
for match in matches_by_type.semantic_ref_matches:
|
1075
|
+
message_matches.add_messages_for_semantic_ref(
|
1076
|
+
await semantic_refs.get_item(match.semantic_ref_ordinal),
|
1077
|
+
match.score,
|
1078
|
+
)
|
1079
|
+
if intersect_across_knowledge_types and knowledge_type_hit_count > 0:
|
1080
|
+
# Intersect the sets of messages that matched each knowledge type
|
1081
|
+
relevant_messages = message_matches.get_with_hit_count(knowledge_type_hit_count)
|
1082
|
+
if relevant_messages:
|
1083
|
+
message_matches = MessageAccumulator(relevant_messages)
|
1084
|
+
message_matches.smooth_scores()
|
1085
|
+
return message_matches
|
1086
|
+
|
1087
|
+
|
1088
|
+
# TODO: Implement proper SelectMessagesInCharBudget functionality
|
1089
|
+
@dataclass
|
1090
|
+
class SelectMessagesInCharBudget(QueryOpExpr[MessageAccumulator]):
|
1091
|
+
"""Selects messages within a character budget."""
|
1092
|
+
|
1093
|
+
src_expr: IQueryOpExpr[MessageAccumulator]
|
1094
|
+
max_chars: int
|
1095
|
+
|
1096
|
+
async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
|
1097
|
+
matches = await self.src_expr.eval(context)
|
1098
|
+
await matches.select_messages_in_budget(context.messages, self.max_chars)
|
1099
|
+
return matches
|
1100
|
+
|
1101
|
+
|
1102
|
+
# TODO: Implement proper KnowledgeTypePredicate functionality
|
1103
|
+
@dataclass
|
1104
|
+
class KnowledgeTypePredicate(IQuerySemanticRefPredicate):
|
1105
|
+
"""Predicate to filter by knowledge type."""
|
1106
|
+
|
1107
|
+
knowledge_type: KnowledgeType
|
1108
|
+
|
1109
|
+
async def eval(self, context: QueryEvalContext, semantic_ref: SemanticRef) -> bool:
|
1110
|
+
return semantic_ref.knowledge.knowledge_type == self.knowledge_type
|
1111
|
+
|
1112
|
+
|
1113
|
+
# TODO: Implement proper ThreadSelector functionality
|
1114
|
+
@dataclass
|
1115
|
+
class ThreadSelector(IQueryTextRangeSelector):
|
1116
|
+
"""Selector for thread-based text ranges."""
|
1117
|
+
|
1118
|
+
threads: list[Thread]
|
1119
|
+
|
1120
|
+
async def eval(
|
1121
|
+
self,
|
1122
|
+
context: QueryEvalContext,
|
1123
|
+
semantic_refs: SemanticRefAccumulator | None = None,
|
1124
|
+
) -> TextRangeCollection | None:
|
1125
|
+
text_ranges = TextRangeCollection()
|
1126
|
+
for thread in self.threads:
|
1127
|
+
text_ranges.add_ranges(list(thread.ranges))
|
1128
|
+
return text_ranges
|