aiteamutils 0.2.3__tar.gz → 0.2.9__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aiteamutils
3
- Version: 0.2.3
3
+ Version: 0.2.9
4
4
  Summary: AI Team Utilities
5
5
  Project-URL: Homepage, https://github.com/yourusername/aiteamutils
6
6
  Project-URL: Issues, https://github.com/yourusername/aiteamutils/issues
@@ -20,7 +20,6 @@ from .base_service import BaseService
20
20
  from .base_repository import BaseRepository
21
21
  from .validators import validate_with
22
22
  from .enums import ActivityType
23
- from .cache import CacheManager
24
23
  from .version import __version__
25
24
 
26
25
  __all__ = [
@@ -52,8 +51,5 @@ __all__ = [
52
51
  "validate_with",
53
52
 
54
53
  # Enums
55
- "ActivityType",
56
-
57
- # Cache
58
- "CacheManager"
54
+ "ActivityType"
59
55
  ]
@@ -1,12 +1,13 @@
1
1
  """기본 레포지토리 모듈."""
2
- from typing import TypeVar, Generic, Dict, Any, List, Optional, Type
3
- from sqlalchemy.orm import DeclarativeBase
2
+ from typing import TypeVar, Generic, Dict, Any, List, Optional, Type, Union
3
+ from sqlalchemy.orm import DeclarativeBase, Load
4
4
  from sqlalchemy.exc import IntegrityError, SQLAlchemyError
5
5
  from sqlalchemy import select, or_, and_
6
- from .database import DatabaseManager
6
+ from .database import DatabaseService
7
7
  from .exceptions import CustomException, ErrorCode
8
8
  from sqlalchemy.orm import joinedload
9
9
  from sqlalchemy.sql import Select
10
+ from fastapi import Request
10
11
 
11
12
  ModelType = TypeVar("ModelType", bound=DeclarativeBase)
12
13
 
@@ -14,10 +15,10 @@ class BaseRepository(Generic[ModelType]):
14
15
  ##################
15
16
  # 1. 초기화 영역 #
16
17
  ##################
17
- def __init__(self, db_service: DatabaseManager, model: Type[ModelType]):
18
+ def __init__(self, db_service: DatabaseService, model: Type[ModelType]):
18
19
  """
19
20
  Args:
20
- db_service (DatabaseManager): 데이터베이스 서비스 인스턴스
21
+ db_service (DatabaseService): 데이터베이스 서비스 인스턴스
21
22
  model (Type[ModelType]): 모델 클래스
22
23
  """
23
24
  self.db_service = db_service
@@ -1,9 +1,9 @@
1
1
  """기본 서비스 모듈."""
2
2
  from datetime import datetime
3
3
  from typing import TypeVar, Generic, Dict, Any, List, Optional, Type, Union
4
- from sqlalchemy.orm import DeclarativeBase
4
+ from sqlalchemy.orm import DeclarativeBase, Load
5
5
  from sqlalchemy.exc import IntegrityError, SQLAlchemyError
6
- from .database import DatabaseManager
6
+ from .database import DatabaseService
7
7
  from .exceptions import CustomException, ErrorCode
8
8
  from .base_repository import BaseRepository
9
9
  from .security import hash_password
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Dict, Optional, Type, AsyncGenerator, TypeVar, List, Union
2
2
  from sqlalchemy import select, update, and_, Table
3
3
  from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, AsyncEngine
4
- from sqlalchemy.orm import sessionmaker, Load, joinedload
4
+ from sqlalchemy.orm import sessionmaker
5
5
  from sqlalchemy.exc import IntegrityError, SQLAlchemyError
6
6
  from sqlalchemy.pool import QueuePool
7
7
  from contextlib import asynccontextmanager
@@ -18,32 +18,57 @@ from .config import settings
18
18
  T = TypeVar("T", bound=BaseColumn)
19
19
 
20
20
  class DatabaseService:
21
- def __init__(self, db_url: str):
21
+ def __init__(
22
+ self,
23
+ db_url: str = None,
24
+ session: AsyncSession = None,
25
+ db_echo: bool = False,
26
+ db_pool_size: int = 5,
27
+ db_max_overflow: int = 10,
28
+ db_pool_timeout: int = 30,
29
+ db_pool_recycle: int = 1800
30
+ ):
22
31
  """DatabaseService 초기화.
23
32
 
24
33
  Args:
25
- db_url: 데이터베이스 URL
34
+ db_url (str, optional): 데이터베이스 URL
35
+ session (AsyncSession, optional): 기존 세션
36
+ db_echo (bool, optional): SQL 로깅 여부
37
+ db_pool_size (int, optional): DB 커넥션 풀 크기
38
+ db_max_overflow (int, optional): 최대 초과 커넥션 수
39
+ db_pool_timeout (int, optional): 커넥션 풀 타임아웃
40
+ db_pool_recycle (int, optional): 커넥션 재활용 시간
26
41
  """
27
- self.engine = create_async_engine(
28
- db_url,
29
- echo=settings.DB_ECHO,
30
- pool_size=settings.DB_POOL_SIZE,
31
- max_overflow=settings.DB_MAX_OVERFLOW,
32
- pool_timeout=settings.DB_POOL_TIMEOUT,
33
- pool_recycle=settings.DB_POOL_RECYCLE,
34
- pool_pre_ping=True,
35
- poolclass=QueuePool,
36
- )
37
-
38
- self.async_session = sessionmaker(
39
- bind=self.engine,
40
- class_=AsyncSession,
41
- expire_on_commit=False
42
- )
43
-
42
+ if db_url:
43
+ self.engine = create_async_engine(
44
+ db_url,
45
+ echo=db_echo,
46
+ pool_size=db_pool_size,
47
+ max_overflow=db_max_overflow,
48
+ pool_timeout=db_pool_timeout,
49
+ pool_recycle=db_pool_recycle,
50
+ pool_pre_ping=True,
51
+ poolclass=QueuePool,
52
+ )
53
+ self.async_session = sessionmaker(
54
+ bind=self.engine,
55
+ class_=AsyncSession,
56
+ expire_on_commit=False
57
+ )
58
+ self.db = None
59
+ elif session:
60
+ self.engine = session.bind
61
+ self.async_session = None
62
+ self.db = session
63
+ else:
64
+ raise ValueError("Either db_url or session must be provided")
65
+
44
66
  @asynccontextmanager
45
- async def get_db(self) -> AsyncGenerator[AsyncSession, None]:
46
- """데이터베이스 세션을 생성하고 반환하는 비동기 제너레이터."""
67
+ async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
68
+ """데이터베이스 세션을 생성하고 반환하는 비동기 컨텍스트 매니저."""
69
+ if self.async_session is None:
70
+ raise RuntimeError("Session factory not initialized")
71
+
47
72
  async with self.async_session() as session:
48
73
  try:
49
74
  yield session
@@ -1,10 +1,10 @@
1
- from typing import Type, Dict, Tuple, Any
1
+ from typing import Type, Dict, Tuple, Any, AsyncGenerator, Callable
2
2
  from sqlalchemy.ext.asyncio import AsyncSession
3
3
  from fastapi import Depends, status
4
4
  from fastapi.security import OAuth2PasswordBearer
5
5
  from jose import JWTError, jwt
6
6
 
7
- from .database import get_db, DatabaseService
7
+ from .database import DatabaseService
8
8
  from .exceptions import CustomException, ErrorCode
9
9
 
10
10
  class Settings:
@@ -14,16 +14,42 @@ class Settings:
14
14
  self.JWT_ALGORITHM = jwt_algorithm
15
15
 
16
16
  _settings: Settings | None = None
17
-
18
- def init_settings(jwt_secret: str, jwt_algorithm: str = "HS256"):
17
+ _database_service: DatabaseService | None = None
18
+
19
+ def init_settings(
20
+ jwt_secret: str,
21
+ jwt_algorithm: str = "HS256",
22
+ db_url: str = None,
23
+ db_echo: bool = False,
24
+ db_pool_size: int = 5,
25
+ db_max_overflow: int = 10,
26
+ db_pool_timeout: int = 30,
27
+ db_pool_recycle: int = 1800
28
+ ):
19
29
  """설정 초기화 함수
20
30
 
21
31
  Args:
22
32
  jwt_secret (str): JWT 시크릿 키
23
33
  jwt_algorithm (str, optional): JWT 알고리즘. Defaults to "HS256".
34
+ db_url (str, optional): 데이터베이스 URL
35
+ db_echo (bool, optional): SQL 로깅 여부
36
+ db_pool_size (int, optional): DB 커넥션 풀 크기
37
+ db_max_overflow (int, optional): 최대 초과 커넥션 수
38
+ db_pool_timeout (int, optional): 커넥션 풀 타임아웃
39
+ db_pool_recycle (int, optional): 커넥션 재활용 시간
24
40
  """
25
- global _settings
41
+ global _settings, _database_service
26
42
  _settings = Settings(jwt_secret, jwt_algorithm)
43
+
44
+ if db_url:
45
+ _database_service = DatabaseService(
46
+ db_url=db_url,
47
+ db_echo=db_echo,
48
+ db_pool_size=db_pool_size,
49
+ db_max_overflow=db_max_overflow,
50
+ db_pool_timeout=db_pool_timeout,
51
+ db_pool_recycle=db_pool_recycle
52
+ )
27
53
 
28
54
  def get_settings() -> Settings:
29
55
  """현재 설정을 반환하는 함수
@@ -38,6 +64,39 @@ def get_settings() -> Settings:
38
64
  raise RuntimeError("Settings not initialized. Call init_settings first.")
39
65
  return _settings
40
66
 
67
+ def get_database_service() -> DatabaseService:
68
+ """DatabaseService 인스턴스를 반환하는 함수
69
+
70
+ Returns:
71
+ DatabaseService: DatabaseService 인스턴스
72
+
73
+ Raises:
74
+ RuntimeError: DatabaseService가 초기화되지 않은 경우
75
+ """
76
+ if _database_service is None:
77
+ raise RuntimeError("DatabaseService not initialized. Call init_settings with db_url first.")
78
+ return _database_service
79
+
80
+ async def get_db() -> AsyncGenerator[AsyncSession, None]:
81
+ """데이터베이스 세션을 생성하고 반환하는 비동기 제너레이터."""
82
+ db_service = get_database_service()
83
+ async with db_service.get_session() as session:
84
+ try:
85
+ yield session
86
+ finally:
87
+ await session.close()
88
+
89
+ def get_database_session(db: AsyncSession = Depends(get_db)) -> DatabaseService:
90
+ """DatabaseService 의존성
91
+
92
+ Args:
93
+ db (AsyncSession): 데이터베이스 세션
94
+
95
+ Returns:
96
+ DatabaseService: DatabaseService 인스턴스
97
+ """
98
+ return DatabaseService(session=db)
99
+
41
100
  class ServiceRegistry:
42
101
  """서비스 레지스트리를 관리하는 클래스"""
43
102
  def __init__(self):
@@ -81,17 +140,6 @@ class ServiceRegistry:
81
140
  # ServiceRegistry 초기화
82
141
  service_registry = ServiceRegistry()
83
142
 
84
- def get_database_service(db: AsyncSession = Depends(get_db)) -> DatabaseService:
85
- """DatabaseService 의존성
86
-
87
- Args:
88
- db (AsyncSession): 데이터베이스 세션
89
-
90
- Returns:
91
- DatabaseService: DatabaseService 인스턴스
92
- """
93
- return DatabaseService(db)
94
-
95
143
  def get_service(name: str):
96
144
  """등록된 서비스를 가져오는 의존성 함수
97
145
 
@@ -145,9 +193,16 @@ async def get_current_user(
145
193
  source_function="dependencies.py / get_current_user"
146
194
  )
147
195
 
148
- from app.user.repository import UserRepository
149
- user_repo = UserRepository(db_service)
150
- user = await user_repo.get_user(user_ulid, by="ulid")
196
+ try:
197
+ repository_class, _ = service_registry.get("user")
198
+ user_repo = repository_class(db_service)
199
+ user = await user_repo.get_user(user_ulid, by="ulid")
200
+ except ValueError:
201
+ raise CustomException(
202
+ ErrorCode.SERVICE_NOT_REGISTERED,
203
+ detail="User service is not registered",
204
+ source_function="dependencies.py / get_current_user"
205
+ )
151
206
 
152
207
  if not user:
153
208
  raise CustomException(
@@ -1,13 +1,16 @@
1
1
  """유효성 검사 관련 유틸리티 함수들을 모아둔 모듈입니다."""
2
2
 
3
- from typing import Any, Optional, Dict
3
+ from typing import Type, Dict, Any, Callable, TypeVar, Optional, List, Union
4
+ from functools import wraps
5
+ from sqlalchemy import Table
4
6
  from sqlalchemy.ext.asyncio import AsyncSession
5
- from sqlalchemy import select
6
- from pydantic import field_validator
7
+ from fastapi import Request
8
+ from inspect import signature
9
+ from pydantic import BaseModel, field_validator
7
10
  import re
8
11
 
9
12
  from .exceptions import ErrorCode, CustomException
10
- from .database import DatabaseManager
13
+ from .database import DatabaseService
11
14
  from .base_model import Base
12
15
 
13
16
  def validate_with(validator_func, unique_check=None, skip_if_none=False):
@@ -0,0 +1,2 @@
1
+ """버전 정보"""
2
+ __version__ = "0.2.9"
@@ -1,2 +0,0 @@
1
- """버전 정보"""
2
- __version__ = "0.2.3"
File without changes
File without changes
File without changes
File without changes