typeagent-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. typeagent/aitools/auth.py +61 -0
  2. typeagent/aitools/embeddings.py +232 -0
  3. typeagent/aitools/utils.py +244 -0
  4. typeagent/aitools/vectorbase.py +175 -0
  5. typeagent/knowpro/answer_context_schema.py +49 -0
  6. typeagent/knowpro/answer_response_schema.py +34 -0
  7. typeagent/knowpro/answers.py +577 -0
  8. typeagent/knowpro/collections.py +759 -0
  9. typeagent/knowpro/common.py +9 -0
  10. typeagent/knowpro/convknowledge.py +112 -0
  11. typeagent/knowpro/convsettings.py +94 -0
  12. typeagent/knowpro/convutils.py +49 -0
  13. typeagent/knowpro/date_time_schema.py +32 -0
  14. typeagent/knowpro/field_helpers.py +87 -0
  15. typeagent/knowpro/fuzzyindex.py +144 -0
  16. typeagent/knowpro/interfaces.py +818 -0
  17. typeagent/knowpro/knowledge.py +88 -0
  18. typeagent/knowpro/kplib.py +125 -0
  19. typeagent/knowpro/query.py +1128 -0
  20. typeagent/knowpro/search.py +628 -0
  21. typeagent/knowpro/search_query_schema.py +165 -0
  22. typeagent/knowpro/searchlang.py +729 -0
  23. typeagent/knowpro/searchlib.py +345 -0
  24. typeagent/knowpro/secindex.py +100 -0
  25. typeagent/knowpro/serialization.py +390 -0
  26. typeagent/knowpro/textlocindex.py +179 -0
  27. typeagent/knowpro/utils.py +17 -0
  28. typeagent/mcp/server.py +139 -0
  29. typeagent/podcasts/podcast.py +473 -0
  30. typeagent/podcasts/podcast_import.py +105 -0
  31. typeagent/storage/__init__.py +25 -0
  32. typeagent/storage/memory/__init__.py +13 -0
  33. typeagent/storage/memory/collections.py +68 -0
  34. typeagent/storage/memory/convthreads.py +81 -0
  35. typeagent/storage/memory/messageindex.py +178 -0
  36. typeagent/storage/memory/propindex.py +289 -0
  37. typeagent/storage/memory/provider.py +84 -0
  38. typeagent/storage/memory/reltermsindex.py +318 -0
  39. typeagent/storage/memory/semrefindex.py +660 -0
  40. typeagent/storage/memory/timestampindex.py +176 -0
  41. typeagent/storage/sqlite/__init__.py +31 -0
  42. typeagent/storage/sqlite/collections.py +362 -0
  43. typeagent/storage/sqlite/messageindex.py +382 -0
  44. typeagent/storage/sqlite/propindex.py +119 -0
  45. typeagent/storage/sqlite/provider.py +293 -0
  46. typeagent/storage/sqlite/reltermsindex.py +328 -0
  47. typeagent/storage/sqlite/schema.py +248 -0
  48. typeagent/storage/sqlite/semrefindex.py +156 -0
  49. typeagent/storage/sqlite/timestampindex.py +146 -0
  50. typeagent/storage/utils.py +41 -0
  51. typeagent_py-0.1.0.dist-info/METADATA +28 -0
  52. typeagent_py-0.1.0.dist-info/RECORD +55 -0
  53. typeagent_py-0.1.0.dist-info/WHEEL +5 -0
  54. typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
  55. typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,759 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import bisect
