activemodel 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,110 @@
1
+ import os
2
+ from typing import Iterable
3
+
4
+ import pytest
1
5
  from sqlmodel import SQLModel
2
6
 
3
7
  from ..logger import logger
4
8
  from ..session_manager import get_engine
9
+ from pytest import Config
10
+ import typing as t
11
+
12
+ T = t.TypeVar("T")
13
+
14
+
15
+ def _normalize_to_list_of_strings(str_or_list: list[str] | str) -> list[str]:
16
+ if isinstance(str_or_list, list):
17
+ return str_or_list
18
+
19
+ raw_list = str_or_list.split(",")
20
+ return [entry.strip() for entry in raw_list if entry and entry.strip()]
21
+
22
+
23
+ def _get_pytest_option(
24
+ config: Config, key: str, *, cast: t.Callable[[t.Any], T] | None = str
25
+ ) -> T | None:
26
+ if not config:
27
+ return None
28
+
29
+ try:
30
+ val = config.getoption(key)
31
+ except ValueError:
32
+ val = None
33
+
34
+ if val is None:
35
+ val = config.getini(key)
36
+
37
+ if val is not None:
38
+ if cast:
39
+ return cast(val)
40
+
41
+ return val
42
+
43
+ return None
44
+
45
+
46
+ def _normalize_preserve_tables(raw: Iterable[str]) -> list[str]:
47
+ """Normalize user supplied table list: strip, dedupe (order not preserved).
48
+
49
+ Returns a sorted list (case-insensitive sort while preserving original casing
50
+ for readability in logs).
51
+ """
52
+
53
+ cleaned = {name.strip() for name in raw if name and name.strip()}
54
+ # deterministic order: casefold sort
55
+ return sorted(cleaned, key=lambda s: s.casefold())
56
+
57
+
58
+ def _get_excluded_tables(
59
+ pytest_config: Config | None, preserve_tables: list[str] | None
60
+ ) -> list[str]:
61
+ """Resolve list of tables to exclude (i.e. *preserve* / NOT truncate).
62
+
63
+ Precedence (lowest -> highest):
64
+ 1. pytest ini option ``activemodel_preserve_tables`` (if available)
65
+ 2. Environment variable ``ACTIVEMODEL_PRESERVE_TABLES`` (comma separated)
66
+ 3. Function argument ``preserve_tables``
67
+
68
+ Behavior:
69
+ * If user supplies nothing via any channel, defaults to ["alembic_version"].
70
+ * Case-insensitive matching during truncation; returned list is normalized
71
+ (deduped, sorted) for deterministic logging.
72
+ * Emits a warning only when the ini option is *explicitly* specified but empty after normalization.
73
+ """
74
+
75
+ # 1. pytest ini option (registered as type="linelist" -> typically list[str])
76
+ ini_tables = (
77
+ _get_pytest_option(
78
+ pytest_config,
79
+ "activemodel_preserve_tables",
80
+ cast=_normalize_to_list_of_strings,
81
+ )
82
+ or []
83
+ )
84
+
85
+ # 2. environment variable
86
+ env_var = os.getenv("ACTIVEMODEL_PRESERVE_TABLES", "")
87
+ env_tables = _normalize_to_list_of_strings(env_var)
88
+
89
+ # 3. function argument
90
+ arg_tables = preserve_tables or []
91
+
92
+ # Consider customization only if any non-empty source provided values OR the function arg explicitly passed
93
+ combined_raw = [*ini_tables, *env_tables, *arg_tables]
94
+
95
+ # if no user customization, force alembic_version
96
+ if not combined_raw:
97
+ return ["alembic_version"]
98
+
99
+ normalized = _normalize_preserve_tables(combined_raw)
100
+ logger.debug(f"excluded tables for truncation: {normalized}")
101
+
102
+ return normalized
5
103
 
6
104
 
