typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,9 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from .interfaces import SearchTerm
|
5
|
+
|
6
|
+
|
7
|
+
def is_search_term_wildcard(search_term: SearchTerm) -> bool:
|
8
|
+
"""Check if a search term is a wildcard."""
|
9
|
+
return search_term.term.text == "*"
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
import asyncio
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
import os
|
7
|
+
|
8
|
+
import typechat
|
9
|
+
|
10
|
+
from ..aitools import auth
|
11
|
+
from . import kplib
|
12
|
+
from .interfaces import IKnowledgeExtractor
|
13
|
+
|
14
|
+
|
15
|
+
# TODO: Move ModelWrapper and create_typechat_model() to aitools package.
|
16
|
+
|
17
|
+
|
18
|
+
class ModelWrapper(typechat.TypeChatLanguageModel):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
base_model: typechat.TypeChatLanguageModel,
|
22
|
+
token_provider: auth.AzureTokenProvider,
|
23
|
+
):
|
24
|
+
self.base_model = base_model
|
25
|
+
self.token_provider = token_provider
|
26
|
+
|
27
|
+
async def complete(
|
28
|
+
self, prompt: str | list[typechat.PromptSection]
|
29
|
+
) -> typechat.Result[str]:
|
30
|
+
if self.token_provider.needs_refresh():
|
31
|
+
loop = asyncio.get_running_loop()
|
32
|
+
api_key = await loop.run_in_executor(
|
33
|
+
None, self.token_provider.refresh_token
|
34
|
+
)
|
35
|
+
env: dict[str, str | None] = dict(os.environ)
|
36
|
+
key_name = "AZURE_OPENAI_API_KEY"
|
37
|
+
env[key_name] = api_key
|
38
|
+
self.base_model = typechat.create_language_model(env)
|
39
|
+
return await self.base_model.complete(prompt)
|
40
|
+
|
41
|
+
|
42
|
+
def create_typechat_model() -> typechat.TypeChatLanguageModel:
|
43
|
+
env: dict[str, str | None] = dict(os.environ)
|
44
|
+
key_name = "AZURE_OPENAI_API_KEY"
|
45
|
+
key = env.get(key_name)
|
46
|
+
shared_token_provider: auth.AzureTokenProvider | None = None
|
47
|
+
if key is not None and key.lower() == "identity":
|
48
|
+
shared_token_provider = auth.get_shared_token_provider()
|
49
|
+
env[key_name] = shared_token_provider.get_token()
|
50
|
+
model = typechat.create_language_model(env)
|
51
|
+
if shared_token_provider is not None:
|
52
|
+
model = ModelWrapper(model, shared_token_provider)
|
53
|
+
return model
|
54
|
+
|
55
|
+
|
56
|
+
@dataclass
|
57
|
+
class KnowledgeExtractor:
|
58
|
+
model: typechat.TypeChatLanguageModel = field(default_factory=create_typechat_model)
|
59
|
+
max_chars_per_chunk: int = 2048
|
60
|
+
merge_action_knowledge: bool = True
|
61
|
+
# Not in the signature:
|
62
|
+
translator: typechat.TypeChatJsonTranslator[kplib.KnowledgeResponse] = field(
|
63
|
+
init=False
|
64
|
+
)
|
65
|
+
|
66
|
+
def __post_init__(self):
|
67
|
+
self.translator = self.create_translator(self.model)
|
68
|
+
|
69
|
+
# TODO: Use max_chars_per_chunk and merge_action_knowledge.
|
70
|
+
|
71
|
+
async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]:
|
72
|
+
result = await self.translator.translate(message)
|
73
|
+
if isinstance(result, typechat.Success):
|
74
|
+
if self.merge_action_knowledge:
|
75
|
+
self.merge_action_knowledge_into_response(result.value)
|
76
|
+
return result
|
77
|
+
|
78
|
+
def create_translator(
|
79
|
+
self, model: typechat.TypeChatLanguageModel
|
80
|
+
) -> typechat.TypeChatJsonTranslator[kplib.KnowledgeResponse]:
|
81
|
+
schema = kplib.KnowledgeResponse
|
82
|
+
type_name = "KnowledgeResponse"
|
83
|
+
validator = typechat.TypeChatValidator[kplib.KnowledgeResponse](schema)
|
84
|
+
translator = typechat.TypeChatJsonTranslator[kplib.KnowledgeResponse](
|
85
|
+
model, validator, kplib.KnowledgeResponse
|
86
|
+
)
|
87
|
+
schema_text = translator.schema_str.rstrip()
|
88
|
+
|
89
|
+
def create_request_prompt(intent: str) -> str:
|
90
|
+
return (
|
91
|
+
f"You are a service that translates user messages in a conversation "
|
92
|
+
+ f'into JSON objects of type "{type_name}" '
|
93
|
+
+ f"according to the following TypeScript definitions:\n"
|
94
|
+
+ f"```\n"
|
95
|
+
+ f"{schema_text}\n"
|
96
|
+
+ f"```\n"
|
97
|
+
+ f"The following are messages in a conversation:\n"
|
98
|
+
+ f'"""\n'
|
99
|
+
+ f"{intent}\n"
|
100
|
+
+ f'"""\n'
|
101
|
+
+ f"The following is the user request translated into a JSON object "
|
102
|
+
+ f"with 2 spaces of indentation and no properties with the value undefined:\n"
|
103
|
+
)
|
104
|
+
|
105
|
+
translator._create_request_prompt = create_request_prompt
|
106
|
+
return translator
|
107
|
+
|
108
|
+
def merge_action_knowledge_into_response(
|
109
|
+
self, knowledge: kplib.KnowledgeResponse
|
110
|
+
) -> None:
|
111
|
+
"""Merge action knowledge into a single knowledge object."""
|
112
|
+
raise NotImplementedError("TODO")
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from ..aitools.embeddings import AsyncEmbeddingModel
|
9
|
+
from ..aitools.vectorbase import TextEmbeddingIndexSettings
|
10
|
+
from .interfaces import IKnowledgeExtractor, IStorageProvider
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class MessageTextIndexSettings:
|
15
|
+
embedding_index_settings: TextEmbeddingIndexSettings
|
16
|
+
|
17
|
+
def __init__(self, embedding_index_settings: TextEmbeddingIndexSettings):
|
18
|
+
self.embedding_index_settings = embedding_index_settings
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class RelatedTermIndexSettings:
|
23
|
+
embedding_index_settings: TextEmbeddingIndexSettings
|
24
|
+
|
25
|
+
def __init__(self, embedding_index_settings: TextEmbeddingIndexSettings):
|
26
|
+
self.embedding_index_settings = embedding_index_settings
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class SemanticRefIndexSettings:
|
31
|
+
batch_size: int
|
32
|
+
auto_extract_knowledge: bool
|
33
|
+
knowledge_extractor: IKnowledgeExtractor | None = None
|
34
|
+
|
35
|
+
|
36
|
+
class ConversationSettings:
|
37
|
+
"""Settings for conversation processing and indexing."""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
model: AsyncEmbeddingModel | None = None,
|
42
|
+
storage_provider: IStorageProvider | None = None,
|
43
|
+
):
|
44
|
+
# All settings share the same model, so they share the embedding cache.
|
45
|
+
model = model or AsyncEmbeddingModel()
|
46
|
+
self.embedding_model = model
|
47
|
+
min_score = 0.85
|
48
|
+
self.related_term_index_settings = RelatedTermIndexSettings(
|
49
|
+
TextEmbeddingIndexSettings(model, min_score=min_score, max_matches=50)
|
50
|
+
)
|
51
|
+
self.thread_settings = TextEmbeddingIndexSettings(model, min_score=min_score)
|
52
|
+
self.message_text_index_settings = MessageTextIndexSettings(
|
53
|
+
TextEmbeddingIndexSettings(model, min_score=min_score)
|
54
|
+
)
|
55
|
+
self.semantic_ref_index_settings = SemanticRefIndexSettings(
|
56
|
+
batch_size=10,
|
57
|
+
auto_extract_knowledge=False,
|
58
|
+
)
|
59
|
+
|
60
|
+
# Storage provider will be created lazily if not provided
|
61
|
+
self._storage_provider: IStorageProvider | None = storage_provider
|
62
|
+
self._storage_provider_created = storage_provider is not None
|
63
|
+
|
64
|
+
@property
|
65
|
+
def storage_provider(self) -> IStorageProvider:
|
66
|
+
if not self._storage_provider_created:
|
67
|
+
raise RuntimeError(
|
68
|
+
"Storage provider not initialized. Use await ConversationSettings.get_storage_provider() "
|
69
|
+
"or provide storage_provider in constructor."
|
70
|
+
)
|
71
|
+
assert (
|
72
|
+
self._storage_provider is not None
|
73
|
+
), "Storage provider should be set when _storage_provider_created is True"
|
74
|
+
return self._storage_provider
|
75
|
+
|
76
|
+
@storage_provider.setter
|
77
|
+
def storage_provider(self, value: IStorageProvider) -> None:
|
78
|
+
self._storage_provider = value
|
79
|
+
self._storage_provider_created = True
|
80
|
+
|
81
|
+
async def get_storage_provider(self) -> IStorageProvider:
|
82
|
+
"""Get or create the storage provider asynchronously."""
|
83
|
+
if not self._storage_provider_created:
|
84
|
+
from ..storage.memory import MemoryStorageProvider
|
85
|
+
|
86
|
+
self._storage_provider = MemoryStorageProvider(
|
87
|
+
message_text_settings=self.message_text_index_settings,
|
88
|
+
related_terms_settings=self.related_term_index_settings,
|
89
|
+
)
|
90
|
+
self._storage_provider_created = True
|
91
|
+
assert (
|
92
|
+
self._storage_provider is not None
|
93
|
+
), "Storage provider should be set after creation"
|
94
|
+
return self._storage_provider
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
import typechat
|
5
|
+
|
6
|
+
from .convsettings import ConversationSettings
|
7
|
+
from .interfaces import (
|
8
|
+
DateRange,
|
9
|
+
Datetime,
|
10
|
+
IConversation,
|
11
|
+
IMessage,
|
12
|
+
ITermToSemanticRefIndex,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
async def get_time_range_prompt_section_for_conversation[
|
17
|
+
TMessage: IMessage, TIndex: ITermToSemanticRefIndex
|
18
|
+
](
|
19
|
+
conversation: IConversation[TMessage, TIndex],
|
20
|
+
) -> typechat.PromptSection | None:
|
21
|
+
time_range = await get_time_range_for_conversation(conversation)
|
22
|
+
if time_range is not None:
|
23
|
+
start = time_range.start.replace(tzinfo=None).isoformat()
|
24
|
+
end = (
|
25
|
+
time_range.end.replace(tzinfo=None).isoformat() if time_range.end else "now"
|
26
|
+
)
|
27
|
+
return typechat.PromptSection(
|
28
|
+
role="system",
|
29
|
+
content=f"ONLY IF user request explicitly asks for time ranges, "
|
30
|
+
f'THEN use the CONVERSATION TIME RANGE: "{start} to {end}"',
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
async def get_time_range_for_conversation[
|
35
|
+
TMessage: IMessage, TIndex: ITermToSemanticRefIndex
|
36
|
+
](
|
37
|
+
conversation: IConversation[TMessage, TIndex],
|
38
|
+
) -> DateRange | None:
|
39
|
+
messages = conversation.messages
|
40
|
+
size = await messages.size()
|
41
|
+
if size > 0:
|
42
|
+
start = (await messages.get_item(0)).timestamp
|
43
|
+
if start is not None:
|
44
|
+
end = (await messages.get_item(size - 1)).timestamp
|
45
|
+
return DateRange(
|
46
|
+
start=Datetime.fromisoformat(start),
|
47
|
+
end=Datetime.fromisoformat(end) if end else None,
|
48
|
+
)
|
49
|
+
return None
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from pydantic.dataclasses import dataclass
|
5
|
+
from typing import Annotated
|
6
|
+
from typing_extensions import Doc
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class DateVal:
|
11
|
+
day: int
|
12
|
+
month: int
|
13
|
+
year: int
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class TimeVal:
|
18
|
+
hour: Annotated[int, Doc("In 24 hour form")]
|
19
|
+
minute: int
|
20
|
+
seconds: int
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class DateTime:
|
25
|
+
date: DateVal
|
26
|
+
time: TimeVal | None = None
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class DateTimeRange:
|
31
|
+
start_date: DateTime
|
32
|
+
stop_date: DateTime | None = None
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from dataclasses import MISSING
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from pydantic import Field, AliasChoices
|
8
|
+
from pydantic.alias_generators import to_camel
|
9
|
+
|
10
|
+
|
11
|
+
def CamelCaseField(
|
12
|
+
description: str | None = None,
|
13
|
+
*,
|
14
|
+
field_name: str | None = None,
|
15
|
+
default: Any = MISSING,
|
16
|
+
default_factory: Any = MISSING,
|
17
|
+
) -> Any:
|
18
|
+
"""
|
19
|
+
Helper function to create a Field with camelCase serialization alias.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
description: The field description
|
23
|
+
field_name: The snake_case field name (if provided, creates Field directly)
|
24
|
+
default: The default value for the field (optional)
|
25
|
+
default_factory: The default factory for the field (optional)
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
If field_name is provided: A Field with serialization_alias set to the camelCase version
|
29
|
+
Otherwise: A descriptor that will create a Field with serialization_alias set to the camelCase version
|
30
|
+
of the field name and validation_alias set to accept both snake_case and camelCase versions.
|
31
|
+
|
32
|
+
Note: For fields ending with underscore (like 'from_'), the underscore is removed in the camelCase version.
|
33
|
+
"""
|
34
|
+
|
35
|
+
# If field_name is provided, create the Field directly
|
36
|
+
if field_name is not None:
|
37
|
+
clean_name = field_name.rstrip("_")
|
38
|
+
camel_name = to_camel(clean_name)
|
39
|
+
|
40
|
+
field_kwargs = {
|
41
|
+
"description": description,
|
42
|
+
"serialization_alias": camel_name,
|
43
|
+
"validation_alias": AliasChoices(field_name, camel_name),
|
44
|
+
}
|
45
|
+
|
46
|
+
if default is not MISSING:
|
47
|
+
field_kwargs["default"] = default
|
48
|
+
elif default_factory is not MISSING:
|
49
|
+
field_kwargs["default_factory"] = default_factory
|
50
|
+
|
51
|
+
return Field(**field_kwargs)
|
52
|
+
|
53
|
+
return Field(**field_kwargs)
|
54
|
+
|
55
|
+
# Otherwise, use the descriptor approach for backward compatibility
|
56
|
+
class CamelCaseFieldDescriptor:
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
description: str | None = None,
|
60
|
+
default: Any = MISSING,
|
61
|
+
default_factory: Any = MISSING,
|
62
|
+
):
|
63
|
+
self.description = description
|
64
|
+
self.default = default
|
65
|
+
self.default_factory = default_factory
|
66
|
+
|
67
|
+
def __set_name__(self, owner, name):
|
68
|
+
# Replace ourselves with the actual Field when the field name is known
|
69
|
+
# Handle trailing underscore (like from_ -> from)
|
70
|
+
clean_name = name.rstrip("_")
|
71
|
+
camel_name = to_camel(clean_name)
|
72
|
+
|
73
|
+
field_kwargs = {
|
74
|
+
"description": self.description,
|
75
|
+
"serialization_alias": camel_name,
|
76
|
+
"validation_alias": AliasChoices(name, camel_name),
|
77
|
+
}
|
78
|
+
|
79
|
+
if self.default is not MISSING:
|
80
|
+
field_kwargs["default"] = self.default
|
81
|
+
elif self.default_factory is not MISSING:
|
82
|
+
field_kwargs["default_factory"] = self.default_factory
|
83
|
+
|
84
|
+
field = Field(**field_kwargs)
|
85
|
+
setattr(owner, name, field)
|
86
|
+
|
87
|
+
return CamelCaseFieldDescriptor(description, default, default_factory)
|
@@ -0,0 +1,144 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
from collections.abc import Callable
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from ..aitools.vectorbase import VectorBase, TextEmbeddingIndexSettings, ScoredInt
|
9
|
+
from ..aitools.embeddings import NormalizedEmbedding, NormalizedEmbeddings
|
10
|
+
|
11
|
+
|
12
|
+
class EmbeddingIndex:
|
13
|
+
"""Wrapper around VectorBase."""
|
14
|
+
|
15
|
+
# TODO: Don't use self._vector_base._vectors directly; use VectorBase methods.
|
16
|
+
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
settings: TextEmbeddingIndexSettings,
|
20
|
+
embeddings: NormalizedEmbeddings | None = None,
|
21
|
+
):
|
22
|
+
# Use VectorBase for storage and operations on embeddings.
|
23
|
+
self._vector_base = VectorBase(settings)
|
24
|
+
|
25
|
+
# Initialize with embeddings if provided.
|
26
|
+
if embeddings is not None:
|
27
|
+
for embedding in embeddings:
|
28
|
+
self._vector_base.add_embedding(None, embedding)
|
29
|
+
|
30
|
+
def __len__(self) -> int:
|
31
|
+
return len(self._vector_base)
|
32
|
+
|
33
|
+
async def size(self) -> int:
|
34
|
+
return len(self._vector_base)
|
35
|
+
|
36
|
+
async def is_empty(self) -> bool:
|
37
|
+
return len(self._vector_base) == 0
|
38
|
+
|
39
|
+
async def get_embedding(self, key: str, cache: bool = True) -> NormalizedEmbedding:
|
40
|
+
return await self._vector_base.get_embedding(key, cache)
|
41
|
+
|
42
|
+
def get(self, pos: int) -> NormalizedEmbedding:
|
43
|
+
return self._vector_base.get_embedding_at(pos)
|
44
|
+
|
45
|
+
def push(self, embeddings: NormalizedEmbeddings) -> None:
|
46
|
+
self._vector_base.add_embeddings(embeddings)
|
47
|
+
|
48
|
+
async def add_texts(self, texts: list[str]) -> None:
|
49
|
+
await self._vector_base.add_keys(texts)
|
50
|
+
|
51
|
+
# def insert_at(
|
52
|
+
# self, index: int, embeddings: NormalizedEmbedding | list[NormalizedEmbedding]
|
53
|
+
# ) -> None:
|
54
|
+
# """Insert one or more embeddings at the specified position.
|
55
|
+
|
56
|
+
# Args:
|
57
|
+
# index: Position to insert at
|
58
|
+
# embeddings: A single embedding or list of embeddings to insert
|
59
|
+
# """
|
60
|
+
# # Convert input to list
|
61
|
+
# emb_list = embeddings if isinstance(embeddings, list) else [embeddings]
|
62
|
+
|
63
|
+
# # Create a new array with space for the insertions
|
64
|
+
# old_vectors = self._vector_base._vectors
|
65
|
+
# size = len(old_vectors)
|
66
|
+
|
67
|
+
# if index < 0 or index > size:
|
68
|
+
# raise IndexError(
|
69
|
+
# f"Index {index} out of bounds for insertion in embedding index of size {size}"
|
70
|
+
# )
|
71
|
+
|
72
|
+
# # Convert embeddings to 2D array
|
73
|
+
# new_vectors = np.vstack([e.reshape(1, -1) for e in emb_list])
|
74
|
+
|
75
|
+
# # Split and recombine the vectors
|
76
|
+
# if index == 0:
|
77
|
+
# result = np.vstack([new_vectors, old_vectors])
|
78
|
+
# elif index >= size:
|
79
|
+
# result = np.vstack([old_vectors, new_vectors])
|
80
|
+
# else:
|
81
|
+
# result = np.vstack([old_vectors[:index], new_vectors, old_vectors[index:]])
|
82
|
+
|
83
|
+
# # Update the vector base
|
84
|
+
# self._vector_base._vectors = result
|
85
|
+
|
86
|
+
def get_indexes_of_nearest(
|
87
|
+
self,
|
88
|
+
embedding: NormalizedEmbedding,
|
89
|
+
max_matches: int | None = None,
|
90
|
+
min_score: float | None = None,
|
91
|
+
predicate: Callable[[int], bool] | None = None,
|
92
|
+
) -> list[ScoredInt]:
|
93
|
+
return self._vector_base.fuzzy_lookup_embedding(
|
94
|
+
embedding,
|
95
|
+
max_hits=max_matches,
|
96
|
+
min_score=min_score,
|
97
|
+
predicate=predicate,
|
98
|
+
)
|
99
|
+
|
100
|
+
def get_indexes_of_nearest_in_subset(
|
101
|
+
self,
|
102
|
+
embedding: NormalizedEmbedding,
|
103
|
+
ordinals_of_subset: list[int],
|
104
|
+
max_matches: int | None = None,
|
105
|
+
min_score: float | None = None,
|
106
|
+
) -> list[ScoredInt]:
|
107
|
+
return self._vector_base.fuzzy_lookup_embedding_in_subset(
|
108
|
+
embedding,
|
109
|
+
ordinals_of_subset,
|
110
|
+
max_matches,
|
111
|
+
min_score,
|
112
|
+
)
|
113
|
+
|
114
|
+
# def remove_at(self, pos: int) -> None:
|
115
|
+
# """Remove the embedding at the specified position.
|
116
|
+
|
117
|
+
# Args:
|
118
|
+
# pos: The position to remove
|
119
|
+
# """
|
120
|
+
# if 0 <= pos < len(self._vector_base):
|
121
|
+
# # Create new array without the element at pos
|
122
|
+
# self._vector_base._vectors = np.delete(
|
123
|
+
# self._vector_base._vectors, pos, axis=0
|
124
|
+
# )
|
125
|
+
# else:
|
126
|
+
# raise IndexError(
|
127
|
+
# f"Index {pos} out of bounds for embedding index of size {len(self._vector_base)}"
|
128
|
+
# )
|
129
|
+
|
130
|
+
def clear(self) -> None:
|
131
|
+
self._vector_base.clear()
|
132
|
+
|
133
|
+
def serialize(self) -> NormalizedEmbeddings:
|
134
|
+
return self._vector_base.serialize()
|
135
|
+
|
136
|
+
def deserialize(self, embeddings: NormalizedEmbedding) -> None:
|
137
|
+
assert isinstance(embeddings, np.ndarray), type(embeddings)
|
138
|
+
assert embeddings.dtype == np.float32, embeddings.dtype
|
139
|
+
assert embeddings.ndim == 2, embeddings.shape
|
140
|
+
assert (
|
141
|
+
embeddings.shape[1] == self._vector_base._embedding_size
|
142
|
+
), embeddings.shape
|
143
|
+
self.clear()
|
144
|
+
self.push(embeddings)
|