typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
typeagent/mcp/server.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
"""Fledgling MCP server on top of knowpro."""
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
import time
|
8
|
+
|
9
|
+
from mcp.server.fastmcp import FastMCP
|
10
|
+
import typechat
|
11
|
+
|
12
|
+
from typeagent.aitools import embeddings, utils
|
13
|
+
from typeagent.aitools.embeddings import AsyncEmbeddingModel
|
14
|
+
from typeagent.knowpro import answers, convknowledge, query, searchlang
|
15
|
+
from typeagent.knowpro.convsettings import ConversationSettings
|
16
|
+
from typeagent.knowpro.answer_response_schema import AnswerResponse
|
17
|
+
from typeagent.knowpro.search_query_schema import SearchQuery
|
18
|
+
from typeagent.podcasts import podcast
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class ProcessingContext:
|
23
|
+
lang_search_options: searchlang.LanguageSearchOptions
|
24
|
+
answer_context_options: answers.AnswerContextOptions
|
25
|
+
query_context: query.QueryEvalContext
|
26
|
+
embedding_model: embeddings.AsyncEmbeddingModel
|
27
|
+
query_translator: typechat.TypeChatJsonTranslator[SearchQuery]
|
28
|
+
answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse]
|
29
|
+
|
30
|
+
def __repr__(self) -> str:
|
31
|
+
parts = []
|
32
|
+
parts.append(f"{self.lang_search_options}")
|
33
|
+
parts.append(f"{self.answer_context_options}")
|
34
|
+
return f"Context({', '.join(parts)})"
|
35
|
+
|
36
|
+
|
37
|
+
async def make_context() -> ProcessingContext:
|
38
|
+
utils.load_dotenv()
|
39
|
+
|
40
|
+
settings = ConversationSettings()
|
41
|
+
lang_search_options = searchlang.LanguageSearchOptions(
|
42
|
+
compile_options=searchlang.LanguageQueryCompileOptions(
|
43
|
+
exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
|
44
|
+
),
|
45
|
+
exact_match=False,
|
46
|
+
max_message_matches=25,
|
47
|
+
)
|
48
|
+
answer_context_options = answers.AnswerContextOptions(
|
49
|
+
entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
|
50
|
+
)
|
51
|
+
|
52
|
+
query_context = await load_podcast_index(
|
53
|
+
"testdata/Episode_53_AdrianTchaikovsky_index", settings
|
54
|
+
)
|
55
|
+
|
56
|
+
model = convknowledge.create_typechat_model()
|
57
|
+
query_translator = utils.create_translator(model, SearchQuery)
|
58
|
+
answer_translator = utils.create_translator(model, AnswerResponse)
|
59
|
+
|
60
|
+
context = ProcessingContext(
|
61
|
+
lang_search_options,
|
62
|
+
answer_context_options,
|
63
|
+
query_context,
|
64
|
+
settings.embedding_model,
|
65
|
+
query_translator,
|
66
|
+
answer_translator,
|
67
|
+
)
|
68
|
+
|
69
|
+
return context
|
70
|
+
|
71
|
+
|
72
|
+
async def load_podcast_index(
|
73
|
+
podcast_file_prefix: str, settings: ConversationSettings
|
74
|
+
) -> query.QueryEvalContext:
|
75
|
+
conversation = await podcast.Podcast.read_from_file(podcast_file_prefix, settings)
|
76
|
+
assert (
|
77
|
+
conversation is not None
|
78
|
+
), f"Failed to load podcast from {podcast_file_prefix!r}"
|
79
|
+
return query.QueryEvalContext(conversation)
|
80
|
+
|
81
|
+
|
82
|
+
# Create an MCP server
|
83
|
+
mcp = FastMCP("knowpro")
|
84
|
+
|
85
|
+
|
86
|
+
@dataclass
|
87
|
+
class QuestionResponse:
|
88
|
+
success: bool
|
89
|
+
answer: str
|
90
|
+
time_used: int # Milliseconds
|
91
|
+
|
92
|
+
|
93
|
+
@mcp.tool()
|
94
|
+
async def query_conversation(question: str) -> QuestionResponse:
|
95
|
+
"""Send a question to the memory server and get an answer back"""
|
96
|
+
t0 = time.time()
|
97
|
+
question = question.strip()
|
98
|
+
if not question:
|
99
|
+
dt = int((time.time() - t0) * 1000) # Convert to milliseconds
|
100
|
+
return QuestionResponse(
|
101
|
+
success=False, answer="No question provided", time_used=dt
|
102
|
+
)
|
103
|
+
context = await make_context()
|
104
|
+
|
105
|
+
# Stages 1, 2, 3 (LLM -> proto-query, compile, execute query)
|
106
|
+
result = await searchlang.search_conversation_with_language(
|
107
|
+
context.query_context.conversation,
|
108
|
+
context.query_translator,
|
109
|
+
question,
|
110
|
+
context.lang_search_options,
|
111
|
+
)
|
112
|
+
if isinstance(result, typechat.Failure):
|
113
|
+
dt = int((time.time() - t0) * 1000) # Convert to milliseconds
|
114
|
+
return QuestionResponse(success=False, answer=result.message, time_used=dt)
|
115
|
+
|
116
|
+
# Stages 3a, 4 (ordinals -> messages/semrefs, LLM -> answer)
|
117
|
+
_, combined_answer = await answers.generate_answers(
|
118
|
+
context.answer_translator,
|
119
|
+
result.value,
|
120
|
+
context.query_context.conversation,
|
121
|
+
question,
|
122
|
+
options=context.answer_context_options,
|
123
|
+
)
|
124
|
+
dt = int((time.time() - t0) * 1000) # Convert to milliseconds
|
125
|
+
match combined_answer.type:
|
126
|
+
case "NoAnswer":
|
127
|
+
return QuestionResponse(
|
128
|
+
success=False, answer=combined_answer.whyNoAnswer or "", time_used=dt
|
129
|
+
)
|
130
|
+
case "Answered":
|
131
|
+
return QuestionResponse(
|
132
|
+
success=True, answer=combined_answer.answer or "", time_used=dt
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
# Run the MCP server
|
137
|
+
if __name__ == "__main__":
|
138
|
+
# Use stdio transport for simplicity
|
139
|
+
mcp.run(transport="stdio")
|
@@ -0,0 +1,473 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from dataclasses import dataclass
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
from typing import TypedDict, cast, Any
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
11
|
+
from pydantic import Field, AliasChoices
|
12
|
+
|
13
|
+
from ..aitools.embeddings import NormalizedEmbeddings
|
14
|
+
from ..storage.memory import semrefindex
|
15
|
+
from ..knowpro import kplib, secindex
|
16
|
+
from ..knowpro.field_helpers import CamelCaseField
|
17
|
+
from ..storage.memory.convthreads import ConversationThreads
|
18
|
+
from ..knowpro.convsettings import ConversationSettings
|
19
|
+
from ..knowpro.interfaces import (
|
20
|
+
ConversationDataWithIndexes,
|
21
|
+
Datetime,
|
22
|
+
ICollection,
|
23
|
+
IConversation,
|
24
|
+
IConversationSecondaryIndexes,
|
25
|
+
IKnowledgeSource,
|
26
|
+
IMessage,
|
27
|
+
IMessageCollection,
|
28
|
+
IMessageMetadata,
|
29
|
+
ISemanticRefCollection,
|
30
|
+
IStorageProvider,
|
31
|
+
ITermToSemanticRefIndex,
|
32
|
+
MessageOrdinal,
|
33
|
+
SemanticRef,
|
34
|
+
Term,
|
35
|
+
Timedelta,
|
36
|
+
)
|
37
|
+
from ..storage.memory.messageindex import MessageTextIndex
|
38
|
+
from ..storage.memory.reltermsindex import TermToRelatedTermsMap
|
39
|
+
from ..storage.utils import create_storage_provider
|
40
|
+
from ..knowpro import serialization
|
41
|
+
from ..storage.memory.collections import (
|
42
|
+
MemoryMessageCollection,
|
43
|
+
MemorySemanticRefCollection,
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
@pydantic_dataclass
|
48
|
+
class PodcastMessageMeta(IKnowledgeSource, IMessageMetadata):
|
49
|
+
"""Metadata class (!= metaclass) for podcast messages."""
|
50
|
+
|
51
|
+
speaker: str | None = None
|
52
|
+
listeners: list[str] = Field(default_factory=list)
|
53
|
+
|
54
|
+
@property
|
55
|
+
def source(self) -> str | None: # type: ignore[reportIncompatibleVariableOverride]
|
56
|
+
return self.speaker
|
57
|
+
|
58
|
+
@property
|
59
|
+
def dest(self) -> str | list[str] | None: # type: ignore[reportIncompatibleVariableOverride]
|
60
|
+
return self.listeners
|
61
|
+
|
62
|
+
def get_knowledge(self) -> kplib.KnowledgeResponse:
|
63
|
+
if not self.speaker:
|
64
|
+
return kplib.KnowledgeResponse(
|
65
|
+
entities=[],
|
66
|
+
actions=[],
|
67
|
+
inverse_actions=[],
|
68
|
+
topics=[],
|
69
|
+
)
|
70
|
+
else:
|
71
|
+
entities: list[kplib.ConcreteEntity] = []
|
72
|
+
entities.append(
|
73
|
+
kplib.ConcreteEntity(
|
74
|
+
name=self.speaker,
|
75
|
+
type=["person"],
|
76
|
+
)
|
77
|
+
)
|
78
|
+
listener_entities = [
|
79
|
+
kplib.ConcreteEntity(
|
80
|
+
name=listener,
|
81
|
+
type=["person"],
|
82
|
+
)
|
83
|
+
for listener in self.listeners
|
84
|
+
]
|
85
|
+
entities.extend(listener_entities)
|
86
|
+
actions = [
|
87
|
+
kplib.Action(
|
88
|
+
verbs=["say"],
|
89
|
+
verb_tense="past",
|
90
|
+
subject_entity_name=self.speaker,
|
91
|
+
object_entity_name=listener,
|
92
|
+
indirect_object_entity_name="none",
|
93
|
+
)
|
94
|
+
for listener in self.listeners
|
95
|
+
]
|
96
|
+
return kplib.KnowledgeResponse(
|
97
|
+
entities=entities,
|
98
|
+
actions=actions,
|
99
|
+
# TODO: Also create inverse actions.
|
100
|
+
inverse_actions=[],
|
101
|
+
topics=[],
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
class PodcastMessageMetaData(TypedDict):
|
106
|
+
speaker: str | None
|
107
|
+
listeners: list[str]
|
108
|
+
|
109
|
+
|
110
|
+
class PodcastMessageData(TypedDict):
|
111
|
+
metadata: PodcastMessageMetaData
|
112
|
+
textChunks: list[str]
|
113
|
+
tags: list[str]
|
114
|
+
timestamp: str | None
|
115
|
+
|
116
|
+
|
117
|
+
@pydantic_dataclass
|
118
|
+
class PodcastMessage(IMessage):
|
119
|
+
text_chunks: list[str] = CamelCaseField("The text chunks of the podcast message")
|
120
|
+
metadata: PodcastMessageMeta = CamelCaseField(
|
121
|
+
"Metadata associated with the podcast message"
|
122
|
+
)
|
123
|
+
tags: list[str] = CamelCaseField(
|
124
|
+
"Tags associated with the message", default_factory=list
|
125
|
+
)
|
126
|
+
timestamp: str | None = None
|
127
|
+
|
128
|
+
def get_knowledge(self) -> kplib.KnowledgeResponse:
|
129
|
+
return self.metadata.get_knowledge()
|
130
|
+
|
131
|
+
def add_timestamp(self, timestamp: str) -> None:
|
132
|
+
self.timestamp = timestamp
|
133
|
+
|
134
|
+
def add_content(self, content: str) -> None:
|
135
|
+
self.text_chunks[0] += content
|
136
|
+
|
137
|
+
def serialize(self) -> PodcastMessageData:
|
138
|
+
return self.__pydantic_serializer__.to_python(self, by_alias=True) # type: ignore
|
139
|
+
|
140
|
+
@staticmethod
|
141
|
+
def deserialize(message_data: PodcastMessageData) -> "PodcastMessage":
|
142
|
+
return PodcastMessage.__pydantic_validator__.validate_python(message_data) # type: ignore
|
143
|
+
|
144
|
+
|
145
|
+
class PodcastData(ConversationDataWithIndexes[PodcastMessageData]):
|
146
|
+
pass
|
147
|
+
|
148
|
+
|
149
|
+
@dataclass
|
150
|
+
class Podcast(IConversation[PodcastMessage, ITermToSemanticRefIndex]):
|
151
|
+
settings: ConversationSettings
|
152
|
+
name_tag: str
|
153
|
+
messages: IMessageCollection[PodcastMessage]
|
154
|
+
semantic_refs: ISemanticRefCollection
|
155
|
+
tags: list[str]
|
156
|
+
semantic_ref_index: ITermToSemanticRefIndex
|
157
|
+
secondary_indexes: IConversationSecondaryIndexes[PodcastMessage] | None
|
158
|
+
|
159
|
+
@classmethod
|
160
|
+
async def create(
|
161
|
+
cls,
|
162
|
+
settings: ConversationSettings,
|
163
|
+
name_tag: str | None = None,
|
164
|
+
messages: IMessageCollection[PodcastMessage] | None = None,
|
165
|
+
semantic_refs: ISemanticRefCollection | None = None,
|
166
|
+
semantic_ref_index: ITermToSemanticRefIndex | None = None,
|
167
|
+
tags: list[str] | None = None,
|
168
|
+
secondary_indexes: IConversationSecondaryIndexes[PodcastMessage] | None = None,
|
169
|
+
) -> "Podcast":
|
170
|
+
"""Create a fully initialized Podcast instance."""
|
171
|
+
storage_provider = await settings.get_storage_provider()
|
172
|
+
return cls(
|
173
|
+
settings,
|
174
|
+
name_tag or "",
|
175
|
+
messages or await storage_provider.get_message_collection(),
|
176
|
+
semantic_refs or await storage_provider.get_semantic_ref_collection(),
|
177
|
+
tags if tags is not None else [],
|
178
|
+
semantic_ref_index or await storage_provider.get_semantic_ref_index(),
|
179
|
+
secondary_indexes
|
180
|
+
or await secindex.ConversationSecondaryIndexes.create(
|
181
|
+
storage_provider, settings.related_term_index_settings
|
182
|
+
),
|
183
|
+
)
|
184
|
+
|
185
|
+
def _get_secondary_indexes(self) -> IConversationSecondaryIndexes[PodcastMessage]:
|
186
|
+
"""Get secondary indexes, asserting they are initialized."""
|
187
|
+
assert (
|
188
|
+
self.secondary_indexes is not None
|
189
|
+
), "Use await Podcast.create() to create an initialized instance"
|
190
|
+
return self.secondary_indexes
|
191
|
+
|
192
|
+
async def add_metadata_to_index(self) -> None:
|
193
|
+
await semrefindex.add_metadata_to_index(
|
194
|
+
self.messages,
|
195
|
+
self.semantic_refs,
|
196
|
+
self.semantic_ref_index,
|
197
|
+
)
|
198
|
+
|
199
|
+
async def generate_timestamps(
|
200
|
+
self, start_date: Datetime, length_minutes: float = 60.0
|
201
|
+
) -> None:
|
202
|
+
await timestamp_messages(
|
203
|
+
self.messages, start_date, start_date + Timedelta(minutes=length_minutes)
|
204
|
+
)
|
205
|
+
|
206
|
+
async def build_index(
|
207
|
+
self,
|
208
|
+
) -> None:
|
209
|
+
await self.add_metadata_to_index()
|
210
|
+
assert (
|
211
|
+
self.settings is not None
|
212
|
+
), "Settings must be initialized before building index"
|
213
|
+
await semrefindex.build_semantic_ref(self, self.settings)
|
214
|
+
# build_semantic_ref automatically builds standard secondary indexes.
|
215
|
+
# Pass false here to build podcast specific secondary indexes only.
|
216
|
+
await self._build_transient_secondary_indexes(False)
|
217
|
+
if self.secondary_indexes is not None:
|
218
|
+
if self.secondary_indexes.threads is not None:
|
219
|
+
await self.secondary_indexes.threads.build_index() # type: ignore # TODO
|
220
|
+
|
221
|
+
async def serialize(self) -> PodcastData:
|
222
|
+
data = PodcastData(
|
223
|
+
nameTag=self.name_tag,
|
224
|
+
messages=[m.serialize() async for m in self.messages],
|
225
|
+
tags=self.tags,
|
226
|
+
semanticRefs=(
|
227
|
+
[r.serialize() async for r in self.semantic_refs]
|
228
|
+
if self.semantic_refs is not None
|
229
|
+
else None
|
230
|
+
),
|
231
|
+
)
|
232
|
+
data["semanticIndexData"] = await self.semantic_ref_index.serialize()
|
233
|
+
|
234
|
+
secondary_indexes = self._get_secondary_indexes()
|
235
|
+
if secondary_indexes.term_to_related_terms_index is not None:
|
236
|
+
data["relatedTermsIndexData"] = (
|
237
|
+
await secondary_indexes.term_to_related_terms_index.serialize()
|
238
|
+
)
|
239
|
+
if secondary_indexes.threads:
|
240
|
+
data["threadData"] = secondary_indexes.threads.serialize()
|
241
|
+
if secondary_indexes.message_index is not None:
|
242
|
+
data["messageIndexData"] = await secondary_indexes.message_index.serialize()
|
243
|
+
return data
|
244
|
+
|
245
|
+
async def write_to_file(self, filename: str) -> None:
|
246
|
+
data = await self.serialize()
|
247
|
+
serialization.write_conversation_data_to_file(data, filename)
|
248
|
+
|
249
|
+
async def deserialize(
|
250
|
+
self, podcast_data: ConversationDataWithIndexes[PodcastMessageData]
|
251
|
+
) -> None:
|
252
|
+
if await self.messages.size() or (
|
253
|
+
self.semantic_refs is not None and await self.semantic_refs.size()
|
254
|
+
):
|
255
|
+
raise RuntimeError("Cannot deserialize into a non-empty Podcast.")
|
256
|
+
|
257
|
+
self.name_tag = podcast_data["nameTag"]
|
258
|
+
|
259
|
+
message_list = [PodcastMessage.deserialize(m) for m in podcast_data["messages"]]
|
260
|
+
await self.messages.extend(message_list)
|
261
|
+
|
262
|
+
semantic_refs_data = podcast_data.get("semanticRefs")
|
263
|
+
if semantic_refs_data is not None:
|
264
|
+
semrefs = [SemanticRef.deserialize(r) for r in semantic_refs_data]
|
265
|
+
await self.semantic_refs.extend(semrefs)
|
266
|
+
|
267
|
+
self.tags = podcast_data["tags"]
|
268
|
+
|
269
|
+
semantic_index_data = podcast_data.get("semanticIndexData")
|
270
|
+
if semantic_index_data is not None:
|
271
|
+
await self.semantic_ref_index.deserialize(semantic_index_data)
|
272
|
+
|
273
|
+
related_terms_index_data = podcast_data.get("relatedTermsIndexData")
|
274
|
+
if related_terms_index_data is not None:
|
275
|
+
secondary_indexes = self._get_secondary_indexes()
|
276
|
+
term_to_related_terms_index = secondary_indexes.term_to_related_terms_index
|
277
|
+
if term_to_related_terms_index is not None:
|
278
|
+
# Assert empty before deserializing
|
279
|
+
assert (
|
280
|
+
await term_to_related_terms_index.aliases.is_empty()
|
281
|
+
), "Term to related terms index must be empty before deserializing"
|
282
|
+
await term_to_related_terms_index.deserialize(related_terms_index_data)
|
283
|
+
|
284
|
+
thread_data = podcast_data.get("threadData")
|
285
|
+
if thread_data is not None:
|
286
|
+
assert (
|
287
|
+
self.settings is not None
|
288
|
+
), "Settings must be initialized for deserialization"
|
289
|
+
secondary_indexes = self._get_secondary_indexes()
|
290
|
+
secondary_indexes.threads = ConversationThreads(
|
291
|
+
self.settings.thread_settings
|
292
|
+
)
|
293
|
+
secondary_indexes.threads.deserialize(thread_data)
|
294
|
+
|
295
|
+
message_index_data = podcast_data.get("messageIndexData")
|
296
|
+
if message_index_data is not None:
|
297
|
+
secondary_indexes = self._get_secondary_indexes()
|
298
|
+
# Assert the message index is empty before deserializing
|
299
|
+
assert (
|
300
|
+
secondary_indexes.message_index is not None
|
301
|
+
), "Message index should be initialized"
|
302
|
+
|
303
|
+
if isinstance(secondary_indexes.message_index, MessageTextIndex):
|
304
|
+
index_size = await secondary_indexes.message_index.size()
|
305
|
+
assert (
|
306
|
+
index_size == 0
|
307
|
+
), "Message index must be empty before deserializing"
|
308
|
+
await secondary_indexes.message_index.deserialize(message_index_data)
|
309
|
+
|
310
|
+
await self._build_transient_secondary_indexes(True)
|
311
|
+
|
312
|
+
@staticmethod
|
313
|
+
def _read_conversation_data_from_file(
|
314
|
+
filename_prefix: str, embedding_size: int
|
315
|
+
) -> ConversationDataWithIndexes[Any]:
|
316
|
+
"""Read podcast conversation data from files. No exceptions are caught; they just bubble out."""
|
317
|
+
with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f:
|
318
|
+
json_data: serialization.ConversationJsonData[PodcastMessageData] = (
|
319
|
+
json.load(f)
|
320
|
+
)
|
321
|
+
embeddings_list: list[NormalizedEmbeddings] | None = None
|
322
|
+
if embedding_size:
|
323
|
+
with open(filename_prefix + "_embeddings.bin", "rb") as f:
|
324
|
+
embeddings = np.fromfile(f, dtype=np.float32).reshape(
|
325
|
+
(-1, embedding_size)
|
326
|
+
)
|
327
|
+
embeddings_list = [embeddings]
|
328
|
+
else:
|
329
|
+
print(
|
330
|
+
"Warning: not reading embeddings file because size is {embedding_size}"
|
331
|
+
)
|
332
|
+
embeddings_list = None
|
333
|
+
file_data = serialization.ConversationFileData(
|
334
|
+
jsonData=json_data,
|
335
|
+
binaryData=serialization.ConversationBinaryData(
|
336
|
+
embeddingsList=embeddings_list
|
337
|
+
),
|
338
|
+
)
|
339
|
+
if json_data.get("fileHeader") is None:
|
340
|
+
json_data["fileHeader"] = serialization.create_file_header()
|
341
|
+
return serialization.from_conversation_file_data(file_data)
|
342
|
+
|
343
|
+
@staticmethod
|
344
|
+
async def read_from_file(
|
345
|
+
filename_prefix: str,
|
346
|
+
settings: ConversationSettings,
|
347
|
+
dbname: str | None = None,
|
348
|
+
) -> "Podcast":
|
349
|
+
embedding_size = settings.embedding_model.embedding_size
|
350
|
+
data = Podcast._read_conversation_data_from_file(
|
351
|
+
filename_prefix, embedding_size
|
352
|
+
)
|
353
|
+
|
354
|
+
provider = await settings.get_storage_provider()
|
355
|
+
msgs = await provider.get_message_collection()
|
356
|
+
semrefs = await provider.get_semantic_ref_collection()
|
357
|
+
if await msgs.size() or await semrefs.size():
|
358
|
+
raise RuntimeError(
|
359
|
+
f"Database {dbname!r} already has messages or semantic refs."
|
360
|
+
)
|
361
|
+
podcast = await Podcast.create(settings, messages=msgs, semantic_refs=semrefs)
|
362
|
+
await podcast.deserialize(data)
|
363
|
+
return podcast
|
364
|
+
|
365
|
+
async def _build_transient_secondary_indexes(self, build_all: bool) -> None:
|
366
|
+
# Secondary indexes are already initialized via create() factory method
|
367
|
+
if build_all:
|
368
|
+
await secindex.build_transient_secondary_indexes(self, self.settings)
|
369
|
+
await self._build_participant_aliases()
|
370
|
+
await self._add_synonyms()
|
371
|
+
|
372
|
+
async def _build_participant_aliases(self) -> None:
|
373
|
+
secondary_indexes = self._get_secondary_indexes()
|
374
|
+
term_to_related_terms_index = secondary_indexes.term_to_related_terms_index
|
375
|
+
assert term_to_related_terms_index is not None
|
376
|
+
aliases = term_to_related_terms_index.aliases
|
377
|
+
await aliases.clear()
|
378
|
+
name_to_alias_map = await self._collect_participant_aliases()
|
379
|
+
for name in name_to_alias_map.keys():
|
380
|
+
related_terms: list[Term] = [
|
381
|
+
Term(text=alias) for alias in name_to_alias_map[name]
|
382
|
+
]
|
383
|
+
await aliases.add_related_term(name, related_terms)
|
384
|
+
|
385
|
+
async def _add_synonyms(self) -> None:
|
386
|
+
secondary_indexes = self._get_secondary_indexes()
|
387
|
+
assert secondary_indexes.term_to_related_terms_index is not None
|
388
|
+
aliases = secondary_indexes.term_to_related_terms_index.aliases
|
389
|
+
synonym_file = os.path.join(os.path.dirname(__file__), "podcastVerbs.json")
|
390
|
+
with open(synonym_file) as f:
|
391
|
+
data: list[dict] = json.load(f)
|
392
|
+
if data:
|
393
|
+
for obj in data:
|
394
|
+
text = obj.get("term")
|
395
|
+
synonyms = obj.get("relatedTerms")
|
396
|
+
if text and synonyms:
|
397
|
+
related_term = Term(text=text.lower())
|
398
|
+
for synonym in synonyms:
|
399
|
+
await aliases.add_related_term(synonym.lower(), related_term)
|
400
|
+
|
401
|
+
async def _collect_participant_aliases(self) -> dict[str, set[str]]:
|
402
|
+
|
403
|
+
aliases: dict[str, set[str]] = {}
|
404
|
+
|
405
|
+
def collect_name(participant_name: str | None):
|
406
|
+
if not participant_name:
|
407
|
+
return
|
408
|
+
participant_name = participant_name.lower()
|
409
|
+
parsed_name = split_participant_name(participant_name)
|
410
|
+
if parsed_name and parsed_name.first_name and parsed_name.last_name:
|
411
|
+
# If participant_name is a full name, associate first_name with the full name.
|
412
|
+
aliases.setdefault(parsed_name.first_name, set()).add(participant_name)
|
413
|
+
# And also the reverse.
|
414
|
+
aliases.setdefault(participant_name, set()).add(parsed_name.first_name)
|
415
|
+
|
416
|
+
async for message in self.messages:
|
417
|
+
collect_name(message.metadata.speaker)
|
418
|
+
for listener in message.metadata.listeners:
|
419
|
+
collect_name(listener)
|
420
|
+
|
421
|
+
return aliases
|
422
|
+
|
423
|
+
|
424
|
+
# Text (such as a transcript) can be collected over a time range.
|
425
|
+
# This text can be partitioned into blocks.
|
426
|
+
# However, timestamps for individual blocks are not available.
|
427
|
+
# Assigns individual timestamps to blocks proportional to their lengths.
|
428
|
+
async def timestamp_messages(
|
429
|
+
messages: ICollection[PodcastMessage, MessageOrdinal],
|
430
|
+
start_time: Datetime,
|
431
|
+
end_time: Datetime,
|
432
|
+
) -> None:
|
433
|
+
start = start_time.timestamp()
|
434
|
+
duration = end_time.timestamp() - start
|
435
|
+
if duration <= 0:
|
436
|
+
raise RuntimeError(f"{start_time} is not < {end_time}")
|
437
|
+
message_lengths = [
|
438
|
+
sum(len(chunk) for chunk in m.text_chunks) async for m in messages
|
439
|
+
]
|
440
|
+
text_length = sum(message_lengths)
|
441
|
+
seconds_per_char = duration / text_length
|
442
|
+
messages_list = [m async for m in messages]
|
443
|
+
for message, length in zip(messages_list, message_lengths):
|
444
|
+
message.timestamp = Datetime.fromtimestamp(start).isoformat()
|
445
|
+
start += seconds_per_char * length
|
446
|
+
|
447
|
+
|
448
|
+
@dataclass
|
449
|
+
class ParticipantName:
|
450
|
+
first_name: str
|
451
|
+
last_name: str | None = None
|
452
|
+
middle_name: str | None = None
|
453
|
+
|
454
|
+
|
455
|
+
def split_participant_name(full_name: str) -> ParticipantName | None:
|
456
|
+
parts = full_name.split(None, 2)
|
457
|
+
match len(parts):
|
458
|
+
case 0:
|
459
|
+
return None
|
460
|
+
case 1:
|
461
|
+
return ParticipantName(first_name=parts[0])
|
462
|
+
case 2:
|
463
|
+
return ParticipantName(first_name=parts[0], last_name=parts[1])
|
464
|
+
case 3:
|
465
|
+
if parts[1].lower() == "van":
|
466
|
+
parts[1:] = [f"{parts[1]} {parts[2]}"]
|
467
|
+
return ParticipantName(first_name=parts[0], last_name=parts[1])
|
468
|
+
last_name = " ".join(parts[2].split())
|
469
|
+
return ParticipantName(
|
470
|
+
first_name=parts[0], middle_name=parts[1], last_name=last_name
|
471
|
+
)
|
472
|
+
case _:
|
473
|
+
assert False, "SHOULD BE UNREACHABLE: Full name has too many parts"
|