reflex 0.8.14.post1__py3-none-any.whl → 0.8.15a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflex might be problematic. Click here for more details.

Files changed (49) hide show
  1. reflex/__init__.py +12 -7
  2. reflex/__init__.pyi +11 -3
  3. reflex/app.py +5 -2
  4. reflex/base.py +58 -33
  5. reflex/components/datadisplay/dataeditor.py +17 -2
  6. reflex/components/datadisplay/dataeditor.pyi +6 -2
  7. reflex/components/field.py +3 -1
  8. reflex/components/lucide/icon.py +2 -1
  9. reflex/components/lucide/icon.pyi +2 -1
  10. reflex/components/markdown/markdown.py +101 -27
  11. reflex/components/sonner/toast.py +3 -2
  12. reflex/components/sonner/toast.pyi +3 -2
  13. reflex/constants/base.py +5 -0
  14. reflex/constants/installer.py +3 -3
  15. reflex/environment.py +9 -1
  16. reflex/event.py +3 -0
  17. reflex/istate/manager/__init__.py +120 -0
  18. reflex/istate/manager/disk.py +210 -0
  19. reflex/istate/manager/memory.py +76 -0
  20. reflex/istate/{manager.py → manager/redis.py} +5 -372
  21. reflex/istate/proxy.py +35 -24
  22. reflex/model.py +534 -511
  23. reflex/plugins/tailwind_v4.py +2 -2
  24. reflex/reflex.py +16 -10
  25. reflex/state.py +35 -34
  26. reflex/testing.py +12 -14
  27. reflex/utils/build.py +11 -1
  28. reflex/utils/compat.py +51 -48
  29. reflex/utils/misc.py +2 -1
  30. reflex/utils/monitoring.py +1 -2
  31. reflex/utils/prerequisites.py +19 -4
  32. reflex/utils/processes.py +3 -1
  33. reflex/utils/redir.py +21 -37
  34. reflex/utils/serializers.py +21 -20
  35. reflex/utils/telemetry.py +0 -2
  36. reflex/utils/templates.py +4 -4
  37. reflex/utils/types.py +82 -90
  38. reflex/vars/base.py +108 -41
  39. reflex/vars/color.py +28 -8
  40. reflex/vars/datetime.py +6 -2
  41. reflex/vars/dep_tracking.py +2 -2
  42. reflex/vars/number.py +26 -0
  43. reflex/vars/object.py +51 -7
  44. reflex/vars/sequence.py +32 -1
  45. {reflex-0.8.14.post1.dist-info → reflex-0.8.15a1.dist-info}/METADATA +8 -3
  46. {reflex-0.8.14.post1.dist-info → reflex-0.8.15a1.dist-info}/RECORD +49 -46
  47. {reflex-0.8.14.post1.dist-info → reflex-0.8.15a1.dist-info}/WHEEL +0 -0
  48. {reflex-0.8.14.post1.dist-info → reflex-0.8.15a1.dist-info}/entry_points.txt +0 -0
  49. {reflex-0.8.14.post1.dist-info → reflex-0.8.15a1.dist-info}/licenses/LICENSE +0 -0
reflex/model.py CHANGED
@@ -5,69 +5,22 @@ from __future__ import annotations
5
5
  import re
6
6
  from collections import defaultdict
7
7
  from contextlib import suppress
8
- from typing import Any, ClassVar
9
-
10
- import alembic.autogenerate
11
- import alembic.command
12
- import alembic.config
13
- import alembic.operations.ops
14
- import alembic.runtime.environment
15
- import alembic.script
16
- import alembic.util
17
- import sqlalchemy
18
- import sqlalchemy.exc
19
- import sqlalchemy.ext.asyncio
20
- import sqlalchemy.orm
21
- from alembic.runtime.migration import MigrationContext
22
- from alembic.script.base import Script
23
-
24
- from reflex.base import Base
8
+ from importlib.util import find_spec
9
+ from typing import TYPE_CHECKING, Any, ClassVar
10
+
25
11
  from reflex.config import get_config
26
12
  from reflex.environment import environment
27
13
  from reflex.utils import console
28
- from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
29
-
30
- _ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
31
- _ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
32
- _AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
33
-
34
- # Import AsyncSession _after_ reflex.utils.compat
35
- from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
36
-
37
-
38
- def format_revision(
39
- rev: Script,
40
- current_rev: str | None,
41
- current_reached_ref: list[bool],
42
- ) -> str:
43
- """Format a single revision for display.
44
-
45
- Args:
46
- rev: The alembic script object
47
- current_rev: The currently applied revision ID
48
- current_reached_ref: Mutable reference to track if we've reached current revision
49
-
50
- Returns:
51
- Formatted string for display
52
- """
53
- current = rev.revision
54
- message = rev.doc
55
-
56
- # Determine if this migration is applied
57
- if current_rev is None:
58
- is_applied = False
59
- elif current == current_rev:
60
- is_applied = True
61
- current_reached_ref[0] = True
62
- else:
63
- is_applied = not current_reached_ref[0]
14
+ from reflex.utils.compat import sqlmodel_field_has_primary_key
15
+ from reflex.utils.serializers import serializer
64
16
 
65
- # Show checkmark or X with colors
66
- status_icon = "[green]✓[/green]" if is_applied else "[red]✗[/red]"
67
- head_marker = " (head)" if rev.is_head else ""
17
+ if TYPE_CHECKING:
18
+ import sqlalchemy
19
+ import sqlmodel
68
20
 
