databao-context-engine 0.1.2__py3-none-any.whl → 0.1.4.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databao_context_engine/__init__.py +18 -6
- databao_context_engine/build_sources/__init__.py +4 -0
- databao_context_engine/build_sources/{internal/build_runner.py → build_runner.py} +27 -23
- databao_context_engine/build_sources/build_service.py +53 -0
- databao_context_engine/build_sources/build_wiring.py +84 -0
- databao_context_engine/build_sources/export_results.py +41 -0
- databao_context_engine/build_sources/{internal/plugin_execution.py → plugin_execution.py} +3 -7
- databao_context_engine/cli/add_datasource_config.py +41 -15
- databao_context_engine/cli/commands.py +12 -43
- databao_context_engine/cli/info.py +3 -2
- databao_context_engine/databao_context_engine.py +137 -0
- databao_context_engine/databao_context_project_manager.py +96 -6
- databao_context_engine/datasources/add_config.py +34 -0
- databao_context_engine/{datasource_config → datasources}/check_config.py +18 -7
- databao_context_engine/datasources/datasource_context.py +93 -0
- databao_context_engine/{project → datasources}/datasource_discovery.py +17 -16
- databao_context_engine/{project → datasources}/types.py +64 -15
- databao_context_engine/init_project.py +25 -3
- databao_context_engine/introspection/property_extract.py +67 -53
- databao_context_engine/llm/errors.py +2 -8
- databao_context_engine/llm/install.py +13 -20
- databao_context_engine/llm/service.py +1 -3
- databao_context_engine/mcp/mcp_runner.py +4 -2
- databao_context_engine/mcp/mcp_server.py +10 -10
- databao_context_engine/plugin_loader.py +111 -0
- databao_context_engine/pluginlib/build_plugin.py +25 -9
- databao_context_engine/pluginlib/config.py +16 -2
- databao_context_engine/plugins/base_db_plugin.py +5 -2
- databao_context_engine/plugins/databases/athena_introspector.py +85 -22
- databao_context_engine/plugins/databases/base_introspector.py +5 -3
- databao_context_engine/plugins/databases/clickhouse_introspector.py +22 -11
- databao_context_engine/plugins/databases/duckdb_introspector.py +3 -5
- databao_context_engine/plugins/databases/introspection_model_builder.py +1 -1
- databao_context_engine/plugins/databases/introspection_scope.py +11 -9
- databao_context_engine/plugins/databases/introspection_scope_matcher.py +2 -5
- databao_context_engine/plugins/databases/mssql_introspector.py +26 -17
- databao_context_engine/plugins/databases/mysql_introspector.py +23 -12
- databao_context_engine/plugins/databases/postgresql_introspector.py +2 -2
- databao_context_engine/plugins/databases/snowflake_introspector.py +43 -10
- databao_context_engine/plugins/duckdb_tools.py +18 -0
- databao_context_engine/plugins/plugin_loader.py +43 -42
- databao_context_engine/plugins/resources/parquet_introspector.py +7 -19
- databao_context_engine/project/info.py +34 -2
- databao_context_engine/project/init_project.py +16 -7
- databao_context_engine/project/layout.py +3 -3
- databao_context_engine/retrieve_embeddings/__init__.py +3 -0
- databao_context_engine/retrieve_embeddings/{internal/export_results.py → export_results.py} +2 -2
- databao_context_engine/retrieve_embeddings/{internal/retrieve_runner.py → retrieve_runner.py} +5 -9
- databao_context_engine/retrieve_embeddings/{internal/retrieve_service.py → retrieve_service.py} +3 -17
- databao_context_engine/retrieve_embeddings/retrieve_wiring.py +49 -0
- databao_context_engine/{serialisation → serialization}/yaml.py +1 -1
- databao_context_engine/services/chunk_embedding_service.py +23 -11
- databao_context_engine/services/factories.py +1 -46
- databao_context_engine/services/persistence_service.py +11 -11
- databao_context_engine/storage/connection.py +11 -7
- databao_context_engine/storage/exceptions/exceptions.py +2 -2
- databao_context_engine/storage/migrate.py +2 -4
- databao_context_engine/storage/migrations/V01__init.sql +6 -31
- databao_context_engine/storage/models.py +2 -23
- databao_context_engine/storage/repositories/chunk_repository.py +16 -12
- databao_context_engine/storage/repositories/factories.py +1 -12
- databao_context_engine/storage/repositories/vector_search_repository.py +8 -13
- databao_context_engine/system/properties.py +4 -2
- databao_context_engine-0.1.4.dev1.dist-info/METADATA +75 -0
- databao_context_engine-0.1.4.dev1.dist-info/RECORD +125 -0
- {databao_context_engine-0.1.2.dist-info → databao_context_engine-0.1.4.dev1.dist-info}/WHEEL +1 -1
- databao_context_engine/build_sources/internal/build_service.py +0 -77
- databao_context_engine/build_sources/internal/build_wiring.py +0 -52
- databao_context_engine/build_sources/internal/export_results.py +0 -43
- databao_context_engine/build_sources/public/api.py +0 -4
- databao_context_engine/databao_engine.py +0 -85
- databao_context_engine/datasource_config/__init__.py +0 -0
- databao_context_engine/datasource_config/add_config.py +0 -50
- databao_context_engine/datasource_config/datasource_context.py +0 -60
- databao_context_engine/mcp/all_results_tool.py +0 -5
- databao_context_engine/mcp/retrieve_tool.py +0 -22
- databao_context_engine/project/runs.py +0 -39
- databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +0 -29
- databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/public/api.py +0 -3
- databao_context_engine/serialisation/__init__.py +0 -0
- databao_context_engine/services/run_name_policy.py +0 -8
- databao_context_engine/storage/repositories/datasource_run_repository.py +0 -136
- databao_context_engine/storage/repositories/run_repository.py +0 -157
- databao_context_engine-0.1.2.dist-info/METADATA +0 -187
- databao_context_engine-0.1.2.dist-info/RECORD +0 -135
- /databao_context_engine/{build_sources/internal → datasources}/__init__.py +0 -0
- /databao_context_engine/{build_sources/public → serialization}/__init__.py +0 -0
- {databao_context_engine-0.1.2.dist-info → databao_context_engine-0.1.4.dev1.dist-info}/entry_points.txt +0 -0
|
@@ -1,22 +1,57 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List
|
|
3
|
+
from typing import Annotated, Any, Dict, List
|
|
4
4
|
|
|
5
5
|
import snowflake.connector
|
|
6
|
-
from pydantic import Field
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
7
|
from snowflake.connector import DictCursor
|
|
8
8
|
|
|
9
|
+
from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
|
|
9
10
|
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
10
11
|
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
11
12
|
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
|
|
12
13
|
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
class SnowflakePasswordAuth(BaseModel):
|
|
17
|
+
password: Annotated[str, ConfigPropertyAnnotation(secret=True)]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SnowflakeKeyPairAuth(BaseModel):
|
|
21
|
+
private_key_file: str | None = None
|
|
22
|
+
private_key_file_pwd: str | None = None
|
|
23
|
+
private_key: Annotated[str, ConfigPropertyAnnotation(secret=True)]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SnowflakeSSOAuth(BaseModel):
|
|
27
|
+
authenticator: str = Field(description='e.g. "externalbrowser"')
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SnowflakeConnectionProperties(BaseModel):
|
|
31
|
+
account: Annotated[str, ConfigPropertyAnnotation(required=True)]
|
|
32
|
+
warehouse: str | None = None
|
|
33
|
+
database: str | None = None
|
|
34
|
+
user: str | None = None
|
|
35
|
+
role: str | None = None
|
|
36
|
+
auth: SnowflakePasswordAuth | SnowflakeKeyPairAuth | SnowflakeSSOAuth
|
|
37
|
+
additional_properties: dict[str, Any] = {}
|
|
38
|
+
|
|
39
|
+
def to_snowflake_kwargs(self) -> dict[str, Any]:
|
|
40
|
+
kwargs = self.model_dump(
|
|
41
|
+
exclude={
|
|
42
|
+
"additional_properties": True,
|
|
43
|
+
},
|
|
44
|
+
exclude_none=True,
|
|
45
|
+
)
|
|
46
|
+
auth_fields = kwargs.pop("auth", {})
|
|
47
|
+
kwargs.update(auth_fields)
|
|
48
|
+
kwargs.update(self.additional_properties)
|
|
49
|
+
return kwargs
|
|
50
|
+
|
|
51
|
+
|
|
15
52
|
class SnowflakeConfigFile(BaseDatabaseConfigFile):
|
|
16
|
-
type: str = Field(default="
|
|
17
|
-
connection:
|
|
18
|
-
description="Connection parameters for Snowflake. It can contain any of the keys supported by the Snowflake connection library"
|
|
19
|
-
)
|
|
53
|
+
type: str = Field(default="snowflake")
|
|
54
|
+
connection: SnowflakeConnectionProperties
|
|
20
55
|
|
|
21
56
|
|
|
22
57
|
class SnowflakeIntrospector(BaseIntrospector[SnowflakeConfigFile]):
|
|
@@ -28,11 +63,9 @@ class SnowflakeIntrospector(BaseIntrospector[SnowflakeConfigFile]):
|
|
|
28
63
|
|
|
29
64
|
def _connect(self, file_config: SnowflakeConfigFile):
|
|
30
65
|
connection = file_config.connection
|
|
31
|
-
if not isinstance(connection, Mapping):
|
|
32
|
-
raise ValueError("Invalid YAML config: 'connection' must be a mapping of connection parameters")
|
|
33
66
|
snowflake.connector.paramstyle = "qmark"
|
|
34
67
|
return snowflake.connector.connect(
|
|
35
|
-
**connection,
|
|
68
|
+
**connection.to_snowflake_kwargs(),
|
|
36
69
|
)
|
|
37
70
|
|
|
38
71
|
def _connect_to_catalog(self, file_config: SnowflakeConfigFile, catalog: str):
|
|
@@ -41,7 +74,7 @@ class SnowflakeIntrospector(BaseIntrospector[SnowflakeConfigFile]):
|
|
|
41
74
|
return snowflake.connector.connect(**cfg)
|
|
42
75
|
|
|
43
76
|
def _get_catalogs(self, connection, file_config: SnowflakeConfigFile) -> list[str]:
|
|
44
|
-
database = file_config.connection.
|
|
77
|
+
database = file_config.connection.database
|
|
45
78
|
if database:
|
|
46
79
|
return [database]
|
|
47
80
|
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from databao_context_engine.pluginlib.config import DuckDBSecret
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def generate_create_secret_sql(secret_name, duckdb_secret: DuckDBSecret) -> str:
|
|
7
|
+
parameters = [("type", duckdb_secret.type)] + list(duckdb_secret.properties.items())
|
|
8
|
+
return f"""CREATE SECRET {secret_name} (
|
|
9
|
+
{", ".join([f"{k} '{v}'" for (k, v) in parameters])}
|
|
10
|
+
);
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def fetchall_dicts(cur, sql: str, params=None) -> list[dict[str, Any]]:
|
|
15
|
+
cur.execute(sql, params or [])
|
|
16
|
+
columns = [desc[0].lower() for desc in cur.description] if cur.description else []
|
|
17
|
+
rows = cur.fetchall()
|
|
18
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
@@ -14,29 +14,11 @@ class DuplicatePluginTypeError(RuntimeError):
|
|
|
14
14
|
"""Raised when two plugins register the same <main>/<sub> plugin key."""
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def get_all_available_plugin_types(exclude_file_plugins: bool = False) -> set[DatasourceType]:
|
|
21
|
-
return set(load_plugins(exclude_file_plugins=exclude_file_plugins).keys())
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def get_plugin_for_type(datasource_type: DatasourceType) -> BuildPlugin:
|
|
25
|
-
all_plugins = load_plugins()
|
|
26
|
-
|
|
27
|
-
if datasource_type not in all_plugins:
|
|
28
|
-
raise ValueError(f"No plugin found for type '{datasource_type.full_type}'")
|
|
29
|
-
|
|
30
|
-
return load_plugins()[datasource_type]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def load_plugins(exclude_file_plugins: bool = False) -> PluginList:
|
|
34
|
-
"""
|
|
35
|
-
Loads both builtin and external plugins and merges them into one list
|
|
36
|
-
"""
|
|
17
|
+
def load_plugins(exclude_file_plugins: bool = False) -> dict[DatasourceType, BuildPlugin]:
|
|
18
|
+
"""Load both builtin and external plugins and merges them into one list."""
|
|
37
19
|
builtin_plugins = _load_builtin_plugins(exclude_file_plugins)
|
|
38
20
|
external_plugins = _load_external_plugins(exclude_file_plugins)
|
|
39
|
-
plugins =
|
|
21
|
+
plugins = _merge_plugins(builtin_plugins, external_plugins)
|
|
40
22
|
|
|
41
23
|
return plugins
|
|
42
24
|
|
|
@@ -61,16 +43,9 @@ def _load_builtin_file_plugins() -> list[BuildFilePlugin]:
|
|
|
61
43
|
|
|
62
44
|
|
|
63
45
|
def _load_builtin_datasource_plugins() -> list[BuildDatasourcePlugin]:
|
|
64
|
-
"""
|
|
65
|
-
Statically register built-in plugins
|
|
66
|
-
"""
|
|
67
|
-
from databao_context_engine.plugins.athena_db_plugin import AthenaDbPlugin
|
|
68
|
-
from databao_context_engine.plugins.clickhouse_db_plugin import ClickhouseDbPlugin
|
|
46
|
+
"""Statically register built-in plugins."""
|
|
69
47
|
from databao_context_engine.plugins.duckdb_db_plugin import DuckDbPlugin
|
|
70
|
-
from databao_context_engine.plugins.mysql_db_plugin import MySQLDbPlugin
|
|
71
48
|
from databao_context_engine.plugins.parquet_plugin import ParquetPlugin
|
|
72
|
-
from databao_context_engine.plugins.postgresql_db_plugin import PostgresqlDbPlugin
|
|
73
|
-
from databao_context_engine.plugins.snowflake_db_plugin import SnowflakeDbPlugin
|
|
74
49
|
|
|
75
50
|
# optional plugins are added to the python environment via extras
|
|
76
51
|
optional_plugins: list[BuildDatasourcePlugin] = []
|
|
@@ -81,31 +56,57 @@ def _load_builtin_datasource_plugins() -> list[BuildDatasourcePlugin]:
|
|
|
81
56
|
except ImportError:
|
|
82
57
|
pass
|
|
83
58
|
|
|
59
|
+
try:
|
|
60
|
+
from databao_context_engine.plugins.clickhouse_db_plugin import ClickhouseDbPlugin
|
|
61
|
+
|
|
62
|
+
optional_plugins.append(ClickhouseDbPlugin())
|
|
63
|
+
except ImportError:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
from databao_context_engine.plugins.athena_db_plugin import AthenaDbPlugin
|
|
68
|
+
|
|
69
|
+
optional_plugins.append(AthenaDbPlugin())
|
|
70
|
+
except ImportError:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
from databao_context_engine.plugins.snowflake_db_plugin import SnowflakeDbPlugin
|
|
75
|
+
|
|
76
|
+
optional_plugins.append(SnowflakeDbPlugin())
|
|
77
|
+
except ImportError:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
from databao_context_engine.plugins.mysql_db_plugin import MySQLDbPlugin
|
|
82
|
+
|
|
83
|
+
optional_plugins.append(MySQLDbPlugin())
|
|
84
|
+
except ImportError:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
from databao_context_engine.plugins.postgresql_db_plugin import PostgresqlDbPlugin
|
|
89
|
+
|
|
90
|
+
optional_plugins.append(PostgresqlDbPlugin())
|
|
91
|
+
except ImportError:
|
|
92
|
+
pass
|
|
93
|
+
|
|
84
94
|
required_plugins: list[BuildDatasourcePlugin] = [
|
|
85
|
-
AthenaDbPlugin(),
|
|
86
|
-
ClickhouseDbPlugin(),
|
|
87
95
|
DuckDbPlugin(),
|
|
88
|
-
MySQLDbPlugin(),
|
|
89
|
-
PostgresqlDbPlugin(),
|
|
90
|
-
SnowflakeDbPlugin(),
|
|
91
96
|
ParquetPlugin(),
|
|
92
97
|
]
|
|
93
98
|
return required_plugins + optional_plugins
|
|
94
99
|
|
|
95
100
|
|
|
96
101
|
def _load_external_plugins(exclude_file_plugins: bool = False) -> list[BuildPlugin]:
|
|
97
|
-
"""
|
|
98
|
-
Discover external plugins via entry points
|
|
99
|
-
"""
|
|
102
|
+
"""Discover external plugins via entry points."""
|
|
100
103
|
# TODO: implement external plugin loading
|
|
101
104
|
return []
|
|
102
105
|
|
|
103
106
|
|
|
104
|
-
def
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
"""
|
|
108
|
-
registry: PluginList = {}
|
|
107
|
+
def _merge_plugins(*plugin_lists: list[BuildPlugin]) -> dict[DatasourceType, BuildPlugin]:
|
|
108
|
+
"""Merge multiple plugin maps."""
|
|
109
|
+
registry: dict[DatasourceType, BuildPlugin] = {}
|
|
109
110
|
for plugins in plugin_lists:
|
|
110
111
|
for plugin in plugins:
|
|
111
112
|
for full_type in plugin.supported_types():
|
|
@@ -6,10 +6,11 @@ from dataclasses import dataclass, replace
|
|
|
6
6
|
from urllib.parse import urlparse
|
|
7
7
|
|
|
8
8
|
import duckdb
|
|
9
|
-
from
|
|
9
|
+
from duckdb import DuckDBPyConnection
|
|
10
10
|
from pydantic import BaseModel, Field
|
|
11
11
|
|
|
12
12
|
from databao_context_engine.pluginlib.config import DuckDBSecret
|
|
13
|
+
from databao_context_engine.plugins.duckdb_tools import fetchall_dicts, generate_create_secret_sql
|
|
13
14
|
|
|
14
15
|
parquet_type = "resources/parquet"
|
|
15
16
|
|
|
@@ -18,9 +19,8 @@ logger = logging.getLogger(__name__)
|
|
|
18
19
|
|
|
19
20
|
class ParquetConfigFile(BaseModel):
|
|
20
21
|
name: str | None = Field(default=None)
|
|
21
|
-
type: str = Field(default=
|
|
22
|
+
type: str = Field(default="parquet")
|
|
22
23
|
url: str = Field(
|
|
23
|
-
default=type,
|
|
24
24
|
description="Parquet resource location. Should be a valid URL or a path to a local file. "
|
|
25
25
|
"Examples: s3://your_bucket/file.parquet, s3://your-bucket/*.parquet, https://some.url/some_file.parquet, ~/path_to/file.parquet",
|
|
26
26
|
)
|
|
@@ -50,14 +50,6 @@ class ParquetIntrospectionResult:
|
|
|
50
50
|
files: list[ParquetFile]
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def generate_create_secret_sql(secret_name, duckdb_secret: DuckDBSecret) -> str:
|
|
54
|
-
parameters = [("type", duckdb_secret.type)] + list(duckdb_secret.properties.items())
|
|
55
|
-
return f"""CREATE SECRET {secret_name} (
|
|
56
|
-
{", ".join([f"{k} {v}" for (k, v) in parameters])}
|
|
57
|
-
);
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
53
|
@contextlib.contextmanager
|
|
62
54
|
def _create_secret(conn: DuckDBPyConnection, duckdb_secret: DuckDBSecret):
|
|
63
55
|
secret_name = duckdb_secret.name or "gen_secret_" + str(uuid.uuid4()).replace("-", "_")
|
|
@@ -98,10 +90,9 @@ class ParquetIntrospector:
|
|
|
98
90
|
with self._connect(file_config) as conn:
|
|
99
91
|
with conn.cursor() as cur:
|
|
100
92
|
resolved_url = _resolve_url(file_config)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
parquet_file_metadata = [dict(zip(columns, row)) for row in rows]
|
|
93
|
+
parquet_file_metadata = fetchall_dicts(
|
|
94
|
+
cur, f"SELECT * FROM parquet_file_metadata('{resolved_url}') LIMIT 1"
|
|
95
|
+
)
|
|
105
96
|
if not parquet_file_metadata:
|
|
106
97
|
raise ValueError(f"No parquet files found by url {resolved_url}")
|
|
107
98
|
if not parquet_file_metadata or not parquet_file_metadata[0]["file_name"]:
|
|
@@ -111,10 +102,7 @@ class ParquetIntrospector:
|
|
|
111
102
|
with self._connect(file_config) as conn:
|
|
112
103
|
with conn.cursor() as cur:
|
|
113
104
|
resolved_url = _resolve_url(file_config)
|
|
114
|
-
cur
|
|
115
|
-
cols = [desc[0].lower() for desc in cur.description] if cur.description else []
|
|
116
|
-
rows = cur.fetchall()
|
|
117
|
-
file_metas = [dict(zip(cols, row)) for row in rows]
|
|
105
|
+
file_metas = fetchall_dicts(cur, f"SELECT * from parquet_metadata('{resolved_url}')")
|
|
118
106
|
|
|
119
107
|
columns_per_file: dict[str, dict[int, ParquetColumn]] = defaultdict(defaultdict)
|
|
120
108
|
for file_meta in file_metas:
|
|
@@ -3,29 +3,56 @@ from importlib.metadata import version
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from uuid import UUID
|
|
5
5
|
|
|
6
|
+
from databao_context_engine.plugins.plugin_loader import load_plugins
|
|
6
7
|
from databao_context_engine.project.layout import validate_project_dir
|
|
7
8
|
from databao_context_engine.system.properties import get_dce_path
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@dataclass(kw_only=True, frozen=True)
|
|
11
12
|
class DceProjectInfo:
|
|
13
|
+
"""Information about a Databao Context Engine project.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
project_path: The root directory of the Databao Context Engine project.
|
|
17
|
+
is_initialized: Whether the project has been initialized.
|
|
18
|
+
project_id: The UUID of the project, or None if the project has not been initialized.
|
|
19
|
+
"""
|
|
20
|
+
|
|
12
21
|
project_path: Path
|
|
13
|
-
|
|
22
|
+
is_initialized: bool
|
|
14
23
|
project_id: UUID | None
|
|
15
24
|
|
|
16
25
|
|
|
17
26
|
@dataclass(kw_only=True, frozen=True)
|
|
18
27
|
class DceInfo:
|
|
28
|
+
"""Information about the current Databao Context Engine installation and project.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
version: The version of the databao_context_engine package installed on the system.
|
|
32
|
+
dce_path: The path where databao_context_engine stores its global data.
|
|
33
|
+
project_info: Information about the Databao Context Engine project.
|
|
34
|
+
"""
|
|
35
|
+
|
|
19
36
|
version: str
|
|
20
37
|
dce_path: Path
|
|
38
|
+
plugin_ids: list[str]
|
|
21
39
|
|
|
22
40
|
project_info: DceProjectInfo
|
|
23
41
|
|
|
24
42
|
|
|
25
43
|
def get_databao_context_engine_info(project_dir: Path) -> DceInfo:
|
|
44
|
+
"""Return information about the current Databao Context Engine installation and project.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
project_dir: The root directory of the Databao Context Project.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
A DceInfo instance containing information about the Databao Context Engine installation and project.
|
|
51
|
+
"""
|
|
26
52
|
return DceInfo(
|
|
27
53
|
version=get_dce_version(),
|
|
28
54
|
dce_path=get_dce_path(),
|
|
55
|
+
plugin_ids=[plugin.id for plugin in load_plugins().values()],
|
|
29
56
|
project_info=_get_project_info(project_dir),
|
|
30
57
|
)
|
|
31
58
|
|
|
@@ -35,10 +62,15 @@ def _get_project_info(project_dir: Path) -> DceProjectInfo:
|
|
|
35
62
|
|
|
36
63
|
return DceProjectInfo(
|
|
37
64
|
project_path=project_dir,
|
|
38
|
-
|
|
65
|
+
is_initialized=project_layout is not None,
|
|
39
66
|
project_id=project_layout.read_config_file().project_id if project_layout is not None else None,
|
|
40
67
|
)
|
|
41
68
|
|
|
42
69
|
|
|
43
70
|
def get_dce_version() -> str:
|
|
71
|
+
"""Return the installed version of the databao_context_engine package.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The installed version of the databao_context_engine package.
|
|
75
|
+
"""
|
|
44
76
|
return version("databao_context_engine")
|
|
@@ -13,12 +13,21 @@ from databao_context_engine.project.project_config import ProjectConfig
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class InitErrorReason(Enum):
|
|
16
|
+
"""Reasons for which project initialization can fail."""
|
|
17
|
+
|
|
16
18
|
PROJECT_DIR_DOESNT_EXIST = "PROJECT_DIR_DOESNT_EXIST"
|
|
17
19
|
PROJECT_DIR_NOT_DIRECTORY = "PROJECT_DIR_NOT_DIRECTORY"
|
|
18
|
-
|
|
20
|
+
PROJECT_DIR_ALREADY_INITIALIZED = "PROJECT_DIR_ALREADY_INITIALIZED"
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class InitProjectError(Exception):
|
|
24
|
+
"""Raised when a project can't be initialized.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
message: The error message.
|
|
28
|
+
reason: The reason for the initialization failure.
|
|
29
|
+
"""
|
|
30
|
+
|
|
22
31
|
reason: InitErrorReason
|
|
23
32
|
|
|
24
33
|
def __init__(self, reason: InitErrorReason, message: str | None):
|
|
@@ -65,20 +74,20 @@ class _ProjectCreator:
|
|
|
65
74
|
|
|
66
75
|
if self.config_file.is_file() or self.deprecated_config_file.is_file():
|
|
67
76
|
raise InitProjectError(
|
|
68
|
-
message=f"Can't
|
|
69
|
-
reason=InitErrorReason.
|
|
77
|
+
message=f"Can't initialize a Databao Context Engine project in a folder that already contains a config file. [project_dir: {self.project_dir.resolve()}]",
|
|
78
|
+
reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
|
|
70
79
|
)
|
|
71
80
|
|
|
72
81
|
if self.src_dir.is_dir():
|
|
73
82
|
raise InitProjectError(
|
|
74
|
-
message=f"Can't
|
|
75
|
-
reason=InitErrorReason.
|
|
83
|
+
message=f"Can't initialize a Databao Context Engine project in a folder that already contains a src directory. [project_dir: {self.project_dir.resolve()}]",
|
|
84
|
+
reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
|
|
76
85
|
)
|
|
77
86
|
|
|
78
87
|
if self.examples_dir.is_file():
|
|
79
88
|
raise InitProjectError(
|
|
80
|
-
message=f"Can't
|
|
81
|
-
reason=InitErrorReason.
|
|
89
|
+
message=f"Can't initialize a Databao Context Engine project in a folder that already contains an examples dir. [project_dir: {self.project_dir.resolve()}]",
|
|
90
|
+
reason=InitErrorReason.PROJECT_DIR_ALREADY_INITIALIZED,
|
|
82
91
|
)
|
|
83
92
|
|
|
84
93
|
return True
|
|
@@ -2,8 +2,8 @@ import logging
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
|
+
from databao_context_engine.datasources.types import DatasourceId
|
|
5
6
|
from databao_context_engine.project.project_config import ProjectConfig
|
|
6
|
-
from databao_context_engine.project.types import DatasourceId
|
|
7
7
|
|
|
8
8
|
SOURCE_FOLDER_NAME = "src"
|
|
9
9
|
OUTPUT_FOLDER_NAME = "output"
|
|
@@ -96,12 +96,12 @@ class _ProjectValidator:
|
|
|
96
96
|
|
|
97
97
|
if self.config_file is None:
|
|
98
98
|
raise ValueError(
|
|
99
|
-
f"The current project directory has not been
|
|
99
|
+
f"The current project directory has not been initialized. It should contain a config file. [project_dir: {self.project_dir.resolve()}]"
|
|
100
100
|
)
|
|
101
101
|
|
|
102
102
|
if not self.is_src_valid():
|
|
103
103
|
raise ValueError(
|
|
104
|
-
f"The current project directory has not been
|
|
104
|
+
f"The current project directory has not been initialized. It should contain a src directory. [project_dir: {self.project_dir.resolve()}]"
|
|
105
105
|
)
|
|
106
106
|
|
|
107
107
|
return ProjectLayout(project_dir=self.project_dir, config_file=self.config_file)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
def export_retrieve_results(
|
|
5
|
-
path =
|
|
4
|
+
def export_retrieve_results(output_dir: Path, retrieve_results: list[str]) -> Path:
|
|
5
|
+
path = output_dir.joinpath("context_duckdb.yaml")
|
|
6
6
|
|
|
7
7
|
with path.open("w") as export_file:
|
|
8
8
|
for result in retrieve_results:
|
databao_context_engine/retrieve_embeddings/{internal/retrieve_runner.py → retrieve_runner.py}
RENAMED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
-
from databao_context_engine.project.
|
|
5
|
-
from databao_context_engine.retrieve_embeddings.
|
|
6
|
-
from databao_context_engine.retrieve_embeddings.
|
|
4
|
+
from databao_context_engine.project.layout import get_output_dir
|
|
5
|
+
from databao_context_engine.retrieve_embeddings.export_results import export_retrieve_results
|
|
6
|
+
from databao_context_engine.retrieve_embeddings.retrieve_service import RetrieveService
|
|
7
7
|
from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger(__name__)
|
|
@@ -15,17 +15,13 @@ def retrieve(
|
|
|
15
15
|
retrieve_service: RetrieveService,
|
|
16
16
|
project_id: str,
|
|
17
17
|
text: str,
|
|
18
|
-
run_name: str | None,
|
|
19
18
|
limit: int | None,
|
|
20
19
|
export_to_file: bool,
|
|
21
20
|
) -> list[VectorSearchResult]:
|
|
22
|
-
|
|
23
|
-
retrieve_results = retrieve_service.retrieve(
|
|
24
|
-
project_id=project_id, text=text, run_name=resolved_run_name, limit=limit
|
|
25
|
-
)
|
|
21
|
+
retrieve_results = retrieve_service.retrieve(project_id=project_id, text=text, limit=limit)
|
|
26
22
|
|
|
27
23
|
if export_to_file:
|
|
28
|
-
export_directory =
|
|
24
|
+
export_directory = get_output_dir(project_dir)
|
|
29
25
|
|
|
30
26
|
display_texts = [result.display_text for result in retrieve_results]
|
|
31
27
|
export_file = export_retrieve_results(export_directory, display_texts)
|
databao_context_engine/retrieve_embeddings/{internal/retrieve_service.py → retrieve_service.py}
RENAMED
|
@@ -2,9 +2,7 @@ import logging
|
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
|
|
4
4
|
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
5
|
-
from databao_context_engine.project.runs import resolve_run_name_from_repo
|
|
6
5
|
from databao_context_engine.services.embedding_shard_resolver import EmbeddingShardResolver
|
|
7
|
-
from databao_context_engine.storage.repositories.run_repository import RunRepository
|
|
8
6
|
from databao_context_engine.storage.repositories.vector_search_repository import (
|
|
9
7
|
VectorSearchRepository,
|
|
10
8
|
VectorSearchResult,
|
|
@@ -17,43 +15,34 @@ class RetrieveService:
|
|
|
17
15
|
def __init__(
|
|
18
16
|
self,
|
|
19
17
|
*,
|
|
20
|
-
run_repo: RunRepository,
|
|
21
18
|
vector_search_repo: VectorSearchRepository,
|
|
22
19
|
shard_resolver: EmbeddingShardResolver,
|
|
23
20
|
provider: EmbeddingProvider,
|
|
24
21
|
):
|
|
25
|
-
self._run_repo = run_repo
|
|
26
22
|
self._shard_resolver = shard_resolver
|
|
27
23
|
self._provider = provider
|
|
28
24
|
self._vector_search_repo = vector_search_repo
|
|
29
25
|
|
|
30
|
-
def retrieve(
|
|
31
|
-
self, *, project_id: str, text: str, run_name: str, limit: int | None = None
|
|
32
|
-
) -> list[VectorSearchResult]:
|
|
26
|
+
def retrieve(self, *, project_id: str, text: str, limit: int | None = None) -> list[VectorSearchResult]:
|
|
33
27
|
if limit is None:
|
|
34
28
|
limit = 10
|
|
35
29
|
|
|
36
|
-
run = self._run_repo.get_by_run_name(project_id=project_id, run_name=run_name)
|
|
37
|
-
if run is None:
|
|
38
|
-
raise LookupError(f"Run '{run_name}' not found for project '{project_id}'.")
|
|
39
|
-
|
|
40
30
|
table_name, dimension = self._shard_resolver.resolve(
|
|
41
31
|
embedder=self._provider.embedder, model_id=self._provider.model_id
|
|
42
32
|
)
|
|
43
33
|
|
|
44
34
|
retrieve_vec: Sequence[float] = self._provider.embed(text)
|
|
45
35
|
|
|
46
|
-
logger.debug(f"Retrieving display texts
|
|
36
|
+
logger.debug(f"Retrieving display texts in table {table_name}")
|
|
47
37
|
|
|
48
38
|
search_results = self._vector_search_repo.get_display_texts_by_similarity(
|
|
49
39
|
table_name=table_name,
|
|
50
|
-
run_id=run.run_id,
|
|
51
40
|
retrieve_vec=retrieve_vec,
|
|
52
41
|
dimension=dimension,
|
|
53
42
|
limit=limit,
|
|
54
43
|
)
|
|
55
44
|
|
|
56
|
-
logger.debug(f"Retrieved {len(search_results)} display texts
|
|
45
|
+
logger.debug(f"Retrieved {len(search_results)} display texts in table {table_name}")
|
|
57
46
|
|
|
58
47
|
if logger.isEnabledFor(logging.DEBUG):
|
|
59
48
|
closest_result = min(search_results, key=lambda result: result.cosine_distance)
|
|
@@ -63,6 +52,3 @@ class RetrieveService:
|
|
|
63
52
|
logger.debug(f"Worst result: ({farthest_result.cosine_distance}, {farthest_result.embeddable_text})")
|
|
64
53
|
|
|
65
54
|
return search_results
|
|
66
|
-
|
|
67
|
-
def resolve_run_name(self, *, project_id: str, run_name: str | None) -> str:
|
|
68
|
-
return resolve_run_name_from_repo(run_repository=self._run_repo, project_id=project_id, run_name=run_name)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from duckdb import DuckDBPyConnection
|
|
2
|
+
|
|
3
|
+
from databao_context_engine.llm.embeddings.provider import EmbeddingProvider
|
|
4
|
+
from databao_context_engine.llm.factory import create_ollama_embedding_provider, create_ollama_service
|
|
5
|
+
from databao_context_engine.project.layout import ProjectLayout, ensure_project_dir
|
|
6
|
+
from databao_context_engine.retrieve_embeddings.retrieve_runner import retrieve
|
|
7
|
+
from databao_context_engine.retrieve_embeddings.retrieve_service import RetrieveService
|
|
8
|
+
from databao_context_engine.services.factories import create_shard_resolver
|
|
9
|
+
from databao_context_engine.storage.connection import open_duckdb_connection
|
|
10
|
+
from databao_context_engine.storage.repositories.factories import create_vector_search_repository
|
|
11
|
+
from databao_context_engine.storage.repositories.vector_search_repository import VectorSearchResult
|
|
12
|
+
from databao_context_engine.system.properties import get_db_path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def retrieve_embeddings(
|
|
16
|
+
project_layout: ProjectLayout,
|
|
17
|
+
retrieve_text: str,
|
|
18
|
+
limit: int | None,
|
|
19
|
+
export_to_file: bool,
|
|
20
|
+
) -> list[VectorSearchResult]:
|
|
21
|
+
ensure_project_dir(project_layout.project_dir)
|
|
22
|
+
|
|
23
|
+
with open_duckdb_connection(get_db_path(project_layout.project_dir)) as conn:
|
|
24
|
+
ollama_service = create_ollama_service()
|
|
25
|
+
embedding_provider = create_ollama_embedding_provider(ollama_service)
|
|
26
|
+
retrieve_service = _create_retrieve_service(conn, embedding_provider=embedding_provider)
|
|
27
|
+
return retrieve(
|
|
28
|
+
project_dir=project_layout.project_dir,
|
|
29
|
+
retrieve_service=retrieve_service,
|
|
30
|
+
project_id=str(project_layout.read_config_file().project_id),
|
|
31
|
+
text=retrieve_text,
|
|
32
|
+
limit=limit,
|
|
33
|
+
export_to_file=export_to_file,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _create_retrieve_service(
|
|
38
|
+
conn: DuckDBPyConnection,
|
|
39
|
+
*,
|
|
40
|
+
embedding_provider: EmbeddingProvider,
|
|
41
|
+
) -> RetrieveService:
|
|
42
|
+
vector_search_repo = create_vector_search_repository(conn)
|
|
43
|
+
shard_resolver = create_shard_resolver(conn)
|
|
44
|
+
|
|
45
|
+
return RetrieveService(
|
|
46
|
+
vector_search_repo=vector_search_repo,
|
|
47
|
+
shard_resolver=shard_resolver,
|
|
48
|
+
provider=embedding_provider,
|
|
49
|
+
)
|
|
@@ -8,7 +8,7 @@ def default_representer(dumper: SafeDumper, data: object) -> Node:
|
|
|
8
8
|
if isinstance(data, Mapping):
|
|
9
9
|
return dumper.represent_dict(data)
|
|
10
10
|
elif hasattr(data, "__dict__"):
|
|
11
|
-
# Doesn't
|
|
11
|
+
# Doesn't serialize "private" attributes (that starts with an _)
|
|
12
12
|
data_public_attributes = {key: value for key, value in data.__dict__.items() if not key.startswith("_")}
|
|
13
13
|
if data_public_attributes:
|
|
14
14
|
return dumper.represent_dict(data_public_attributes)
|