databao-context-engine 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databao_context_engine/__init__.py +18 -6
- databao_context_engine/build_sources/__init__.py +4 -0
- databao_context_engine/build_sources/{internal/build_runner.py → build_runner.py} +27 -23
- databao_context_engine/build_sources/build_service.py +53 -0
- databao_context_engine/build_sources/build_wiring.py +84 -0
- databao_context_engine/build_sources/export_results.py +41 -0
- databao_context_engine/build_sources/{internal/plugin_execution.py → plugin_execution.py} +3 -7
- databao_context_engine/cli/add_datasource_config.py +41 -15
- databao_context_engine/cli/commands.py +12 -43
- databao_context_engine/cli/info.py +2 -2
- databao_context_engine/databao_context_engine.py +137 -0
- databao_context_engine/databao_context_project_manager.py +96 -6
- databao_context_engine/datasources/add_config.py +34 -0
- databao_context_engine/{datasource_config → datasources}/check_config.py +18 -7
- databao_context_engine/datasources/datasource_context.py +93 -0
- databao_context_engine/{project → datasources}/datasource_discovery.py +18 -17
- databao_context_engine/{project → datasources}/types.py +64 -15
- databao_context_engine/init_project.py +25 -3
- databao_context_engine/introspection/property_extract.py +59 -30
- databao_context_engine/llm/errors.py +2 -8
- databao_context_engine/llm/install.py +13 -20
- databao_context_engine/llm/service.py +1 -3
- databao_context_engine/mcp/all_results_tool.py +2 -2
- databao_context_engine/mcp/mcp_runner.py +4 -2
- databao_context_engine/mcp/mcp_server.py +1 -4
- databao_context_engine/mcp/retrieve_tool.py +3 -11
- databao_context_engine/plugin_loader.py +111 -0
- databao_context_engine/pluginlib/build_plugin.py +25 -9
- databao_context_engine/pluginlib/config.py +16 -2
- databao_context_engine/plugins/databases/athena_introspector.py +85 -22
- databao_context_engine/plugins/databases/base_introspector.py +5 -3
- databao_context_engine/plugins/databases/clickhouse_introspector.py +22 -11
- databao_context_engine/plugins/databases/duckdb_introspector.py +1 -1
- databao_context_engine/plugins/databases/introspection_scope.py +11 -9
- databao_context_engine/plugins/databases/introspection_scope_matcher.py +2 -5
- databao_context_engine/plugins/databases/mssql_introspector.py +26 -17
- databao_context_engine/plugins/databases/mysql_introspector.py +23 -12
- databao_context_engine/plugins/databases/postgresql_introspector.py +2 -2
- databao_context_engine/plugins/databases/snowflake_introspector.py +43 -10
- databao_context_engine/plugins/plugin_loader.py +54 -45
- databao_context_engine/plugins/resources/parquet_introspector.py +2 -3
- databao_context_engine/project/info.py +31 -2
- databao_context_engine/project/init_project.py +16 -7
- databao_context_engine/project/layout.py +3 -3
- databao_context_engine/retrieve_embeddings/__init__.py +3 -0
- databao_context_engine/retrieve_embeddings/{internal/export_results.py → export_results.py} +2 -2
- databao_context_engine/retrieve_embeddings/{internal/retrieve_runner.py → retrieve_runner.py} +5 -9
- databao_context_engine/retrieve_embeddings/{internal/retrieve_service.py → retrieve_service.py} +3 -17
- databao_context_engine/retrieve_embeddings/retrieve_wiring.py +49 -0
- databao_context_engine/{serialisation → serialization}/yaml.py +1 -1
- databao_context_engine/services/chunk_embedding_service.py +23 -11
- databao_context_engine/services/factories.py +1 -46
- databao_context_engine/services/persistence_service.py +11 -11
- databao_context_engine/storage/connection.py +11 -7
- databao_context_engine/storage/exceptions/exceptions.py +2 -2
- databao_context_engine/storage/migrate.py +2 -4
- databao_context_engine/storage/migrations/V01__init.sql +6 -31
- databao_context_engine/storage/models.py +2 -23
- databao_context_engine/storage/repositories/chunk_repository.py +16 -12
- databao_context_engine/storage/repositories/factories.py +1 -12
- databao_context_engine/storage/repositories/vector_search_repository.py +8 -13
- databao_context_engine/system/properties.py +4 -2
- databao_context_engine-0.1.3.dist-info/METADATA +75 -0
- {databao_context_engine-0.1.1.dist-info → databao_context_engine-0.1.3.dist-info}/RECORD +68 -77
- databao_context_engine/build_sources/internal/build_service.py +0 -77
- databao_context_engine/build_sources/internal/build_wiring.py +0 -52
- databao_context_engine/build_sources/internal/export_results.py +0 -43
- databao_context_engine/build_sources/public/api.py +0 -4
- databao_context_engine/databao_engine.py +0 -85
- databao_context_engine/datasource_config/__init__.py +0 -0
- databao_context_engine/datasource_config/add_config.py +0 -50
- databao_context_engine/datasource_config/datasource_context.py +0 -60
- databao_context_engine/project/runs.py +0 -39
- databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +0 -29
- databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/public/api.py +0 -3
- databao_context_engine/serialisation/__init__.py +0 -0
- databao_context_engine/services/run_name_policy.py +0 -8
- databao_context_engine/storage/repositories/datasource_run_repository.py +0 -136
- databao_context_engine/storage/repositories/run_repository.py +0 -157
- databao_context_engine-0.1.1.dist-info/METADATA +0 -186
- /databao_context_engine/{build_sources/internal → datasources}/__init__.py +0 -0
- /databao_context_engine/{build_sources/public → serialization}/__init__.py +0 -0
- {databao_context_engine-0.1.1.dist-info → databao_context_engine-0.1.3.dist-info}/WHEEL +0 -0
- {databao_context_engine-0.1.1.dist-info → databao_context_engine-0.1.3.dist-info}/entry_points.txt +0 -0
|
@@ -14,29 +14,11 @@ class DuplicatePluginTypeError(RuntimeError):
|
|
|
14
14
|
"""Raised when two plugins register the same <main>/<sub> plugin key."""
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def get_all_available_plugin_types(exclude_file_plugins: bool = False) -> set[DatasourceType]:
|
|
21
|
-
return set(load_plugins(exclude_file_plugins=exclude_file_plugins).keys())
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def get_plugin_for_type(datasource_type: DatasourceType) -> BuildPlugin:
|
|
25
|
-
all_plugins = load_plugins()
|
|
26
|
-
|
|
27
|
-
if datasource_type not in all_plugins:
|
|
28
|
-
raise ValueError(f"No plugin found for type '{datasource_type.full_type}'")
|
|
29
|
-
|
|
30
|
-
return load_plugins()[datasource_type]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def load_plugins(exclude_file_plugins: bool = False) -> PluginList:
|
|
34
|
-
"""
|
|
35
|
-
Loads both builtin and external plugins and merges them into one list
|
|
36
|
-
"""
|
|
17
|
+
def load_plugins(exclude_file_plugins: bool = False) -> dict[DatasourceType, BuildPlugin]:
|
|
18
|
+
"""Load both builtin and external plugins and merges them into one list."""
|
|
37
19
|
builtin_plugins = _load_builtin_plugins(exclude_file_plugins)
|
|
38
20
|
external_plugins = _load_external_plugins(exclude_file_plugins)
|
|
39
|
-
plugins =
|
|
21
|
+
plugins = _merge_plugins(builtin_plugins, external_plugins)
|
|
40
22
|
|
|
41
23
|
return plugins
|
|
42
24
|
|
|
@@ -61,43 +43,70 @@ def _load_builtin_file_plugins() -> list[BuildFilePlugin]:
|
|
|
61
43
|
|
|
62
44
|
|
|
63
45
|
def _load_builtin_datasource_plugins() -> list[BuildDatasourcePlugin]:
|
|
64
|
-
"""
|
|
65
|
-
Statically register built-in plugins
|
|
66
|
-
"""
|
|
67
|
-
from databao_context_engine.plugins.athena_db_plugin import AthenaDbPlugin
|
|
68
|
-
from databao_context_engine.plugins.clickhouse_db_plugin import ClickhouseDbPlugin
|
|
46
|
+
"""Statically register built-in plugins."""
|
|
69
47
|
from databao_context_engine.plugins.duckdb_db_plugin import DuckDbPlugin
|
|
70
|
-
from databao_context_engine.plugins.mssql_db_plugin import MSSQLDbPlugin
|
|
71
|
-
from databao_context_engine.plugins.mysql_db_plugin import MySQLDbPlugin
|
|
72
48
|
from databao_context_engine.plugins.parquet_plugin import ParquetPlugin
|
|
73
|
-
from databao_context_engine.plugins.postgresql_db_plugin import PostgresqlDbPlugin
|
|
74
|
-
from databao_context_engine.plugins.snowflake_db_plugin import SnowflakeDbPlugin
|
|
75
49
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
50
|
+
# optional plugins are added to the python environment via extras
|
|
51
|
+
optional_plugins: list[BuildDatasourcePlugin] = []
|
|
52
|
+
try:
|
|
53
|
+
from databao_context_engine.plugins.mssql_db_plugin import MSSQLDbPlugin
|
|
54
|
+
|
|
55
|
+
optional_plugins = [MSSQLDbPlugin()]
|
|
56
|
+
except ImportError:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
from databao_context_engine.plugins.clickhouse_db_plugin import ClickhouseDbPlugin
|
|
61
|
+
|
|
62
|
+
optional_plugins.append(ClickhouseDbPlugin())
|
|
63
|
+
except ImportError:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
from databao_context_engine.plugins.athena_db_plugin import AthenaDbPlugin
|
|
68
|
+
|
|
69
|
+
optional_plugins.append(AthenaDbPlugin())
|
|
70
|
+
except ImportError:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
from databao_context_engine.plugins.snowflake_db_plugin import SnowflakeDbPlugin
|
|
75
|
+
|
|
76
|
+
optional_plugins.append(SnowflakeDbPlugin())
|
|
77
|
+
except ImportError:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
from databao_context_engine.plugins.mysql_db_plugin import MySQLDbPlugin
|
|
82
|
+
|
|
83
|
+
optional_plugins.append(MySQLDbPlugin())
|
|
84
|
+
except ImportError:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
from databao_context_engine.plugins.postgresql_db_plugin import PostgresqlDbPlugin
|
|
89
|
+
|
|
90
|
+
optional_plugins.append(PostgresqlDbPlugin())
|
|
91
|
+
except ImportError:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
required_plugins: list[BuildDatasourcePlugin] = [
|
|
79
95
|
DuckDbPlugin(),
|
|
80
|
-
MSSQLDbPlugin(),
|
|
81
|
-
MySQLDbPlugin(),
|
|
82
|
-
PostgresqlDbPlugin(),
|
|
83
|
-
SnowflakeDbPlugin(),
|
|
84
96
|
ParquetPlugin(),
|
|
85
97
|
]
|
|
98
|
+
return required_plugins + optional_plugins
|
|
86
99
|
|
|
87
100
|
|
|
88
101
|
def _load_external_plugins(exclude_file_plugins: bool = False) -> list[BuildPlugin]:
|
|
89
|
-
"""
|
|
90
|
-
Discover external plugins via entry points
|
|
91
|
-
"""
|
|
102
|
+
"""Discover external plugins via entry points."""
|
|
92
103
|
# TODO: implement external plugin loading
|
|
93
104
|
return []
|
|
94
105
|
|
|
95
106
|
|
|
96
|
-
def
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
"""
|
|
100
|
-
registry: PluginList = {}
|
|
107
|
+
def _merge_plugins(*plugin_lists: list[BuildPlugin]) -> dict[DatasourceType, BuildPlugin]:
|
|
108
|
+
"""Merge multiple plugin maps."""
|
|
109
|
+
registry: dict[DatasourceType, BuildPlugin] = {}
|
|
101
110
|
for plugins in plugin_lists:
|
|
102
111
|
for plugin in plugins:
|
|
103
112
|
for full_type in plugin.supported_types():
|
|
@@ -6,7 +6,7 @@ from dataclasses import dataclass, replace
|
|
|
6
6
|
from urllib.parse import urlparse
|
|
7
7
|
|
|
8
8
|
import duckdb
|
|
9
|
-
from
|
|
9
|
+
from duckdb import DuckDBPyConnection
|
|
10
10
|
from pydantic import BaseModel, Field
|
|
11
11
|
|
|
12
12
|
from databao_context_engine.pluginlib.config import DuckDBSecret
|
|
@@ -18,9 +18,8 @@ logger = logging.getLogger(__name__)
|
|
|
18
18
|
|
|
19
19
|
class ParquetConfigFile(BaseModel):
|
|
20
20
|
name: str | None = Field(default=None)
|
|
21
|
-
type: str = Field(default=
|
|
21
|
+
type: str = Field(default="parquet")
|
|
22
22
|
url: str = Field(
|
|
23
|
-
default=type,
|
|
24
23
|
description="Parquet resource location. Should be a valid URL or a path to a local file. "
|
|
25
24
|
"Examples: s3://your_bucket/file.parquet, s3://your-bucket/*.parquet, https://some.url/some_file.parquet, ~/path_to/file.parquet",
|
|
26
25
|
)
|
|
@@ -9,13 +9,29 @@ from databao_context_engine.system.properties import get_dce_path
|
|
|
9
9
|
|
|
10
10
|
@dataclass(kw_only=True, frozen=True)
|
|
11
11
|
class DceProjectInfo:
|
|
12
|
+
"""Information about a Databao Context Engine project.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
project_path: The root directory of the Databao Context Engine project.
|
|
16
|
+
is_initialized: Whether the project has been initialized.
|
|
17
|
+
project_id: The UUID of the project, or None if the project has not been initialized.
|
|
18
|
+
"""
|
|
19
|
+
|
|
12
20
|
project_path: Path
|
|
13
|
-
|
|
21
|
+
is_initialized: bool
|
|
14
22
|
project_id: UUID | None
|
|
15
23
|
|
|
16
24
|
|
|
17
25
|
@dataclass(kw_only=True, frozen=True)
|
|
18
26
|
class DceInfo:
|
|
27
|
+
"""Information about the current Databao Context Engine installation and project.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
version: The version of the databao_context_engine package installed on the system.
|
|
31
|
+
dce_path: The path where databao_context_engine stores its global data.
|
|
32
|
+
project_info: Information about the Databao Context Engine project.
|
|
33
|
+
"""
|
|
34
|
+
|
|
19
35
|
version: str
|
|
20
36
|
dce_path: Path
|
|
21
37
|
|
|
@@ -23,6 +39,14 @@ class DceInfo:
|
|
|
23
39
|
|
|
24
40
|
|
|
25
41
|
def get_databao_context_engine_info(project_dir: Path) -> DceInfo:
|
|
42
|
+
"""Return information about the current Databao Context Engine installation and project.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
project_dir: The root directory of the Databao Context Project.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A DceInfo instance containing information about the Databao Context Engine installation and project.
|
|
49
|
+
"""
|
|
26
50
|
return DceInfo(
|
|
27
51
|
version=get_dce_version(),
|
|
28
52
|
dce_path=get_dce_path(),
|
|
@@ -35,10 +59,15 @@ def _get_project_info(project_dir: Path) -> DceProjectInfo:
|
|
|
35
59
|
|
|
36
60
|
return DceProjectInfo(
|
|
37
61
|
project_path=project_dir,
|
|
38
|
-
|
|
62
|
+
is_initialized=project_layout is not None,
|
|
39
63
|
project_id=project_layout.read_config_file().project_id if project_layout is not None else None,
|
|
40
64
|
)
|
|
41
65
|
|
|
42
66
|
|
|
43
67
|
def get_dce_version() -> str:
|
|
68
|
+
"""Return the installed version of the databao_context_engine package.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The installed version of the databao_context_engine package.
|
|
72
|
+
"""
|
|
44
73
|
return version("databao_context_engine")
|
|
@@ -13,12 +13,21 @@ from databao_context_engine.project.project_config import ProjectConfig
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class InitErrorReason(Enum):
|
|
16
|
+
"""Reasons for which project initialization can fail."""
|
|
17
|
+
|
|
16
18
|
PROJECT_DIR_DOESNT_EXIST = "PROJECT_DIR_DOESNT_EXIST"
|
|
17
19
|
PROJECT_DIR_NOT_DIRECTORY = "PROJECT_DIR_NOT_DIRECTORY"
|
|
18
|
-
|
|
20
|
+
PROJECT_DIR_ALREADY_INITIALIZED = "PROJECT_DIR_ALREADY_INITIALIZED"
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class InitProjectError(Exception):
|
|
24
|
+
"""Raised when a project can't be initialized.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
message: The error message.
|
|
28
|
+
reason: The reason for the initialization failure.
|
|
29
|
+
"""
|
|
30
|
+
|
|
22
31
|
reason: InitErrorReason
|
|
23
32
|
|
|
24
33
|
def __init__(self, reason: InitErrorReason, message: str | None):
|
|
@@ -65,20 +74,20 @@ class _ProjectCreator:
|
|
|
65
74
|
|
|
66
75
|
if self.config_file.is_file() or self.deprecated_config_file.is_file():
|
|
67
76
|
raise InitProjectError(
|
|
68
|
-
message=f"Can't
|
|
69
|
-
reason=InitErrorReason.
|
|
77
|
+
message=f"Can't initialize a Databao Context Engine project in a folder that already contains a config file. [project_dir: {self.project_dir.resolve()}]",
|
|
78
|
+
reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
|
|
70
79
|
)
|
|
71
80
|
|
|
72
81
|
if self.src_dir.is_dir():
|
|
73
82
|
raise InitProjectError(
|
|
74
|
-
message=f"Can't
|
|
75
|
-
reason=InitErrorReason.
|
|
83
|
+
message=f"Can't initialize a Databao Context Engine project in a folder that already contains a src directory. [project_dir: {self.project_dir.resolve()}]",
|
|
84
|
+
reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
|
|
76
85
|
)
|
|
77
86
|
|
|
78
87
|
if self.examples_dir.is_file():
|
|
79
88
|
raise InitProjectError(
|
|
80
|
-
message=f"Can't
|
|
81
|
-
reason=InitErrorReason.
|
|
89
|
+
message=f"Can't initialize a Databao Context Engine project in a folder that already contains an examples dir. [project_dir: {self.project_dir.resolve()}]",
|
|
90
|
+
reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
|
|
82
91
|
)
|
|
83
92
|
|
|
84
93
|
return True
|
|
@@ -2,8 +2,8 @@ import logging
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
+
from databao_context_engine.datasources.types import DatasourceId
|
|
5
6
|
from databao_context_engine.project.project_config import ProjectConfig
|
|
6
|
-
from databao_context_engine.project.types import DatasourceId
|
|
7
7
|
|
|
8
8
|
SOURCE_FOLDER_NAME = "src"
|
|
9
9
|
OUTPUT_FOLDER_NAME = "output"
|
|
@@ -96,12 +96,12 @@ class _ProjectValidator:
|
|
|
96
96
|
|
|
97
97
|
if self.config_file is None:
|
|
98
98
|
raise ValueError(
|
|
99
|
-
f"The current project directory has not been
|
|
99
|
+
f"The current project directory has not been initialized. It should contain a config file. [project_dir: {self.project_dir.resolve()}]"
|
|
100
100
|
)
|
|
101
101
|
|
|
102
102
|
if not self.is_src_valid():
|
|
103
103
|
raise ValueError(
|
|
104
|
-
f"The current project directory has not been
|
|
104
|
+
f"The current project directory has not been initialized. It should contain a src directory. [project_dir: {self.project_dir.resolve()}]"
|
|
105
105
|
)
|
|
106
106
|
|
|
107
107
|
return ProjectLayout(project_dir=self.project_dir, config_file=self.config_file)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
def export_retrieve_results(
|
|
5
|
-
path =
|
|
4
|
+
def export_retrieve_results(output_dir: Path, retrieve_results: list[str]) -> Path:
|
|
5
|
+
path = output_dir.joinpath("context_duckdb.yaml")
|
|
6
6
|
|
|
7
7
|
with path.open("w") as export_file:
|
|
8
8
|
for result in retrieve_results:
|
databao_context_engine/retrieve_embeddings/{internal/retrieve_runner.py → retrieve_runner.py}
RENAMED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
-
from databao_context_engine.project.
|
|
5
|
-
from databao_context_engine.retrieve_embeddings.
|
|
6
|
-
from databao_context_engine.retrieve_embeddings.
|
|
4
|
+
from databao_context_engine.project.layout import get_output_dir
|
|
5
|
+
from databao_context_engine.retrieve_embeddings.export_results import export_retrieve_results
|
|
6
|
+
from databao_context_engine.retrieve_embeddings.retrieve_service import RetrieveService
|
|
7
7
|
from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger(__name__)
|
|
@@ -15,17 +15,13 @@ def retrieve(
|
|
|
15
15
|
retrieve_service: RetrieveService,
|
|
16
16
|
project_id: str,
|
|
17
17
|
text: str,
|
|
18
|
-
run_name: str | None,
|
|
19
18
|
limit: int | None,
|
|
20
19
|
export_to_file: bool,
|
|
21
20
|
) -> list[VectorSearchResult]:
|
|
22
|
-
|
|
23
|
-
retrieve_results = retrieve_service.retrieve(
|
|
24
|
-
project_id=project_id, text=text, run_name=resolved_run_name, limit=limit
|
|
25
|
-
)
|
|
21
|
+
retrieve_results = retrieve_service.retrieve(project_id=project_id, text=text, limit=limit)
|
|
26
22
|
|
|
27
23
|
if export_to_file:
|
|
28
|
-
export_directory =
|
|
24
|
+
export_directory = get_output_dir(project_dir)
|
|
29
25
|
|
|
30
26
|
display_texts = [result.display_text for result in retrieve_results]
|
|
31
27
|
export_file = export_retrieve_results(export_directory, display_texts)
|
databao_context_engine/retrieve_embeddings/{internal/retrieve_service.py → retrieve_service.py}
RENAMED
|
@@ -2,9 +2,7 @@ import logging
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
|
|
4
4
|
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
5
|
-
from databao_context_engine.project.runs import resolve_run_name_from_repo
|
|
6
5
|
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
7
|
-
from databao_context_engine.storage.repositories.run_repository import RunRepository
|
|
8
6
|
from databao_context_engine.storage.repositories.vector_search_repository import (
|
|
9
7
|
VectorSearchRepository,
|
|
10
8
|
VectorSearchResult,
|
|
@@ -17,43 +15,34 @@ class RetrieveService:
|
|
|
17
15
|
def __init__(
|
|
18
16
|
self,
|
|
19
17
|
*,
|
|
20
|
-
run_repo: RunRepository,
|
|
21
18
|
vector_search_repo: VectorSearchRepository,
|
|
22
19
|
shard_resolver: EmbeddingShardResolver,
|
|
23
20
|
provider: EmbeddingProvider,
|
|
24
21
|
):
|
|
25
|
-
self._run_repo = run_repo
|
|
26
22
|
self._shard_resolver = shard_resolver
|
|
27
23
|
self._provider = provider
|
|
28
24
|
self._vector_search_repo = vector_search_repo
|
|
29
25
|
|
|
30
|
-
def retrieve(
|
|
31
|
-
self, *, project_id: str, text: str, run_name: str, limit: int | None = None
|
|
32
|
-
) -> list[VectorSearchResult]:
|
|
26
|
+
def retrieve(self, *, project_id: str, text: str, limit: int | None = None) -> list[VectorSearchResult]:
|
|
33
27
|
if limit is None:
|
|
34
28
|
limit = 10
|
|
35
29
|
|
|
36
|
-
run = self._run_repo.get_by_run_name(project_id=project_id, run_name=run_name)
|
|
37
|
-
if run is None:
|
|
38
|
-
raise LookupError(f"Run '{run_name}' not found for project '{project_id}'.")
|
|
39
|
-
|
|
40
30
|
table_name, dimension = self._shard_resolver.resolve(
|
|
41
31
|
embedder=self._provider.embedder, model_id=self._provider.model_id
|
|
42
32
|
)
|
|
43
33
|
|
|
44
34
|
retrieve_vec: Sequence[float] = self._provider.embed(text)
|
|
45
35
|
|
|
46
|
-
logger.debug(f"Retrieving display texts
|
|
36
|
+
logger.debug(f"Retrieving display texts in table {table_name}")
|
|
47
37
|
|
|
48
38
|
search_results = self._vector_search_repo.get_display_texts_by_similarity(
|
|
49
39
|
table_name=table_name,
|
|
50
|
-
run_id=run.run_id,
|
|
51
40
|
retrieve_vec=retrieve_vec,
|
|
52
41
|
dimension=dimension,
|
|
53
42
|
limit=limit,
|
|
54
43
|
)
|
|
55
44
|
|
|
56
|
-
logger.debug(f"Retrieved {len(search_results)} display texts
|
|
45
|
+
logger.debug(f"Retrieved {len(search_results)} display texts in table {table_name}")
|
|
57
46
|
|
|
58
47
|
if logger.isEnabledFor(logging.DEBUG):
|
|
59
48
|
closest_result = min(search_results, key=lambda result: result.cosine_distance)
|
|
@@ -63,6 +52,3 @@ class RetrieveService:
|
|
|
63
52
|
logger.debug(f"Worst result: ({farthest_result.cosine_distance}, {farthest_result.embeddable_text})")
|
|
64
53
|
|
|
65
54
|
return search_results
|
|
66
|
-
|
|
67
|
-
def resolve_run_name(self, *, project_id: str, run_name: str | None) -> str:
|
|
68
|
-
return resolve_run_name_from_repo(run_repository=self._run_repo, project_id=project_id, run_name=run_name)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from duckdb import DuckDBPyConnection
|
|
2
|
+
|
|
3
|
+
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
4
|
+
from databao_context_engine.llm.factory import create_ollama_embedding_provider, create_ollama_service
|
|
5
|
+
from databao_context_engine.project.layout import ProjectLayout, ensure_project_dir
|
|
6
|
+
from databao_context_engine.retrieve_embeddings.retrieve_runner import retrieve
|
|
7
|
+
from databao_context_engine.retrieve_embeddings.retrieve_service import RetrieveService
|
|
8
|
+
from databao_context_engine.services.factories import create_shard_resolver
|
|
9
|
+
from databao_context_engine.storage.connection import open_duckdb_connection
|
|
10
|
+
from databao_context_engine.storage.repositories.factories import create_vector_search_repository
|
|
11
|
+
from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
|
|
12
|
+
from databao_context_engine.system.properties import get_db_path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def retrieve_embeddings(
|
|
16
|
+
project_layout: ProjectLayout,
|
|
17
|
+
retrieve_text: str,
|
|
18
|
+
limit: int | None,
|
|
19
|
+
export_to_file: bool,
|
|
20
|
+
) -> list[VectorSearchResult]:
|
|
21
|
+
ensure_project_dir(project_layout.project_dir)
|
|
22
|
+
|
|
23
|
+
with open_duckdb_connection(get_db_path(project_layout.project_dir)) as conn:
|
|
24
|
+
ollama_service = create_ollama_service()
|
|
25
|
+
embedding_provider = create_ollama_embedding_provider(ollama_service)
|
|
26
|
+
retrieve_service = _create_retrieve_service(conn, embedding_provider=embedding_provider)
|
|
27
|
+
return retrieve(
|
|
28
|
+
project_dir=project_layout.project_dir,
|
|
29
|
+
retrieve_service=retrieve_service,
|
|
30
|
+
project_id=str(project_layout.read_config_file().project_id),
|
|
31
|
+
text=retrieve_text,
|
|
32
|
+
limit=limit,
|
|
33
|
+
export_to_file=export_to_file,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _create_retrieve_service(
|
|
38
|
+
conn: DuckDBPyConnection,
|
|
39
|
+
*,
|
|
40
|
+
embedding_provider: EmbeddingProvider,
|
|
41
|
+
) -> RetrieveService:
|
|
42
|
+
vector_search_repo = create_vector_search_repository(conn)
|
|
43
|
+
shard_resolver = create_shard_resolver(conn)
|
|
44
|
+
|
|
45
|
+
return RetrieveService(
|
|
46
|
+
vector_search_repo=vector_search_repo,
|
|
47
|
+
shard_resolver=shard_resolver,
|
|
48
|
+
provider=embedding_provider,
|
|
49
|
+
)
|
|
@@ -8,7 +8,7 @@ def default_representer(dumper: SafeDumper, data: object) -> Node:
|
|
|
8
8
|
if isinstance(data, Mapping):
|
|
9
9
|
return dumper.represent_dict(data)
|
|
10
10
|
elif hasattr(data, "__dict__"):
|
|
11
|
-
# Doesn't
|
|
11
|
+
# Doesn't serialize "private" attributes (that starts with an _)
|
|
12
12
|
data_public_attributes = {key: value for key, value in data.__dict__.items() if not key.startswith("_")}
|
|
13
13
|
if data_public_attributes:
|
|
14
14
|
return dumper.represent_dict(data_public_attributes)
|
|
@@ -5,7 +5,7 @@ from typing import cast
|
|
|
5
5
|
from databao_context_engine.llm.descriptions.provider import DescriptionProvider
|
|
6
6
|
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
7
7
|
from databao_context_engine.pluginlib.build_plugin import EmbeddableChunk
|
|
8
|
-
from databao_context_engine.
|
|
8
|
+
from databao_context_engine.serialization.yaml import to_yaml_string
|
|
9
9
|
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
10
10
|
from databao_context_engine.services.models import ChunkEmbedding
|
|
11
11
|
from databao_context_engine.services.persistence_service import PersistenceService
|
|
@@ -14,9 +14,22 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ChunkEmbeddingMode(Enum):
|
|
17
|
+
"""Mode controlling how chunks are embedded."""
|
|
18
|
+
|
|
17
19
|
EMBEDDABLE_TEXT_ONLY = "EMBEDDABLE_TEXT_ONLY"
|
|
20
|
+
"""
|
|
21
|
+
The embedding is generated only from the string defined by the plugin as embeddable for a chunk.
|
|
22
|
+
"""
|
|
23
|
+
|
|
18
24
|
GENERATED_DESCRIPTION_ONLY = "GENERATED_DESCRIPTION_ONLY"
|
|
25
|
+
"""
|
|
26
|
+
The embedding is generated only from a description of the chunk generated by a LLM.
|
|
27
|
+
"""
|
|
28
|
+
|
|
19
29
|
EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION = "EMBEDDABLE_TEXT_AND_GENERATED_DESCRIPTION"
|
|
30
|
+
"""
|
|
31
|
+
The embedding is generated from both the embeddable string of the chunk and the description of the chunk generated by a LLM.
|
|
32
|
+
"""
|
|
20
33
|
|
|
21
34
|
def should_generate_description(self) -> bool:
|
|
22
35
|
return self in (
|
|
@@ -44,26 +57,24 @@ class ChunkEmbeddingService:
|
|
|
44
57
|
if self._chunk_embedding_mode.should_generate_description() and description_provider is None:
|
|
45
58
|
raise ValueError("A DescriptionProvider must be provided when generating descriptions")
|
|
46
59
|
|
|
47
|
-
def embed_chunks(self, *,
|
|
48
|
-
"""
|
|
49
|
-
Turn plugin chunks into persisted chunks and embeddings
|
|
60
|
+
def embed_chunks(self, *, chunks: list[EmbeddableChunk], result: str, full_type: str, datasource_id: str) -> None:
|
|
61
|
+
"""Turn plugin chunks into persisted chunks and embeddings.
|
|
50
62
|
|
|
51
63
|
Flow:
|
|
52
|
-
1) Embed each chunk into an embedded vector
|
|
53
|
-
2) Get or create embedding table for the appropriate model and embedding dimensions
|
|
54
|
-
3) Persist chunks and embeddings vectors in a single transaction
|
|
64
|
+
1) Embed each chunk into an embedded vector.
|
|
65
|
+
2) Get or create embedding table for the appropriate model and embedding dimensions.
|
|
66
|
+
3) Persist chunks and embeddings vectors in a single transaction.
|
|
55
67
|
"""
|
|
56
|
-
|
|
57
68
|
if not chunks:
|
|
58
69
|
return
|
|
59
70
|
|
|
60
71
|
logger.debug(
|
|
61
|
-
f"Embedding {len(chunks)} chunks for datasource
|
|
72
|
+
f"Embedding {len(chunks)} chunks for datasource {datasource_id}, with chunk_embedding_mode={self._chunk_embedding_mode}"
|
|
62
73
|
)
|
|
63
74
|
|
|
64
75
|
enriched_embeddings: list[ChunkEmbedding] = []
|
|
65
76
|
for chunk in chunks:
|
|
66
|
-
chunk_display_text = to_yaml_string(chunk.content)
|
|
77
|
+
chunk_display_text = chunk.content if isinstance(chunk.content, str) else to_yaml_string(chunk.content)
|
|
67
78
|
|
|
68
79
|
generated_description = ""
|
|
69
80
|
match self._chunk_embedding_mode:
|
|
@@ -98,7 +109,8 @@ class ChunkEmbeddingService:
|
|
|
98
109
|
)
|
|
99
110
|
|
|
100
111
|
self._persistence_service.write_chunks_and_embeddings(
|
|
101
|
-
datasource_run_id=datasource_run_id,
|
|
102
112
|
chunk_embeddings=enriched_embeddings,
|
|
103
113
|
table_name=table_name,
|
|
114
|
+
full_type=full_type,
|
|
115
|
+
datasource_id=datasource_id,
|
|
104
116
|
)
|
|
@@ -1,20 +1,15 @@
|
|
|
1
|
-
from
|
|
1
|
+
from duckdb import DuckDBPyConnection
|
|
2
2
|
|
|
3
|
-
from databao_context_engine.build_sources.internal.build_service import BuildService
|
|
4
3
|
from databao_context_engine.llm.descriptions.provider import DescriptionProvider
|
|
5
4
|
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
6
|
-
from databao_context_engine.retrieve_embeddings.internal.retrieve_service import RetrieveService
|
|
7
5
|
from databao_context_engine.services.chunk_embedding_service import ChunkEmbeddingMode, ChunkEmbeddingService
|
|
8
6
|
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
9
7
|
from databao_context_engine.services.persistence_service import PersistenceService
|
|
10
8
|
from databao_context_engine.services.table_name_policy import TableNamePolicy
|
|
11
9
|
from databao_context_engine.storage.repositories.factories import (
|
|
12
10
|
create_chunk_repository,
|
|
13
|
-
create_datasource_run_repository,
|
|
14
11
|
create_embedding_repository,
|
|
15
12
|
create_registry_repository,
|
|
16
|
-
create_run_repository,
|
|
17
|
-
create_vector_search_repository,
|
|
18
13
|
)
|
|
19
14
|
|
|
20
15
|
|
|
@@ -46,43 +41,3 @@ def create_chunk_embedding_service(
|
|
|
46
41
|
description_provider=description_provider,
|
|
47
42
|
chunk_embedding_mode=chunk_embedding_mode,
|
|
48
43
|
)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def create_build_service(
|
|
52
|
-
conn: DuckDBPyConnection,
|
|
53
|
-
*,
|
|
54
|
-
embedding_provider: EmbeddingProvider,
|
|
55
|
-
description_provider: DescriptionProvider | None,
|
|
56
|
-
chunk_embedding_mode: ChunkEmbeddingMode,
|
|
57
|
-
) -> BuildService:
|
|
58
|
-
run_repo = create_run_repository(conn)
|
|
59
|
-
datasource_run_repo = create_datasource_run_repository(conn)
|
|
60
|
-
chunk_embedding_service = create_chunk_embedding_service(
|
|
61
|
-
conn,
|
|
62
|
-
embedding_provider=embedding_provider,
|
|
63
|
-
description_provider=description_provider,
|
|
64
|
-
chunk_embedding_mode=chunk_embedding_mode,
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
return BuildService(
|
|
68
|
-
run_repo=run_repo,
|
|
69
|
-
datasource_run_repo=datasource_run_repo,
|
|
70
|
-
chunk_embedding_service=chunk_embedding_service,
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def create_retrieve_service(
|
|
75
|
-
conn: DuckDBPyConnection,
|
|
76
|
-
*,
|
|
77
|
-
embedding_provider: EmbeddingProvider,
|
|
78
|
-
) -> RetrieveService:
|
|
79
|
-
run_repo = create_run_repository(conn)
|
|
80
|
-
vector_search_repo = create_vector_search_repository(conn)
|
|
81
|
-
shard_resolver = create_shard_resolver(conn)
|
|
82
|
-
|
|
83
|
-
return RetrieveService(
|
|
84
|
-
run_repo=run_repo,
|
|
85
|
-
vector_search_repo=vector_search_repo,
|
|
86
|
-
shard_resolver=shard_resolver,
|
|
87
|
-
provider=embedding_provider,
|
|
88
|
-
)
|
|
@@ -24,11 +24,13 @@ class PersistenceService:
|
|
|
24
24
|
self._dim = dim
|
|
25
25
|
|
|
26
26
|
def write_chunks_and_embeddings(
|
|
27
|
-
self, *,
|
|
27
|
+
self, *, chunk_embeddings: list[ChunkEmbedding], table_name: str, full_type: str, datasource_id: str
|
|
28
28
|
):
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
"""Atomically persist chunks and their vectors.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ValueError: If chunk_embeddings is an empty list.
|
|
33
|
+
|
|
32
34
|
"""
|
|
33
35
|
if not chunk_embeddings:
|
|
34
36
|
raise ValueError("chunk_embeddings must be a non-empty list")
|
|
@@ -36,21 +38,19 @@ class PersistenceService:
|
|
|
36
38
|
with transaction(self._conn):
|
|
37
39
|
for chunk_embedding in chunk_embeddings:
|
|
38
40
|
chunk_dto = self.create_chunk(
|
|
39
|
-
|
|
41
|
+
full_type=full_type,
|
|
42
|
+
datasource_id=datasource_id,
|
|
40
43
|
embeddable_text=chunk_embedding.chunk.embeddable_text,
|
|
41
44
|
display_text=chunk_embedding.display_text,
|
|
42
|
-
generated_description=chunk_embedding.generated_description,
|
|
43
45
|
)
|
|
44
46
|
self.create_embedding(table_name=table_name, chunk_id=chunk_dto.chunk_id, vec=chunk_embedding.vec)
|
|
45
47
|
|
|
46
|
-
def create_chunk(
|
|
47
|
-
self, *, datasource_run_id: int, embeddable_text: str, display_text: str, generated_description: str
|
|
48
|
-
) -> ChunkDTO:
|
|
48
|
+
def create_chunk(self, *, full_type: str, datasource_id: str, embeddable_text: str, display_text: str) -> ChunkDTO:
|
|
49
49
|
return self._chunk_repo.create(
|
|
50
|
-
|
|
50
|
+
full_type=full_type,
|
|
51
|
+
datasource_id=datasource_id,
|
|
51
52
|
embeddable_text=embeddable_text,
|
|
52
53
|
display_text=display_text,
|
|
53
|
-
generated_description=generated_description,
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
def create_embedding(self, *, table_name: str, chunk_id: int, vec: Sequence[float]):
|