69
- # Format output with message
70
- return f" [{status_icon}] {current}{head_marker}, {message}"
21
+ SQLModelOrSqlAlchemy = (
22
+ type[sqlmodel.SQLModel] | type[sqlalchemy.orm.DeclarativeBase]
23
+ )
71
24
 
72
25
 
73
26
  def _safe_db_url_for_logging(url: str) -> str:
@@ -82,536 +35,606 @@ def _safe_db_url_for_logging(url: str) -> str:
82
35
  return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
83
36
 
84
37
 
85
- def get_engine_args(url: str | None = None) -> dict[str, Any]:
86
- """Get the database engine arguments.
87
-
88
- Args:
89
- url: The database url.
90
-
91
- Returns:
92
- The database engine arguments as a dict.
93
- """
94
- kwargs: dict[str, Any] = {
95
- # Print the SQL queries if the log level is INFO or lower.
96
- "echo": environment.SQLALCHEMY_ECHO.get(),
97
- # Check connections before returning them.
98
- "pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
99
- "pool_size": environment.SQLALCHEMY_POOL_SIZE.get(),
100
- "max_overflow": environment.SQLALCHEMY_MAX_OVERFLOW.get(),
101
- "pool_recycle": environment.SQLALCHEMY_POOL_RECYCLE.get(),
102
- "pool_timeout": environment.SQLALCHEMY_POOL_TIMEOUT.get(),
103
- }
104
- conf = get_config()
105
- url = url or conf.db_url
106
- if url is not None and url.startswith("sqlite"):
107
- # Needed for the admin dash on sqlite.
108
- kwargs["connect_args"] = {"check_same_thread": False}
109
- return kwargs
110
-
111
-
112
- def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
113
- """Get the database engine.
114
-
115
- Args:
116
- url: the DB url to use.
117
-
118
- Returns:
119
- The database engine.
120
-
121
- Raises:
122
- ValueError: If the database url is None.
123
- """
124
- conf = get_config()
125
- url = url or conf.db_url
126
- if url is None:
127
- msg = "No database url configured"
128
- raise ValueError(msg)
129
-
130
- global _ENGINE
131
- if url in _ENGINE:
132
- return _ENGINE[url]
133
-
134
- if not environment.ALEMBIC_CONFIG.get().exists():
135
- console.warn(
136
- "Database is not initialized, run [bold]reflex db init[/bold] first."
137
- )
138
- _ENGINE[url] = sqlmodel.create_engine(
139
- url,
140
- **get_engine_args(url),
38
+ def _print_db_not_available(*args, **kwargs):
39
+ msg = (
40
+ "Database is not available. Please install the required packages: "
41
+ "`pip install reflex[db]`."
141
42
  )
142
- return _ENGINE[url]
143
-
43
+ raise ImportError(msg)
144
44
 
145
- def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
146
- """Get the async database engine.
147
45
 
148
- Args:
149
- url: The database url.
150
-
151
- Returns:
152
- The async database engine.
153
-
154
- Raises:
155
- ValueError: If the async database url is None.
156
- """
157
- if url is None:
158
- conf = get_config()
159
- url = conf.async_db_url
160
- if url is not None and conf.db_url is not None:
161
- async_db_url_tail = url.partition("://")[2]
162
- db_url_tail = conf.db_url.partition("://")[2]
163
- if async_db_url_tail != db_url_tail:
164
- console.warn(
165
- f"async_db_url `{_safe_db_url_for_logging(url)}` "
166
- "should reference the same database as "
167
- f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
168
- )
169
- if url is None:
170
- msg = "No async database url configured"
171
- raise ValueError(msg)
172
-
173
- global _ASYNC_ENGINE
174
- if url in _ASYNC_ENGINE:
175
- return _ASYNC_ENGINE[url]
46
+ class _ClassThatErrorsOnInit:
47
+ def __init__(self, *args, **kwargs):
48
+ _print_db_not_available(*args, **kwargs)
176
49
 
177
- if not environment.ALEMBIC_CONFIG.get().exists():
178
- console.warn(
179
- "Database is not initialized, run [bold]reflex db init[/bold] first."
180
- )
181
- _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
182
- url,
183
- **get_engine_args(url),
184
- )
185
- return _ASYNC_ENGINE[url]
186
50
 
51
+ if find_spec("sqlalchemy"):
52
+ import sqlalchemy
53
+ import sqlalchemy.exc
54
+ import sqlalchemy.ext.asyncio
55
+ import sqlalchemy.orm
187
56
 
188
- async def get_db_status() -> dict[str, bool]:
189
- """Checks the status of the database connection.
57
+ _ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
58
+ _ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
190
59
 
191
- Attempts to connect to the database and execute a simple query to verify connectivity.
60
+ def get_engine_args(url: str | None = None) -> dict[str, Any]:
61
+ """Get the database engine arguments.
192
62
 