5
+ from collections.abc import Callable, Iterable, Iterator
6
+ from dataclasses import dataclass, field
7
+ import heapq
8
+ import math
9
+ import sys
10
+ from typing import cast, Set # Set is an alias for builtin set
11
+
12
+ from .interfaces import (
13
+ ICollection,
14
+ IMessage,
15
+ IMessageCollection,
16
+ ISemanticRefCollection,
17
+ Knowledge,
18
+ KnowledgeType,
19
+ MessageOrdinal,
20
+ ScoredMessageOrdinal,
21
+ ScoredSemanticRefOrdinal,
22
+ SemanticRef,
23
+ SemanticRefOrdinal,
24
+ Term,
25
+ TextRange,
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class Match[T]:
31
+ value: T
32
+ score: float
33
+ hit_count: int
34
+ related_score: float
35
+ related_hit_count: int
36
+
37
+
38
+ # TODO: sortMatchesByRelevance
39
+
40
+
41
+ class MatchAccumulator[T]:
42
+ def __init__(self):
43
+ self._matches: dict[T, Match[T]] = {}
44
+
45
+ def __len__(self) -> int:
46
+ return len(self._matches)
47
+
48
+ def __iter__(self) -> Iterator[Match[T]]:
49
+ return iter(self._matches.values())
50
+
51
+ def __contains__(self, value: T) -> bool:
52
+ return value in self._matches
53
+
54
+ def get_match(self, value: T) -> Match[T] | None:
55
+ return self._matches.get(value)
56
+
57
+ def set_match(self, match: Match[T]) -> None:
58
+ self._matches[match.value] = match
59
+
60
+ # TODO: Maybe make the callers call clear_matches()?
61
+ def set_matches(self, matches: Iterable[Match[T]], *, clear: bool = False) -> None:
62
+ if clear:
63
+ self.clear_matches()
64
+ for match in matches:
65
+ self.set_match(match)
66
+
67
+ def get_max_hit_count(self) -> int:
68
+ count = 0
69
+ for match in self._matches.values():
70
+ count = max(count, match.hit_count)
71
+ return count
72
+
73
+ # TODO: Rename to add_exact if we ever add add_related
74
+ def add(self, value: T, score: float, is_exact_match: bool = True) -> None:
75
+ existing_match = self.get_match(value)
76
+ if existing_match is not None:
77
+ if is_exact_match:
78
+ existing_match.hit_count += 1
79
+ existing_match.score += score
80
+ else:
81
+ existing_match.related_hit_count += 1
82
+ existing_match.related_score += score
83
+ else:
84
+ if is_exact_match:
85
+ self.set_match(
86
+ Match(
87
+ value,
88
+ hit_count=1,
89
+ score=score,
90
+ related_hit_count=0,
91
+ related_score=0.0,
92
+ )
93
+ )
94
+ else:
95
+ self.set_match(
96
+ Match(
97
+ value,
98
+ hit_count=1,
99
+ score=0.0,
100
+ related_hit_count=1,
101
+ related_score=score,
102
+ )
103
+ )
104
+
105
+ def add_union(self, other: "MatchAccumulator[T]") -> None:
106
+ """Add matches from another collection of matches."""
107
+ for other_match in other:
108
+ existing_match = self.get_match(other_match.value)
109
+ if existing_match is None:
110
+ self.set_match(other_match)
111
+ else:
112
+ self.combine_matches(existing_match, other_match)
113
+
114
+ def intersect(
115
+ self, other: "MatchAccumulator[T]", intersection: "MatchAccumulator[T]"
116
+ ) -> "MatchAccumulator[T]":
117
+ """Intersect with another collection of matches."""
118
+ for self_match in self:
119
+ other_match = other.get_match(self_match.value)
120
+ if other_match is not None:
121
+ self.combine_matches(self_match, other_match)
122
+ intersection.set_match(self_match)
123
+ return intersection
124
+
125
+ def combine_matches(self, match: Match[T], other: Match[T]) -> None:
126
+ """Combine the other match into the first."""
127
+ match.hit_count += other.hit_count
128
+ match.score += other.score
129
+ match.related_hit_count += other.related_hit_count
130
+ match.related_score += other.related_score
131
+
132
+ def calculate_total_score(
133
+ self, scorer: Callable[[Match[T]], None] | None = None
134
+ ) -> None:
135
+ if scorer is None:
136
+ scorer = add_smooth_related_score_to_match_score
137
+ for match in self:
138
+ scorer(match)
139
+
140
+ def get_sorted_by_score(self, min_hit_count: int | None = None) -> list[Match[T]]:
141
+ """Get matches sorted by score"""
142
+ if len(self._matches) == 0:
143
+ return []
144
+ matches = [*self._matches_with_min_hit_count(min_hit_count)]
145
+ matches.sort(key=lambda m: m.score, reverse=True)
146
+ return matches
147
+
148
+ def get_top_n_scoring(
149
+ self,
150
+ max_matches: int | None = None,
151
+ min_hit_count: int | None = None,
152
+ ) -> list[Match[T]]:
153
+ """Get the top N scoring matches."""
154
+ if not self._matches:
155
+ return []
156
+ if max_matches and max_matches > 0:
157
+ top_list = TopNList[T](max_matches)
158
+ for match in self._matches_with_min_hit_count(min_hit_count):
159
+ top_list.push(match.value, match.score)
160
+ ranked = top_list.by_rank()
161
+ return [self._matches[match.item] for match in ranked]
162
+ else:
163
+ return self.get_sorted_by_score(min_hit_count)
164
+
165
+ def get_with_hit_count(self, min_hit_count: int) -> list[Match[T]]:
166
+ """Get matches with a minimum hit count."""
167
+ return list(self.matches_with_min_hit_count(min_hit_count))
168
+
169
+ def get_matches(
170
+ self, predicate: Callable[[Match[T]], bool] | None = None
171
+ ) -> Iterator[Match[T]]:
172
+ """Iterate over all matches."""
173
+ if predicate is None:
174
+ return iter(self._matches.values())
175
+ else:
176
+ return filter(predicate, self._matches.values())
177
+
178
+ def get_matched_values(self) -> Iterator[T]:
179
+ """Iterate over all matched values."""
180
+ return iter(self._matches)
181
+
182
+ def clear_matches(self):
183
+ self._matches.clear()
184
+
185
+ def select_top_n_scoring(
186
+ self,
187
+ max_matches: int | None = None,
188
+ min_hit_count: int | None = None,
189
+ ) -> int:
190
+ """Retain only the top N matches sorted by score."""
191
+ top_n = self.get_top_n_scoring(max_matches, min_hit_count)
192
+ self.set_matches(top_n, clear=True)
193
+ return len(top_n)
194
+
195
+ def select_with_hit_count(self, min_hit_count: int) -> int:
196
+ """Retain only matches with a minimum hit count."""
197
+ matches = self.get_with_hit_count(min_hit_count)
198
+ self.set_matches(matches, clear=True)
199
+ return len(matches)
200
+
201
+ def _matches_with_min_hit_count(
202
+ self, min_hit_count: int | None
203
+ ) -> Iterable[Match[T]]:
204
+ """Get matches with a minimum hit count"""
205
+ if min_hit_count is not None and min_hit_count > 0:
206
+ return self.get_matches(lambda m: m.hit_count >= min_hit_count)
207
+ else:
208
+ return self._matches.values()
209
+
210
+ def matches_with_min_hit_count(
211
+ self, min_hit_count: int | None
212
+ ) -> Iterable[Match[T]]:
213
+ if min_hit_count is not None and min_hit_count > 0:
214
+ return filter(lambda m: m.hit_count >= min_hit_count, self.get_matches())
215
+ else:
216
+ return self._matches.values()
217
+
218
+
219
+ def get_smooth_score(
220
+ total_score: float,
221
+ hit_count: int,
222
+ ) -> float:
223
+ """See the long comment in collections.ts for an explanation."""
224
+ if hit_count > 0:
225
+ if hit_count == 1:
226
+ return total_score
227
+ avg = total_score / hit_count
228
+ smooth_avg = math.log(hit_count + 1) * avg
229
+ return smooth_avg
230
+ else:
231
+ return 0.0
232
+
233
+
234
+ def add_smooth_related_score_to_match_score[T](match: Match[T]) -> None:
235
+ """Add the smooth related score to the match score."""
236
+ if match.related_hit_count > 0:
237
+ # Related term matches can be noisy and duplicative.
238
+ # See the comment on getSmoothScore in collections.ts.
239
+ smooth_related_score = get_smooth_score(
240
+ match.related_score, match.related_hit_count
241
+ )
242
+ match.score += smooth_related_score
243
+
244
+
245
+ def smooth_match_score[T](match: Match[T]) -> None:
246
+ if match.hit_count > 0:
247
+ match.score = get_smooth_score(match.score, match.hit_count)
248
+
249
+
250
+ type KnowledgePredicate[T: Knowledge] = Callable[[T], bool]
251
+
252
+
253
+ class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]):
254
+ def __init__(self, search_term_matches: set[str] = set()):
255
+ super().__init__()
256
+ self.search_term_matches = search_term_matches
257
+
258
+ def add_term_matches(
259
+ self,
260
+ search_term: Term,
261
+ scored_refs: Iterable[ScoredSemanticRefOrdinal] | None,
262
+ is_exact_match: bool,
263
+ *,
264
+ weight: float | None = None,
265
+ ) -> None:
266
+ """Add term matches to the accumulator"""
267
+ if scored_refs is not None:
268
+ if weight is None:
269
+ weight = search_term.weight
270
+ if weight is None:
271
+ weight = 1.0
272
+ for scored_ref in scored_refs:
273
+ self.add(
274
+ scored_ref.semantic_ref_ordinal,
275
+ scored_ref.score * weight,
276
+ is_exact_match,
277
+ )
278
+ self.search_term_matches.add(search_term.text)
279
+
280
+ def add_term_matches_if_new(
281
+ self,
282
+ search_term: Term,
283
+ scored_refs: Iterable[ScoredSemanticRefOrdinal] | None,
284
+ is_exact_match: bool,
285
+ weight: float | None = None,
286
+ ) -> None:
287
+ """Add term matches if they are new."""
288
+ if scored_refs is not None:
289
+ if weight is None:
290
+ weight = search_term.weight
291
+ if weight is None:
292
+ weight = 1.0
293
+ for scored_ref in scored_refs:
294
+ if scored_ref.semantic_ref_ordinal not in self:
295
+ self.add(
296
+ scored_ref.semantic_ref_ordinal,
297
+ scored_ref.score * weight,
298
+ is_exact_match,
299
+ )
300
+ self.search_term_matches.add(search_term.text)
301
+
302
+ async def get_semantic_refs(
303
+ self,
304
+ semantic_refs: ISemanticRefCollection,
305
+ predicate: Callable[[SemanticRef], bool],
306
+ ) -> list[SemanticRef]:
307
+ result = []
308
+ for match in self:
309
+ semantic_ref = await semantic_refs.get_item(match.value)
310
+ if predicate is None or predicate(semantic_ref):
311
+ result.append(semantic_ref)
312
+ return result
313
+
314
+ def get_matches_of_type[T: Knowledge](
315
+ self,
316
+ semantic_refs: list[SemanticRef],
317
+ knowledgeType: KnowledgeType,
318
+ predicate: KnowledgePredicate[T] | None = None,
319
+ ) -> Iterable[Match[SemanticRefOrdinal]]:
320
+ for match in self:
321
+ semantic_ref = semantic_refs[match.value]
322
+ if predicate is None or predicate(cast(T, semantic_ref.knowledge)):
323
+ yield match
324
+
325
+ async def group_matches_by_type(
326
+ self,
327
+ semantic_refs: ISemanticRefCollection,
328
+ ) -> dict[KnowledgeType, "SemanticRefAccumulator"]:
329
+ groups: dict[KnowledgeType, SemanticRefAccumulator] = {}
330
+ for match in self:
331
+ semantic_ref = await semantic_refs.get_item(match.value)
332
+ group = groups.get(semantic_ref.knowledge.knowledge_type)
333
+ if group is None:
334
+ group = SemanticRefAccumulator()
335
+ group.search_term_matches = self.search_term_matches
336
+ groups[semantic_ref.knowledge.knowledge_type] = group
337
+ group.set_match(match)
338
+ return groups
339
+
340
+ async def get_matches_in_scope(
341
+ self,
342
+ semantic_refs: ISemanticRefCollection,
343
+ ranges_in_scope: "TextRangesInScope",
344
+ ) -> "SemanticRefAccumulator":
345
+ accumulator = SemanticRefAccumulator(self.search_term_matches)
346
+ for match in self:
347
+ if ranges_in_scope.is_range_in_scope(
348
+ (await semantic_refs.get_item(match.value)).range
349
+ ):
350
+ accumulator.set_match(match)
351
+ return accumulator
352
+
353
+ def add_union(self, other: "MatchAccumulator[SemanticRefOrdinal]") -> None:
354
+ """Add matches from another SemanticRefAccumulator."""
355
+ assert isinstance(other, SemanticRefAccumulator)
356
+ super().add_union(other)
357
+ self.search_term_matches.update(other.search_term_matches)
358
+
359
+ def intersect(
360
+ self,
361
+ other: MatchAccumulator[SemanticRefOrdinal],
362
+ intersection: MatchAccumulator[SemanticRefOrdinal] | None = None,
363
+ ) -> "SemanticRefAccumulator":
364
+ """Intersect with another SemanticRefAccumulator."""
365
+ assert isinstance(other, SemanticRefAccumulator)
366
+ if intersection is None:
367
+ intersection = SemanticRefAccumulator()
368
+ else:
369
+ assert isinstance(intersection, SemanticRefAccumulator)
370
+ super().intersect(other, intersection)
371
+ if len(intersection) > 0:
372
+ intersection.search_term_matches.update(self.search_term_matches)
373
+ intersection.search_term_matches.update(other.search_term_matches)
374
+ return intersection
375
+
376
+ def to_scored_semantic_refs(self) -> list[ScoredSemanticRefOrdinal]:
377
+ """Convert the accumulator to a list of scored semantic references."""
378
+ return [
379
+ ScoredSemanticRefOrdinal(
380
+ semantic_ref_ordinal=match.value,
381
+ score=match.score,
382
+ )
383
+ for match in self.get_sorted_by_score()
384
+ ]
385
+
386
+
387
+ class MessageAccumulator(MatchAccumulator[MessageOrdinal]):
388
+ def __init__(self, matches: list[Match[MessageOrdinal]] | None = None):
389
+ super().__init__()
390
+ if matches:
391
+ self.set_matches(matches)
392
+
393
+ def add(
394
+ self, value: MessageOrdinal, score: float, is_exact_match: bool = True
395
+ ) -> None:
396
+ match = self.get_match(value)
397
+ if match is None:
398
+ match = Match(value, score, 1, 0.0, 0)
399
+ self.set_match(match)
400
+ elif score > match.score:
401
+ match.score = score
402
+ # TODO: Question(Guido->Umesh): Why not increment hit_count always?
403
+ match.hit_count += 1
404
+
405
+ # TODO: add_messages_from_locations
406
+
407
+ def add_messages_for_semantic_ref(
408
+ self,
409
+ semantic_ref: SemanticRef,
410
+ score: float,
411
+ ) -> None:
412
+ message_ordinal_start = semantic_ref.range.start.message_ordinal
413
+ if semantic_ref.range.end is not None:
414
+ message_ordinal_end = semantic_ref.range.end.message_ordinal
415
+ for message_ordinal in range(
416
+ message_ordinal_start, message_ordinal_end + 1
417
+ ):
418
+ self.add(message_ordinal, score)
419
+ else:
420
+ self.add(message_ordinal_start, score)
421
+
422
+ def add_scored_matches(self, scored_ordinals: list[ScoredMessageOrdinal]) -> None:
423
+ """Add scored message ordinals to the accumulator."""
424
+ for scored_ordinal in scored_ordinals:
425
+ self.add(scored_ordinal.message_ordinal, scored_ordinal.score)
426
+
427
+ def intersect(
428
+ self,
429
+ other: MatchAccumulator[MessageOrdinal],
430
+ intersection: MatchAccumulator[MessageOrdinal] | None = None,
431
+ ) -> "MessageAccumulator":
432
+ if intersection is None:
433
+ intersection = MessageAccumulator()
434
+ else:
435
+ assert isinstance(intersection, MessageAccumulator)
436
+ super().intersect(other, intersection)
437
+ return intersection
438
+
439
+ def smooth_scores(self) -> None:
440
+ for match in self:
441
+ smooth_match_score(match)
442
+
443
+ def to_scored_message_ordinals(self) -> list[ScoredMessageOrdinal]:
444
+ sorted_matches = self.get_sorted_by_score()
445
+ return [ScoredMessageOrdinal(m.value, m.score) for m in sorted_matches]
446
+
447
+ async def select_messages_in_budget(
448
+ self, messages: IMessageCollection, max_chars_in_budget: int
449
+ ) -> None:
450
+ """Select messages that fit within the character budget."""
451
+ scored_matches = self.get_sorted_by_score()
452
+ ranked_ordinals = [m.value for m in scored_matches]
453
+ message_count_in_budget = await get_count_of_messages_in_char_budget(
454
+ messages, ranked_ordinals, max_chars_in_budget
455
+ )
456
+ self.clear_matches()
457
+ if message_count_in_budget > 0:
458
+ scored_matches = scored_matches[:message_count_in_budget]
459
+ self.set_matches(scored_matches)
460
+
461
+ @staticmethod
462
+ def from_scored_ordinals(
463
+ ordinals: list[ScoredMessageOrdinal] | None,
464
+ ) -> "MessageAccumulator":
465
+ """Create a MessageAccumulator from scored ordinals."""
466
+ accumulator = MessageAccumulator()
467
+ if ordinals and len(ordinals) > 0:
468
+ accumulator.add_scored_matches(ordinals)
469
+ return accumulator
470
+
471
+
472
+ # TODO: intersectScoredMessageOrdinals
473
+
474
+
475
+ @dataclass
476
+ class TextRangeCollection(Iterable[TextRange]):
477
+ _ranges: list[TextRange]
478
+
479
+ def __init__(
480
+ self,
481
+ ranges: list[TextRange] | None = None,
482
+ ensure_sorted: bool = False,
483
+ ) -> None:
484
+ if ensure_sorted:
485
+ self._ranges = []
486
+ if ranges:
487
+ self.add_ranges(ranges)
488
+ else:
489
+ self._ranges = ranges if ranges is not None else []
490
+
491
+ def __len__(self) -> int:
492
+ return len(self._ranges)
493
+
494
+ def __iter__(self) -> Iterator[TextRange]:
495
+ return iter(self._ranges)
496
+
497
+ def get_ranges(self) -> list[TextRange]:
498
+ return self._ranges # TODO: Maybe return a copy?
499
+
500
+ def add_range(self, text_range: TextRange) -> bool:
501
+ # This assumes TextRanges are totally ordered.
502
+ pos = bisect.bisect_left(self._ranges, text_range)
503
+ if pos < len(self._ranges) and self._ranges[pos] == text_range:
504
+ return False
505
+ self._ranges.insert(pos, text_range)
506
+ return True
507
+
508
+ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> None:
509
+ if isinstance(text_ranges, list):
510
+ for text_range in text_ranges:
511
+ self.add_range(text_range)
512
+ else:
513
+ assert isinstance(text_ranges, TextRangeCollection)
514
+ for text_range in text_ranges._ranges:
515
+ self.add_range(text_range)
516
+
517
+ def is_in_range(self, inner_range: TextRange) -> bool:
518
+ if len(self._ranges) == 0:
519
+ return False
520
+ i = bisect.bisect_left(self._ranges, inner_range)
521
+ for outer_range in self._ranges[i:]:
522
+ if outer_range.start > inner_range.start:
523
+ break
524
+ if inner_range in outer_range:
525
+ return True
526
+ return False
527
+
528
+
529
+ @dataclass
530
+ class TextRangesInScope:
531
+ text_ranges: list[TextRangeCollection] | None = None
532
+
533
+ def add_text_ranges(
534
+ self,
535
+ ranges: TextRangeCollection,
536
+ ) -> None:
537
+ if self.text_ranges is None:
538
+ self.text_ranges = []
539
+ self.text_ranges.append(ranges)
540
+
541
+ def is_range_in_scope(self, inner_range: TextRange) -> bool:
542
+ if self.text_ranges is not None:
543
+ # Since outer ranges come from a set of range selectors, they may overlap, or may not agree.
544
+ # Outer ranges allowed by say a date range selector... may not be allowed by a tag selector.
545
+ # We have a very simple impl: we don't intersect/union ranges yet.
546
+ # Instead, we ensure that the inner range is not rejected by any outer ranges.
547
+ for outer_ranges in self.text_ranges:
548
+ if not outer_ranges.is_in_range(inner_range):
549
+ return False
550
+ return True
551
+
552
+
553
+ @dataclass
554
+ class TermSet:
555
+ """A collection of terms with support for adding, updating, and retrieving terms."""
556
+
557
+ terms: dict[str, Term]
558
+
559
+ def __init__(self, terms: list[Term] | None = None):
560
+ self.terms = {}
561
+ self.add_or_union(terms)
562
+
563
+ def __len__(self) -> int:
564
+ """Return the number of terms in the set."""
565
+ return len(self.terms)
566
+
567
+ def add(self, term: Term) -> bool:
568
+ """Add a term to the set if it doesn't already exist."""
569
+ if term.text in self.terms:
570
+ return False
571
+ self.terms[term.text] = term
572
+ return True
573
+
574
+ def add_or_union(self, terms: Term | list[Term] | None) -> None:
575
+ """Add a term or merge a list of terms into the set."""
576
+ if terms is None:
577
+ return
578
+ if isinstance(terms, list):
579
+ for term in terms:
580
+ self.add_or_union(term)
581
+ else:
582
+ existing_term = self.terms.get(terms.text)
583
+ if existing_term:
584
+ existing_score = existing_term.weight or 0
585
+ new_score = terms.weight or 0
586
+ if new_score > existing_score:
587
+ existing_term.weight = new_score
588
+ else:
589
+ self.terms[terms.text] = terms
590
+
591
+ def get(self, term: str | Term) -> Term | None:
592
+ """Retrieve a term by its text."""
593
+ return self.terms.get(term if isinstance(term, str) else term.text)
594
+
595
+ def get_weight(self, term: Term) -> float | None:
596
+ """Retrieve the weight of a term."""
597
+ t = self.terms.get(term.text)
598
+ return t.weight if t is not None else None
599
+
600
+ def __contains__(self, term: Term) -> bool:
601
+ """Check if a term exists in the set."""
602
+ return term.text in self.terms
603
+
604
+ def remove(self, term: Term):
605
+ """Remove a term from the set, if present."""
606
+ self.terms.pop(term.text, None)
607
+
608
+ def clear(self):
609
+ """Clear all terms from the set."""
610
+ self.terms.clear()
611
+
612
+ def values(self) -> list[Term]:
613
+ """Retrieve all terms in the set."""
614
+ return list(self.terms.values())
615
+
616
+
617
+ @dataclass
618
+ class PropertyTermSet:
619
+ """A collection of property terms with support for adding, checking, and clearing."""
620
+
621
+ terms: dict[str, Term] = field(default_factory=dict[str, Term])
622
+
623
+ def add(self, property_name: str, property_value: Term) -> None:
624
+ """Add a property term to the set."""
625
+ key = self._make_key(property_name, property_value)
626
+ if key not in self.terms:
627
+ self.terms[key] = property_value
628
+
629
+ def has(self, property_name: str, property_value: Term | str) -> bool:
630
+ """Check if a property term exists in the set."""
631
+ key = self._make_key(property_name, property_value)
632
+ return key in self.terms
633
+
634
+ def clear(self) -> None:
635
+ """Clear all property terms from the set."""
636
+ self.terms.clear()
637
+
638
+ def _make_key(self, property_name: str, property_value: Term | str) -> str:
639
+ """Create a unique key for a property term."""
640
+ value = (
641
+ property_value if isinstance(property_value, str) else property_value.text
642
+ )
643
+ return f"{property_name}:{value}"
644
+
645
+
646
+ # TODO: unionArrays
647
+ # TODO: union
648
+ # TODO: addToSet
649
+ # TODO: setUnion
650
+ # TODO: setIntersect
651
+ # TODO: getBatches
652
+
653
+
654
+ @dataclass
655
+ class Scored[T]:
656
+ item: T
657
+ score: float
658
+
659
+ def __lt__(self, other: "Scored[T]") -> bool:
660
+ return self.score < other.score
661
+
662
+ def __gt__(self, other: "Scored[T]") -> bool:
663
+ return self.score > other.score
664
+
665
+ def __le__(self, other: "Scored[T]") -> bool:
666
+ return self.score <= other.score
667
+
668
+ def __ge__(self, other: "Scored[T]") -> bool:
669
+ return self.score >= other.score
670
+
671
+
672
+ # Implementation change compared to TS version: Use heapq; no sentinel.
673
+ # API change: pop/top are not properties.
674
+ class TopNCollection[T]:
675
+ """A collection that maintains the top N items based on their scores."""
676
+
677
+ def __init__(self, max_count: int):
678
+ self._max_count = max_count
679
+ self._heap: list[Scored[T]] = []
680
+
681
+ def __len__(self) -> int:
682
+ return len(self._heap)
683
+
684
+ def reset(self) -> None:
685
+ self._heap = []
686
+
687
+ def pop(self) -> Scored[T]:
688
+ return heapq.heappop(self._heap)
689
+
690
+ def top(self) -> Scored[T]:
691
+ return self._heap[0]
692
+
693
+ def push(self, item: T, score: float) -> None:
694
+ if len(self._heap) < self._max_count:
695
+ heapq.heappush(self._heap, Scored(item, score))
696
+ else:
697
+ heapq.heappushpop(self._heap, Scored(item, score))
698
+
699
+ def by_rank(self) -> list[Scored[T]]:
700
+ return sorted(self._heap, reverse=True)
701
+
702
+ def values_by_rank(self) -> list[T]:
703
+ return [item.item for item in self.by_rank()]
704
+
705
+
706
+ class TopNList[T](TopNCollection[T]):
707
+ """Alias for TopNCollection."""
708
+
709
+
710
+ class TopNListAll[T](TopNList[T]):
711
+ """A Top N list for N = infinity (approximated by sys.maxsize)."""
712
+
713
+ def __init__(self):
714
+ super().__init__(sys.maxsize)
715
+
716
+
717
+ def get_top_k[T](
718
+ scored_items: Iterable[Scored[T]],
719
+ top_k: int,
720
+ ) -> list[Scored[T]]:
721
+ """A function to get the top K of an unsorted list of scored items."""
722
+ top_n_list = TopNCollection[T](top_k)
723
+ for scored_item in scored_items:
724
+ top_n_list.push(scored_item.item, scored_item.score)
725
+ return top_n_list.by_rank()
726
+
727
+
728
+ def add_to_set[T](
729
+ set: Set[T],
730
+ values: Iterable[T],
731
+ ) -> None:
732
+ """Add values to a set."""
733
+ set.update(values)
734
+
735
+
736
+ def get_message_char_count(message: IMessage) -> int:
737
+ """Get the character count of a message."""
738
+ total = 0
739
+ for chunk in message.text_chunks:
740
+ total += len(chunk)
741
+ return total
742
+
743
+
744
+ async def get_count_of_messages_in_char_budget(
745
+ messages: IMessageCollection,
746
+ message_ordinals: Iterable[MessageOrdinal],
747
+ max_chars_in_budget: int,
748
+ ) -> int:
749
+ """Get the count of messages that fit within the character budget."""
750
+ i = 0
751
+ total_char_count = 0
752
+ for message_ordinal in message_ordinals:
753
+ message = await messages.get_item(message_ordinal)
754
+ message_char_count = get_message_char_count(message)
755
+ if message_char_count + total_char_count > max_chars_in_budget:
756
+ break
757
+ total_char_count += message_char_count
758
+ i += 1
759
+ return i