databao-context-engine 0.1.4.dev1__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databao_context_engine/__init__.py +14 -1
- databao_context_engine/build_sources/build_runner.py +7 -7
- databao_context_engine/build_sources/build_wiring.py +8 -10
- databao_context_engine/build_sources/plugin_execution.py +9 -12
- databao_context_engine/cli/add_datasource_config.py +9 -30
- databao_context_engine/cli/commands.py +29 -13
- databao_context_engine/databao_context_engine.py +3 -13
- databao_context_engine/databao_context_project_manager.py +56 -29
- databao_context_engine/datasources/check_config.py +13 -16
- databao_context_engine/datasources/datasource_context.py +21 -24
- databao_context_engine/datasources/datasource_discovery.py +45 -44
- databao_context_engine/datasources/types.py +53 -42
- databao_context_engine/generate_configs_schemas.py +4 -5
- databao_context_engine/introspection/property_extract.py +52 -47
- databao_context_engine/llm/__init__.py +10 -0
- databao_context_engine/llm/api.py +57 -0
- databao_context_engine/llm/descriptions/ollama.py +1 -3
- databao_context_engine/llm/factory.py +5 -2
- databao_context_engine/llm/install.py +13 -10
- databao_context_engine/llm/runtime.py +3 -5
- databao_context_engine/mcp/mcp_server.py +1 -3
- databao_context_engine/plugin_loader.py +6 -7
- databao_context_engine/pluginlib/build_plugin.py +0 -33
- databao_context_engine/plugins/databases/athena/__init__.py +0 -0
- databao_context_engine/plugins/{athena_db_plugin.py → databases/athena/athena_db_plugin.py} +3 -3
- databao_context_engine/plugins/databases/{athena_introspector.py → athena/athena_introspector.py} +2 -5
- databao_context_engine/plugins/{base_db_plugin.py → databases/base_db_plugin.py} +1 -3
- databao_context_engine/plugins/databases/base_introspector.py +11 -14
- databao_context_engine/plugins/databases/clickhouse/__init__.py +0 -0
- databao_context_engine/plugins/{clickhouse_db_plugin.py → databases/clickhouse/clickhouse_db_plugin.py} +3 -3
- databao_context_engine/plugins/databases/{clickhouse_introspector.py → clickhouse/clickhouse_introspector.py} +2 -5
- databao_context_engine/plugins/databases/duckdb/__init__.py +0 -0
- databao_context_engine/plugins/databases/duckdb/duckdb_db_plugin.py +12 -0
- databao_context_engine/plugins/databases/{duckdb_introspector.py → duckdb/duckdb_introspector.py} +4 -7
- databao_context_engine/plugins/databases/mssql/__init__.py +0 -0
- databao_context_engine/plugins/{mssql_db_plugin.py → databases/mssql/mssql_db_plugin.py} +3 -3
- databao_context_engine/plugins/databases/{mssql_introspector.py → mssql/mssql_introspector.py} +9 -10
- databao_context_engine/plugins/databases/mysql/__init__.py +0 -0
- databao_context_engine/plugins/{mysql_db_plugin.py → databases/mysql/mysql_db_plugin.py} +3 -3
- databao_context_engine/plugins/databases/{mysql_introspector.py → mysql/mysql_introspector.py} +8 -8
- databao_context_engine/plugins/databases/postgresql/__init__.py +0 -0
- databao_context_engine/plugins/databases/postgresql/postgresql_db_plugin.py +15 -0
- databao_context_engine/plugins/databases/{postgresql_introspector.py → postgresql/postgresql_introspector.py} +9 -16
- databao_context_engine/plugins/databases/snowflake/__init__.py +0 -0
- databao_context_engine/plugins/databases/snowflake/snowflake_db_plugin.py +15 -0
- databao_context_engine/plugins/databases/{snowflake_introspector.py → snowflake/snowflake_introspector.py} +8 -9
- databao_context_engine/plugins/databases/sqlite/__init__.py +0 -0
- databao_context_engine/plugins/databases/sqlite/sqlite_db_plugin.py +12 -0
- databao_context_engine/plugins/databases/sqlite/sqlite_introspector.py +241 -0
- databao_context_engine/plugins/dbt/__init__.py +0 -0
- databao_context_engine/plugins/dbt/dbt_chunker.py +47 -0
- databao_context_engine/plugins/dbt/dbt_context_extractor.py +106 -0
- databao_context_engine/plugins/dbt/dbt_plugin.py +25 -0
- databao_context_engine/plugins/dbt/types.py +44 -0
- databao_context_engine/plugins/dbt/types_artifacts.py +58 -0
- databao_context_engine/plugins/files/__init__.py +0 -0
- databao_context_engine/plugins/{unstructured_files_plugin.py → files/unstructured_files_plugin.py} +1 -1
- databao_context_engine/plugins/plugin_loader.py +13 -15
- databao_context_engine/plugins/resources/parquet_introspector.py +1 -1
- databao_context_engine/plugins/{parquet_plugin.py → resources/parquet_plugin.py} +1 -3
- databao_context_engine/project/layout.py +12 -13
- databao_context_engine/retrieve_embeddings/retrieve_runner.py +3 -16
- databao_context_engine/retrieve_embeddings/retrieve_service.py +13 -6
- databao_context_engine/retrieve_embeddings/retrieve_wiring.py +4 -7
- databao_context_engine/serialization/yaml.py +5 -5
- databao_context_engine/storage/migrate.py +1 -1
- databao_context_engine/storage/repositories/vector_search_repository.py +18 -6
- databao_context_engine-0.1.6.dist-info/METADATA +228 -0
- {databao_context_engine-0.1.4.dev1.dist-info → databao_context_engine-0.1.6.dist-info}/RECORD +71 -55
- databao_context_engine/datasources/add_config.py +0 -34
- databao_context_engine/plugins/duckdb_db_plugin.py +0 -12
- databao_context_engine/plugins/postgresql_db_plugin.py +0 -12
- databao_context_engine/plugins/snowflake_db_plugin.py +0 -12
- databao_context_engine/retrieve_embeddings/export_results.py +0 -12
- databao_context_engine-0.1.4.dev1.dist-info/METADATA +0 -75
- {databao_context_engine-0.1.4.dev1.dist-info → databao_context_engine-0.1.6.dist-info}/WHEEL +0 -0
- {databao_context_engine-0.1.4.dev1.dist-info → databao_context_engine-0.1.6.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from databao_context_engine.llm.errors import OllamaError
|
|
4
|
+
from databao_context_engine.llm.factory import (
|
|
5
|
+
DEFAULT_DESCRIPTION_GENERATOR_MODEL,
|
|
6
|
+
DEFAULT_EMBED_MODEL_ID,
|
|
7
|
+
create_ollama_service,
|
|
8
|
+
)
|
|
9
|
+
from databao_context_engine.llm.install import resolve_ollama_bin
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def install_ollama_if_needed() -> Path:
|
|
13
|
+
"""Install the Ollama CLI locally if needed.
|
|
14
|
+
|
|
15
|
+
This will look for any existing installation of Ollama on the system. If none is found, it will install it locally.
|
|
16
|
+
|
|
17
|
+
Here is the priority order of how it looks for an installed Ollama CLI binary:
|
|
18
|
+
1. Look at the path defined in the DCE_OLLAMA_BIN env var, if it is set
|
|
19
|
+
2. Look for `ollama` in the PATH
|
|
20
|
+
3. Look for a DCE-managed installation in the global DCE path
|
|
21
|
+
|
|
22
|
+
If Ollama is not found, it will get installed as a DCE-managed installation in the global DCE path.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The path to the Ollama CLI executable.
|
|
26
|
+
"""
|
|
27
|
+
return Path(resolve_ollama_bin())
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def download_ollama_models_if_needed(
|
|
31
|
+
*, download_embed_model: bool = True, download_description_generator_model: bool = False
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Download the Ollama models required to run DCE if needed.
|
|
34
|
+
|
|
35
|
+
If the models were already downloaded, this method will do nothing.
|
|
36
|
+
|
|
37
|
+
If no Ollama CLI is found on the system, this method will install one as a DCE-managed installation in the global DCE path.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
download_embed_model: Whether to download the embedding model.
|
|
41
|
+
download_description_generator_model: Whether to download the description generator model.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
OllamaError: If there is an error downloading one of the models.
|
|
45
|
+
"""
|
|
46
|
+
ollama_service = create_ollama_service()
|
|
47
|
+
|
|
48
|
+
if download_embed_model:
|
|
49
|
+
try:
|
|
50
|
+
ollama_service.pull_model_if_needed(model=DEFAULT_EMBED_MODEL_ID)
|
|
51
|
+
except OllamaError as e:
|
|
52
|
+
raise e
|
|
53
|
+
if download_description_generator_model:
|
|
54
|
+
try:
|
|
55
|
+
ollama_service.pull_model_if_needed(model=DEFAULT_DESCRIPTION_GENERATOR_MODEL)
|
|
56
|
+
except OllamaError as e:
|
|
57
|
+
raise e
|
|
@@ -16,6 +16,4 @@ class OllamaDescriptionProvider(DescriptionProvider):
|
|
|
16
16
|
return self._model_id
|
|
17
17
|
|
|
18
18
|
def describe(self, text: str, context: str) -> str:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
return description
|
|
19
|
+
return self._service.describe(model=self._model_id, text=text, context=context)
|
|
@@ -5,6 +5,9 @@ from databao_context_engine.llm.install import resolve_ollama_bin
|
|
|
5
5
|
from databao_context_engine.llm.runtime import OllamaRuntime
|
|
6
6
|
from databao_context_engine.llm.service import OllamaService
|
|
7
7
|
|
|
8
|
+
DEFAULT_EMBED_MODEL_ID = "nomic-embed-text:v1.5"
|
|
9
|
+
DEFAULT_DESCRIPTION_GENERATOR_MODEL = "llama3.2:1b"
|
|
10
|
+
|
|
8
11
|
|
|
9
12
|
def _create_ollama_service_common(
|
|
10
13
|
*,
|
|
@@ -39,7 +42,7 @@ def create_ollama_service(
|
|
|
39
42
|
def create_ollama_embedding_provider(
|
|
40
43
|
service: OllamaService,
|
|
41
44
|
*,
|
|
42
|
-
model_id: str =
|
|
45
|
+
model_id: str = DEFAULT_EMBED_MODEL_ID,
|
|
43
46
|
dim: int = 768,
|
|
44
47
|
pull_if_needed: bool = True,
|
|
45
48
|
) -> OllamaEmbeddingProvider:
|
|
@@ -52,7 +55,7 @@ def create_ollama_embedding_provider(
|
|
|
52
55
|
def create_ollama_description_provider(
|
|
53
56
|
service: OllamaService,
|
|
54
57
|
*,
|
|
55
|
-
model_id: str =
|
|
58
|
+
model_id: str = DEFAULT_DESCRIPTION_GENERATOR_MODEL,
|
|
56
59
|
pull_if_needed: bool = True,
|
|
57
60
|
):
|
|
58
61
|
if pull_if_needed:
|
|
@@ -95,16 +95,18 @@ def _detect_platform() -> str:
|
|
|
95
95
|
raise RuntimeError(f"Unsupported OS/arch: os={os_name!r} arch={arch!r}")
|
|
96
96
|
|
|
97
97
|
|
|
98
|
-
def
|
|
98
|
+
def _download_artifact_to_temp(artifact_version: str, artifact_name: str) -> Path:
|
|
99
99
|
"""Download to a temporary file and return its path."""
|
|
100
100
|
import urllib.request
|
|
101
101
|
|
|
102
|
+
artifact_url = f"https://github.com/ollama/ollama/releases/download/{artifact_version}/{artifact_name}"
|
|
103
|
+
|
|
102
104
|
tmp_dir = Path(tempfile.mkdtemp(prefix="ollama-download-"))
|
|
103
|
-
file_name =
|
|
105
|
+
file_name = artifact_url.rsplit("/", 1)[-1]
|
|
104
106
|
dest = tmp_dir / file_name
|
|
105
107
|
|
|
106
|
-
logger.info("Downloading %s to %s",
|
|
107
|
-
with urllib.request.urlopen(
|
|
108
|
+
logger.info("Downloading %s to %s", artifact_url, dest)
|
|
109
|
+
with urllib.request.urlopen(artifact_url) as resp, dest.open("wb") as out:
|
|
108
110
|
shutil.copyfileobj(resp, out)
|
|
109
111
|
|
|
110
112
|
return dest
|
|
@@ -128,10 +130,12 @@ def _extract_archive(archive: Path, target_dir: Path) -> None:
|
|
|
128
130
|
|
|
129
131
|
if name.endswith(".zip"):
|
|
130
132
|
with ZipFile(archive, "r") as zf:
|
|
131
|
-
|
|
133
|
+
# There is no built-in protection against zip bombs in ZipFile.
|
|
134
|
+
# However, we previously checked the sha256 of the downloaded archive and we trust the origin (GitHub repo of Ollama)
|
|
135
|
+
zf.extractall(target_dir) # noqa: S202
|
|
132
136
|
elif name.endswith(".tgz") or name.endswith(".tar.gz"):
|
|
133
137
|
with tarfile.open(archive, "r:gz") as tf:
|
|
134
|
-
tf.extractall(target_dir)
|
|
138
|
+
tf.extractall(target_dir, filter="data")
|
|
135
139
|
else:
|
|
136
140
|
raise RuntimeError(f"Unsupported archive format: {archive}")
|
|
137
141
|
|
|
@@ -142,7 +146,7 @@ def _ensure_executable(path: Path) -> None:
|
|
|
142
146
|
mode = path.stat().st_mode
|
|
143
147
|
path.chmod(mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
|
144
148
|
except Exception:
|
|
145
|
-
|
|
149
|
+
logger.debug("Failed to mark %s as executable", path, exc_info=True, stack_info=True)
|
|
146
150
|
|
|
147
151
|
|
|
148
152
|
def install_ollama_to(target: Path) -> None:
|
|
@@ -172,8 +176,7 @@ def install_ollama_to(target: Path) -> None:
|
|
|
172
176
|
except KeyError as e:
|
|
173
177
|
raise RuntimeError(f"Unsupported platform: {platform_key}") from e
|
|
174
178
|
|
|
175
|
-
|
|
176
|
-
archive_path = _download_to_temp(url)
|
|
179
|
+
archive_path = _download_artifact_to_temp(DEFAULT_VERSION, artifact.name)
|
|
177
180
|
|
|
178
181
|
try:
|
|
179
182
|
_verify_sha256(archive_path, artifact.sha256)
|
|
@@ -217,4 +220,4 @@ def install_ollama_to(target: Path) -> None:
|
|
|
217
220
|
try:
|
|
218
221
|
archive_path.unlink(missing_ok=True)
|
|
219
222
|
except Exception:
|
|
220
|
-
|
|
223
|
+
logger.debug("Failed to remove temporary archive %s", archive_path, exc_info=True, stack_info=True)
|
|
@@ -26,7 +26,7 @@ class OllamaRuntime:
|
|
|
26
26
|
|
|
27
27
|
stdout = subprocess.DEVNULL
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
return subprocess.Popen( # noqa: S603 We're always running Ollama
|
|
30
30
|
cmd,
|
|
31
31
|
cwd=str(self._config.work_dir) if self._config.work_dir else None,
|
|
32
32
|
env=env,
|
|
@@ -36,8 +36,6 @@ class OllamaRuntime:
|
|
|
36
36
|
close_fds=os.name != "nt",
|
|
37
37
|
)
|
|
38
38
|
|
|
39
|
-
return proc
|
|
40
|
-
|
|
41
39
|
def start_and_await(
|
|
42
40
|
self,
|
|
43
41
|
*,
|
|
@@ -62,11 +60,11 @@ class OllamaRuntime:
|
|
|
62
60
|
try:
|
|
63
61
|
proc.terminate()
|
|
64
62
|
except Exception:
|
|
65
|
-
|
|
63
|
+
logger.debug("Failed to terminate Ollama server", exc_info=True, stack_info=True)
|
|
66
64
|
try:
|
|
67
65
|
proc.kill()
|
|
68
66
|
except Exception:
|
|
69
|
-
|
|
67
|
+
logger.debug("Failed to kill Ollama server", exc_info=True, stack_info=True)
|
|
70
68
|
|
|
71
69
|
raise TimeoutError(
|
|
72
70
|
f"Timed out waiting for Ollama to become healthy at http://{self._config.host}:{self._config.port}"
|
|
@@ -47,9 +47,7 @@ class McpServer:
|
|
|
47
47
|
annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True, openWorldHint=False),
|
|
48
48
|
)
|
|
49
49
|
def retrieve_tool(text: str, limit: int | None):
|
|
50
|
-
retrieve_results = self._databao_context_engine.search_context(
|
|
51
|
-
retrieve_text=text, limit=limit, export_to_file=False
|
|
52
|
-
)
|
|
50
|
+
retrieve_results = self._databao_context_engine.search_context(retrieve_text=text, limit=limit)
|
|
53
51
|
|
|
54
52
|
display_results = [context_search_result.context_result for context_search_result in retrieve_results]
|
|
55
53
|
|
|
@@ -38,8 +38,8 @@ class DatabaoContextPluginLoader:
|
|
|
38
38
|
for (datasource_type, plugin) in self._all_plugins_by_type.items()
|
|
39
39
|
if not isinstance(plugin, BuildFilePlugin)
|
|
40
40
|
}
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
|
|
42
|
+
return set(self._all_plugins_by_type.keys())
|
|
43
43
|
|
|
44
44
|
def get_plugin_for_datasource_type(self, datasource_type: DatasourceType) -> BuildPlugin:
|
|
45
45
|
"""Return the plugin able to build a context for the given datasource type.
|
|
@@ -103,9 +103,8 @@ class DatabaoContextPluginLoader:
|
|
|
103
103
|
|
|
104
104
|
if isinstance(plugin, CustomiseConfigProperties):
|
|
105
105
|
return plugin.get_config_file_properties()
|
|
106
|
-
|
|
106
|
+
if isinstance(plugin, BuildDatasourcePlugin):
|
|
107
107
|
return get_property_list_from_type(plugin.config_file_type)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f'Impossible to create a config for datasource type "{datasource_type.full_type}". The corresponding plugin is a {type(plugin).__name__} but should be a BuildDatasourcePlugin or CustomiseConfigProperties'
|
|
110
|
+
)
|
|
@@ -88,36 +88,3 @@ class DatasourceType:
|
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
full_type: str
|
|
91
|
-
|
|
92
|
-
def __post_init__(self):
|
|
93
|
-
type_segments = self.full_type.split("/")
|
|
94
|
-
if len(type_segments) != 2:
|
|
95
|
-
raise ValueError(f"Invalid DatasourceType: {self.full_type}")
|
|
96
|
-
|
|
97
|
-
@property
|
|
98
|
-
def main_type(self) -> str:
|
|
99
|
-
"""The main type of the datasource, aka the folder in which the config or raw file is located."""
|
|
100
|
-
return self.full_type.split("/")[0]
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def config_folder(self) -> str:
|
|
104
|
-
"""The folder in which the config or raw file is located. This is equivalent to `main_type`."""
|
|
105
|
-
return self.main_type
|
|
106
|
-
|
|
107
|
-
@property
|
|
108
|
-
def subtype(self) -> str:
|
|
109
|
-
"""The subtype of the datasource. This is the actual type declared in the config file or the raw file's extension."""
|
|
110
|
-
return self.full_type.split("/")[1]
|
|
111
|
-
|
|
112
|
-
@staticmethod
|
|
113
|
-
def from_main_and_subtypes(main_type: str, subtype: str) -> "DatasourceType":
|
|
114
|
-
"""Create a DatasourceType from its main type and subtype.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
main_type: The main type (aka config folder) of the datasource.
|
|
118
|
-
subtype: The subtype of the datasource.
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
A DatasourceType instance with the specified main type and subtype.
|
|
122
|
-
"""
|
|
123
|
-
return DatasourceType(full_type=f"{main_type}/{subtype}")
|
|
File without changes
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from databao_context_engine.plugins.
|
|
2
|
-
from databao_context_engine.plugins.databases.
|
|
1
|
+
from databao_context_engine.plugins.databases.athena.athena_introspector import AthenaConfigFile, AthenaIntrospector
|
|
2
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class AthenaDbPlugin(BaseDatabasePlugin[AthenaConfigFile]):
|
|
6
6
|
id = "jetbrains/athena"
|
|
7
7
|
name = "Athena DB Plugin"
|
|
8
|
-
supported = {"
|
|
8
|
+
supported = {"athena"}
|
|
9
9
|
config_file_type = AthenaConfigFile
|
|
10
10
|
|
|
11
11
|
def __init__(self):
|
databao_context_engine/plugins/databases/{athena_introspector.py → athena/athena_introspector.py}
RENAMED
|
@@ -7,7 +7,7 @@ from pyathena.cursor import DictCursor
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
9
|
from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
|
|
10
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
10
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
|
|
11
11
|
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
12
12
|
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
|
|
13
13
|
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
@@ -67,7 +67,7 @@ class AthenaIntrospector(BaseIntrospector[AthenaConfigFile]):
|
|
|
67
67
|
}
|
|
68
68
|
supports_catalogs = True
|
|
69
69
|
|
|
70
|
-
def _connect(self, file_config: AthenaConfigFile):
|
|
70
|
+
def _connect(self, file_config: AthenaConfigFile, *, catalog: str | None = None) -> Any:
|
|
71
71
|
return connect(**file_config.connection.to_athena_kwargs(), cursor_class=DictCursor)
|
|
72
72
|
|
|
73
73
|
def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
|
|
@@ -79,9 +79,6 @@ class AthenaIntrospector(BaseIntrospector[AthenaConfigFile]):
|
|
|
79
79
|
catalog = file_config.connection.catalog or self._resolve_pseudo_catalog_name(file_config)
|
|
80
80
|
return [catalog]
|
|
81
81
|
|
|
82
|
-
def _connect_to_catalog(self, file_config: AthenaConfigFile, catalog: str):
|
|
83
|
-
return self._connect(file_config)
|
|
84
|
-
|
|
85
82
|
def _sql_list_schemas(self, catalogs: list[str] | None) -> SQLQuery:
|
|
86
83
|
if not catalogs:
|
|
87
84
|
return SQLQuery("SELECT schema_name, catalog_name FROM information_schema.schemata", None)
|
|
@@ -37,9 +37,7 @@ class BaseDatabasePlugin(BuildDatasourcePlugin[T]):
|
|
|
37
37
|
return self.supported
|
|
38
38
|
|
|
39
39
|
def build_context(self, full_type: str, datasource_name: str, file_config: T) -> Any:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
return introspection_result
|
|
40
|
+
return self._introspector.introspect_database(file_config)
|
|
43
41
|
|
|
44
42
|
def check_connection(self, full_type: str, datasource_name: str, file_config: T) -> None:
|
|
45
43
|
self._introspector.check_connection(file_config)
|
|
@@ -40,7 +40,7 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
|
|
|
40
40
|
|
|
41
41
|
discovered_schemas_per_catalog: dict[str, list[str]] = {}
|
|
42
42
|
for catalog in catalogs:
|
|
43
|
-
with self.
|
|
43
|
+
with self._connect(file_config, catalog=catalog) as conn:
|
|
44
44
|
discovered_schemas_per_catalog[catalog] = self._list_schemas_for_catalog(conn, catalog)
|
|
45
45
|
scope = scope_matcher.filter_scopes(catalogs, discovered_schemas_per_catalog)
|
|
46
46
|
|
|
@@ -50,7 +50,7 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
|
|
|
50
50
|
if not schemas_to_introspect:
|
|
51
51
|
continue
|
|
52
52
|
|
|
53
|
-
with self.
|
|
53
|
+
with self._connect(file_config, catalog=catalog) as catalog_connection:
|
|
54
54
|
introspected_schemas = self.collect_catalog_model(catalog_connection, catalog, schemas_to_introspect)
|
|
55
55
|
|
|
56
56
|
if not introspected_schemas:
|
|
@@ -74,9 +74,9 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
|
|
|
74
74
|
if self.supports_catalogs:
|
|
75
75
|
sql = "SELECT catalog_name, schema_name FROM information_schema.schemata WHERE catalog_name = ANY(%s)"
|
|
76
76
|
return SQLQuery(sql, (catalogs,))
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
|
|
78
|
+
sql = "SELECT schema_name FROM information_schema.schemata"
|
|
79
|
+
return SQLQuery(sql, None)
|
|
80
80
|
|
|
81
81
|
def _list_schemas_for_catalog(self, connection: Any, catalog: str) -> list[str]:
|
|
82
82
|
sql_query = self._sql_list_schemas([catalog] if self.supports_catalogs else None)
|
|
@@ -108,7 +108,12 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
|
|
|
108
108
|
return samples
|
|
109
109
|
|
|
110
110
|
@abstractmethod
|
|
111
|
-
def _connect(self, file_config: T):
|
|
111
|
+
def _connect(self, file_config: T, *, catalog: str | None = None) -> Any:
|
|
112
|
+
"""Connect to the database.
|
|
113
|
+
|
|
114
|
+
If the `catalog` argument is provided, the connection is "scoped" to that catalog. For engines that don’t need a new connection,
|
|
115
|
+
return a connection with the session set/USE’d to that catalog.
|
|
116
|
+
"""
|
|
112
117
|
raise NotImplementedError
|
|
113
118
|
|
|
114
119
|
@abstractmethod
|
|
@@ -119,14 +124,6 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
|
|
|
119
124
|
def _get_catalogs(self, connection, file_config: T) -> list[str]:
|
|
120
125
|
raise NotImplementedError
|
|
121
126
|
|
|
122
|
-
@abstractmethod
|
|
123
|
-
def _connect_to_catalog(self, file_config: T, catalog: str):
|
|
124
|
-
"""Return a connection scoped to `catalog`.
|
|
125
|
-
|
|
126
|
-
For engines that don’t need a new connection, return a connection with the
|
|
127
|
-
session set/USE’d to that catalog.
|
|
128
|
-
"""
|
|
129
|
-
|
|
130
127
|
def _sql_sample_rows(self, catalog: str, schema: str, table: str, limit: int) -> SQLQuery:
|
|
131
128
|
raise NotImplementedError
|
|
132
129
|
|
|
File without changes
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
|
|
2
|
-
from databao_context_engine.plugins.databases.clickhouse_introspector import (
|
|
1
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.clickhouse.clickhouse_introspector import (
|
|
3
3
|
ClickhouseConfigFile,
|
|
4
4
|
ClickhouseIntrospector,
|
|
5
5
|
)
|
|
@@ -8,7 +8,7 @@ from databao_context_engine.plugins.databases.clickhouse_introspector import (
|
|
|
8
8
|
class ClickhouseDbPlugin(BaseDatabasePlugin[ClickhouseConfigFile]):
|
|
9
9
|
id = "jetbrains/clickhouse"
|
|
10
10
|
name = "Clickhouse DB Plugin"
|
|
11
|
-
supported = {"
|
|
11
|
+
supported = {"clickhouse"}
|
|
12
12
|
config_file_type = ClickhouseConfigFile
|
|
13
13
|
|
|
14
14
|
def __init__(self):
|
|
@@ -6,7 +6,7 @@ import clickhouse_connect
|
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
8
|
from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
|
|
9
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
9
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
|
|
10
10
|
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
11
11
|
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
|
|
12
12
|
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
@@ -36,14 +36,11 @@ class ClickhouseIntrospector(BaseIntrospector[ClickhouseConfigFile]):
|
|
|
36
36
|
|
|
37
37
|
supports_catalogs = True
|
|
38
38
|
|
|
39
|
-
def _connect(self, file_config: ClickhouseConfigFile):
|
|
39
|
+
def _connect(self, file_config: ClickhouseConfigFile, *, catalog: str | None = None):
|
|
40
40
|
return clickhouse_connect.get_client(
|
|
41
41
|
**file_config.connection.to_clickhouse_kwargs(),
|
|
42
42
|
)
|
|
43
43
|
|
|
44
|
-
def _connect_to_catalog(self, file_config: ClickhouseConfigFile, catalog: str):
|
|
45
|
-
return self._connect(file_config)
|
|
46
|
-
|
|
47
44
|
def _get_catalogs(self, connection, file_config: ClickhouseConfigFile) -> list[str]:
|
|
48
45
|
return ["clickhouse"]
|
|
49
46
|
|
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.duckdb.duckdb_introspector import DuckDBConfigFile, DuckDBIntrospector
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DuckDbPlugin(BaseDatabasePlugin[DuckDBConfigFile]):
|
|
6
|
+
id = "jetbrains/duckdb"
|
|
7
|
+
name = "DuckDB Plugin"
|
|
8
|
+
supported = {"duckdb"}
|
|
9
|
+
config_file_type = DuckDBConfigFile
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
super().__init__(DuckDBIntrospector())
|
databao_context_engine/plugins/databases/{duckdb_introspector.py → duckdb/duckdb_introspector.py}
RENAMED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import duckdb
|
|
4
4
|
from pydantic import BaseModel, Field
|
|
5
5
|
|
|
6
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
6
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
|
|
7
7
|
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
8
8
|
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
|
|
9
9
|
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
@@ -16,7 +16,7 @@ class DuckDBConfigFile(BaseDatabaseConfigFile):
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class DuckDBConnectionConfig(BaseModel):
|
|
19
|
-
|
|
19
|
+
database_path: str = Field(description="Path to the DuckDB database file")
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class DuckDBIntrospector(BaseIntrospector[DuckDBConfigFile]):
|
|
@@ -24,13 +24,10 @@ class DuckDBIntrospector(BaseIntrospector[DuckDBConfigFile]):
|
|
|
24
24
|
_IGNORED_SCHEMAS = {"information_schema", "pg_catalog"}
|
|
25
25
|
supports_catalogs = True
|
|
26
26
|
|
|
27
|
-
def _connect(self, file_config: DuckDBConfigFile):
|
|
28
|
-
database_path = str(file_config.connection.
|
|
27
|
+
def _connect(self, file_config: DuckDBConfigFile, *, catalog: str | None = None):
|
|
28
|
+
database_path = str(file_config.connection.database_path)
|
|
29
29
|
return duckdb.connect(database=database_path)
|
|
30
30
|
|
|
31
|
-
def _connect_to_catalog(self, file_config: DuckDBConfigFile, catalog: str):
|
|
32
|
-
return self._connect(file_config)
|
|
33
|
-
|
|
34
31
|
def _get_catalogs(self, connection, file_config: DuckDBConfigFile) -> list[str]:
|
|
35
32
|
rows = self._fetchall_dicts(connection, "SELECT database_name FROM duckdb_databases();", None)
|
|
36
33
|
catalogs = [r["database_name"] for r in rows if r.get("database_name")]
|
|
File without changes
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
|
|
2
|
-
from databao_context_engine.plugins.databases.mssql_introspector import MSSQLConfigFile, MSSQLIntrospector
|
|
1
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.mssql.mssql_introspector import MSSQLConfigFile, MSSQLIntrospector
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class MSSQLDbPlugin(BaseDatabasePlugin[MSSQLConfigFile]):
|
|
6
6
|
id = "jetbrains/mssql"
|
|
7
7
|
name = "MSSQL DB Plugin"
|
|
8
|
-
supported = {"
|
|
8
|
+
supported = {"mssql"}
|
|
9
9
|
config_file_type = MSSQLConfigFile
|
|
10
10
|
|
|
11
11
|
def __init__(self):
|
databao_context_engine/plugins/databases/{mssql_introspector.py → mssql/mssql_introspector.py}
RENAMED
|
@@ -6,7 +6,7 @@ from mssql_python import connect # type: ignore[import-untyped]
|
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
8
8
|
from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
|
|
9
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
9
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
|
|
10
10
|
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
11
11
|
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema, DatabaseTable
|
|
12
12
|
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
@@ -55,15 +55,15 @@ class MSSQLIntrospector(BaseIntrospector[MSSQLConfigFile]):
|
|
|
55
55
|
)
|
|
56
56
|
supports_catalogs = True
|
|
57
57
|
|
|
58
|
-
def _connect(self, file_config: MSSQLConfigFile):
|
|
58
|
+
def _connect(self, file_config: MSSQLConfigFile, *, catalog: str | None = None):
|
|
59
59
|
connection = file_config.connection
|
|
60
|
-
connection_string = self._create_connection_string_for_config(connection.to_mssql_kwargs())
|
|
61
|
-
return connect(connection_string)
|
|
62
60
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
61
|
+
connection_kwargs = connection.to_mssql_kwargs()
|
|
62
|
+
if catalog:
|
|
63
|
+
connection_kwargs["database"] = catalog
|
|
64
|
+
|
|
65
|
+
connection_string = self._create_connection_string_for_config(connection_kwargs)
|
|
66
|
+
return connect(connection_string)
|
|
67
67
|
|
|
68
68
|
def _get_catalogs(self, connection, file_config: MSSQLConfigFile) -> list[str]:
|
|
69
69
|
database = file_config.connection.database
|
|
@@ -422,8 +422,7 @@ class MSSQLIntrospector(BaseIntrospector[MSSQLConfigFile]):
|
|
|
422
422
|
"trust_server_certificate": "yes" if file_config.get("trust_server_certificate") else None,
|
|
423
423
|
}
|
|
424
424
|
|
|
425
|
-
|
|
426
|
-
return connection_string
|
|
425
|
+
return ";".join(f"{k}={v}" for k, v in connection_parts.items() if v is not None)
|
|
427
426
|
|
|
428
427
|
def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
|
|
429
428
|
with connection.cursor() as cursor:
|
|
File without changes
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
|
|
2
|
-
from databao_context_engine.plugins.databases.mysql_introspector import MySQLConfigFile, MySQLIntrospector
|
|
1
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.mysql.mysql_introspector import MySQLConfigFile, MySQLIntrospector
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class MySQLDbPlugin(BaseDatabasePlugin[MySQLConfigFile]):
|
|
6
6
|
id = "jetbrains/mysql"
|
|
7
7
|
name = "MySQL DB Plugin"
|
|
8
|
-
supported = {"
|
|
8
|
+
supported = {"mysql"}
|
|
9
9
|
config_file_type = MySQLConfigFile
|
|
10
10
|
|
|
11
11
|
def __init__(self):
|
databao_context_engine/plugins/databases/{mysql_introspector.py → mysql/mysql_introspector.py}
RENAMED
|
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
|
|
5
5
|
from pymysql.constants import CLIENT
|
|
6
6
|
|
|
7
7
|
from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
|
|
8
|
-
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
8
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
|
|
9
9
|
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
10
10
|
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema, DatabaseTable
|
|
11
11
|
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
@@ -35,18 +35,18 @@ class MySQLIntrospector(BaseIntrospector[MySQLConfigFile]):
|
|
|
35
35
|
|
|
36
36
|
supports_catalogs = True
|
|
37
37
|
|
|
38
|
-
def _connect(self, file_config: MySQLConfigFile):
|
|
38
|
+
def _connect(self, file_config: MySQLConfigFile, *, catalog: str | None = None):
|
|
39
|
+
connection_kwargs = file_config.connection.to_pymysql_kwargs()
|
|
40
|
+
|
|
41
|
+
if catalog:
|
|
42
|
+
connection_kwargs["database"] = catalog
|
|
43
|
+
|
|
39
44
|
return pymysql.connect(
|
|
40
|
-
**
|
|
45
|
+
**connection_kwargs,
|
|
41
46
|
cursorclass=pymysql.cursors.DictCursor,
|
|
42
47
|
client_flag=CLIENT.MULTI_STATEMENTS | CLIENT.MULTI_RESULTS,
|
|
43
48
|
)
|
|
44
49
|
|
|
45
|
-
def _connect_to_catalog(self, file_config: MySQLConfigFile, catalog: str):
|
|
46
|
-
cfg = file_config.model_copy(deep=True)
|
|
47
|
-
cfg.connection.database = catalog
|
|
48
|
-
return self._connect(cfg)
|
|
49
|
-
|
|
50
50
|
def _get_catalogs(self, connection, file_config: MySQLConfigFile) -> list[str]:
|
|
51
51
|
with connection.cursor() as cur:
|
|
52
52
|
cur.execute(
|
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.postgresql.postgresql_introspector import (
|
|
3
|
+
PostgresConfigFile,
|
|
4
|
+
PostgresqlIntrospector,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PostgresqlDbPlugin(BaseDatabasePlugin[PostgresConfigFile]):
|
|
9
|
+
id = "jetbrains/postgres"
|
|
10
|
+
name = "PostgreSQL DB Plugin"
|
|
11
|
+
supported = {"postgres"}
|
|
12
|
+
config_file_type = PostgresConfigFile
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
super().__init__(PostgresqlIntrospector())
|