193
- Returns:
194
- The status of the database connection.
195
- """
196
- status = True
197
- try:
198
- engine = get_engine()
199
- with engine.connect() as connection:
200
- connection.execute(sqlalchemy.text("SELECT 1"))
201
- except sqlalchemy.exc.OperationalError:
202
- status = False
63
+ Args:
64
+ url: The database url.
203
65
 
204
- return {"db": status}
66
+ Returns:
67
+ The database engine arguments as a dict.
68
+ """
69
+ kwargs: dict[str, Any] = {
70
+ # Print the SQL queries if the log level is INFO or lower.
71
+ "echo": environment.SQLALCHEMY_ECHO.get(),
72
+ # Check connections before returning them.
73
+ "pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
74
+ "pool_size": environment.SQLALCHEMY_POOL_SIZE.get(),
75
+ "max_overflow": environment.SQLALCHEMY_MAX_OVERFLOW.get(),
76
+ "pool_recycle": environment.SQLALCHEMY_POOL_RECYCLE.get(),
77
+ "pool_timeout": environment.SQLALCHEMY_POOL_TIMEOUT.get(),
78
+ }
79
+ conf = get_config()
80
+ url = url or conf.db_url
81
+ if url is not None and url.startswith("sqlite"):
82
+ # Needed for the admin dash on sqlite.
83
+ kwargs["connect_args"] = {"check_same_thread": False}
84
+ return kwargs
205
85
 
86
+ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
87
+ """Get the database engine.
206
88
 
207
- SQLModelOrSqlAlchemy = type[sqlmodel.SQLModel] | type[sqlalchemy.orm.DeclarativeBase]
89
+ Args:
90
+ url: the DB url to use.
208
91
 
92
+ Returns:
93
+ The database engine.
209
94
 
210
- class ModelRegistry:
211
- """Registry for all models."""
95
+ Raises:
96
+ ValueError: If the database url is None.
97
+ """
98
+ conf = get_config()
99
+ url = url or conf.db_url
100
+ if url is None:
101
+ msg = "No database url configured"
102
+ raise ValueError(msg)
212
103
 
213
- models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
104
+ global _ENGINE
105
+ if url in _ENGINE:
106
+ return _ENGINE[url]
214
107
 
215
- # Cache the metadata to avoid re-creating it.
216
- _metadata: ClassVar[sqlalchemy.MetaData | None] = None
108
+ if not environment.ALEMBIC_CONFIG.get().exists():
109
+ console.warn(
110
+ "Database is not initialized, run [bold]reflex db init[/bold] first."
111
+ )
112
+ _ENGINE[url] = sqlalchemy.engine.create_engine(
113
+ url,
114
+ **get_engine_args(url),
115
+ )
116
+ return _ENGINE[url]
217
117
 
218
- @classmethod
219
- def register(cls, model: SQLModelOrSqlAlchemy):
220
- """Register a model. Can be used directly or as a decorator.
118
+ def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
119
+ """Get the async database engine.
221
120
 
222
121
  Args:
223
- model: The model to register.
122
+ url: The database url.
224
123
 
225
124
  Returns:
226
- The model passed in as an argument (Allows decorator usage)
227
- """
228
- cls.models.add(model)
229
- return model
230
-
231
- @classmethod
232
- def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
233
- """Get registered models.
234
-
235
- Args:
236
- include_empty: If True, include models with empty metadata.
125
+ The async database engine.
237
126
 
238
- Returns:
239
- The registered models.
127
+ Raises:
128
+ ValueError: If the async database url is None.
240
129
  """
241
- if include_empty:
242
- return cls.models
243
- return {
244
- model for model in cls.models if not cls._model_metadata_is_empty(model)
245
- }
130
+ if url is None:
131
+ conf = get_config()
132
+ url = conf.async_db_url
133
+ if url is not None and conf.db_url is not None:
134
+ async_db_url_tail = url.partition("://")[2]
135
+ db_url_tail = conf.db_url.partition("://")[2]
136
+ if async_db_url_tail != db_url_tail:
137
+ console.warn(
138
+ f"async_db_url `{_safe_db_url_for_logging(url)}` "
139
+ "should reference the same database as "
140
+ f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
141
+ )
142
+ if url is None:
143
+ msg = "No async database url configured"
144
+ raise ValueError(msg)
145
+
146
+ global _ASYNC_ENGINE
147
+ if url in _ASYNC_ENGINE:
148
+ return _ASYNC_ENGINE[url]
149
+
150
+ if not environment.ALEMBIC_CONFIG.get().exists():
151
+ console.warn(
152
+ "Database is not initialized, run [bold]reflex db init[/bold] first."
153
+ )
154
+ _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
155
+ url,
156
+ **get_engine_args(url),
157
+ )
158
+ return _ASYNC_ENGINE[url]
246
159
 
247
- @staticmethod
248
- def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
249
- """Check if the model metadata is empty.
160
+ def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
161
+ """Get a bare sqlalchemy session to interact with the database.
250
162
 
251
163
  Args:
252
- model: The model to check.
164
+ url: The database url.
253
165
 
254
166
  Returns:
255
- True if the model metadata is empty, False otherwise.
167
+ A database session.
256
168
  """
