mocksql 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. app/__init__.py +0 -0
  2. app/api/__init__.py +0 -0
  3. app/api/endpoints/__init__.py +0 -0
  4. app/api/endpoints/messages.py +114 -0
  5. app/api/endpoints/models.py +117 -0
  6. app/api/endpoints/projects.py +141 -0
  7. app/api/endpoints/query.py +413 -0
  8. app/api/endpoints/users.py +44 -0
  9. app/exceptions/__init__.py +0 -0
  10. app/exceptions/exceptions.py +4 -0
  11. app/services/__init__.py +0 -0
  12. app/services/query_service.py +2 -0
  13. build_query/__init__.py +0 -0
  14. build_query/constraint_simplifier.py +1066 -0
  15. build_query/converstion_history.py +336 -0
  16. build_query/examples_executor.py +870 -0
  17. build_query/examples_generator.py +424 -0
  18. build_query/other.py +62 -0
  19. build_query/profile_checker.py +342 -0
  20. build_query/profiler.py +2135 -0
  21. build_query/prompt_tools.py +864 -0
  22. build_query/query_chain.py +172 -0
  23. build_query/query_executor.py +125 -0
  24. build_query/routing.py +113 -0
  25. build_query/schema_fetcher.py +138 -0
  26. build_query/state.py +48 -0
  27. build_query/test_evaluator.py +124 -0
  28. build_query/validator.py +546 -0
  29. cli/__init__.py +0 -0
  30. cli/generate.py +388 -0
  31. cli/main.py +321 -0
  32. cli/test_runner.py +323 -0
  33. common_vars.py +365 -0
  34. fetch_secrets.py +34 -0
  35. init/__init__.py +0 -0
  36. init/add_column.py +153 -0
  37. init/add_table.py +129 -0
  38. init/create_user.py +72 -0
  39. init/grant_access_to_db.py +31 -0
  40. init/init_db.py +154 -0
  41. mocksql-0.1.0.dist-info/METADATA +232 -0
  42. mocksql-0.1.0.dist-info/RECORD +80 -0
  43. mocksql-0.1.0.dist-info/WHEEL +4 -0
  44. mocksql-0.1.0.dist-info/entry_points.txt +10 -0
  45. models/__init__.py +0 -0
  46. models/database.py +66 -0
  47. models/db_pool.py +161 -0
  48. models/env_variables.py +41 -0
  49. models/message_service.py +286 -0
  50. models/model.py +2 -0
  51. models/model_service.py +2 -0
  52. models/permissions.py +66 -0
  53. models/schemas.py +122 -0
  54. models/session_service.py +105 -0
  55. models/user_service.py +99 -0
  56. server.py +95 -0
  57. sql_functions/__init__.py +0 -0
  58. sql_functions/functions.py +50 -0
  59. sql_functions/helpers.py +276 -0
  60. storage/__init__.py +0 -0
  61. storage/config.py +47 -0
  62. storage/test_repository.py +257 -0
  63. utils/__init__.py +19 -0
  64. utils/bigquery_test_helper.py +416 -0
  65. utils/duckdb_test_helper.py +235 -0
  66. utils/errors.py +376 -0
  67. utils/examples.py +754 -0
  68. utils/find_grains.py +822 -0
  69. utils/insert_examples.py +366 -0
  70. utils/llm_errors.py +25 -0
  71. utils/logger.py +0 -0
  72. utils/models.py +141 -0
  73. utils/msg_types.py +25 -0
  74. utils/postgres_db_utils.py +39 -0
  75. utils/postgres_test_helper.py +178 -0
  76. utils/prompt_utils.py +27 -0
  77. utils/query_services.py +17 -0
  78. utils/saver.py +270 -0
  79. utils/schema_utils.py +125 -0
  80. utils/sql_code.py +224 -0
