mocksql 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +0 -0
- app/api/__init__.py +0 -0
- app/api/endpoints/__init__.py +0 -0
- app/api/endpoints/messages.py +114 -0
- app/api/endpoints/models.py +117 -0
- app/api/endpoints/projects.py +141 -0
- app/api/endpoints/query.py +413 -0
- app/api/endpoints/users.py +44 -0
- app/exceptions/__init__.py +0 -0
- app/exceptions/exceptions.py +4 -0
- app/services/__init__.py +0 -0
- app/services/query_service.py +2 -0
- build_query/__init__.py +0 -0
- build_query/constraint_simplifier.py +1066 -0
- build_query/converstion_history.py +336 -0
- build_query/examples_executor.py +870 -0
- build_query/examples_generator.py +424 -0
- build_query/other.py +62 -0
- build_query/profile_checker.py +342 -0
- build_query/profiler.py +2135 -0
- build_query/prompt_tools.py +864 -0
- build_query/query_chain.py +172 -0
- build_query/query_executor.py +125 -0
- build_query/routing.py +113 -0
- build_query/schema_fetcher.py +138 -0
- build_query/state.py +48 -0
- build_query/test_evaluator.py +124 -0
- build_query/validator.py +546 -0
- cli/__init__.py +0 -0
- cli/generate.py +388 -0
- cli/main.py +321 -0
- cli/test_runner.py +323 -0
- common_vars.py +365 -0
- fetch_secrets.py +34 -0
- init/__init__.py +0 -0
- init/add_column.py +153 -0
- init/add_table.py +129 -0
- init/create_user.py +72 -0
- init/grant_access_to_db.py +31 -0
- init/init_db.py +154 -0
- mocksql-0.1.0.dist-info/METADATA +232 -0
- mocksql-0.1.0.dist-info/RECORD +80 -0
- mocksql-0.1.0.dist-info/WHEEL +4 -0
- mocksql-0.1.0.dist-info/entry_points.txt +10 -0
- models/__init__.py +0 -0
- models/database.py +66 -0
- models/db_pool.py +161 -0
- models/env_variables.py +41 -0
- models/message_service.py +286 -0
- models/model.py +2 -0
- models/model_service.py +2 -0
- models/permissions.py +66 -0
- models/schemas.py +122 -0
- models/session_service.py +105 -0
- models/user_service.py +99 -0
- server.py +95 -0
- sql_functions/__init__.py +0 -0
- sql_functions/functions.py +50 -0
- sql_functions/helpers.py +276 -0
- storage/__init__.py +0 -0
- storage/config.py +47 -0
- storage/test_repository.py +257 -0
- utils/__init__.py +19 -0
- utils/bigquery_test_helper.py +416 -0
- utils/duckdb_test_helper.py +235 -0
- utils/errors.py +376 -0
- utils/examples.py +754 -0
- utils/find_grains.py +822 -0
- utils/insert_examples.py +366 -0
- utils/llm_errors.py +25 -0
- utils/logger.py +0 -0
- utils/models.py +141 -0
- utils/msg_types.py +25 -0
- utils/postgres_db_utils.py +39 -0
- utils/postgres_test_helper.py +178 -0
- utils/prompt_utils.py +27 -0
- utils/query_services.py +17 -0
- utils/saver.py +270 -0
- utils/schema_utils.py +125 -0
- utils/sql_code.py +224 -0
app/__init__.py
ADDED
|
File without changes
|
app/api/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Any, Optional
|
|
3
|
+
|
|
4
|
+
from fastapi import APIRouter, HTTPException
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from storage.config import is_initialized
|
|
8
|
+
from storage.test_repository import get_test, update_test
|
|
9
|
+
from utils.saver import common_history_retriever
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
router = APIRouter()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MessageRequest(BaseModel):
|
|
17
|
+
modelId: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PatchTestsRequest(BaseModel):
|
|
21
|
+
sessionId: str
|
|
22
|
+
tests: List[Any]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PatchSqlRequest(BaseModel):
|
|
26
|
+
sessionId: str
|
|
27
|
+
sql: str
|
|
28
|
+
optimized_sql: str = ""
|
|
29
|
+
tests: Optional[List[Any]] = None
|
|
30
|
+
test_results: Optional[List[Any]] = None
|
|
31
|
+
restored_message_id: Optional[str] = None
|
|
32
|
+
last_error: Optional[str] = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@router.post("/getMessages")
|
|
36
|
+
async def get_messages(body: MessageRequest):
|
|
37
|
+
if not is_initialized():
|
|
38
|
+
raise HTTPException(
|
|
39
|
+
status_code=400,
|
|
40
|
+
detail="Projet non initialisé. Lancez 'mocksql init' dans votre répertoire de travail pour commencer.",
|
|
41
|
+
)
|
|
42
|
+
try:
|
|
43
|
+
history = await common_history_retriever(body.modelId, filtered_types=[])
|
|
44
|
+
if history is None:
|
|
45
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
46
|
+
|
|
47
|
+
test = get_test(body.modelId)
|
|
48
|
+
sql = test.get("sql") if test else None
|
|
49
|
+
optimized_sql = test.get("optimized_sql") if test else None
|
|
50
|
+
last_error = test.get("last_error") if test else ""
|
|
51
|
+
test_results = test.get("test_cases", []) if test else []
|
|
52
|
+
restored_message_id = test.get("restored_message_id") if test else None
|
|
53
|
+
|
|
54
|
+
# Fallback: extract from last results message in history
|
|
55
|
+
if not test_results:
|
|
56
|
+
import json
|
|
57
|
+
|
|
58
|
+
for msg in reversed(history):
|
|
59
|
+
if msg.additional_kwargs.get("type") == "results":
|
|
60
|
+
try:
|
|
61
|
+
test_results = json.loads(msg.content)
|
|
62
|
+
except Exception:
|
|
63
|
+
pass
|
|
64
|
+
break
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
"messages": history,
|
|
68
|
+
"sql": sql,
|
|
69
|
+
"optimized_sql": optimized_sql,
|
|
70
|
+
"test_results": test_results,
|
|
71
|
+
"restored_message_id": restored_message_id,
|
|
72
|
+
"last_error": last_error or "",
|
|
73
|
+
"sql_history": [],
|
|
74
|
+
}
|
|
75
|
+
except HTTPException:
|
|
76
|
+
raise
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.exception(
|
|
79
|
+
"Erreur lors du chargement des messages pour la session %s", body.modelId
|
|
80
|
+
)
|
|
81
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@router.patch("/models/sql")
|
|
85
|
+
async def patch_model_sql(body: PatchSqlRequest):
|
|
86
|
+
try:
|
|
87
|
+
updates: dict = {
|
|
88
|
+
"sql": body.sql,
|
|
89
|
+
"optimized_sql": body.optimized_sql,
|
|
90
|
+
}
|
|
91
|
+
if body.tests is not None:
|
|
92
|
+
updates["test_cases"] = body.tests
|
|
93
|
+
if body.test_results is not None:
|
|
94
|
+
updates["test_cases"] = body.test_results
|
|
95
|
+
if body.restored_message_id is not None:
|
|
96
|
+
updates["restored_message_id"] = body.restored_message_id or None
|
|
97
|
+
if body.last_error is not None:
|
|
98
|
+
updates["last_error"] = body.last_error
|
|
99
|
+
|
|
100
|
+
update_test(body.sessionId, updates)
|
|
101
|
+
return {"ok": True}
|
|
102
|
+
except Exception as e:
|
|
103
|
+
print(e)
|
|
104
|
+
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@router.patch("/models/tests")
|
|
108
|
+
async def patch_model_tests(body: PatchTestsRequest):
|
|
109
|
+
try:
|
|
110
|
+
update_test(body.sessionId, {"test_cases": body.tests})
|
|
111
|
+
return {"ok": True}
|
|
112
|
+
except Exception as e:
|
|
113
|
+
print(e)
|
|
114
|
+
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from fastapi import APIRouter, HTTPException
|
|
2
|
+
from fastapi.responses import JSONResponse
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from storage.test_repository import (
|
|
6
|
+
list_models,
|
|
7
|
+
list_all_tests,
|
|
8
|
+
list_tests,
|
|
9
|
+
get_test,
|
|
10
|
+
create_test,
|
|
11
|
+
delete_test,
|
|
12
|
+
_test_path,
|
|
13
|
+
_read_json,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
router = APIRouter()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
# SQL Models (fichiers physiques dans models_path)
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@router.get("/models")
|
|
25
|
+
async def get_models():
|
|
26
|
+
"""Liste les fichiers .sql disponibles dans models_path.
|
|
27
|
+
|
|
28
|
+
Each entry includes session_id / updated_at / test_name when a test already
|
|
29
|
+
exists for that model, so the frontend can redirect instead of regenerating.
|
|
30
|
+
Tested models are returned first (sorted by updated_at desc), then untested.
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
sql_files = list_models()
|
|
34
|
+
tested = []
|
|
35
|
+
untested = []
|
|
36
|
+
for f in sql_files:
|
|
37
|
+
p = _test_path(f["name"])
|
|
38
|
+
if p.exists():
|
|
39
|
+
data = _read_json(p)
|
|
40
|
+
if data:
|
|
41
|
+
tested.append(
|
|
42
|
+
{
|
|
43
|
+
**f,
|
|
44
|
+
"session_id": data.get("test_id"),
|
|
45
|
+
"updated_at": data.get("updated_at"),
|
|
46
|
+
"test_name": data.get("test_name"),
|
|
47
|
+
}
|
|
48
|
+
)
|
|
49
|
+
continue
|
|
50
|
+
untested.append(f)
|
|
51
|
+
tested.sort(key=lambda x: x.get("updated_at") or "", reverse=True)
|
|
52
|
+
return tested + untested
|
|
53
|
+
except Exception as e:
|
|
54
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@router.get("/tests/all")
|
|
58
|
+
async def get_all_tests():
|
|
59
|
+
"""Liste toutes les sessions de test (tous models confondus), triées par date décroissante."""
|
|
60
|
+
try:
|
|
61
|
+
return list_all_tests()
|
|
62
|
+
except Exception as e:
|
|
63
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
# Tests (fichiers dans .mocksql/tests/)
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CreateTestRequest(BaseModel):
|
|
72
|
+
model_name: str
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@router.get("/tests")
|
|
76
|
+
async def get_tests(model_name: str):
|
|
77
|
+
"""Liste tous les tests pour un model donné."""
|
|
78
|
+
try:
|
|
79
|
+
return list_tests(model_name)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@router.get("/test/{session_id}")
|
|
85
|
+
async def get_test_route(session_id: str, model_name: str = None):
|
|
86
|
+
try:
|
|
87
|
+
test = get_test(session_id, model_name)
|
|
88
|
+
if test is None:
|
|
89
|
+
raise HTTPException(status_code=404, detail="Test not found")
|
|
90
|
+
return test
|
|
91
|
+
except HTTPException:
|
|
92
|
+
raise
|
|
93
|
+
except Exception as e:
|
|
94
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@router.post("/tests")
|
|
98
|
+
async def create_test_route(body: CreateTestRequest):
|
|
99
|
+
"""Crée un nouveau test pour un model. Retourne le test_id (= session_id)."""
|
|
100
|
+
try:
|
|
101
|
+
test = create_test(body.model_name)
|
|
102
|
+
return JSONResponse(content=test, status_code=201)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@router.delete("/tests/{session_id}")
|
|
108
|
+
async def delete_test_route(session_id: str, model_name: str):
|
|
109
|
+
try:
|
|
110
|
+
ok = delete_test(session_id, model_name)
|
|
111
|
+
if not ok:
|
|
112
|
+
raise HTTPException(status_code=404, detail="Test not found")
|
|
113
|
+
return JSONResponse(status_code=204, content={"message": "Test deleted"})
|
|
114
|
+
except HTTPException:
|
|
115
|
+
raise
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import uuid
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, HTTPException
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from common_vars import PROJECTS_TABLE_NAME, USERS_TABLE_NAME
|
|
11
|
+
from models.database import execute, query
|
|
12
|
+
from models.permissions import grant_role
|
|
13
|
+
|
|
14
|
+
router = APIRouter()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ProjectRequest(BaseModel):
|
|
18
|
+
project_id: Optional[str] = None
|
|
19
|
+
name: str
|
|
20
|
+
dialect: str
|
|
21
|
+
description: Optional[str] = None
|
|
22
|
+
service_account_key: Optional[str] = None
|
|
23
|
+
auto_import: Optional[bool] = False
|
|
24
|
+
user_sub: Optional[str] = ""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ShareProjectRequest(BaseModel):
|
|
28
|
+
project: str
|
|
29
|
+
target: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _row_to_project(row: dict) -> dict:
|
|
33
|
+
return {
|
|
34
|
+
"project_id": row["project_id"],
|
|
35
|
+
"name": row["name"],
|
|
36
|
+
"dialect": row["dialect"],
|
|
37
|
+
"description": row.get("description") or "",
|
|
38
|
+
"service_account_key": row.get("service_account_key"),
|
|
39
|
+
"auto_import": row.get("auto_import") or False,
|
|
40
|
+
"schema": json.loads(row["json_schema"]) if row.get("json_schema") else [],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@router.get("/projects")
|
|
45
|
+
async def get_projects():
|
|
46
|
+
rows = await query(f"SELECT * FROM {PROJECTS_TABLE_NAME}")
|
|
47
|
+
return [_row_to_project(r) for r in (rows or [])]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@router.get("/project/{project_id}")
|
|
51
|
+
async def get_project(project_id: str):
|
|
52
|
+
rows = await query(
|
|
53
|
+
f"SELECT * FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
|
|
54
|
+
(project_id,),
|
|
55
|
+
)
|
|
56
|
+
if not rows:
|
|
57
|
+
raise HTTPException(status_code=404, detail="Project not found")
|
|
58
|
+
return _row_to_project(rows[0])
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@router.post("/projects")
|
|
62
|
+
async def create_project(body: ProjectRequest):
|
|
63
|
+
project_id = body.project_id or str(uuid.uuid4())
|
|
64
|
+
now = datetime.now().isoformat()
|
|
65
|
+
await execute(
|
|
66
|
+
f"""
|
|
67
|
+
INSERT INTO {PROJECTS_TABLE_NAME}
|
|
68
|
+
(project_id, name, dialect, description, service_account_key, auto_import, user_sub, created_at, updated_at)
|
|
69
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
70
|
+
ON CONFLICT (project_id) DO UPDATE SET
|
|
71
|
+
name = EXCLUDED.name,
|
|
72
|
+
dialect = EXCLUDED.dialect,
|
|
73
|
+
description = EXCLUDED.description,
|
|
74
|
+
service_account_key = EXCLUDED.service_account_key,
|
|
75
|
+
auto_import = EXCLUDED.auto_import,
|
|
76
|
+
updated_at = EXCLUDED.updated_at
|
|
77
|
+
""",
|
|
78
|
+
project_id,
|
|
79
|
+
body.name,
|
|
80
|
+
body.dialect,
|
|
81
|
+
body.description or "",
|
|
82
|
+
body.service_account_key,
|
|
83
|
+
body.auto_import or False,
|
|
84
|
+
body.user_sub or "",
|
|
85
|
+
now,
|
|
86
|
+
now,
|
|
87
|
+
)
|
|
88
|
+
rows = await query(
|
|
89
|
+
f"SELECT * FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
|
|
90
|
+
(project_id,),
|
|
91
|
+
)
|
|
92
|
+
return JSONResponse(content=_row_to_project(rows[0]), status_code=201)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@router.delete("/projects/{project_id}")
|
|
96
|
+
async def delete_project(project_id: str):
|
|
97
|
+
rows = await query(
|
|
98
|
+
f"SELECT project_id FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
|
|
99
|
+
(project_id,),
|
|
100
|
+
)
|
|
101
|
+
if not rows:
|
|
102
|
+
raise HTTPException(status_code=404, detail="Project not found")
|
|
103
|
+
await execute(
|
|
104
|
+
f"DELETE FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
|
|
105
|
+
project_id,
|
|
106
|
+
)
|
|
107
|
+
return JSONResponse(status_code=204, content={"message": "Project deleted"})
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@router.delete("/projects/{project_id}/table/{table_name}")
|
|
111
|
+
async def delete_project_table(project_id: str, table_name: str):
|
|
112
|
+
rows = await query(
|
|
113
|
+
f"SELECT json_schema FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
|
|
114
|
+
(project_id,),
|
|
115
|
+
)
|
|
116
|
+
if not rows:
|
|
117
|
+
raise HTTPException(status_code=404, detail="Project not found")
|
|
118
|
+
|
|
119
|
+
schema = json.loads(rows[0].get("json_schema") or "[]")
|
|
120
|
+
updated = [t for t in schema if t.get("table_name") != table_name]
|
|
121
|
+
|
|
122
|
+
await execute(
|
|
123
|
+
f"UPDATE {PROJECTS_TABLE_NAME} SET json_schema = $1, updated_at = $2 WHERE project_id = $3",
|
|
124
|
+
json.dumps(updated, ensure_ascii=False),
|
|
125
|
+
datetime.now().isoformat(),
|
|
126
|
+
project_id,
|
|
127
|
+
)
|
|
128
|
+
return {"removed": table_name, "remaining": len(updated)}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@router.post("/projects/share")
|
|
132
|
+
async def share_project(body: ShareProjectRequest):
|
|
133
|
+
target_rows = await query(
|
|
134
|
+
f"SELECT user_id FROM {USERS_TABLE_NAME} WHERE email = $1 OR user_id = $1",
|
|
135
|
+
(body.target,),
|
|
136
|
+
)
|
|
137
|
+
if not target_rows:
|
|
138
|
+
raise HTTPException(status_code=404, detail=f"User not found: {body.target}")
|
|
139
|
+
target_user_id = target_rows[0]["user_id"]
|
|
140
|
+
await grant_role(target_user_id, body.project, role="user")
|
|
141
|
+
return {"shared": True, "user_id": target_user_id}
|