257
- return len(model.metadata.tables) == 0
169
+ return sqlalchemy.orm.Session(get_engine(url))
170
+
171
+ class ModelRegistry:
172
+ """Registry for all models."""
173
+
174
+ models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
175
+
176
+ # Cache the metadata to avoid re-creating it.
177
+ _metadata: ClassVar[sqlalchemy.MetaData | None] = None
178
+
179
+ @classmethod
180
+ def register(cls, model: SQLModelOrSqlAlchemy):
181
+ """Register a model. Can be used directly or as a decorator.
182
+
183
+ Args:
184
+ model: The model to register.
185
+
186
+ Returns:
187
+ The model passed in as an argument (Allows decorator usage)
188
+ """
189
+ cls.models.add(model)
190
+ return model
191
+
192
+ @classmethod
193
+ def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
194
+ """Get registered models.
195
+
196
+ Args:
197
+ include_empty: If True, include models with empty metadata.
198
+
199
+ Returns:
200
+ The registered models.
201
+ """
202
+ if include_empty:
203
+ return cls.models
204
+ return {
205
+ model for model in cls.models if not cls._model_metadata_is_empty(model)
206
+ }
207
+
208
+ @staticmethod
209
+ def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
210
+ """Check if the model metadata is empty.
211
+
212
+ Args:
213
+ model: The model to check.
214
+
215
+ Returns:
216
+ True if the model metadata is empty, False otherwise.
217
+ """
218
+ return len(model.metadata.tables) == 0
219
+
220
+ @classmethod
221
+ def get_metadata(cls) -> sqlalchemy.MetaData:
222
+ """Get the database metadata.
223
+
224
+ Returns:
225
+ The database metadata.
226
+ """
227
+ if cls._metadata is not None:
228
+ return cls._metadata
229
+
230
+ models = cls.get_models(include_empty=False)
231
+
232
+ if len(models) == 1:
233
+ metadata = next(iter(models)).metadata
234
+ else:
235
+ # Merge the metadata from all the models.
236
+ # This allows mixing bare sqlalchemy models with sqlmodel models in one database.
237
+ metadata = sqlalchemy.MetaData()
238
+ for model in cls.get_models():
239
+ for table in model.metadata.tables.values():
240
+ table.to_metadata(metadata)
241
+
242
+ # Cache the metadata
243
+ cls._metadata = metadata
244
+
245
+ return metadata
246
+
247
+ else:
248
+ get_engine_args = _print_db_not_available
249
+ get_engine = _print_db_not_available
250
+ get_async_engine = _print_db_not_available
251
+ sqla_session = _print_db_not_available
252
+ ModelRegistry = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]
253
+
254
+ if find_spec("sqlmodel") and find_spec("sqlalchemy") and find_spec("pydantic"):
255
+ import alembic.autogenerate
256
+ import alembic.command
257
+ import alembic.config
258
+ import alembic.operations.ops
259
+ import alembic.runtime.environment
260
+ import alembic.script
261
+ import sqlmodel
262
+ from alembic.runtime.migration import MigrationContext
263
+ from alembic.script.base import Script
264
+ from sqlmodel.ext.asyncio.session import AsyncSession
265
+
266
+ _AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
267
+
268
+ def format_revision(
269
+ rev: Script,
270
+ current_rev: str | None,
271
+ current_reached_ref: list[bool],
272
+ ) -> str:
273
+ """Format a single revision for display.
258
274
 
259
- @classmethod
260
- def get_metadata(cls) -> sqlalchemy.MetaData:
261
- """Get the database metadata.
275
+ Args:
276
+ rev: The alembic script object
277
+ current_rev: The currently applied revision ID
278
+ current_reached_ref: Mutable reference to track if we've reached current revision
262
279
 
263
280
  Returns:
264
- The database metadata.
281
+ Formatted string for display
265
282
  """
266
- if cls._metadata is not None:
267
- return cls._metadata
268
-
269
- models = cls.get_models(include_empty=False)
270
-
271
- if len(models) == 1:
272
- metadata = next(iter(models)).metadata
283
+ current = rev.revision
284
+ message = rev.doc
285
+
286
+ # Determine if this migration is applied
287
+ if current_rev is None:
288
+ is_applied = False
289
+ elif current == current_rev:
290
+ is_applied = True
291
+ current_reached_ref[0] = True
273
292
  else:
274
- # Merge the metadata from all the models.
275
- # This allows mixing bare sqlalchemy models with sqlmodel models in one database.
276
- metadata = sqlalchemy.MetaData()
277
- for model in cls.get_models():
278
- for table in model.metadata.tables.values():
279
- table.to_metadata(metadata)
280
-
281
- # Cache the metadata
282
- cls._metadata = metadata
283
-
284
- return metadata
293
+ is_applied = not current_reached_ref[0]
285
294
 
295
+ # Show checkmark or X with colors
296
+ status_icon = "[green]✓[/green]" if is_applied else "[red]✗[/red]"
297
+ head_marker = " (head)" if rev.is_head else ""
286
298
 
287
- class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride]
288
- """Base class to define a table in the database."""
299
+ # Format output with message
300
+ return f" [{status_icon}] {current}{head_marker}, {message}"
289
301
 
290
- # The primary key for the table.
291
- id: int | None = sqlmodel.Field(default=None, primary_key=True)
302
+ async def get_db_status() -> dict[str, bool]:
303
+ """Checks the status of the database connection.
292
304
 
293
- def __init_subclass__(cls):
294
- """Drop the default primary key field if any primary key field is defined."""
295
- non_default_primary_key_fields = [
296
- field_name
297
- for field_name, field in cls.__fields__.items()
298
- if field_name != "id" and sqlmodel_field_has_primary_key(field)
299
- ]
300
- if non_default_primary_key_fields:
301
- cls.__fields__.pop("id", None)
302
-
303
- super().__init_subclass__()
304
-
305
- @classmethod
306
- def _dict_recursive(cls, value: Any):
307
- """Recursively serialize the relationship object(s).
308
-
309
- Args:
310
- value: The value to serialize.
305
+ Attempts to connect to the database and execute a simple query to verify connectivity.
311
306
 
312
307
  Returns:
313
- The serialized value.
308
+ The status of the database connection.
314
309
  """