app/__init__.py ADDED
File without changes
app/api/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,114 @@
1
+ import logging
2
+ from typing import List, Any, Optional
3
+
4
+ from fastapi import APIRouter, HTTPException
5
+ from pydantic import BaseModel
6
+
7
+ from storage.config import is_initialized
8
+ from storage.test_repository import get_test, update_test
9
+ from utils.saver import common_history_retriever
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ router = APIRouter()
14
+
15
+
16
+ class MessageRequest(BaseModel):
17
+ modelId: str
18
+
19
+
20
+ class PatchTestsRequest(BaseModel):
21
+ sessionId: str
22
+ tests: List[Any]
23
+
24
+
25
+ class PatchSqlRequest(BaseModel):
26
+ sessionId: str
27
+ sql: str
28
+ optimized_sql: str = ""
29
+ tests: Optional[List[Any]] = None
30
+ test_results: Optional[List[Any]] = None
31
+ restored_message_id: Optional[str] = None
32
+ last_error: Optional[str] = None
33
+
34
+
35
+ @router.post("/getMessages")
36
+ async def get_messages(body: MessageRequest):
37
+ if not is_initialized():
38
+ raise HTTPException(
39
+ status_code=400,
40
+ detail="Projet non initialisé. Lancez 'mocksql init' dans votre répertoire de travail pour commencer.",
41
+ )
42
+ try:
43
+ history = await common_history_retriever(body.modelId, filtered_types=[])
44
+ if history is None:
45
+ raise HTTPException(status_code=404, detail="Session not found")
46
+
47
+ test = get_test(body.modelId)
48
+ sql = test.get("sql") if test else None
49
+ optimized_sql = test.get("optimized_sql") if test else None
50
+ last_error = test.get("last_error") if test else ""
51
+ test_results = test.get("test_cases", []) if test else []
52
+ restored_message_id = test.get("restored_message_id") if test else None
53
+
54
+ # Fallback: extract from last results message in history
55
+ if not test_results:
56
+ import json
57
+
58
+ for msg in reversed(history):
59
+ if msg.additional_kwargs.get("type") == "results":
60
+ try:
61
+ test_results = json.loads(msg.content)
62
+ except Exception:
63
+ pass
64
+ break
65
+
66
+ return {
67
+ "messages": history,
68
+ "sql": sql,
69
+ "optimized_sql": optimized_sql,
70
+ "test_results": test_results,
71
+ "restored_message_id": restored_message_id,
72
+ "last_error": last_error or "",
73
+ "sql_history": [],
74
+ }
75
+ except HTTPException:
76
+ raise
77
+ except Exception as e:
78
+ logger.exception(
79
+ "Erreur lors du chargement des messages pour la session %s", body.modelId
80
+ )
81
+ raise HTTPException(status_code=500, detail=str(e))
82
+
83
+
84
+ @router.patch("/models/sql")
85
+ async def patch_model_sql(body: PatchSqlRequest):
86
+ try:
87
+ updates: dict = {
88
+ "sql": body.sql,
89
+ "optimized_sql": body.optimized_sql,
90
+ }
91
+ if body.tests is not None:
92
+ updates["test_cases"] = body.tests
93
+ if body.test_results is not None:
94
+ updates["test_cases"] = body.test_results
95
+ if body.restored_message_id is not None:
96
+ updates["restored_message_id"] = body.restored_message_id or None
97
+ if body.last_error is not None:
98
+ updates["last_error"] = body.last_error
99
+
100
+ update_test(body.sessionId, updates)
101
+ return {"ok": True}
102
+ except Exception as e:
103
+ print(e)
104
+ raise HTTPException(status_code=500, detail="Internal Server Error")
105
+
106
+
107
+ @router.patch("/models/tests")
108
+ async def patch_model_tests(body: PatchTestsRequest):
109
+ try:
110
+ update_test(body.sessionId, {"test_cases": body.tests})
111
+ return {"ok": True}
112
+ except Exception as e:
113
+ print(e)
114
+ raise HTTPException(status_code=500, detail="Internal Server Error")
@@ -0,0 +1,117 @@
1
+ from fastapi import APIRouter, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+
5
+ from storage.test_repository import (
6
+ list_models,
7
+ list_all_tests,
8
+ list_tests,
9
+ get_test,
10
+ create_test,
11
+ delete_test,
12
+ _test_path,
13
+ _read_json,
14
+ )
15
+
16
+ router = APIRouter()
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # SQL Models (fichiers physiques dans models_path)
21
+ # ---------------------------------------------------------------------------
22
+
23
+
24
+ @router.get("/models")
25
+ async def get_models():
26
+ """Liste les fichiers .sql disponibles dans models_path.
27
+
28
+ Each entry includes session_id / updated_at / test_name when a test already
29
+ exists for that model, so the frontend can redirect instead of regenerating.
30
+ Tested models are returned first (sorted by updated_at desc), then untested.
31
+ """
32
+ try:
33
+ sql_files = list_models()
34
+ tested = []
35
+ untested = []
36
+ for f in sql_files:
37
+ p = _test_path(f["name"])
38
+ if p.exists():
39
+ data = _read_json(p)
40
+ if data:
41
+ tested.append(
42
+ {
43
+ **f,
44
+ "session_id": data.get("test_id"),
45
+ "updated_at": data.get("updated_at"),
46
+ "test_name": data.get("test_name"),
47
+ }
48
+ )
49
+ continue
50
+ untested.append(f)
51
+ tested.sort(key=lambda x: x.get("updated_at") or "", reverse=True)
52
+ return tested + untested
53
+ except Exception as e:
54
+ raise HTTPException(status_code=500, detail=str(e))
55
+
56
+
57
+ @router.get("/tests/all")
58
+ async def get_all_tests():
59
+ """Liste toutes les sessions de test (tous models confondus), triées par date décroissante."""
60
+ try:
61
+ return list_all_tests()
62
+ except Exception as e:
63
+ raise HTTPException(status_code=500, detail=str(e))
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Tests (fichiers dans .mocksql/tests/)
68
+ # ---------------------------------------------------------------------------
69
+
70
+
71
+ class CreateTestRequest(BaseModel):
72
+ model_name: str
73
+
74
+
75
+ @router.get("/tests")
76
+ async def get_tests(model_name: str):
77
+ """Liste tous les tests pour un model donné."""
78
+ try:
79
+ return list_tests(model_name)
80
+ except Exception as e:
81
+ raise HTTPException(status_code=500, detail=str(e))
82
+
83
+
84
+ @router.get("/test/{session_id}")
85
+ async def get_test_route(session_id: str, model_name: str = None):
86
+ try:
87
+ test = get_test(session_id, model_name)
88
+ if test is None:
89
+ raise HTTPException(status_code=404, detail="Test not found")
90
+ return test
91
+ except HTTPException:
92
+ raise
93
+ except Exception as e:
94
+ raise HTTPException(status_code=500, detail=str(e))
95
+
96
+
97
+ @router.post("/tests")
98
+ async def create_test_route(body: CreateTestRequest):
99
+ """Crée un nouveau test pour un model. Retourne le test_id (= session_id)."""
100
+ try:
101
+ test = create_test(body.model_name)
102
+ return JSONResponse(content=test, status_code=201)
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))
105
+
106
+
107
+ @router.delete("/tests/{session_id}")
108
+ async def delete_test_route(session_id: str, model_name: str):
109
+ try:
110
+ ok = delete_test(session_id, model_name)
111
+ if not ok:
112
+ raise HTTPException(status_code=404, detail="Test not found")
113
+ return JSONResponse(status_code=204, content={"message": "Test deleted"})
114
+ except HTTPException:
115
+ raise
116
+ except Exception as e:
117
+ raise HTTPException(status_code=500, detail=str(e))
@@ -0,0 +1,141 @@
1
+ import json
2
+ import uuid
3
+ from datetime import datetime
4
+ from typing import Optional
5
+
6
+ from fastapi import APIRouter, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from pydantic import BaseModel
9
+
10
+ from common_vars import PROJECTS_TABLE_NAME, USERS_TABLE_NAME
11
+ from models.database import execute, query
12
+ from models.permissions import grant_role
13
+
14
+ router = APIRouter()
15
+
16
+
17
+ class ProjectRequest(BaseModel):
18
+ project_id: Optional[str] = None
19
+ name: str
20
+ dialect: str
21
+ description: Optional[str] = None
22
+ service_account_key: Optional[str] = None
23
+ auto_import: Optional[bool] = False
24
+ user_sub: Optional[str] = ""
25
+
26
+
27
+ class ShareProjectRequest(BaseModel):
28
+ project: str
29
+ target: str
30
+
31
+
32
+ def _row_to_project(row: dict) -> dict:
33
+ return {
34
+ "project_id": row["project_id"],
35
+ "name": row["name"],
36
+ "dialect": row["dialect"],
37
+ "description": row.get("description") or "",
38
+ "service_account_key": row.get("service_account_key"),
39
+ "auto_import": row.get("auto_import") or False,
40
+ "schema": json.loads(row["json_schema"]) if row.get("json_schema") else [],
41
+ }
42
+
43
+
44
+ @router.get("/projects")
45
+ async def get_projects():
46
+ rows = await query(f"SELECT * FROM {PROJECTS_TABLE_NAME}")
47
+ return [_row_to_project(r) for r in (rows or [])]
48
+
49
+
50
+ @router.get("/project/{project_id}")
51
+ async def get_project(project_id: str):
52
+ rows = await query(
53
+ f"SELECT * FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
54
+ (project_id,),
55
+ )
56
+ if not rows:
57
+ raise HTTPException(status_code=404, detail="Project not found")
58
+ return _row_to_project(rows[0])
59
+
60
+
61
+ @router.post("/projects")
62
+ async def create_project(body: ProjectRequest):
63
+ project_id = body.project_id or str(uuid.uuid4())
64
+ now = datetime.now().isoformat()
65
+ await execute(
66
+ f"""
67
+ INSERT INTO {PROJECTS_TABLE_NAME}
68
+ (project_id, name, dialect, description, service_account_key, auto_import, user_sub, created_at, updated_at)
69
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
70
+ ON CONFLICT (project_id) DO UPDATE SET
71
+ name = EXCLUDED.name,
72
+ dialect = EXCLUDED.dialect,
73
+ description = EXCLUDED.description,
74
+ service_account_key = EXCLUDED.service_account_key,
75
+ auto_import = EXCLUDED.auto_import,
76
+ updated_at = EXCLUDED.updated_at
77
+ """,
78
+ project_id,
79
+ body.name,
80
+ body.dialect,
81
+ body.description or "",
82
+ body.service_account_key,
83
+ body.auto_import or False,
84
+ body.user_sub or "",
85
+ now,
86
+ now,
87
+ )
88
+ rows = await query(
89
+ f"SELECT * FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
90
+ (project_id,),
91
+ )
92
+ return JSONResponse(content=_row_to_project(rows[0]), status_code=201)
93
+
94
+
95
+ @router.delete("/projects/{project_id}")
96
+ async def delete_project(project_id: str):
97
+ rows = await query(
98
+ f"SELECT project_id FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
99
+ (project_id,),
100
+ )
101
+ if not rows:
102
+ raise HTTPException(status_code=404, detail="Project not found")
103
+ await execute(
104
+ f"DELETE FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
105
+ project_id,
106
+ )
107
+ return JSONResponse(status_code=204, content={"message": "Project deleted"})
108
+
109
+
110
+ @router.delete("/projects/{project_id}/table/{table_name}")
111
+ async def delete_project_table(project_id: str, table_name: str):
112
+ rows = await query(
113
+ f"SELECT json_schema FROM {PROJECTS_TABLE_NAME} WHERE project_id = $1",
114
+ (project_id,),
115
+ )
116
+ if not rows:
117
+ raise HTTPException(status_code=404, detail="Project not found")
118
+
119
+ schema = json.loads(rows[0].get("json_schema") or "[]")
120
+ updated = [t for t in schema if t.get("table_name") != table_name]
121
+
122
+ await execute(
123
+ f"UPDATE {PROJECTS_TABLE_NAME} SET json_schema = $1, updated_at = $2 WHERE project_id = $3",
124
+ json.dumps(updated, ensure_ascii=False),
125
+ datetime.now().isoformat(),
126
+ project_id,
127
+ )
128
+ return {"removed": table_name, "remaining": len(updated)}
129
+
130
+
131
+ @router.post("/projects/share")
132
+ async def share_project(body: ShareProjectRequest):
133
+ target_rows = await query(
134
+ f"SELECT user_id FROM {USERS_TABLE_NAME} WHERE email = $1 OR user_id = $1",
135
+ (body.target,),
136
+ )
137
+ if not target_rows:
138
+ raise HTTPException(status_code=404, detail=f"User not found: {body.target}")
139
+ target_user_id = target_rows[0]["user_id"]
140
+ await grant_role(target_user_id, body.project, role="user")
141
+ return {"shared": True, "user_id": target_user_id}