svc-infra 0.1.595__py3-none-any.whl → 0.1.706__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of svc-infra might be problematic. Click here for more details.
- svc_infra/__init__.py +58 -2
- svc_infra/apf_payments/models.py +133 -42
- svc_infra/apf_payments/provider/aiydan.py +121 -47
- svc_infra/apf_payments/provider/base.py +30 -9
- svc_infra/apf_payments/provider/stripe.py +156 -62
- svc_infra/apf_payments/schemas.py +18 -9
- svc_infra/apf_payments/service.py +98 -41
- svc_infra/apf_payments/settings.py +5 -1
- svc_infra/api/__init__.py +61 -0
- svc_infra/api/fastapi/__init__.py +15 -0
- svc_infra/api/fastapi/admin/__init__.py +3 -0
- svc_infra/api/fastapi/admin/add.py +245 -0
- svc_infra/api/fastapi/apf_payments/router.py +128 -70
- svc_infra/api/fastapi/apf_payments/setup.py +13 -6
- svc_infra/api/fastapi/auth/__init__.py +65 -0
- svc_infra/api/fastapi/auth/_cookies.py +6 -2
- svc_infra/api/fastapi/auth/add.py +17 -14
- svc_infra/api/fastapi/auth/gaurd.py +45 -16
- svc_infra/api/fastapi/auth/mfa/models.py +3 -1
- svc_infra/api/fastapi/auth/mfa/pre_auth.py +10 -6
- svc_infra/api/fastapi/auth/mfa/router.py +15 -8
- svc_infra/api/fastapi/auth/mfa/security.py +1 -2
- svc_infra/api/fastapi/auth/mfa/utils.py +2 -1
- svc_infra/api/fastapi/auth/mfa/verify.py +9 -2
- svc_infra/api/fastapi/auth/policy.py +0 -1
- svc_infra/api/fastapi/auth/providers.py +3 -1
- svc_infra/api/fastapi/auth/routers/apikey_router.py +6 -6
- svc_infra/api/fastapi/auth/routers/oauth_router.py +146 -52
- svc_infra/api/fastapi/auth/routers/session_router.py +6 -2
- svc_infra/api/fastapi/auth/security.py +31 -10
- svc_infra/api/fastapi/auth/sender.py +8 -1
- svc_infra/api/fastapi/auth/state.py +3 -1
- svc_infra/api/fastapi/auth/ws_security.py +275 -0
- svc_infra/api/fastapi/billing/router.py +73 -0
- svc_infra/api/fastapi/billing/setup.py +19 -0
- svc_infra/api/fastapi/cache/add.py +9 -5
- svc_infra/api/fastapi/db/__init__.py +5 -1
- svc_infra/api/fastapi/db/http.py +3 -1
- svc_infra/api/fastapi/db/nosql/__init__.py +39 -1
- svc_infra/api/fastapi/db/nosql/mongo/add.py +47 -32
- svc_infra/api/fastapi/db/nosql/mongo/crud_router.py +30 -11
- svc_infra/api/fastapi/db/sql/__init__.py +5 -1
- svc_infra/api/fastapi/db/sql/add.py +71 -26
- svc_infra/api/fastapi/db/sql/crud_router.py +210 -22
- svc_infra/api/fastapi/db/sql/health.py +3 -1
- svc_infra/api/fastapi/db/sql/session.py +18 -0
- svc_infra/api/fastapi/db/sql/users.py +18 -6
- svc_infra/api/fastapi/dependencies/ratelimit.py +78 -14
- svc_infra/api/fastapi/docs/add.py +173 -0
- svc_infra/api/fastapi/docs/landing.py +4 -2
- svc_infra/api/fastapi/docs/scoped.py +62 -15
- svc_infra/api/fastapi/dual/__init__.py +12 -2
- svc_infra/api/fastapi/dual/dualize.py +1 -1
- svc_infra/api/fastapi/dual/protected.py +126 -4
- svc_infra/api/fastapi/dual/public.py +25 -0
- svc_infra/api/fastapi/dual/router.py +40 -13
- svc_infra/api/fastapi/dx.py +33 -2
- svc_infra/api/fastapi/ease.py +10 -2
- svc_infra/api/fastapi/http/concurrency.py +2 -1
- svc_infra/api/fastapi/http/conditional.py +3 -1
- svc_infra/api/fastapi/middleware/debug.py +4 -1
- svc_infra/api/fastapi/middleware/errors/catchall.py +6 -2
- svc_infra/api/fastapi/middleware/errors/exceptions.py +1 -1
- svc_infra/api/fastapi/middleware/errors/handlers.py +54 -8
- svc_infra/api/fastapi/middleware/graceful_shutdown.py +104 -0
- svc_infra/api/fastapi/middleware/idempotency.py +197 -70
- svc_infra/api/fastapi/middleware/idempotency_store.py +187 -0
- svc_infra/api/fastapi/middleware/optimistic_lock.py +42 -0
- svc_infra/api/fastapi/middleware/ratelimit.py +125 -28
- svc_infra/api/fastapi/middleware/ratelimit_store.py +43 -10
- svc_infra/api/fastapi/middleware/request_id.py +27 -11
- svc_infra/api/fastapi/middleware/request_size_limit.py +3 -3
- svc_infra/api/fastapi/middleware/timeout.py +177 -0
- svc_infra/api/fastapi/openapi/apply.py +5 -3
- svc_infra/api/fastapi/openapi/conventions.py +9 -2
- svc_infra/api/fastapi/openapi/mutators.py +165 -20
- svc_infra/api/fastapi/openapi/pipeline.py +1 -1
- svc_infra/api/fastapi/openapi/security.py +3 -1
- svc_infra/api/fastapi/ops/add.py +75 -0
- svc_infra/api/fastapi/pagination.py +47 -20
- svc_infra/api/fastapi/routers/__init__.py +43 -15
- svc_infra/api/fastapi/routers/ping.py +1 -0
- svc_infra/api/fastapi/setup.py +188 -57
- svc_infra/api/fastapi/tenancy/add.py +19 -0
- svc_infra/api/fastapi/tenancy/context.py +112 -0
- svc_infra/api/fastapi/versioned.py +101 -0
- svc_infra/app/README.md +5 -5
- svc_infra/app/__init__.py +3 -1
- svc_infra/app/env.py +69 -1
- svc_infra/app/logging/add.py +9 -2
- svc_infra/app/logging/formats.py +12 -5
- svc_infra/billing/__init__.py +23 -0
- svc_infra/billing/async_service.py +147 -0
- svc_infra/billing/jobs.py +241 -0
- svc_infra/billing/models.py +177 -0
- svc_infra/billing/quotas.py +103 -0
- svc_infra/billing/schemas.py +36 -0
- svc_infra/billing/service.py +123 -0
- svc_infra/bundled_docs/README.md +5 -0
- svc_infra/bundled_docs/__init__.py +1 -0
- svc_infra/bundled_docs/getting-started.md +6 -0
- svc_infra/cache/__init__.py +9 -0
- svc_infra/cache/add.py +170 -0
- svc_infra/cache/backend.py +7 -6
- svc_infra/cache/decorators.py +81 -15
- svc_infra/cache/demo.py +2 -2
- svc_infra/cache/keys.py +24 -4
- svc_infra/cache/recache.py +26 -14
- svc_infra/cache/resources.py +14 -5
- svc_infra/cache/tags.py +19 -44
- svc_infra/cache/utils.py +3 -1
- svc_infra/cli/__init__.py +52 -8
- svc_infra/cli/__main__.py +4 -0
- svc_infra/cli/cmds/__init__.py +39 -2
- svc_infra/cli/cmds/db/nosql/mongo/mongo_cmds.py +7 -4
- svc_infra/cli/cmds/db/nosql/mongo/mongo_scaffold_cmds.py +7 -5
- svc_infra/cli/cmds/db/ops_cmds.py +270 -0
- svc_infra/cli/cmds/db/sql/alembic_cmds.py +103 -18
- svc_infra/cli/cmds/db/sql/sql_export_cmds.py +88 -0
- svc_infra/cli/cmds/db/sql/sql_scaffold_cmds.py +3 -3
- svc_infra/cli/cmds/docs/docs_cmds.py +142 -0
- svc_infra/cli/cmds/dx/__init__.py +12 -0
- svc_infra/cli/cmds/dx/dx_cmds.py +116 -0
- svc_infra/cli/cmds/health/__init__.py +179 -0
- svc_infra/cli/cmds/health/health_cmds.py +8 -0
- svc_infra/cli/cmds/help.py +4 -0
- svc_infra/cli/cmds/jobs/__init__.py +1 -0
- svc_infra/cli/cmds/jobs/jobs_cmds.py +47 -0
- svc_infra/cli/cmds/obs/obs_cmds.py +36 -15
- svc_infra/cli/cmds/sdk/__init__.py +0 -0
- svc_infra/cli/cmds/sdk/sdk_cmds.py +112 -0
- svc_infra/cli/foundation/runner.py +6 -2
- svc_infra/data/add.py +61 -0
- svc_infra/data/backup.py +58 -0
- svc_infra/data/erasure.py +45 -0
- svc_infra/data/fixtures.py +42 -0
- svc_infra/data/retention.py +61 -0
- svc_infra/db/__init__.py +15 -0
- svc_infra/db/crud_schema.py +9 -9
- svc_infra/db/inbox.py +67 -0
- svc_infra/db/nosql/__init__.py +3 -0
- svc_infra/db/nosql/core.py +30 -9
- svc_infra/db/nosql/indexes.py +3 -1
- svc_infra/db/nosql/management.py +1 -1
- svc_infra/db/nosql/mongo/README.md +13 -13
- svc_infra/db/nosql/mongo/client.py +19 -2
- svc_infra/db/nosql/mongo/settings.py +6 -2
- svc_infra/db/nosql/repository.py +35 -15
- svc_infra/db/nosql/resource.py +20 -3
- svc_infra/db/nosql/scaffold.py +9 -3
- svc_infra/db/nosql/service.py +3 -1
- svc_infra/db/nosql/types.py +6 -2
- svc_infra/db/ops.py +384 -0
- svc_infra/db/outbox.py +108 -0
- svc_infra/db/sql/apikey.py +37 -9
- svc_infra/db/sql/authref.py +9 -3
- svc_infra/db/sql/constants.py +12 -8
- svc_infra/db/sql/core.py +2 -2
- svc_infra/db/sql/management.py +11 -8
- svc_infra/db/sql/repository.py +99 -26
- svc_infra/db/sql/resource.py +5 -0
- svc_infra/db/sql/scaffold.py +6 -2
- svc_infra/db/sql/service.py +15 -5
- svc_infra/db/sql/templates/models_schemas/auth/models.py.tmpl +7 -56
- svc_infra/db/sql/templates/setup/env_async.py.tmpl +34 -12
- svc_infra/db/sql/templates/setup/env_sync.py.tmpl +29 -7
- svc_infra/db/sql/tenant.py +88 -0
- svc_infra/db/sql/uniq_hooks.py +9 -3
- svc_infra/db/sql/utils.py +138 -51
- svc_infra/db/sql/versioning.py +14 -0
- svc_infra/deploy/__init__.py +538 -0
- svc_infra/documents/__init__.py +100 -0
- svc_infra/documents/add.py +264 -0
- svc_infra/documents/ease.py +233 -0
- svc_infra/documents/models.py +114 -0
- svc_infra/documents/storage.py +264 -0
- svc_infra/dx/add.py +65 -0
- svc_infra/dx/changelog.py +74 -0
- svc_infra/dx/checks.py +68 -0
- svc_infra/exceptions.py +141 -0
- svc_infra/health/__init__.py +864 -0
- svc_infra/http/__init__.py +13 -0
- svc_infra/http/client.py +105 -0
- svc_infra/jobs/builtins/outbox_processor.py +40 -0
- svc_infra/jobs/builtins/webhook_delivery.py +95 -0
- svc_infra/jobs/easy.py +33 -0
- svc_infra/jobs/loader.py +50 -0
- svc_infra/jobs/queue.py +116 -0
- svc_infra/jobs/redis_queue.py +256 -0
- svc_infra/jobs/runner.py +79 -0
- svc_infra/jobs/scheduler.py +53 -0
- svc_infra/jobs/worker.py +40 -0
- svc_infra/loaders/__init__.py +186 -0
- svc_infra/loaders/base.py +142 -0
- svc_infra/loaders/github.py +311 -0
- svc_infra/loaders/models.py +147 -0
- svc_infra/loaders/url.py +235 -0
- svc_infra/logging/__init__.py +374 -0
- svc_infra/mcp/svc_infra_mcp.py +91 -33
- svc_infra/obs/README.md +2 -0
- svc_infra/obs/add.py +65 -9
- svc_infra/obs/cloud_dash.py +2 -1
- svc_infra/obs/grafana/dashboards/http-overview.json +45 -0
- svc_infra/obs/metrics/__init__.py +3 -4
- svc_infra/obs/metrics/asgi.py +13 -7
- svc_infra/obs/metrics/http.py +9 -5
- svc_infra/obs/metrics/sqlalchemy.py +13 -9
- svc_infra/obs/metrics.py +6 -5
- svc_infra/obs/settings.py +6 -2
- svc_infra/security/add.py +217 -0
- svc_infra/security/audit.py +92 -10
- svc_infra/security/audit_service.py +4 -3
- svc_infra/security/headers.py +15 -2
- svc_infra/security/hibp.py +14 -4
- svc_infra/security/jwt_rotation.py +74 -22
- svc_infra/security/lockout.py +11 -5
- svc_infra/security/models.py +54 -12
- svc_infra/security/oauth_models.py +73 -0
- svc_infra/security/org_invites.py +5 -3
- svc_infra/security/passwords.py +3 -1
- svc_infra/security/permissions.py +25 -2
- svc_infra/security/session.py +1 -1
- svc_infra/security/signed_cookies.py +21 -1
- svc_infra/storage/__init__.py +93 -0
- svc_infra/storage/add.py +253 -0
- svc_infra/storage/backends/__init__.py +11 -0
- svc_infra/storage/backends/local.py +339 -0
- svc_infra/storage/backends/memory.py +216 -0
- svc_infra/storage/backends/s3.py +353 -0
- svc_infra/storage/base.py +239 -0
- svc_infra/storage/easy.py +185 -0
- svc_infra/storage/settings.py +195 -0
- svc_infra/testing/__init__.py +685 -0
- svc_infra/utils.py +7 -3
- svc_infra/webhooks/__init__.py +69 -0
- svc_infra/webhooks/add.py +339 -0
- svc_infra/webhooks/encryption.py +115 -0
- svc_infra/webhooks/fastapi.py +39 -0
- svc_infra/webhooks/router.py +55 -0
- svc_infra/webhooks/service.py +70 -0
- svc_infra/webhooks/signing.py +34 -0
- svc_infra/websocket/__init__.py +79 -0
- svc_infra/websocket/add.py +140 -0
- svc_infra/websocket/client.py +282 -0
- svc_infra/websocket/config.py +69 -0
- svc_infra/websocket/easy.py +76 -0
- svc_infra/websocket/exceptions.py +61 -0
- svc_infra/websocket/manager.py +344 -0
- svc_infra/websocket/models.py +49 -0
- svc_infra-0.1.706.dist-info/LICENSE +21 -0
- svc_infra-0.1.706.dist-info/METADATA +356 -0
- svc_infra-0.1.706.dist-info/RECORD +357 -0
- svc_infra-0.1.595.dist-info/METADATA +0 -80
- svc_infra-0.1.595.dist-info/RECORD +0 -253
- {svc_infra-0.1.595.dist-info → svc_infra-0.1.706.dist-info}/WHEEL +0 -0
- {svc_infra-0.1.595.dist-info → svc_infra-0.1.706.dist-info}/entry_points.txt +0 -0
svc_infra/db/sql/management.py
CHANGED
|
@@ -15,20 +15,19 @@ def _sa_columns(model: type[object]) -> list[Column]:
|
|
|
15
15
|
def _py_type(col: Column) -> type:
|
|
16
16
|
# Prefer SQLAlchemy-provided python_type when available
|
|
17
17
|
if getattr(col.type, "python_type", None):
|
|
18
|
-
return col.type.python_type
|
|
18
|
+
return col.type.python_type
|
|
19
19
|
|
|
20
20
|
from datetime import date, datetime
|
|
21
|
-
from typing import Any as _Any
|
|
22
21
|
from uuid import UUID
|
|
23
22
|
|
|
24
23
|
from sqlalchemy import JSON, Boolean, Date, DateTime, Integer, String, Text
|
|
25
24
|
|
|
26
25
|
try:
|
|
27
26
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
28
|
-
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
|
27
|
+
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
|
|
29
28
|
except Exception: # pragma: no cover
|
|
30
|
-
PG_UUID = None # type: ignore
|
|
31
|
-
JSONB = None # type: ignore
|
|
29
|
+
PG_UUID = None # type: ignore[misc,assignment]
|
|
30
|
+
JSONB = None # type: ignore[misc,assignment]
|
|
32
31
|
|
|
33
32
|
t = col.type
|
|
34
33
|
if PG_UUID is not None and isinstance(t, PG_UUID):
|
|
@@ -47,7 +46,7 @@ def _py_type(col: Column) -> type:
|
|
|
47
46
|
return dict
|
|
48
47
|
if JSONB is not None and isinstance(t, JSONB):
|
|
49
48
|
return dict
|
|
50
|
-
return
|
|
49
|
+
return object # fallback type for unknown column types
|
|
51
50
|
|
|
52
51
|
|
|
53
52
|
def _exclude_from_create(col: Column) -> bool:
|
|
@@ -101,9 +100,13 @@ def make_crud_schemas(
|
|
|
101
100
|
name=name,
|
|
102
101
|
typ=T,
|
|
103
102
|
required_for_create=bool(
|
|
104
|
-
is_required
|
|
103
|
+
is_required
|
|
104
|
+
and name not in explicit_excludes
|
|
105
|
+
and not _exclude_from_create(col)
|
|
106
|
+
),
|
|
107
|
+
exclude_from_create=bool(
|
|
108
|
+
name in explicit_excludes or _exclude_from_create(col)
|
|
105
109
|
),
|
|
106
|
-
exclude_from_create=bool(name in explicit_excludes or _exclude_from_create(col)),
|
|
107
110
|
exclude_from_read=bool(name in read_ex),
|
|
108
111
|
exclude_from_update=bool(name in update_ex),
|
|
109
112
|
)
|
svc_infra/db/sql/repository.py
CHANGED
|
@@ -1,11 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, Iterable, Optional, Sequence, Set, cast
|
|
4
6
|
|
|
5
7
|
from sqlalchemy import Select, String, and_, func, or_, select
|
|
6
8
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
9
|
from sqlalchemy.orm import InstrumentedAttribute, class_mapper
|
|
8
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _escape_ilike(q: str) -> str:
|
|
15
|
+
"""Escape special characters for ILIKE pattern matching.
|
|
16
|
+
|
|
17
|
+
Prevents SQL injection via wildcard characters that could match
|
|
18
|
+
unintended data (e.g., % matches any string, _ matches any char).
|
|
19
|
+
"""
|
|
20
|
+
return q.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
|
21
|
+
|
|
9
22
|
|
|
10
23
|
class SqlRepository:
|
|
11
24
|
"""
|
|
@@ -34,8 +47,8 @@ class SqlRepository:
|
|
|
34
47
|
def _model_columns(self) -> set[str]:
|
|
35
48
|
return {c.key for c in class_mapper(self.model).columns}
|
|
36
49
|
|
|
37
|
-
def _id_column(self) -> InstrumentedAttribute:
|
|
38
|
-
return getattr(self.model, self.id_attr)
|
|
50
|
+
def _id_column(self) -> InstrumentedAttribute[Any]:
|
|
51
|
+
return cast(InstrumentedAttribute[Any], getattr(self.model, self.id_attr))
|
|
39
52
|
|
|
40
53
|
def _base_select(self) -> Select:
|
|
41
54
|
stmt = select(self.model)
|
|
@@ -43,8 +56,12 @@ class SqlRepository:
|
|
|
43
56
|
# Filter out soft-deleted rows by timestamp and/or active flag
|
|
44
57
|
if hasattr(self.model, self.soft_delete_field):
|
|
45
58
|
stmt = stmt.where(getattr(self.model, self.soft_delete_field).is_(None))
|
|
46
|
-
if self.soft_delete_flag_field and hasattr(
|
|
47
|
-
|
|
59
|
+
if self.soft_delete_flag_field and hasattr(
|
|
60
|
+
self.model, self.soft_delete_flag_field
|
|
61
|
+
):
|
|
62
|
+
stmt = stmt.where(
|
|
63
|
+
getattr(self.model, self.soft_delete_flag_field).is_(True)
|
|
64
|
+
)
|
|
48
65
|
return stmt
|
|
49
66
|
|
|
50
67
|
# basic ops
|
|
@@ -56,20 +73,37 @@ class SqlRepository:
|
|
|
56
73
|
limit: int,
|
|
57
74
|
offset: int,
|
|
58
75
|
order_by: Optional[Sequence[Any]] = None,
|
|
76
|
+
where: Optional[Sequence[Any]] = None,
|
|
59
77
|
) -> Sequence[Any]:
|
|
60
|
-
stmt = self._base_select()
|
|
78
|
+
stmt = self._base_select()
|
|
79
|
+
if where:
|
|
80
|
+
stmt = stmt.where(and_(*where))
|
|
81
|
+
stmt = stmt.limit(limit).offset(offset)
|
|
61
82
|
if order_by:
|
|
62
83
|
stmt = stmt.order_by(*order_by)
|
|
63
|
-
|
|
64
|
-
return
|
|
65
|
-
|
|
66
|
-
async def count(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
84
|
+
result = (await session.execute(stmt)).scalars().all()
|
|
85
|
+
return list(result)
|
|
86
|
+
|
|
87
|
+
async def count(
|
|
88
|
+
self, session: AsyncSession, *, where: Optional[Sequence[Any]] = None
|
|
89
|
+
) -> int:
|
|
90
|
+
base = self._base_select()
|
|
91
|
+
if where:
|
|
92
|
+
base = base.where(and_(*where))
|
|
93
|
+
stmt = select(func.count()).select_from(base.subquery())
|
|
94
|
+
return int((await session.execute(stmt)).scalar_one())
|
|
95
|
+
|
|
96
|
+
async def get(
|
|
97
|
+
self,
|
|
98
|
+
session: AsyncSession,
|
|
99
|
+
id_value: Any,
|
|
100
|
+
*,
|
|
101
|
+
where: Optional[Sequence[Any]] = None,
|
|
102
|
+
) -> Any | None:
|
|
71
103
|
# honors soft-delete if configured
|
|
72
104
|
stmt = self._base_select().where(self._id_column() == id_value)
|
|
105
|
+
if where:
|
|
106
|
+
stmt = stmt.where(and_(*where))
|
|
73
107
|
return (await session.execute(stmt)).scalars().first()
|
|
74
108
|
|
|
75
109
|
async def create(self, session: AsyncSession, data: dict[str, Any]) -> Any:
|
|
@@ -78,12 +112,18 @@ class SqlRepository:
|
|
|
78
112
|
obj = self.model(**filtered)
|
|
79
113
|
session.add(obj)
|
|
80
114
|
await session.flush()
|
|
115
|
+
await session.refresh(obj)
|
|
81
116
|
return obj
|
|
82
117
|
|
|
83
118
|
async def update(
|
|
84
|
-
self,
|
|
119
|
+
self,
|
|
120
|
+
session: AsyncSession,
|
|
121
|
+
id_value: Any,
|
|
122
|
+
data: dict[str, Any],
|
|
123
|
+
*,
|
|
124
|
+
where: Optional[Sequence[Any]] = None,
|
|
85
125
|
) -> Any | None:
|
|
86
|
-
obj = await self.get(session, id_value)
|
|
126
|
+
obj = await self.get(session, id_value, where=where)
|
|
87
127
|
if not obj:
|
|
88
128
|
return None
|
|
89
129
|
valid = self._model_columns()
|
|
@@ -91,21 +131,40 @@ class SqlRepository:
|
|
|
91
131
|
if k in valid and k not in self.immutable_fields:
|
|
92
132
|
setattr(obj, k, v)
|
|
93
133
|
await session.flush()
|
|
134
|
+
await session.refresh(obj)
|
|
94
135
|
return obj
|
|
95
136
|
|
|
96
|
-
async def delete(
|
|
97
|
-
|
|
137
|
+
async def delete(
|
|
138
|
+
self,
|
|
139
|
+
session: AsyncSession,
|
|
140
|
+
id_value: Any,
|
|
141
|
+
*,
|
|
142
|
+
where: Optional[Sequence[Any]] = None,
|
|
143
|
+
) -> bool:
|
|
144
|
+
# Fast path: when no extra filters provided, use session.get for simplicity (matches tests)
|
|
145
|
+
if not where:
|
|
146
|
+
obj = await session.get(self.model, id_value)
|
|
147
|
+
else:
|
|
148
|
+
# Respect soft-delete and optional tenant/extra filters by selecting through base select
|
|
149
|
+
stmt = self._base_select().where(self._id_column() == id_value)
|
|
150
|
+
stmt = stmt.where(and_(*where))
|
|
151
|
+
obj = (await session.execute(stmt)).scalars().first()
|
|
98
152
|
if not obj:
|
|
99
153
|
return False
|
|
100
154
|
if self.soft_delete:
|
|
101
155
|
# Prefer timestamp, also optionally set flag to False
|
|
102
|
-
|
|
156
|
+
# Check attributes on the instance to support test doubles without class-level fields
|
|
157
|
+
if hasattr(obj, self.soft_delete_field):
|
|
103
158
|
setattr(obj, self.soft_delete_field, func.now())
|
|
104
|
-
if self.soft_delete_flag_field and hasattr(
|
|
159
|
+
if self.soft_delete_flag_field and hasattr(
|
|
160
|
+
obj, self.soft_delete_flag_field
|
|
161
|
+
):
|
|
105
162
|
setattr(obj, self.soft_delete_flag_field, False)
|
|
106
163
|
await session.flush()
|
|
107
164
|
return True
|
|
108
|
-
session.delete(obj)
|
|
165
|
+
delete_result = session.delete(obj)
|
|
166
|
+
if inspect.isawaitable(delete_result):
|
|
167
|
+
await delete_result
|
|
109
168
|
await session.flush()
|
|
110
169
|
return True
|
|
111
170
|
|
|
@@ -118,18 +177,22 @@ class SqlRepository:
|
|
|
118
177
|
limit: int,
|
|
119
178
|
offset: int,
|
|
120
179
|
order_by: Optional[Sequence[Any]] = None,
|
|
180
|
+
where: Optional[Sequence[Any]] = None,
|
|
121
181
|
) -> Sequence[Any]:
|
|
122
|
-
ilike = f"%{q}%"
|
|
182
|
+
ilike = f"%{_escape_ilike(q)}%"
|
|
123
183
|
conditions = []
|
|
124
184
|
for f in fields:
|
|
125
185
|
col = getattr(self.model, f, None)
|
|
126
186
|
if col is not None:
|
|
127
187
|
try:
|
|
128
188
|
conditions.append(col.cast(String).ilike(ilike))
|
|
129
|
-
except Exception:
|
|
189
|
+
except Exception as e:
|
|
130
190
|
# skip columns that cannot be used in ilike even with cast
|
|
191
|
+
logger.debug("Column %s cannot be cast for ILIKE search: %s", f, e)
|
|
131
192
|
continue
|
|
132
193
|
stmt = self._base_select()
|
|
194
|
+
if where:
|
|
195
|
+
stmt = stmt.where(and_(*where))
|
|
133
196
|
if conditions:
|
|
134
197
|
stmt = stmt.where(or_(*conditions))
|
|
135
198
|
stmt = stmt.limit(limit).offset(offset)
|
|
@@ -137,17 +200,27 @@ class SqlRepository:
|
|
|
137
200
|
stmt = stmt.order_by(*order_by)
|
|
138
201
|
return (await session.execute(stmt)).scalars().all()
|
|
139
202
|
|
|
140
|
-
async def count_filtered(
|
|
141
|
-
|
|
203
|
+
async def count_filtered(
|
|
204
|
+
self,
|
|
205
|
+
session: AsyncSession,
|
|
206
|
+
*,
|
|
207
|
+
q: str,
|
|
208
|
+
fields: Sequence[str],
|
|
209
|
+
where: Optional[Sequence[Any]] = None,
|
|
210
|
+
) -> int:
|
|
211
|
+
ilike = f"%{_escape_ilike(q)}%"
|
|
142
212
|
conditions = []
|
|
143
213
|
for f in fields:
|
|
144
214
|
col = getattr(self.model, f, None)
|
|
145
215
|
if col is not None:
|
|
146
216
|
try:
|
|
147
217
|
conditions.append(col.cast(String).ilike(ilike))
|
|
148
|
-
except Exception:
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.debug("Column %s cannot be cast for ILIKE search: %s", f, e)
|
|
149
220
|
continue
|
|
150
221
|
stmt = self._base_select()
|
|
222
|
+
if where:
|
|
223
|
+
stmt = stmt.where(and_(*where))
|
|
151
224
|
if conditions:
|
|
152
225
|
stmt = stmt.where(or_(*conditions))
|
|
153
226
|
# SELECT COUNT(*) FROM (<stmt>) as t
|
svc_infra/db/sql/resource.py
CHANGED
|
@@ -34,3 +34,8 @@ class SqlResource:
|
|
|
34
34
|
|
|
35
35
|
# Only a type reference; no runtime dependency on FastAPI layer
|
|
36
36
|
service_factory: Optional[Callable[[SqlRepository], "SqlService"]] = None
|
|
37
|
+
|
|
38
|
+
# Tenancy
|
|
39
|
+
tenant_field: Optional[str] = (
|
|
40
|
+
None # when set, CRUD router will require TenantId and scope by field
|
|
41
|
+
)
|
svc_infra/db/sql/scaffold.py
CHANGED
|
@@ -9,7 +9,9 @@ from svc_infra.utils import ensure_init_py, render_template, write
|
|
|
9
9
|
|
|
10
10
|
# ---------------- helpers ----------------
|
|
11
11
|
|
|
12
|
-
_INIT_CONTENT_PAIRED =
|
|
12
|
+
_INIT_CONTENT_PAIRED = (
|
|
13
|
+
'from . import models, schemas\n\n__all__ = ["models", "schemas"]\n'
|
|
14
|
+
)
|
|
13
15
|
_INIT_CONTENT_MINIMAL = "# package marker; add explicit exports here if desired\n"
|
|
14
16
|
|
|
15
17
|
|
|
@@ -102,7 +104,9 @@ def scaffold_core(
|
|
|
102
104
|
},
|
|
103
105
|
)
|
|
104
106
|
|
|
105
|
-
tenant_schema_field =
|
|
107
|
+
tenant_schema_field = (
|
|
108
|
+
" tenant_id: Optional[str] = None\n" if include_tenant else ""
|
|
109
|
+
)
|
|
106
110
|
schemas_txt = render_template(
|
|
107
111
|
tmpl_dir="svc_infra.db.sql.templates.models_schemas.entity",
|
|
108
112
|
name="schemas.py.tmpl",
|
svc_infra/db/sql/service.py
CHANGED
|
@@ -24,8 +24,12 @@ class SqlService:
|
|
|
24
24
|
async def pre_update(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
25
25
|
return data
|
|
26
26
|
|
|
27
|
-
async def list(
|
|
28
|
-
|
|
27
|
+
async def list(
|
|
28
|
+
self, session: AsyncSession, *, limit: int, offset: int, order_by=None
|
|
29
|
+
):
|
|
30
|
+
return await self.repo.list(
|
|
31
|
+
session, limit=limit, offset=offset, order_by=order_by
|
|
32
|
+
)
|
|
29
33
|
|
|
30
34
|
async def count(self, session: AsyncSession) -> int:
|
|
31
35
|
return await self.repo.count(session)
|
|
@@ -41,9 +45,13 @@ class SqlService:
|
|
|
41
45
|
# unique constraint or not-null -> 409/400 instead of 500
|
|
42
46
|
msg = str(e.orig) if getattr(e, "orig", None) else str(e)
|
|
43
47
|
if "duplicate key value" in msg or "UniqueViolation" in msg:
|
|
44
|
-
raise HTTPException(
|
|
48
|
+
raise HTTPException(
|
|
49
|
+
status_code=409, detail="Record already exists."
|
|
50
|
+
) from e
|
|
45
51
|
if "not-null" in msg or "NotNullViolation" in msg:
|
|
46
|
-
raise HTTPException(
|
|
52
|
+
raise HTTPException(
|
|
53
|
+
status_code=400, detail="Missing required field."
|
|
54
|
+
) from e
|
|
47
55
|
raise # unknown, let your error middleware turn into 500
|
|
48
56
|
|
|
49
57
|
async def update(self, session: AsyncSession, id_value: Any, data: dict[str, Any]):
|
|
@@ -67,7 +75,9 @@ class SqlService:
|
|
|
67
75
|
session, q=q, fields=fields, limit=limit, offset=offset, order_by=order_by
|
|
68
76
|
)
|
|
69
77
|
|
|
70
|
-
async def count_filtered(
|
|
78
|
+
async def count_filtered(
|
|
79
|
+
self, session: AsyncSession, *, q: str, fields: Sequence[str]
|
|
80
|
+
) -> int:
|
|
71
81
|
return await self.repo.count_filtered(session, q=q, fields=fields)
|
|
72
82
|
|
|
73
83
|
async def exists(self, session: AsyncSession, *, where):
|
|
@@ -129,62 +129,13 @@ for _ix in make_unique_sql_indexes(
|
|
|
129
129
|
# Registered with Table metadata (alembic/autogenerate will pick them up)
|
|
130
130
|
pass
|
|
131
131
|
|
|
132
|
-
#
|
|
133
|
-
|
|
134
|
-
class
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
- Optionally stores tokens for later API calls (refresh_token encrypted at rest)
|
|
140
|
-
"""
|
|
141
|
-
__tablename__ = "provider_accounts"
|
|
142
|
-
|
|
143
|
-
id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
|
|
144
|
-
|
|
145
|
-
user_id: Mapped[uuid.UUID] = mapped_column(
|
|
146
|
-
GUID(), ForeignKey("${auth_table_name}.id", ondelete="CASCADE"), nullable=False
|
|
147
|
-
)
|
|
148
|
-
user: Mapped["${AuthEntity}"] = relationship(
|
|
149
|
-
back_populates="provider_accounts",
|
|
150
|
-
lazy="selectin",
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
provider: Mapped[str] = mapped_column(String(50), nullable=False) # "google"|"github"|"linkedin"|"microsoft"|...
|
|
154
|
-
provider_account_id: Mapped[str] = mapped_column(String(255), nullable=False) # sub/oid (OIDC) or id (github/linkedin)
|
|
155
|
-
|
|
156
|
-
# Optional token material
|
|
157
|
-
access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
|
158
|
-
|
|
159
|
-
# Store encrypted refresh_token in the same column name for DB compatibility.
|
|
160
|
-
_refresh_token: Mapped[Optional[str]] = mapped_column("refresh_token", Text, nullable=True)
|
|
161
|
-
|
|
162
|
-
@property
|
|
163
|
-
def refresh_token(self) -> Optional[str]:
|
|
164
|
-
return _decrypt(self._refresh_token)
|
|
165
|
-
|
|
166
|
-
@refresh_token.setter
|
|
167
|
-
def refresh_token(self, value: Optional[str]) -> None:
|
|
168
|
-
self._refresh_token = _encrypt(value)
|
|
169
|
-
|
|
170
|
-
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
171
|
-
raw_claims: Mapped[Optional[dict]] = mapped_column(MutableDict.as_mutable(JSON), nullable=True)
|
|
172
|
-
|
|
173
|
-
created_at = mapped_column(
|
|
174
|
-
DateTime(timezone=True), server_default=text("CURRENT_TIMESTAMP"), nullable=False
|
|
175
|
-
)
|
|
176
|
-
updated_at = mapped_column(
|
|
177
|
-
DateTime(timezone=True), server_default=text("CURRENT_TIMESTAMP"),
|
|
178
|
-
onupdate=text("CURRENT_TIMESTAMP"), nullable=False
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
__table_args__ = (
|
|
182
|
-
UniqueConstraint("provider", "provider_account_id", name="uq_provider_account"),
|
|
183
|
-
Index("ix_provider_accounts_user_id", "user_id"),
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
def __repr__(self) -> str:
|
|
187
|
-
return f"<ProviderAccount provider={self.provider!r} provider_account_id={self.provider_account_id!r} user_id={self.user_id}>"
|
|
132
|
+
# NOTE: ProviderAccount model is imported from svc_infra.security.oauth_models
|
|
133
|
+
# It's an opt-in OAuth model that links users to providers (Google, GitHub, etc.)
|
|
134
|
+
# The relationship 'provider_accounts' is defined above in the ${AuthEntity} class.
|
|
135
|
+
# To enable OAuth in your project:
|
|
136
|
+
# 1. Set ALEMBIC_ENABLE_OAUTH=true in your .env
|
|
137
|
+
# 2. Pass provider_account_model=ProviderAccount to add_auth_users()
|
|
138
|
+
# 3. Import: from svc_infra.security.oauth_models import ProviderAccount
|
|
188
139
|
|
|
189
140
|
# --- Auth service factory ------------------------------------------------------
|
|
190
141
|
|
|
@@ -6,13 +6,10 @@ from typing import List, Tuple
|
|
|
6
6
|
import sys, pathlib, importlib, pkgutil, traceback
|
|
7
7
|
|
|
8
8
|
from alembic import context
|
|
9
|
+
from sqlalchemy import MetaData
|
|
9
10
|
from sqlalchemy.engine import make_url, URL
|
|
10
|
-
from sqlalchemy.ext.asyncio import create_async_engine
|
|
11
11
|
|
|
12
|
-
from svc_infra.db.sql.utils import
|
|
13
|
-
get_database_url_from_env,
|
|
14
|
-
_ensure_ssl_default_async as _ensure_ssl_default,
|
|
15
|
-
)
|
|
12
|
+
from svc_infra.db.sql.utils import get_database_url_from_env
|
|
16
13
|
|
|
17
14
|
try:
|
|
18
15
|
from svc_infra.db.sql.types import GUID as _GUID # type: ignore
|
|
@@ -105,7 +102,6 @@ def _coerce_to_async(u: URL) -> URL:
|
|
|
105
102
|
|
|
106
103
|
u = make_url(effective_url)
|
|
107
104
|
u = _coerce_to_async(u)
|
|
108
|
-
u = _ensure_ssl_default(u)
|
|
109
105
|
config.set_main_option("sqlalchemy.url", u.render_as_string(hide_password=False))
|
|
110
106
|
|
|
111
107
|
# feature flags
|
|
@@ -131,15 +127,16 @@ def _collect_metadata() -> list[object]:
|
|
|
131
127
|
|
|
132
128
|
def _maybe_add(obj: object) -> None:
|
|
133
129
|
md = getattr(obj, "metadata", None) or obj
|
|
134
|
-
|
|
130
|
+
# Strict check: must be actual MetaData instance
|
|
131
|
+
if isinstance(md, MetaData) and md.tables:
|
|
135
132
|
found.append(md)
|
|
136
133
|
|
|
137
134
|
def _scan_module_objects(mod: object) -> None:
|
|
138
135
|
try:
|
|
139
136
|
for val in vars(mod).values():
|
|
140
|
-
|
|
141
|
-
if
|
|
142
|
-
found.append(
|
|
137
|
+
# Strict check: must be actual MetaData instance
|
|
138
|
+
if isinstance(val, MetaData) and val.tables:
|
|
139
|
+
found.append(val)
|
|
143
140
|
except Exception:
|
|
144
141
|
pass
|
|
145
142
|
|
|
@@ -177,8 +174,16 @@ def _collect_metadata() -> list[object]:
|
|
|
177
174
|
if name not in pkgs:
|
|
178
175
|
pkgs.append(name)
|
|
179
176
|
|
|
177
|
+
# Only attempt bare 'models' import if it is discoverable to avoid noisy tracebacks
|
|
180
178
|
if "models" not in pkgs:
|
|
181
|
-
|
|
179
|
+
try:
|
|
180
|
+
spec = getattr(importlib, "util", None)
|
|
181
|
+
if spec is not None and getattr(spec, "find_spec", None) is not None:
|
|
182
|
+
if spec.find_spec("models") is not None:
|
|
183
|
+
pkgs.append("models")
|
|
184
|
+
except Exception:
|
|
185
|
+
# Best-effort; if discovery fails, skip adding bare 'models'
|
|
186
|
+
pass
|
|
182
187
|
|
|
183
188
|
def _import_and_collect(modname: str):
|
|
184
189
|
try:
|
|
@@ -221,6 +226,21 @@ def _collect_metadata() -> list[object]:
|
|
|
221
226
|
except Exception:
|
|
222
227
|
_note("ModelBase import", False, traceback.format_exc())
|
|
223
228
|
|
|
229
|
+
# Core security models (AuthSession, RefreshToken, etc.)
|
|
230
|
+
try:
|
|
231
|
+
import svc_infra.security.models # noqa: F401
|
|
232
|
+
_note("svc_infra.security.models", True, None)
|
|
233
|
+
except Exception:
|
|
234
|
+
_note("svc_infra.security.models", False, traceback.format_exc())
|
|
235
|
+
|
|
236
|
+
# OAuth models (opt-in via environment variable)
|
|
237
|
+
if os.getenv("ALEMBIC_ENABLE_OAUTH", "").lower() in {"1", "true", "yes"}:
|
|
238
|
+
try:
|
|
239
|
+
import svc_infra.security.oauth_models # noqa: F401
|
|
240
|
+
_note("svc_infra.security.oauth_models", True, None)
|
|
241
|
+
except Exception:
|
|
242
|
+
_note("svc_infra.security.oauth_models", False, traceback.format_exc())
|
|
243
|
+
|
|
224
244
|
try:
|
|
225
245
|
from svc_infra.db.sql.apikey import try_autobind_apikey_model
|
|
226
246
|
try_autobind_apikey_model(require_env=False)
|
|
@@ -352,7 +372,9 @@ def _do_run_migrations(connection):
|
|
|
352
372
|
|
|
353
373
|
async def run_migrations_online() -> None:
|
|
354
374
|
url = config.get_main_option("sqlalchemy.url")
|
|
355
|
-
|
|
375
|
+
# Use build_engine to ensure proper driver-specific handling (e.g., asyncpg SSL)
|
|
376
|
+
from svc_infra.db.sql.utils import build_engine
|
|
377
|
+
engine = build_engine(url)
|
|
356
378
|
async with engine.connect() as connection:
|
|
357
379
|
await connection.run_sync(_do_run_migrations)
|
|
358
380
|
await engine.dispose()
|
|
@@ -6,11 +6,10 @@ from typing import List, Tuple
|
|
|
6
6
|
import sys, pathlib, importlib, pkgutil, traceback
|
|
7
7
|
|
|
8
8
|
from alembic import context
|
|
9
|
+
from sqlalchemy import MetaData
|
|
9
10
|
from sqlalchemy.engine import make_url, URL
|
|
10
11
|
|
|
11
12
|
from svc_infra.db.sql.utils import (
|
|
12
|
-
_coerce_sync_driver,
|
|
13
|
-
_ensure_ssl_default,
|
|
14
13
|
get_database_url_from_env,
|
|
15
14
|
build_engine,
|
|
16
15
|
)
|
|
@@ -103,7 +102,6 @@ if not effective_url:
|
|
|
103
102
|
|
|
104
103
|
u = make_url(effective_url)
|
|
105
104
|
u = _coerce_sync_driver(u)
|
|
106
|
-
u = _ensure_ssl_default(u)
|
|
107
105
|
config.set_main_option("sqlalchemy.url", u.render_as_string(hide_password=False))
|
|
108
106
|
|
|
109
107
|
|
|
@@ -142,14 +140,16 @@ def _collect_metadata() -> list[object]:
|
|
|
142
140
|
|
|
143
141
|
def _maybe_add(obj: object) -> None:
|
|
144
142
|
md = getattr(obj, "metadata", None) or obj
|
|
145
|
-
|
|
143
|
+
# Strict check: must be actual MetaData instance
|
|
144
|
+
if isinstance(md, MetaData) and md.tables:
|
|
146
145
|
found.append(md)
|
|
147
146
|
|
|
148
147
|
def _scan_module_objects(mod: object) -> None:
|
|
149
148
|
try:
|
|
150
149
|
for val in vars(mod).values():
|
|
151
150
|
md = getattr(val, "metadata", None) or None
|
|
152
|
-
if
|
|
151
|
+
# Only add if it's a SQLAlchemy MetaData object (has tables dict, not a callable/generator)
|
|
152
|
+
if md is not None and hasattr(md, "tables") and isinstance(getattr(md, "tables", None), dict):
|
|
153
153
|
found.append(md)
|
|
154
154
|
except Exception:
|
|
155
155
|
pass
|
|
@@ -191,9 +191,16 @@ def _collect_metadata() -> list[object]:
|
|
|
191
191
|
if name not in pkgs:
|
|
192
192
|
pkgs.append(name)
|
|
193
193
|
|
|
194
|
-
#
|
|
194
|
+
# Only attempt a bare 'models' import if discoverable to avoid noisy tracebacks
|
|
195
195
|
if "models" not in pkgs:
|
|
196
|
-
|
|
196
|
+
try:
|
|
197
|
+
spec = getattr(importlib, "util", None)
|
|
198
|
+
if spec is not None and getattr(spec, "find_spec", None) is not None:
|
|
199
|
+
if spec.find_spec("models") is not None:
|
|
200
|
+
pkgs.append("models")
|
|
201
|
+
except Exception:
|
|
202
|
+
# If discovery fails, skip adding bare 'models'
|
|
203
|
+
pass
|
|
197
204
|
|
|
198
205
|
def _import_and_collect(modname: str):
|
|
199
206
|
try:
|
|
@@ -239,6 +246,21 @@ def _collect_metadata() -> list[object]:
|
|
|
239
246
|
except Exception:
|
|
240
247
|
_note("ModelBase import", False, traceback.format_exc())
|
|
241
248
|
|
|
249
|
+
# Core security models (AuthSession, RefreshToken, etc.)
|
|
250
|
+
try:
|
|
251
|
+
import svc_infra.security.models # noqa: F401
|
|
252
|
+
_note("svc_infra.security.models", True, None)
|
|
253
|
+
except Exception:
|
|
254
|
+
_note("svc_infra.security.models", False, traceback.format_exc())
|
|
255
|
+
|
|
256
|
+
# OAuth models (opt-in via environment variable)
|
|
257
|
+
if os.getenv("ALEMBIC_ENABLE_OAUTH", "").lower() in {"1", "true", "yes"}:
|
|
258
|
+
try:
|
|
259
|
+
import svc_infra.security.oauth_models # noqa: F401
|
|
260
|
+
_note("svc_infra.security.oauth_models", True, None)
|
|
261
|
+
except Exception:
|
|
262
|
+
_note("svc_infra.security.oauth_models", False, traceback.format_exc())
|
|
263
|
+
|
|
242
264
|
# Optional: autobind API key model
|
|
243
265
|
try:
|
|
244
266
|
from svc_infra.db.sql.apikey import try_autobind_apikey_model
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Sequence
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
|
+
|
|
7
|
+
from .service import SqlService
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TenantSqlService(SqlService):
|
|
11
|
+
"""
|
|
12
|
+
SQL service wrapper that automatically scopes operations to a tenant.
|
|
13
|
+
|
|
14
|
+
- Adds a where filter (model.tenant_field == tenant_id) for list/get/update/delete/search/count.
|
|
15
|
+
- On create, if the model has the tenant field and it's not set in data, injects tenant_id.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, repo, *, tenant_id: str, tenant_field: str = "tenant_id"):
|
|
19
|
+
super().__init__(repo)
|
|
20
|
+
self.tenant_id = tenant_id
|
|
21
|
+
self.tenant_field = tenant_field
|
|
22
|
+
|
|
23
|
+
def _where(self) -> Sequence[Any]:
|
|
24
|
+
model = self.repo.model
|
|
25
|
+
col = getattr(model, self.tenant_field, None)
|
|
26
|
+
if col is None:
|
|
27
|
+
return []
|
|
28
|
+
return [col == self.tenant_id]
|
|
29
|
+
|
|
30
|
+
async def list(
|
|
31
|
+
self, session: AsyncSession, *, limit: int, offset: int, order_by=None
|
|
32
|
+
):
|
|
33
|
+
return await self.repo.list(
|
|
34
|
+
session, limit=limit, offset=offset, order_by=order_by, where=self._where()
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
async def count(self, session: AsyncSession) -> int:
|
|
38
|
+
return await self.repo.count(session, where=self._where())
|
|
39
|
+
|
|
40
|
+
async def get(self, session: AsyncSession, id_value: Any):
|
|
41
|
+
return await self.repo.get(session, id_value, where=self._where())
|
|
42
|
+
|
|
43
|
+
async def create(self, session: AsyncSession, data: dict[str, Any]):
|
|
44
|
+
data = await self.pre_create(data)
|
|
45
|
+
# inject tenant_id if model supports it and value missing
|
|
46
|
+
if (
|
|
47
|
+
self.tenant_field in self.repo._model_columns()
|
|
48
|
+
and self.tenant_field not in data
|
|
49
|
+
):
|
|
50
|
+
data[self.tenant_field] = self.tenant_id
|
|
51
|
+
return await self.repo.create(session, data)
|
|
52
|
+
|
|
53
|
+
async def update(self, session: AsyncSession, id_value: Any, data: dict[str, Any]):
|
|
54
|
+
data = await self.pre_update(data)
|
|
55
|
+
return await self.repo.update(session, id_value, data, where=self._where())
|
|
56
|
+
|
|
57
|
+
async def delete(self, session: AsyncSession, id_value: Any) -> bool:
|
|
58
|
+
return await self.repo.delete(session, id_value, where=self._where())
|
|
59
|
+
|
|
60
|
+
async def search(
|
|
61
|
+
self,
|
|
62
|
+
session: AsyncSession,
|
|
63
|
+
*,
|
|
64
|
+
q: str,
|
|
65
|
+
fields: Sequence[str],
|
|
66
|
+
limit: int,
|
|
67
|
+
offset: int,
|
|
68
|
+
order_by=None,
|
|
69
|
+
):
|
|
70
|
+
return await self.repo.search(
|
|
71
|
+
session,
|
|
72
|
+
q=q,
|
|
73
|
+
fields=fields,
|
|
74
|
+
limit=limit,
|
|
75
|
+
offset=offset,
|
|
76
|
+
order_by=order_by,
|
|
77
|
+
where=self._where(),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
async def count_filtered(
|
|
81
|
+
self, session: AsyncSession, *, q: str, fields: Sequence[str]
|
|
82
|
+
) -> int:
|
|
83
|
+
return await self.repo.count_filtered(
|
|
84
|
+
session, q=q, fields=fields, where=self._where()
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
__all__ = ["TenantSqlService"]
|