sera-2 1.12.3__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
sera/exports/schema.py ADDED
@@ -0,0 +1,157 @@
1
+ from pathlib import Path
2
+ from typing import Annotated
3
+
4
+ import typer
5
+
6
+ from sera.models import Cardinality, Class, DataProperty, Schema, parse_schema
7
+ from sera.models._datatype import DataType
8
+
9
+
10
+ def get_prisma_field_type(datatype: DataType) -> str:
11
+ pytype = datatype.get_python_type().type
12
+ if pytype == "str":
13
+ return "String"
14
+ if pytype == "int":
15
+ return "Int"
16
+ if pytype == "float":
17
+ return "Float"
18
+ if pytype == "bool":
19
+ return "Boolean"
20
+ if pytype == "bytes":
21
+ return "Bytes"
22
+ if pytype == "dict":
23
+ return "Json"
24
+ if pytype == "datetime":
25
+ return "DateTime"
26
+ if pytype == "list[str]":
27
+ return "String[]"
28
+ if pytype == "list[int]":
29
+ return "Int[]"
30
+ if pytype == "list[float]":
31
+ return "Float[]"
32
+ if pytype == "list[bool]":
33
+ return "Boolean[]"
34
+ if pytype == "list[bytes]":
35
+ return "Bytes[]"
36
+ if pytype == "list[dict]":
37
+ return "Json[]"
38
+ if pytype == "list[datetime]":
39
+ return "DateTime[]"
40
+
41
+ raise ValueError(f"Unsupported data type for Prisma: {pytype}")
42
+
43
+
44
+ def to_prisma_model(schema: Schema, cls: Class, lines: list[str]):
45
+ """Convert a Sera Class to a Prisma model string representation."""
46
+ lines.append(f"model {cls.name} {{")
47
+
48
+ if cls.db is None:
49
+ # This class has no database mapping, we must generate a default key for it
50
+ lines.append(
51
+ f" {'id'.ljust(30)} {'Int'.ljust(10)} @id @default(autoincrement())"
52
+ )
53
+ # lines.append(f" @@unique([%s])" % ", ".join(cls.properties.keys()))
54
+
55
+ for prop in cls.properties.values():
56
+ propattrs = ""
57
+ if isinstance(prop, DataProperty):
58
+ proptype = get_prisma_field_type(prop.datatype)
59
+ if prop.is_optional:
60
+ proptype = f"{proptype}?"
61
+ if prop.db is not None and prop.db.is_primary_key:
62
+ propattrs += "@id "
63
+
64
+ lines.append(f" {prop.name.ljust(30)} {proptype.ljust(10)} {propattrs}")
65
+ continue
66
+
67
+ if prop.cardinality == Cardinality.MANY_TO_MANY:
68
+ # For many-to-many relationships, we need to handle the join table
69
+ lines.append(
70
+ f" {prop.name.ljust(30)} {(prop.target.name + '[]').ljust(10)}"
71
+ )
72
+ else:
73
+ lines.append(
74
+ f" {(prop.name + '_').ljust(30)} {prop.target.name.ljust(10)} @relation(fields: [{prop.name}], references: [id])"
75
+ )
76
+ lines.append(f" {prop.name.ljust(30)} {'Int'.ljust(10)} @unique")
77
+
78
+ lines.append("")
79
+ for upstream_cls, reverse_upstream_prop in schema.get_upstream_classes(cls):
80
+ if (
81
+ reverse_upstream_prop.cardinality == Cardinality.MANY_TO_ONE
82
+ or reverse_upstream_prop.cardinality == Cardinality.MANY_TO_MANY
83
+ ):
84
+
85
+ proptype = f"{upstream_cls.name}[]"
86
+ else:
87
+ proptype = upstream_cls.name + "?"
88
+ lines.append(f" {upstream_cls.name.lower().ljust(30)} {proptype.ljust(10)}")
89
+
90
+ lines.append("}\n")
91
+
92
+
93
+ def export_prisma_schema(schema: Schema, outfile: Path):
94
+ """Export Prisma schema file"""
95
+ lines = []
96
+
97
+ # Datasource
98
+ lines.append("datasource db {")
99
+ lines.append(
100
+ ' provider = "postgresql"'
101
+ ) # Defaulting to postgresql as per user context
102
+ lines.append(' url = env("DATABASE_URL")')
103
+ lines.append("}\n")
104
+
105
+ # Generator
106
+ lines.append("generator client {")
107
+ lines.append(' provider = "prisma-client-py"')
108
+ lines.append(" recursive_type_depth = 5")
109
+ lines.append("}\n")
110
+
111
+ # Enums
112
+ if schema.enums:
113
+ for enum_name, enum_def in schema.enums.items():
114
+ lines.append(f"enum {enum_name} {{")
115
+ # Assuming enum_def.values is a list of strings based on previous errors
116
+ for val_str in enum_def.values:
117
+ lines.append(f" {val_str}")
118
+ lines.append("}\\n")
119
+
120
+ # Models
121
+ for cls in schema.topological_sort():
122
+ to_prisma_model(schema, cls, lines)
123
+
124
+ with outfile.open("w", encoding="utf-8") as f:
125
+ f.write("\n".join(lines))
126
+
127
+
128
+ app = typer.Typer(pretty_exceptions_short=True, pretty_exceptions_enable=False)
129
+
130
+
131
+ @app.command()
132
+ def cli(
133
+ schema_files: Annotated[
134
+ list[Path],
135
+ typer.Option(
136
+ "-s", help="YAML schema files. Multiple files are merged automatically"
137
+ ),
138
+ ],
139
+ outfile: Annotated[
140
+ Path,
141
+ typer.Option(
142
+ "-o", "--output", help="Output file for the Prisma schema", writable=True
143
+ ),
144
+ ],
145
+ ):
146
+ schema = parse_schema(
147
+ "sera",
148
+ schema_files,
149
+ )
150
+ export_prisma_schema(
151
+ schema,
152
+ outfile,
153
+ )
154
+
155
+
156
+ if __name__ == "__main__":
157
+ app()
sera/exports/test.py ADDED
@@ -0,0 +1,70 @@
1
+ from sqlalchemy import (
2
+ Column,
3
+ ForeignKey,
4
+ Integer,
5
+ MetaData,
6
+ String,
7
+ Table,
8
+ create_engine,
9
+ )
10
+ from sqlalchemy.schema import CreateTable
11
+
12
+ # Define your SQLAlchemy engine (dialect matters for SQL output)
13
+ # Using a specific dialect helps generate appropriate SQL
14
+ # engine = create_engine(
15
+ # "postgresql+psycopg2://user:password@host:port/database", echo=False
16
+ # )
17
+ # Or for SQLite:
18
+ # engine = create_engine("sqlite:///:memory:")
19
+ # Or for MySQL:
20
+ # engine = create_engine("mysql+mysqlconnector://user:password@host:port/database")
21
+
22
+
23
+ metadata_obj = MetaData()
24
+
25
+ user_table = Table(
26
+ "users",
27
+ metadata_obj,
28
+ Column("id", Integer, primary_key=True),
29
+ Column("name", String(50)),
30
+ Column("email", String(100), unique=True),
31
+ )
32
+
33
+ address_table = Table(
34
+ "addresses",
35
+ metadata_obj,
36
+ Column("id", Integer, primary_key=True),
37
+ Column("user_id", Integer, ForeignKey("users.id"), nullable=False),
38
+ Column("street_name", String(100)),
39
+ Column("city", String(50)),
40
+ )
41
+
42
+ # --- ORM Example ---
43
+ # from sqlalchemy.orm import declarative_base, Mapped, mapped_column
44
+ # Base = declarative_base()
45
+ # metadata_obj = Base.metadata
46
+ # class User(Base): # ... (define as above)
47
+ # class Address(Base): # ... (define as above)
48
+ # -------------------
49
+
50
+ print("--- Generating DDL for PostgreSQL ---")
51
+
52
+
53
+ def generate_schema_ddl(metadata, engine_dialect):
54
+ for table in metadata.sorted_tables:
55
+ # The CreateTable construct can be compiled to a string
56
+ # specific to the dialect of the engine.
57
+ create_table_ddl = CreateTable(table).compile(dialect=engine_dialect)
58
+ print(str(create_table_ddl).strip() + ";\n")
59
+
60
+
61
+ from sqlalchemy.dialects import postgresql, sqlite
62
+
63
+ generate_schema_ddl(metadata_obj, postgresql.dialect())
64
+
65
+ # # Example with a different dialect (e.g., SQLite)
66
+ # # Note: You don't need a live connection for this, just the dialect.
67
+
68
+
69
+ # print("\n--- Generating DDL for SQLite ---")
70
+ # generate_schema_ddl(metadata_obj, sqlite.dialect())
sera/libs/base_orm.py CHANGED
@@ -4,10 +4,12 @@ from typing import Optional
4
4
 