315
- if hasattr(value, "dict"):
316
- return value.dict()
317
- if isinstance(value, list):
318
- return [cls._dict_recursive(item) for item in value]
319
- return value
310
+ status = True
311
+ try:
312
+ engine = get_engine()
313
+ with engine.connect() as connection:
314
+ connection.execute(sqlalchemy.text("SELECT 1"))
315
+ except sqlalchemy.exc.OperationalError:
316
+ status = False
317
+
318
+ return {"db": status}
320
319
 
321
- def dict(self, **kwargs):
322
- """Convert the object to a dictionary.
320
+ @serializer
321
+ def serialize_sqlmodel(m: sqlmodel.SQLModel) -> dict[str, Any]:
322
+ """Serialize a SQLModel object to a dictionary.
323
323
 
324
324
  Args:
325
- kwargs: Ignored but needed for compatibility.
325
+ m: The SQLModel object to serialize.
326
326
 
327
327
  Returns:
328
- The object as a dictionary.
328
+ The serialized object as a dictionary.
329
329
  """
330
- base_fields = {name: getattr(self, name) for name in self.__fields__}
330
+ base_fields = m.model_dump()
331
331
  relationships = {}
332
332
  # SQLModel relationships do not appear in __fields__, but should be included if present.
333
- for name in self.__sqlmodel_relationships__:
333
+ for name in m.__sqlmodel_relationships__:
334
334
  with suppress(
335
335
  sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
336
336
  ):
337
- relationships[name] = self._dict_recursive(getattr(self, name))
337
+ relationships[name] = getattr(m, name)
338
338
  return {
339
339
  **base_fields,
340
340
  **relationships,
341
341
  }
342
342
 
343
- @staticmethod
344
- def create_all():
345
- """Create all the tables."""
346
- engine = get_engine()
347
- ModelRegistry.get_metadata().create_all(engine)
343
+ class Model(sqlmodel.SQLModel):
344
+ """Base class to define a table in the database."""
348
345
 
349
- @staticmethod
350
- def get_db_engine():
351
- """Get the database engine.
352
-
353
- Returns:
354
- The database engine.
355
- """
356
- return get_engine()
346
+ # The primary key for the table.
347
+ id: int | None = sqlmodel.Field(default=None, primary_key=True)
357
348
 
358
- @staticmethod
359
- def _alembic_config():
360
- """Get the alembic configuration and script_directory.
349
+ model_config = { # pyright: ignore [reportAssignmentType]
350
+ "arbitrary_types_allowed": True,
351
+ "use_enum_values": True,
352
+ "extra": "allow",
353
+ }
361
354
 
362
- Returns:
363
- tuple of (config, script_directory)
364
- """
365
- config = alembic.config.Config(environment.ALEMBIC_CONFIG.get())
366
- if not config.get_main_option("script_location"):
367
- config.set_main_option("script_location", "version")
368
- return config, alembic.script.ScriptDirectory.from_config(config)
355
+ @classmethod
356
+ def __pydantic_init_subclass__(cls):
357
+ """Drop the default primary key field if any primary key field is defined."""
358
+ non_default_primary_key_fields = [
359
+ field_name
360
+ for field_name, field_info in cls.model_fields.items()
361
+ if field_name != "id" and sqlmodel_field_has_primary_key(field_info)
362
+ ]
363
+ if non_default_primary_key_fields:
364
+ cls.model_fields.pop("id", None)
365
+ console.deprecate(
366
+ feature_name="Overriding default primary key",
367
+ reason=(
368
+ "Register sqlmodel.SQLModel classes with `@rx.ModelRegistry.register`"
369
+ ),
370
+ deprecation_version="0.8.15",
371
+ removal_version="0.9.0",
372
+ )
373
+ super().__pydantic_init_subclass__()
374
+
375
+ @staticmethod
376
+ def create_all():
377
+ """Create all the tables."""
378
+ engine = get_engine()
379
+ ModelRegistry.get_metadata().create_all(engine)
380
+
381
+ @staticmethod
382
+ def get_db_engine():
383
+ """Get the database engine.
384
+
385
+ Returns:
386
+ The database engine.
387
+ """
388
+ return get_engine()
389
+
390
+ @staticmethod
391
+ def _alembic_config():
392
+ """Get the alembic configuration and script_directory.
393
+
394
+ Returns:
395
+ tuple of (config, script_directory)
396
+ """
397
+ config = alembic.config.Config(environment.ALEMBIC_CONFIG.get())
398
+ if not config.get_main_option("script_location"):
399
+ config.set_main_option("script_location", "version")
400
+ return config, alembic.script.ScriptDirectory.from_config(config)
401
+
402
+ @staticmethod
403
+ def _alembic_render_item(
404
+ type_: str,
405
+ obj: Any,
406
+ autogen_context: alembic.autogenerate.api.AutogenContext,
407
+ ):
408
+ """Alembic render_item hook call.
369
409
 
