typeagent-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. typeagent/aitools/auth.py +61 -0
  2. typeagent/aitools/embeddings.py +232 -0
  3. typeagent/aitools/utils.py +244 -0
  4. typeagent/aitools/vectorbase.py +175 -0
  5. typeagent/knowpro/answer_context_schema.py +49 -0
  6. typeagent/knowpro/answer_response_schema.py +34 -0
  7. typeagent/knowpro/answers.py +577 -0
  8. typeagent/knowpro/collections.py +759 -0
  9. typeagent/knowpro/common.py +9 -0
  10. typeagent/knowpro/convknowledge.py +112 -0
  11. typeagent/knowpro/convsettings.py +94 -0
  12. typeagent/knowpro/convutils.py +49 -0
  13. typeagent/knowpro/date_time_schema.py +32 -0
  14. typeagent/knowpro/field_helpers.py +87 -0
  15. typeagent/knowpro/fuzzyindex.py +144 -0
  16. typeagent/knowpro/interfaces.py +818 -0
  17. typeagent/knowpro/knowledge.py +88 -0
  18. typeagent/knowpro/kplib.py +125 -0
  19. typeagent/knowpro/query.py +1128 -0
  20. typeagent/knowpro/search.py +628 -0
  21. typeagent/knowpro/search_query_schema.py +165 -0
  22. typeagent/knowpro/searchlang.py +729 -0
  23. typeagent/knowpro/searchlib.py +345 -0
  24. typeagent/knowpro/secindex.py +100 -0
  25. typeagent/knowpro/serialization.py +390 -0
  26. typeagent/knowpro/textlocindex.py +179 -0
  27. typeagent/knowpro/utils.py +17 -0
  28. typeagent/mcp/server.py +139 -0
  29. typeagent/podcasts/podcast.py +473 -0
  30. typeagent/podcasts/podcast_import.py +105 -0
  31. typeagent/storage/__init__.py +25 -0
  32. typeagent/storage/memory/__init__.py +13 -0
  33. typeagent/storage/memory/collections.py +68 -0
  34. typeagent/storage/memory/convthreads.py +81 -0
  35. typeagent/storage/memory/messageindex.py +178 -0
  36. typeagent/storage/memory/propindex.py +289 -0
  37. typeagent/storage/memory/provider.py +84 -0
  38. typeagent/storage/memory/reltermsindex.py +318 -0
  39. typeagent/storage/memory/semrefindex.py +660 -0
  40. typeagent/storage/memory/timestampindex.py +176 -0
  41. typeagent/storage/sqlite/__init__.py +31 -0
  42. typeagent/storage/sqlite/collections.py +362 -0
  43. typeagent/storage/sqlite/messageindex.py +382 -0
  44. typeagent/storage/sqlite/propindex.py +119 -0
  45. typeagent/storage/sqlite/provider.py +293 -0
  46. typeagent/storage/sqlite/reltermsindex.py +328 -0
  47. typeagent/storage/sqlite/schema.py +248 -0
  48. typeagent/storage/sqlite/semrefindex.py +156 -0
  49. typeagent/storage/sqlite/timestampindex.py +146 -0
  50. typeagent/storage/utils.py +41 -0
  51. typeagent_py-0.1.0.dist-info/METADATA +28 -0
  52. typeagent_py-0.1.0.dist-info/RECORD +55 -0
  53. typeagent_py-0.1.0.dist-info/WHEEL +5 -0
  54. typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
  55. typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1128 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Callable
