databao-context-engine 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databao_context_engine/__init__.py +35 -0
- databao_context_engine/build_sources/__init__.py +0 -0
- databao_context_engine/build_sources/internal/__init__.py +0 -0
- databao_context_engine/build_sources/internal/build_runner.py +111 -0
- databao_context_engine/build_sources/internal/build_service.py +77 -0
- databao_context_engine/build_sources/internal/build_wiring.py +52 -0
- databao_context_engine/build_sources/internal/export_results.py +43 -0
- databao_context_engine/build_sources/internal/plugin_execution.py +74 -0
- databao_context_engine/build_sources/public/__init__.py +0 -0
- databao_context_engine/build_sources/public/api.py +4 -0
- databao_context_engine/cli/__init__.py +0 -0
- databao_context_engine/cli/add_datasource_config.py +130 -0
- databao_context_engine/cli/commands.py +256 -0
- databao_context_engine/cli/datasources.py +64 -0
- databao_context_engine/cli/info.py +32 -0
- databao_context_engine/config/__init__.py +0 -0
- databao_context_engine/config/log_config.yaml +16 -0
- databao_context_engine/config/logging.py +43 -0
- databao_context_engine/databao_context_project_manager.py +92 -0
- databao_context_engine/databao_engine.py +85 -0
- databao_context_engine/datasource_config/__init__.py +0 -0
- databao_context_engine/datasource_config/add_config.py +50 -0
- databao_context_engine/datasource_config/check_config.py +131 -0
- databao_context_engine/datasource_config/datasource_context.py +60 -0
- databao_context_engine/event_journal/__init__.py +0 -0
- databao_context_engine/event_journal/writer.py +29 -0
- databao_context_engine/generate_configs_schemas.py +92 -0
- databao_context_engine/init_project.py +18 -0
- databao_context_engine/introspection/__init__.py +0 -0
- databao_context_engine/introspection/property_extract.py +202 -0
- databao_context_engine/llm/__init__.py +0 -0
- databao_context_engine/llm/config.py +20 -0
- databao_context_engine/llm/descriptions/__init__.py +0 -0
- databao_context_engine/llm/descriptions/ollama.py +21 -0
- databao_context_engine/llm/descriptions/provider.py +10 -0
- databao_context_engine/llm/embeddings/__init__.py +0 -0
- databao_context_engine/llm/embeddings/ollama.py +37 -0
- databao_context_engine/llm/embeddings/provider.py +13 -0
- databao_context_engine/llm/errors.py +16 -0
- databao_context_engine/llm/factory.py +61 -0
- databao_context_engine/llm/install.py +227 -0
- databao_context_engine/llm/runtime.py +73 -0
- databao_context_engine/llm/service.py +159 -0
- databao_context_engine/main.py +19 -0
- databao_context_engine/mcp/__init__.py +0 -0
- databao_context_engine/mcp/all_results_tool.py +5 -0
- databao_context_engine/mcp/mcp_runner.py +16 -0
- databao_context_engine/mcp/mcp_server.py +63 -0
- databao_context_engine/mcp/retrieve_tool.py +22 -0
- databao_context_engine/pluginlib/__init__.py +0 -0
- databao_context_engine/pluginlib/build_plugin.py +107 -0
- databao_context_engine/pluginlib/config.py +37 -0
- databao_context_engine/pluginlib/plugin_utils.py +68 -0
- databao_context_engine/plugins/__init__.py +0 -0
- databao_context_engine/plugins/athena_db_plugin.py +12 -0
- databao_context_engine/plugins/base_db_plugin.py +45 -0
- databao_context_engine/plugins/clickhouse_db_plugin.py +15 -0
- databao_context_engine/plugins/databases/__init__.py +0 -0
- databao_context_engine/plugins/databases/athena_introspector.py +101 -0
- databao_context_engine/plugins/databases/base_introspector.py +144 -0
- databao_context_engine/plugins/databases/clickhouse_introspector.py +162 -0
- databao_context_engine/plugins/databases/database_chunker.py +69 -0
- databao_context_engine/plugins/databases/databases_types.py +114 -0
- databao_context_engine/plugins/databases/duckdb_introspector.py +325 -0
- databao_context_engine/plugins/databases/introspection_model_builder.py +270 -0
- databao_context_engine/plugins/databases/introspection_scope.py +74 -0
- databao_context_engine/plugins/databases/introspection_scope_matcher.py +103 -0
- databao_context_engine/plugins/databases/mssql_introspector.py +433 -0
- databao_context_engine/plugins/databases/mysql_introspector.py +338 -0
- databao_context_engine/plugins/databases/postgresql_introspector.py +428 -0
- databao_context_engine/plugins/databases/snowflake_introspector.py +287 -0
- databao_context_engine/plugins/duckdb_db_plugin.py +12 -0
- databao_context_engine/plugins/mssql_db_plugin.py +12 -0
- databao_context_engine/plugins/mysql_db_plugin.py +12 -0
- databao_context_engine/plugins/parquet_plugin.py +32 -0
- databao_context_engine/plugins/plugin_loader.py +110 -0
- databao_context_engine/plugins/postgresql_db_plugin.py +12 -0
- databao_context_engine/plugins/resources/__init__.py +0 -0
- databao_context_engine/plugins/resources/parquet_chunker.py +23 -0
- databao_context_engine/plugins/resources/parquet_introspector.py +154 -0
- databao_context_engine/plugins/snowflake_db_plugin.py +12 -0
- databao_context_engine/plugins/unstructured_files_plugin.py +68 -0
- databao_context_engine/project/__init__.py +0 -0
- databao_context_engine/project/datasource_discovery.py +141 -0
- databao_context_engine/project/info.py +44 -0
- databao_context_engine/project/init_project.py +102 -0
- databao_context_engine/project/layout.py +127 -0
- databao_context_engine/project/project_config.py +32 -0
- databao_context_engine/project/resources/examples/src/databases/example_postgres.yaml +7 -0
- databao_context_engine/project/resources/examples/src/files/documentation.md +30 -0
- databao_context_engine/project/resources/examples/src/files/notes.txt +20 -0
- databao_context_engine/project/runs.py +39 -0
- databao_context_engine/project/types.py +134 -0
- databao_context_engine/retrieve_embeddings/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/internal/export_results.py +12 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_runner.py +34 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_service.py +68 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +29 -0
- databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/public/api.py +3 -0
- databao_context_engine/serialisation/__init__.py +0 -0
- databao_context_engine/serialisation/yaml.py +35 -0
- databao_context_engine/services/__init__.py +0 -0
- databao_context_engine/services/chunk_embedding_service.py +104 -0
- databao_context_engine/services/embedding_shard_resolver.py +64 -0
- databao_context_engine/services/factories.py +88 -0
- databao_context_engine/services/models.py +12 -0
- databao_context_engine/services/persistence_service.py +61 -0
- databao_context_engine/services/run_name_policy.py +8 -0
- databao_context_engine/services/table_name_policy.py +15 -0
- databao_context_engine/storage/__init__.py +0 -0
- databao_context_engine/storage/connection.py +32 -0
- databao_context_engine/storage/exceptions/__init__.py +0 -0
- databao_context_engine/storage/exceptions/exceptions.py +6 -0
- databao_context_engine/storage/migrate.py +127 -0
- databao_context_engine/storage/migrations/V01__init.sql +63 -0
- databao_context_engine/storage/models.py +51 -0
- databao_context_engine/storage/repositories/__init__.py +0 -0
- databao_context_engine/storage/repositories/chunk_repository.py +130 -0
- databao_context_engine/storage/repositories/datasource_run_repository.py +136 -0
- databao_context_engine/storage/repositories/embedding_model_registry_repository.py +87 -0
- databao_context_engine/storage/repositories/embedding_repository.py +113 -0
- databao_context_engine/storage/repositories/factories.py +35 -0
- databao_context_engine/storage/repositories/run_repository.py +157 -0
- databao_context_engine/storage/repositories/vector_search_repository.py +63 -0
- databao_context_engine/storage/transaction.py +14 -0
- databao_context_engine/system/__init__.py +0 -0
- databao_context_engine/system/properties.py +13 -0
- databao_context_engine/templating/__init__.py +0 -0
- databao_context_engine/templating/renderer.py +29 -0
- databao_context_engine-0.1.1.dist-info/METADATA +186 -0
- databao_context_engine-0.1.1.dist-info/RECORD +135 -0
- databao_context_engine-0.1.1.dist-info/WHEEL +4 -0
- databao_context_engine-0.1.1.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from databao_context_engine.pluginlib.build_plugin import DatasourceType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DatasourceKind(StrEnum):
|
|
11
|
+
CONFIG = "config"
|
|
12
|
+
FILE = "file"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class DatasourceDescriptor:
|
|
17
|
+
path: Path
|
|
18
|
+
kind: DatasourceKind
|
|
19
|
+
main_type: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class PreparedConfig:
|
|
24
|
+
datasource_type: DatasourceType
|
|
25
|
+
path: Path
|
|
26
|
+
config: dict[Any, Any]
|
|
27
|
+
datasource_name: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class PreparedFile:
|
|
32
|
+
datasource_type: DatasourceType
|
|
33
|
+
path: Path
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
PreparedDatasource = PreparedConfig | PreparedFile
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(kw_only=True, frozen=True, eq=True)
|
|
40
|
+
class DatasourceId:
|
|
41
|
+
"""
|
|
42
|
+
The ID of a datasource. The ID is the path to the datasource's config file relative to the src folder in the project.
|
|
43
|
+
|
|
44
|
+
e.g: "databases/my_postgres_datasource.yaml"
|
|
45
|
+
|
|
46
|
+
Use the provided factory methods `from_string_repr` and `from_datasource_config_file_path` to create a DatasourceId, rather than its constructor.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
datasource_config_folder: str
|
|
50
|
+
datasource_name: str
|
|
51
|
+
config_file_suffix: str
|
|
52
|
+
|
|
53
|
+
def __post_init__(self):
|
|
54
|
+
if not self.datasource_config_folder.strip():
|
|
55
|
+
raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_config_folder must not be empty")
|
|
56
|
+
if not self.datasource_name.strip():
|
|
57
|
+
raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_name must not be empty")
|
|
58
|
+
if not self.config_file_suffix.strip():
|
|
59
|
+
raise ValueError(f"Invalid DatasourceId ({str(self)}): config_file_suffix must not be empty")
|
|
60
|
+
|
|
61
|
+
if os.sep in self.datasource_config_folder:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Invalid DatasourceId ({str(self)}): datasource_config_folder must not contain a path separator"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if os.sep in self.datasource_name:
|
|
67
|
+
raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_name must not contain a path separator")
|
|
68
|
+
|
|
69
|
+
if not self.config_file_suffix.startswith("."):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f'Invalid DatasourceId ({str(self)}): config_file_suffix must start with a dot "." (e.g.: .yaml)'
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if self.datasource_name.endswith(self.config_file_suffix):
|
|
75
|
+
raise ValueError(f"Invalid DatasourceId ({str(self)}): datasource_name must not contain the file suffix")
|
|
76
|
+
|
|
77
|
+
def __str__(self):
|
|
78
|
+
return str(self.relative_path_to_config_file())
|
|
79
|
+
|
|
80
|
+
def relative_path_to_config_file(self) -> Path:
|
|
81
|
+
"""
|
|
82
|
+
Returns a path to the config file for this datasource.
|
|
83
|
+
|
|
84
|
+
The returned path is relative to the src folder in the project.
|
|
85
|
+
"""
|
|
86
|
+
return Path(self.datasource_config_folder).joinpath(self.datasource_name + self.config_file_suffix)
|
|
87
|
+
|
|
88
|
+
def relative_path_to_context_file(self) -> Path:
|
|
89
|
+
"""
|
|
90
|
+
Returns a path to the config file for this datasource.
|
|
91
|
+
|
|
92
|
+
The returned path is relative to an output run folder in the project.
|
|
93
|
+
"""
|
|
94
|
+
# Keep the suffix in the filename if this datasource is a raw file, to handle multiple files with the same name and different extensions
|
|
95
|
+
suffix = ".yaml" if self.config_file_suffix == ".yaml" else (self.config_file_suffix + ".yaml")
|
|
96
|
+
|
|
97
|
+
return Path(self.datasource_config_folder).joinpath(self.datasource_name + suffix)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def from_string_repr(cls, datasource_id_as_string: str):
|
|
101
|
+
"""
|
|
102
|
+
Creates a DatasourceId from a string representation.
|
|
103
|
+
|
|
104
|
+
The string representation of a DatasourceId is the path to the datasource's config file relative to the src folder in the project.
|
|
105
|
+
|
|
106
|
+
e.g: "databases/my_postgres_datasource.yaml"
|
|
107
|
+
"""
|
|
108
|
+
config_file_path = Path(datasource_id_as_string)
|
|
109
|
+
|
|
110
|
+
if len(config_file_path.parents) > 2:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Invalid string representation of a DatasourceId: too many parent folders defined in {datasource_id_as_string}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return DatasourceId.from_datasource_config_file_path(config_file_path)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def from_datasource_config_file_path(cls, datasource_config_file: Path):
|
|
119
|
+
"""
|
|
120
|
+
Creates a DatasourceId from a config file path.
|
|
121
|
+
|
|
122
|
+
The `datasource_config_file` path provided can either be the config file path relative to the src folder or the full path to the config file.
|
|
123
|
+
"""
|
|
124
|
+
return DatasourceId(
|
|
125
|
+
datasource_config_folder=datasource_config_file.parent.name,
|
|
126
|
+
datasource_name=datasource_config_file.stem,
|
|
127
|
+
config_file_suffix=datasource_config_file.suffix,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class Datasource:
|
|
133
|
+
id: DatasourceId
|
|
134
|
+
type: DatasourceType
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def export_retrieve_results(run_dir: Path, retrieve_results: list[str]) -> Path:
|
|
5
|
+
path = run_dir.joinpath("context_duckdb.yaml")
|
|
6
|
+
|
|
7
|
+
with path.open("w") as export_file:
|
|
8
|
+
for result in retrieve_results:
|
|
9
|
+
export_file.write(result)
|
|
10
|
+
export_file.write("\n")
|
|
11
|
+
|
|
12
|
+
return path
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from databao_context_engine.project.runs import get_run_dir
|
|
5
|
+
from databao_context_engine.retrieve_embeddings.internal.export_results import export_retrieve_results
|
|
6
|
+
from databao_context_engine.retrieve_embeddings.internal.retrieve_service import RetrieveService
|
|
7
|
+
from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def retrieve(
|
|
13
|
+
project_dir: Path,
|
|
14
|
+
*,
|
|
15
|
+
retrieve_service: RetrieveService,
|
|
16
|
+
project_id: str,
|
|
17
|
+
text: str,
|
|
18
|
+
run_name: str | None,
|
|
19
|
+
limit: int | None,
|
|
20
|
+
export_to_file: bool,
|
|
21
|
+
) -> list[VectorSearchResult]:
|
|
22
|
+
resolved_run_name = retrieve_service.resolve_run_name(project_id=project_id, run_name=run_name)
|
|
23
|
+
retrieve_results = retrieve_service.retrieve(
|
|
24
|
+
project_id=project_id, text=text, run_name=resolved_run_name, limit=limit
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
if export_to_file:
|
|
28
|
+
export_directory = get_run_dir(project_dir=project_dir, run_name=resolved_run_name)
|
|
29
|
+
|
|
30
|
+
display_texts = [result.display_text for result in retrieve_results]
|
|
31
|
+
export_file = export_retrieve_results(export_directory, display_texts)
|
|
32
|
+
logger.info(f"Exported results to {export_file}")
|
|
33
|
+
|
|
34
|
+
return retrieve_results
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
5
|
+
from databao_context_engine.project.runs import resolve_run_name_from_repo
|
|
6
|
+
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
7
|
+
from databao_context_engine.storage.repositories.run_repository import RunRepository
|
|
8
|
+
from databao_context_engine.storage.repositories.vector_search_repository import (
|
|
9
|
+
VectorSearchRepository,
|
|
10
|
+
VectorSearchResult,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RetrieveService:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
run_repo: RunRepository,
|
|
21
|
+
vector_search_repo: VectorSearchRepository,
|
|
22
|
+
shard_resolver: EmbeddingShardResolver,
|
|
23
|
+
provider: EmbeddingProvider,
|
|
24
|
+
):
|
|
25
|
+
self._run_repo = run_repo
|
|
26
|
+
self._shard_resolver = shard_resolver
|
|
27
|
+
self._provider = provider
|
|
28
|
+
self._vector_search_repo = vector_search_repo
|
|
29
|
+
|
|
30
|
+
def retrieve(
|
|
31
|
+
self, *, project_id: str, text: str, run_name: str, limit: int | None = None
|
|
32
|
+
) -> list[VectorSearchResult]:
|
|
33
|
+
if limit is None:
|
|
34
|
+
limit = 10
|
|
35
|
+
|
|
36
|
+
run = self._run_repo.get_by_run_name(project_id=project_id, run_name=run_name)
|
|
37
|
+
if run is None:
|
|
38
|
+
raise LookupError(f"Run '{run_name}' not found for project '{project_id}'.")
|
|
39
|
+
|
|
40
|
+
table_name, dimension = self._shard_resolver.resolve(
|
|
41
|
+
embedder=self._provider.embedder, model_id=self._provider.model_id
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
retrieve_vec: Sequence[float] = self._provider.embed(text)
|
|
45
|
+
|
|
46
|
+
logger.debug(f"Retrieving display texts for run {run.run_id} in table {table_name}")
|
|
47
|
+
|
|
48
|
+
search_results = self._vector_search_repo.get_display_texts_by_similarity(
|
|
49
|
+
table_name=table_name,
|
|
50
|
+
run_id=run.run_id,
|
|
51
|
+
retrieve_vec=retrieve_vec,
|
|
52
|
+
dimension=dimension,
|
|
53
|
+
limit=limit,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
logger.debug(f"Retrieved {len(search_results)} display texts for run {run.run_id} in table {table_name}")
|
|
57
|
+
|
|
58
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
59
|
+
closest_result = min(search_results, key=lambda result: result.cosine_distance)
|
|
60
|
+
logger.debug(f"Best result: ({closest_result.cosine_distance}, {closest_result.embeddable_text})")
|
|
61
|
+
|
|
62
|
+
farthest_result = max(search_results, key=lambda result: result.cosine_distance)
|
|
63
|
+
logger.debug(f"Worst result: ({farthest_result.cosine_distance}, {farthest_result.embeddable_text})")
|
|
64
|
+
|
|
65
|
+
return search_results
|
|
66
|
+
|
|
67
|
+
def resolve_run_name(self, *, project_id: str, run_name: str | None) -> str:
|
|
68
|
+
return resolve_run_name_from_repo(run_repository=self._run_repo, project_id=project_id, run_name=run_name)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from databao_context_engine.llm.factory import create_ollama_embedding_provider, create_ollama_service
|
|
2
|
+
from databao_context_engine.project.layout import ProjectLayout
|
|
3
|
+
from databao_context_engine.retrieve_embeddings.internal.retrieve_runner import retrieve
|
|
4
|
+
from databao_context_engine.services.factories import create_retrieve_service
|
|
5
|
+
from databao_context_engine.storage.connection import open_duckdb_connection
|
|
6
|
+
from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
|
|
7
|
+
from databao_context_engine.system.properties import get_db_path
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def retrieve_embeddings(
|
|
11
|
+
project_layout: ProjectLayout,
|
|
12
|
+
retrieve_text: str,
|
|
13
|
+
run_name: str | None,
|
|
14
|
+
limit: int | None,
|
|
15
|
+
export_to_file: bool,
|
|
16
|
+
) -> list[VectorSearchResult]:
|
|
17
|
+
with open_duckdb_connection(get_db_path()) as conn:
|
|
18
|
+
ollama_service = create_ollama_service()
|
|
19
|
+
embedding_provider = create_ollama_embedding_provider(ollama_service)
|
|
20
|
+
retrieve_service = create_retrieve_service(conn, embedding_provider=embedding_provider)
|
|
21
|
+
return retrieve(
|
|
22
|
+
project_dir=project_layout.project_dir,
|
|
23
|
+
retrieve_service=retrieve_service,
|
|
24
|
+
project_id=str(project_layout.read_config_file().project_id),
|
|
25
|
+
text=retrieve_text,
|
|
26
|
+
run_name=run_name,
|
|
27
|
+
limit=limit,
|
|
28
|
+
export_to_file=export_to_file,
|
|
29
|
+
)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Any, Mapping, TextIO, cast
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
from yaml import Node, SafeDumper
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def default_representer(dumper: SafeDumper, data: object) -> Node:
|
|
8
|
+
if isinstance(data, Mapping):
|
|
9
|
+
return dumper.represent_dict(data)
|
|
10
|
+
elif hasattr(data, "__dict__"):
|
|
11
|
+
# Doesn't serialise "private" attributes (that starts with an _)
|
|
12
|
+
data_public_attributes = {key: value for key, value in data.__dict__.items() if not key.startswith("_")}
|
|
13
|
+
if data_public_attributes:
|
|
14
|
+
return dumper.represent_dict(data_public_attributes)
|
|
15
|
+
else:
|
|
16
|
+
# If there is no public attributes, we default to the string representation
|
|
17
|
+
return dumper.represent_str(str(data))
|
|
18
|
+
else:
|
|
19
|
+
return dumper.represent_str(str(data))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Registers our default representer only once, when that file is imported
|
|
23
|
+
yaml.add_multi_representer(object, default_representer, Dumper=SafeDumper)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def write_yaml_to_stream(*, data: Any, file_stream: TextIO) -> None:
|
|
27
|
+
_to_yaml(data, file_stream)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def to_yaml_string(data: Any) -> str:
|
|
31
|
+
return cast(str, _to_yaml(data, None))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _to_yaml(data: Any, stream: TextIO | None) -> str | None:
|
|
35
|
+
return yaml.safe_dump(data, stream, sort_keys=False, default_flow_style=False)
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
from databao_context_engine.llm.descriptions.provider import DescriptionProvider
|
|
6
|
+
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
7
|
+
from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk
|
|
8
|
+
from databao_context_engine.serialisation.yaml import to_yaml_string
|
|
9
|
+
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
10
|
+
from databao_context_engine.services.models import ChunkEmbedding
|
|
11
|
+
from databao_context_engine.services.persistence_service import PersistenceService
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChunkEmbeddingMode(Enum):
|
|
17
|
+
EMBEDDABLE_TEXT_ONLY = "EMBEDDABLE_TEXT_ONLY"
|
|
18
|
+
GENERATED_DESCRIPTION_ONLY = "GENERATED_DESCRIPTION_ONLY"
|
|
19
|
+
EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION = "EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION"
|
|
20
|
+
|
|
21
|
+
def should_generate_description(self) -> bool:
|
|
22
|
+
return self in (
|
|
23
|
+
ChunkEmbeddingMode.GENERATED_DESCRIPTION_ONLY,
|
|
24
|
+
ChunkEmbeddingMode.EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ChunkEmbeddingService:
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*,
|
|
32
|
+
persistence_service: PersistenceService,
|
|
33
|
+
embedding_provider: EmbeddingProvider,
|
|
34
|
+
description_provider: DescriptionProvider | None,
|
|
35
|
+
shard_resolver: EmbeddingShardResolver,
|
|
36
|
+
chunk_embedding_mode: ChunkEmbeddingMode = ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY,
|
|
37
|
+
):
|
|
38
|
+
self._persistence_service = persistence_service
|
|
39
|
+
self._embedding_provider = embedding_provider
|
|
40
|
+
self._description_provider = description_provider
|
|
41
|
+
self._shard_resolver = shard_resolver
|
|
42
|
+
self._chunk_embedding_mode = chunk_embedding_mode
|
|
43
|
+
|
|
44
|
+
if self._chunk_embedding_mode.should_generate_description() and description_provider is None:
|
|
45
|
+
raise ValueError("A DescriptionProvider must be provided when generating descriptions")
|
|
46
|
+
|
|
47
|
+
def embed_chunks(self, *, datasource_run_id: int, chunks: list[EmbeddableChunk], result: str) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Turn plugin chunks into persisted chunks and embeddings
|
|
50
|
+
|
|
51
|
+
Flow:
|
|
52
|
+
1) Embed each chunk into an embedded vector
|
|
53
|
+
2) Get or create embedding table for the appropriate model and embedding dimensions
|
|
54
|
+
3) Persist chunks and embeddings vectors in a single transaction
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
if not chunks:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
logger.debug(
|
|
61
|
+
f"Embedding {len(chunks)} chunks for datasource run {datasource_run_id}, with chunk_embedding_mode={self._chunk_embedding_mode}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
enriched_embeddings: list[ChunkEmbedding] = []
|
|
65
|
+
for chunk in chunks:
|
|
66
|
+
chunk_display_text = to_yaml_string(chunk.content)
|
|
67
|
+
|
|
68
|
+
generated_description = ""
|
|
69
|
+
match self._chunk_embedding_mode:
|
|
70
|
+
case ChunkEmbeddingMode.EMBEDDABLE_TEXT_ONLY:
|
|
71
|
+
embedding_text = chunk.embeddable_text
|
|
72
|
+
case ChunkEmbeddingMode.GENERATED_DESCRIPTION_ONLY:
|
|
73
|
+
generated_description = cast(DescriptionProvider, self._description_provider).describe(
|
|
74
|
+
text=chunk_display_text, context=result
|
|
75
|
+
)
|
|
76
|
+
embedding_text = generated_description
|
|
77
|
+
case ChunkEmbeddingMode.EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION:
|
|
78
|
+
generated_description = cast(DescriptionProvider, self._description_provider).describe(
|
|
79
|
+
text=chunk_display_text, context=result
|
|
80
|
+
)
|
|
81
|
+
embedding_text = generated_description + "\n" + chunk.embeddable_text
|
|
82
|
+
|
|
83
|
+
vec = self._embedding_provider.embed(embedding_text)
|
|
84
|
+
|
|
85
|
+
enriched_embeddings.append(
|
|
86
|
+
ChunkEmbedding(
|
|
87
|
+
chunk=chunk,
|
|
88
|
+
vec=vec,
|
|
89
|
+
display_text=chunk_display_text,
|
|
90
|
+
generated_description=generated_description,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
table_name = self._shard_resolver.resolve_or_create(
|
|
95
|
+
embedder=self._embedding_provider.embedder,
|
|
96
|
+
model_id=self._embedding_provider.model_id,
|
|
97
|
+
dim=self._embedding_provider.dim,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self._persistence_service.write_chunks_and_embeddings(
|
|
101
|
+
datasource_run_id=datasource_run_id,
|
|
102
|
+
chunk_embeddings=enriched_embeddings,
|
|
103
|
+
table_name=table_name,
|
|
104
|
+
)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import duckdb
|
|
2
|
+
|
|
3
|
+
from databao_context_engine.services.table_name_policy import TableNamePolicy
|
|
4
|
+
from databao_context_engine.storage.repositories.embedding_model_registry_repository import (
|
|
5
|
+
EmbeddingModelRegistryRepository,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EmbeddingShardResolver:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
*,
|
|
13
|
+
conn: duckdb.DuckDBPyConnection,
|
|
14
|
+
registry_repo: EmbeddingModelRegistryRepository,
|
|
15
|
+
table_name_policy: TableNamePolicy | None = None,
|
|
16
|
+
):
|
|
17
|
+
self._conn = conn
|
|
18
|
+
self._registry = registry_repo
|
|
19
|
+
self._policy = table_name_policy or TableNamePolicy()
|
|
20
|
+
|
|
21
|
+
def resolve(self, *, embedder: str, model_id: str) -> tuple[str, int]:
|
|
22
|
+
row = self._registry.get(embedder=embedder, model_id=model_id)
|
|
23
|
+
if not row:
|
|
24
|
+
raise ValueError(f"Model not registered: {embedder}:{model_id}")
|
|
25
|
+
return row.table_name, row.dim
|
|
26
|
+
|
|
27
|
+
def resolve_or_create(self, *, embedder: str, model_id: str, dim: int) -> str:
|
|
28
|
+
row = self._registry.get(embedder=embedder, model_id=model_id)
|
|
29
|
+
if row:
|
|
30
|
+
if row.dim != dim:
|
|
31
|
+
raise ValueError(f"Model already registered with dim={row.dim}, requested dim={dim}")
|
|
32
|
+
return row.table_name
|
|
33
|
+
|
|
34
|
+
table_name = self._policy.build(embedder=embedder, model_id=model_id, dim=dim)
|
|
35
|
+
self._create_table_and_index(table_name, dim)
|
|
36
|
+
|
|
37
|
+
self._registry.create(
|
|
38
|
+
embedder=embedder,
|
|
39
|
+
model_id=model_id,
|
|
40
|
+
dim=dim,
|
|
41
|
+
table_name=table_name,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
return table_name
|
|
45
|
+
|
|
46
|
+
def _create_table_and_index(self, table_name: str, dim: int) -> None:
|
|
47
|
+
self._conn.execute("LOAD vss;")
|
|
48
|
+
self._conn.execute("SET hnsw_enable_experimental_persistence = true;")
|
|
49
|
+
|
|
50
|
+
self._conn.execute(
|
|
51
|
+
f"""
|
|
52
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
|
53
|
+
chunk_id BIGINT NOT NULL REFERENCES chunk(chunk_id),
|
|
54
|
+
vec FLOAT[{dim}] NOT NULL,
|
|
55
|
+
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
56
|
+
PRIMARY KEY (chunk_id)
|
|
57
|
+
)
|
|
58
|
+
"""
|
|
59
|
+
)
|
|
60
|
+
self._conn.execute(
|
|
61
|
+
f"""
|
|
62
|
+
CREATE INDEX IF NOT EXISTS emb_hnsw_{table_name} ON {table_name} USING HNSW (vec) WITH (metric='cosine');
|
|
63
|
+
"""
|
|
64
|
+
)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from _duckdb import DuckDBPyConnection
|
|
2
|
+
|
|
3
|
+
from databao_context_engine.build_sources.internal.build_service import BuildService
|
|
4
|
+
from databao_context_engine.llm.descriptions.provider import DescriptionProvider
|
|
5
|
+
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
6
|
+
from databao_context_engine.retrieve_embeddings.internal.retrieve_service import RetrieveService
|
|
7
|
+
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode, ChunkEmbeddingService
|
|
8
|
+
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
9
|
+
from databao_context_engine.services.persistence_service import PersistenceService
|
|
10
|
+
from databao_context_engine.services.table_name_policy import TableNamePolicy
|
|
11
|
+
from databao_context_engine.storage.repositories.factories import (
|
|
12
|
+
create_chunk_repository,
|
|
13
|
+
create_datasource_run_repository,
|
|
14
|
+
create_embedding_repository,
|
|
15
|
+
create_registry_repository,
|
|
16
|
+
create_run_repository,
|
|
17
|
+
create_vector_search_repository,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_shard_resolver(conn: DuckDBPyConnection, policy: TableNamePolicy | None = None) -> EmbeddingShardResolver:
|
|
22
|
+
return EmbeddingShardResolver(
|
|
23
|
+
conn=conn, registry_repo=create_registry_repository(conn), table_name_policy=policy or TableNamePolicy()
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_persistence_service(conn: DuckDBPyConnection) -> PersistenceService:
|
|
28
|
+
return PersistenceService(
|
|
29
|
+
conn=conn, chunk_repo=create_chunk_repository(conn), embedding_repo=create_embedding_repository(conn)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def create_chunk_embedding_service(
|
|
34
|
+
conn: DuckDBPyConnection,
|
|
35
|
+
*,
|
|
36
|
+
embedding_provider: EmbeddingProvider,
|
|
37
|
+
description_provider: DescriptionProvider | None,
|
|
38
|
+
chunk_embedding_mode: ChunkEmbeddingMode,
|
|
39
|
+
) -> ChunkEmbeddingService:
|
|
40
|
+
resolver = create_shard_resolver(conn)
|
|
41
|
+
persistence = create_persistence_service(conn)
|
|
42
|
+
return ChunkEmbeddingService(
|
|
43
|
+
persistence_service=persistence,
|
|
44
|
+
embedding_provider=embedding_provider,
|
|
45
|
+
shard_resolver=resolver,
|
|
46
|
+
description_provider=description_provider,
|
|
47
|
+
chunk_embedding_mode=chunk_embedding_mode,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def create_build_service(
|
|
52
|
+
conn: DuckDBPyConnection,
|
|
53
|
+
*,
|
|
54
|
+
embedding_provider: EmbeddingProvider,
|
|
55
|
+
description_provider: DescriptionProvider | None,
|
|
56
|
+
chunk_embedding_mode: ChunkEmbeddingMode,
|
|
57
|
+
) -> BuildService:
|
|
58
|
+
run_repo = create_run_repository(conn)
|
|
59
|
+
datasource_run_repo = create_datasource_run_repository(conn)
|
|
60
|
+
chunk_embedding_service = create_chunk_embedding_service(
|
|
61
|
+
conn,
|
|
62
|
+
embedding_provider=embedding_provider,
|
|
63
|
+
description_provider=description_provider,
|
|
64
|
+
chunk_embedding_mode=chunk_embedding_mode,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return BuildService(
|
|
68
|
+
run_repo=run_repo,
|
|
69
|
+
datasource_run_repo=datasource_run_repo,
|
|
70
|
+
chunk_embedding_service=chunk_embedding_service,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def create_retrieve_service(
|
|
75
|
+
conn: DuckDBPyConnection,
|
|
76
|
+
*,
|
|
77
|
+
embedding_provider: EmbeddingProvider,
|
|
78
|
+
) -> RetrieveService:
|
|
79
|
+
run_repo = create_run_repository(conn)
|
|
80
|
+
vector_search_repo = create_vector_search_repository(conn)
|
|
81
|
+
shard_resolver = create_shard_resolver(conn)
|
|
82
|
+
|
|
83
|
+
return RetrieveService(
|
|
84
|
+
run_repo=run_repo,
|
|
85
|
+
vector_search_repo=vector_search_repo,
|
|
86
|
+
shard_resolver=shard_resolver,
|
|
87
|
+
provider=embedding_provider,
|
|
88
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class ChunkEmbedding:
|
|
9
|
+
chunk: EmbeddableChunk
|
|
10
|
+
vec: Sequence[float]
|
|
11
|
+
display_text: str
|
|
12
|
+
generated_description: str
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import duckdb
|
|
4
|
+
|
|
5
|
+
from databao_context_engine.services.models import ChunkEmbedding
|
|
6
|
+
from databao_context_engine.storage.models import ChunkDTO
|
|
7
|
+
from databao_context_engine.storage.repositories.chunk_repository import ChunkRepository
|
|
8
|
+
from databao_context_engine.storage.repositories.embedding_repository import EmbeddingRepository
|
|
9
|
+
from databao_context_engine.storage.transaction import transaction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PersistenceService:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
conn: duckdb.DuckDBPyConnection,
|
|
16
|
+
chunk_repo: ChunkRepository,
|
|
17
|
+
embedding_repo: EmbeddingRepository,
|
|
18
|
+
*,
|
|
19
|
+
dim: int = 768,
|
|
20
|
+
):
|
|
21
|
+
self._conn = conn
|
|
22
|
+
self._chunk_repo = chunk_repo
|
|
23
|
+
self._embedding_repo = embedding_repo
|
|
24
|
+
self._dim = dim
|
|
25
|
+
|
|
26
|
+
def write_chunks_and_embeddings(
|
|
27
|
+
self, *, datasource_run_id: int, chunk_embeddings: list[ChunkEmbedding], table_name: str
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Atomically persist chunks and their vectors.
|
|
31
|
+
Returns the number of embeddings written.
|
|
32
|
+
"""
|
|
33
|
+
if not chunk_embeddings:
|
|
34
|
+
raise ValueError("chunk_embeddings must be a non-empty list")
|
|
35
|
+
|
|
36
|
+
with transaction(self._conn):
|
|
37
|
+
for chunk_embedding in chunk_embeddings:
|
|
38
|
+
chunk_dto = self.create_chunk(
|
|
39
|
+
datasource_run_id=datasource_run_id,
|
|
40
|
+
embeddable_text=chunk_embedding.chunk.embeddable_text,
|
|
41
|
+
display_text=chunk_embedding.display_text,
|
|
42
|
+
generated_description=chunk_embedding.generated_description,
|
|
43
|
+
)
|
|
44
|
+
self.create_embedding(table_name=table_name, chunk_id=chunk_dto.chunk_id, vec=chunk_embedding.vec)
|
|
45
|
+
|
|
46
|
+
def create_chunk(
|
|
47
|
+
self, *, datasource_run_id: int, embeddable_text: str, display_text: str, generated_description: str
|
|
48
|
+
) -> ChunkDTO:
|
|
49
|
+
return self._chunk_repo.create(
|
|
50
|
+
datasource_run_id=datasource_run_id,
|
|
51
|
+
embeddable_text=embeddable_text,
|
|
52
|
+
display_text=display_text,
|
|
53
|
+
generated_description=generated_description,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def create_embedding(self, *, table_name: str, chunk_id: int, vec: Sequence[float]):
|
|
57
|
+
self._embedding_repo.create(
|
|
58
|
+
table_name=table_name,
|
|
59
|
+
chunk_id=chunk_id,
|
|
60
|
+
vec=vec,
|
|
61
|
+
)
|