370
- @staticmethod
371
- def _alembic_render_item(
372
- type_: str,
373
- obj: Any,
374
- autogen_context: alembic.autogenerate.api.AutogenContext,
375
- ):
376
- """Alembic render_item hook call.
410
+ This method is called to provide python code for the given obj,
411
+ but currently it is only used to add `sqlmodel` to the import list
412
+ when generating migration scripts.
377
413
 
378
- This method is called to provide python code for the given obj,
379
- but currently it is only used to add `sqlmodel` to the import list
380
- when generating migration scripts.
414
+ See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
381
415
 
382
- See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
416
+ Args:
417
+ type_: One of "schema", "table", "column", "index",
418
+ "unique_constraint", or "foreign_key_constraint".
419
+ obj: The object being rendered.
420
+ autogen_context: Shared AutogenContext passed to each render_item call.
383
421
 
384
- Args:
385
- type_: One of "schema", "table", "column", "index",
386
- "unique_constraint", or "foreign_key_constraint".
387
- obj: The object being rendered.
388
- autogen_context: Shared AutogenContext passed to each render_item call.
422
+ Returns:
423
+ False - Indicating that the default rendering should be used.
424
+ """
425
+ autogen_context.imports.add("import sqlmodel")
426
+ return False
389
427
 
390
- Returns:
391
- False - Indicating that the default rendering should be used.
392
- """
393
- autogen_context.imports.add("import sqlmodel")
394
- return False
395
-
396
- @classmethod
397
- def alembic_init(cls):
398
- """Initialize alembic for the project."""
399
- alembic.command.init(
400
- config=alembic.config.Config(environment.ALEMBIC_CONFIG.get()),
401
- directory=str(environment.ALEMBIC_CONFIG.get().parent / "alembic"),
402
- )
428
+ @classmethod
429
+ def alembic_init(cls):
430
+ """Initialize alembic for the project."""
431
+ alembic.command.init(
432
+ config=alembic.config.Config(environment.ALEMBIC_CONFIG.get()),
433
+ directory=str(environment.ALEMBIC_CONFIG.get().parent / "alembic"),
434
+ )
403
435
 
404
- @classmethod
405
- def get_migration_history(cls):
406
- """Get migration history with current database state.
436
+ @classmethod
437
+ def get_migration_history(cls):
438
+ """Get migration history with current database state.
439
+
440
+ Returns:
441
+ tuple: (current_revision, revisions_list) where revisions_list is in chronological order
442
+ """
443
+ # Get current revision from database
444
+ with cls.get_db_engine().connect() as connection:
445
+ context = MigrationContext.configure(connection)
446
+ current_rev = context.get_current_revision()
447
+
448
+ # Get all revisions from base to head
449
+ _, script_dir = cls._alembic_config()
450
+ revisions = list(script_dir.walk_revisions())
451
+ revisions.reverse() # Reverse to get chronological order (base first)
452
+
453
+ return current_rev, revisions
454
+
455
+ @classmethod
456
+ def alembic_autogenerate(
457
+ cls,
458
+ connection: sqlalchemy.engine.Connection,
459
+ message: str | None = None,
460
+ write_migration_scripts: bool = True,
461
+ ) -> bool:
462
+ """Generate migration scripts for alembic-detectable changes.
463
+
464
+ Args:
465
+ connection: SQLAlchemy connection to use when detecting changes.
466
+ message: Human readable identifier describing the generated revision.
467
+ write_migration_scripts: If True, write autogenerated revisions to script directory.
468
+
469
+ Returns:
470
+ True when changes have been detected.
471
+ """
472
+ if not environment.ALEMBIC_CONFIG.get().exists():
473
+ return False
474
+
475
+ config, script_directory = cls._alembic_config()
476
+ revision_context = alembic.autogenerate.api.RevisionContext(
477
+ config=config,
478
+ script_directory=script_directory,
479
+ command_args=defaultdict(
480
+ lambda: None,
481
+ autogenerate=True,
482
+ head="head",
483
+ message=message,
484
+ ),
485
+ )
486
+ writer = alembic.autogenerate.rewriter.Rewriter()
407
487
 