7
- def database_reset_truncate():
105
+ def database_reset_truncate(
106
+ preserve_tables: list[str] | None = None, pytest_config: Config | None = None
107
+ ):
8
108
  """
9
109
  Transaction is most likely the better way to go, but there are some scenarios where the session override
10
110
  logic does not work properly and you need to truncate tables back to their original state.
@@ -28,18 +128,19 @@ def database_reset_truncate():
28
128
 
29
129
  logger.info("truncating database")
30
130
 
31
- # TODO get additonal tables to preserve from config
32
- exception_tables = ["alembic_version"]
131
+ # Determine excluded (preserved) tables and build case-insensitive lookup set
132
+ exception_tables = _get_excluded_tables(pytest_config, preserve_tables)
133
+ exception_lookup = {t.lower() for t in exception_tables}
33
134
 
34
- assert (
35
- SQLModel.metadata.sorted_tables
36
- ), "No model metadata. Ensure model metadata is imported before running truncate_db"
135
+ assert SQLModel.metadata.sorted_tables, (
136
+ "No model metadata. Ensure model metadata is imported before running truncate_db"
137
+ )
37
138
 
38
139
  with get_engine().connect() as connection:
39
140
  for table in reversed(SQLModel.metadata.sorted_tables):
40
141
  transaction = connection.begin()
41
142
 
42
- if table.name not in exception_tables:
143
+ if table.name.lower() not in exception_lookup:
43
144
  logger.debug(f"truncating table={table.name}")
44
145
  connection.execute(table.delete())
45
146
 
@@ -1,11 +1,13 @@
1
1
  import sqlmodel as sm
2
2
  from sqlmodel.sql.expression import SelectOfScalar
3
3
 
4
+ from activemodel.types.sqlalchemy_protocol import SQLAlchemyQueryMethods
5
+
4
6
  from .session_manager import get_session
5
7
  from .utils import compile_sql
6
8
 
7
9
 
8
- class QueryWrapper[T: sm.SQLModel]:
10
+ class QueryWrapper[T: sm.SQLModel](SQLAlchemyQueryMethods[T]):
9
11
  """
10
12
  Make it easy to run queries off of a model
11
13
  """
@@ -46,13 +48,20 @@ class QueryWrapper[T: sm.SQLModel]:
46
48
  with get_session() as session:
47
49
  return session.scalar(sm.select(sm.func.count()).select_from(self.target))
48
50
 
51
+ def scalar(self):
52
+ """
53
+ >>>
54
+ """
55
+ with get_session() as session:
56
+ return session.scalar(self.target)
57
+
49
58
  def exec(self):
50
59
  with get_session() as session:
51
60
  return session.exec(self.target)
52
61
 
53
62
  def delete(self):
54
63
  with get_session() as session:
55
- session.delete(self.target)
64
+ return session.delete(self.target)
56
65
 
57
66
  def __getattr__(self, name):
58
67
  """
@@ -87,4 +96,5 @@ class QueryWrapper[T: sm.SQLModel]:
87
96
  return compile_sql(self.target)
88
97
 
89
98
  def __repr__(self) -> str:
99
+ # TODO we should improve structure of this a bit more, maybe wrap in <> or something?
90
100
  return f"{self.__class__.__name__}: Current SQL:\n{self.sql()}"
@@ -13,6 +13,8 @@ from pydantic import BaseModel
13
13
  from sqlalchemy import Connection, Engine
14
14
  from sqlmodel import Session, create_engine
15
15
 
16
+ ACTIVEMODEL_LOG_SQL = config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False)
17
+
16
18
 
17
19
  def _serialize_pydantic_model(model: BaseModel | list[BaseModel] | None) -> str | None:
