databao-context-engine 0.1.4.dev1__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. databao_context_engine/__init__.py +14 -1
  2. databao_context_engine/build_sources/build_runner.py +7 -7
  3. databao_context_engine/build_sources/build_wiring.py +8 -10
  4. databao_context_engine/build_sources/plugin_execution.py +9 -12
  5. databao_context_engine/cli/add_datasource_config.py +9 -30
  6. databao_context_engine/cli/commands.py +29 -13
  7. databao_context_engine/databao_context_engine.py +3 -13
  8. databao_context_engine/databao_context_project_manager.py +56 -29
  9. databao_context_engine/datasources/check_config.py +13 -16
  10. databao_context_engine/datasources/datasource_context.py +21 -24
  11. databao_context_engine/datasources/datasource_discovery.py +45 -44
  12. databao_context_engine/datasources/types.py +53 -42
  13. databao_context_engine/generate_configs_schemas.py +4 -5
  14. databao_context_engine/introspection/property_extract.py +52 -47
  15. databao_context_engine/llm/__init__.py +10 -0
  16. databao_context_engine/llm/api.py +57 -0
  17. databao_context_engine/llm/descriptions/ollama.py +1 -3
  18. databao_context_engine/llm/factory.py +5 -2
  19. databao_context_engine/llm/install.py +13 -10
  20. databao_context_engine/llm/runtime.py +3 -5
  21. databao_context_engine/mcp/mcp_server.py +1 -3
  22. databao_context_engine/plugin_loader.py +6 -7
  23. databao_context_engine/pluginlib/build_plugin.py +0 -33
  24. databao_context_engine/plugins/databases/athena/__init__.py +0 -0
  25. databao_context_engine/plugins/{athena_db_plugin.py → databases/athena/athena_db_plugin.py} +3 -3
  26. databao_context_engine/plugins/databases/{athena_introspector.py → athena/athena_introspector.py} +2 -5
  27. databao_context_engine/plugins/{base_db_plugin.py → databases/base_db_plugin.py} +1 -3
  28. databao_context_engine/plugins/databases/base_introspector.py +11 -14
  29. databao_context_engine/plugins/databases/clickhouse/__init__.py +0 -0
  30. databao_context_engine/plugins/{clickhouse_db_plugin.py → databases/clickhouse/clickhouse_db_plugin.py} +3 -3
  31. databao_context_engine/plugins/databases/{clickhouse_introspector.py → clickhouse/clickhouse_introspector.py} +2 -5
  32. databao_context_engine/plugins/databases/duckdb/__init__.py +0 -0
  33. databao_context_engine/plugins/databases/duckdb/duckdb_db_plugin.py +12 -0
  34. databao_context_engine/plugins/databases/{duckdb_introspector.py → duckdb/duckdb_introspector.py} +4 -7
  35. databao_context_engine/plugins/databases/mssql/__init__.py +0 -0
  36. databao_context_engine/plugins/{mssql_db_plugin.py → databases/mssql/mssql_db_plugin.py} +3 -3
  37. databao_context_engine/plugins/databases/{mssql_introspector.py → mssql/mssql_introspector.py} +9 -10
  38. databao_context_engine/plugins/databases/mysql/__init__.py +0 -0
  39. databao_context_engine/plugins/{mysql_db_plugin.py → databases/mysql/mysql_db_plugin.py} +3 -3
  40. databao_context_engine/plugins/databases/{mysql_introspector.py → mysql/mysql_introspector.py} +8 -8
  41. databao_context_engine/plugins/databases/postgresql/__init__.py +0 -0
  42. databao_context_engine/plugins/databases/postgresql/postgresql_db_plugin.py +15 -0
  43. databao_context_engine/plugins/databases/{postgresql_introspector.py → postgresql/postgresql_introspector.py} +9 -16
  44. databao_context_engine/plugins/databases/snowflake/__init__.py +0 -0
  45. databao_context_engine/plugins/databases/snowflake/snowflake_db_plugin.py +15 -0
  46. databao_context_engine/plugins/databases/{snowflake_introspector.py → snowflake/snowflake_introspector.py} +8 -9
  47. databao_context_engine/plugins/databases/sqlite/__init__.py +0 -0
  48. databao_context_engine/plugins/databases/sqlite/sqlite_db_plugin.py +12 -0
  49. databao_context_engine/plugins/databases/sqlite/sqlite_introspector.py +241 -0
  50. databao_context_engine/plugins/dbt/__init__.py +0 -0
  51. databao_context_engine/plugins/dbt/dbt_chunker.py +47 -0
  52. databao_context_engine/plugins/dbt/dbt_context_extractor.py +106 -0
  53. databao_context_engine/plugins/dbt/dbt_plugin.py +25 -0
  54. databao_context_engine/plugins/dbt/types.py +44 -0
  55. databao_context_engine/plugins/dbt/types_artifacts.py +58 -0
  56. databao_context_engine/plugins/files/__init__.py +0 -0
  57. databao_context_engine/plugins/{unstructured_files_plugin.py → files/unstructured_files_plugin.py} +1 -1
  58. databao_context_engine/plugins/plugin_loader.py +13 -15
  59. databao_context_engine/plugins/resources/parquet_introspector.py +1 -1
  60. databao_context_engine/plugins/{parquet_plugin.py → resources/parquet_plugin.py} +1 -3
  61. databao_context_engine/project/layout.py +12 -13
  62. databao_context_engine/retrieve_embeddings/retrieve_runner.py +3 -16
  63. databao_context_engine/retrieve_embeddings/retrieve_service.py +13 -6
  64. databao_context_engine/retrieve_embeddings/retrieve_wiring.py +4 -7
  65. databao_context_engine/serialization/yaml.py +5 -5
  66. databao_context_engine/storage/migrate.py +1 -1
  67. databao_context_engine/storage/repositories/vector_search_repository.py +18 -6
  68. databao_context_engine-0.1.6.dist-info/METADATA +228 -0
  69. {databao_context_engine-0.1.4.dev1.dist-info → databao_context_engine-0.1.6.dist-info}/RECORD +71 -55
  70. databao_context_engine/datasources/add_config.py +0 -34
  71. databao_context_engine/plugins/duckdb_db_plugin.py +0 -12
  72. databao_context_engine/plugins/postgresql_db_plugin.py +0 -12
  73. databao_context_engine/plugins/snowflake_db_plugin.py +0 -12
  74. databao_context_engine/retrieve_embeddings/export_results.py +0 -12
  75. databao_context_engine-0.1.4.dev1.dist-info/METADATA +0 -75
  76. {databao_context_engine-0.1.4.dev1.dist-info → databao_context_engine-0.1.6.dist-info}/WHEEL +0 -0
  77. {databao_context_engine-0.1.4.dev1.dist-info → databao_context_engine-0.1.6.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,57 @@
1
+ from pathlib import Path
2
+
3
+ from databao_context_engine.llm.errors import OllamaError
4
+ from databao_context_engine.llm.factory import (
5
+ DEFAULT_DESCRIPTION_GENERATOR_MODEL,
6
+ DEFAULT_EMBED_MODEL_ID,
7
+ create_ollama_service,
8
+ )
9
+ from databao_context_engine.llm.install import resolve_ollama_bin
10
+
11
+
12
+ def install_ollama_if_needed() -> Path:
13
+ """Install the Ollama CLI locally if needed.
14
+
15
+ This will look for any existing installation of Ollama on the system. If none is found, it will install it locally.
16
+
17
+ Here is the priority order of how it looks for an installed Ollama CLI binary:
18
+ 1. Look at the path defined in the DCE_OLLAMA_BIN env var, if it is set
19
+ 2. Look for `ollama` in the PATH
20
+ 3. Look for a DCE-managed installation in the global DCE path
21
+
22
+ If Ollama is not found, it will get installed as a DCE-managed installation in the global DCE path.
23
+
24
+ Returns:
25
+ The path to the Ollama CLI executable.
26
+ """
27
+ return Path(resolve_ollama_bin())
28
+
29
+
30
+ def download_ollama_models_if_needed(
31
+ *, download_embed_model: bool = True, download_description_generator_model: bool = False
32
+ ) -> None:
33
+ """Download the Ollama models required to run DCE if needed.
34
+
35
+ If the models were already downloaded, this method will do nothing.
36
+
37
+ If no Ollama CLI is found on the system, this method will install one as a DCE-managed installation in the global DCE path.
38
+
39
+ Args:
40
+ download_embed_model: Whether to download the embedding model.
41
+ download_description_generator_model: Whether to download the description generator model.
42
+
43
+ Raises:
44
+ OllamaError: If there is an error downloading one of the models.
45
+ """
46
+ ollama_service = create_ollama_service()
47
+
48
+ if download_embed_model:
49
+ try:
50
+ ollama_service.pull_model_if_needed(model=DEFAULT_EMBED_MODEL_ID)
51
+ except OllamaError as e:
52
+ raise e
53
+ if download_description_generator_model:
54
+ try:
55
+ ollama_service.pull_model_if_needed(model=DEFAULT_DESCRIPTION_GENERATOR_MODEL)
56
+ except OllamaError as e:
57
+ raise e
@@ -16,6 +16,4 @@ class OllamaDescriptionProvider(DescriptionProvider):
16
16
  return self._model_id
17
17
 
18
18
  def describe(self, text: str, context: str) -> str:
19
- description = self._service.describe(model=self._model_id, text=text, context=context)
20
-
21
- return description
19
+ return self._service.describe(model=self._model_id, text=text, context=context)
@@ -5,6 +5,9 @@ from databao_context_engine.llm.install import resolve_ollama_bin
5
5
  from databao_context_engine.llm.runtime import OllamaRuntime
6
6
  from databao_context_engine.llm.service import OllamaService
7
7
 
8
+ DEFAULT_EMBED_MODEL_ID = "nomic-embed-text:v1.5"
9
+ DEFAULT_DESCRIPTION_GENERATOR_MODEL = "llama3.2:1b"
10
+
8
11
 
9
12
  def _create_ollama_service_common(
10
13
  *,
@@ -39,7 +42,7 @@ def create_ollama_service(
39
42
  def create_ollama_embedding_provider(
40
43
  service: OllamaService,
41
44
  *,
42
- model_id: str = "nomic-embed-text:v1.5",
45
+ model_id: str = DEFAULT_EMBED_MODEL_ID,
43
46
  dim: int = 768,
44
47
  pull_if_needed: bool = True,
45
48
  ) -> OllamaEmbeddingProvider:
@@ -52,7 +55,7 @@ def create_ollama_embedding_provider(
52
55
  def create_ollama_description_provider(
53
56
  service: OllamaService,
54
57
  *,
55
- model_id: str = "llama3.2:1b",
58
+ model_id: str = DEFAULT_DESCRIPTION_GENERATOR_MODEL,
56
59
  pull_if_needed: bool = True,
57
60
  ):
58
61
  if pull_if_needed:
@@ -95,16 +95,18 @@ def _detect_platform() -> str:
95
95
  raise RuntimeError(f"Unsupported OS/arch: os={os_name!r} arch={arch!r}")
96
96
 
97
97
 
98
- def _download_to_temp(url: str) -> Path:
98
+ def _download_artifact_to_temp(artifact_version: str, artifact_name: str) -> Path:
99
99
  """Download to a temporary file and return its path."""
100
100
  import urllib.request
101
101
 
102
+ artifact_url = f"https://github.com/ollama/ollama/releases/download/{artifact_version}/{artifact_name}"
103
+
102
104
  tmp_dir = Path(tempfile.mkdtemp(prefix="ollama-download-"))
103
- file_name = url.rsplit("/", 1)[-1]
105
+ file_name = artifact_url.rsplit("/", 1)[-1]
104
106
  dest = tmp_dir / file_name
105
107
 
106
- logger.info("Downloading %s to %s", url, dest)
107
- with urllib.request.urlopen(url) as resp, dest.open("wb") as out:
108
+ logger.info("Downloading %s to %s", artifact_url, dest)
109
+ with urllib.request.urlopen(artifact_url) as resp, dest.open("wb") as out:
108
110
  shutil.copyfileobj(resp, out)
109
111
 
110
112
  return dest
@@ -128,10 +130,12 @@ def _extract_archive(archive: Path, target_dir: Path) -> None:
128
130
 
129
131
  if name.endswith(".zip"):
130
132
  with ZipFile(archive, "r") as zf:
131
- zf.extractall(target_dir)
133
+ # There is no built-in protection against zip bombs in ZipFile.
134
+ # However, we previously checked the sha256 of the downloaded archive and we trust the origin (GitHub repo of Ollama)
135
+ zf.extractall(target_dir) # noqa: S202
132
136
  elif name.endswith(".tgz") or name.endswith(".tar.gz"):
133
137
  with tarfile.open(archive, "r:gz") as tf:
134
- tf.extractall(target_dir)
138
+ tf.extractall(target_dir, filter="data")
135
139
  else:
136
140
  raise RuntimeError(f"Unsupported archive format: {archive}")
137
141
 
@@ -142,7 +146,7 @@ def _ensure_executable(path: Path) -> None:
142
146
  mode = path.stat().st_mode
143
147
  path.chmod(mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
144
148
  except Exception:
145
- pass
149
+ logger.debug("Failed to mark %s as executable", path, exc_info=True, stack_info=True)
146
150
 
147
151
 
148
152
  def install_ollama_to(target: Path) -> None:
@@ -172,8 +176,7 @@ def install_ollama_to(target: Path) -> None:
172
176
  except KeyError as e:
173
177
  raise RuntimeError(f"Unsupported platform: {platform_key}") from e
174
178
 
175
- url = f"https://github.com/ollama/ollama/releases/download/{DEFAULT_VERSION}/{artifact.name}"
176
- archive_path = _download_to_temp(url)
179
+ archive_path = _download_artifact_to_temp(DEFAULT_VERSION, artifact.name)
177
180
 
178
181
  try:
179
182
  _verify_sha256(archive_path, artifact.sha256)
@@ -217,4 +220,4 @@ def install_ollama_to(target: Path) -> None:
217
220
  try:
218
221
  archive_path.unlink(missing_ok=True)
219
222
  except Exception:
220
- pass
223
+ logger.debug("Failed to remove temporary archive %s", archive_path, exc_info=True, stack_info=True)
@@ -26,7 +26,7 @@ class OllamaRuntime:
26
26
 
27
27
  stdout = subprocess.DEVNULL
28
28
 
29
- proc = subprocess.Popen(
29
+ return subprocess.Popen( # noqa: S603 We're always running Ollama
30
30
  cmd,
31
31
  cwd=str(self._config.work_dir) if self._config.work_dir else None,
32
32
  env=env,
@@ -36,8 +36,6 @@ class OllamaRuntime:
36
36
  close_fds=os.name != "nt",
37
37
  )
38
38
 
39
- return proc
40
-
41
39
  def start_and_await(
42
40
  self,
43
41
  *,
@@ -62,11 +60,11 @@ class OllamaRuntime:
62
60
  try:
63
61
  proc.terminate()
64
62
  except Exception:
65
- pass
63
+ logger.debug("Failed to terminate Ollama server", exc_info=True, stack_info=True)
66
64
  try:
67
65
  proc.kill()
68
66
  except Exception:
69
- pass
67
+ logger.debug("Failed to kill Ollama server", exc_info=True, stack_info=True)
70
68
 
71
69
  raise TimeoutError(
72
70
  f"Timed out waiting for Ollama to become healthy at http://{self._config.host}:{self._config.port}"
@@ -47,9 +47,7 @@ class McpServer:
47
47
  annotations=ToolAnnotations(readOnlyHint=True, idempotentHint=True, openWorldHint=False),
48
48
  )
49
49
  def retrieve_tool(text: str, limit: int | None):
50
- retrieve_results = self._databao_context_engine.search_context(
51
- retrieve_text=text, limit=limit, export_to_file=False
52
- )
50
+ retrieve_results = self._databao_context_engine.search_context(retrieve_text=text, limit=limit)
53
51
 
54
52
  display_results = [context_search_result.context_result for context_search_result in retrieve_results]
55
53
 
@@ -38,8 +38,8 @@ class DatabaoContextPluginLoader:
38
38
  for (datasource_type, plugin) in self._all_plugins_by_type.items()
39
39
  if not isinstance(plugin, BuildFilePlugin)
40
40
  }
41
- else:
42
- return set(self._all_plugins_by_type.keys())
41
+
42
+ return set(self._all_plugins_by_type.keys())
43
43
 
44
44
  def get_plugin_for_datasource_type(self, datasource_type: DatasourceType) -> BuildPlugin:
45
45
  """Return the plugin able to build a context for the given datasource type.
@@ -103,9 +103,8 @@ class DatabaoContextPluginLoader:
103
103
 
104
104
  if isinstance(plugin, CustomiseConfigProperties):
105
105
  return plugin.get_config_file_properties()
106
- elif isinstance(plugin, BuildDatasourcePlugin):
106
+ if isinstance(plugin, BuildDatasourcePlugin):
107
107
  return get_property_list_from_type(plugin.config_file_type)
108
- else:
109
- raise ValueError(
110
- f'Impossible to create a config for datasource type "{datasource_type.full_type}". The corresponding plugin is a {type(plugin).__name__} but should be a BuildDatasourcePlugin or CustomiseConfigProperties'
111
- )
108
+ raise ValueError(
109
+ f'Impossible to create a config for datasource type "{datasource_type.full_type}". The corresponding plugin is a {type(plugin).__name__} but should be a BuildDatasourcePlugin or CustomiseConfigProperties'
110
+ )
@@ -88,36 +88,3 @@ class DatasourceType:
88
88
  """
89
89
 
90
90
  full_type: str
91
-
92
- def __post_init__(self):
93
- type_segments = self.full_type.split("/")
94
- if len(type_segments) != 2:
95
- raise ValueError(f"Invalid DatasourceType: {self.full_type}")
96
-
97
- @property
98
- def main_type(self) -> str:
99
- """The main type of the datasource, aka the folder in which the config or raw file is located."""
100
- return self.full_type.split("/")[0]
101
-
102
- @property
103
- def config_folder(self) -> str:
104
- """The folder in which the config or raw file is located. This is equivalent to `main_type`."""
105
- return self.main_type
106
-
107
- @property
108
- def subtype(self) -> str:
109
- """The subtype of the datasource. This is the actual type declared in the config file or the raw file's extension."""
110
- return self.full_type.split("/")[1]
111
-
112
- @staticmethod
113
- def from_main_and_subtypes(main_type: str, subtype: str) -> "DatasourceType":
114
- """Create a DatasourceType from its main type and subtype.
115
-
116
- Args:
117
- main_type: The main type (aka config folder) of the datasource.
118
- subtype: The subtype of the datasource.
119
-
120
- Returns:
121
- A DatasourceType instance with the specified main type and subtype.
122
- """
123
- return DatasourceType(full_type=f"{main_type}/{subtype}")
@@ -1,11 +1,11 @@
1
- from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
2
- from databao_context_engine.plugins.databases.athena_introspector import AthenaConfigFile, AthenaIntrospector
1
+ from databao_context_engine.plugins.databases.athena.athena_introspector import AthenaConfigFile, AthenaIntrospector
2
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
3
3
 
4
4
 
5
5
  class AthenaDbPlugin(BaseDatabasePlugin[AthenaConfigFile]):
6
6
  id = "jetbrains/athena"
7
7
  name = "Athena DB Plugin"
8
- supported = {"databases/athena"}
8
+ supported = {"athena"}
9
9
  config_file_type = AthenaConfigFile
10
10
 
11
11
  def __init__(self):
@@ -7,7 +7,7 @@ from pyathena.cursor import DictCursor
7
7
  from pydantic import BaseModel, Field
8
8
 
9
9
  from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
10
- from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
10
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
11
11
  from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
12
12
  from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
13
13
  from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
@@ -67,7 +67,7 @@ class AthenaIntrospector(BaseIntrospector[AthenaConfigFile]):
67
67
  }
68
68
  supports_catalogs = True
69
69
 
70
- def _connect(self, file_config: AthenaConfigFile):
70
+ def _connect(self, file_config: AthenaConfigFile, *, catalog: str | None = None) -> Any:
71
71
  return connect(**file_config.connection.to_athena_kwargs(), cursor_class=DictCursor)
72
72
 
73
73
  def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
@@ -79,9 +79,6 @@ class AthenaIntrospector(BaseIntrospector[AthenaConfigFile]):
79
79
  catalog = file_config.connection.catalog or self._resolve_pseudo_catalog_name(file_config)
80
80
  return [catalog]
81
81
 
82
- def _connect_to_catalog(self, file_config: AthenaConfigFile, catalog: str):
83
- return self._connect(file_config)
84
-
85
82
  def _sql_list_schemas(self, catalogs: list[str] | None) -> SQLQuery:
86
83
  if not catalogs:
87
84
  return SQLQuery("SELECT schema_name, catalog_name FROM information_schema.schemata", None)
@@ -37,9 +37,7 @@ class BaseDatabasePlugin(BuildDatasourcePlugin[T]):
37
37
  return self.supported
38
38
 
39
39
  def build_context(self, full_type: str, datasource_name: str, file_config: T) -> Any:
40
- introspection_result = self._introspector.introspect_database(file_config)
41
-
42
- return introspection_result
40
+ return self._introspector.introspect_database(file_config)
43
41
 
44
42
  def check_connection(self, full_type: str, datasource_name: str, file_config: T) -> None:
45
43
  self._introspector.check_connection(file_config)
@@ -40,7 +40,7 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
40
40
 
41
41
  discovered_schemas_per_catalog: dict[str, list[str]] = {}
42
42
  for catalog in catalogs:
43
- with self._connect_to_catalog(file_config, catalog) as conn:
43
+ with self._connect(file_config, catalog=catalog) as conn:
44
44
  discovered_schemas_per_catalog[catalog] = self._list_schemas_for_catalog(conn, catalog)
45
45
  scope = scope_matcher.filter_scopes(catalogs, discovered_schemas_per_catalog)
46
46
 
@@ -50,7 +50,7 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
50
50
  if not schemas_to_introspect:
51
51
  continue
52
52
 
53
- with self._connect_to_catalog(file_config, catalog) as catalog_connection:
53
+ with self._connect(file_config, catalog=catalog) as catalog_connection:
54
54
  introspected_schemas = self.collect_catalog_model(catalog_connection, catalog, schemas_to_introspect)
55
55
 
56
56
  if not introspected_schemas:
@@ -74,9 +74,9 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
74
74
  if self.supports_catalogs:
75
75
  sql = "SELECT catalog_name, schema_name FROM information_schema.schemata WHERE catalog_name = ANY(%s)"
76
76
  return SQLQuery(sql, (catalogs,))
77
- else:
78
- sql = "SELECT schema_name FROM information_schema.schemata"
79
- return SQLQuery(sql, None)
77
+
78
+ sql = "SELECT schema_name FROM information_schema.schemata"
79
+ return SQLQuery(sql, None)
80
80
 
81
81
  def _list_schemas_for_catalog(self, connection: Any, catalog: str) -> list[str]:
82
82
  sql_query = self._sql_list_schemas([catalog] if self.supports_catalogs else None)
@@ -108,7 +108,12 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
108
108
  return samples
109
109
 
110
110
  @abstractmethod
111
- def _connect(self, file_config: T):
111
+ def _connect(self, file_config: T, *, catalog: str | None = None) -> Any:
112
+ """Connect to the database.
113
+
114
+ If the `catalog` argument is provided, the connection is "scoped" to that catalog. For engines that don’t need a new connection,
115
+ return a connection with the session set/USE’d to that catalog.
116
+ """
112
117
  raise NotImplementedError
113
118
 
114
119
  @abstractmethod
@@ -119,14 +124,6 @@ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
119
124
  def _get_catalogs(self, connection, file_config: T) -> list[str]:
120
125
  raise NotImplementedError
121
126
 
122
- @abstractmethod
123
- def _connect_to_catalog(self, file_config: T, catalog: str):
124
- """Return a connection scoped to `catalog`.
125
-
126
- For engines that don’t need a new connection, return a connection with the
127
- session set/USE’d to that catalog.
128
- """
129
-
130
127
  def _sql_sample_rows(self, catalog: str, schema: str, table: str, limit: int) -> SQLQuery:
131
128
  raise NotImplementedError
132
129
 
@@ -1,5 +1,5 @@
1
- from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
2
- from databao_context_engine.plugins.databases.clickhouse_introspector import (
1
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.clickhouse.clickhouse_introspector import (
3
3
  ClickhouseConfigFile,
4
4
  ClickhouseIntrospector,
5
5
  )
@@ -8,7 +8,7 @@ from databao_context_engine.plugins.databases.clickhouse_introspector import (
8
8
  class ClickhouseDbPlugin(BaseDatabasePlugin[ClickhouseConfigFile]):
9
9
  id = "jetbrains/clickhouse"
10
10
  name = "Clickhouse DB Plugin"
11
- supported = {"databases/clickhouse"}
11
+ supported = {"clickhouse"}
12
12
  config_file_type = ClickhouseConfigFile
13
13
 
14
14
  def __init__(self):
@@ -6,7 +6,7 @@ import clickhouse_connect
6
6
  from pydantic import BaseModel, Field
7
7
 
8
8
  from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
9
- from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
9
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
10
10
  from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
11
11
  from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
12
12
  from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
@@ -36,14 +36,11 @@ class ClickhouseIntrospector(BaseIntrospector[ClickhouseConfigFile]):
36
36
 
37
37
  supports_catalogs = True
38
38
 
39
- def _connect(self, file_config: ClickhouseConfigFile):
39
+ def _connect(self, file_config: ClickhouseConfigFile, *, catalog: str | None = None):
40
40
  return clickhouse_connect.get_client(
41
41
  **file_config.connection.to_clickhouse_kwargs(),
42
42
  )
43
43
 
44
- def _connect_to_catalog(self, file_config: ClickhouseConfigFile, catalog: str):
45
- return self._connect(file_config)
46
-
47
44
  def _get_catalogs(self, connection, file_config: ClickhouseConfigFile) -> list[str]:
48
45
  return ["clickhouse"]
49
46
 
@@ -0,0 +1,12 @@
1
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.duckdb.duckdb_introspector import DuckDBConfigFile, DuckDBIntrospector
3
+
4
+
5
+ class DuckDbPlugin(BaseDatabasePlugin[DuckDBConfigFile]):
6
+ id = "jetbrains/duckdb"
7
+ name = "DuckDB Plugin"
8
+ supported = {"duckdb"}
9
+ config_file_type = DuckDBConfigFile
10
+
11
+ def __init__(self):
12
+ super().__init__(DuckDBIntrospector())
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import duckdb
4
4
  from pydantic import BaseModel, Field
5
5
 
6
- from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
6
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
7
7
  from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
8
8
  from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
9
9
  from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
@@ -16,7 +16,7 @@ class DuckDBConfigFile(BaseDatabaseConfigFile):
16
16
 
17
17
 
18
18
  class DuckDBConnectionConfig(BaseModel):
19
- database: str = Field(description="Path to the DuckDB database file")
19
+ database_path: str = Field(description="Path to the DuckDB database file")
20
20
 
21
21
 
22
22
  class DuckDBIntrospector(BaseIntrospector[DuckDBConfigFile]):
@@ -24,13 +24,10 @@ class DuckDBIntrospector(BaseIntrospector[DuckDBConfigFile]):
24
24
  _IGNORED_SCHEMAS = {"information_schema", "pg_catalog"}
25
25
  supports_catalogs = True
26
26
 
27
- def _connect(self, file_config: DuckDBConfigFile):
28
- database_path = str(file_config.connection.database)
27
+ def _connect(self, file_config: DuckDBConfigFile, *, catalog: str | None = None):
28
+ database_path = str(file_config.connection.database_path)
29
29
  return duckdb.connect(database=database_path)
30
30
 
31
- def _connect_to_catalog(self, file_config: DuckDBConfigFile, catalog: str):
32
- return self._connect(file_config)
33
-
34
31
  def _get_catalogs(self, connection, file_config: DuckDBConfigFile) -> list[str]:
35
32
  rows = self._fetchall_dicts(connection, "SELECT database_name FROM duckdb_databases();", None)
36
33
  catalogs = [r["database_name"] for r in rows if r.get("database_name")]
@@ -1,11 +1,11 @@
1
- from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
2
- from databao_context_engine.plugins.databases.mssql_introspector import MSSQLConfigFile, MSSQLIntrospector
1
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.mssql.mssql_introspector import MSSQLConfigFile, MSSQLIntrospector
3
3
 
4
4
 
5
5
  class MSSQLDbPlugin(BaseDatabasePlugin[MSSQLConfigFile]):
6
6
  id = "jetbrains/mssql"
7
7
  name = "MSSQL DB Plugin"
8
- supported = {"databases/mssql"}
8
+ supported = {"mssql"}
9
9
  config_file_type = MSSQLConfigFile
10
10
 
11
11
  def __init__(self):
@@ -6,7 +6,7 @@ from mssql_python import connect # type: ignore[import-untyped]
6
6
  from pydantic import BaseModel, Field
7
7
 
8
8
  from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
9
- from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
9
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
10
10
  from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
11
11
  from databao_context_engine.plugins.databases.databases_types import DatabaseSchema, DatabaseTable
12
12
  from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
@@ -55,15 +55,15 @@ class MSSQLIntrospector(BaseIntrospector[MSSQLConfigFile]):
55
55
  )
56
56
  supports_catalogs = True
57
57
 
58
- def _connect(self, file_config: MSSQLConfigFile):
58
+ def _connect(self, file_config: MSSQLConfigFile, *, catalog: str | None = None):
59
59
  connection = file_config.connection
60
- connection_string = self._create_connection_string_for_config(connection.to_mssql_kwargs())
61
- return connect(connection_string)
62
60
 
63
- def _connect_to_catalog(self, file_config: MSSQLConfigFile, catalog: str):
64
- cfg = file_config.model_copy(deep=True)
65
- cfg.connection.database = catalog
66
- return self._connect(cfg)
61
+ connection_kwargs = connection.to_mssql_kwargs()
62
+ if catalog:
63
+ connection_kwargs["database"] = catalog
64
+
65
+ connection_string = self._create_connection_string_for_config(connection_kwargs)
66
+ return connect(connection_string)
67
67
 
68
68
  def _get_catalogs(self, connection, file_config: MSSQLConfigFile) -> list[str]:
69
69
  database = file_config.connection.database
@@ -422,8 +422,7 @@ class MSSQLIntrospector(BaseIntrospector[MSSQLConfigFile]):
422
422
  "trust_server_certificate": "yes" if file_config.get("trust_server_certificate") else None,
423
423
  }
424
424
 
425
- connection_string = ";".join(f"{k}={v}" for k, v in connection_parts.items() if v is not None)
426
- return connection_string
425
+ return ";".join(f"{k}={v}" for k, v in connection_parts.items() if v is not None)
427
426
 
428
427
  def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
429
428
  with connection.cursor() as cursor:
@@ -1,11 +1,11 @@
1
- from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
2
- from databao_context_engine.plugins.databases.mysql_introspector import MySQLConfigFile, MySQLIntrospector
1
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.mysql.mysql_introspector import MySQLConfigFile, MySQLIntrospector
3
3
 
4
4
 
5
5
  class MySQLDbPlugin(BaseDatabasePlugin[MySQLConfigFile]):
6
6
  id = "jetbrains/mysql"
7
7
  name = "MySQL DB Plugin"
8
- supported = {"databases/mysql"}
8
+ supported = {"mysql"}
9
9
  config_file_type = MySQLConfigFile
10
10
 
11
11
  def __init__(self):
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
5
5
  from pymysql.constants import CLIENT
6
6
 
7
7
  from databao_context_engine.pluginlib.config import ConfigPropertyAnnotation
8
- from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
8
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabaseConfigFile
9
9
  from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
10
10
  from databao_context_engine.plugins.databases.databases_types import DatabaseSchema, DatabaseTable
11
11
  from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
@@ -35,18 +35,18 @@ class MySQLIntrospector(BaseIntrospector[MySQLConfigFile]):
35
35
 
36
36
  supports_catalogs = True
37
37
 
38
- def _connect(self, file_config: MySQLConfigFile):
38
+ def _connect(self, file_config: MySQLConfigFile, *, catalog: str | None = None):
39
+ connection_kwargs = file_config.connection.to_pymysql_kwargs()
40
+
41
+ if catalog:
42
+ connection_kwargs["database"] = catalog
43
+
39
44
  return pymysql.connect(
40
- **file_config.connection.to_pymysql_kwargs(),
45
+ **connection_kwargs,
41
46
  cursorclass=pymysql.cursors.DictCursor,
42
47
  client_flag=CLIENT.MULTI_STATEMENTS | CLIENT.MULTI_RESULTS,
43
48
  )
44
49
 
45
- def _connect_to_catalog(self, file_config: MySQLConfigFile, catalog: str):
46
- cfg = file_config.model_copy(deep=True)
47
- cfg.connection.database = catalog
48
- return self._connect(cfg)
49
-
50
50
  def _get_catalogs(self, connection, file_config: MySQLConfigFile) -> list[str]:
51
51
  with connection.cursor() as cur:
52
52
  cur.execute(
@@ -0,0 +1,15 @@
1
+ from databao_context_engine.plugins.databases.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.postgresql.postgresql_introspector import (
3
+ PostgresConfigFile,
4
+ PostgresqlIntrospector,
5
+ )
6
+
7
+
8
+ class PostgresqlDbPlugin(BaseDatabasePlugin[PostgresConfigFile]):
9
+ id = "jetbrains/postgres"
10
+ name = "PostgreSQL DB Plugin"
11
+ supported = {"postgres"}
12
+ config_file_type = PostgresConfigFile
13
+
14
+ def __init__(self):
15
+ super().__init__(PostgresqlIntrospector())