starspring 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ """
2
+ Entity decorators and base classes for ORM
3
+
4
+ Provides Spring Boot-style entity annotations for database models.
5
+ """
6
+
7
+ from typing import Any, Optional, Type, TypeVar, List, get_type_hints
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ import inspect
11
+
12
+
13
+ T = TypeVar('T')
14
+
15
+
16
+ class GenerationType(Enum):
17
+ """ID generation strategies"""
18
+ AUTO = "auto"
19
+ IDENTITY = "identity"
20
+ SEQUENCE = "sequence"
21
+ UUID = "uuid"
22
+
23
+
24
+ class ColumnMetadata:
25
+ """Metadata for a database column"""
26
+
27
+ def __init__(
28
+ self,
29
+ name: Optional[str] = None,
30
+ type: Optional[Type] = None,
31
+ nullable: bool = True,
32
+ unique: bool = False,
33
+ default: Any = None,
34
+ length: Optional[int] = None,
35
+ primary_key: bool = False,
36
+ auto_increment: bool = False
37
+ ):
38
+ self.name = name
39
+ self.type = type
40
+ self.nullable = nullable
41
+ self.unique = unique
42
+ self.default = default
43
+ self.length = length
44
+ self.primary_key = primary_key
45
+ self.auto_increment = auto_increment
46
+
47
+
48
+ class RelationshipMetadata:
49
+ """Metadata for entity relationships"""
50
+
51
+ def __init__(
52
+ self,
53
+ target_entity: Type,
54
+ relationship_type: str,
55
+ mapped_by: Optional[str] = None,
56
+ cascade: Optional[List[str]] = None,
57
+ lazy: bool = True
58
+ ):
59
+ self.target_entity = target_entity
60
+ self.relationship_type = relationship_type
61
+ self.mapped_by = mapped_by
62
+ self.cascade = cascade or []
63
+ self.lazy = lazy
64
+
65
+
66
+ class EntityMetadata:
67
+ """Metadata for an entity class"""
68
+
69
+ def __init__(self, table_name: str):
70
+ self.table_name = table_name
71
+ self.columns: dict[str, ColumnMetadata] = {}
72
+ self.relationships: dict[str, RelationshipMetadata] = {}
73
+ self.primary_key: Optional[str] = None
74
+
75
+
76
+ # SQLAlchemy integration
77
+ from sqlalchemy.orm import registry as SARegistry
78
+ from sqlalchemy import Table, Column as SAColumn, Integer, String, Boolean, DateTime, MetaData, ForeignKey
79
+ from sqlalchemy.types import TypeEngine
80
+
81
+ # Global registry for imperative mapping
82
+ mapper_registry = SARegistry()
83
+
84
+ def _get_sa_type(py_type: Type) -> TypeEngine:
85
+ """Convert Python type to SQLAlchemy type"""
86
+ if py_type == int:
87
+ return Integer
88
+ elif py_type == str:
89
+ return String
90
+ elif py_type == bool:
91
+ return Boolean
92
+ elif py_type == datetime:
93
+ return DateTime
94
+ return String # Default fallback
95
+
96
+ def Entity(table_name: Optional[str] = None):
97
+ """
98
+ Mark a class as a database entity
99
+
100
+ Similar to JPA's @Entity annotation.
101
+ """
102
+ def decorator(cls: Type[T]) -> Type[T]:
103
+ # Generate table name from class name if not provided
104
+ if table_name is None:
105
+ import re
106
+ name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
107
+ name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
108
+ final_table_name = name
109
+ else:
110
+ final_table_name = table_name
111
+
112
+ # Store entity metadata (legacy support for parsing)
113
+ metadata = EntityMetadata(final_table_name)
114
+ cls._entity_metadata = metadata # type: ignore
115
+ cls._is_entity = True # type: ignore
116
+
117
+ # Columns list for SQLAlchemy Table
118
+ sa_columns = []
119
+
120
+ # Process class annotations to find columns
121
+ if hasattr(cls, '__annotations__'):
122
+ type_hints = get_type_hints(cls)
123
+ for field_name, field_type in type_hints.items():
124
+ # Check if field has column metadata
125
+ if hasattr(cls, field_name):
126
+ field_value = getattr(cls, field_name)
127
+ if isinstance(field_value, ColumnMetadata):
128
+ # Update metadata
129
+ field_value.name = field_value.name or field_name
130
+ field_value.type = field_value.type or field_type
131
+ metadata.columns[field_name] = field_value
132
+
133
+ if field_value.primary_key:
134
+ metadata.primary_key = field_name
135
+
136
+ # Create SQLAlchemy Column
137
+ sa_col_args = [
138
+ field_value.name,
139
+ _get_sa_type(field_value.type)
140
+ ]
141
+
142
+ sa_col_kwargs = {
143
+ 'nullable': field_value.nullable,
144
+ 'unique': field_value.unique,
145
+ 'primary_key': field_value.primary_key
146
+ }
147
+
148
+ if field_value.auto_increment:
149
+ sa_col_kwargs['autoincrement'] = True
150
+
151
+ # Handle defaults
152
+ # Note: We rely on BaseEntity.__init__ for python-side defaults mostly,
153
+ # but we can set server defaults here if needed.
154
+ # For now, we skip server_default to keep logic simple and consistent with BaseEntity.
155
+
156
+ sa_columns.append(SAColumn(*sa_col_args, **sa_col_kwargs))
157
+
158
+ # CRITICAL FIX: Remove the attribute from the class so SQLAlchemy can
159
+ # install its instrumented descriptor. If we leave the ColumnMetadata object,
160
+ # SQLAlchemy might not replace it, detecting it as a user-defined default.
161
+ if field_name in cls.__dict__:
162
+ delattr(cls, field_name)
163
+
164
+ # Create SQLAlchemy Table and Map
165
+ # Check if already mapped to avoid double mapping on reloads
166
+ try:
167
+ # We map strictly if not already mapped.
168
+ # Using imperative mapping pattern.
169
+ table = Table(final_table_name, mapper_registry.metadata, *sa_columns, extend_existing=True)
170
+ mapper_registry.map_imperatively(cls, table)
171
+ except Exception:
172
+ # If mapping fails (e.g. already mapped), we ignore for now or log
173
+ pass
174
+
175
+ return cls
176
+
177
+ return decorator
178
+
179
+
180
+ def Column(
181
+ name: Optional[str] = None,
182
+ type: Optional[Type] = None,
183
+ nullable: bool = True,
184
+ unique: bool = False,
185
+ default: Any = None,
186
+ length: Optional[int] = None
187
+ ):
188
+ """
189
+ Define a database column
190
+
191
+ Similar to JPA's @Column annotation.
192
+
193
+ Args:
194
+ name: Column name in database
195
+ type: Column data type
196
+ nullable: Whether column can be NULL
197
+ unique: Whether column must be unique
198
+ default: Default value
199
+ length: Maximum length for string columns
200
+
201
+ Example:
202
+ @Column(name="email_address", unique=True, nullable=False)
203
+ email: str
204
+ """
205
+ return ColumnMetadata(
206
+ name=name,
207
+ type=type,
208
+ nullable=nullable,
209
+ unique=unique,
210
+ default=default,
211
+ length=length
212
+ )
213
+
214
+
215
+ def Id():
216
+ """
217
+ Mark a field as the primary key
218
+
219
+ Similar to JPA's @Id annotation.
220
+
221
+ Example:
222
+ @Id
223
+ @Column(type=int)
224
+ id: int
225
+ """
226
+ return ColumnMetadata(primary_key=True)
227
+
228
+
229
+ def GeneratedValue(strategy: GenerationType = GenerationType.AUTO):
230
+ """
231
+ Mark a field as auto-generated
232
+
233
+ Similar to JPA's @GeneratedValue annotation.
234
+
235
+ Args:
236
+ strategy: Generation strategy
237
+
238
+ Example:
239
+ @Id
240
+ @GeneratedValue(strategy=GenerationType.IDENTITY)
241
+ @Column(type=int)
242
+ id: int
243
+ """
244
+ return ColumnMetadata(auto_increment=True, primary_key=True)
245
+
246
+
247
+ def ManyToOne(target_entity: Type, lazy: bool = True):
248
+ """
249
+ Define a many-to-one relationship
250
+
251
+ Similar to JPA's @ManyToOne annotation.
252
+
253
+ Args:
254
+ target_entity: Target entity class
255
+ lazy: Whether to lazy load the relationship
256
+
257
+ Example:
258
+ @ManyToOne(target_entity=User)
259
+ author: User
260
+ """
261
+ return RelationshipMetadata(
262
+ target_entity=target_entity,
263
+ relationship_type="many_to_one",
264
+ lazy=lazy
265
+ )
266
+
267
+
268
+ def OneToMany(target_entity: Type, mapped_by: str, cascade: Optional[List[str]] = None):
269
+ """
270
+ Define a one-to-many relationship
271
+
272
+ Similar to JPA's @OneToMany annotation.
273
+
274
+ Args:
275
+ target_entity: Target entity class
276
+ mapped_by: Field name in target entity that owns the relationship
277
+ cascade: Cascade operations
278
+
279
+ Example:
280
+ @OneToMany(target_entity=Comment, mapped_by="post")
281
+ comments: List[Comment]
282
+ """
283
+ return RelationshipMetadata(
284
+ target_entity=target_entity,
285
+ relationship_type="one_to_many",
286
+ mapped_by=mapped_by,
287
+ cascade=cascade or []
288
+ )
289
+
290
+
291
+ def ManyToMany(target_entity: Type, mapped_by: Optional[str] = None):
292
+ """
293
+ Define a many-to-many relationship
294
+
295
+ Similar to JPA's @ManyToMany annotation.
296
+
297
+ Args:
298
+ target_entity: Target entity class
299
+ mapped_by: Field name in target entity (if this is the inverse side)
300
+
301
+ Example:
302
+ @ManyToMany(target_entity=Tag)
303
+ tags: List[Tag]
304
+ """
305
+ return RelationshipMetadata(
306
+ target_entity=target_entity,
307
+ relationship_type="many_to_many",
308
+ mapped_by=mapped_by
309
+ )
310
+
311
+
312
+ class BaseEntity:
313
+ """
314
+ Base class for all entities
315
+
316
+ Provides common fields like id, created_at, updated_at.
317
+ Similar to Spring Data JPA's base entity pattern.
318
+ """
319
+
320
+ id: int = GeneratedValue()
321
+ created_at: datetime = Column(type=datetime, nullable=False, default=datetime.now)
322
+ updated_at: datetime = Column(type=datetime, nullable=False, default=datetime.now)
323
+
324
+ def __init__(self, **kwargs):
325
+ # First, set provided values from kwargs
326
+ for key, value in kwargs.items():
327
+ setattr(self, key, value)
328
+
329
+ # Then, check for missing values that have defaults in metadata
330
+ if hasattr(self.__class__, '_entity_metadata'):
331
+ metadata: EntityMetadata = self.__class__._entity_metadata
332
+ for col_name, col_meta in metadata.columns.items():
333
+ # If attribute is not set explicitly (or is still the class-level ColumnMetadata), resolve default
334
+ # We check if it's in kwargs first.
335
+ if col_name not in kwargs:
336
+ # Check current value on instance - might be the class attribute (ColumnMetadata)
337
+ current_val = getattr(self, col_name, None)
338
+
339
+ # If it's a ColumnMetadata object or missing, we need to try to set a default
340
+ if isinstance(current_val, ColumnMetadata) or current_val is None:
341
+ if col_meta.default is not None:
342
+ if callable(col_meta.default):
343
+ setattr(self, col_name, col_meta.default())
344
+ else:
345
+ setattr(self, col_name, col_meta.default)
346
+ elif col_meta.nullable:
347
+ setattr(self, col_name, None)
348
+
349
+ def __repr__(self) -> str:
350
+ return f"<{self.__class__.__name__}(id={getattr(self, 'id', None)})>"
351
+
352
+ def to_dict(self) -> dict:
353
+ """Convert entity to dictionary"""
354
+ result = {}
355
+ if hasattr(self.__class__, '_entity_metadata'):
356
+ metadata: EntityMetadata = self.__class__._entity_metadata
357
+ for field_name in metadata.columns.keys():
358
+ if hasattr(self, field_name):
359
+ value = getattr(self, field_name)
360
+ # Handle datetime serialization
361
+ if isinstance(value, datetime):
362
+ result[field_name] = value.isoformat()
363
+ else:
364
+ result[field_name] = value
365
+ return result
@@ -0,0 +1,256 @@
1
+ """
2
+ ORM Gateway abstraction
3
+
4
+ Provides a unified interface for different ORMs.
5
+ """
6
+
7
+ from typing import TypeVar, Generic, List, Optional, Any, Type
8
+ from abc import ABC, abstractmethod
9
+
10
+
11
+ T = TypeVar('T')
12
+
13
+
14
+ class ORMGateway(ABC, Generic[T]):
15
+ """
16
+ Abstract ORM gateway interface
17
+
18
+ Provides a unified interface for database operations across different ORMs.
19
+ """
20
+
21
+ @abstractmethod
22
+ async def save(self, entity: T) -> T:
23
+ """Save an entity"""
24
+ pass
25
+
26
+ @abstractmethod
27
+ async def find_by_id(self, entity_class: Type[T], id: Any) -> Optional[T]:
28
+ """Find an entity by ID"""
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def find_all(self, entity_class: Type[T]) -> List[T]:
33
+ """Find all entities"""
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def delete(self, entity: T) -> None:
38
+ """Delete an entity"""
39
+ pass
40
+
41
+ @abstractmethod
42
+ async def update(self, entity: T) -> T:
43
+ """Update an entity"""
44
+ pass
45
+
46
+ @abstractmethod
47
+ async def exists(self, entity_class: Type[T], id: Any) -> bool:
48
+ """Check if an entity exists"""
49
+ pass
50
+
51
+ @abstractmethod
52
+ async def execute_query(self, sql: str, params: dict, entity_class: Type[T], operation: str) -> Any:
53
+ """
54
+ Execute a custom SQL query
55
+
56
+ Args:
57
+ sql: SQL query string
58
+ params: Query parameters
59
+ entity_class: Entity class for result mapping
60
+ operation: Query operation type (find, count, delete, exists)
61
+
62
+ Returns:
63
+ Query results based on operation type
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ def begin_transaction(self):
69
+ """Begin a transaction"""
70
+ pass
71
+
72
+ @abstractmethod
73
+ def commit(self):
74
+ """Commit a transaction"""
75
+ pass
76
+
77
+ @abstractmethod
78
+ def rollback(self):
79
+ """Rollback a transaction"""
80
+ pass
81
+
82
+
83
+ class SQLAlchemyGateway(ORMGateway[T]):
84
+ """
85
+ SQLAlchemy implementation of ORM gateway
86
+
87
+ Provides SQLAlchemy-specific database operations.
88
+ """
89
+
90
+ def __init__(self, session_factory):
91
+ """
92
+ Initialize with SQLAlchemy session factory
93
+
94
+ Args:
95
+ session_factory: SQLAlchemy sessionmaker instance
96
+ """
97
+ self.session_factory = session_factory
98
+ self._session = None
99
+ self._transaction_stack = []
100
+
101
+ @property
102
+ def session(self):
103
+ """Get or create session"""
104
+ if self._session is None:
105
+ self._session = self.session_factory()
106
+ return self._session
107
+
108
+ def _auto_commit(self):
109
+ """Commit only if not in a transaction"""
110
+ if not self._transaction_stack and not self.session.in_transaction():
111
+ self.session.commit()
112
+ # Else: let the active transaction handle it
113
+
114
+ # Standard commit is for transaction management (decrementing depth)
115
+
116
+ async def save(self, entity: T) -> T:
117
+ """Save an entity using SQLAlchemy ORM"""
118
+ # Since entities are mapped imperatively, we can just use the session
119
+ self.session.add(entity)
120
+ # Flush to generate ID
121
+ self.session.flush()
122
+ self.session.refresh(entity)
123
+
124
+ self._auto_commit()
125
+ return entity
126
+
127
+ async def find_by_id(self, entity_class: Type[T], id: Any) -> Optional[T]:
128
+ """Find an entity by ID using SQLAlchemy ORM"""
129
+ return self.session.get(entity_class, id)
130
+
131
+ async def find_all(self, entity_class: Type[T]) -> List[T]:
132
+ """Find all entities using SQLAlchemy ORM"""
133
+ from sqlalchemy import select
134
+ stmt = select(entity_class)
135
+ result = self.session.execute(stmt)
136
+ return list(result.scalars().all())
137
+
138
+ async def delete(self, entity: T) -> None:
139
+ """Delete an entity using SQLAlchemy ORM"""
140
+ self.session.delete(entity)
141
+ self._auto_commit()
142
+
143
+ async def update(self, entity: T) -> T:
144
+ """Update an entity using SQLAlchemy ORM"""
145
+ # Merge handles re-attaching detached instances
146
+ merged = self.session.merge(entity)
147
+ self.session.flush()
148
+ self._auto_commit()
149
+ return merged
150
+
151
+ async def exists(self, entity_class: Type[T], id: Any) -> bool:
152
+ """Check if an entity exists using SQLAlchemy ORM"""
153
+ obj = self.session.get(entity_class, id)
154
+ return obj is not None
155
+
156
+ async def execute_query(self, sql: str, params: dict, entity_class: Type[T], operation: str) -> Any:
157
+ """
158
+ Execute a custom SQL query
159
+
160
+ Note: We are slowly deprecating raw SQL strings in favor of SQLAlchemy expressions,
161
+ but we support this for the 'query_builder' which currently generates strings.
162
+ Ideally query_builder needs update to generate Select objects.
163
+ """
164
+ from sqlalchemy import text
165
+ from starspring.data.query_builder import QueryOperation
166
+
167
+ # This fallback implementation still works for raw strings generated by the old query builder
168
+ sql_alchemy = sql
169
+ param_values = {}
170
+ for param_name, value in params.items():
171
+ sql_alchemy = sql_alchemy.replace('?', f':{param_name}', 1)
172
+ param_values[param_name] = value
173
+
174
+ result = self.session.execute(text(sql_alchemy), param_values)
175
+
176
+ if operation == QueryOperation.FIND or operation.value == 'find':
177
+ # Map raw rows back to entities?
178
+ # Issue: raw SQL returns rows, not objects if we don't use select(Entity).from_statement(...)
179
+ # Better approach: Iterate rows and manually construct if we must,
180
+ # OR assume the query selects all columns matching the entity.
181
+
182
+ # Since query_builder generates "SELECT * FROM ...", mapping by name is safer.
183
+ rows = result.fetchall()
184
+ entities = []
185
+ cols = result.keys()
186
+ for row in rows:
187
+ data = {col: val for col, val in zip(cols, row)}
188
+ entities.append(entity_class(**data))
189
+ return entities
190
+
191
+ elif operation == QueryOperation.COUNT or operation.value == 'count':
192
+ return result.scalar()
193
+
194
+ elif operation == QueryOperation.EXISTS or operation.value == 'exists':
195
+ return result.scalar() > 0
196
+
197
+ elif operation == QueryOperation.DELETE or operation.value == 'delete':
198
+ self._auto_commit()
199
+ return None
200
+
201
+ return None
202
+
203
+ # Stack to track transactions (Root + specific Nested ones)
204
+
205
+ def begin_transaction(self):
206
+ """Begin a transaction (supports nesting via SAVEPOINTs)"""
207
+ if not self.session.in_transaction():
208
+ # Root transaction
209
+ txn = self.session.begin()
210
+ self._transaction_stack.append(txn)
211
+ else:
212
+ # Nested transaction (SAVEPOINT)
213
+ txn = self.session.begin_nested()
214
+ self._transaction_stack.append(txn)
215
+
216
+ def commit(self):
217
+ """Commit the current transaction level"""
218
+ if self._transaction_stack:
219
+ txn = self._transaction_stack.pop()
220
+ txn.commit()
221
+ else:
222
+ # Fallback if manual commit called without explicit begin
223
+ self.session.commit()
224
+
225
+ def rollback(self):
226
+ """Rollback the current usage level"""
227
+ if self._transaction_stack:
228
+ txn = self._transaction_stack.pop()
229
+ txn.rollback()
230
+ else:
231
+ self.session.rollback()
232
+
233
+ def close(self):
234
+ """Close the session"""
235
+ if self._session:
236
+ self._session.close()
237
+ self._session = None
238
+ self._transaction_stack = []
239
+
240
+
241
+ # Global ORM gateway instance
242
+ _orm_gateway: Optional[ORMGateway] = None
243
+
244
+
245
+ def get_orm_gateway() -> ORMGateway:
246
+ """Get the global ORM gateway instance"""
247
+ global _orm_gateway
248
+ if _orm_gateway is None:
249
+ raise RuntimeError("ORM gateway not initialized. Call set_orm_gateway() first.")
250
+ return _orm_gateway
251
+
252
+
253
+ def set_orm_gateway(gateway: ORMGateway) -> None:
254
+ """Set the global ORM gateway instance"""
255
+ global _orm_gateway
256
+ _orm_gateway = gateway