databao-context-engine 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. databao_context_engine/__init__.py +35 -0
  2. databao_context_engine/build_sources/__init__.py +0 -0
  3. databao_context_engine/build_sources/internal/__init__.py +0 -0
  4. databao_context_engine/build_sources/internal/build_runner.py +111 -0
  5. databao_context_engine/build_sources/internal/build_service.py +77 -0
  6. databao_context_engine/build_sources/internal/build_wiring.py +52 -0
  7. databao_context_engine/build_sources/internal/export_results.py +43 -0
  8. databao_context_engine/build_sources/internal/plugin_execution.py +74 -0
  9. databao_context_engine/build_sources/public/__init__.py +0 -0
  10. databao_context_engine/build_sources/public/api.py +4 -0
  11. databao_context_engine/cli/__init__.py +0 -0
  12. databao_context_engine/cli/add_datasource_config.py +130 -0
  13. databao_context_engine/cli/commands.py +256 -0
  14. databao_context_engine/cli/datasources.py +64 -0
  15. databao_context_engine/cli/info.py +32 -0
  16. databao_context_engine/config/__init__.py +0 -0
  17. databao_context_engine/config/log_config.yaml +16 -0
  18. databao_context_engine/config/logging.py +43 -0
  19. databao_context_engine/databao_context_project_manager.py +92 -0
  20. databao_context_engine/databao_engine.py +85 -0
  21. databao_context_engine/datasource_config/__init__.py +0 -0
  22. databao_context_engine/datasource_config/add_config.py +50 -0
  23. databao_context_engine/datasource_config/check_config.py +131 -0
  24. databao_context_engine/datasource_config/datasource_context.py +60 -0
  25. databao_context_engine/event_journal/__init__.py +0 -0
  26. databao_context_engine/event_journal/writer.py +29 -0
  27. databao_context_engine/generate_configs_schemas.py +92 -0
  28. databao_context_engine/init_project.py +18 -0
  29. databao_context_engine/introspection/__init__.py +0 -0
  30. databao_context_engine/introspection/property_extract.py +202 -0
  31. databao_context_engine/llm/__init__.py +0 -0
  32. databao_context_engine/llm/config.py +20 -0
  33. databao_context_engine/llm/descriptions/__init__.py +0 -0
  34. databao_context_engine/llm/descriptions/ollama.py +21 -0
  35. databao_context_engine/llm/descriptions/provider.py +10 -0
  36. databao_context_engine/llm/embeddings/__init__.py +0 -0
  37. databao_context_engine/llm/embeddings/ollama.py +37 -0
  38. databao_context_engine/llm/embeddings/provider.py +13 -0
  39. databao_context_engine/llm/errors.py +16 -0
  40. databao_context_engine/llm/factory.py +61 -0
  41. databao_context_engine/llm/install.py +227 -0
  42. databao_context_engine/llm/runtime.py +73 -0
  43. databao_context_engine/llm/service.py +159 -0
  44. databao_context_engine/main.py +19 -0
  45. databao_context_engine/mcp/__init__.py +0 -0
  46. databao_context_engine/mcp/all_results_tool.py +5 -0
  47. databao_context_engine/mcp/mcp_runner.py +16 -0
  48. databao_context_engine/mcp/mcp_server.py +63 -0
  49. databao_context_engine/mcp/retrieve_tool.py +22 -0
  50. databao_context_engine/pluginlib/__init__.py +0 -0
  51. databao_context_engine/pluginlib/build_plugin.py +107 -0
  52. databao_context_engine/pluginlib/config.py +37 -0
  53. databao_context_engine/pluginlib/plugin_utils.py +68 -0
  54. databao_context_engine/plugins/__init__.py +0 -0
  55. databao_context_engine/plugins/athena_db_plugin.py +12 -0
  56. databao_context_engine/plugins/base_db_plugin.py +45 -0
  57. databao_context_engine/plugins/clickhouse_db_plugin.py +15 -0
  58. databao_context_engine/plugins/databases/__init__.py +0 -0
  59. databao_context_engine/plugins/databases/athena_introspector.py +101 -0
  60. databao_context_engine/plugins/databases/base_introspector.py +144 -0
  61. databao_context_engine/plugins/databases/clickhouse_introspector.py +162 -0
  62. databao_context_engine/plugins/databases/database_chunker.py +69 -0
  63. databao_context_engine/plugins/databases/databases_types.py +114 -0
  64. databao_context_engine/plugins/databases/duckdb_introspector.py +325 -0
  65. databao_context_engine/plugins/databases/introspection_model_builder.py +270 -0
  66. databao_context_engine/plugins/databases/introspection_scope.py +74 -0
  67. databao_context_engine/plugins/databases/introspection_scope_matcher.py +103 -0
  68. databao_context_engine/plugins/databases/mssql_introspector.py +433 -0
  69. databao_context_engine/plugins/databases/mysql_introspector.py +338 -0
  70. databao_context_engine/plugins/databases/postgresql_introspector.py +428 -0
  71. databao_context_engine/plugins/databases/snowflake_introspector.py +287 -0
  72. databao_context_engine/plugins/duckdb_db_plugin.py +12 -0
  73. databao_context_engine/plugins/mssql_db_plugin.py +12 -0
  74. databao_context_engine/plugins/mysql_db_plugin.py +12 -0
  75. databao_context_engine/plugins/parquet_plugin.py +32 -0
  76. databao_context_engine/plugins/plugin_loader.py +110 -0
  77. databao_context_engine/plugins/postgresql_db_plugin.py +12 -0
  78. databao_context_engine/plugins/resources/__init__.py +0 -0
  79. databao_context_engine/plugins/resources/parquet_chunker.py +23 -0
  80. databao_context_engine/plugins/resources/parquet_introspector.py +154 -0
  81. databao_context_engine/plugins/snowflake_db_plugin.py +12 -0
  82. databao_context_engine/plugins/unstructured_files_plugin.py +68 -0
  83. databao_context_engine/project/__init__.py +0 -0
  84. databao_context_engine/project/datasource_discovery.py +141 -0
  85. databao_context_engine/project/info.py +44 -0
  86. databao_context_engine/project/init_project.py +102 -0
  87. databao_context_engine/project/layout.py +127 -0
  88. databao_context_engine/project/project_config.py +32 -0
  89. databao_context_engine/project/resources/examples/src/databases/example_postgres.yaml +7 -0
  90. databao_context_engine/project/resources/examples/src/files/documentation.md +30 -0
  91. databao_context_engine/project/resources/examples/src/files/notes.txt +20 -0
  92. databao_context_engine/project/runs.py +39 -0
  93. databao_context_engine/project/types.py +134 -0
  94. databao_context_engine/retrieve_embeddings/__init__.py +0 -0
  95. databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
  96. databao_context_engine/retrieve_embeddings/internal/export_results.py +12 -0
  97. databao_context_engine/retrieve_embeddings/internal/retrieve_runner.py +34 -0
  98. databao_context_engine/retrieve_embeddings/internal/retrieve_service.py +68 -0
  99. databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +29 -0
  100. databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
  101. databao_context_engine/retrieve_embeddings/public/api.py +3 -0
  102. databao_context_engine/serialisation/__init__.py +0 -0
  103. databao_context_engine/serialisation/yaml.py +35 -0
  104. databao_context_engine/services/__init__.py +0 -0
  105. databao_context_engine/services/chunk_embedding_service.py +104 -0
  106. databao_context_engine/services/embedding_shard_resolver.py +64 -0
  107. databao_context_engine/services/factories.py +88 -0
  108. databao_context_engine/services/models.py +12 -0
  109. databao_context_engine/services/persistence_service.py +61 -0
  110. databao_context_engine/services/run_name_policy.py +8 -0
  111. databao_context_engine/services/table_name_policy.py +15 -0
  112. databao_context_engine/storage/__init__.py +0 -0
  113. databao_context_engine/storage/connection.py +32 -0
  114. databao_context_engine/storage/exceptions/__init__.py +0 -0
  115. databao_context_engine/storage/exceptions/exceptions.py +6 -0
  116. databao_context_engine/storage/migrate.py +127 -0
  117. databao_context_engine/storage/migrations/V01__init.sql +63 -0
  118. databao_context_engine/storage/models.py +51 -0
  119. databao_context_engine/storage/repositories/__init__.py +0 -0
  120. databao_context_engine/storage/repositories/chunk_repository.py +130 -0
  121. databao_context_engine/storage/repositories/datasource_run_repository.py +136 -0
  122. databao_context_engine/storage/repositories/embedding_model_registry_repository.py +87 -0
  123. databao_context_engine/storage/repositories/embedding_repository.py +113 -0
  124. databao_context_engine/storage/repositories/factories.py +35 -0
  125. databao_context_engine/storage/repositories/run_repository.py +157 -0
  126. databao_context_engine/storage/repositories/vector_search_repository.py +63 -0
  127. databao_context_engine/storage/transaction.py +14 -0
  128. databao_context_engine/system/__init__.py +0 -0
  129. databao_context_engine/system/properties.py +13 -0
  130. databao_context_engine/templating/__init__.py +0 -0
  131. databao_context_engine/templating/renderer.py +29 -0
  132. databao_context_engine-0.1.1.dist-info/METADATA +186 -0
  133. databao_context_engine-0.1.1.dist-info/RECORD +135 -0
  134. databao_context_engine-0.1.1.dist-info/WHEEL +4 -0
  135. databao_context_engine-0.1.1.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,107 @@