18
20
  """
@@ -47,21 +49,29 @@ class SessionManager:
47
49
  "singleton instance of SessionManager"
48
50
 
49
51
  session_connection: Connection | None
50
- "optionally specify a specific session connection to use for all get_session() calls, useful for testing"
52
+ "optionally specify a specific session connection to use for all get_session() calls, useful for testing and migrations"
51
53
 
52
54
  @classmethod
53
- def get_instance(cls, database_url: str | None = None) -> "SessionManager":
55
+ def get_instance(
56
+ cls,
57
+ database_url: str | None = None,
58
+ *,
59
+ engine_options: dict[str, t.Any] | None = None,
60
+ ) -> "SessionManager":
54
61
  if cls._instance is None:
55
62
  assert database_url is not None, (
56
63
  "Database URL required for first initialization"
57
64
  )
58
- cls._instance = cls(database_url)
65
+ cls._instance = cls(database_url, engine_options=engine_options)
59
66
 
60
67
  return cls._instance
61
68
 
62
- def __init__(self, database_url: str):
69
+ def __init__(
70
+ self, database_url: str, *, engine_options: dict[str, t.Any] | None = None
71
+ ):
63
72
  self._database_url = database_url
64
73
  self._engine = None
74
+ self._engine_options: dict = engine_options or {}
65
75
 
66
76
  self.session_connection = None
67
77
 
@@ -72,16 +82,19 @@ class SessionManager:
72
82
  self._database_url,
73
83
  # NOTE very important! This enables pydantic models to be serialized for JSONB columns
74
84
  json_serializer=_serialize_pydantic_model,
75
- # TODO move to a constants area
76
- echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
85
+ echo=ACTIVEMODEL_LOG_SQL,
86
+ echo_pool=ACTIVEMODEL_LOG_SQL,
77
87
  # https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic
78
88
  pool_pre_ping=True,
79
89
  # some implementations include `future=True` but it's not required anymore
90
+ **self._engine_options,
80
91
  )
81
92
 
82
93
  return self._engine
83
94
 
84
95
  def get_session(self):
96
+ "get a new database session, respecting any globally set sessions"
97
+
85
98
  if gsession := _session_context.get():
86
99
 
87
100
  @contextlib.contextmanager
@@ -97,33 +110,70 @@ class SessionManager:
97
110
  return Session(self.get_engine())
98
111
 
99
112
 
100
- def init(database_url: str):
113
+ # TODO would be great one day to type engine_options as the SQLAlchemy EngineOptions
114
+ def init(database_url: str, *, engine_options: dict[str, t.Any] | None = None):
101
115
  "configure activemodel to connect to a specific database"
102
- return SessionManager.get_instance(database_url)
116
+ return SessionManager.get_instance(database_url, engine_options=engine_options)
103
117
 
104
118
 
105
119
  def get_engine():
120
+ "alias to get the database engine without importing SessionManager"
106
121
  return SessionManager.get_instance().get_engine()
107
122
 
108
123
 
109
124
  def get_session():
125
+ "alias to get a database session without importing SessionManager"
110
126
  return SessionManager.get_instance().get_session()
111
127
 
112
128
 
113
- # contextvars must be at the top-level of a module! You will not get a warning if you don't do this.
114
- # ContextVar is implemented in C, so it's very special and is both thread-safe and asyncio safe. This variable gives us
115
- # a place to persist a session to use globally across the application.
116
129
  _session_context = contextvars.ContextVar[Session | None](
117
130
  "session_context", default=None
118
131
  )
132
+ """
133
+ This is a VERY important ContextVar, it sets a global session to be used across all ActiveModel operations by default
134
+ and ensures get_session() uses this session as well.
135
+
136
+ contextvars must be at the top-level of a module! You will not get a warning if you don't do this.
137
+ ContextVar is implemented in C, so it's very special and is both thread-safe and asyncio safe. This variable gives us
138
+ a place to persist a session to use globally across the application.
139
+ """
119
140
 
120
141
 
121
142
  @contextlib.contextmanager
122
- def global_session():
143
+ def global_session(session: Session | None = None):
144
+ """
145
+ Generate a session and share it across all activemodel calls.
146
+
147
+ Alternatively, you can pass a session to use globally into the context manager, which is helpful for migrations
148
+ and testing.
149
+
150
+ This may only be called a single time per callstack. There is one exception: if you call this multiple times
151
+ and pass in the same session reference, it will result in a noop.
152
+
153
+ Args:
154
+ session: Use an existing session instead of creating a new one
155
+ """
156
+
157
+ if session is not None and _session_context.get() is session:
158
+ yield session
159
+ return
160
+
123
161
  if _session_context.get() is not None:
124
- raise RuntimeError("global session already set")
162
+ raise RuntimeError("ActiveModel: global session already set")
125
163
 
126
- with SessionManager.get_instance().get_session() as s:
164
+ @contextlib.contextmanager
165
+ def manage_existing_session():
166
+ "if an existing session already exists, use it without triggering another __enter__"
167
+ yield session
168
+
169
+ # Use provided session or create a new one
170
+ session_context = (
171
+ manage_existing_session()
172
+ if session is not None
173
+ else SessionManager.get_instance().get_session()
174
+ )
175
+
176
+ with session_context as s:
127
177
  token = _session_context.set(s)
128
178
 
129
179
  try:
@@ -0,0 +1,10 @@
1
+ # IMPORTANT: This file is auto-generated. Do not edit directly.
2
+
3
+ from typing import Protocol, TypeVar, Any, Generic
4
+ import sqlmodel as sm
5
+ from sqlalchemy.sql.base import _NoArg
6
+ from typing import TYPE_CHECKING
7
+
8
+
9
+ class SQLAlchemyQueryMethods[T: sm.SQLModel](Protocol):
10
+ pass
@@ -0,0 +1,132 @@
1
+ # IMPORTANT: This file is auto-generated. Do not edit directly.
2
+
3
+ from typing import Protocol, TypeVar, Any, Generic
4
+ import sqlmodel as sm
5
+ from sqlalchemy.sql.base import _NoArg
6
+
7
+ from ..query_wrapper import QueryWrapper
8
+
9
+ class SQLAlchemyQueryMethods[T: sm.SQLModel](Protocol):
10
+ """Protocol defining SQLAlchemy query methods forwarded by QueryWrapper.__getattr__"""
11
+
12
+ def add_columns(self, *entities: Any) -> "QueryWrapper[T]": ...
13
+ def add_cte(self, *ctes: Any, nest_here: Any = False) -> "QueryWrapper[T]": ...
14
+ def alias(self, name: Any = None, flat: Any = False) -> "QueryWrapper[T]": ...
15
+ def argument_for(self, argument_name: Any, default: Any) -> "QueryWrapper[T]": ...
16
+ def as_scalar(
17
+ self,
18
+ ) -> "QueryWrapper[T]": ...
19
+ def column(self, column: Any) -> "QueryWrapper[T]": ...
20
+ def compare(self, other: Any, **kw: Any) -> "QueryWrapper[T]": ...
21
+ def compile(
22
+ self, bind: Any = None, dialect: Any = None, **kw: Any
23
+ ) -> "QueryWrapper[T]": ...
24
+ def correlate(self, *fromclauses: Any) -> "QueryWrapper[T]": ...
25
+ def correlate_except(self, *fromclauses: Any) -> "QueryWrapper[T]": ...
26
+ def corresponding_column(
27
+ self, column: Any, require_embedded: Any = False
28
+ ) -> "QueryWrapper[T]": ...
29
+ def cte(
30
+ self, name: Any = None, recursive: Any = False, nesting: Any = False
31
+ ) -> "QueryWrapper[T]": ...
32
+ def distinct(self, *expr: Any) -> "QueryWrapper[T]": ...
33
+ def except_(self, *other: Any) -> "QueryWrapper[T]": ...
34
+ def except_all(self, *other: Any) -> "QueryWrapper[T]": ...
35
+ def execution_options(self, **kw: Any) -> "QueryWrapper[T]": ...
36
+ def exists(
37
+ self,
38
+ ) -> "QueryWrapper[T]": ...
39
+ def fetch(
40
+ self,
41
+ count: Any,
42
+ with_ties: Any = False,
43
+ percent: Any = False,
44
+ **dialect_kw: Any,
45
+ ) -> "QueryWrapper[T]": ...
46
+ def filter(self, *criteria: Any) -> "QueryWrapper[T]": ...
47
+ def filter_by(self, **kwargs: Any) -> "QueryWrapper[T]": ...
48
+ def from_statement(self, statement: Any) -> "QueryWrapper[T]": ...
49
+ def get_children(self, **kw: Any) -> "QueryWrapper[T]": ...
50
+ def get_execution_options(
51
+ self,
52
+ ) -> "QueryWrapper[T]": ...
53
+ def get_final_froms(
54
+ self,
55
+ ) -> "QueryWrapper[T]": ...
56
+ def get_label_style(
57
+ self,
58
+ ) -> "QueryWrapper[T]": ...
59
+ def group_by(
60
+ self, _GenerativeSelect__first: Any = _NoArg.NO_ARG, *clauses: Any
61
+ ) -> "QueryWrapper[T]": ...
62
+ def having(self, *having: Any) -> "QueryWrapper[T]": ...
63
+ def intersect(self, *other: Any) -> "QueryWrapper[T]": ...
64
+ def intersect_all(self, *other: Any) -> "QueryWrapper[T]": ...
65
+ def is_derived_from(self, fromclause: Any) -> "QueryWrapper[T]": ...
66
+ def join(
67
+ self, target: Any, onclause: Any = None, isouter: Any = False, full: Any = False
68
+ ) -> "QueryWrapper[T]": ...
69
+ def join_from(
70
+ self,
71
+ from_: Any,
72
+ target: Any,
73
+ onclause: Any = None,
74
+ isouter: Any = False,
75
+ full: Any = False,
76
+ ) -> "QueryWrapper[T]": ...
77
+ def label(self, name: Any) -> "QueryWrapper[T]": ...
78
+ def lateral(self, name: Any = None) -> "QueryWrapper[T]": ...
79
+ def limit(self, limit: Any) -> "QueryWrapper[T]": ...
80
+ def memoized_instancemethod(
81
+ self,
82
+ ) -> "QueryWrapper[T]": ...
83
+ def offset(self, offset: Any) -> "QueryWrapper[T]": ...
84
+ def options(self, *options: Any) -> "QueryWrapper[T]": ...
85
+ def order_by(
86
+ self, _GenerativeSelect__first: Any = _NoArg.NO_ARG, *clauses: Any
87
+ ) -> QueryWrapper[T]: ...
88
+ def outerjoin(
89
+ self, target: Any, onclause: Any = None, full: Any = False
90
+ ) -> "QueryWrapper[T]": ...
91
+ def outerjoin_from(
92
+ self, from_: Any, target: Any, onclause: Any = None, full: Any = False
93
+ ) -> "QueryWrapper[T]": ...
94
+ def params(
95
+ self, _ClauseElement__optionaldict: Any = None, **kwargs: Any
96
+ ) -> "QueryWrapper[T]": ...
97
+ def prefix_with(self, *prefixes: Any, dialect: Any = "*") -> "QueryWrapper[T]": ...
98
+ def reduce_columns(self, only_synonyms: Any = True) -> "QueryWrapper[T]": ...
99
+ def replace_selectable(self, old: Any, alias: Any) -> "QueryWrapper[T]": ...
100
+ def scalar_subquery(
101
+ self,
102
+ ) -> "QueryWrapper[T]": ...
103
+ def select(self, *arg: Any, **kw: Any) -> "QueryWrapper[T]": ...
104
+ def select_from(self, *froms: Any) -> "QueryWrapper[T]": ...
105
+ def self_group(self, against: Any = None) -> "QueryWrapper[T]": ...
106
+ def set_label_style(self, style: Any) -> "QueryWrapper[T]": ...
107
+ def slice(self, start: Any, stop: Any) -> "QueryWrapper[T]": ...
108
+ def subquery(self, name: Any = None) -> "QueryWrapper[T]": ...
109
+ def suffix_with(self, *suffixes: Any, dialect: Any = "*") -> "QueryWrapper[T]": ...
110
+ def union(self, *other: Any) -> "QueryWrapper[T]": ...
111
+ def union_all(self, *other: Any) -> "QueryWrapper[T]": ...
112
+ def unique_params(
113
+ self, _ClauseElement__optionaldict: Any = None, **kwargs: Any
114
+ ) -> "QueryWrapper[T]": ...
115
+ def where(self, *whereclause: Any) -> "QueryWrapper[T]": ...
116
+ def with_for_update(
117
+ self,
118
+ nowait: Any = False,
119
+ read: Any = False,
120
+ of: Any = None,
121
+ skip_locked: Any = False,
122
+ key_share: Any = False,
123
+ ) -> "QueryWrapper[T]": ...
124
+ def with_hint(
125
+ self, selectable: Any, text: Any, dialect_name: Any = "*"
126
+ ) -> "QueryWrapper[T]": ...
127
+ def with_only_columns(
128
+ self, *entities: Any, maintain_column_froms: Any = False, **_Select__kw: Any
129
+ ) -> "QueryWrapper[T]": ...
130
+ def with_statement_hint(
131
+ self, text: Any, dialect_name: Any = "*"
132
+ ) -> "QueryWrapper[T]": ...
@@ -2,7 +2,6 @@
2
2
  Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sqlalchemy.py
3
3
  """