5
5
  import orjson
6
6
  from msgspec.json import decode, encode
7
- from sera.typing import UNSET
8
7
  from sqlalchemy import LargeBinary, TypeDecorator
9
8
  from sqlalchemy import create_engine as sqlalchemy_create_engine
10
9
  from sqlalchemy import update
10
+ from sqlalchemy.ext.asyncio import create_async_engine as sqlalchemy_create_async_engine
11
+
12
+ from sera.typing import UNSET
11
13
 
12
14
 
13
15
  class BaseORM:
@@ -104,11 +106,27 @@ class DictDataclassType(TypeDecorator):
104
106
  def create_engine(
105
107
  dbconn: str,
106
108
  connect_args: Optional[dict] = None,
107
- debug: bool = False,
109
+ echo: bool = False,
110
+ ):
111
+ if dbconn.startswith("sqlite"):
112
+ connect_args = {"check_same_thread": False}
113
+ else:
114
+ connect_args = {}
115
+ engine = sqlalchemy_create_engine(dbconn, connect_args=connect_args, echo=echo)
116
+ return engine
117
+
118
+
119
+ def create_async_engine(
120
+ dbconn: str,
121
+ connect_args: Optional[dict] = None,
122
+ echo: bool = False,
108
123
  ):
109
124
  if dbconn.startswith("sqlite"):
