gaard-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,237 @@
1
+ from gaard_core.errors import ConfigurationError
2
+ from gaard_core.investigation.models import InvestigationContext
3
+ from gaard_core.json_utils import json_dumps
4
+ from gaard_core.prompt_compiler.models import CompiledPrompt, SqlGenerationPromptRequest
5
+ from gaard_core.prompt_compiler.schema_formatter import SchemaPromptFormatter
6
+ from gaard_core.query_pipeline.models import QueryRequest, QueryResult
7
+
8
+ from gaard_api.admin.models import PromptTemplate
9
+ from gaard_api.admin.services import get_active_prompt_template_safe
10
+
11
+
12
+ class MetadataSqlGenerationPromptCompiler:
13
+ def __init__(
14
+ self,
15
+ prompt_template: PromptTemplate,
16
+ schema_formatter: SchemaPromptFormatter | None = None,
17
+ ) -> None:
18
+ self.prompt_template = prompt_template
19
+ self.schema_formatter = schema_formatter or SchemaPromptFormatter()
20
+
21
+ def compile(self, request: SqlGenerationPromptRequest) -> CompiledPrompt:
22
+ formatted_schema = self._resolve_formatted_schema(request)
23
+
24
+ system_prompt = self.prompt_template.system_prompt.format(
25
+ dialect=request.dialect,
26
+ max_rows=request.max_rows,
27
+ )
28
+ user_prompt = self.prompt_template.user_prompt_template.format(
29
+ schema=formatted_schema,
30
+ question=request.question,
31
+ dialect=request.dialect,
32
+ max_rows=request.max_rows,
33
+ )
34
+
35
+ return CompiledPrompt(
36
+ system_prompt=system_prompt,
37
+ user_prompt=user_prompt,
38
+ metadata={
39
+ "dialect": request.dialect,
40
+ "max_rows": request.max_rows,
41
+ "prompt_key": self.prompt_template.prompt_key,
42
+ "prompt_version": self.prompt_template.version,
43
+ "schema_source": "formatted_schema"
44
+ if request.formatted_schema is not None
45
+ else "database_schema",
46
+ "tables_count": len(request.database_schema.tables)
47
+ if request.database_schema is not None
48
+ else None,
49
+ },
50
+ )
51
+
52
+ def _resolve_formatted_schema(self, request: SqlGenerationPromptRequest) -> str:
53
+ if request.formatted_schema is not None:
54
+ return request.formatted_schema
55
+
56
+ if request.database_schema is None:
57
+ raise ConfigurationError(
58
+ "Either database_schema or formatted_schema must be provided."
59
+ )
60
+
61
+ return self.schema_formatter.format(request.database_schema)
62
+
63
+
64
+ class MetadataIntentClassificationPromptCompiler:
65
+ def __init__(self, prompt_template: PromptTemplate) -> None:
66
+ self.prompt_template = prompt_template
67
+
68
+ def compile(self, request: QueryRequest) -> CompiledPrompt:
69
+ payload = {
70
+ "question": request.question,
71
+ "datasource_id": request.datasource_id,
72
+ "user_id": request.user_id,
73
+ }
74
+ payload_json = json_dumps(payload, ensure_ascii=False, indent=2)
75
+
76
+ return CompiledPrompt(
77
+ system_prompt=self.prompt_template.system_prompt,
78
+ user_prompt=self.prompt_template.user_prompt_template.format(
79
+ payload=payload_json,
80
+ question=request.question,
81
+ datasource_id=request.datasource_id,
82
+ user_id=request.user_id,
83
+ ),
84
+ metadata={
85
+ "prompt_key": self.prompt_template.prompt_key,
86
+ "prompt_version": self.prompt_template.version,
87
+ },
88
+ )
89
+
90
+
91
+ class MetadataInvestigationReadinessPromptCompiler:
92
+ def __init__(self, prompt_template: PromptTemplate) -> None:
93
+ self.prompt_template = prompt_template
94
+
95
+ def compile(self, context: InvestigationContext) -> CompiledPrompt:
96
+ payload = {
97
+ "question": context.question,
98
+ "datasource_id": context.datasource_id,
99
+ "user_id": context.user_id,
100
+ "schema": context.formatted_schema,
101
+ "business_logic": context.business_logic,
102
+ }
103
+ payload_json = json_dumps(payload, ensure_ascii=False, indent=2)
104
+
105
+ return CompiledPrompt(
106
+ system_prompt=self.prompt_template.system_prompt,
107
+ user_prompt=self.prompt_template.user_prompt_template.format(
108
+ payload=payload_json,
109
+ question=context.question,
110
+ datasource_id=context.datasource_id,
111
+ user_id=context.user_id,
112
+ schema=context.formatted_schema,
113
+ business_logic=context.business_logic,
114
+ ),
115
+ metadata={
116
+ "prompt_key": self.prompt_template.prompt_key,
117
+ "prompt_version": self.prompt_template.version,
118
+ },
119
+ )
120
+
121
+
122
+ class MetadataResultInterpretationPromptCompiler:
123
+ def __init__(self, prompt_template: PromptTemplate) -> None:
124
+ self.prompt_template = prompt_template
125
+
126
+ def compile(
127
+ self,
128
+ request: QueryRequest,
129
+ sql: str,
130
+ result: QueryResult,
131
+ ) -> CompiledPrompt:
132
+ payload = {
133
+ "question": request.question,
134
+ "sql": sql,
135
+ "columns": result.columns,
136
+ "rows": result.rows,
137
+ }
138
+ payload_json = json_dumps(payload, ensure_ascii=False, indent=2)
139
+
140
+ return CompiledPrompt(
141
+ system_prompt=self.prompt_template.system_prompt,
142
+ user_prompt=self.prompt_template.user_prompt_template.format(
143
+ payload=payload_json,
144
+ question=request.question,
145
+ sql=sql,
146
+ columns=json_dumps(result.columns, ensure_ascii=False),
147
+ rows=json_dumps(result.rows, ensure_ascii=False),
148
+ ),
149
+ metadata={
150
+ "rows_count": len(result.rows),
151
+ "columns_count": len(result.columns),
152
+ "prompt_key": self.prompt_template.prompt_key,
153
+ "prompt_version": self.prompt_template.version,
154
+ },
155
+ )
156
+
157
+
158
+ class MetadataResultClassificationPromptCompiler:
159
+ def __init__(self, prompt_template: PromptTemplate) -> None:
160
+ self.prompt_template = prompt_template
161
+
162
+ def compile(
163
+ self,
164
+ request: QueryRequest,
165
+ answer: str,
166
+ ) -> CompiledPrompt:
167
+ payload = {
168
+ "question": request.question,
169
+ "answer": answer,
170
+ }
171
+ payload_json = json_dumps(payload, ensure_ascii=False, indent=2)
172
+
173
+ return CompiledPrompt(
174
+ system_prompt=self.prompt_template.system_prompt,
175
+ user_prompt=self.prompt_template.user_prompt_template.format(
176
+ payload=payload_json,
177
+ question=request.question,
178
+ answer=answer,
179
+ ),
180
+ metadata={
181
+ "prompt_key": self.prompt_template.prompt_key,
182
+ "prompt_version": self.prompt_template.version,
183
+ },
184
+ )
185
+
186
+
187
+ def get_sql_generation_prompt_compiler() -> MetadataSqlGenerationPromptCompiler | None:
188
+ prompt_template = get_active_prompt_template_safe("sql_generation")
189
+
190
+ if prompt_template is None:
191
+ return None
192
+
193
+ return MetadataSqlGenerationPromptCompiler(prompt_template=prompt_template)
194
+
195
+
196
+ def get_intent_classification_prompt_compiler() -> (
197
+ MetadataIntentClassificationPromptCompiler | None
198
+ ):
199
+ prompt_template = get_active_prompt_template_safe("intent_classification")
200
+
201
+ if prompt_template is None:
202
+ return None
203
+
204
+ return MetadataIntentClassificationPromptCompiler(prompt_template=prompt_template)
205
+
206
+
207
+ def get_investigation_readiness_prompt_compiler() -> (
208
+ MetadataInvestigationReadinessPromptCompiler | None
209
+ ):
210
+ prompt_template = get_active_prompt_template_safe("investigation_readiness")
211
+
212
+ if prompt_template is None:
213
+ return None
214
+
215
+ return MetadataInvestigationReadinessPromptCompiler(prompt_template=prompt_template)
216
+
217
+
218
+ def get_result_interpretation_prompt_compiler() -> (
219
+ MetadataResultInterpretationPromptCompiler | None
220
+ ):
221
+ prompt_template = get_active_prompt_template_safe("result_interpretation")
222
+
223
+ if prompt_template is None:
224
+ return None
225
+
226
+ return MetadataResultInterpretationPromptCompiler(prompt_template=prompt_template)
227
+
228
+
229
+ def get_result_classification_prompt_compiler() -> (
230
+ MetadataResultClassificationPromptCompiler | None
231
+ ):
232
+ prompt_template = get_active_prompt_template_safe("result_classification")
233
+
234
+ if prompt_template is None:
235
+ return None
236
+
237
+ return MetadataResultClassificationPromptCompiler(prompt_template=prompt_template)
@@ -0,0 +1,45 @@
1
+ import hashlib
2
+ import hmac
3
+ import secrets
4
+
5
+
6
+ PASSWORD_ITERATIONS = 260_000
7
+
8
+
9
+ def hash_password(password: str) -> str:
10
+ salt = secrets.token_hex(16)
11
+ digest = hashlib.pbkdf2_hmac(
12
+ "sha256",
13
+ password.encode("utf-8"),
14
+ salt.encode("utf-8"),
15
+ PASSWORD_ITERATIONS,
16
+ ).hex()
17
+
18
+ return f"pbkdf2_sha256${PASSWORD_ITERATIONS}${salt}${digest}"
19
+
20
+
21
+ def verify_password(password: str, password_hash: str) -> bool:
22
+ try:
23
+ algorithm, iterations_value, salt, expected_digest = password_hash.split("$", 3)
24
+ except ValueError:
25
+ return False
26
+
27
+ if algorithm != "pbkdf2_sha256":
28
+ return False
29
+
30
+ digest = hashlib.pbkdf2_hmac(
31
+ "sha256",
32
+ password.encode("utf-8"),
33
+ salt.encode("utf-8"),
34
+ int(iterations_value),
35
+ ).hex()
36
+
37
+ return hmac.compare_digest(digest, expected_digest)
38
+
39
+
40
+ def create_session_token() -> str:
41
+ return secrets.token_urlsafe(32)
42
+
43
+
44
+ def hash_token(token: str) -> str:
45
+ return hashlib.sha256(token.encode("utf-8")).hexdigest()