4
4
 
5
- from typing import Optional
6
5
  from uuid import UUID
7
6
 
8
7
  from pydantic import (
@@ -19,25 +18,32 @@ from activemodel.errors import TypeIDValidationError
19
18
  class TypeIDType(types.TypeDecorator):
20
19
  """
21
20
  A SQLAlchemy TypeDecorator that allows storing TypeIDs in the database.
22
- The prefix will not be persisted, instead the database-native UUID field will be used.
23
- At retrieval time a TypeID will be constructed based on the configured prefix and the
24
- UUID value from the database.
25
-
26
- Usage:
27
- # will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2"
28
- id = mapped_column(
29
- TypeIDType("user"),
30
- primary_key=True,
31
- default=lambda: TypeID("user")
32
- )
21
+
22
+ The prefix will not be persisted to the database, instead the database-native UUID field will be used.
23
+ At retrieval time a TypeID will be constructed (in python) based on the configured prefix and the UUID
24
+ value from the database.
25
+
26
+ For example:
27
+
28
+ >>> id = mapped_column(
29
+ >>> TypeIDType("user"),
30
+ >>> primary_key=True,
31
+ >>> default=lambda: TypeID("user")
32
+ >>> )
33
+
34
+ Will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2". There's a mixin provided to make it easy
35
+ to add a `id` pk field to your model with a specific prefix.
33
36
  """
34
37
 
38
+ # TODO are we sure we wouldn't use TypeID here?
35
39
  impl = types.Uuid
40
+ # TODO why the types version?
36
41
  # impl = uuid.UUID
42
+
37
43
  cache_ok = True
38
- prefix: Optional[str] = None
44
+ prefix: str
39
45
 
40
- def __init__(self, prefix: Optional[str], *args, **kwargs):
46
+ def __init__(self, prefix: str, *args, **kwargs):
41
47
  self.prefix = prefix
42
48
  super().__init__(*args, **kwargs)
43
49
 
@@ -70,6 +76,7 @@ class TypeIDType(types.TypeDecorator):
70
76
  if isinstance(value, str):
71
77
  # no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle
72
78
  # ex: '01942886-7afc-7129-8f57-db09137ed002'
79
+ # if an invalid uuid is passed, `ValueError('badly formed hexadecimal UUID string')` will be raised
73
80
  return UUID(value)
74
81
 
75
82
  if isinstance(value, TypeID):
@@ -90,6 +97,8 @@ class TypeIDType(types.TypeDecorator):
90
97
  raise ValueError("Unexpected input type")
91
98
 
92
99
  def process_result_value(self, value, dialect):
100
+ "convert a raw UUID, without a prefix, to a TypeID with the correct prefix"
101
+
93
102
  if value is None:
94
103
  return None
95
104
 
@@ -123,13 +132,19 @@ class TypeIDType(types.TypeDecorator):
123
132
  - https://github.com/alice-biometrics/petisco/blob/b01ef1b84949d156f73919e126ed77aa8e0b48dd/petisco/base/domain/model/uuid.py#L50
124
133
  """
125
134
 
135
+ def convert_from_string(value: str | TypeID) -> TypeID:
136
+ if isinstance(value, TypeID):
137
+ return value
138
+
139
+ return TypeID.from_string(value)
140
+
126
141
  from_uuid_schema = core_schema.chain_schema(
127
142
  [
128
143
  # TODO not sure how this is different from the UUID schema, should try it out.
129
144
  # core_schema.is_instance_schema(TypeID),
130
145
  # core_schema.uuid_schema(),
131
146
  core_schema.no_info_plain_validator_function(
132
- TypeID.from_string,
147
+ convert_from_string,
133
148
  json_schema_input_schema=core_schema.str_schema(),
134
149
  ),
135
150
  ]
@@ -151,11 +166,10 @@ class TypeIDType(types.TypeDecorator):
151
166
  # )
152
167
  # },
153
168
  python_schema=core_schema.union_schema([from_uuid_schema]),
154
- serialization=core_schema.plain_serializer_function_ser_schema(
155
- lambda x: str(x)
156
- ),
169
+ serialization=core_schema.plain_serializer_function_ser_schema(str),
157
170
  )
158
171
 
172
+ # TODO I have a feeling that the `serialization` param in the above method solves this for us.
159
173
  @classmethod
160
174
  def __get_pydantic_json_schema__(
161
175
  cls, schema: CoreSchema, handler: GetJsonSchemaHandler
@@ -164,18 +178,18 @@ class TypeIDType(types.TypeDecorator):
164
178
  Called when generating the openapi schema. This overrides the `function-plain` type which
165
179
  is generated by the `no_info_plain_validator_function`.
166
180
 
167
- This logis seems to be a hot part of the codebase, so I'd expect this to break as pydantic
181
+ This logic seems to be a hot part of the codebase, so I'd expect this to break as pydantic
168
182
  fastapi continue to evolve.
169
183
 
170
184
  Note that this method can return multiple types. A return value can be as simple as:
171
185
 
172
- {"type": "string"}
186
+ >>> {"type": "string"}
173
187
 
174
188
  Or, you could return a more specific JSON schema type:
175
189
 
176
- core_schema.uuid_schema()
190
+ >>> core_schema.uuid_schema()
177
191
 
178
- The problem with using something like uuid_schema is the specifi patterns
192
+ The problem with using something like uuid_schema is the specific patterns
179
193
 
180
194
  https://github.com/BeanieODM/beanie/blob/2190cd9d1fc047af477d5e6897cc283799f54064/beanie/odm/fields.py#L153
181
195
  """
@@ -0,0 +1,22 @@
1
+ from typing import Any, Type
2
+
3
+ from pydantic import GetCoreSchemaHandler
4
+ from pydantic_core import CoreSchema, core_schema
5
+
6
+ from typeid import TypeID
7
+
8
+
9
+ @classmethod
10
+ def get_pydantic_core_schema(
11
+ cls: Type[TypeID], source_type: Any, handler: GetCoreSchemaHandler
12
+ ) -> CoreSchema:
13
+ return core_schema.union_schema(
14
+ [
15
+ core_schema.str_schema(),
16
+ core_schema.is_instance_schema(cls),
17
+ ],
18
+ serialization=core_schema.plain_serializer_function_ser_schema(str),
19
+ )
20
+
21
+
22
+ TypeID.__get_pydantic_core_schema__ = get_pydantic_core_schema
activemodel/utils.py CHANGED
@@ -1,18 +1,12 @@
1
- import inspect
2
- import pkgutil
3
- import sys
4
- from types import ModuleType
5
-
6
1
  from sqlalchemy import text
7
- from sqlmodel import SQLModel
8
2
  from sqlmodel.sql.expression import SelectOfScalar
9
3
 
10
- from .logger import logger
11
4
  from .session_manager import get_engine, get_session
12
5
 
13
6
 
14
- def compile_sql(target: SelectOfScalar):
15
- "convert a query into SQL, helpful for debugging"
7
+ def compile_sql(target: SelectOfScalar) -> str:
8
+ "convert a query into SQL, helpful for debugging sqlalchemy/sqlmodel queries"
9
+
16
10
  dialect = get_engine().dialect
17
11
  # TODO I wonder if we could store the dialect to avoid getting an engine reference
18
12
  compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
@@ -25,36 +19,6 @@ def raw_sql_exec(raw_query: str):
25
19
  session.execute(text(raw_query))
26
20
 
27
21
 
28
- def find_all_sqlmodels(module: ModuleType):
29
- """Import all model classes from module and submodules into current namespace."""
30
-
31
- logger.debug(f"Starting model import from module: {module.__name__}")
32
- model_classes = {}
33
-
34
- # Walk through all submodules
35
- for loader, module_name, is_pkg in pkgutil.walk_packages(module.__path__):
36
- full_name = f"{module.__name__}.{module_name}"
37
- logger.debug(f"Importing submodule: {full_name}")
38
-
39
- # Check if module is already imported
40
- if full_name in sys.modules:
41
- submodule = sys.modules[full_name]
42
- else:
43
- logger.warning(
44
- f"Module not found in sys.modules, not importing: {full_name}"
45
- )
46
- continue
47
-
48
- # Get all classes from module
49
- for name, obj in inspect.getmembers(submodule):
50
- if inspect.isclass(obj) and issubclass(obj, SQLModel) and obj != SQLModel:
51
- logger.debug(f"Found model class: {name}")
52
- model_classes[name] = obj
53
-
54
- logger.debug(f"Completed model import. Found {len(model_classes)} models")
55
- return model_classes
56
-
57
-
58
22
  def hash_function_code(func):
59
23
  "get sha of a function to easily assert that it hasn't changed"
60
24