110
125
  connect_args = {"check_same_thread": False}
111
126
  else:
112
127
  connect_args = {}
113
- engine = sqlalchemy_create_engine(dbconn, connect_args=connect_args, echo=debug)
128
+
129
+ engine = sqlalchemy_create_async_engine(
130
+ dbconn, connect_args=connect_args, echo=echo
131
+ )
114
132
  return engine
sera/libs/base_service.py CHANGED
@@ -4,8 +4,11 @@ from enum import Enum
4
4
  from math import dist
5
5
  from typing import Annotated, Any, Generic, NamedTuple, Optional, Sequence, TypeVar
6
6
 
7
- from sqlalchemy import Result, Select, exists, func, select
8
- from sqlalchemy.orm import Session, load_only
7
+ from litestar.exceptions import HTTPException
8
+ from sqlalchemy import Result, Select, delete, exists, func, select
9
+ from sqlalchemy.exc import IntegrityError
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from sqlalchemy.orm import load_only
9
12
 
10
13
  from sera.libs.base_orm import BaseORM
11
14
  from sera.misc import assert_not_null
@@ -41,7 +44,7 @@ class QueryResult(NamedTuple, Generic[R]):
41
44
  total: int
42
45
 
43
46
 
44
- class BaseService(Generic[ID, R]):
47
+ class BaseAsyncService(Generic[ID, R]):
45
48
 
46
49
  instance = None
47
50
 
@@ -51,7 +54,7 @@ class BaseService(Generic[ID, R]):
51
54
  self.id_prop = assert_not_null(cls.get_id_property())
52
55
 
53
56
  self._cls_id_prop = getattr(self.orm_cls, self.id_prop.name)
54
- self.is_id_auto_increment = self.id_prop.db.is_auto_increment
57
+ self.is_id_auto_increment = assert_not_null(self.id_prop.db).is_auto_increment
55
58
 
56
59
  @classmethod
57
60
  def get_instance(cls):
@@ -59,10 +62,10 @@ class BaseService(Generic[ID, R]):
59
62
  if cls.instance is None:
60
63
  # assume that the subclass overrides the __init__ method
61
64
  # so that we don't need to pass the class and orm_cls
62
- cls.instance = cls()
65
+ cls.instance = cls() # type: ignore[call-arg]
63
66
  return cls.instance
64
67
 
