gaard-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,55 @@
1
+ from fastapi import APIRouter
2
+ from pydantic import BaseModel, Field
3
+
4
+ from gaard_connectors.sqlalchemy.introspector import SQLAlchemySchemaIntrospector
5
+ from gaard_core.prompt_compiler.models import CompiledPrompt
6
+ from gaard_core.prompt_compiler.models import SqlGenerationPromptRequest
7
+ from gaard_core.prompt_compiler.sql_generation_prompt import SqlGenerationPromptCompiler
8
+ from gaard_core.schema.context import SchemaContextService
9
+
10
+ from gaard_api.admin.prompt_runtime import get_sql_generation_prompt_compiler
11
+ from gaard_api.admin.services import get_datasource_schema_context_safe, get_query_runtime_config_safe
12
+ from gaard_api.api.v1.schema import get_schema_cache_key
13
+ from gaard_api.core.schema_cache import schema_context_cache
14
+ from gaard_api.core.settings import settings
15
+
16
+ router = APIRouter()
17
+
18
+
19
+ class CompileSqlGenerationPromptApiRequest(BaseModel):
20
+ question: str = Field(min_length=1)
21
+
22
+
23
+ @router.post("/prompts/sql-generation", response_model=CompiledPrompt)
24
+ def compile_sql_generation_prompt(
25
+ request: CompileSqlGenerationPromptApiRequest,
26
+ ) -> CompiledPrompt:
27
+ datasource_context = get_datasource_schema_context_safe()
28
+
29
+ if datasource_context is not None:
30
+ connector, schema_cache = datasource_context
31
+ formatted_schema = schema_cache.formatted_schema
32
+ sql_dialect = connector.sql_dialect
33
+ else:
34
+ sql_dialect = settings.gaard_sql_dialect
35
+ introspector = SQLAlchemySchemaIntrospector(
36
+ database_url=settings.gaard_datasource_url,
37
+ )
38
+ service = SchemaContextService(
39
+ introspector=introspector,
40
+ cache=schema_context_cache,
41
+ )
42
+ context = service.get_schema_context(get_schema_cache_key())
43
+ formatted_schema = context.formatted_schema
44
+
45
+ compiler = get_sql_generation_prompt_compiler() or SqlGenerationPromptCompiler()
46
+ runtime_config = get_query_runtime_config_safe()
47
+
48
+ return compiler.compile(
49
+ SqlGenerationPromptRequest(
50
+ question=request.question,
51
+ formatted_schema=formatted_schema,
52
+ dialect=sql_dialect,
53
+ max_rows=runtime_config.query_max_rows,
54
+ )
55
+ )