typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,759 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
import bisect
|
5
|
+
from collections.abc import Callable, Iterable, Iterator
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
import heapq
|
8
|
+
import math
|
9
|
+
import sys
|
10
|
+
from typing import cast, Set # Set is an alias for builtin set
|
11
|
+
|
12
|
+
from .interfaces import (
|
13
|
+
ICollection,
|
14
|
+
IMessage,
|
15
|
+
IMessageCollection,
|
16
|
+
ISemanticRefCollection,
|
17
|
+
Knowledge,
|
18
|
+
KnowledgeType,
|
19
|
+
MessageOrdinal,
|
20
|
+
ScoredMessageOrdinal,
|
21
|
+
ScoredSemanticRefOrdinal,
|
22
|
+
SemanticRef,
|
23
|
+
SemanticRefOrdinal,
|
24
|
+
Term,
|
25
|
+
TextRange,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class Match[T]:
|
31
|
+
value: T
|
32
|
+
score: float
|
33
|
+
hit_count: int
|
34
|
+
related_score: float
|
35
|
+
related_hit_count: int
|
36
|
+
|
37
|
+
|
38
|
+
# TODO: sortMatchesByRelevance
|
39
|
+
|
40
|
+
|
41
|
+
class MatchAccumulator[T]:
|
42
|
+
def __init__(self):
|
43
|
+
self._matches: dict[T, Match[T]] = {}
|
44
|
+
|
45
|
+
def __len__(self) -> int:
|
46
|
+
return len(self._matches)
|
47
|
+
|
48
|
+
def __iter__(self) -> Iterator[Match[T]]:
|
49
|
+
return iter(self._matches.values())
|
50
|
+
|
51
|
+
def __contains__(self, value: T) -> bool:
|
52
|
+
return value in self._matches
|
53
|
+
|
54
|
+
def get_match(self, value: T) -> Match[T] | None:
|
55
|
+
return self._matches.get(value)
|
56
|
+
|
57
|
+
def set_match(self, match: Match[T]) -> None:
|
58
|
+
self._matches[match.value] = match
|
59
|
+
|
60
|
+
# TODO: Maybe make the callers call clear_matches()?
|
61
|
+
def set_matches(self, matches: Iterable[Match[T]], *, clear: bool = False) -> None:
|
62
|
+
if clear:
|
63
|
+
self.clear_matches()
|
64
|
+
for match in matches:
|
65
|
+
self.set_match(match)
|
66
|
+
|
67
|
+
def get_max_hit_count(self) -> int:
|
68
|
+
count = 0
|
69
|
+
for match in self._matches.values():
|
70
|
+
count = max(count, match.hit_count)
|
71
|
+
return count
|
72
|
+
|
73
|
+
# TODO: Rename to add_exact if we ever add add_related
|
74
|
+
def add(self, value: T, score: float, is_exact_match: bool = True) -> None:
|
75
|
+
existing_match = self.get_match(value)
|
76
|
+
if existing_match is not None:
|
77
|
+
if is_exact_match:
|
78
|
+
existing_match.hit_count += 1
|
79
|
+
existing_match.score += score
|
80
|
+
else:
|
81
|
+
existing_match.related_hit_count += 1
|
82
|
+
existing_match.related_score += score
|
83
|
+
else:
|
84
|
+
if is_exact_match:
|
85
|
+
self.set_match(
|
86
|
+
Match(
|
87
|
+
value,
|
88
|
+
hit_count=1,
|
89
|
+
score=score,
|
90
|
+
related_hit_count=0,
|
91
|
+
related_score=0.0,
|
92
|
+
)
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
self.set_match(
|
96
|
+
Match(
|
97
|
+
value,
|
98
|
+
hit_count=1,
|
99
|
+
score=0.0,
|
100
|
+
related_hit_count=1,
|
101
|
+
related_score=score,
|
102
|
+
)
|
103
|
+
)
|
104
|
+
|
105
|
+
def add_union(self, other: "MatchAccumulator[T]") -> None:
|
106
|
+
"""Add matches from another collection of matches."""
|
107
|
+
for other_match in other:
|
108
|
+
existing_match = self.get_match(other_match.value)
|
109
|
+
if existing_match is None:
|
110
|
+
self.set_match(other_match)
|
111
|
+
else:
|
112
|
+
self.combine_matches(existing_match, other_match)
|
113
|
+
|
114
|
+
def intersect(
|
115
|
+
self, other: "MatchAccumulator[T]", intersection: "MatchAccumulator[T]"
|
116
|
+
) -> "MatchAccumulator[T]":
|
117
|
+
"""Intersect with another collection of matches."""
|
118
|
+
for self_match in self:
|
119
|
+
other_match = other.get_match(self_match.value)
|
120
|
+
if other_match is not None:
|
121
|
+
self.combine_matches(self_match, other_match)
|
122
|
+
intersection.set_match(self_match)
|
123
|
+
return intersection
|
124
|
+
|
125
|
+
def combine_matches(self, match: Match[T], other: Match[T]) -> None:
|
126
|
+
"""Combine the other match into the first."""
|
127
|
+
match.hit_count += other.hit_count
|
128
|
+
match.score += other.score
|
129
|
+
match.related_hit_count += other.related_hit_count
|
130
|
+
match.related_score += other.related_score
|
131
|
+
|
132
|
+
def calculate_total_score(
|
133
|
+
self, scorer: Callable[[Match[T]], None] | None = None
|
134
|
+
) -> None:
|
135
|
+
if scorer is None:
|
136
|
+
scorer = add_smooth_related_score_to_match_score
|
137
|
+
for match in self:
|
138
|
+
scorer(match)
|
139
|
+
|
140
|
+
def get_sorted_by_score(self, min_hit_count: int | None = None) -> list[Match[T]]:
|
141
|
+
"""Get matches sorted by score"""
|
142
|
+
if len(self._matches) == 0:
|
143
|
+
return []
|
144
|
+
matches = [*self._matches_with_min_hit_count(min_hit_count)]
|
145
|
+
matches.sort(key=lambda m: m.score, reverse=True)
|
146
|
+
return matches
|
147
|
+
|
148
|
+
def get_top_n_scoring(
|
149
|
+
self,
|
150
|
+
max_matches: int | None = None,
|
151
|
+
min_hit_count: int | None = None,
|
152
|
+
) -> list[Match[T]]:
|
153
|
+
"""Get the top N scoring matches."""
|
154
|
+
if not self._matches:
|
155
|
+
return []
|
156
|
+
if max_matches and max_matches > 0:
|
157
|
+
top_list = TopNList[T](max_matches)
|
158
|
+
for match in self._matches_with_min_hit_count(min_hit_count):
|
159
|
+
top_list.push(match.value, match.score)
|
160
|
+
ranked = top_list.by_rank()
|
161
|
+
return [self._matches[match.item] for match in ranked]
|
162
|
+
else:
|
163
|
+
return self.get_sorted_by_score(min_hit_count)
|
164
|
+
|
165
|
+
def get_with_hit_count(self, min_hit_count: int) -> list[Match[T]]:
|
166
|
+
"""Get matches with a minimum hit count."""
|
167
|
+
return list(self.matches_with_min_hit_count(min_hit_count))
|
168
|
+
|
169
|
+
def get_matches(
|
170
|
+
self, predicate: Callable[[Match[T]], bool] | None = None
|
171
|
+
) -> Iterator[Match[T]]:
|
172
|
+
"""Iterate over all matches."""
|
173
|
+
if predicate is None:
|
174
|
+
return iter(self._matches.values())
|
175
|
+
else:
|
176
|
+
return filter(predicate, self._matches.values())
|
177
|
+
|
178
|
+
def get_matched_values(self) -> Iterator[T]:
|
179
|
+
"""Iterate over all matched values."""
|
180
|
+
return iter(self._matches)
|
181
|
+
|
182
|
+
def clear_matches(self):
|
183
|
+
self._matches.clear()
|
184
|
+
|
185
|
+
def select_top_n_scoring(
|
186
|
+
self,
|
187
|
+
max_matches: int | None = None,
|
188
|
+
min_hit_count: int | None = None,
|
189
|
+
) -> int:
|
190
|
+
"""Retain only the top N matches sorted by score."""
|
191
|
+
top_n = self.get_top_n_scoring(max_matches, min_hit_count)
|
192
|
+
self.set_matches(top_n, clear=True)
|
193
|
+
return len(top_n)
|
194
|
+
|
195
|
+
def select_with_hit_count(self, min_hit_count: int) -> int:
|
196
|
+
"""Retain only matches with a minimum hit count."""
|
197
|
+
matches = self.get_with_hit_count(min_hit_count)
|
198
|
+
self.set_matches(matches, clear=True)
|
199
|
+
return len(matches)
|
200
|
+
|
201
|
+
def _matches_with_min_hit_count(
|
202
|
+
self, min_hit_count: int | None
|
203
|
+
) -> Iterable[Match[T]]:
|
204
|
+
"""Get matches with a minimum hit count"""
|
205
|
+
if min_hit_count is not None and min_hit_count > 0:
|
206
|
+
return self.get_matches(lambda m: m.hit_count >= min_hit_count)
|
207
|
+
else:
|
208
|
+
return self._matches.values()
|
209
|
+
|
210
|
+
def matches_with_min_hit_count(
|
211
|
+
self, min_hit_count: int | None
|
212
|
+
) -> Iterable[Match[T]]:
|
213
|
+
if min_hit_count is not None and min_hit_count > 0:
|
214
|
+
return filter(lambda m: m.hit_count >= min_hit_count, self.get_matches())
|
215
|
+
else:
|
216
|
+
return self._matches.values()
|
217
|
+
|
218
|
+
|
219
|
+
def get_smooth_score(
|
220
|
+
total_score: float,
|
221
|
+
hit_count: int,
|
222
|
+
) -> float:
|
223
|
+
"""See the long comment in collections.ts for an explanation."""
|
224
|
+
if hit_count > 0:
|
225
|
+
if hit_count == 1:
|
226
|
+
return total_score
|
227
|
+
avg = total_score / hit_count
|
228
|
+
smooth_avg = math.log(hit_count + 1) * avg
|
229
|
+
return smooth_avg
|
230
|
+
else:
|
231
|
+
return 0.0
|
232
|
+
|
233
|
+
|
234
|
+
def add_smooth_related_score_to_match_score[T](match: Match[T]) -> None:
|
235
|
+
"""Add the smooth related score to the match score."""
|
236
|
+
if match.related_hit_count > 0:
|
237
|
+
# Related term matches can be noisy and duplicative.
|
238
|
+
# See the comment on getSmoothScore in collections.ts.
|
239
|
+
smooth_related_score = get_smooth_score(
|
240
|
+
match.related_score, match.related_hit_count
|
241
|
+
)
|
242
|
+
match.score += smooth_related_score
|
243
|
+
|
244
|
+
|
245
|
+
def smooth_match_score[T](match: Match[T]) -> None:
|
246
|
+
if match.hit_count > 0:
|
247
|
+
match.score = get_smooth_score(match.score, match.hit_count)
|
248
|
+
|
249
|
+
|
250
|
+
type KnowledgePredicate[T: Knowledge] = Callable[[T], bool]
|
251
|
+
|
252
|
+
|
253
|
+
class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]):
|
254
|
+
def __init__(self, search_term_matches: set[str] = set()):
|
255
|
+
super().__init__()
|
256
|
+
self.search_term_matches = search_term_matches
|
257
|
+
|
258
|
+
def add_term_matches(
|
259
|
+
self,
|
260
|
+
search_term: Term,
|
261
|
+
scored_refs: Iterable[ScoredSemanticRefOrdinal] | None,
|
262
|
+
is_exact_match: bool,
|
263
|
+
*,
|
264
|
+
weight: float | None = None,
|
265
|
+
) -> None:
|
266
|
+
"""Add term matches to the accumulator"""
|
267
|
+
if scored_refs is not None:
|
268
|
+
if weight is None:
|
269
|
+
weight = search_term.weight
|
270
|
+
if weight is None:
|
271
|
+
weight = 1.0
|
272
|
+
for scored_ref in scored_refs:
|
273
|
+
self.add(
|
274
|
+
scored_ref.semantic_ref_ordinal,
|
275
|
+
scored_ref.score * weight,
|
276
|
+
is_exact_match,
|
277
|
+
)
|
278
|
+
self.search_term_matches.add(search_term.text)
|
279
|
+
|
280
|
+
def add_term_matches_if_new(
|
281
|
+
self,
|
282
|
+
search_term: Term,
|
283
|
+
scored_refs: Iterable[ScoredSemanticRefOrdinal] | None,
|
284
|
+
is_exact_match: bool,
|
285
|
+
weight: float | None = None,
|
286
|
+
) -> None:
|
287
|
+
"""Add term matches if they are new."""
|
288
|
+
if scored_refs is not None:
|
289
|
+
if weight is None:
|
290
|
+
weight = search_term.weight
|
291
|
+
if weight is None:
|
292
|
+
weight = 1.0
|
293
|
+
for scored_ref in scored_refs:
|
294
|
+
if scored_ref.semantic_ref_ordinal not in self:
|
295
|
+
self.add(
|
296
|
+
scored_ref.semantic_ref_ordinal,
|
297
|
+
scored_ref.score * weight,
|
298
|
+
is_exact_match,
|
299
|
+
)
|
300
|
+
self.search_term_matches.add(search_term.text)
|
301
|
+
|
302
|
+
async def get_semantic_refs(
|
303
|
+
self,
|
304
|
+
semantic_refs: ISemanticRefCollection,
|
305
|
+
predicate: Callable[[SemanticRef], bool],
|
306
|
+
) -> list[SemanticRef]:
|
307
|
+
result = []
|
308
|
+
for match in self:
|
309
|
+
semantic_ref = await semantic_refs.get_item(match.value)
|
310
|
+
if predicate is None or predicate(semantic_ref):
|
311
|
+
result.append(semantic_ref)
|
312
|
+
return result
|
313
|
+
|
314
|
+
def get_matches_of_type[T: Knowledge](
|
315
|
+
self,
|
316
|
+
semantic_refs: list[SemanticRef],
|
317
|
+
knowledgeType: KnowledgeType,
|
318
|
+
predicate: KnowledgePredicate[T] | None = None,
|
319
|
+
) -> Iterable[Match[SemanticRefOrdinal]]:
|
320
|
+
for match in self:
|
321
|
+
semantic_ref = semantic_refs[match.value]
|
322
|
+
if predicate is None or predicate(cast(T, semantic_ref.knowledge)):
|
323
|
+
yield match
|
324
|
+
|
325
|
+
async def group_matches_by_type(
|
326
|
+
self,
|
327
|
+
semantic_refs: ISemanticRefCollection,
|
328
|
+
) -> dict[KnowledgeType, "SemanticRefAccumulator"]:
|
329
|
+
groups: dict[KnowledgeType, SemanticRefAccumulator] = {}
|
330
|
+
for match in self:
|
331
|
+
semantic_ref = await semantic_refs.get_item(match.value)
|
332
|
+
group = groups.get(semantic_ref.knowledge.knowledge_type)
|
333
|
+
if group is None:
|
334
|
+
group = SemanticRefAccumulator()
|
335
|
+
group.search_term_matches = self.search_term_matches
|
336
|
+
groups[semantic_ref.knowledge.knowledge_type] = group
|
337
|
+
group.set_match(match)
|
338
|
+
return groups
|
339
|
+
|
340
|
+
async def get_matches_in_scope(
|
341
|
+
self,
|
342
|
+
semantic_refs: ISemanticRefCollection,
|
343
|
+
ranges_in_scope: "TextRangesInScope",
|
344
|
+
) -> "SemanticRefAccumulator":
|
345
|
+
accumulator = SemanticRefAccumulator(self.search_term_matches)
|
346
|
+
for match in self:
|
347
|
+
if ranges_in_scope.is_range_in_scope(
|
348
|
+
(await semantic_refs.get_item(match.value)).range
|
349
|
+
):
|
350
|
+
accumulator.set_match(match)
|
351
|
+
return accumulator
|
352
|
+
|
353
|
+
def add_union(self, other: "MatchAccumulator[SemanticRefOrdinal]") -> None:
|
354
|
+
"""Add matches from another SemanticRefAccumulator."""
|
355
|
+
assert isinstance(other, SemanticRefAccumulator)
|
356
|
+
super().add_union(other)
|
357
|
+
self.search_term_matches.update(other.search_term_matches)
|
358
|
+
|
359
|
+
def intersect(
|
360
|
+
self,
|
361
|
+
other: MatchAccumulator[SemanticRefOrdinal],
|
362
|
+
intersection: MatchAccumulator[SemanticRefOrdinal] | None = None,
|
363
|
+
) -> "SemanticRefAccumulator":
|
364
|
+
"""Intersect with another SemanticRefAccumulator."""
|
365
|
+
assert isinstance(other, SemanticRefAccumulator)
|
366
|
+
if intersection is None:
|
367
|
+
intersection = SemanticRefAccumulator()
|
368
|
+
else:
|
369
|
+
assert isinstance(intersection, SemanticRefAccumulator)
|
370
|
+
super().intersect(other, intersection)
|
371
|
+
if len(intersection) > 0:
|
372
|
+
intersection.search_term_matches.update(self.search_term_matches)
|
373
|
+
intersection.search_term_matches.update(other.search_term_matches)
|
374
|
+
return intersection
|
375
|
+
|
376
|
+
def to_scored_semantic_refs(self) -> list[ScoredSemanticRefOrdinal]:
|
377
|
+
"""Convert the accumulator to a list of scored semantic references."""
|
378
|
+
return [
|
379
|
+
ScoredSemanticRefOrdinal(
|
380
|
+
semantic_ref_ordinal=match.value,
|
381
|
+
score=match.score,
|
382
|
+
)
|
383
|
+
for match in self.get_sorted_by_score()
|
384
|
+
]
|
385
|
+
|
386
|
+
|
387
|
+
class MessageAccumulator(MatchAccumulator[MessageOrdinal]):
|
388
|
+
def __init__(self, matches: list[Match[MessageOrdinal]] | None = None):
|
389
|
+
super().__init__()
|
390
|
+
if matches:
|
391
|
+
self.set_matches(matches)
|
392
|
+
|
393
|
+
def add(
|
394
|
+
self, value: MessageOrdinal, score: float, is_exact_match: bool = True
|
395
|
+
) -> None:
|
396
|
+
match = self.get_match(value)
|
397
|
+
if match is None:
|
398
|
+
match = Match(value, score, 1, 0.0, 0)
|
399
|
+
self.set_match(match)
|
400
|
+
elif score > match.score:
|
401
|
+
match.score = score
|
402
|
+
# TODO: Question(Guido->Umesh): Why not increment hit_count always?
|
403
|
+
match.hit_count += 1
|
404
|
+
|
405
|
+
# TODO: add_messages_from_locations
|
406
|
+
|
407
|
+
def add_messages_for_semantic_ref(
|
408
|
+
self,
|
409
|
+
semantic_ref: SemanticRef,
|
410
|
+
score: float,
|
411
|
+
) -> None:
|
412
|
+
message_ordinal_start = semantic_ref.range.start.message_ordinal
|
413
|
+
if semantic_ref.range.end is not None:
|
414
|
+
message_ordinal_end = semantic_ref.range.end.message_ordinal
|
415
|
+
for message_ordinal in range(
|
416
|
+
message_ordinal_start, message_ordinal_end + 1
|
417
|
+
):
|
418
|
+
self.add(message_ordinal, score)
|
419
|
+
else:
|
420
|
+
self.add(message_ordinal_start, score)
|
421
|
+
|
422
|
+
def add_scored_matches(self, scored_ordinals: list[ScoredMessageOrdinal]) -> None:
|
423
|
+
"""Add scored message ordinals to the accumulator."""
|
424
|
+
for scored_ordinal in scored_ordinals:
|
425
|
+
self.add(scored_ordinal.message_ordinal, scored_ordinal.score)
|
426
|
+
|
427
|
+
def intersect(
|
428
|
+
self,
|
429
|
+
other: MatchAccumulator[MessageOrdinal],
|
430
|
+
intersection: MatchAccumulator[MessageOrdinal] | None = None,
|
431
|
+
) -> "MessageAccumulator":
|
432
|
+
if intersection is None:
|
433
|
+
intersection = MessageAccumulator()
|
434
|
+
else:
|
435
|
+
assert isinstance(intersection, MessageAccumulator)
|
436
|
+
super().intersect(other, intersection)
|
437
|
+
return intersection
|
438
|
+
|
439
|
+
def smooth_scores(self) -> None:
|
440
|
+
for match in self:
|
441
|
+
smooth_match_score(match)
|
442
|
+
|
443
|
+
def to_scored_message_ordinals(self) -> list[ScoredMessageOrdinal]:
|
444
|
+
sorted_matches = self.get_sorted_by_score()
|
445
|
+
return [ScoredMessageOrdinal(m.value, m.score) for m in sorted_matches]
|
446
|
+
|
447
|
+
async def select_messages_in_budget(
|
448
|
+
self, messages: IMessageCollection, max_chars_in_budget: int
|
449
|
+
) -> None:
|
450
|
+
"""Select messages that fit within the character budget."""
|
451
|
+
scored_matches = self.get_sorted_by_score()
|
452
|
+
ranked_ordinals = [m.value for m in scored_matches]
|
453
|
+
message_count_in_budget = await get_count_of_messages_in_char_budget(
|
454
|
+
messages, ranked_ordinals, max_chars_in_budget
|
455
|
+
)
|
456
|
+
self.clear_matches()
|
457
|
+
if message_count_in_budget > 0:
|
458
|
+
scored_matches = scored_matches[:message_count_in_budget]
|
459
|
+
self.set_matches(scored_matches)
|
460
|
+
|
461
|
+
@staticmethod
|
462
|
+
def from_scored_ordinals(
|
463
|
+
ordinals: list[ScoredMessageOrdinal] | None,
|
464
|
+
) -> "MessageAccumulator":
|
465
|
+
"""Create a MessageAccumulator from scored ordinals."""
|
466
|
+
accumulator = MessageAccumulator()
|
467
|
+
if ordinals and len(ordinals) > 0:
|
468
|
+
accumulator.add_scored_matches(ordinals)
|
469
|
+
return accumulator
|
470
|
+
|
471
|
+
|
472
|
+
# TODO: intersectScoredMessageOrdinals
|
473
|
+
|
474
|
+
|
475
|
+
@dataclass
|
476
|
+
class TextRangeCollection(Iterable[TextRange]):
|
477
|
+
_ranges: list[TextRange]
|
478
|
+
|
479
|
+
def __init__(
|
480
|
+
self,
|
481
|
+
ranges: list[TextRange] | None = None,
|
482
|
+
ensure_sorted: bool = False,
|
483
|
+
) -> None:
|
484
|
+
if ensure_sorted:
|
485
|
+
self._ranges = []
|
486
|
+
if ranges:
|
487
|
+
self.add_ranges(ranges)
|
488
|
+
else:
|
489
|
+
self._ranges = ranges if ranges is not None else []
|
490
|
+
|
491
|
+
def __len__(self) -> int:
|
492
|
+
return len(self._ranges)
|
493
|
+
|
494
|
+
def __iter__(self) -> Iterator[TextRange]:
|
495
|
+
return iter(self._ranges)
|
496
|
+
|
497
|
+
def get_ranges(self) -> list[TextRange]:
|
498
|
+
return self._ranges # TODO: Maybe return a copy?
|
499
|
+
|
500
|
+
def add_range(self, text_range: TextRange) -> bool:
|
501
|
+
# This assumes TextRanges are totally ordered.
|
502
|
+
pos = bisect.bisect_left(self._ranges, text_range)
|
503
|
+
if pos < len(self._ranges) and self._ranges[pos] == text_range:
|
504
|
+
return False
|
505
|
+
self._ranges.insert(pos, text_range)
|
506
|
+
return True
|
507
|
+
|
508
|
+
def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> None:
|
509
|
+
if isinstance(text_ranges, list):
|
510
|
+
for text_range in text_ranges:
|
511
|
+
self.add_range(text_range)
|
512
|
+
else:
|
513
|
+
assert isinstance(text_ranges, TextRangeCollection)
|
514
|
+
for text_range in text_ranges._ranges:
|
515
|
+
self.add_range(text_range)
|
516
|
+
|
517
|
+
def is_in_range(self, inner_range: TextRange) -> bool:
|
518
|
+
if len(self._ranges) == 0:
|
519
|
+
return False
|
520
|
+
i = bisect.bisect_left(self._ranges, inner_range)
|
521
|
+
for outer_range in self._ranges[i:]:
|
522
|
+
if outer_range.start > inner_range.start:
|
523
|
+
break
|
524
|
+
if inner_range in outer_range:
|
525
|
+
return True
|
526
|
+
return False
|
527
|
+
|
528
|
+
|
529
|
+
@dataclass
|
530
|
+
class TextRangesInScope:
|
531
|
+
text_ranges: list[TextRangeCollection] | None = None
|
532
|
+
|
533
|
+
def add_text_ranges(
|
534
|
+
self,
|
535
|
+
ranges: TextRangeCollection,
|
536
|
+
) -> None:
|
537
|
+
if self.text_ranges is None:
|
538
|
+
self.text_ranges = []
|
539
|
+
self.text_ranges.append(ranges)
|
540
|
+
|
541
|
+
def is_range_in_scope(self, inner_range: TextRange) -> bool:
|
542
|
+
if self.text_ranges is not None:
|
543
|
+
# Since outer ranges come from a set of range selectors, they may overlap, or may not agree.
|
544
|
+
# Outer ranges allowed by say a date range selector... may not be allowed by a tag selector.
|
545
|
+
# We have a very simple impl: we don't intersect/union ranges yet.
|
546
|
+
# Instead, we ensure that the inner range is not rejected by any outer ranges.
|
547
|
+
for outer_ranges in self.text_ranges:
|
548
|
+
if not outer_ranges.is_in_range(inner_range):
|
549
|
+
return False
|
550
|
+
return True
|
551
|
+
|
552
|
+
|
553
|
+
@dataclass
|
554
|
+
class TermSet:
|
555
|
+
"""A collection of terms with support for adding, updating, and retrieving terms."""
|
556
|
+
|
557
|
+
terms: dict[str, Term]
|
558
|
+
|
559
|
+
def __init__(self, terms: list[Term] | None = None):
|
560
|
+
self.terms = {}
|
561
|
+
self.add_or_union(terms)
|
562
|
+
|
563
|
+
def __len__(self) -> int:
|
564
|
+
"""Return the number of terms in the set."""
|
565
|
+
return len(self.terms)
|
566
|
+
|
567
|
+
def add(self, term: Term) -> bool:
|
568
|
+
"""Add a term to the set if it doesn't already exist."""
|
569
|
+
if term.text in self.terms:
|
570
|
+
return False
|
571
|
+
self.terms[term.text] = term
|
572
|
+
return True
|
573
|
+
|
574
|
+
def add_or_union(self, terms: Term | list[Term] | None) -> None:
|
575
|
+
"""Add a term or merge a list of terms into the set."""
|
576
|
+
if terms is None:
|
577
|
+
return
|
578
|
+
if isinstance(terms, list):
|
579
|
+
for term in terms:
|
580
|
+
self.add_or_union(term)
|
581
|
+
else:
|
582
|
+
existing_term = self.terms.get(terms.text)
|
583
|
+
if existing_term:
|
584
|
+
existing_score = existing_term.weight or 0
|
585
|
+
new_score = terms.weight or 0
|
586
|
+
if new_score > existing_score:
|
587
|
+
existing_term.weight = new_score
|
588
|
+
else:
|
589
|
+
self.terms[terms.text] = terms
|
590
|
+
|
591
|
+
def get(self, term: str | Term) -> Term | None:
|
592
|
+
"""Retrieve a term by its text."""
|
593
|
+
return self.terms.get(term if isinstance(term, str) else term.text)
|
594
|
+
|
595
|
+
def get_weight(self, term: Term) -> float | None:
|
596
|
+
"""Retrieve the weight of a term."""
|
597
|
+
t = self.terms.get(term.text)
|
598
|
+
return t.weight if t is not None else None
|
599
|
+
|
600
|
+
def __contains__(self, term: Term) -> bool:
|
601
|
+
"""Check if a term exists in the set."""
|
602
|
+
return term.text in self.terms
|
603
|
+
|
604
|
+
def remove(self, term: Term):
|
605
|
+
"""Remove a term from the set, if present."""
|
606
|
+
self.terms.pop(term.text, None)
|
607
|
+
|
608
|
+
def clear(self):
|
609
|
+
"""Clear all terms from the set."""
|
610
|
+
self.terms.clear()
|
611
|
+
|
612
|
+
def values(self) -> list[Term]:
|
613
|
+
"""Retrieve all terms in the set."""
|
614
|
+
return list(self.terms.values())
|
615
|
+
|
616
|
+
|
617
|
+
@dataclass
|
618
|
+
class PropertyTermSet:
|
619
|
+
"""A collection of property terms with support for adding, checking, and clearing."""
|
620
|
+
|
621
|
+
terms: dict[str, Term] = field(default_factory=dict[str, Term])
|
622
|
+
|
623
|
+
def add(self, property_name: str, property_value: Term) -> None:
|
624
|
+
"""Add a property term to the set."""
|
625
|
+
key = self._make_key(property_name, property_value)
|
626
|
+
if key not in self.terms:
|
627
|
+
self.terms[key] = property_value
|
628
|
+
|
629
|
+
def has(self, property_name: str, property_value: Term | str) -> bool:
|
630
|
+
"""Check if a property term exists in the set."""
|
631
|
+
key = self._make_key(property_name, property_value)
|
632
|
+
return key in self.terms
|
633
|
+
|
634
|
+
def clear(self) -> None:
|
635
|
+
"""Clear all property terms from the set."""
|
636
|
+
self.terms.clear()
|
637
|
+
|
638
|
+
def _make_key(self, property_name: str, property_value: Term | str) -> str:
|
639
|
+
"""Create a unique key for a property term."""
|
640
|
+
value = (
|
641
|
+
property_value if isinstance(property_value, str) else property_value.text
|
642
|
+
)
|
643
|
+
return f"{property_name}:{value}"
|
644
|
+
|
645
|
+
|
646
|
+
# TODO: unionArrays
|
647
|
+
# TODO: union
|
648
|
+
# TODO: addToSet
|
649
|
+
# TODO: setUnion
|
650
|
+
# TODO: setIntersect
|
651
|
+
# TODO: getBatches
|
652
|
+
|
653
|
+
|
654
|
+
@dataclass
|
655
|
+
class Scored[T]:
|
656
|
+
item: T
|
657
|
+
score: float
|
658
|
+
|
659
|
+
def __lt__(self, other: "Scored[T]") -> bool:
|
660
|
+
return self.score < other.score
|
661
|
+
|
662
|
+
def __gt__(self, other: "Scored[T]") -> bool:
|
663
|
+
return self.score > other.score
|
664
|
+
|
665
|
+
def __le__(self, other: "Scored[T]") -> bool:
|
666
|
+
return self.score <= other.score
|
667
|
+
|
668
|
+
def __ge__(self, other: "Scored[T]") -> bool:
|
669
|
+
return self.score >= other.score
|
670
|
+
|
671
|
+
|
672
|
+
# Implementation change compared to TS version: Use heapq; no sentinel.
|
673
|
+
# API change: pop/top are not properties.
|
674
|
+
class TopNCollection[T]:
|
675
|
+
"""A collection that maintains the top N items based on their scores."""
|
676
|
+
|
677
|
+
def __init__(self, max_count: int):
|
678
|
+
self._max_count = max_count
|
679
|
+
self._heap: list[Scored[T]] = []
|
680
|
+
|
681
|
+
def __len__(self) -> int:
|
682
|
+
return len(self._heap)
|
683
|
+
|
684
|
+
def reset(self) -> None:
|
685
|
+
self._heap = []
|
686
|
+
|
687
|
+
def pop(self) -> Scored[T]:
|
688
|
+
return heapq.heappop(self._heap)
|
689
|
+
|
690
|
+
def top(self) -> Scored[T]:
|
691
|
+
return self._heap[0]
|
692
|
+
|
693
|
+
def push(self, item: T, score: float) -> None:
|
694
|
+
if len(self._heap) < self._max_count:
|
695
|
+
heapq.heappush(self._heap, Scored(item, score))
|
696
|
+
else:
|
697
|
+
heapq.heappushpop(self._heap, Scored(item, score))
|
698
|
+
|
699
|
+
def by_rank(self) -> list[Scored[T]]:
|
700
|
+
return sorted(self._heap, reverse=True)
|
701
|
+
|
702
|
+
def values_by_rank(self) -> list[T]:
|
703
|
+
return [item.item for item in self.by_rank()]
|
704
|
+
|
705
|
+
|
706
|
+
class TopNList[T](TopNCollection[T]):
|
707
|
+
"""Alias for TopNCollection."""
|
708
|
+
|
709
|
+
|
710
|
+
class TopNListAll[T](TopNList[T]):
|
711
|
+
"""A Top N list for N = infinity (approximated by sys.maxsize)."""
|
712
|
+
|
713
|
+
def __init__(self):
|
714
|
+
super().__init__(sys.maxsize)
|
715
|
+
|
716
|
+
|
717
|
+
def get_top_k[T](
|
718
|
+
scored_items: Iterable[Scored[T]],
|
719
|
+
top_k: int,
|
720
|
+
) -> list[Scored[T]]:
|
721
|
+
"""A function to get the top K of an unsorted list of scored items."""
|
722
|
+
top_n_list = TopNCollection[T](top_k)
|
723
|
+
for scored_item in scored_items:
|
724
|
+
top_n_list.push(scored_item.item, scored_item.score)
|
725
|
+
return top_n_list.by_rank()
|
726
|
+
|
727
|
+
|
728
|
+
def add_to_set[T](
|
729
|
+
set: Set[T],
|
730
|
+
values: Iterable[T],
|
731
|
+
) -> None:
|
732
|
+
"""Add values to a set."""
|
733
|
+
set.update(values)
|
734
|
+
|
735
|
+
|
736
|
+
def get_message_char_count(message: IMessage) -> int:
|
737
|
+
"""Get the character count of a message."""
|
738
|
+
total = 0
|
739
|
+
for chunk in message.text_chunks:
|
740
|
+
total += len(chunk)
|
741
|
+
return total
|
742
|
+
|
743
|
+
|
744
|
+
async def get_count_of_messages_in_char_budget(
|
745
|
+
messages: IMessageCollection,
|
746
|
+
message_ordinals: Iterable[MessageOrdinal],
|
747
|
+
max_chars_in_budget: int,
|
748
|
+
) -> int:
|
749
|
+
"""Get the count of messages that fit within the character budget."""
|
750
|
+
i = 0
|
751
|
+
total_char_count = 0
|
752
|
+
for message_ordinal in message_ordinals:
|
753
|
+
message = await messages.get_item(message_ordinal)
|
754
|
+
message_char_count = get_message_char_count(message)
|
755
|
+
if message_char_count + total_char_count > max_chars_in_budget:
|
756
|
+
break
|
757
|
+
total_char_count += message_char_count
|
758
|
+
i += 1
|
759
|
+
return i
|