1
+ from dataclasses import dataclass
2
+ from io import BufferedReader
3
+ from typing import Any, Protocol, runtime_checkable
4
+
5
+
6
+ @dataclass
7
+ class EmbeddableChunk:
8
+ """
9
+ A chunk that will be embedded as a vector and used when searching context from a given AI prompt
10
+ """
11
+
12
+ embeddable_text: str
13
+ """
14
+ The text to embed as a vector for search usage
15
+ """
16
+ content: Any
17
+ """
18
+ The content to return as a response when the embeddings has been selected in a search
19
+ """
20
+
21
+
22
+ class BaseBuildPlugin(Protocol):
23
+ id: str
24
+ name: str
25
+
26
+ def supported_types(self) -> set[str]: ...
27
+
28
+ """
29
+ Returns the list of all supported types for this plugin.
30
+ If the plugin supports multiple types, they should check the type given in the `full_type` argument when `execute` is called.
31
+ """
32
+
33
+ def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]: ...
34
+
35
+ """
36
+ A method dividing the data source context into meaningful chunks that will be used when searching the context from an AI prompt.
37
+ """
38
+
39
+
40
+ @runtime_checkable
41
+ class BuildDatasourcePlugin[T](BaseBuildPlugin, Protocol):
42
+ config_file_type: type[T]
43
+
44
+ def build_context(self, full_type: str, datasource_name: str, file_config: T) -> Any: ...
45
+
46
+ """
47
+ The method that will be called when a config file has been found for a data source supported by this plugin.
48
+ """
49
+
50
+ def check_connection(self, full_type: str, datasource_name: str, file_config: T) -> None:
51
+ """
52
+ Checks whether the configuration to the datasource is working.
53
+
54
+ The function is expected to succeed without a result if the connection is working.
55
+ If something is wrong with the connection, the function should raise an Exception
56
+ """
57
+ raise NotSupportedError("This method is not implemented for this plugin")
58
+
59
+
60
+ class DefaultBuildDatasourcePlugin(BuildDatasourcePlugin[dict[str, Any]], Protocol):
61
+ """
62
+ Use this as a base class for plugins that don't need a specific config file type.
63
+ """
64
+
65
+ config_file_type: type[dict[str, Any]] = dict[str, Any]
66
+
67
+
68
+ @runtime_checkable
69
+ class BuildFilePlugin(BaseBuildPlugin, Protocol):
70
+ def build_file_context(self, full_type: str, file_name: str, file_buffer: BufferedReader) -> Any: ...
71
+
72
+ """
73
+ The method that will be called when a file has been found as a data source supported by this plugin.
74
+ """
75
+
76
+
77
+ class NotSupportedError(RuntimeError):
78
+ """Exception raised by methods not supported by a plugin"""
79
+
80
+
81
+ BuildPlugin = BuildDatasourcePlugin | BuildFilePlugin
82
+
83
+
84
+ @dataclass(kw_only=True, frozen=True)
85
+ class DatasourceType:
86
+ full_type: str
87
+
88
+ def __post_init__(self):
89
+ type_segments = self.full_type.split("/")
90
+ if len(type_segments) != 2:
91
+ raise ValueError(f"Invalid DatasourceType: {self.full_type}")
92
+
93
+ @property
94
+ def main_type(self) -> str:
95
+ return self.full_type.split("/")[0]
96
+
97
+ @property
98
+ def config_folder(self) -> str:
99
+ return self.main_type
100
+
101
+ @property
102
+ def subtype(self) -> str:
103
+ return self.full_type.split("/")[1]
104
+
105
+ @staticmethod
106
+ def from_main_and_subtypes(main_type: str, subtype: str) -> "DatasourceType":
107
+ return DatasourceType(full_type=f"{main_type}/{subtype}")
@@ -0,0 +1,37 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Protocol, runtime_checkable
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class DuckDBSecret(BaseModel):
8
+ name: str | None = Field(default=None)
9
+ type: str = Field(
10
+ description="DuckDB secret type. Examples: s3, postgres, iceberg, etc. See https://duckdb.org/docs/stable/configuration/secrets_manager#types-of-secrets"
11
+ )
12
+ properties: dict[str, Any] = Field(
13
+ default={},
14
+ description="Key/Value pairs which will be used to create a duckdb secret. "
15
+ "See https://duckdb.org/docs/stable/configuration/secrets_manager",
16
+ )
17
+
18
+
19
+ @dataclass(kw_only=True)
20
+ class ConfigPropertyDefinition:
21
+ property_key: str
22
+ required: bool
23
+ property_type: type | None = str
24
+ default_value: str | None = None
25
+ nested_properties: list["ConfigPropertyDefinition"] | None = None
26
+
27
+
28
+ @dataclass(kw_only=True)
29
+ class ConfigPropertyAnnotation:
30
+ required: bool = False
31
+ default_value: str | None = None
32
+ ignored_for_config_wizard: bool = False
33
+
34
+
35
+ @runtime_checkable
36
+ class CustomiseConfigProperties(Protocol):
37
+ def get_config_file_properties(self) -> list[ConfigPropertyDefinition]: ...
@@ -0,0 +1,68 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Mapping
6
+
7
+ from pydantic import TypeAdapter
8
+
9
+ from databao_context_engine.pluginlib.build_plugin import BuildDatasourcePlugin, BuildFilePlugin, DatasourceType
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def execute_datasource_plugin(
15
+ plugin: BuildDatasourcePlugin, datasource_type: DatasourceType, config: Mapping[str, Any], datasource_name: str
16
+ ) -> Any:
17
+ if not isinstance(plugin, BuildDatasourcePlugin):
18
+ raise ValueError("This method can only execute a BuildDatasourcePlugin")
19
+
20
+ validated_config = _validate_datasource_config_file(config, plugin)
21
+
22
+ return plugin.build_context(
23
+ full_type=datasource_type.full_type,
24
+ datasource_name=datasource_name,
25
+ file_config=validated_config,
26
+ )
27
+
28
+
29
+ def check_connection_for_datasource(
30
+ plugin: BuildDatasourcePlugin, datasource_type: DatasourceType, config: Mapping[str, Any], datasource_name: str
31
+ ) -> None:
32
+ if not isinstance(plugin, BuildDatasourcePlugin):
33
+ raise ValueError("Connection checks can only be performed on BuildDatasourcePlugin")
34
+
35
+ validated_config = _validate_datasource_config_file(config, plugin)
36
+
37
+ plugin.check_connection(
38
+ full_type=datasource_type.full_type,
39
+ datasource_name=datasource_name,
40
+ file_config=validated_config,
41
+ )
42
+
43
+
44
+ def _validate_datasource_config_file(config: Mapping[str, Any], plugin: BuildDatasourcePlugin) -> Any:
45
+ return TypeAdapter(plugin.config_file_type).validate_python(config)
46
+
47
+
48
+ def execute_file_plugin(plugin: BuildFilePlugin, datasource_type: DatasourceType, file_path: Path) -> Any:
49
+ with file_path.open("rb") as fh:
50
+ return plugin.build_file_context(
51
+ full_type=datasource_type.full_type,
52
+ file_name=file_path.name,
53
+ file_buffer=fh,
54
+ )
55
+
56
+
57
+ def generate_json_schema(plugin: BuildDatasourcePlugin, pretty_print: bool = True) -> str | None:
58
+ if plugin.config_file_type == dict[str, Any]:
59
+ logger.debug(f"Skipping json schema generation for plugin {plugin.id}: no custom config_file_type provided")
60
+ return None
61
+
62
+ json_schema = TypeAdapter(plugin.config_file_type).json_schema(mode="serialization")
63
+
64
+ return json.dumps(json_schema, indent=4 if pretty_print else None)
65
+
66
+
67
+ def format_json_schema_for_output(plugin: BuildDatasourcePlugin, json_schema: str) -> str:
68
+ return os.linesep.join([f"JSON Schema for plugin {plugin.id}:", json_schema])
File without changes
@@ -0,0 +1,12 @@
1
+ from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.athena_introspector import AthenaConfigFile, AthenaIntrospector
3
+
4
+
5
+ class AthenaDbPlugin(BaseDatabasePlugin[AthenaConfigFile]):
6
+ id = "jetbrains/athena"
7
+ name = "Athena DB Plugin"
8
+ supported = {"databases/athena"}
9
+ config_file_type = AthenaConfigFile
10
+
11
+ def __init__(self):
12
+ super().__init__(AthenaIntrospector())
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, TypeVar
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+ from databao_context_engine.pluginlib.build_plugin import (
8
+ BuildDatasourcePlugin,
9
+ EmbeddableChunk,
10
+ )
11
+ from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector
12
+ from databao_context_engine.plugins.databases.database_chunker import build_database_chunks
13
+ from databao_context_engine.plugins.databases.introspection_scope import IntrospectionScope
14
+
15
+
16
+ class BaseDatabaseConfigFile(BaseModel):
17
+ model_config = ConfigDict(populate_by_name=True)
18
+ name: str | None = Field(default=None)
19
+ type: str
20
+ introspection_scope: IntrospectionScope | None = Field(default=None, alias="introspection-scope")
21
+
22
+
23
+ T = TypeVar("T", bound=BaseDatabaseConfigFile)
24
+
25
+
26
+ class BaseDatabasePlugin(BuildDatasourcePlugin[T]):
27
+ name: str
28
+ supported: set[str]
29
+
30
+ def __init__(self, introspector: BaseIntrospector):
31
+ self._introspector = introspector
32
+
33
+ def supported_types(self) -> set[str]:
34
+ return self.supported
35
+
36
+ def build_context(self, full_type: str, datasource_name: str, file_config: T) -> Any:
37
+ introspection_result = self._introspector.introspect_database(file_config)
38
+
39
+ return introspection_result
40
+
41
+ def check_connection(self, full_type: str, datasource_name: str, file_config: T) -> None:
42
+ self._introspector.check_connection(file_config)
43
+
44
+ def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]:
45
+ return build_database_chunks(context)
@@ -0,0 +1,15 @@
1
+ from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
2
+ from databao_context_engine.plugins.databases.clickhouse_introspector import (
3
+ ClickhouseConfigFile,
4
+ ClickhouseIntrospector,
5
+ )
6
+
7
+
8
+ class ClickhouseDbPlugin(BaseDatabasePlugin[ClickhouseConfigFile]):
9
+ id = "jetbrains/clickhouse"
10
+ name = "Clickhouse DB Plugin"
11
+ supported = {"databases/clickhouse"}
12
+ config_file_type = ClickhouseConfigFile
13
+
14
+ def __init__(self):
15
+ super().__init__(ClickhouseIntrospector())
File without changes
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Mapping
4
+
5
+ from pyathena import connect
6
+ from pyathena.cursor import DictCursor
7
+ from pydantic import Field
8
+
9
+ from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
10
+ from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
11
+ from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
12
+ from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
13
+
14
+
15
+ class AthenaConfigFile(BaseDatabaseConfigFile):
16
+ type: str = Field(default="databases/athena")
17
+ connection: dict[str, Any] = Field(
18
+ description="Connection parameters for the Athena database. It can contain any of the keys supported by the Athena connection library"
19
+ )
20
+
21
+
22
+ class AthenaIntrospector(BaseIntrospector[AthenaConfigFile]):
23
+ _IGNORED_SCHEMAS = {
24
+ "information_schema",
25
+ }
26
+ supports_catalogs = True
27
+
28
+ def _connect(self, file_config: AthenaConfigFile):
29
+ connection = file_config.connection
30
+ if not isinstance(connection, Mapping):
31
+ raise ValueError("Invalid YAML config: 'connection' must be a mapping of connection parameters")
32
+
33
+ return connect(**connection, cursor_class=DictCursor)
34
+
35
+ def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
36
+ with connection.cursor() as cur:
37
+ cur.execute(sql, params or {})
38
+ return cur.fetchall()
39
+
40
+ def _get_catalogs(self, connection, file_config: AthenaConfigFile) -> list[str]:
41
+ catalog = file_config.connection.get("catalog", self._resolve_pseudo_catalog_name(file_config))
42
+ return [catalog]
43
+
44
+ def _connect_to_catalog(self, file_config: AthenaConfigFile, catalog: str):
45
+ self._connect(file_config)
46
+
47
+ def _sql_list_schemas(self, catalogs: list[str] | None) -> SQLQuery:
48
+ if not catalogs:
49
+ return SQLQuery("SELECT schema_name, catalog_name FROM information_schema.schemata", None)
50
+ catalog = catalogs[0]
51
+ sql = "SELECT schema_name, catalog_name FROM information_schema.schemata WHERE catalog_name = %(catalog)s"
52
+ return SQLQuery(sql, {"catalog": catalog})
53
+
54
+ # TODO: Incomplete plugin. Awaiting permission access to AWS to properly develop
55
+ def collect_catalog_model(self, connection, catalog: str, schemas: list[str]) -> list[DatabaseSchema] | None:
56
+ if not schemas:
57
+ return []
58
+
59
+ comps = {"columns": self._sql_columns(catalog, schemas)}
60
+ results: dict[str, list[dict]] = {}
61
+
62
+ for name, q in comps.items():
63
+ results[name] = self._fetchall_dicts(connection, q.sql, q.params)
64
+
65
+ return IntrospectionModelBuilder.build_schemas_from_components(
66
+ schemas=schemas,
67
+ rels=results.get("relations", []),
68
+ cols=results.get("columns", []),
69
+ pk_cols=[],
70
+ uq_cols=[],
71
+ checks=[],
72
+ fk_cols=[],
73
+ idx_cols=[],
74
+ )
75
+
76
+ def _sql_columns(self, catalog: str, schemas: list[str]) -> SQLQuery:
77
+ sql = f"""
78
+ SELECT
79
+ table_schema AS schema_name,
80
+ table_name,
81
+ column_name,
82
+ ordinal_position,
83
+ data_type,
84
+ is_nullable
85
+ FROM
86
+ {catalog}.information_schema.columns
87
+ WHERE
88
+ table_schema IN ({schemas})
89
+ ORDER BY
90
+ table_schema,
91
+ table_name,
92
+ ordinal_position
93
+ """
94
+ return SQLQuery(sql, {"schema": schemas})
95
+
96
+ def _resolve_pseudo_catalog_name(self, file_config: AthenaConfigFile) -> str:
97
+ return "awsdatacatalog"
98
+
99
+ def _sql_sample_rows(self, catalog: str, schema: str, table: str, limit: int) -> SQLQuery:
100
+ sql = f'SELECT * FROM "{schema}"."{table}" LIMIT %(limit)s'
101
+ return SQLQuery(sql, {"limit": limit})
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from typing import Any, Mapping, Protocol, Sequence, Union
7
+
8
+ from databao_context_engine.plugins.databases.databases_types import (
9
+ DatabaseCatalog,
10
+ DatabaseIntrospectionResult,
11
+ DatabaseSchema,
12
+ )
13
+ from databao_context_engine.plugins.databases.introspection_scope import IntrospectionScope
14
+ from databao_context_engine.plugins.databases.introspection_scope_matcher import IntrospectionScopeMatcher
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SupportsIntrospectionScope(Protocol):
20
+ introspection_scope: IntrospectionScope | None
21
+
22
+
23
+ class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
24
+ supports_catalogs: bool = True
25
+ _IGNORED_SCHEMAS: set[str] = {"information_schema"}
26
+ _SAMPLE_LIMIT: int = 5
27
+
28
+ def check_connection(self, file_config: T) -> None:
29
+ with self._connect(file_config) as connection:
30
+ self._fetchall_dicts(connection, "SELECT 1 as test", None)
31
+
32
+ def introspect_database(self, file_config: T) -> DatabaseIntrospectionResult:
33
+ scope_matcher = IntrospectionScopeMatcher(
34
+ file_config.introspection_scope,
35
+ ignored_schemas=self._ignored_schemas(),
36
+ )
37
+
38
+ with self._connect(file_config) as root_connection:
39
+ catalogs = self._get_catalogs_adapted(root_connection, file_config)
40
+
41
+ discovered_schemas_per_catalog: dict[str, list[str]] = {}
42
+ for catalog in catalogs:
43
+ with self._connect_to_catalog(file_config, catalog) as conn:
44
+ discovered_schemas_per_catalog[catalog] = self._list_schemas_for_catalog(conn, catalog)
45
+ scope = scope_matcher.filter_scopes(catalogs, discovered_schemas_per_catalog)
46
+
47
+ introspected_catalogs: list[DatabaseCatalog] = []
48
+ for catalog in scope.catalogs:
49
+ schemas_to_introspect = scope.schemas_per_catalog.get(catalog, [])
50
+ if not schemas_to_introspect:
51
+ continue
52
+
53
+ with self._connect_to_catalog(file_config, catalog) as catalog_connection:
54
+ introspected_schemas = self.collect_catalog_model(catalog_connection, catalog, schemas_to_introspect)
55
+
56
+ if not introspected_schemas:
57
+ continue
58
+
59
+ for schema in introspected_schemas:
60
+ for table in schema.tables:
61
+ table.samples = self._collect_samples_for_table(
62
+ catalog_connection, catalog, schema.name, table.name
63
+ )
64
+
65
+ introspected_catalogs.append(DatabaseCatalog(name=catalog, schemas=introspected_schemas))
66
+ return DatabaseIntrospectionResult(catalogs=introspected_catalogs)
67
+
68
+ def _get_catalogs_adapted(self, connection, file_config: T) -> list[str]:
69
+ if self.supports_catalogs:
70
+ return self._get_catalogs(connection, file_config)
71
+ return [self._resolve_pseudo_catalog_name(file_config)]
72
+
73
+ def _sql_list_schemas(self, catalogs: list[str] | None) -> SQLQuery:
74
+ if self.supports_catalogs:
75
+ sql = "SELECT catalog_name, schema_name FROM information_schema.schemata WHERE catalog_name = ANY(%s)"
76
+ return SQLQuery(sql, (catalogs,))
77
+ else:
78
+ sql = "SELECT schema_name FROM information_schema.schemata"
79
+ return SQLQuery(sql, None)
80
+
81
+ def _list_schemas_for_catalog(self, connection: Any, catalog: str) -> list[str]:
82
+ sql_query = self._sql_list_schemas([catalog] if self.supports_catalogs else None)
83
+ rows = self._fetchall_dicts(connection, sql_query.sql, sql_query.params)
84
+
85
+ schemas: list[str] = []
86
+ for row in rows:
87
+ schema_name = row.get("schema_name")
88
+ if schema_name:
89
+ schemas.append(schema_name)
90
+
91
+ return schemas
92
+
93
+ @abstractmethod
94
+ def collect_catalog_model(self, connection, catalog: str, schemas: list[str]) -> list[DatabaseSchema] | None:
95
+ raise NotImplementedError
96
+
97
+ def _collect_samples_for_table(self, connection, catalog: str, schema: str, table: str) -> list[dict[str, Any]]:
98
+ samples: list[dict[str, Any]] = []
99
+ if self._SAMPLE_LIMIT > 0:
100
+ try:
101
+ sql_query = self._sql_sample_rows(catalog, schema, table, self._SAMPLE_LIMIT)
102
+ samples = self._fetchall_dicts(connection, sql_query.sql, sql_query.params)
103
+ except NotImplementedError:
104
+ samples = []
105
+ except Exception as e:
106
+ logger.warning("Failed to fetch samples for %s.%s (catalog=%s): %s", schema, table, catalog, e)
107
+ samples = []
108
+ return samples
109
+
110
+ @abstractmethod
111
+ def _connect(self, file_config: T):
112
+ raise NotImplementedError
113
+
114
+ @abstractmethod
115
+ def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
116
+ raise NotImplementedError
117
+
118
+ @abstractmethod
119
+ def _get_catalogs(self, connection, file_config: T) -> list[str]:
120
+ raise NotImplementedError
121
+
122
+ @abstractmethod
123
+ def _connect_to_catalog(self, file_config: T, catalog: str):
124
+ """Return a connection scoped to `catalog`. For engines that
125
+ don’t need a new connection, return a connection with the
126
+ session set/USE’d to that catalog."""
127
+
128
+ def _sql_sample_rows(self, catalog: str, schema: str, table: str, limit: int) -> SQLQuery:
129
+ raise NotImplementedError
130
+
131
+ def _resolve_pseudo_catalog_name(self, file_config: T) -> str:
132
+ return "default"
133
+
134
+ def _ignored_schemas(self) -> set[str]:
135
+ return self._IGNORED_SCHEMAS
136
+
137
+
138
+ @dataclass
139
+ class SQLQuery:
140
+ sql: str
141
+ params: ParamsType = None
142
+
143
+
144
+ ParamsType = Union[Mapping[str, Any], Sequence[Any], None]