65
- def get(
68
+ async def get(
66
69
  self,
67
70
  query: Query,
68
71
  limit: int,
@@ -71,7 +74,7 @@ class BaseService(Generic[ID, R]):
71
74
  sorted_by: list[str],
72
75
  group_by: list[str],
73
76
  fields: list[str],
74
- session: Session,
77
+ session: AsyncSession,
75
78
  ) -> QueryResult[R]:
76
79
  """Retrieving records matched a query.
77
80
 
@@ -103,35 +106,37 @@ class BaseService(Generic[ID, R]):
103
106
 
104
107
  cq = select(func.count()).select_from(q.subquery())
105
108
  rq = q.limit(limit).offset(offset)
106
- records = self._process_result(session.execute(rq)).scalars().all()
107
- total = session.execute(cq).scalar_one()
109
+ records = self._process_result(await session.execute(rq)).scalars().all()
110
+ total = (await session.execute(cq)).scalar_one()
108
111
  return QueryResult(records, total)
109
112
 
110
- def get_by_id(self, id: ID, session: Session) -> Optional[R]:
113
+ async def get_by_id(self, id: ID, session: AsyncSession) -> Optional[R]:
111
114
  """Retrieving a record by ID."""
112
115
  q = self._select().where(self._cls_id_prop == id)
113
- result = self._process_result(session.execute(q)).scalar_one_or_none()
116
+ result = self._process_result(await session.execute(q)).scalar_one_or_none()
114
117
  return result
115
118
 
116
- def has_id(self, id: ID, session: Session) -> bool:
119
+ async def has_id(self, id: ID, session: AsyncSession) -> bool:
117
120
  """Check whether we have a record with the given ID."""
118
- q = exists().where(self._cls_id_prop == id)
119
- result = session.query(q).scalar()
121
+ q = exists().where(self._cls_id_prop == id).select()
122
+ result = (await session.execute(q)).scalar()
120
123
  return bool(result)
121
124
 
122
- def create(self, record: R, session: Session) -> R:
125
+ async def create(self, record: R, session: AsyncSession) -> R:
123
126
  """Create a new record."""
124
127
  if self.is_id_auto_increment:
125
128
  setattr(record, self.id_prop.name, None)
126
129
 
127
- session.add(record)
128
- session.commit()
130
+ try:
131
+ session.add(record)
132
+ await session.flush()
133
+ except IntegrityError:
134
+ raise HTTPException(detail="Invalid request", status_code=409)
129
135
  return record
130
136
 
131
- def update(self, record: R, session: Session) -> R:
137
+ async def update(self, record: R, session: AsyncSession) -> R:
132
138
  """Update an existing record."""
133
- session.execute(record.get_update_query())
134
- session.commit()
139
+ await session.execute(record.get_update_query())
135
140
  return record
136
141
 
137
142
  def _select(self) -> Select:
@@ -141,3 +146,7 @@ class BaseService(Generic[ID, R]):
141
146
  def _process_result(self, result: SqlResult) -> SqlResult:
142
147
  """Process the result of a query."""
143
148
  return result
149
+
150
+ async def truncate(self, session: AsyncSession) -> None:
151
+ """Truncate the table."""
152
+ await session.execute(delete(self.orm_cls))
@@ -1,16 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from datetime import datetime, timezone
4
- from typing import Callable, Generator, Generic, Optional, Sequence, Type
4
+ from typing import Awaitable, Callable, Generic, Sequence
5
5
 
6
6
  from litestar import Request
7
7
  from litestar.connection import ASGIConnection
8
8
  from litestar.exceptions import NotAuthorizedException
9
9
  from litestar.middleware import AbstractAuthenticationMiddleware, AuthenticationResult
10
10
  from litestar.types import ASGIApp, Method, Scopes
11
- from litestar.types.composite_types import Dependencies
12
- from sqlalchemy import select
13
- from sqlalchemy.orm import Session
14
11
 
15
12
  from sera.typing import T
16
13
 
@@ -26,7 +23,7 @@ class AuthMiddleware(AbstractAuthenticationMiddleware, Generic[T]):
26
23
  def __init__(
27
24
  self,
28
25
  app: ASGIApp,
29
- user_handler: Callable[[str], T],
26
+ user_handler: Callable[[str], Awaitable[T]],
30
27
  exclude: str | list[str] | None = None,
31
28
  exclude_from_auth_key: str = "exclude_from_auth",
32
29
  exclude_http_methods: Sequence[Method] | None = None,
@@ -59,7 +56,7 @@ class AuthMiddleware(AbstractAuthenticationMiddleware, Generic[T]):
59
56
  detail="Credentials expired",
60
57
  )
61
58
 
62
- user = self.user_handler(userid)
59
+ user = await self.user_handler(userid)
63
60
  if user is None:
64
61
  raise NotAuthorizedException(
65
62
  detail="User not found",
sera/make/__main__.py CHANGED
@@ -49,4 +49,5 @@ def cli(
49
49
  make_app(app_dir, schema_files, api_collections, language, referenced_schema)
50
50
 
51
51
 
52
- app()
52
+ if __name__ == "__main__":
53
+ app()