6
+ from dataclasses import dataclass, field
7
+ from re import search
8
+ from typing import Literal, Protocol, cast
9
+
10
+ from ..aitools.embeddings import NormalizedEmbedding
11
+
12
+ from .collections import (
13
+ Match,
14
+ MatchAccumulator,
15
+ MessageAccumulator,
16
+ PropertyTermSet,
17
+ SemanticRefAccumulator,
18
+ TermSet,
19
+ TextRangeCollection,
20
+ TextRangesInScope,
21
+ )
22
+ from .common import is_search_term_wildcard
23
+ from .interfaces import (
24
+ Datetime,
25
+ DateRange,
26
+ IConversation,
27
+ IMessage,
28
+ IMessageCollection,
29
+ IPropertyToSemanticRefIndex,
30
+ ISemanticRefCollection,
31
+ ITermToSemanticRefIndex,
32
+ ITimestampToTextRangeIndex,
33
+ KnowledgeType,
34
+ MessageOrdinal,
35
+ PropertySearchTerm,
36
+ ScoredMessageOrdinal,
37
+ ScoredSemanticRefOrdinal,
38
+ SearchTerm,
39
+ SearchTermGroup,
40
+ SemanticRef,
41
+ SemanticRefOrdinal,
42
+ SemanticRefSearchResult,
43
+ Term,
44
+ TextLocation,
45
+ TextRange,
46
+ Thread,
47
+ )
48
+ from .kplib import ConcreteEntity
49
+ from ..storage.memory.messageindex import IMessageTextEmbeddingIndex
50
+ from ..storage.memory.propindex import PropertyNames, lookup_property_in_property_index
51
+ from .searchlib import create_property_search_term, create_tag_search_term_group
52
+
53
+
54
+ # TODO: Move to compilelib.py
55
+ type BooleanOp = Literal["and", "or", "or_max"]
56
+
57
+
58
+ # TODO: Move to compilelib.py
59
+ @dataclass
60
+ class CompiledSearchTerm(SearchTerm):
61
+ related_terms_required: bool = False
62
+
63
+
64
+ # TODO: Move to compilelib.py
65
+ def to_required_search_term(term: SearchTerm) -> CompiledSearchTerm:
66
+ # NOTE: We must cast since the output must alias the input.
67
+ # If not, assignments to related_terms will be lost.
68
+ cst = cast(CompiledSearchTerm, term)
69
+ cst.related_terms_required = True
70
+ return cst
71
+
72
+
73
+ # TODO: Move to compilelib.py
74
+ def to_non_required_search_term(term: SearchTerm) -> CompiledSearchTerm:
75
+ # NOTE: We must cast since the output must alias the input.
76
+ # If not, assignments to related_terms will be lost.
77
+ cst = cast(CompiledSearchTerm, term)
78
+ cst.related_terms_required = False
79
+ return cst
80
+
81
+
82
+ # TODO: Move to compilelib.py
83
+ @dataclass
84
+ class CompiledTermGroup:
85
+ boolean_op: BooleanOp
86
+ terms: list[CompiledSearchTerm]
87
+
88
+
89
+ def is_conversation_searchable(conversation: IConversation) -> bool:
90
+ """Determine if a conversation is searchable.
91
+
92
+ A conversation is searchable if it has a semantic reference index
93
+ and semantic references initialized.
94
+ """
95
+ # TODO: also require secondary indices, once we have removed non-index based retrieval to test.
96
+ return (
97
+ conversation.semantic_ref_index is not None
98
+ and conversation.semantic_refs is not None
99
+ )
100
+
101
+
102
+ async def get_text_range_for_date_range(
103
+ conversation: IConversation,
104
+ date_range: DateRange,
105
+ ) -> TextRange | None:
106
+ messages = conversation.messages
107
+ range_start_ordinal: MessageOrdinal = -1
108
+ range_end_ordinal = range_start_ordinal
109
+ async for message in messages:
110
+ if Datetime.fromisoformat(message.timestamp) in date_range:
111
+ if range_start_ordinal < 0:
112
+ range_start_ordinal = message.ordinal
113
+ range_end_ordinal = message.ordinal
114
+ else:
115
+ if range_start_ordinal >= 0:
116
+ # We have a range, so break.
117
+ break
118
+ if range_start_ordinal >= 0:
119
+ return TextRange(
120
+ start=TextLocation(range_start_ordinal),
121
+ end=TextLocation(range_end_ordinal + 1),
122
+ )
123
+ return None
124
+
125
+
126
+ def get_matching_term_for_text(search_term: SearchTerm, text: str) -> Term | None:
127
+ # Do case-INSENSITIVE comparisons, since stored entities may have different case.
128
+ if text.lower() == search_term.term.text.lower():
129
+ return search_term.term
130
+ if search_term.related_terms:
131
+ for related_term in search_term.related_terms:
132
+ if text.lower() == related_term.text.lower():
133
+ return related_term
134
+ return None
135
+
136
+
137
+ def match_search_term_to_text(search_term: SearchTerm, text: str | None) -> bool:
138
+ if text:
139
+ return get_matching_term_for_text(search_term, text) is not None
140
+ return False
141
+
142
+
143
+ def match_search_term_to_one_of_text(
144
+ search_term: SearchTerm, texts: list[str] | None
145
+ ) -> bool:
146
+ if texts:
147
+ for text in texts:
148
+ if match_search_term_to_text(search_term, text):
149
+ return True
150
+ return False
151
+
152
+
153
+ # TODO: match_search_term_to_entity
154
+ # TODO: match_property_search_term_to_entity
155
+ # TODO: match_concrete_entity
156
+
157
+
158
+ def match_entity_name_or_type(
159
+ property_value: SearchTerm,
160
+ entity: ConcreteEntity,
161
+ ) -> bool:
162
+ return match_search_term_to_text(
163
+ property_value, entity.name
164
+ ) or match_search_term_to_one_of_text(property_value, entity.type)
165
+
166
+
167
+ # TODO: match_property_name_to_facet_name
168
+ # TODO: match_property_name_to_facet_value
169
+ # TODO: match_property_search_term_to_action
170
+ # TODO: match_property_search_term_to_tag
171
+ # TODO: match_property_search_term_to_semantic_ref
172
+
173
+
174
+ async def lookup_term_filtered(
175
+ semantic_ref_index: ITermToSemanticRefIndex,
176
+ term: Term,
177
+ semantic_refs: ISemanticRefCollection,
178
+ filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool],
179
+ ) -> list[ScoredSemanticRefOrdinal] | None:
180
+ """Look up a term in the semantic reference index and filter the results."""
181
+ scored_refs = await semantic_ref_index.lookup_term(term.text)
182
+ if scored_refs:
183
+ filtered = []
184
+ for sr in scored_refs:
185
+ semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal)
186
+ if filter(semantic_ref, sr):
187
+ filtered.append(sr)
188
+ return filtered
189
+ return None
190
+
191
+
192
+ async def lookup_term(
193
+ semantic_ref_index: ITermToSemanticRefIndex,
194
+ term: Term,
195
+ semantic_refs: ISemanticRefCollection,
196
+ ranges_in_scope: TextRangesInScope | None = None,
197
+ knowledge_type: KnowledgeType | None = None,
198
+ ) -> list[ScoredSemanticRefOrdinal] | None:
199
+ """Look up a term in the semantic reference index, optionally filtering by ranges in scope."""
200
+ if ranges_in_scope is not None:
201
+ # If ranges_in_scope has no actual text ranges, lookups can't possibly match.
202
+ return await lookup_term_filtered(
203
+ semantic_ref_index,
204
+ term,
205
+ semantic_refs,
206
+ lambda sr, _: (
207
+ not knowledge_type or sr.knowledge.knowledge_type == knowledge_type
208
+ )
209
+ and ranges_in_scope.is_range_in_scope(sr.range),
210
+ )
211
+ return await semantic_ref_index.lookup_term(term.text)
212
+
213
+
214
+ # TODO: lookup_property
215
+ # TODO: lookup_knowledge_type
216
+
217
+
218
+ @dataclass
219
+ class QueryEvalContext[TMessage: IMessage, TIndex: ITermToSemanticRefIndex]:
220
+ """Context for evaluating a query within a conversation.
221
+
222
+ This class provides the necessary context for query evaluation, including
223
+ the conversation being queried, optional indexes for properties and timestamps,
224
+ and structures for tracking matched terms and text ranges in scope.
225
+ """
226
+
227
+ # TODO: Make property and timestamp indexes NON-OPTIONAL
228
+ # TODO: Move non-index based code to test
229
+
230
+ conversation: IConversation[TMessage, TIndex]
231
+ # If a property secondary index is available, the query processor will use it
232
+ property_index: IPropertyToSemanticRefIndex | None = None
233
+ # If a timestamp secondary index is available, the query processor will use it
234
+ timestamp_index: ITimestampToTextRangeIndex | None = None
235
+ matched_terms: TermSet = field(init=False, default_factory=TermSet)
236
+ matched_property_terms: PropertyTermSet = field(
237
+ init=False, default_factory=PropertyTermSet
238
+ )
239
+ text_ranges_in_scope: TextRangesInScope | None = field(
240
+ init=False, default_factory=TextRangesInScope
241
+ )
242
+
243
+ def __post_init__(self):
244
+ if not is_conversation_searchable(self.conversation):
245
+ raise ValueError(
246
+ f"{self.conversation.name_tag} "
247
+ + "is not initialized and cannot be searched."
248
+ )
249
+
250
+ @property
251
+ def semantic_ref_index(self) -> ITermToSemanticRefIndex:
252
+ assert self.conversation.semantic_ref_index is not None
253
+ return self.conversation.semantic_ref_index
254
+
255
+ @property
256
+ def semantic_refs(self) -> ISemanticRefCollection:
257
+ assert self.conversation.semantic_refs is not None
258
+ return self.conversation.semantic_refs
259
+
260
+ @property
261
+ def messages(self) -> IMessageCollection:
262
+ return self.conversation.messages
263
+
264
+ async def get_semantic_ref(
265
+ self, semantic_ref_ordinal: SemanticRefOrdinal
266
+ ) -> SemanticRef:
267
+ """Retrieve a semantic reference by its ordinal."""
268
+ assert self.conversation.semantic_refs is not None
269
+ return await self.conversation.semantic_refs.get_item(semantic_ref_ordinal)
270
+
271
+ async def get_message_for_ref(self, semantic_ref: SemanticRef) -> TMessage:
272
+ """Retrieve the message associated with a semantic reference."""
273
+ message_ordinal = semantic_ref.range.start.message_ordinal
274
+ return await self.conversation.messages.get_item(message_ordinal)
275
+
276
+ async def get_message(self, message_ordinal: MessageOrdinal) -> TMessage:
277
+ """Retrieve a message by its ordinal."""
278
+ return await self.messages.get_item(message_ordinal)
279
+
280
+ def clear_matched_terms(self) -> None:
281
+ """Clear all matched terms and property terms."""
282
+ self.matched_terms.clear()
283
+ self.matched_property_terms.clear()
284
+
285
+
286
+ async def lookup_knowledge_type(
287
+ semantic_refs: ISemanticRefCollection, knowledge_type: KnowledgeType
288
+ ) -> list[ScoredSemanticRefOrdinal]:
289
+ return [
290
+ ScoredSemanticRefOrdinal(sr.semantic_ref_ordinal, 1.0)
291
+ async for sr in semantic_refs
292
+ if sr.knowledge.knowledge_type == knowledge_type
293
+ ]
294
+
295
+
296
+ class IQueryOpExpr[T](Protocol):
297
+ """Protocol for query operation expressions that can be evaluated in a context."""
298
+
299
+ async def eval(self, context: QueryEvalContext) -> T: ...
300
+
301
+
302
+ class QueryOpExpr[T](IQueryOpExpr[T]):
303
+ """Base class for query operation expressions."""
304
+
305
+
306
+ @dataclass
307
+ class SelectTopNExpr[T: MatchAccumulator](QueryOpExpr[T]):
308
+ """Expression for selecting the top N matches from a query."""
309
+
310
+ source_expr: IQueryOpExpr[T]
311
+ max_matches: int | None = None
312
+ min_hit_count: int | None = None
313
+
314
+ async def eval(self, context: QueryEvalContext) -> T:
315
+ """Evaluate the expression and return the top N matches."""
316
+ matches = await self.source_expr.eval(context)
317
+ matches.select_top_n_scoring(self.max_matches, self.min_hit_count)
318
+ return matches
319
+
320
+
321
+ # Abstract base class.
322
+ class MatchTermsBooleanExpr(QueryOpExpr[SemanticRefAccumulator]):
323
+ """Expression for matching terms in a boolean query.
324
+
325
+ Subclasses implement 'OR', 'OR MAX' and 'AND' logic.
326
+ """
327
+
328
+ get_scope_expr: "GetScopeExpr | None" = None
329
+
330
+ async def begin_match(self, context: QueryEvalContext) -> None:
331
+ """Prepare for matching terms in the context by resetting some things."""
332
+ if self.get_scope_expr is not None:
333
+ context.text_ranges_in_scope = await self.get_scope_expr.eval(context)
334
+ context.clear_matched_terms()
335
+
336
+
337
+ @dataclass
338
+ class MatchTermsOrExpr(MatchTermsBooleanExpr):
339
+ """Expression for matching terms with an OR condition."""
340
+
341
+ term_expressions: list[IQueryOpExpr[SemanticRefAccumulator | None]] = field(
342
+ default_factory=list
343
+ )
344
+ get_scope_expr: "GetScopeExpr | None" = None
345
+
346
+ async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
347
+ await self.begin_match(context)
348
+ all_matches: SemanticRefAccumulator | None = None
349
+ for match_expr in self.term_expressions:
350
+ term_matches = await match_expr.eval(context)
351
+ if term_matches:
352
+ if all_matches is None:
353
+ all_matches = term_matches
354
+ else:
355
+ all_matches.add_union(term_matches)
356
+ if all_matches is not None:
357
+ all_matches.calculate_total_score()
358
+ return all_matches or SemanticRefAccumulator()
359
+
360
+
361
+ @dataclass
362
+ class MatchTermsOrMaxExpr(MatchTermsOrExpr):
363
+ """OR-MAX returns the union if there are no common matches, else the maximum scoring match."""
364
+
365
+ term_expressions: list[IQueryOpExpr[SemanticRefAccumulator | None]] = field(
366
+ default_factory=list
367
+ )
368
+ get_scope_expr: "GetScopeExpr | None" = None
369
+
370
+ async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
371
+ matches = await super().eval(context)
372
+ max_hit_count = matches.get_max_hit_count()
373
+ if max_hit_count > 1:
374
+ matches.select_with_hit_count(max_hit_count)
375
+ return matches
376
+
377
+
378
+ @dataclass
379
+ class MatchTermsAndExpr(MatchTermsBooleanExpr):
380
+ term_expressions: list[IQueryOpExpr[SemanticRefAccumulator | None]] = field(
381
+ default_factory=list
382
+ )
383
+ get_scope_expr: "GetScopeExpr | None" = None
384
+
385
+ async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
386
+ await self.begin_match(context)
387
+ all_matches: SemanticRefAccumulator | None = None
388
+ for match_expr in self.term_expressions:
389
+ term_matches = await match_expr.eval(context)
390
+ if not term_matches:
391
+ if all_matches is not None:
392
+ all_matches.clear_matches()
393
+ break
394
+ if all_matches is None:
395
+ all_matches = term_matches
396
+ else:
397
+ all_matches = all_matches.intersect(term_matches)
398
+ if all_matches is not None:
399
+ all_matches.calculate_total_score()
400
+ all_matches.select_with_hit_count(len(self.term_expressions))
401
+ else:
402
+ all_matches = SemanticRefAccumulator()
403
+ return all_matches
404
+
405
+
406
+ class MatchTermExpr(QueryOpExpr[SemanticRefAccumulator | None], ABC):
407
+ """Expression for matching terms in a query.
408
+
409
+ Subclasses need to define accumulate_matches(), which must add
410
+ matches to its SemanticRefAccumulator argument.
411
+ """
412
+
413
+ async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator | None:
414
+ matches = SemanticRefAccumulator()
415
+ await self.accumulate_matches(context, matches)
416
+ if len(matches) > 0:
417
+ return matches
418
+ return None
419
+
420
+ @abstractmethod
421
+ async def accumulate_matches(
422
+ self, context: QueryEvalContext, matches: SemanticRefAccumulator
423
+ ) -> None: ...
424
+
425
+
426
+ type ScoreBoosterType = Callable[
427
+ [SearchTerm, SemanticRef, ScoredSemanticRefOrdinal],
428
+ ScoredSemanticRefOrdinal,
429
+ ]
430
+
431
+
432
+ @dataclass
433
+ class MatchSearchTermExpr(MatchTermExpr):
434
+ search_term: SearchTerm
435
+ score_booster: ScoreBoosterType | None = None
436
+
437
+ async def accumulate_matches(
438
+ self, context: QueryEvalContext, matches: SemanticRefAccumulator
439
+ ) -> None:
440
+ """Accumulate matches for the search term and its related terms."""
441
+ # Match the search term
442
+ await self.accumulate_matches_for_term(context, matches, self.search_term.term)
443
+
444
+ # And any related terms
445
+ if self.search_term.related_terms is not None:
446
+ for related_term in self.search_term.related_terms:
447
+ await self.accumulate_matches_for_term(
448
+ context, matches, self.search_term.term, related_term
449
+ )
450
+
451
+ async def lookup_term(
452
+ self, context: QueryEvalContext, term: Term
453
+ ) -> list[ScoredSemanticRefOrdinal] | None:
454
+ """Look up a term in the semantic reference index."""
455
+ matches = await lookup_term(
456
+ context.semantic_ref_index,
457
+ term,
458
+ context.semantic_refs,
459
+ context.text_ranges_in_scope,
460
+ )
461
+ if matches and self.score_booster:
462
+ for i in range(len(matches)):
463
+ matches[i] = self.score_booster(
464
+ self.search_term,
465
+ await context.get_semantic_ref(matches[i].semantic_ref_ordinal),
466
+ matches[i],
467
+ )
468
+ return matches
469
+
470
+ async def accumulate_matches_for_term(
471
+ self,
472
+ context: QueryEvalContext,
473
+ matches: SemanticRefAccumulator,
474
+ term: Term,
475
+ related_term: Term | None = None,
476
+ ) -> None:
477
+ """Accumulate matches for a term or a related term."""
478
+ if related_term is None:
479
+ if term not in context.matched_terms:
480
+ semantic_refs = await self.lookup_term(context, term)
481
+ matches.add_term_matches(term, semantic_refs, True)
482
+ context.matched_terms.add(term)
483
+ else:
484
+ if related_term not in context.matched_terms:
485
+ # If this related term had not already matched as a related term for some other term
486
+ # Minimize over counting
487
+ semantic_refs = await self.lookup_term(context, related_term)
488
+ # This will only consider semantic refs that have not already matched this expression.
489
+ # In other words, if a semantic ref already matched due to the term 'novel',
490
+ # don't also match it because it matched the related term 'book'
491
+ matches.add_term_matches_if_new(
492
+ term, semantic_refs, False, related_term.weight
493
+ )
494
+ context.matched_terms.add(related_term)
495
+
496
+
497
+ @dataclass
498
+ class MatchPropertySearchTermExpr(MatchTermExpr):
499
+ property_search_term: PropertySearchTerm
500
+
501
+ async def accumulate_matches(
502
+ self, context: QueryEvalContext, matches: SemanticRefAccumulator
503
+ ) -> None:
504
+ if isinstance(self.property_search_term.property_name, str):
505
+ await self.accumulate_matches_for_property(
506
+ context,
507
+ self.property_search_term.property_name,
508
+ self.property_search_term.property_value,
509
+ matches,
510
+ )
511
+ else:
512
+ await self.accumulate_matches_for_facets(
513
+ context,
514
+ self.property_search_term.property_name,
515
+ self.property_search_term.property_value,
516
+ matches,
517
+ )
518
+
519
+ async def accumulate_matches_for_facets(
520
+ self,
521
+ context: QueryEvalContext,
522
+ property_name: SearchTerm,
523
+ property_value: SearchTerm,
524
+ matches: SemanticRefAccumulator,
525
+ ):
526
+ await self.accumulate_matches_for_property(
527
+ context,
528
+ PropertyNames.FacetName.value,
529
+ property_name,
530
+ matches,
531
+ )
532
+ if not is_search_term_wildcard(property_value):
533
+ await self.accumulate_matches_for_property(
534
+ context,
535
+ PropertyNames.FacetValue.value,
536
+ property_value,
537
+ matches,
538
+ )
539
+
540
+ async def accumulate_matches_for_property(
541
+ self,
542
+ context: QueryEvalContext,
543
+ property_name: str,
544
+ property_value: SearchTerm,
545
+ matches: SemanticRefAccumulator,
546
+ ):
547
+ await self.accumulate_matches_for_property_value(
548
+ context,
549
+ matches,
550
+ property_name,
551
+ property_value.term,
552
+ )
553
+ if property_value.related_terms:
554
+ for related_property_value in property_value.related_terms:
555
+ await self.accumulate_matches_for_property_value(
556
+ context,
557
+ matches,
558
+ property_name,
559
+ property_value.term,
560
+ related_property_value,
561
+ )
562
+
563
+ async def accumulate_matches_for_property_value(
564
+ self,
565
+ context: QueryEvalContext,
566
+ matches: SemanticRefAccumulator,
567
+ property_name: str,
568
+ property_value: Term,
569
+ related_prop_val: Term | None = None,
570
+ ) -> None:
571
+ if related_prop_val is None:
572
+ if not context.matched_property_terms.has(property_name, property_value):
573
+ semantic_refs = await self.lookup_property(
574
+ context,
575
+ property_name,
576
+ property_value.text,
577
+ )
578
+ if semantic_refs:
579
+ matches.add_term_matches(property_value, semantic_refs, True)
580
+ context.matched_property_terms.add(property_name, property_value)
581
+ else:
582
+ # To prevent over-counting, ensure this related_prop_val was not already used to match terms earlier
583
+ if not context.matched_property_terms.has(property_name, related_prop_val):
584
+ semantic_refs = await self.lookup_property(
585
+ context,
586
+ property_name,
587
+ related_prop_val.text,
588
+ )
589
+ if semantic_refs:
590
+ # This will only consider semantic refs that were not already matched by this expression.
591
+ # In other words, if a semantic ref already matched due to the term 'novel',
592
+ # don't also match it because it matched the related term 'book'
593
+ matches.add_term_matches_if_new(
594
+ property_value,
595
+ semantic_refs,
596
+ False,
597
+ related_prop_val.weight,
598
+ )
599
+ context.matched_property_terms.add(property_name, related_prop_val)
600
+
601
+ async def lookup_property(
602
+ self,
603
+ context: QueryEvalContext,
604
+ property_name: str,
605
+ property_value: str,
606
+ ) -> list[ScoredSemanticRefOrdinal] | None:
607
+ if context.property_index is not None:
608
+ return await lookup_property_in_property_index(
609
+ context.property_index,
610
+ property_name,
611
+ property_value,
612
+ context.semantic_refs,
613
+ context.text_ranges_in_scope,
614
+ )
615
+
616
+
617
+ class MatchTagExpr(MatchSearchTermExpr):
618
+ def __init__(self, tag_term: SearchTerm):
619
+ self.tag_term = tag_term
620
+ super().__init__(tag_term)
621
+
622
+ async def lookup_term(
623
+ self, context: QueryEvalContext, term: Term
624
+ ) -> list[ScoredSemanticRefOrdinal] | None:
625
+ if self.tag_term.term.text == "*":
626
+ return await lookup_knowledge_type(context.semantic_refs, "tag")
627
+ else:
628
+ return await lookup_term(
629
+ context.semantic_ref_index,
630
+ term,
631
+ context.semantic_refs,
632
+ context.text_ranges_in_scope,
633
+ "tag",
634
+ )
635
+
636
+
637
+ class MatchTopicExpr(MatchSearchTermExpr):
638
+ def __init__(self, topic: SearchTerm):
639
+ self.topic = topic
640
+ super().__init__(topic)
641
+
642
+ async def lookup_term(
643
+ self, context: QueryEvalContext, term: Term
644
+ ) -> list[ScoredSemanticRefOrdinal] | None:
645
+ if self.topic.term.text == "*":
646
+ return await lookup_knowledge_type(context.semantic_refs, "topic")
647
+ else:
648
+ return await lookup_term(
649
+ context.semantic_ref_index,
650
+ term,
651
+ context.semantic_refs,
652
+ context.text_ranges_in_scope,
653
+ "topic",
654
+ )
655
+
656
+
657
+ @dataclass
658
+ class GroupByKnowledgeTypeExpr(
659
+ QueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
660
+ ):
661
+ matches: IQueryOpExpr[SemanticRefAccumulator]
662
+
663
+ async def eval(
664
+ self, context: QueryEvalContext
665
+ ) -> dict[KnowledgeType, SemanticRefAccumulator]:
666
+ semantic_ref_matches = await self.matches.eval(context)
667
+ return await semantic_ref_matches.group_matches_by_type(context.semantic_refs)
668
+
669
+
670
+ @dataclass
671
+ class SelectTopNKnowledgeGroupExpr(
672
+ QueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
673
+ ):
674
+ source_expr: IQueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
675
+ max_matches: int | None = None
676
+ min_hit_count: int | None = None
677
+
678
+ async def eval(
679
+ self, context: QueryEvalContext
680
+ ) -> dict[KnowledgeType, SemanticRefAccumulator]:
681
+ groups_accumulators = await self.source_expr.eval(context)
682
+ for accumulator in groups_accumulators.values():
683
+ accumulator.select_top_n_scoring(self.max_matches, self.min_hit_count)
684
+ return groups_accumulators
685
+
686
+
687
+ @dataclass
688
+ class GroupSearchResultsExpr(QueryOpExpr[dict[KnowledgeType, SemanticRefSearchResult]]):
689
+ src_expr: IQueryOpExpr[dict[KnowledgeType, SemanticRefAccumulator]]
690
+
691
+ async def eval(
692
+ self, context: QueryEvalContext
693
+ ) -> dict[KnowledgeType, SemanticRefSearchResult]:
694
+ return to_grouped_search_results(await self.src_expr.eval(context))
695
+
696
+
697
+ @dataclass
698
+ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]):
699
+ source_expr: IQueryOpExpr[SemanticRefAccumulator]
700
+ predicates: list["IQuerySemanticRefPredicate"]
701
+
702
+ async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator:
703
+ accumulator = await self.source_expr.eval(context)
704
+ filtered = SemanticRefAccumulator(accumulator.search_term_matches)
705
+
706
+ # Filter matches asynchronously
707
+ filtered_matches = []
708
+ for match in accumulator.get_matches():
709
+ if await self._eval_predicates(context, self.predicates, match):
710
+ filtered_matches.append(match)
711
+
712
+ filtered.set_matches(filtered_matches)
713
+ return filtered
714
+
715
+ async def _eval_predicates(
716
+ self,
717
+ context: QueryEvalContext,
718
+ predicates: list["IQuerySemanticRefPredicate"],
719
+ match: Match[SemanticRefOrdinal],
720
+ ) -> bool:
721
+ for predicate in predicates:
722
+ semantic_ref = await context.get_semantic_ref(match.value)
723
+ if not await predicate.eval(context, semantic_ref):
724
+ return False
725
+
726
+ return True
727
+
728
+
729
+ class IQuerySemanticRefPredicate(Protocol):
730
+ async def eval(
731
+ self, context: QueryEvalContext, semantic_ref: SemanticRef
732
+ ) -> bool: ...
733
+
734
+
735
+ # TODO: match_predicates
736
+ # TODO: knowledge_type_predicate
737
+ # TODO: property_match_predicate
738
+
739
+
740
+ # NOTE: GetScopeExpr is moved after TextRangeSelector to avoid circular references.
741
+
742
+
743
+ class IQueryTextRangeSelector(Protocol):
744
+ """Protocol for a selector that can evaluate to a text range."""
745
+
746
+ async def eval(
747
+ self,
748
+ context: QueryEvalContext,
749
+ semantic_refs: SemanticRefAccumulator | None = None,
750
+ ) -> TextRangeCollection | None:
751
+ """Evaluate the selector and return the text range."""
752
+ ...
753
+
754
+
755
+ class TextRangeSelector(IQueryTextRangeSelector):
756
+ """A selector that evaluates to a pre-computed TextRangeCollection."""
757
+
758
+ text_ranges_in_scope: TextRangeCollection
759
+
760
+ def __init__(self, ranges_in_scope: list[TextRange]) -> None:
761
+ self.text_ranges_in_scope = TextRangeCollection(ranges_in_scope, True)
762
+
763
+ async def eval(
764
+ self,
765
+ context: QueryEvalContext,
766
+ semantic_refs: SemanticRefAccumulator | None = None,
767
+ ) -> TextRangeCollection | None:
768
+ return self.text_ranges_in_scope
769
+
770
+
771
+ @dataclass
772
+ class GetScopeExpr(QueryOpExpr[TextRangesInScope]):
773
+ """Expression for getting the scope of a query."""
774
+
775
+ range_selectors: list[IQueryTextRangeSelector]
776
+
777
+ async def eval(self, context: QueryEvalContext) -> TextRangesInScope:
778
+ """Evaluate the expression and return the text ranges in scope."""
779
+ ranges_in_scope = TextRangesInScope()
780
+ for selector in self.range_selectors:
781
+ range_collection = await selector.eval(context)
782
+ if range_collection is not None:
783
+ ranges_in_scope.add_text_ranges(range_collection)
784
+ return ranges_in_scope
785
+
786
+
787
+ # TODO: SelectInScopeExpr
788
+
789
+
790
+ @dataclass
791
+ class TextRangesInDateRangeSelector(IQueryTextRangeSelector):
792
+ date_range_in_scope: DateRange
793
+
794
+ async def eval(
795
+ self,
796
+ context: QueryEvalContext,
797
+ semantic_refs: SemanticRefAccumulator | None = None,
798
+ ) -> TextRangeCollection | None:
799
+ """Evaluate the selector and return text ranges in the specified date range."""
800
+ text_ranges_in_scope = TextRangeCollection()
801
+
802
+ if context.timestamp_index is not None:
803
+ text_ranges = await context.timestamp_index.lookup_range(
804
+ self.date_range_in_scope,
805
+ )
806
+ for time_range in text_ranges:
807
+ text_ranges_in_scope.add_range(time_range.range)
808
+ else:
809
+ text_range = await get_text_range_for_date_range(
810
+ context.conversation,
811
+ self.date_range_in_scope,
812
+ )
813
+ if text_range is not None:
814
+ text_ranges_in_scope.add_range(text_range)
815
+
816
+ return text_ranges_in_scope
817
+
818
+
819
+ # TODO: TextRangesPredicateSelector
820
+ # TODO: TextRangesWithTagSelector
821
+ # TODO: TextRangesFromSemanticRefsSelector
822
+
823
+
824
+ @dataclass
825
+ class TextRangesFromMessagesSelector(IQueryTextRangeSelector):
826
+ source_expr: IQueryOpExpr[MessageAccumulator]
827
+
828
+ async def eval(
829
+ self,
830
+ context: QueryEvalContext,
831
+ semantic_refs: SemanticRefAccumulator | None = None,
832
+ ) -> TextRangeCollection | None:
833
+ matches = await self.source_expr.eval(context)
834
+ ranges_in_scope: list[TextRange] | None = None
835
+ if matches:
836
+ all_ordinals = sorted(matches.get_matched_values())
837
+ ranges_in_scope = text_ranges_from_message_ordinals(all_ordinals)
838
+ return TextRangeCollection(ranges_in_scope)
839
+
840
+
841
+ # TODO: Move to messagelib.py
842
+ def text_ranges_from_message_ordinals(
843
+ message_ordinals: list[MessageOrdinal],
844
+ ) -> list[TextRange]:
845
+ return [text_range_from_message(ordinal) for ordinal in message_ordinals]
846
+
847
+
848
+ # TODO: Move to messagelib.py
849
+ def text_range_from_message(message_ordinal: MessageOrdinal) -> TextRange:
850
+ return TextRange(start=TextLocation(message_ordinal))
851
+
852
+
853
+ # TODO: ThreadSelector
854
+ # TODO: to_grouped_search_results
855
+
856
+
857
+ def to_grouped_search_results(
858
+ eval_results: dict[KnowledgeType, SemanticRefAccumulator],
859
+ ) -> dict[KnowledgeType, SemanticRefSearchResult]:
860
+ semantic_ref_matches: dict[KnowledgeType, SemanticRefSearchResult] = {}
861
+ for typ, accumulator in eval_results.items():
862
+ if len(accumulator) > 0:
863
+ semantic_ref_matches[typ] = SemanticRefSearchResult(
864
+ term_matches=accumulator.search_term_matches,
865
+ semantic_ref_matches=accumulator.to_scored_semantic_refs(),
866
+ )
867
+ return semantic_ref_matches
868
+
869
+
870
+ @dataclass
871
+ class MessagesFromKnowledgeExpr(QueryOpExpr[MessageAccumulator]):
872
+ src_expr: (
873
+ IQueryOpExpr[dict[KnowledgeType, SemanticRefSearchResult]]
874
+ | dict[KnowledgeType, SemanticRefSearchResult]
875
+ )
876
+
877
+ async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
878
+ knowledge = (
879
+ self.src_expr
880
+ if isinstance(self.src_expr, dict)
881
+ else await self.src_expr.eval(context)
882
+ )
883
+ return await message_matches_from_knowledge_matches(
884
+ context.semantic_refs, knowledge
885
+ )
886
+
887
+
888
+ # TODO: SelectMessagesInCharBudget
889
+
890
+
891
+ @dataclass
892
+ class RankMessagesBySimilarityExpr(QueryOpExpr[MessageAccumulator]):
893
+ src_expr: IQueryOpExpr[MessageAccumulator]
894
+ embedding: NormalizedEmbedding
895
+ max_messages: int | None = None
896
+ threshold_score: float | None = None
897
+
898
+ async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
899
+ matches = await self.src_expr.eval(context)
900
+ if self.max_messages is not None and len(matches) <= self.max_messages:
901
+ return matches
902
+
903
+ # Try to use the message embedding index for re-ranking if available.
904
+ message_index = (
905
+ None
906
+ if context.conversation.secondary_indexes is None
907
+ else context.conversation.secondary_indexes.message_index
908
+ )
909
+ if isinstance(message_index, IMessageTextEmbeddingIndex):
910
+ message_ordinals = await self._get_message_ordinals_in_index(
911
+ message_index, matches
912
+ )
913
+ if len(message_ordinals) == len(matches):
914
+ matches.clear_matches()
915
+ ranked_messages = message_index.lookup_in_subset_by_embedding(
916
+ self.embedding,
917
+ message_ordinals,
918
+ self.max_messages,
919
+ self.threshold_score,
920
+ )
921
+ for match in ranked_messages:
922
+ matches.add(match.message_ordinal, match.score)
923
+ return matches
924
+
925
+ if self.max_messages is not None:
926
+ # Can't re rank, so just take the top K from what we already have.
927
+ matches.select_top_n_scoring(self.max_messages)
928
+ return matches
929
+
930
+ async def _get_message_ordinals_in_index(
931
+ self, message_index, matches: MessageAccumulator
932
+ ):
933
+ message_ordinals: list[MessageOrdinal] = []
934
+ index_size = await message_index.size()
935
+ for message_ordinal in matches.get_matched_values():
936
+ if message_ordinal >= index_size:
937
+ break
938
+ message_ordinals.append(message_ordinal)
939
+ return message_ordinals
940
+
941
+
942
+ @dataclass
943
+ class GetScoredMessagesExpr(QueryOpExpr[list[ScoredMessageOrdinal]]):
944
+ src_expr: IQueryOpExpr[MessageAccumulator]
945
+
946
+ async def eval(self, context: QueryEvalContext) -> list[ScoredMessageOrdinal]:
947
+ matches = await self.src_expr.eval(context)
948
+ return matches.to_scored_message_ordinals()
949
+
950
+
951
+ @dataclass
952
+ class MatchMessagesBooleanExpr(IQueryOpExpr[MessageAccumulator]):
953
+ term_expressions: list[
954
+ IQueryOpExpr[SemanticRefAccumulator | MessageAccumulator | None]
955
+ ]
956
+
957
+ def _begin_match(self, context: QueryEvalContext) -> None:
958
+ context.clear_matched_terms()
959
+
960
+ async def _accumulate_messages(
961
+ self,
962
+ context: QueryEvalContext,
963
+ semantic_ref_matches: SemanticRefAccumulator,
964
+ ) -> MessageAccumulator:
965
+ message_matches = MessageAccumulator()
966
+ for semantic_ref_match in semantic_ref_matches:
967
+ semantic_ref = await context.get_semantic_ref(semantic_ref_match.value)
968
+ message_matches.add_messages_for_semantic_ref(
969
+ semantic_ref,
970
+ semantic_ref_match.score,
971
+ )
972
+ return message_matches
973
+
974
+
975
+ @dataclass
976
+ class MatchMessagesOrExpr(MatchMessagesBooleanExpr):
977
+
978
+ async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
979
+ self._begin_match(context)
980
+
981
+ all_matches: MessageAccumulator | None = None
982
+ for match_expr in self.term_expressions:
983
+ matches = await match_expr.eval(context)
984
+ if not matches:
985
+ continue
986
+ if isinstance(matches, SemanticRefAccumulator):
987
+ message_matches = await self._accumulate_messages(context, matches)
988
+ else:
989
+ message_matches = matches
990
+ if all_matches is not None:
991
+ all_matches.add_union(message_matches)
992
+ else:
993
+ all_matches = message_matches
994
+ if all_matches is not None:
995
+ all_matches.calculate_total_score()
996
+ else:
997
+ all_matches = MessageAccumulator()
998
+ return all_matches
999
+
1000
+
1001
+ @dataclass
1002
+ class MatchMessagesAndExpr(MatchMessagesBooleanExpr):
1003
+
1004
+ async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
1005
+ self._begin_match(context)
1006
+
1007
+ all_matches: MessageAccumulator | None = None
1008
+ all_done = False
1009
+ for match_expr in self.term_expressions:
1010
+ matches = await match_expr.eval(context)
1011
+ if not matches:
1012
+ # If any expr does not match, the AND fails.
1013
+ break
1014
+ if isinstance(matches, SemanticRefAccumulator):
1015
+ message_matches = await self._accumulate_messages(context, matches)
1016
+ else:
1017
+ message_matches = matches
1018
+ if all_matches is None:
1019
+ all_matches = message_matches
1020
+ else:
1021
+ # Intersect the message matches
1022
+ all_matches = all_matches.intersect(message_matches)
1023
+ if not all_matches:
1024
+ # If the intersection is empty, we can stop early.
1025
+ break
1026
+ else:
1027
+ # If we did not break, all terms matched.
1028
+ all_done = True
1029
+
1030
+ if all_matches is not None:
1031
+ if all_done:
1032
+ all_matches.calculate_total_score()
1033
+ all_matches.select_with_hit_count(len(self.term_expressions))
1034
+ else:
1035
+ all_matches.clear_matches()
1036
+ else:
1037
+ all_matches = MessageAccumulator()
1038
+ return all_matches
1039
+
1040
+
1041
+ @dataclass
1042
+ class MatchMessagesOrMaxExpr(MatchMessagesOrExpr):
1043
+
1044
+ async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
1045
+ matches = await super().eval(context)
1046
+ max_hit_count = matches.get_max_hit_count()
1047
+ if max_hit_count > 1:
1048
+ matches.select_with_hit_count(max_hit_count)
1049
+ return matches
1050
+
1051
+
1052
+ # TODO: class MatchMessagesBySimilarityExpr(QueryOpExpr[list[ScoredMessageOrdinal]]):
1053
+
1054
+
1055
+ class NoOpExpr[T](QueryOpExpr[T]):
1056
+ def __init__(self, src_expr: IQueryOpExpr[T]) -> None:
1057
+ self.src_expr = src_expr
1058
+ super().__init__()
1059
+
1060
+ async def eval(self, context: QueryEvalContext) -> T:
1061
+ return await self.src_expr.eval(context)
1062
+
1063
+
1064
+ async def message_matches_from_knowledge_matches(
1065
+ semantic_refs: ISemanticRefCollection,
1066
+ knowledge_matches: dict[KnowledgeType, SemanticRefSearchResult],
1067
+ intersect_across_knowledge_types: bool = True,
1068
+ ) -> MessageAccumulator:
1069
+ message_matches = MessageAccumulator()
1070
+ knowledge_type_hit_count = 0 # How many types of knowledge matched?
1071
+ for knowledge_type, matches_by_type in knowledge_matches.items():
1072
+ if matches_by_type and matches_by_type.semantic_ref_matches:
1073
+ knowledge_type_hit_count += 1
1074
+ for match in matches_by_type.semantic_ref_matches:
1075
+ message_matches.add_messages_for_semantic_ref(
1076
+ await semantic_refs.get_item(match.semantic_ref_ordinal),
1077
+ match.score,
1078
+ )
1079
+ if intersect_across_knowledge_types and knowledge_type_hit_count > 0:
1080
+ # Intersect the sets of messages that matched each knowledge type
1081
+ relevant_messages = message_matches.get_with_hit_count(knowledge_type_hit_count)
1082
+ if relevant_messages:
1083
+ message_matches = MessageAccumulator(relevant_messages)
1084
+ message_matches.smooth_scores()
1085
+ return message_matches
1086
+
1087
+
1088
+ # TODO: Implement proper SelectMessagesInCharBudget functionality
1089
+ @dataclass
1090
+ class SelectMessagesInCharBudget(QueryOpExpr[MessageAccumulator]):
1091
+ """Selects messages within a character budget."""
1092
+
1093
+ src_expr: IQueryOpExpr[MessageAccumulator]
1094
+ max_chars: int
1095
+
1096
+ async def eval(self, context: QueryEvalContext) -> MessageAccumulator:
1097
+ matches = await self.src_expr.eval(context)
1098
+ await matches.select_messages_in_budget(context.messages, self.max_chars)
1099
+ return matches
1100
+
1101
+
1102
+ # TODO: Implement proper KnowledgeTypePredicate functionality
1103
+ @dataclass
1104
+ class KnowledgeTypePredicate(IQuerySemanticRefPredicate):
1105
+ """Predicate to filter by knowledge type."""
1106
+
1107
+ knowledge_type: KnowledgeType
1108
+
1109
+ async def eval(self, context: QueryEvalContext, semantic_ref: SemanticRef) -> bool:
1110
+ return semantic_ref.knowledge.knowledge_type == self.knowledge_type
1111
+
1112
+
1113
+ # TODO: Implement proper ThreadSelector functionality
1114
+ @dataclass
1115
+ class ThreadSelector(IQueryTextRangeSelector):
1116
+ """Selector for thread-based text ranges."""
1117
+
1118
+ threads: list[Thread]
1119
+
1120
+ async def eval(
1121
+ self,
1122
+ context: QueryEvalContext,
1123
+ semantic_refs: SemanticRefAccumulator | None = None,
1124
+ ) -> TextRangeCollection | None:
1125
+ text_ranges = TextRangeCollection()
1126
+ for thread in self.threads:
1127
+ text_ranges.add_ranges(list(thread.ranges))
1128
+ return text_ranges