408
- Returns:
409
- tuple: (current_revision, revisions_list) where revisions_list is in chronological order
410
- """
411
- # Get current revision from database
412
- with cls.get_db_engine().connect() as connection:
413
- context = MigrationContext.configure(connection)
414
- current_rev = context.get_current_revision()
415
-
416
- # Get all revisions from base to head
417
- _, script_dir = cls._alembic_config()
418
- revisions = list(script_dir.walk_revisions())
419
- revisions.reverse() # Reverse to get chronological order (base first)
420
-
421
- return current_rev, revisions
422
-
423
- @classmethod
424
- def alembic_autogenerate(
425
- cls,
426
- connection: sqlalchemy.engine.Connection,
427
- message: str | None = None,
428
- write_migration_scripts: bool = True,
429
- ) -> bool:
430
- """Generate migration scripts for alembic-detectable changes.
488
+ @writer.rewrites(alembic.operations.ops.AddColumnOp)
489
+ def render_add_column_with_server_default(
490
+ context: MigrationContext,
491
+ revision: str | None,
492
+ op: Any,
493
+ ):
494
+ # Carry the sqlmodel default as server_default so that newly added
495
+ # columns get the desired default value in existing rows.
496
+ if op.column.default is not None and op.column.server_default is None:
497
+ op.column.server_default = sqlalchemy.DefaultClause(
498
+ sqlalchemy.sql.expression.literal(op.column.default.arg),
499
+ )
500
+ return op
501
+
502
+ def run_autogenerate(rev: str, context: MigrationContext):
503
+ revision_context.run_autogenerate(rev, context)
504
+ return []
505
+
506
+ with alembic.runtime.environment.EnvironmentContext(
507
+ config=config,
508
+ script=script_directory,
509
+ fn=run_autogenerate,
510
+ ) as env:
511
+ env.configure(
512
+ connection=connection,
513
+ target_metadata=ModelRegistry.get_metadata(),
514
+ render_item=cls._alembic_render_item,
515
+ process_revision_directives=writer,
516
+ compare_type=False,
517
+ render_as_batch=True, # for sqlite compatibility
518
+ )
519
+ env.run_migrations()
520
+ changes_detected = False
521
+ if revision_context.generated_revisions:
522
+ upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
523
+ if upgrade_ops is not None:
524
+ changes_detected = bool(upgrade_ops.ops)
525
+ if changes_detected and write_migration_scripts:
526
+ # Must iterate the generator to actually write the scripts.
527
+ _ = tuple(revision_context.generate_scripts())
528
+ return changes_detected
529
+
530
+ @classmethod
531
+ def _alembic_upgrade(
532
+ cls,
533
+ connection: sqlalchemy.engine.Connection,
534
+ to_rev: str = "head",
535
+ ) -> None:
536
+ """Apply alembic migrations up to the given revision.
537
+
538
+ Args:
539
+ connection: SQLAlchemy connection to use when performing upgrade.
540
+ to_rev: Revision to migrate towards.
541
+ """
542
+ config, script_directory = cls._alembic_config()
543
+
544
+ def run_upgrade(rev: str, context: MigrationContext):
545
+ return script_directory._upgrade_revs(to_rev, rev)
546
+
547
+ with alembic.runtime.environment.EnvironmentContext(
548
+ config=config,
549
+ script=script_directory,
550
+ fn=run_upgrade,
551
+ ) as env:
552
+ env.configure(connection=connection)
553
+ env.run_migrations()
554
+
555
+ @classmethod
556
+ def migrate(cls, autogenerate: bool = False) -> bool | None:
557
+ """Execute alembic migrations for all sqlmodel Model classes.
558
+
559
+ If alembic is not installed or has not been initialized for the project,
560
+ then no action is performed.
561
+
562
+ If there are no revisions currently tracked by alembic, then
563
+ an initial revision will be created based on sqlmodel metadata.
564
+
565
+ If models in the app have changed in incompatible ways that alembic
566
+ cannot automatically generate revisions for, the app may not be able to
567
+ start up until migration scripts have been corrected by hand.
568
+
569
+ Args:
570
+ autogenerate: If True, generate migration script and use it to upgrade schema
571
+ (otherwise, just bring the schema to current "head" revision).
572
+
573
+ Returns:
574
+ True - indicating the process was successful.
575
+ None - indicating the process was skipped.
576
+ """
577
+ if not environment.ALEMBIC_CONFIG.get().exists():
578
+ return None
579
+
580
+ with cls.get_db_engine().connect() as connection:
581
+ cls._alembic_upgrade(connection=connection)
582
+ if autogenerate:
583
+ changes_detected = cls.alembic_autogenerate(connection=connection)
584
+ if changes_detected:
585
+ cls._alembic_upgrade(connection=connection)
586
+ connection.commit()
587
+ return True
588
+
589
+ @classmethod
590
+ def select(cls):
591
+ """Select rows from the table.
592
+
593
+ Returns:
594
+ The select statement.
595
+ """
596
+ return sqlmodel.select(cls)
597
+
598
+ ModelRegistry.register(Model)
599
+
600
+ def session(url: str | None = None) -> sqlmodel.Session:
601
+ """Get a sqlmodel session to interact with the database.
431
602
 
432
603
  Args:
433
- connection: SQLAlchemy connection to use when detecting changes.
434
- message: Human readable identifier describing the generated revision.
435
- write_migration_scripts: If True, write autogenerated revisions to script directory.
604
+ url: The database url.
436
605
 
437
606
  Returns:
438
- True when changes have been detected.
439
- """
440
- if not environment.ALEMBIC_CONFIG.get().exists():
441
- return False
442
-
443
- config, script_directory = cls._alembic_config()
444
- revision_context = alembic.autogenerate.api.RevisionContext(
445
- config=config,
446
- script_directory=script_directory,
447
- command_args=defaultdict(
448
- lambda: None,
449
- autogenerate=True,
450
- head="head",
451
- message=message,
452
- ),
453
- )
454
- writer = alembic.autogenerate.rewriter.Rewriter()
455
-
456
- @writer.rewrites(alembic.operations.ops.AddColumnOp)
457
- def render_add_column_with_server_default(
458
- context: MigrationContext,
459
- revision: str | None,
460
- op: Any,
461
- ):
462
- # Carry the sqlmodel default as server_default so that newly added
463
- # columns get the desired default value in existing rows.
464
- if op.column.default is not None and op.column.server_default is None:
465
- op.column.server_default = sqlalchemy.DefaultClause(
466
- sqlalchemy.sql.expression.literal(op.column.default.arg),
467
- )
468
- return op
469
-
470
- def run_autogenerate(rev: str, context: MigrationContext):
471
- revision_context.run_autogenerate(rev, context)
472
- return []
473
-
474
- with alembic.runtime.environment.EnvironmentContext(
475
- config=config,
476
- script=script_directory,
477
- fn=run_autogenerate,
478
- ) as env:
479
- env.configure(
480
- connection=connection,
481
- target_metadata=ModelRegistry.get_metadata(),
482
- render_item=cls._alembic_render_item,
483
- process_revision_directives=writer,
484
- compare_type=False,
485
- render_as_batch=True, # for sqlite compatibility
486
- )
487
- env.run_migrations()
488
- changes_detected = False
489
- if revision_context.generated_revisions:
490
- upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
491
- if upgrade_ops is not None:
492
- changes_detected = bool(upgrade_ops.ops)
493
- if changes_detected and write_migration_scripts:
494
- # Must iterate the generator to actually write the scripts.
495
- _ = tuple(revision_context.generate_scripts())
496
- return changes_detected
497
-
498
- @classmethod
499
- def _alembic_upgrade(
500
- cls,
501
- connection: sqlalchemy.engine.Connection,
502
- to_rev: str = "head",
503
- ) -> None:
504
- """Apply alembic migrations up to the given revision.
505
-
506
- Args:
507
- connection: SQLAlchemy connection to use when performing upgrade.
508
- to_rev: Revision to migrate towards.
607
+ A database session.
509
608
  """
