databao-context-engine 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databao_context_engine/__init__.py +35 -0
- databao_context_engine/build_sources/__init__.py +0 -0
- databao_context_engine/build_sources/internal/__init__.py +0 -0
- databao_context_engine/build_sources/internal/build_runner.py +111 -0
- databao_context_engine/build_sources/internal/build_service.py +77 -0
- databao_context_engine/build_sources/internal/build_wiring.py +52 -0
- databao_context_engine/build_sources/internal/export_results.py +43 -0
- databao_context_engine/build_sources/internal/plugin_execution.py +74 -0
- databao_context_engine/build_sources/public/__init__.py +0 -0
- databao_context_engine/build_sources/public/api.py +4 -0
- databao_context_engine/cli/__init__.py +0 -0
- databao_context_engine/cli/add_datasource_config.py +130 -0
- databao_context_engine/cli/commands.py +256 -0
- databao_context_engine/cli/datasources.py +64 -0
- databao_context_engine/cli/info.py +32 -0
- databao_context_engine/config/__init__.py +0 -0
- databao_context_engine/config/log_config.yaml +16 -0
- databao_context_engine/config/logging.py +43 -0
- databao_context_engine/databao_context_project_manager.py +92 -0
- databao_context_engine/databao_engine.py +85 -0
- databao_context_engine/datasource_config/__init__.py +0 -0
- databao_context_engine/datasource_config/add_config.py +50 -0
- databao_context_engine/datasource_config/check_config.py +131 -0
- databao_context_engine/datasource_config/datasource_context.py +60 -0
- databao_context_engine/event_journal/__init__.py +0 -0
- databao_context_engine/event_journal/writer.py +29 -0
- databao_context_engine/generate_configs_schemas.py +92 -0
- databao_context_engine/init_project.py +18 -0
- databao_context_engine/introspection/__init__.py +0 -0
- databao_context_engine/introspection/property_extract.py +202 -0
- databao_context_engine/llm/__init__.py +0 -0
- databao_context_engine/llm/config.py +20 -0
- databao_context_engine/llm/descriptions/__init__.py +0 -0
- databao_context_engine/llm/descriptions/ollama.py +21 -0
- databao_context_engine/llm/descriptions/provider.py +10 -0
- databao_context_engine/llm/embeddings/__init__.py +0 -0
- databao_context_engine/llm/embeddings/ollama.py +37 -0
- databao_context_engine/llm/embeddings/provider.py +13 -0
- databao_context_engine/llm/errors.py +16 -0
- databao_context_engine/llm/factory.py +61 -0
- databao_context_engine/llm/install.py +227 -0
- databao_context_engine/llm/runtime.py +73 -0
- databao_context_engine/llm/service.py +159 -0
- databao_context_engine/main.py +19 -0
- databao_context_engine/mcp/__init__.py +0 -0
- databao_context_engine/mcp/all_results_tool.py +5 -0
- databao_context_engine/mcp/mcp_runner.py +16 -0
- databao_context_engine/mcp/mcp_server.py +63 -0
- databao_context_engine/mcp/retrieve_tool.py +22 -0
- databao_context_engine/pluginlib/__init__.py +0 -0
- databao_context_engine/pluginlib/build_plugin.py +107 -0
- databao_context_engine/pluginlib/config.py +37 -0
- databao_context_engine/pluginlib/plugin_utils.py +68 -0
- databao_context_engine/plugins/__init__.py +0 -0
- databao_context_engine/plugins/athena_db_plugin.py +12 -0
- databao_context_engine/plugins/base_db_plugin.py +45 -0
- databao_context_engine/plugins/clickhouse_db_plugin.py +15 -0
- databao_context_engine/plugins/databases/__init__.py +0 -0
- databao_context_engine/plugins/databases/athena_introspector.py +101 -0
- databao_context_engine/plugins/databases/base_introspector.py +144 -0
- databao_context_engine/plugins/databases/clickhouse_introspector.py +162 -0
- databao_context_engine/plugins/databases/database_chunker.py +69 -0
- databao_context_engine/plugins/databases/databases_types.py +114 -0
- databao_context_engine/plugins/databases/duckdb_introspector.py +325 -0
- databao_context_engine/plugins/databases/introspection_model_builder.py +270 -0
- databao_context_engine/plugins/databases/introspection_scope.py +74 -0
- databao_context_engine/plugins/databases/introspection_scope_matcher.py +103 -0
- databao_context_engine/plugins/databases/mssql_introspector.py +433 -0
- databao_context_engine/plugins/databases/mysql_introspector.py +338 -0
- databao_context_engine/plugins/databases/postgresql_introspector.py +428 -0
- databao_context_engine/plugins/databases/snowflake_introspector.py +287 -0
- databao_context_engine/plugins/duckdb_db_plugin.py +12 -0
- databao_context_engine/plugins/mssql_db_plugin.py +12 -0
- databao_context_engine/plugins/mysql_db_plugin.py +12 -0
- databao_context_engine/plugins/parquet_plugin.py +32 -0
- databao_context_engine/plugins/plugin_loader.py +110 -0
- databao_context_engine/plugins/postgresql_db_plugin.py +12 -0
- databao_context_engine/plugins/resources/__init__.py +0 -0
- databao_context_engine/plugins/resources/parquet_chunker.py +23 -0
- databao_context_engine/plugins/resources/parquet_introspector.py +154 -0
- databao_context_engine/plugins/snowflake_db_plugin.py +12 -0
- databao_context_engine/plugins/unstructured_files_plugin.py +68 -0
- databao_context_engine/project/__init__.py +0 -0
- databao_context_engine/project/datasource_discovery.py +141 -0
- databao_context_engine/project/info.py +44 -0
- databao_context_engine/project/init_project.py +102 -0
- databao_context_engine/project/layout.py +127 -0
- databao_context_engine/project/project_config.py +32 -0
- databao_context_engine/project/resources/examples/src/databases/example_postgres.yaml +7 -0
- databao_context_engine/project/resources/examples/src/files/documentation.md +30 -0
- databao_context_engine/project/resources/examples/src/files/notes.txt +20 -0
- databao_context_engine/project/runs.py +39 -0
- databao_context_engine/project/types.py +134 -0
- databao_context_engine/retrieve_embeddings/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/internal/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/internal/export_results.py +12 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_runner.py +34 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_service.py +68 -0
- databao_context_engine/retrieve_embeddings/internal/retrieve_wiring.py +29 -0
- databao_context_engine/retrieve_embeddings/public/__init__.py +0 -0
- databao_context_engine/retrieve_embeddings/public/api.py +3 -0
- databao_context_engine/serialisation/__init__.py +0 -0
- databao_context_engine/serialisation/yaml.py +35 -0
- databao_context_engine/services/__init__.py +0 -0
- databao_context_engine/services/chunk_embedding_service.py +104 -0
- databao_context_engine/services/embedding_shard_resolver.py +64 -0
- databao_context_engine/services/factories.py +88 -0
- databao_context_engine/services/models.py +12 -0
- databao_context_engine/services/persistence_service.py +61 -0
- databao_context_engine/services/run_name_policy.py +8 -0
- databao_context_engine/services/table_name_policy.py +15 -0
- databao_context_engine/storage/__init__.py +0 -0
- databao_context_engine/storage/connection.py +32 -0
- databao_context_engine/storage/exceptions/__init__.py +0 -0
- databao_context_engine/storage/exceptions/exceptions.py +6 -0
- databao_context_engine/storage/migrate.py +127 -0
- databao_context_engine/storage/migrations/V01__init.sql +63 -0
- databao_context_engine/storage/models.py +51 -0
- databao_context_engine/storage/repositories/__init__.py +0 -0
- databao_context_engine/storage/repositories/chunk_repository.py +130 -0
- databao_context_engine/storage/repositories/datasource_run_repository.py +136 -0
- databao_context_engine/storage/repositories/embedding_model_registry_repository.py +87 -0
- databao_context_engine/storage/repositories/embedding_repository.py +113 -0
- databao_context_engine/storage/repositories/factories.py +35 -0
- databao_context_engine/storage/repositories/run_repository.py +157 -0
- databao_context_engine/storage/repositories/vector_search_repository.py +63 -0
- databao_context_engine/storage/transaction.py +14 -0
- databao_context_engine/system/__init__.py +0 -0
- databao_context_engine/system/properties.py +13 -0
- databao_context_engine/templating/__init__.py +0 -0
- databao_context_engine/templating/renderer.py +29 -0
- databao_context_engine-0.1.1.dist-info/METADATA +186 -0
- databao_context_engine-0.1.1.dist-info/RECORD +135 -0
- databao_context_engine-0.1.1.dist-info/WHEEL +4 -0
- databao_context_engine-0.1.1.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from io import BufferedReader
|
|
3
|
+
from typing import Any, Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class EmbeddableChunk:
|
|
8
|
+
"""
|
|
9
|
+
A chunk that will be embedded as a vector and used when searching context from a given AI prompt
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
embeddable_text: str
|
|
13
|
+
"""
|
|
14
|
+
The text to embed as a vector for search usage
|
|
15
|
+
"""
|
|
16
|
+
content: Any
|
|
17
|
+
"""
|
|
18
|
+
The content to return as a response when the embeddings has been selected in a search
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseBuildPlugin(Protocol):
|
|
23
|
+
id: str
|
|
24
|
+
name: str
|
|
25
|
+
|
|
26
|
+
def supported_types(self) -> set[str]: ...
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
Returns the list of all supported types for this plugin.
|
|
30
|
+
If the plugin supports multiple types, they should check the type given in the `full_type` argument when `execute` is called.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]: ...
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
A method dividing the data source context into meaningful chunks that will be used when searching the context from an AI prompt.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@runtime_checkable
|
|
41
|
+
class BuildDatasourcePlugin[T](BaseBuildPlugin, Protocol):
|
|
42
|
+
config_file_type: type[T]
|
|
43
|
+
|
|
44
|
+
def build_context(self, full_type: str, datasource_name: str, file_config: T) -> Any: ...
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
The method that will be called when a config file has been found for a data source supported by this plugin.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def check_connection(self, full_type: str, datasource_name: str, file_config: T) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Checks whether the configuration to the datasource is working.
|
|
53
|
+
|
|
54
|
+
The function is expected to succeed without a result if the connection is working.
|
|
55
|
+
If something is wrong with the connection, the function should raise an Exception
|
|
56
|
+
"""
|
|
57
|
+
raise NotSupportedError("This method is not implemented for this plugin")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DefaultBuildDatasourcePlugin(BuildDatasourcePlugin[dict[str, Any]], Protocol):
|
|
61
|
+
"""
|
|
62
|
+
Use this as a base class for plugins that don't need a specific config file type.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
config_file_type: type[dict[str, Any]] = dict[str, Any]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@runtime_checkable
|
|
69
|
+
class BuildFilePlugin(BaseBuildPlugin, Protocol):
|
|
70
|
+
def build_file_context(self, full_type: str, file_name: str, file_buffer: BufferedReader) -> Any: ...
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
The method that will be called when a file has been found as a data source supported by this plugin.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class NotSupportedError(RuntimeError):
|
|
78
|
+
"""Exception raised by methods not supported by a plugin"""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
BuildPlugin = BuildDatasourcePlugin | BuildFilePlugin
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(kw_only=True, frozen=True)
|
|
85
|
+
class DatasourceType:
|
|
86
|
+
full_type: str
|
|
87
|
+
|
|
88
|
+
def __post_init__(self):
|
|
89
|
+
type_segments = self.full_type.split("/")
|
|
90
|
+
if len(type_segments) != 2:
|
|
91
|
+
raise ValueError(f"Invalid DatasourceType: {self.full_type}")
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def main_type(self) -> str:
|
|
95
|
+
return self.full_type.split("/")[0]
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def config_folder(self) -> str:
|
|
99
|
+
return self.main_type
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def subtype(self) -> str:
|
|
103
|
+
return self.full_type.split("/")[1]
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def from_main_and_subtypes(main_type: str, subtype: str) -> "DatasourceType":
|
|
107
|
+
return DatasourceType(full_type=f"{main_type}/{subtype}")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Protocol, runtime_checkable
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DuckDBSecret(BaseModel):
|
|
8
|
+
name: str | None = Field(default=None)
|
|
9
|
+
type: str = Field(
|
|
10
|
+
description="DuckDB secret type. Examples: s3, postgres, iceberg, etc. See https://duckdb.org/docs/stable/configuration/secrets_manager#types-of-secrets"
|
|
11
|
+
)
|
|
12
|
+
properties: dict[str, Any] = Field(
|
|
13
|
+
default={},
|
|
14
|
+
description="Key/Value pairs which will be used to create a duckdb secret. "
|
|
15
|
+
"See https://duckdb.org/docs/stable/configuration/secrets_manager",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(kw_only=True)
|
|
20
|
+
class ConfigPropertyDefinition:
|
|
21
|
+
property_key: str
|
|
22
|
+
required: bool
|
|
23
|
+
property_type: type | None = str
|
|
24
|
+
default_value: str | None = None
|
|
25
|
+
nested_properties: list["ConfigPropertyDefinition"] | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(kw_only=True)
|
|
29
|
+
class ConfigPropertyAnnotation:
|
|
30
|
+
required: bool = False
|
|
31
|
+
default_value: str | None = None
|
|
32
|
+
ignored_for_config_wizard: bool = False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@runtime_checkable
|
|
36
|
+
class CustomiseConfigProperties(Protocol):
|
|
37
|
+
def get_config_file_properties(self) -> list[ConfigPropertyDefinition]: ...
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Mapping
|
|
6
|
+
|
|
7
|
+
from pydantic import TypeAdapter
|
|
8
|
+
|
|
9
|
+
from databao_context_engine.pluginlib.build_plugin import BuildDatasourcePlugin, BuildFilePlugin, DatasourceType
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def execute_datasource_plugin(
|
|
15
|
+
plugin: BuildDatasourcePlugin, datasource_type: DatasourceType, config: Mapping[str, Any], datasource_name: str
|
|
16
|
+
) -> Any:
|
|
17
|
+
if not isinstance(plugin, BuildDatasourcePlugin):
|
|
18
|
+
raise ValueError("This method can only execute a BuildDatasourcePlugin")
|
|
19
|
+
|
|
20
|
+
validated_config = _validate_datasource_config_file(config, plugin)
|
|
21
|
+
|
|
22
|
+
return plugin.build_context(
|
|
23
|
+
full_type=datasource_type.full_type,
|
|
24
|
+
datasource_name=datasource_name,
|
|
25
|
+
file_config=validated_config,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def check_connection_for_datasource(
|
|
30
|
+
plugin: BuildDatasourcePlugin, datasource_type: DatasourceType, config: Mapping[str, Any], datasource_name: str
|
|
31
|
+
) -> None:
|
|
32
|
+
if not isinstance(plugin, BuildDatasourcePlugin):
|
|
33
|
+
raise ValueError("Connection checks can only be performed on BuildDatasourcePlugin")
|
|
34
|
+
|
|
35
|
+
validated_config = _validate_datasource_config_file(config, plugin)
|
|
36
|
+
|
|
37
|
+
plugin.check_connection(
|
|
38
|
+
full_type=datasource_type.full_type,
|
|
39
|
+
datasource_name=datasource_name,
|
|
40
|
+
file_config=validated_config,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _validate_datasource_config_file(config: Mapping[str, Any], plugin: BuildDatasourcePlugin) -> Any:
|
|
45
|
+
return TypeAdapter(plugin.config_file_type).validate_python(config)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def execute_file_plugin(plugin: BuildFilePlugin, datasource_type: DatasourceType, file_path: Path) -> Any:
|
|
49
|
+
with file_path.open("rb") as fh:
|
|
50
|
+
return plugin.build_file_context(
|
|
51
|
+
full_type=datasource_type.full_type,
|
|
52
|
+
file_name=file_path.name,
|
|
53
|
+
file_buffer=fh,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def generate_json_schema(plugin: BuildDatasourcePlugin, pretty_print: bool = True) -> str | None:
|
|
58
|
+
if plugin.config_file_type == dict[str, Any]:
|
|
59
|
+
logger.debug(f"Skipping json schema generation for plugin {plugin.id}: no custom config_file_type provided")
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
json_schema = TypeAdapter(plugin.config_file_type).json_schema(mode="serialization")
|
|
63
|
+
|
|
64
|
+
return json.dumps(json_schema, indent=4 if pretty_print else None)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def format_json_schema_for_output(plugin: BuildDatasourcePlugin, json_schema: str) -> str:
|
|
68
|
+
return os.linesep.join([f"JSON Schema for plugin {plugin.id}:", json_schema])
|
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.athena_introspector import AthenaConfigFile, AthenaIntrospector
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AthenaDbPlugin(BaseDatabasePlugin[AthenaConfigFile]):
|
|
6
|
+
id = "jetbrains/athena"
|
|
7
|
+
name = "Athena DB Plugin"
|
|
8
|
+
supported = {"databases/athena"}
|
|
9
|
+
config_file_type = AthenaConfigFile
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
super().__init__(AthenaIntrospector())
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TypeVar
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from databao_context_engine.pluginlib.build_plugin import (
|
|
8
|
+
BuildDatasourcePlugin,
|
|
9
|
+
EmbeddableChunk,
|
|
10
|
+
)
|
|
11
|
+
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector
|
|
12
|
+
from databao_context_engine.plugins.databases.database_chunker import build_database_chunks
|
|
13
|
+
from databao_context_engine.plugins.databases.introspection_scope import IntrospectionScope
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseDatabaseConfigFile(BaseModel):
|
|
17
|
+
model_config = ConfigDict(populate_by_name=True)
|
|
18
|
+
name: str | None = Field(default=None)
|
|
19
|
+
type: str
|
|
20
|
+
introspection_scope: IntrospectionScope | None = Field(default=None, alias="introspection-scope")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T", bound=BaseDatabaseConfigFile)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseDatabasePlugin(BuildDatasourcePlugin[T]):
|
|
27
|
+
name: str
|
|
28
|
+
supported: set[str]
|
|
29
|
+
|
|
30
|
+
def __init__(self, introspector: BaseIntrospector):
|
|
31
|
+
self._introspector = introspector
|
|
32
|
+
|
|
33
|
+
def supported_types(self) -> set[str]:
|
|
34
|
+
return self.supported
|
|
35
|
+
|
|
36
|
+
def build_context(self, full_type: str, datasource_name: str, file_config: T) -> Any:
|
|
37
|
+
introspection_result = self._introspector.introspect_database(file_config)
|
|
38
|
+
|
|
39
|
+
return introspection_result
|
|
40
|
+
|
|
41
|
+
def check_connection(self, full_type: str, datasource_name: str, file_config: T) -> None:
|
|
42
|
+
self._introspector.check_connection(file_config)
|
|
43
|
+
|
|
44
|
+
def divide_context_into_chunks(self, context: Any) -> list[EmbeddableChunk]:
|
|
45
|
+
return build_database_chunks(context)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from databao_context_engine.plugins.base_db_plugin import BaseDatabasePlugin
|
|
2
|
+
from databao_context_engine.plugins.databases.clickhouse_introspector import (
|
|
3
|
+
ClickhouseConfigFile,
|
|
4
|
+
ClickhouseIntrospector,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ClickhouseDbPlugin(BaseDatabasePlugin[ClickhouseConfigFile]):
|
|
9
|
+
id = "jetbrains/clickhouse"
|
|
10
|
+
name = "Clickhouse DB Plugin"
|
|
11
|
+
supported = {"databases/clickhouse"}
|
|
12
|
+
config_file_type = ClickhouseConfigFile
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
super().__init__(ClickhouseIntrospector())
|
|
File without changes
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Mapping
|
|
4
|
+
|
|
5
|
+
from pyathena import connect
|
|
6
|
+
from pyathena.cursor import DictCursor
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
|
|
9
|
+
from databao_context_engine.plugins.base_db_plugin import BaseDatabaseConfigFile
|
|
10
|
+
from databao_context_engine.plugins.databases.base_introspector import BaseIntrospector, SQLQuery
|
|
11
|
+
from databao_context_engine.plugins.databases.databases_types import DatabaseSchema
|
|
12
|
+
from databao_context_engine.plugins.databases.introspection_model_builder import IntrospectionModelBuilder
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AthenaConfigFile(BaseDatabaseConfigFile):
|
|
16
|
+
type: str = Field(default="databases/athena")
|
|
17
|
+
connection: dict[str, Any] = Field(
|
|
18
|
+
description="Connection parameters for the Athena database. It can contain any of the keys supported by the Athena connection library"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AthenaIntrospector(BaseIntrospector[AthenaConfigFile]):
|
|
23
|
+
_IGNORED_SCHEMAS = {
|
|
24
|
+
"information_schema",
|
|
25
|
+
}
|
|
26
|
+
supports_catalogs = True
|
|
27
|
+
|
|
28
|
+
def _connect(self, file_config: AthenaConfigFile):
|
|
29
|
+
connection = file_config.connection
|
|
30
|
+
if not isinstance(connection, Mapping):
|
|
31
|
+
raise ValueError("Invalid YAML config: 'connection' must be a mapping of connection parameters")
|
|
32
|
+
|
|
33
|
+
return connect(**connection, cursor_class=DictCursor)
|
|
34
|
+
|
|
35
|
+
def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
|
|
36
|
+
with connection.cursor() as cur:
|
|
37
|
+
cur.execute(sql, params or {})
|
|
38
|
+
return cur.fetchall()
|
|
39
|
+
|
|
40
|
+
def _get_catalogs(self, connection, file_config: AthenaConfigFile) -> list[str]:
|
|
41
|
+
catalog = file_config.connection.get("catalog", self._resolve_pseudo_catalog_name(file_config))
|
|
42
|
+
return [catalog]
|
|
43
|
+
|
|
44
|
+
def _connect_to_catalog(self, file_config: AthenaConfigFile, catalog: str):
|
|
45
|
+
self._connect(file_config)
|
|
46
|
+
|
|
47
|
+
def _sql_list_schemas(self, catalogs: list[str] | None) -> SQLQuery:
|
|
48
|
+
if not catalogs:
|
|
49
|
+
return SQLQuery("SELECT schema_name, catalog_name FROM information_schema.schemata", None)
|
|
50
|
+
catalog = catalogs[0]
|
|
51
|
+
sql = "SELECT schema_name, catalog_name FROM information_schema.schemata WHERE catalog_name = %(catalog)s"
|
|
52
|
+
return SQLQuery(sql, {"catalog": catalog})
|
|
53
|
+
|
|
54
|
+
# TODO: Incomplete plugin. Awaiting permission access to AWS to properly develop
|
|
55
|
+
def collect_catalog_model(self, connection, catalog: str, schemas: list[str]) -> list[DatabaseSchema] | None:
|
|
56
|
+
if not schemas:
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
comps = {"columns": self._sql_columns(catalog, schemas)}
|
|
60
|
+
results: dict[str, list[dict]] = {}
|
|
61
|
+
|
|
62
|
+
for name, q in comps.items():
|
|
63
|
+
results[name] = self._fetchall_dicts(connection, q.sql, q.params)
|
|
64
|
+
|
|
65
|
+
return IntrospectionModelBuilder.build_schemas_from_components(
|
|
66
|
+
schemas=schemas,
|
|
67
|
+
rels=results.get("relations", []),
|
|
68
|
+
cols=results.get("columns", []),
|
|
69
|
+
pk_cols=[],
|
|
70
|
+
uq_cols=[],
|
|
71
|
+
checks=[],
|
|
72
|
+
fk_cols=[],
|
|
73
|
+
idx_cols=[],
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _sql_columns(self, catalog: str, schemas: list[str]) -> SQLQuery:
|
|
77
|
+
sql = f"""
|
|
78
|
+
SELECT
|
|
79
|
+
table_schema AS schema_name,
|
|
80
|
+
table_name,
|
|
81
|
+
column_name,
|
|
82
|
+
ordinal_position,
|
|
83
|
+
data_type,
|
|
84
|
+
is_nullable
|
|
85
|
+
FROM
|
|
86
|
+
{catalog}.information_schema.columns
|
|
87
|
+
WHERE
|
|
88
|
+
table_schema IN ({schemas})
|
|
89
|
+
ORDER BY
|
|
90
|
+
table_schema,
|
|
91
|
+
table_name,
|
|
92
|
+
ordinal_position
|
|
93
|
+
"""
|
|
94
|
+
return SQLQuery(sql, {"schema": schemas})
|
|
95
|
+
|
|
96
|
+
def _resolve_pseudo_catalog_name(self, file_config: AthenaConfigFile) -> str:
|
|
97
|
+
return "awsdatacatalog"
|
|
98
|
+
|
|
99
|
+
def _sql_sample_rows(self, catalog: str, schema: str, table: str, limit: int) -> SQLQuery:
|
|
100
|
+
sql = f'SELECT * FROM "{schema}"."{table}" LIMIT %(limit)s'
|
|
101
|
+
return SQLQuery(sql, {"limit": limit})
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Mapping, Protocol, Sequence, Union
|
|
7
|
+
|
|
8
|
+
from databao_context_engine.plugins.databases.databases_types import (
|
|
9
|
+
DatabaseCatalog,
|
|
10
|
+
DatabaseIntrospectionResult,
|
|
11
|
+
DatabaseSchema,
|
|
12
|
+
)
|
|
13
|
+
from databao_context_engine.plugins.databases.introspection_scope import IntrospectionScope
|
|
14
|
+
from databao_context_engine.plugins.databases.introspection_scope_matcher import IntrospectionScopeMatcher
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SupportsIntrospectionScope(Protocol):
|
|
20
|
+
introspection_scope: IntrospectionScope | None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseIntrospector[T: SupportsIntrospectionScope](ABC):
|
|
24
|
+
supports_catalogs: bool = True
|
|
25
|
+
_IGNORED_SCHEMAS: set[str] = {"information_schema"}
|
|
26
|
+
_SAMPLE_LIMIT: int = 5
|
|
27
|
+
|
|
28
|
+
def check_connection(self, file_config: T) -> None:
|
|
29
|
+
with self._connect(file_config) as connection:
|
|
30
|
+
self._fetchall_dicts(connection, "SELECT 1 as test", None)
|
|
31
|
+
|
|
32
|
+
def introspect_database(self, file_config: T) -> DatabaseIntrospectionResult:
|
|
33
|
+
scope_matcher = IntrospectionScopeMatcher(
|
|
34
|
+
file_config.introspection_scope,
|
|
35
|
+
ignored_schemas=self._ignored_schemas(),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
with self._connect(file_config) as root_connection:
|
|
39
|
+
catalogs = self._get_catalogs_adapted(root_connection, file_config)
|
|
40
|
+
|
|
41
|
+
discovered_schemas_per_catalog: dict[str, list[str]] = {}
|
|
42
|
+
for catalog in catalogs:
|
|
43
|
+
with self._connect_to_catalog(file_config, catalog) as conn:
|
|
44
|
+
discovered_schemas_per_catalog[catalog] = self._list_schemas_for_catalog(conn, catalog)
|
|
45
|
+
scope = scope_matcher.filter_scopes(catalogs, discovered_schemas_per_catalog)
|
|
46
|
+
|
|
47
|
+
introspected_catalogs: list[DatabaseCatalog] = []
|
|
48
|
+
for catalog in scope.catalogs:
|
|
49
|
+
schemas_to_introspect = scope.schemas_per_catalog.get(catalog, [])
|
|
50
|
+
if not schemas_to_introspect:
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
with self._connect_to_catalog(file_config, catalog) as catalog_connection:
|
|
54
|
+
introspected_schemas = self.collect_catalog_model(catalog_connection, catalog, schemas_to_introspect)
|
|
55
|
+
|
|
56
|
+
if not introspected_schemas:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
for schema in introspected_schemas:
|
|
60
|
+
for table in schema.tables:
|
|
61
|
+
table.samples = self._collect_samples_for_table(
|
|
62
|
+
catalog_connection, catalog, schema.name, table.name
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
introspected_catalogs.append(DatabaseCatalog(name=catalog, schemas=introspected_schemas))
|
|
66
|
+
return DatabaseIntrospectionResult(catalogs=introspected_catalogs)
|
|
67
|
+
|
|
68
|
+
def _get_catalogs_adapted(self, connection, file_config: T) -> list[str]:
|
|
69
|
+
if self.supports_catalogs:
|
|
70
|
+
return self._get_catalogs(connection, file_config)
|
|
71
|
+
return [self._resolve_pseudo_catalog_name(file_config)]
|
|
72
|
+
|
|
73
|
+
def _sql_list_schemas(self, catalogs: list[str] | None) -> SQLQuery:
|
|
74
|
+
if self.supports_catalogs:
|
|
75
|
+
sql = "SELECT catalog_name, schema_name FROM information_schema.schemata WHERE catalog_name = ANY(%s)"
|
|
76
|
+
return SQLQuery(sql, (catalogs,))
|
|
77
|
+
else:
|
|
78
|
+
sql = "SELECT schema_name FROM information_schema.schemata"
|
|
79
|
+
return SQLQuery(sql, None)
|
|
80
|
+
|
|
81
|
+
def _list_schemas_for_catalog(self, connection: Any, catalog: str) -> list[str]:
|
|
82
|
+
sql_query = self._sql_list_schemas([catalog] if self.supports_catalogs else None)
|
|
83
|
+
rows = self._fetchall_dicts(connection, sql_query.sql, sql_query.params)
|
|
84
|
+
|
|
85
|
+
schemas: list[str] = []
|
|
86
|
+
for row in rows:
|
|
87
|
+
schema_name = row.get("schema_name")
|
|
88
|
+
if schema_name:
|
|
89
|
+
schemas.append(schema_name)
|
|
90
|
+
|
|
91
|
+
return schemas
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def collect_catalog_model(self, connection, catalog: str, schemas: list[str]) -> list[DatabaseSchema] | None:
|
|
95
|
+
raise NotImplementedError
|
|
96
|
+
|
|
97
|
+
def _collect_samples_for_table(self, connection, catalog: str, schema: str, table: str) -> list[dict[str, Any]]:
|
|
98
|
+
samples: list[dict[str, Any]] = []
|
|
99
|
+
if self._SAMPLE_LIMIT > 0:
|
|
100
|
+
try:
|
|
101
|
+
sql_query = self._sql_sample_rows(catalog, schema, table, self._SAMPLE_LIMIT)
|
|
102
|
+
samples = self._fetchall_dicts(connection, sql_query.sql, sql_query.params)
|
|
103
|
+
except NotImplementedError:
|
|
104
|
+
samples = []
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.warning("Failed to fetch samples for %s.%s (catalog=%s): %s", schema, table, catalog, e)
|
|
107
|
+
samples = []
|
|
108
|
+
return samples
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def _connect(self, file_config: T):
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def _fetchall_dicts(self, connection, sql: str, params) -> list[dict]:
|
|
116
|
+
raise NotImplementedError
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def _get_catalogs(self, connection, file_config: T) -> list[str]:
|
|
120
|
+
raise NotImplementedError
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def _connect_to_catalog(self, file_config: T, catalog: str):
|
|
124
|
+
"""Return a connection scoped to `catalog`. For engines that
|
|
125
|
+
don’t need a new connection, return a connection with the
|
|
126
|
+
session set/USE’d to that catalog."""
|
|
127
|
+
|
|
128
|
+
def _sql_sample_rows(self, catalog: str, schema: str, table: str, limit: int) -> SQLQuery:
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
def _resolve_pseudo_catalog_name(self, file_config: T) -> str:
|
|
132
|
+
return "default"
|
|
133
|
+
|
|
134
|
+
def _ignored_schemas(self) -> set[str]:
|
|
135
|
+
return self._IGNORED_SCHEMAS
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class SQLQuery:
|
|
140
|
+
sql: str
|
|
141
|
+
params: ParamsType = None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
ParamsType = Union[Mapping[str, Any], Sequence[Any], None]
|