danweiyuan-eapi 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. danweiyuan_eapi-0.1.0/PKG-INFO +19 -0
  2. danweiyuan_eapi-0.1.0/README.md +20 -0
  3. danweiyuan_eapi-0.1.0/pyproject.toml +44 -0
  4. danweiyuan_eapi-0.1.0/setup.cfg +4 -0
  5. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/__init__.py +14 -0
  6. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/cache.py +30 -0
  7. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/config.py +23 -0
  8. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/database.py +42 -0
  9. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/dependencies.py +16 -0
  10. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/exceptions.py +75 -0
  11. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/pagination.py +25 -0
  12. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/response.py +19 -0
  13. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi/security.py +48 -0
  14. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi.egg-info/PKG-INFO +19 -0
  15. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi.egg-info/SOURCES.txt +24 -0
  16. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi.egg-info/dependency_links.txt +1 -0
  17. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi.egg-info/requires.txt +14 -0
  18. danweiyuan_eapi-0.1.0/src/danweiyuan_eapi.egg-info/top_level.txt +1 -0
  19. danweiyuan_eapi-0.1.0/tests/test_cache.py +28 -0
  20. danweiyuan_eapi-0.1.0/tests/test_config.py +65 -0
  21. danweiyuan_eapi-0.1.0/tests/test_database.py +42 -0
  22. danweiyuan_eapi-0.1.0/tests/test_dependencies.py +70 -0
  23. danweiyuan_eapi-0.1.0/tests/test_exceptions.py +93 -0
  24. danweiyuan_eapi-0.1.0/tests/test_pagination.py +38 -0
  25. danweiyuan_eapi-0.1.0/tests/test_response.py +43 -0
  26. danweiyuan_eapi-0.1.0/tests/test_security.py +51 -0
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: danweiyuan-eapi
3
+ Version: 0.1.0
4
+ Summary: Lightweight FastAPI infrastructure — exceptions, database, config, security, cache, pagination
5
+ License: MIT
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: fastapi>=0.100.0
8
+ Requires-Dist: sqlalchemy[asyncio]>=2.0.0
9
+ Requires-Dist: pydantic>=2.0.0
10
+ Requires-Dist: pydantic-settings>=2.0.0
11
+ Requires-Dist: python-jose[cryptography]>=3.3.0
12
+ Requires-Dist: bcrypt>=4.0.0
13
+ Requires-Dist: redis>=5.0.0
14
+ Provides-Extra: dev
15
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
16
+ Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
17
+ Requires-Dist: httpx>=0.27.0; extra == "dev"
18
+ Requires-Dist: aiosqlite>=0.20.0; extra == "dev"
19
+ Requires-Dist: ruff>=0.6.0; extra == "dev"
@@ -0,0 +1,20 @@
1
+ # danweiyuan-eapi
2
+
3
+ Lightweight FastAPI infrastructure package for shared use across projects.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install danweiyuan-eapi
9
+ ```
10
+
11
+ ## Modules
12
+
13
+ - `config` — BaseSettings with required fields (database_url, redis_url, secret_key)
14
+ - `exceptions` — AppError hierarchy + FastAPI exception handlers
15
+ - `database` — Async SQLAlchemy engine/session factory + Base + TimestampMixin
16
+ - `dependencies` — `get_db()` FastAPI dependency
17
+ - `security` — JWT create/verify + bcrypt hash/verify (stateless)
18
+ - `cache` — Async Redis connection manager
19
+ - `response` — Unified API response helpers
20
+ - `pagination` — PaginationParams + paginate helper
@@ -0,0 +1,44 @@
1
+ [project]
2
+ name = "danweiyuan-eapi"
3
+ version = "0.1.0"
4
+ description = "Lightweight FastAPI infrastructure — exceptions, database, config, security, cache, pagination"
5
+ requires-python = ">=3.11"
6
+ license = {text = "MIT"}
7
+ dependencies = [
8
+ "fastapi>=0.100.0",
9
+ "sqlalchemy[asyncio]>=2.0.0",
10
+ "pydantic>=2.0.0",
11
+ "pydantic-settings>=2.0.0",
12
+ "python-jose[cryptography]>=3.3.0",
13
+ "bcrypt>=4.0.0",
14
+ "redis>=5.0.0",
15
+ ]
16
+
17
+ [project.optional-dependencies]
18
+ dev = [
19
+ "pytest>=8.0.0",
20
+ "pytest-asyncio>=0.24.0",
21
+ "httpx>=0.27.0",
22
+ "aiosqlite>=0.20.0",
23
+ "ruff>=0.6.0",
24
+ ]
25
+
26
+ [build-system]
27
+ requires = ["setuptools>=68.0"]
28
+ build-backend = "setuptools.build_meta"
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["src"]
32
+
33
+ [tool.ruff]
34
+ target-version = "py311"
35
+ line-length = 120
36
+
37
+ [tool.ruff.lint]
38
+ select = ["E", "W", "F", "I", "N", "UP", "B", "SIM", "RUF"]
39
+
40
+ [tool.ruff.lint.isort]
41
+ known-first-party = ["danweiyuan_eapi"]
42
+
43
+ [tool.pytest.ini_options]
44
+ asyncio_mode = "auto"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,14 @@
1
+ """danweiyuan-eapi — Lightweight FastAPI infrastructure.
2
+
3
+ Modules:
4
+ config — BaseSettings with required fields
5
+ exceptions — AppError hierarchy + FastAPI exception handlers
6
+ database — Async SQLAlchemy engine/session factory + Base + TimestampMixin
7
+ dependencies — get_db() FastAPI dependency factory
8
+ security — JWT + bcrypt helpers (stateless)
9
+ cache — Async Redis connection manager
10
+ response — Unified API response helpers
11
+ pagination — PaginationParams + paginate helper
12
+ """
13
+
14
+ __version__ = "0.1.0"
@@ -0,0 +1,30 @@
1
+ """Async Redis connection manager."""
2
+
3
+ import redis.asyncio as aioredis
4
+
5
+ _redis_url: str | None = None
6
+ _redis: aioredis.Redis | None = None
7
+
8
+
9
+ def configure(redis_url: str) -> None:
10
+ """Set the Redis connection URL."""
11
+ global _redis_url
12
+ _redis_url = redis_url
13
+
14
+
15
+ async def get_redis() -> aioredis.Redis:
16
+ """Get or create the shared async Redis connection."""
17
+ if _redis_url is None:
18
+ raise RuntimeError("danweiyuan_eapi.cache not configured — call cache.configure(redis_url) first")
19
+ global _redis
20
+ if _redis is None:
21
+ _redis = aioredis.from_url(_redis_url, decode_responses=True)
22
+ return _redis
23
+
24
+
25
+ async def close_redis() -> None:
26
+ """Close the Redis connection. Safe to call even if not connected."""
27
+ global _redis
28
+ if _redis is not None:
29
+ await _redis.close()
30
+ _redis = None
@@ -0,0 +1,23 @@
1
+ """Base Pydantic Settings class for FastAPI projects."""
2
+
3
+ from pydantic import Field
4
+ from pydantic_settings import BaseSettings as PydanticBaseSettings
5
+ from pydantic_settings import SettingsConfigDict
6
+
7
+
8
+ class BaseSettings(PydanticBaseSettings):
9
+ """Base settings — subclass and add project-specific fields."""
10
+
11
+ database_url: str
12
+ redis_url: str
13
+ secret_key: str
14
+ jwt_algorithm: str = "HS256"
15
+ access_token_expire_minutes: int = 30
16
+ debug: bool = False
17
+ allowed_origins: list[str] = Field(default_factory=list)
18
+
19
+ model_config = SettingsConfigDict(
20
+ env_file=".env",
21
+ env_file_encoding="utf-8",
22
+ extra="ignore",
23
+ )
@@ -0,0 +1,42 @@
1
+ """Async SQLAlchemy engine and session factory, declarative Base, and TimestampMixin."""
2
+
3
+ from datetime import datetime
4
+
5
+ from sqlalchemy import func
6
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
7
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
8
+
9
+
10
+ class Base(DeclarativeBase):
11
+ """Declarative base class for all ORM models."""
12
+
13
+ pass
14
+
15
+
16
+ class TimestampMixin:
17
+ """Mixin that adds created_at and updated_at columns."""
18
+
19
+ created_at: Mapped[datetime] = mapped_column(server_default=func.now())
20
+ updated_at: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now())
21
+
22
+
23
+ def create_async_engine_factory(
24
+ database_url: str,
25
+ pool_size: int = 5,
26
+ max_overflow: int = 10,
27
+ echo: bool = False,
28
+ ) -> AsyncEngine:
29
+ """Create an async SQLAlchemy engine."""
30
+ if database_url.startswith("sqlite"):
31
+ return create_async_engine(database_url, echo=echo)
32
+ return create_async_engine(
33
+ database_url,
34
+ pool_size=pool_size,
35
+ max_overflow=max_overflow,
36
+ echo=echo,
37
+ )
38
+
39
+
40
+ def create_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
41
+ """Create an async session factory bound to the given engine."""
42
+ return async_sessionmaker(engine, expire_on_commit=False)
@@ -0,0 +1,16 @@
1
+ """FastAPI dependency factories."""
2
+
3
+ from collections.abc import AsyncGenerator
4
+
5
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
6
+
7
+
8
+ def create_get_db(session_factory: async_sessionmaker[AsyncSession]):
9
+ """Create a get_db FastAPI dependency bound to a session factory."""
10
+
11
+ async def get_db() -> AsyncGenerator[AsyncSession]:
12
+ """Yield a database session, automatically closed after the request."""
13
+ async with session_factory() as session:
14
+ yield session
15
+
16
+ return get_db
@@ -0,0 +1,75 @@
1
+ """AppError exception hierarchy and FastAPI exception handler registration."""
2
+
3
+ from fastapi import FastAPI, Request, status
4
+ from fastapi.responses import JSONResponse
5
+
6
+
7
+ class AppError(Exception):
8
+ """Base class for all application-level errors."""
9
+
10
+ def __init__(self, message: str, code: str = "UNKNOWN_ERROR") -> None:
11
+ self.message = message
12
+ self.code = code
13
+
14
+
15
+ class NotFoundError(AppError):
16
+ """Resource not found."""
17
+
18
+ def __init__(self, resource: str) -> None:
19
+ super().__init__(message=f"{resource}不存在", code="NOT_FOUND")
20
+
21
+
22
+ class BusinessError(AppError):
23
+ """Business rule validation failure."""
24
+
25
+ pass
26
+
27
+
28
+ class PermissionDeniedError(AppError):
29
+ """Insufficient permissions."""
30
+
31
+ def __init__(self) -> None:
32
+ super().__init__(message="权限不足", code="PERMISSION_DENIED")
33
+
34
+
35
+ class AuthenticationError(AppError):
36
+ """Authentication failed or missing."""
37
+
38
+ def __init__(self) -> None:
39
+ super().__init__(message="认证失败", code="AUTHENTICATION_FAILED")
40
+
41
+
42
+ def register_exception_handlers(app: FastAPI) -> None:
43
+ """Register exception-to-HTTP-response handlers on a FastAPI app."""
44
+
45
+ @app.exception_handler(NotFoundError)
46
+ async def not_found_handler(request: Request, exc: NotFoundError) -> JSONResponse:
47
+ """Handle NotFoundError → 404."""
48
+ return JSONResponse(
49
+ status_code=status.HTTP_404_NOT_FOUND,
50
+ content={"code": exc.code, "message": exc.message},
51
+ )
52
+
53
+ @app.exception_handler(BusinessError)
54
+ async def business_error_handler(request: Request, exc: BusinessError) -> JSONResponse:
55
+ """Handle BusinessError → 422."""
56
+ return JSONResponse(
57
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
58
+ content={"code": exc.code, "message": exc.message},
59
+ )
60
+
61
+ @app.exception_handler(PermissionDeniedError)
62
+ async def permission_denied_handler(request: Request, exc: PermissionDeniedError) -> JSONResponse:
63
+ """Handle PermissionDeniedError → 403."""
64
+ return JSONResponse(
65
+ status_code=status.HTTP_403_FORBIDDEN,
66
+ content={"code": exc.code, "message": exc.message},
67
+ )
68
+
69
+ @app.exception_handler(AuthenticationError)
70
+ async def auth_error_handler(request: Request, exc: AuthenticationError) -> JSONResponse:
71
+ """Handle AuthenticationError → 401."""
72
+ return JSONResponse(
73
+ status_code=status.HTTP_401_UNAUTHORIZED,
74
+ content={"code": exc.code, "message": exc.message},
75
+ )
@@ -0,0 +1,25 @@
1
+ """Pagination utilities."""
2
+
3
+ from typing import NamedTuple
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class PaginationParams(BaseModel):
9
+ """Query parameters for paginated endpoints."""
10
+
11
+ page: int = Field(default=1, ge=1, description="Page number (1-based)")
12
+ page_size: int = Field(default=20, ge=1, le=100, description="Items per page")
13
+
14
+
15
+ class OffsetLimit(NamedTuple):
16
+ """Result of paginate() — ready for SQL OFFSET / LIMIT."""
17
+
18
+ offset: int
19
+ limit: int
20
+
21
+
22
+ def paginate(page: int, page_size: int) -> OffsetLimit:
23
+ """Calculate SQL offset and limit from 1-based page and page_size."""
24
+ offset = (max(page, 1) - 1) * page_size
25
+ return OffsetLimit(offset=offset, limit=page_size)
@@ -0,0 +1,19 @@
1
+ """Unified API response helpers."""
2
+
3
+ import time
4
+ from typing import Any
5
+
6
+
7
+ def success(data: Any = None, message: str = "success") -> dict:
8
+ """Return a 200 OK response envelope."""
9
+ return {"code": 200, "message": message, "data": data, "timestamp": int(time.time())}
10
+
11
+
12
+ def fail(code: int = 400, message: str = "fail") -> dict:
13
+ """Return an error response envelope."""
14
+ return {"code": code, "message": message, "data": None, "timestamp": int(time.time())}
15
+
16
+
17
+ def paginated(items: list, total: int, page: int, page_size: int) -> dict:
18
+ """Return a 200 OK envelope containing paginated list metadata."""
19
+ return success(data={"items": items, "total": total, "page": page, "page_size": page_size})
@@ -0,0 +1,48 @@
1
+ """JWT token creation/verification and bcrypt password hashing.
2
+
3
+ All helpers are stateless — pass secrets and algorithm as parameters.
4
+ """
5
+
6
+ from datetime import UTC, datetime, timedelta
7
+ from typing import Any
8
+
9
+ import bcrypt
10
+ from jose import JWTError, jwt
11
+
12
+
13
+ def hash_password(password: str) -> str:
14
+ """Return a bcrypt hash of *password*."""
15
+ password_bytes = password.encode("utf-8")
16
+ salt = bcrypt.gensalt()
17
+ hashed = bcrypt.hashpw(password_bytes, salt)
18
+ return hashed.decode("utf-8")
19
+
20
+
21
+ def verify_password(plain: str, hashed: str) -> bool:
22
+ """Return True if *plain* matches *hashed*."""
23
+ return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
24
+
25
+
26
+ def create_token(
27
+ data: dict[str, Any],
28
+ secret: str,
29
+ expires_minutes: int,
30
+ algorithm: str = "HS256",
31
+ ) -> str:
32
+ """Encode *data* as a signed JWT with the given expiration."""
33
+ to_encode = data.copy()
34
+ expire = datetime.now(UTC) + timedelta(minutes=expires_minutes)
35
+ to_encode["exp"] = expire
36
+ return jwt.encode(to_encode, secret, algorithm=algorithm)
37
+
38
+
39
+ def decode_token(
40
+ token: str,
41
+ secret: str,
42
+ algorithm: str = "HS256",
43
+ ) -> dict[str, Any] | None:
44
+ """Decode and verify *token*. Returns None if invalid/expired."""
45
+ try:
46
+ return jwt.decode(token, secret, algorithms=[algorithm])
47
+ except JWTError:
48
+ return None
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: danweiyuan-eapi
3
+ Version: 0.1.0
4
+ Summary: Lightweight FastAPI infrastructure — exceptions, database, config, security, cache, pagination
5
+ License: MIT
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: fastapi>=0.100.0
8
+ Requires-Dist: sqlalchemy[asyncio]>=2.0.0
9
+ Requires-Dist: pydantic>=2.0.0
10
+ Requires-Dist: pydantic-settings>=2.0.0
11
+ Requires-Dist: python-jose[cryptography]>=3.3.0
12
+ Requires-Dist: bcrypt>=4.0.0
13
+ Requires-Dist: redis>=5.0.0
14
+ Provides-Extra: dev
15
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
16
+ Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
17
+ Requires-Dist: httpx>=0.27.0; extra == "dev"
18
+ Requires-Dist: aiosqlite>=0.20.0; extra == "dev"
19
+ Requires-Dist: ruff>=0.6.0; extra == "dev"
@@ -0,0 +1,24 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/danweiyuan_eapi/__init__.py
4
+ src/danweiyuan_eapi/cache.py
5
+ src/danweiyuan_eapi/config.py
6
+ src/danweiyuan_eapi/database.py
7
+ src/danweiyuan_eapi/dependencies.py
8
+ src/danweiyuan_eapi/exceptions.py
9
+ src/danweiyuan_eapi/pagination.py
10
+ src/danweiyuan_eapi/response.py
11
+ src/danweiyuan_eapi/security.py
12
+ src/danweiyuan_eapi.egg-info/PKG-INFO
13
+ src/danweiyuan_eapi.egg-info/SOURCES.txt
14
+ src/danweiyuan_eapi.egg-info/dependency_links.txt
15
+ src/danweiyuan_eapi.egg-info/requires.txt
16
+ src/danweiyuan_eapi.egg-info/top_level.txt
17
+ tests/test_cache.py
18
+ tests/test_config.py
19
+ tests/test_database.py
20
+ tests/test_dependencies.py
21
+ tests/test_exceptions.py
22
+ tests/test_pagination.py
23
+ tests/test_response.py
24
+ tests/test_security.py
@@ -0,0 +1,14 @@
1
+ fastapi>=0.100.0
2
+ sqlalchemy[asyncio]>=2.0.0
3
+ pydantic>=2.0.0
4
+ pydantic-settings>=2.0.0
5
+ python-jose[cryptography]>=3.3.0
6
+ bcrypt>=4.0.0
7
+ redis>=5.0.0
8
+
9
+ [dev]
10
+ pytest>=8.0.0
11
+ pytest-asyncio>=0.24.0
12
+ httpx>=0.27.0
13
+ aiosqlite>=0.20.0
14
+ ruff>=0.6.0
@@ -0,0 +1 @@
1
+ danweiyuan_eapi
@@ -0,0 +1,28 @@
1
+ """Tests for danweiyuan_eapi.cache."""
2
+
3
+ import pytest
4
+
5
+ import danweiyuan_eapi.cache as cache_mod
6
+ from danweiyuan_eapi.cache import close_redis, configure, get_redis
7
+
8
+
9
+ class TestCacheConfiguration:
10
+ @pytest.fixture(autouse=True)
11
+ async def cleanup(self):
12
+ yield
13
+ cache_mod._redis_url = None
14
+ cache_mod._redis = None
15
+
16
+ async def test_get_redis_without_configure_raises(self):
17
+ cache_mod._redis_url = None
18
+ cache_mod._redis = None
19
+ with pytest.raises(RuntimeError, match="not configured"):
20
+ await get_redis()
21
+
22
+ async def test_configure_sets_url(self):
23
+ configure("redis://localhost:6379/15")
24
+ assert cache_mod._redis_url == "redis://localhost:6379/15"
25
+
26
+ async def test_close_redis_is_safe_when_not_connected(self):
27
+ cache_mod._redis = None
28
+ await close_redis() # should not raise
@@ -0,0 +1,65 @@
1
+ """Tests for danweiyuan_eapi.config."""
2
+
3
+ import pytest
4
+
5
+ from danweiyuan_eapi.config import BaseSettings
6
+
7
+
8
+ class TestBaseSettings:
9
+ def test_required_fields_raise_without_env(self, monkeypatch):
10
+ monkeypatch.delenv("DATABASE_URL", raising=False)
11
+ monkeypatch.delenv("REDIS_URL", raising=False)
12
+ monkeypatch.delenv("SECRET_KEY", raising=False)
13
+ with pytest.raises(ValueError):
14
+ BaseSettings()
15
+
16
+ def test_construct_with_all_required(self, monkeypatch):
17
+ monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
18
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
19
+ monkeypatch.setenv("SECRET_KEY", "test-secret")
20
+ s = BaseSettings()
21
+ assert s.database_url == "postgresql+asyncpg://u:p@localhost/db"
22
+ assert s.redis_url == "redis://localhost:6379/0"
23
+ assert s.secret_key == "test-secret"
24
+
25
+ def test_debug_defaults_false(self, monkeypatch):
26
+ monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
27
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
28
+ monkeypatch.setenv("SECRET_KEY", "s")
29
+ s = BaseSettings()
30
+ assert s.debug is False
31
+
32
+ def test_allowed_origins_defaults_empty(self, monkeypatch):
33
+ monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
34
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
35
+ monkeypatch.setenv("SECRET_KEY", "s")
36
+ s = BaseSettings()
37
+ assert s.allowed_origins == []
38
+
39
+ def test_allowed_origins_from_env(self, monkeypatch):
40
+ monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
41
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
42
+ monkeypatch.setenv("SECRET_KEY", "s")
43
+ monkeypatch.setenv("ALLOWED_ORIGINS", '["http://localhost:5173","http://example.com"]')
44
+ s = BaseSettings()
45
+ assert s.allowed_origins == ["http://localhost:5173", "http://example.com"]
46
+
47
+ def test_subclass_adds_fields(self, monkeypatch):
48
+ monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
49
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
50
+ monkeypatch.setenv("SECRET_KEY", "s")
51
+ monkeypatch.setenv("CUSTOM_FIELD", "hello")
52
+
53
+ class MySettings(BaseSettings):
54
+ custom_field: str = "default"
55
+
56
+ s = MySettings()
57
+ assert s.custom_field == "hello"
58
+
59
+ def test_jwt_defaults(self, monkeypatch):
60
+ monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://u:p@localhost/db")
61
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
62
+ monkeypatch.setenv("SECRET_KEY", "s")
63
+ s = BaseSettings()
64
+ assert s.jwt_algorithm == "HS256"
65
+ assert s.access_token_expire_minutes == 30
@@ -0,0 +1,42 @@
1
+ """Tests for danweiyuan_eapi.database."""
2
+
3
+ import pytest
4
+ from sqlalchemy import String, select
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from sqlalchemy.orm import Mapped, mapped_column
7
+
8
+ from danweiyuan_eapi.database import Base, TimestampMixin, create_async_engine_factory, create_session_factory
9
+
10
+
11
+ class FakeModel(Base, TimestampMixin):
12
+ __tablename__ = "fake_items"
13
+ id: Mapped[int] = mapped_column(primary_key=True)
14
+ name: Mapped[str] = mapped_column(String(50))
15
+
16
+
17
+ class TestDatabaseFactory:
18
+ @pytest.fixture()
19
+ async def session(self):
20
+ engine = create_async_engine_factory("sqlite+aiosqlite:///:memory:")
21
+ async with engine.begin() as conn:
22
+ await conn.run_sync(Base.metadata.create_all)
23
+ session_factory = create_session_factory(engine)
24
+ async with session_factory() as session:
25
+ yield session
26
+ await engine.dispose()
27
+
28
+ async def test_insert_and_query(self, session: AsyncSession):
29
+ item = FakeModel(id=1, name="test-item")
30
+ session.add(item)
31
+ await session.commit()
32
+ result = await session.execute(select(FakeModel).where(FakeModel.id == 1))
33
+ found = result.scalar_one()
34
+ assert found.name == "test-item"
35
+
36
+ async def test_timestamp_mixin_has_columns(self):
37
+ assert hasattr(FakeModel, "created_at")
38
+ assert hasattr(FakeModel, "updated_at")
39
+
40
+ async def test_base_is_declarative(self):
41
+ assert hasattr(Base, "metadata")
42
+ assert "fake_items" in Base.metadata.tables
@@ -0,0 +1,70 @@
1
+ """Tests for danweiyuan_eapi.dependencies."""
2
+
3
+ import pytest
4
+ from fastapi import Depends, FastAPI
5
+ from fastapi.testclient import TestClient
6
+ from sqlalchemy import String, select
7
+ from sqlalchemy.orm import Mapped, mapped_column
8
+
9
+ from danweiyuan_eapi.database import Base, create_async_engine_factory, create_session_factory
10
+ from danweiyuan_eapi.dependencies import create_get_db
11
+
12
+
13
+ class Item(Base):
14
+ __tablename__ = "dep_test_items"
15
+ id: Mapped[int] = mapped_column(primary_key=True)
16
+ name: Mapped[str] = mapped_column(String(50))
17
+
18
+
19
+ class TestGetDb:
20
+ @pytest.fixture()
21
+ def app(self, tmp_path):
22
+ db_path = tmp_path / "test.db"
23
+ engine = create_async_engine_factory(f"sqlite+aiosqlite:///{db_path}")
24
+ session_factory = create_session_factory(engine)
25
+ get_db = create_get_db(session_factory)
26
+
27
+ from collections.abc import AsyncGenerator
28
+ from contextlib import asynccontextmanager
29
+
30
+ @asynccontextmanager
31
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
32
+ async with engine.begin() as conn:
33
+ await conn.run_sync(Base.metadata.create_all)
34
+ yield
35
+ await engine.dispose()
36
+
37
+ app = FastAPI(lifespan=lifespan)
38
+
39
+ @app.post("/items")
40
+ async def create_item(session=Depends(get_db)): # noqa: B008
41
+ item = Item(id=1, name="test")
42
+ session.add(item)
43
+ await session.commit()
44
+ return {"id": 1}
45
+
46
+ @app.get("/items/{item_id}")
47
+ async def get_item(item_id: int, session=Depends(get_db)): # noqa: B008
48
+ result = await session.execute(select(Item).where(Item.id == item_id))
49
+ item = result.scalar_one_or_none()
50
+ if item is None:
51
+ return {"error": "not found"}
52
+ return {"id": item.id, "name": item.name}
53
+
54
+ return app
55
+
56
+ @pytest.fixture()
57
+ def client(self, app):
58
+ with TestClient(app) as c:
59
+ yield c
60
+
61
+ def test_create_and_get_via_dependency(self, client):
62
+ resp = client.post("/items")
63
+ assert resp.status_code == 200
64
+ resp = client.get("/items/1")
65
+ assert resp.status_code == 200
66
+ assert resp.json() == {"id": 1, "name": "test"}
67
+
68
+ def test_get_nonexistent_returns_not_found(self, client):
69
+ resp = client.get("/items/999")
70
+ assert resp.json() == {"error": "not found"}
@@ -0,0 +1,93 @@
1
+ """Tests for danweiyuan_eapi.exceptions."""
2
+
3
+ import pytest
4
+ from fastapi import FastAPI
5
+ from fastapi.testclient import TestClient
6
+
7
+ from danweiyuan_eapi.exceptions import (
8
+ AppError,
9
+ AuthenticationError,
10
+ BusinessError,
11
+ NotFoundError,
12
+ PermissionDeniedError,
13
+ register_exception_handlers,
14
+ )
15
+
16
+
17
+ class TestAppErrorHierarchy:
18
+ def test_app_error_defaults(self):
19
+ err = AppError(message="something broke")
20
+ assert err.message == "something broke"
21
+ assert err.code == "UNKNOWN_ERROR"
22
+
23
+ def test_app_error_custom_code(self):
24
+ err = AppError(message="oops", code="CUSTOM")
25
+ assert err.code == "CUSTOM"
26
+
27
+ def test_not_found_error(self):
28
+ err = NotFoundError("用户")
29
+ assert err.message == "用户不存在"
30
+ assert err.code == "NOT_FOUND"
31
+ assert isinstance(err, AppError)
32
+
33
+ def test_business_error(self):
34
+ err = BusinessError(message="余额不足", code="INSUFFICIENT_BALANCE")
35
+ assert err.message == "余额不足"
36
+ assert isinstance(err, AppError)
37
+
38
+ def test_permission_denied_error(self):
39
+ err = PermissionDeniedError()
40
+ assert err.message == "权限不足"
41
+ assert err.code == "PERMISSION_DENIED"
42
+
43
+ def test_authentication_error(self):
44
+ err = AuthenticationError()
45
+ assert err.message == "认证失败"
46
+ assert err.code == "AUTHENTICATION_FAILED"
47
+
48
+
49
+ class TestExceptionHandlers:
50
+ @pytest.fixture()
51
+ def app(self):
52
+ app = FastAPI()
53
+ register_exception_handlers(app)
54
+
55
+ @app.get("/not-found")
56
+ async def raise_not_found():
57
+ raise NotFoundError("订单")
58
+
59
+ @app.get("/business-error")
60
+ async def raise_business():
61
+ raise BusinessError(message="余额不足", code="INSUFFICIENT_BALANCE")
62
+
63
+ @app.get("/forbidden")
64
+ async def raise_forbidden():
65
+ raise PermissionDeniedError()
66
+
67
+ @app.get("/unauthorized")
68
+ async def raise_auth():
69
+ raise AuthenticationError()
70
+
71
+ return app
72
+
73
+ @pytest.fixture()
74
+ def client(self, app):
75
+ return TestClient(app)
76
+
77
+ def test_not_found_returns_404(self, client):
78
+ resp = client.get("/not-found")
79
+ assert resp.status_code == 404
80
+ assert resp.json() == {"code": "NOT_FOUND", "message": "订单不存在"}
81
+
82
+ def test_business_error_returns_422(self, client):
83
+ resp = client.get("/business-error")
84
+ assert resp.status_code == 422
85
+ assert resp.json() == {"code": "INSUFFICIENT_BALANCE", "message": "余额不足"}
86
+
87
+ def test_permission_denied_returns_403(self, client):
88
+ resp = client.get("/forbidden")
89
+ assert resp.status_code == 403
90
+
91
+ def test_authentication_error_returns_401(self, client):
92
+ resp = client.get("/unauthorized")
93
+ assert resp.status_code == 401
@@ -0,0 +1,38 @@
1
+ """Tests for danweiyuan_eapi.pagination."""
2
+
3
+ import pytest
4
+
5
+ from danweiyuan_eapi.pagination import OffsetLimit, PaginationParams, paginate
6
+
7
+
8
+ class TestPaginate:
9
+ def test_first_page(self):
10
+ assert paginate(1, 20) == OffsetLimit(offset=0, limit=20)
11
+
12
+ def test_third_page(self):
13
+ assert paginate(3, 10) == OffsetLimit(offset=20, limit=10)
14
+
15
+ def test_page_zero_treated_as_one(self):
16
+ assert paginate(0, 20) == OffsetLimit(offset=0, limit=20)
17
+
18
+ def test_negative_page_treated_as_one(self):
19
+ assert paginate(-5, 20) == OffsetLimit(offset=0, limit=20)
20
+
21
+
22
+ class TestPaginationParams:
23
+ def test_defaults(self):
24
+ params = PaginationParams()
25
+ assert params.page == 1
26
+ assert params.page_size == 20
27
+
28
+ def test_custom_values(self):
29
+ params = PaginationParams(page=3, page_size=50)
30
+ assert params.page == 3
31
+
32
+ def test_page_min_is_one(self):
33
+ with pytest.raises(ValueError):
34
+ PaginationParams(page=0)
35
+
36
+ def test_page_size_max_is_100(self):
37
+ with pytest.raises(ValueError):
38
+ PaginationParams(page_size=101)
@@ -0,0 +1,43 @@
1
+ """Tests for danweiyuan_eapi.response."""
2
+
3
+ from danweiyuan_eapi.response import fail, paginated, success
4
+
5
+
6
+ class TestSuccess:
7
+ def test_default_message(self):
8
+ result = success(data={"id": 1})
9
+ assert result["code"] == 200
10
+ assert result["message"] == "success"
11
+ assert result["data"] == {"id": 1}
12
+ assert "timestamp" in result
13
+
14
+ def test_custom_message(self):
15
+ result = success(data=None, message="created")
16
+ assert result["message"] == "created"
17
+
18
+ def test_no_data(self):
19
+ result = success()
20
+ assert result["data"] is None
21
+
22
+
23
+ class TestFail:
24
+ def test_default_values(self):
25
+ result = fail()
26
+ assert result["code"] == 400
27
+ assert result["message"] == "fail"
28
+
29
+ def test_custom_code_and_message(self):
30
+ result = fail(code=500, message="internal error")
31
+ assert result["code"] == 500
32
+
33
+
34
+ class TestPaginated:
35
+ def test_paginated_structure(self):
36
+ items = [{"id": 1}, {"id": 2}]
37
+ result = paginated(items=items, total=50, page=1, page_size=20)
38
+ assert result["code"] == 200
39
+ data = result["data"]
40
+ assert data["items"] == items
41
+ assert data["total"] == 50
42
+ assert data["page"] == 1
43
+ assert data["page_size"] == 20
@@ -0,0 +1,51 @@
1
+ """Tests for danweiyuan_eapi.security."""
2
+
3
+ from danweiyuan_eapi.security import create_token, decode_token, hash_password, verify_password
4
+
5
+
6
+ class TestPasswordHashing:
7
+ def test_hash_returns_bcrypt_string(self):
8
+ hashed = hash_password("secret123")
9
+ assert hashed.startswith("$2b$")
10
+
11
+ def test_verify_correct_password(self):
12
+ hashed = hash_password("mypassword")
13
+ assert verify_password("mypassword", hashed) is True
14
+
15
+ def test_verify_wrong_password(self):
16
+ hashed = hash_password("mypassword")
17
+ assert verify_password("wrongpassword", hashed) is False
18
+
19
+ def test_different_hashes_for_same_password(self):
20
+ h1 = hash_password("same")
21
+ h2 = hash_password("same")
22
+ assert h1 != h2
23
+
24
+
25
+ class TestJWTTokens:
26
+ SECRET = "test-secret-key-for-unit-tests"
27
+
28
+ def test_create_and_decode_token(self):
29
+ data = {"sub": "42", "username": "alice"}
30
+ token = create_token(data, secret=self.SECRET, expires_minutes=30)
31
+ payload = decode_token(token, secret=self.SECRET)
32
+ assert payload is not None
33
+ assert payload["sub"] == "42"
34
+ assert payload["username"] == "alice"
35
+
36
+ def test_decode_with_wrong_secret_returns_none(self):
37
+ token = create_token({"sub": "1"}, secret=self.SECRET, expires_minutes=30)
38
+ assert decode_token(token, secret="wrong-secret") is None
39
+
40
+ def test_expired_token_returns_none(self):
41
+ token = create_token({"sub": "1"}, secret=self.SECRET, expires_minutes=-1)
42
+ assert decode_token(token, secret=self.SECRET) is None
43
+
44
+ def test_invalid_token_string_returns_none(self):
45
+ assert decode_token("not-a-jwt", secret=self.SECRET) is None
46
+
47
+ def test_create_does_not_mutate_input(self):
48
+ data = {"sub": "1"}
49
+ original = data.copy()
50
+ create_token(data, secret=self.SECRET, expires_minutes=30)
51
+ assert data == original