510
- config, script_directory = cls._alembic_config()
511
-
512
- def run_upgrade(rev: str, context: MigrationContext):
513
- return script_directory._upgrade_revs(to_rev, rev)
514
-
515
- with alembic.runtime.environment.EnvironmentContext(
516
- config=config,
517
- script=script_directory,
518
- fn=run_upgrade,
519
- ) as env:
520
- env.configure(connection=connection)
521
- env.run_migrations()
609
+ return sqlmodel.Session(get_engine(url))
522
610
 
523
- @classmethod
524
- def migrate(cls, autogenerate: bool = False) -> bool | None:
525
- """Execute alembic migrations for all sqlmodel Model classes.
611
+ def asession(url: str | None = None) -> AsyncSession:
612
+ """Get an async sqlmodel session to interact with the database.
526
613
 
527
- If alembic is not installed or has not been initialized for the project,
528
- then no action is performed.
614
+ async with rx.asession() as asession:
615
+ ...
529
616
 
530
- If there are no revisions currently tracked by alembic, then
531
- an initial revision will be created based on sqlmodel metadata.
532
-
533
- If models in the app have changed in incompatible ways that alembic
534
- cannot automatically generate revisions for, the app may not be able to
535
- start up until migration scripts have been corrected by hand.
617
+ Most operations against the `asession` must be awaited.
536
618
 
537
619
  Args:
538
- autogenerate: If True, generate migration script and use it to upgrade schema
539
- (otherwise, just bring the schema to current "head" revision).
540
-
541
- Returns:
542
- True - indicating the process was successful.
543
- None - indicating the process was skipped.
544
- """
545
- if not environment.ALEMBIC_CONFIG.get().exists():
546
- return None
547
-
548
- with cls.get_db_engine().connect() as connection:
549
- cls._alembic_upgrade(connection=connection)
550
- if autogenerate:
551
- changes_detected = cls.alembic_autogenerate(connection=connection)
552
- if changes_detected:
553
- cls._alembic_upgrade(connection=connection)
554
- connection.commit()
555
- return True
556
-
557
- @classmethod
558
- def select(cls):
559
- """Select rows from the table.
620
+ url: The database url.
560
621
 
561
622
  Returns:
562
- The select statement.
623
+ An async database session.
563
624
  """
564
- return sqlmodel.select(cls)
565
-
566
-
567
- ModelRegistry.register(Model)
568
-
569
-
570
- def session(url: str | None = None) -> sqlmodel.Session:
571
- """Get a sqlmodel session to interact with the database.
572
-
573
- Args:
574
- url: The database url.
575
-
576
- Returns:
577
- A database session.
578
- """
579
- return sqlmodel.Session(get_engine(url))
580
-
581
-
582
- def asession(url: str | None = None) -> AsyncSession:
583
- """Get an async sqlmodel session to interact with the database.
584
-
585
- async with rx.asession() as asession:
586
- ...
587
-
588
- Most operations against the `asession` must be awaited.
589
-
590
- Args:
591
- url: The database url.
592
-
593
- Returns:
594
- An async database session.
595
- """
596
- global _AsyncSessionLocal
597
- if url not in _AsyncSessionLocal:
598
- _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
599
- bind=get_async_engine(url),
600
- class_=AsyncSession,
601
- expire_on_commit=False,
602
- autocommit=False,
603
- autoflush=False,
604
- )
605
- return _AsyncSessionLocal[url]()
606
-
607
-
608
- def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
609
- """Get a bare sqlalchemy session to interact with the database.
610
-
611
- Args:
612
- url: The database url.
625
+ global _AsyncSessionLocal
626
+ if url not in _AsyncSessionLocal:
627
+ _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
628
+ bind=get_async_engine(url),
629
+ class_=AsyncSession,
630
+ expire_on_commit=False,
631
+ autocommit=False,
632
+ autoflush=False,
633
+ )
634
+ return _AsyncSessionLocal[url]()
613
635
 
614
- Returns:
615
- A database session.
616
- """
617
- return sqlalchemy.orm.Session(get_engine(url))
636
+ else:
637
+ get_db_status = _print_db_not_available
638
+ session = _print_db_not_available
639
+ asession = _print_db_not_available
640
+ Model = _ClassThatErrorsOnInit # pyright: ignore [reportAssignmentType]