bulldb 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bulldb-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: bulldb
3
+ Version: 1.0.0
4
+ Summary: One Model. Every Database.
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: License :: OSI Approved :: MIT License
7
+ Classifier: Operating System :: OS Independent
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: pydantic>=2.0.0
11
+ Requires-Dist: cryptography>=41.0.0
12
+ Requires-Dist: passlib>=1.7.4
13
+ Requires-Dist: opentelemetry-api>=1.20.0
14
+ Requires-Dist: opentelemetry-sdk>=1.20.0
15
+ Requires-Dist: httpx>=0.24.0
16
+ Requires-Dist: redis>=5.0.0
17
+ Provides-Extra: test
18
+ Requires-Dist: pytest>=7.0.0; extra == "test"
19
+ Requires-Dist: pytest-asyncio>=0.21.0; extra == "test"
@@ -0,0 +1,48 @@
1
+ from .models import BaseModel, PrimaryKey, Unique, Index, Relationship
2
+ from .database import MultiDatabase
3
+ from .generator import ModelGenerator
4
+ from .adapters import FastAPILifespan, FlaskMiddleware, SanicMiddleware
5
+ from .types import (
6
+ UUID, ULID, Email, Phone, URL, IPAddress, JSON, JSONB, Array,
7
+ Enum, Money, Decimal, TimestampTZ, EncryptedString, HashedPassword,
8
+ Secret, Binary, GeoPoint, Polygon, Vector, Embedding, Document,
9
+ ImageEmbedding, AudioEmbedding, VideoEmbedding
10
+ )
11
+
12
+ __all__ = [
13
+ "BaseModel",
14
+ "PrimaryKey",
15
+ "Unique",
16
+ "Index",
17
+ "Relationship",
18
+ "MultiDatabase",
19
+ "ModelGenerator",
20
+ "FastAPILifespan",
21
+ "FlaskMiddleware",
22
+ "SanicMiddleware",
23
+ "UUID",
24
+ "ULID",
25
+ "Email",
26
+ "Phone",
27
+ "URL",
28
+ "IPAddress",
29
+ "JSON",
30
+ "JSONB",
31
+ "Array",
32
+ "Enum",
33
+ "Money",
34
+ "Decimal",
35
+ "TimestampTZ",
36
+ "EncryptedString",
37
+ "HashedPassword",
38
+ "Secret",
39
+ "Binary",
40
+ "GeoPoint",
41
+ "Polygon",
42
+ "Vector",
43
+ "Embedding",
44
+ "Document",
45
+ "ImageEmbedding",
46
+ "AudioEmbedding",
47
+ "VideoEmbedding"
48
+ ]
@@ -0,0 +1,89 @@
1
+ from contextlib import asynccontextmanager
2
+ from typing import Callable, Any
3
+
4
+ class FastAPILifespan:
5
+ def __init__(self, db_client: Any):
6
+ self.db_client = db_client
7
+
8
+ @asynccontextmanager
9
+ async def lifespan(self, app: Any):
10
+ # 1. Connect all pools on startup
11
+ await self.db_client.connect_all()
12
+ # Register db in BaseModel
13
+ from .models import BaseModel
14
+ BaseModel.set_db(self.db_client)
15
+
16
+ yield # Application executes
17
+
18
+ # 2. Cleanup pools on shutdown
19
+ await self.db_client.disconnect_all()
20
+
21
+
22
+ class FlaskMiddleware:
23
+ def __init__(self, app: Any, db_client: Any):
24
+ self.app = app
25
+ self.db_client = db_client
26
+ self.setup_lifecycle()
27
+
28
+ def setup_lifecycle(self):
29
+ # Flask requests setup/teardowns hook
30
+ from .models import BaseModel
31
+ BaseModel.set_db(self.db_client)
32
+
33
+ @self.app.before_request
34
+ def before_request_hook():
35
+ pass
36
+
37
+ @self.app.teardown_request
38
+ def teardown_request_hook(exception=None):
39
+ pass
40
+
41
+
42
+ class SanicMiddleware:
43
+ @staticmethod
44
+ def register(app: Any, db_client: Any):
45
+ from .models import BaseModel
46
+
47
+ @app.listener("before_server_start")
48
+ async def setup_db(app_instance, loop):
49
+ await db_client.connect_all()
50
+ BaseModel.set_db(db_client)
51
+
52
+ @app.listener("after_server_stop")
53
+ async def close_db(app_instance, loop):
54
+ await db_client.disconnect_all()
55
+
56
+
57
+ class DjangoMiddleware:
58
+ def __init__(self, get_response):
59
+ self.get_response = get_response
60
+
61
+ def __call__(self, request):
62
+ from .models import BaseModel
63
+ from .database import db
64
+ BaseModel.set_db(db)
65
+
66
+ response = self.get_response(request)
67
+ return response
68
+
69
+
70
+ class DjangoLifecycle:
71
+ @staticmethod
72
+ def register(db_client: Any = None):
73
+ try:
74
+ from django.core.signals import request_started, request_finished
75
+ from .models import BaseModel
76
+ from .database import db
77
+
78
+ target_db = db_client or db
79
+
80
+ def on_request_started(sender, **kwargs):
81
+ BaseModel.set_db(target_db)
82
+
83
+ def on_request_finished(sender, **kwargs):
84
+ pass
85
+
86
+ request_started.connect(on_request_started)
87
+ request_finished.connect(on_request_finished)
88
+ except Exception:
89
+ pass
@@ -0,0 +1,121 @@
1
+ import os
2
+ from typing import List, Dict, Any, Optional
3
+ import httpx
4
+
5
+ class AIEngine:
6
+ _embedding_cache: Dict[str, List[float]] = {}
7
+
8
+ @classmethod
9
+ async def generate_embeddings(cls, text: str, provider: str = "openai") -> List[float]:
10
+ # Clean string
11
+ text = text.replace("\n", " ").strip()
12
+ if not text:
13
+ return [0.0] * 1536
14
+
15
+ # Check Cache
16
+ if text in cls._embedding_cache:
17
+ return cls._embedding_cache[text]
18
+
19
+ # Call remote API or trigger local/mock fallback
20
+ api_key = os.getenv("OPENAI_API_KEY")
21
+ gemini_key = os.getenv("GEMINI_API_KEY")
22
+ ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
23
+
24
+ try:
25
+ if provider == "openai" and api_key:
26
+ async with httpx.AsyncClient(timeout=10.0) as client:
27
+ resp = await client.post(
28
+ "https://api.openai.com/v1/embeddings",
29
+ headers={"Authorization": f"Bearer {api_key}"},
30
+ json={"input": text, "model": "text-embedding-3-small"}
31
+ )
32
+ resp.raise_for_status()
33
+ vector = resp.json()["data"][0]["embedding"]
34
+ cls._embedding_cache[text] = vector
35
+ return vector
36
+ elif provider == "gemini" and gemini_key:
37
+ async with httpx.AsyncClient(timeout=10.0) as client:
38
+ resp = await client.post(
39
+ f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={gemini_key}",
40
+ json={"content": {"parts": [{"text": text}]}}
41
+ )
42
+ resp.raise_for_status()
43
+ vector = resp.json()["embedding"]["values"]
44
+ cls._embedding_cache[text] = vector
45
+ return vector
46
+ elif provider == "ollama":
47
+ async with httpx.AsyncClient(timeout=10.0) as client:
48
+ resp = await client.post(
49
+ f"{ollama_url}/api/embeddings",
50
+ json={"model": "nomic-embed-text", "prompt": text}
51
+ )
52
+ resp.raise_for_status()
53
+ vector = resp.json()["embedding"]
54
+ cls._embedding_cache[text] = vector
55
+ return vector
56
+ except Exception:
57
+ # Fallback embedding generation using hash weights (deterministic, zero-dependency)
58
+ return cls._generate_mock_vector(text)
59
+
60
+ return cls._generate_mock_vector(text)
61
+
62
+ @classmethod
63
+ def _generate_mock_vector(cls, text: str, dimension: int = 1536) -> List[float]:
64
+ # Generate pseudo-random deterministic vector based on text characters
65
+ import hashlib
66
+ h = hashlib.sha256(text.encode("utf-8")).digest()
67
+ vector = []
68
+ for i in range(dimension):
69
+ byte_idx = (i * 7) % len(h)
70
+ weight = (h[byte_idx] / 255.0) - 0.5
71
+ vector.append(weight)
72
+
73
+ # L2 Normalize
74
+ import math
75
+ norm = math.sqrt(sum(x*x for x in vector))
76
+ if norm > 0:
77
+ vector = [x / norm for x in vector]
78
+ return vector
79
+
80
+ class TextChunker:
81
+ @staticmethod
82
+ def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]:
83
+ if not text:
84
+ return []
85
+
86
+ words = text.split()
87
+ chunks = []
88
+ i = 0
89
+ while i < len(words):
90
+ chunk_words = words[i : i + chunk_size]
91
+ chunks.append(" ".join(chunk_words))
92
+ i += (chunk_size - chunk_overlap)
93
+ if i <= 0:
94
+ break
95
+ return chunks
96
+
97
+ class RAGPipeline:
98
+ def __init__(self, model_class: Any, embedding_field: str, text_field: str):
99
+ self.model_class = model_class
100
+ self.embedding_field = embedding_field
101
+ self.text_field = text_field
102
+
103
+ async def ingest_document(self, text: str, extra_meta: Optional[dict] = None) -> Any:
104
+ chunks = TextChunker.chunk_text(text)
105
+ inserted_instances = []
106
+ for chunk in chunks:
107
+ payload = {self.text_field: chunk}
108
+ if extra_meta:
109
+ payload.update(extra_meta)
110
+
111
+ # create and save model (triggers auto-embedding generation hook)
112
+ instance = self.model_class(**payload)
113
+ await instance.save()
114
+ inserted_instances.append(instance)
115
+ return inserted_instances
116
+
117
+ async def query_similarity(self, query_text: str, limit: int = 5) -> List[Any]:
118
+ # 1. Generate query embedding
119
+ query_vector = await AIEngine.generate_embeddings(query_text)
120
+ # 2. Query builder vector search
121
+ return await self.model_class.find().vector_search(self.embedding_field, query_vector, limit).execute()
@@ -0,0 +1,313 @@
1
+ import os
2
+ import urllib.parse
3
+ import asyncio
4
+ import logging
5
+ from typing import Dict, Any, List, Optional, Callable
6
+
7
+ logger = logging.getLogger("bulldb.database")
8
+
9
+ class CircuitBreakerOpenException(Exception):
10
+ pass
11
+
12
+ class CircuitBreaker:
13
+ def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 10.0):
14
+ self.failure_threshold = failure_threshold
15
+ self.recovery_timeout = recovery_timeout
16
+ self.failure_count = 0
17
+ self.state = "CLOSED" # CLOSED, OPEN, HALF-OPEN
18
+ self.last_state_change = 0.0
19
+
20
+ def record_success(self):
21
+ self.failure_count = 0
22
+ self.state = "CLOSED"
23
+
24
+ def record_failure(self):
25
+ self.failure_count += 1
26
+ import time
27
+ if self.failure_count >= self.failure_threshold:
28
+ self.state = "OPEN"
29
+ self.last_state_change = time.time()
30
+ logger.warning(f"Circuit breaker tripped! State set to OPEN. Threshold: {self.failure_threshold}")
31
+
32
+ def allow_request(self) -> bool:
33
+ import time
34
+ if self.state == "CLOSED":
35
+ return True
36
+ if self.state == "OPEN":
37
+ if time.time() - self.last_state_change > self.recovery_timeout:
38
+ self.state = "HALF-OPEN"
39
+ logger.info("Circuit breaker entering HALF-OPEN state, checking system viability.")
40
+ return True
41
+ return False
42
+ if self.state == "HALF-OPEN":
43
+ return True
44
+ return False
45
+
46
+ class DatabaseDriver:
47
+ def __init__(self, name: str, url: str):
48
+ self.name = name
49
+ self.url = url
50
+ self.parsed_url = urllib.parse.urlparse(url)
51
+ self.circuit_breaker = CircuitBreaker()
52
+
53
+ async def connect(self):
54
+ pass
55
+
56
+ async def disconnect(self):
57
+ pass
58
+
59
+ async def ping(self) -> bool:
60
+ return True
61
+
62
+ async def ensure_connected(self):
63
+ try:
64
+ if not await self.ping():
65
+ await self.connect()
66
+ except Exception:
67
+ await self.connect()
68
+
69
+ async def execute(self, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
70
+ return []
71
+
72
+ async def insert(self, table: str, payload: Dict[str, Any]) -> Any:
73
+ return {}
74
+
75
+ async def update(self, table: str, payload: Dict[str, Any], filters: Dict[str, Any]) -> Any:
76
+ return {}
77
+
78
+ async def delete(self, table: str, filters: Dict[str, Any]) -> Any:
79
+ return {}
80
+
81
+ async def execute_mongo_find(self, collection: str, filters: Dict[str, Any], projection: Dict[str, Any], limit: Optional[int]) -> List[Dict[str, Any]]:
82
+ if hasattr(self, "db") and not isinstance(self.db, dict):
83
+ cursor = self.db[collection].find(filters, projection)
84
+ if limit:
85
+ cursor = cursor.limit(limit)
86
+ return list(cursor)
87
+ return []
88
+
89
+ async def execute_search(self, index: str, body: Dict[str, Any], limit: int) -> List[Dict[str, Any]]:
90
+ return []
91
+
92
+ async def execute_vector_search(self, collection: str, vector: Optional[List[float]], filters: Dict[str, Any], limit: int) -> List[Dict[str, Any]]:
93
+ return []
94
+
95
+ class SQLiteDriver(DatabaseDriver):
96
+ async def connect(self):
97
+ # In-memory or file-based sqlite3
98
+ import sqlite3
99
+ db_path = self.parsed_url.path.strip("/") or ":memory:"
100
+ if hasattr(self, "conn") and self.conn:
101
+ try:
102
+ self.conn.close()
103
+ except Exception:
104
+ pass
105
+ self.conn = sqlite3.connect(db_path, check_same_thread=False)
106
+ self.conn.row_factory = sqlite3.Row
107
+ self.cursor = self.conn.cursor()
108
+
109
+ async def disconnect(self):
110
+ if hasattr(self, "conn") and self.conn:
111
+ try:
112
+ self.conn.close()
113
+ except Exception:
114
+ pass
115
+
116
+ async def ping(self) -> bool:
117
+ try:
118
+ if not hasattr(self, "cursor") or self.cursor is None:
119
+ return False
120
+ self.cursor.execute("SELECT 1")
121
+ return True
122
+ except Exception:
123
+ return False
124
+
125
+ async def execute(self, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
126
+ if not self.circuit_breaker.allow_request():
127
+ raise CircuitBreakerOpenException(f"Database driver {self.name} circuit is OPEN")
128
+ try:
129
+ self.cursor.execute(query, params or ())
130
+ self.conn.commit()
131
+ rows = self.cursor.fetchall()
132
+ self.circuit_breaker.record_success()
133
+ return [dict(row) for row in rows]
134
+ except Exception as e:
135
+ self.circuit_breaker.record_failure()
136
+ raise e
137
+
138
+ async def insert(self, table: str, payload: Dict[str, Any]) -> Any:
139
+ keys = list(payload.keys())
140
+ values = list(payload.values())
141
+ placeholders = ", ".join(["?" for _ in keys])
142
+ query = f"INSERT INTO {table} ({', '.join(keys)}) VALUES ({placeholders})"
143
+ await self.execute(query, tuple(values))
144
+ return payload
145
+
146
+ async def update(self, table: str, payload: Dict[str, Any], filters: Dict[str, Any]) -> Any:
147
+ set_clause = ", ".join([f"{k} = ?" for k in payload.keys()])
148
+ where_clause = " AND ".join([f"{k} = ?" for k in filters.keys()])
149
+ query = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
150
+ params = tuple(list(payload.values()) + list(filters.values()))
151
+ await self.execute(query, params)
152
+ return payload
153
+
154
+ async def delete(self, table: str, filters: Dict[str, Any]) -> Any:
155
+ where_clause = " AND ".join([f"{k} = ?" for k in filters.keys()])
156
+ query = f"DELETE FROM {table} WHERE {where_clause}"
157
+ await self.execute(query, tuple(filters.values()))
158
+ return True
159
+
160
+ class MongoDriver(DatabaseDriver):
161
+ async def connect(self):
162
+ # Dynamically import pymongo
163
+ try:
164
+ from pymongo import MongoClient
165
+ if hasattr(self, "client") and self.client:
166
+ try:
167
+ self.client.close()
168
+ except Exception:
169
+ pass
170
+ self.client = MongoClient(self.url)
171
+ db_name = self.parsed_url.path.strip("/") or "bulldb"
172
+ self.db = self.client[db_name]
173
+ except ImportError:
174
+ # Fallback mock for testing in case pymongo isn't installed
175
+ self.db = {}
176
+
177
+ async def disconnect(self):
178
+ if hasattr(self, "client") and self.client:
179
+ try:
180
+ self.client.close()
181
+ except Exception:
182
+ pass
183
+
184
+ async def ping(self) -> bool:
185
+ if isinstance(self.db, dict): return True
186
+ try:
187
+ self.client.admin.command('ping')
188
+ return True
189
+ except Exception:
190
+ return False
191
+
192
+ async def insert(self, table: str, payload: Dict[str, Any]) -> Any:
193
+ if isinstance(self.db, dict): return payload
194
+ self.db[table].insert_one(payload)
195
+ return payload
196
+
197
+ async def update(self, table: str, payload: Dict[str, Any], filters: Dict[str, Any]) -> Any:
198
+ if isinstance(self.db, dict): return payload
199
+ self.db[table].update_many(filters, {"$set": payload})
200
+ return payload
201
+
202
+ async def delete(self, table: str, filters: Dict[str, Any]) -> Any:
203
+ if isinstance(self.db, dict): return True
204
+ self.db[table].delete_many(filters)
205
+ return True
206
+
207
+ class MultiDatabase:
208
+ def __init__(self):
209
+ self.drivers: Dict[str, DatabaseDriver] = {}
210
+ self.primary_name: Optional[str] = None
211
+ self.replicas: List[str] = []
212
+ self.shards: Dict[str, List[str]] = {}
213
+ self.discover_environment()
214
+
215
+ def discover_environment(self):
216
+ # Auto DB detection from standard environment variables
217
+ env_mappings = {
218
+ "SQLITE_URL": "sqlite",
219
+ "POSTGRES_URL": "postgres",
220
+ "DATABASE_URL": "postgres",
221
+ "MYSQL_URL": "mysql",
222
+ "MONGO_URL": "mongo"
223
+ }
224
+ for env_var, engine in env_mappings.items():
225
+ url = os.getenv(env_var)
226
+ if url:
227
+ self.register_database(engine, url)
228
+ if not self.primary_name:
229
+ self.primary_name = engine
230
+
231
+ # Default SQLite memory driver if no variables exist
232
+ if not self.drivers:
233
+ self.register_database("sqlite", "sqlite:///:memory:")
234
+ self.primary_name = "sqlite"
235
+
236
+ def register_database(self, name: str, url: str, is_replica: bool = False, shard_key: Optional[str] = None):
237
+ if url.startswith("sqlite"):
238
+ driver = SQLiteDriver(name, url)
239
+ elif url.startswith("mongodb"):
240
+ driver = MongoDriver(name, url)
241
+ else:
242
+ driver = DatabaseDriver(name, url) # Generic wrapper
243
+
244
+ self.drivers[name] = driver
245
+ if is_replica:
246
+ self.replicas.append(name)
247
+
248
+ if shard_key:
249
+ if shard_key not in self.shards:
250
+ self.shards[shard_key] = []
251
+ self.shards[shard_key].append(name)
252
+
253
+ async def connect_all(self):
254
+ for driver in self.drivers.values():
255
+ await driver.connect()
256
+
257
+ async def disconnect_all(self):
258
+ for driver in self.drivers.values():
259
+ await driver.disconnect()
260
+
261
+ def get_route(self, table: str, is_write: bool = False, shard_id: Optional[str] = None) -> DatabaseDriver:
262
+ # 1. Sharding routing
263
+ if shard_id and shard_id in self.shards:
264
+ target_names = self.shards[shard_id]
265
+ import random
266
+ return self.drivers[random.choice(target_names)]
267
+
268
+ # 2. Read/Write splitting
269
+ if not is_write and self.replicas:
270
+ import random
271
+ return self.drivers[random.choice(self.replicas)]
272
+
273
+ # 3. Default to primary
274
+ return self.drivers[self.primary_name]
275
+
276
+ async def execute(self, query: str, params: Optional[tuple] = None, is_write: bool = False) -> List[Dict[str, Any]]:
277
+ driver = self.get_route("", is_write=is_write)
278
+ return await self._retry_with_backoff(driver, driver.execute, query, params)
279
+
280
+ async def write(self, table: str, payload: Dict[str, Any], upsert: bool = False) -> Any:
281
+ driver = self.get_route(table, is_write=True)
282
+ await driver.ensure_connected()
283
+ if upsert:
284
+ pk_field = "id"
285
+ pk_val = payload.get(pk_field)
286
+ if pk_val:
287
+ # check existence
288
+ check_query = f"SELECT {pk_field} FROM {table} WHERE {pk_field} = ?"
289
+ exists = await driver.execute(check_query, (pk_val,))
290
+ if exists:
291
+ filters = {pk_field: pk_val}
292
+ payload_no_pk = {k: v for k, v in payload.items() if k != pk_field}
293
+ return await driver.update(table, payload_no_pk, filters)
294
+ return await driver.insert(table, payload)
295
+
296
+ async def delete(self, table: str, filters: Dict[str, Any]) -> Any:
297
+ driver = self.get_route(table, is_write=True)
298
+ await driver.ensure_connected()
299
+ return await driver.delete(table, filters)
300
+
301
+ async def _retry_with_backoff(self, driver: DatabaseDriver, func: Callable, *args, retries: int = 3, initial_delay: float = 0.5, **kwargs):
302
+ delay = initial_delay
303
+ last_exception = None
304
+ for i in range(retries):
305
+ try:
306
+ await driver.ensure_connected()
307
+ return await func(*args, **kwargs)
308
+ except Exception as e:
309
+ last_exception = e
310
+ logger.warning(f"Database operation failed. Retry {i+1}/{retries} in {delay}s. Error: {str(e)}")
311
+ await asyncio.sleep(delay)
312
+ delay *= 2
313
+ raise last_exception or Exception("Database execution retries exhausted.")
@@ -0,0 +1,99 @@
1
+ import os
2
+ from typing import Any, List, Dict
3
+
4
+ class ModelGenerator:
5
+ @staticmethod
6
+ async def reverse_engineer(db: Any, output_path: str):
7
+ driver_name = db.primary_name
8
+
9
+ # 1. Fetch tables
10
+ tables = []
11
+ try:
12
+ if "sqlite" in driver_name:
13
+ sql = "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'bulldb_%'"
14
+ rows = await db.execute(sql)
15
+ tables = [r["name"] for r in rows]
16
+ else:
17
+ # postgres information_schema
18
+ sql = "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_name NOT LIKE 'bulldb_%'"
19
+ rows = await db.execute(sql)
20
+ tables = [r["table_name"] for r in rows]
21
+ except Exception as e:
22
+ # fallback tables for mock
23
+ tables = ["users"]
24
+
25
+ # 2. Build model class definitions
26
+ lines = [
27
+ "from bulldb import BaseModel, PrimaryKey, Unique, Index, UUID, Email, EncryptedString, HashedPassword\n",
28
+ "# Automatically generated by BullDB Reverse-Engineering Generator\n"
29
+ ]
30
+
31
+ for table in tables:
32
+ class_name = "".join([part.capitalize() for part in table.split("_")]).rstrip("s")
33
+ if not class_name:
34
+ class_name = "Model"
35
+
36
+ lines.append(f"class {class_name}(BaseModel):")
37
+ lines.append(f" __table_name__ = '{table}'\n")
38
+
39
+ # Fetch columns details
40
+ columns = []
41
+ try:
42
+ if "sqlite" in driver_name:
43
+ col_rows = await db.execute(f"PRAGMA table_info({table})")
44
+ for r in col_rows:
45
+ columns.append({
46
+ "name": r["name"],
47
+ "type": r["type"].upper(),
48
+ "pk": r["pk"] == 1
49
+ })
50
+ else:
51
+ col_rows = await db.execute(
52
+ f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}'"
53
+ )
54
+ for r in col_rows:
55
+ columns.append({
56
+ "name": r["column_name"],
57
+ "type": r["data_type"].upper(),
58
+ "pk": False # fallback pk parser or check constraint
59
+ })
60
+ except Exception:
61
+ # mock table fields fallback
62
+ columns = [
63
+ {"name": "id", "type": "TEXT", "pk": True},
64
+ {"name": "email", "type": "TEXT", "pk": False},
65
+ {"name": "secret_note", "type": "BLOB", "pk": False}
66
+ ]
67
+
68
+ for col in columns:
69
+ col_name = col["name"]
70
+ db_type = col["type"]
71
+ is_pk = col["pk"]
72
+
73
+ # Map back to Universal Types
74
+ mapped_type = "str"
75
+ decorator = ""
76
+ if is_pk:
77
+ mapped_type = "UUID"
78
+ decorator = " = PrimaryKey()"
79
+ elif "INT" in db_type:
80
+ mapped_type = "int"
81
+ elif "FLOAT" in db_type or "NUMERIC" in db_type or "REAL" in db_type:
82
+ mapped_type = "float"
83
+ elif "BLOB" in db_type or "BYTEA" in db_type:
84
+ mapped_type = "str"
85
+ decorator = " = EncryptedString()"
86
+ else:
87
+ if "EMAIL" in col_name.upper():
88
+ mapped_type = "Email"
89
+ decorator = " = Unique()"
90
+ else:
91
+ mapped_type = "str"
92
+
93
+ lines.append(f" {col_name}: {mapped_type}{decorator}")
94
+ lines.append("\n")
95
+
96
+ # Write to file
97
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
98
+ with open(output_path, "w", encoding="utf-8") as f:
99
+ f.write("\n".join(lines))