aegra-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aegra_api/__init__.py +3 -0
- aegra_api/api/__init__.py +1 -0
- aegra_api/api/assistants.py +235 -0
- aegra_api/api/runs.py +1110 -0
- aegra_api/api/store.py +200 -0
- aegra_api/api/threads.py +761 -0
- aegra_api/config.py +204 -0
- aegra_api/constants.py +5 -0
- aegra_api/core/__init__.py +0 -0
- aegra_api/core/app_loader.py +91 -0
- aegra_api/core/auth_ctx.py +65 -0
- aegra_api/core/auth_deps.py +186 -0
- aegra_api/core/auth_handlers.py +248 -0
- aegra_api/core/auth_middleware.py +331 -0
- aegra_api/core/database.py +123 -0
- aegra_api/core/health.py +131 -0
- aegra_api/core/orm.py +165 -0
- aegra_api/core/route_merger.py +69 -0
- aegra_api/core/serializers/__init__.py +7 -0
- aegra_api/core/serializers/base.py +22 -0
- aegra_api/core/serializers/general.py +54 -0
- aegra_api/core/serializers/langgraph.py +102 -0
- aegra_api/core/sse.py +178 -0
- aegra_api/main.py +303 -0
- aegra_api/middleware/__init__.py +4 -0
- aegra_api/middleware/double_encoded_json.py +74 -0
- aegra_api/middleware/logger_middleware.py +95 -0
- aegra_api/models/__init__.py +76 -0
- aegra_api/models/assistants.py +81 -0
- aegra_api/models/auth.py +62 -0
- aegra_api/models/enums.py +29 -0
- aegra_api/models/errors.py +29 -0
- aegra_api/models/runs.py +124 -0
- aegra_api/models/store.py +67 -0
- aegra_api/models/threads.py +152 -0
- aegra_api/observability/__init__.py +1 -0
- aegra_api/observability/base.py +88 -0
- aegra_api/observability/otel.py +133 -0
- aegra_api/observability/setup.py +27 -0
- aegra_api/observability/targets/__init__.py +11 -0
- aegra_api/observability/targets/base.py +18 -0
- aegra_api/observability/targets/langfuse.py +33 -0
- aegra_api/observability/targets/otlp.py +38 -0
- aegra_api/observability/targets/phoenix.py +24 -0
- aegra_api/services/__init__.py +0 -0
- aegra_api/services/assistant_service.py +569 -0
- aegra_api/services/base_broker.py +59 -0
- aegra_api/services/broker.py +141 -0
- aegra_api/services/event_converter.py +157 -0
- aegra_api/services/event_store.py +196 -0
- aegra_api/services/graph_streaming.py +433 -0
- aegra_api/services/langgraph_service.py +456 -0
- aegra_api/services/streaming_service.py +362 -0
- aegra_api/services/thread_state_service.py +128 -0
- aegra_api/settings.py +124 -0
- aegra_api/utils/__init__.py +3 -0
- aegra_api/utils/assistants.py +23 -0
- aegra_api/utils/run_utils.py +60 -0
- aegra_api/utils/setup_logging.py +122 -0
- aegra_api/utils/sse_utils.py +26 -0
- aegra_api/utils/status_compat.py +57 -0
- aegra_api-0.1.0.dist-info/METADATA +244 -0
- aegra_api-0.1.0.dist-info/RECORD +64 -0
- aegra_api-0.1.0.dist-info/WHEEL +4 -0
aegra_api/core/health.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Health check endpoints"""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter, HTTPException, Request
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from sqlalchemy import text
|
|
8
|
+
|
|
9
|
+
from aegra_api.core.database import db_manager
|
|
10
|
+
|
|
11
|
+
router = APIRouter()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HealthResponse(BaseModel):
|
|
15
|
+
"""Health check response model"""
|
|
16
|
+
|
|
17
|
+
status: str
|
|
18
|
+
database: str
|
|
19
|
+
langgraph_checkpointer: str
|
|
20
|
+
langgraph_store: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InfoResponse(BaseModel):
|
|
24
|
+
"""Info endpoint response model"""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
version: str
|
|
28
|
+
description: str
|
|
29
|
+
status: str
|
|
30
|
+
flags: dict
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@router.get("/info", response_model=InfoResponse)
|
|
34
|
+
async def info(_request: Request) -> InfoResponse:
|
|
35
|
+
"""Simple service information endpoint"""
|
|
36
|
+
return InfoResponse(
|
|
37
|
+
name="Aegra",
|
|
38
|
+
version="0.1.0",
|
|
39
|
+
description="Production-ready Agent Protocol server built on LangGraph",
|
|
40
|
+
status="running",
|
|
41
|
+
flags={"assistants": True, "crons": False},
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@router.get("/health", response_model=HealthResponse)
|
|
46
|
+
async def health_check(_request: Request) -> dict[str, str]:
|
|
47
|
+
"""Core health check handler logic"""
|
|
48
|
+
health_status = {
|
|
49
|
+
"status": "healthy",
|
|
50
|
+
"database": "unknown",
|
|
51
|
+
"langgraph_checkpointer": "unknown",
|
|
52
|
+
"langgraph_store": "unknown",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
# Database connectivity
|
|
56
|
+
try:
|
|
57
|
+
if db_manager.engine:
|
|
58
|
+
async with db_manager.engine.begin() as conn:
|
|
59
|
+
await conn.execute(text("SELECT 1"))
|
|
60
|
+
health_status["database"] = "connected"
|
|
61
|
+
else:
|
|
62
|
+
health_status["database"] = "not_initialized"
|
|
63
|
+
health_status["status"] = "unhealthy"
|
|
64
|
+
except Exception as e:
|
|
65
|
+
health_status["database"] = f"error: {str(e)}"
|
|
66
|
+
health_status["status"] = "unhealthy"
|
|
67
|
+
|
|
68
|
+
# LangGraph checkpointer (lazy-init)
|
|
69
|
+
try:
|
|
70
|
+
checkpointer = db_manager.get_checkpointer()
|
|
71
|
+
# probe - will raise if connection is bad; tuple may not exist which is fine
|
|
72
|
+
with contextlib.suppress(Exception):
|
|
73
|
+
await checkpointer.aget_tuple({"configurable": {"thread_id": "health-check"}})
|
|
74
|
+
health_status["langgraph_checkpointer"] = "connected"
|
|
75
|
+
except Exception as e:
|
|
76
|
+
health_status["langgraph_checkpointer"] = f"error: {str(e)}"
|
|
77
|
+
health_status["status"] = "unhealthy"
|
|
78
|
+
|
|
79
|
+
# LangGraph store (lazy-init)
|
|
80
|
+
try:
|
|
81
|
+
store = db_manager.get_store()
|
|
82
|
+
with contextlib.suppress(Exception):
|
|
83
|
+
await store.aget(("health",), "check")
|
|
84
|
+
health_status["langgraph_store"] = "connected"
|
|
85
|
+
except Exception as e:
|
|
86
|
+
health_status["langgraph_store"] = f"error: {str(e)}"
|
|
87
|
+
health_status["status"] = "unhealthy"
|
|
88
|
+
|
|
89
|
+
if health_status["status"] == "unhealthy":
|
|
90
|
+
raise HTTPException(status_code=503, detail="Service unhealthy")
|
|
91
|
+
|
|
92
|
+
return health_status
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@router.get("/ready")
|
|
96
|
+
async def readiness_check(_request: Request) -> dict[str, str]:
|
|
97
|
+
"""Kubernetes readiness probe endpoint"""
|
|
98
|
+
# Engine must exist and respond to a trivial query
|
|
99
|
+
if not db_manager.engine:
|
|
100
|
+
raise HTTPException(
|
|
101
|
+
status_code=503,
|
|
102
|
+
detail="Service not ready - database engine not initialized",
|
|
103
|
+
)
|
|
104
|
+
try:
|
|
105
|
+
async with db_manager.engine.begin() as conn:
|
|
106
|
+
await conn.execute(text("SELECT 1"))
|
|
107
|
+
except Exception as e:
|
|
108
|
+
raise HTTPException(status_code=503, detail=f"Service not ready - database error: {str(e)}") from e
|
|
109
|
+
|
|
110
|
+
# Check that LangGraph components can be obtained (lazy init) and respond
|
|
111
|
+
try:
|
|
112
|
+
checkpointer = db_manager.get_checkpointer()
|
|
113
|
+
store = db_manager.get_store()
|
|
114
|
+
# lightweight probes
|
|
115
|
+
with contextlib.suppress(Exception):
|
|
116
|
+
await checkpointer.aget_tuple({"configurable": {"thread_id": "ready-check"}})
|
|
117
|
+
with contextlib.suppress(Exception):
|
|
118
|
+
await store.aget(("ready",), "check")
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise HTTPException(
|
|
121
|
+
status_code=503,
|
|
122
|
+
detail=f"Service not ready - components unavailable: {str(e)}",
|
|
123
|
+
) from e
|
|
124
|
+
|
|
125
|
+
return {"status": "ready"}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@router.get("/live")
|
|
129
|
+
async def liveness_check(_request: Request) -> dict[str, str]:
|
|
130
|
+
"""Kubernetes liveness probe endpoint"""
|
|
131
|
+
return {"status": "alive"}
|
aegra_api/core/orm.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""SQLAlchemy ORM setup for persistent assistant/thread/run records.
|
|
2
|
+
|
|
3
|
+
This module creates:
|
|
4
|
+
• `Base` – the declarative base used by our models.
|
|
5
|
+
• `Assistant`, `Thread`, `Run` – ORM models mirroring the bootstrap tables
|
|
6
|
+
already created in ``DatabaseManager._create_metadata_tables``.
|
|
7
|
+
• `async_session_maker` – a factory that hands out `AsyncSession` objects
|
|
8
|
+
bound to the shared engine managed by `db_manager`.
|
|
9
|
+
• `get_session` – FastAPI dependency helper for routers.
|
|
10
|
+
|
|
11
|
+
Nothing is auto-imported by FastAPI yet; routers will `from ...core.db import get_session`.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from collections.abc import AsyncIterator
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
|
|
19
|
+
from sqlalchemy import (
|
|
20
|
+
TIMESTAMP,
|
|
21
|
+
ForeignKey,
|
|
22
|
+
Index,
|
|
23
|
+
Integer,
|
|
24
|
+
Text,
|
|
25
|
+
text,
|
|
26
|
+
)
|
|
27
|
+
from sqlalchemy.dialects.postgresql import JSONB
|
|
28
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
29
|
+
from sqlalchemy.orm import Mapped, declarative_base, mapped_column
|
|
30
|
+
|
|
31
|
+
Base = declarative_base()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Assistant(Base):
|
|
35
|
+
__tablename__ = "assistant"
|
|
36
|
+
|
|
37
|
+
# TEXT PK with DB-side generation using uuid_generate_v4()::text
|
|
38
|
+
assistant_id: Mapped[str] = mapped_column(
|
|
39
|
+
Text, primary_key=True, server_default=text("public.uuid_generate_v4()::text")
|
|
40
|
+
)
|
|
41
|
+
name: Mapped[str] = mapped_column(Text, nullable=False)
|
|
42
|
+
description: Mapped[str | None] = mapped_column(Text)
|
|
43
|
+
graph_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
44
|
+
config: Mapped[dict] = mapped_column(JSONB, server_default=text("'{}'::jsonb"))
|
|
45
|
+
context: Mapped[dict] = mapped_column(JSONB, server_default=text("'{}'::jsonb"))
|
|
46
|
+
user_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
47
|
+
version: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("1"))
|
|
48
|
+
metadata_dict: Mapped[dict] = mapped_column(JSONB, server_default=text("'{}'::jsonb"), name="metadata")
|
|
49
|
+
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
50
|
+
updated_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
51
|
+
|
|
52
|
+
# Indexes for performance
|
|
53
|
+
__table_args__ = (
|
|
54
|
+
Index("idx_assistant_user", "user_id"),
|
|
55
|
+
Index("idx_assistant_user_assistant", "user_id", "assistant_id", unique=True),
|
|
56
|
+
Index(
|
|
57
|
+
"idx_assistant_user_graph_config",
|
|
58
|
+
"user_id",
|
|
59
|
+
"graph_id",
|
|
60
|
+
"config",
|
|
61
|
+
unique=True,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AssistantVersion(Base):
|
|
67
|
+
__tablename__ = "assistant_versions"
|
|
68
|
+
|
|
69
|
+
assistant_id: Mapped[str] = mapped_column(
|
|
70
|
+
Text, ForeignKey("assistant.assistant_id", ondelete="CASCADE"), primary_key=True
|
|
71
|
+
)
|
|
72
|
+
version: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
73
|
+
graph_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
74
|
+
config: Mapped[dict | None] = mapped_column(JSONB)
|
|
75
|
+
context: Mapped[dict | None] = mapped_column(JSONB)
|
|
76
|
+
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
77
|
+
metadata_dict: Mapped[dict] = mapped_column(JSONB, server_default=text("'{}'::jsonb"), name="metadata")
|
|
78
|
+
name: Mapped[str | None] = mapped_column(Text)
|
|
79
|
+
description: Mapped[str | None] = mapped_column(Text)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Thread(Base):
|
|
83
|
+
__tablename__ = "thread"
|
|
84
|
+
|
|
85
|
+
thread_id: Mapped[str] = mapped_column(Text, primary_key=True)
|
|
86
|
+
status: Mapped[str] = mapped_column(Text, server_default=text("'idle'"))
|
|
87
|
+
# Database column is 'metadata_json' (per database.py). ORM attribute 'metadata_json' must map to that column.
|
|
88
|
+
metadata_json: Mapped[dict] = mapped_column("metadata_json", JSONB, server_default=text("'{}'::jsonb"))
|
|
89
|
+
user_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
90
|
+
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
91
|
+
updated_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
92
|
+
|
|
93
|
+
# Indexes for performance
|
|
94
|
+
__table_args__ = (Index("idx_thread_user", "user_id"),)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Run(Base):
|
|
98
|
+
__tablename__ = "runs"
|
|
99
|
+
|
|
100
|
+
# TEXT PK with DB-side generation using uuid_generate_v4()::text
|
|
101
|
+
run_id: Mapped[str] = mapped_column(Text, primary_key=True, server_default=text("public.uuid_generate_v4()::text"))
|
|
102
|
+
thread_id: Mapped[str] = mapped_column(Text, ForeignKey("thread.thread_id", ondelete="CASCADE"), nullable=False)
|
|
103
|
+
assistant_id: Mapped[str | None] = mapped_column(Text, ForeignKey("assistant.assistant_id", ondelete="CASCADE"))
|
|
104
|
+
status: Mapped[str] = mapped_column(Text, server_default=text("'pending'"))
|
|
105
|
+
input: Mapped[dict | None] = mapped_column(JSONB, server_default=text("'{}'::jsonb"))
|
|
106
|
+
# Some environments may not yet have a 'config' column; make it nullable without default to match existing DB.
|
|
107
|
+
# If migrations add this column later, it's already represented here.
|
|
108
|
+
config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
|
109
|
+
context: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
|
110
|
+
output: Mapped[dict | None] = mapped_column(JSONB)
|
|
111
|
+
error_message: Mapped[str | None] = mapped_column(Text)
|
|
112
|
+
user_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
113
|
+
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
114
|
+
updated_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
115
|
+
|
|
116
|
+
# Indexes for performance
|
|
117
|
+
__table_args__ = (
|
|
118
|
+
Index("idx_runs_thread_id", "thread_id"),
|
|
119
|
+
Index("idx_runs_user", "user_id"),
|
|
120
|
+
Index("idx_runs_status", "status"),
|
|
121
|
+
Index("idx_runs_assistant_id", "assistant_id"),
|
|
122
|
+
Index("idx_runs_created_at", "created_at"),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class RunEvent(Base):
|
|
127
|
+
__tablename__ = "run_events"
|
|
128
|
+
|
|
129
|
+
id: Mapped[str] = mapped_column(Text, primary_key=True)
|
|
130
|
+
run_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
131
|
+
seq: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
132
|
+
event: Mapped[str] = mapped_column(Text, nullable=False)
|
|
133
|
+
data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
|
134
|
+
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), server_default=text("now()"))
|
|
135
|
+
|
|
136
|
+
# Indexes for performance
|
|
137
|
+
__table_args__ = (
|
|
138
|
+
Index("idx_run_events_run_id", "run_id"),
|
|
139
|
+
Index("idx_run_events_seq", "run_id", "seq"),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# ---------------------------------------------------------------------------
|
|
144
|
+
# Session factory
|
|
145
|
+
# ---------------------------------------------------------------------------
|
|
146
|
+
|
|
147
|
+
async_session_maker: async_sessionmaker[AsyncSession] | None = None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _get_session_maker() -> async_sessionmaker[AsyncSession]:
|
|
151
|
+
"""Return a cached async_sessionmaker bound to db_manager.engine."""
|
|
152
|
+
global async_session_maker
|
|
153
|
+
if async_session_maker is None:
|
|
154
|
+
from aegra_api.core.database import db_manager
|
|
155
|
+
|
|
156
|
+
engine = db_manager.get_engine()
|
|
157
|
+
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
|
158
|
+
return async_session_maker
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
async def get_session() -> AsyncIterator[AsyncSession]:
|
|
162
|
+
"""FastAPI dependency that yields an AsyncSession."""
|
|
163
|
+
maker = _get_session_maker()
|
|
164
|
+
async with maker() as session:
|
|
165
|
+
yield session
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Route merging utilities for combining custom apps with Aegra core routes"""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
from fastapi import FastAPI
|
|
8
|
+
|
|
9
|
+
logger = structlog.get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def merge_lifespans(user_app: FastAPI, core_lifespan: Callable) -> FastAPI:
|
|
13
|
+
"""Merge user lifespan with Aegra's core lifespan.
|
|
14
|
+
|
|
15
|
+
Both lifespans will run, with core lifespan wrapping user lifespan.
|
|
16
|
+
This ensures Aegra's initialization (database, services) happens before
|
|
17
|
+
user initialization, and cleanup happens in reverse order.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
user_app: User's FastAPI/Starlette application
|
|
21
|
+
core_lifespan: Aegra's core lifespan context manager
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Modified user_app with merged lifespan
|
|
25
|
+
"""
|
|
26
|
+
user_lifespan = user_app.router.lifespan_context
|
|
27
|
+
|
|
28
|
+
# Check for deprecated on_startup/on_shutdown handlers
|
|
29
|
+
if user_app.router.on_startup or user_app.router.on_shutdown:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Cannot merge lifespans with on_startup or on_shutdown handlers. "
|
|
32
|
+
f"Please use lifespan context manager instead. "
|
|
33
|
+
f"Found: on_startup={user_app.router.on_startup}, "
|
|
34
|
+
f"on_shutdown={user_app.router.on_shutdown}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@asynccontextmanager
|
|
38
|
+
async def combined_lifespan(app):
|
|
39
|
+
async with core_lifespan(app):
|
|
40
|
+
if user_lifespan:
|
|
41
|
+
async with user_lifespan(app):
|
|
42
|
+
yield
|
|
43
|
+
else:
|
|
44
|
+
yield
|
|
45
|
+
|
|
46
|
+
user_app.router.lifespan_context = combined_lifespan
|
|
47
|
+
return user_app
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def merge_exception_handlers(user_app: FastAPI, core_exception_handlers: dict[type, Callable]) -> FastAPI:
|
|
51
|
+
"""Merge core exception handlers with user exception handlers.
|
|
52
|
+
|
|
53
|
+
Core handlers are added only if user hasn't defined a handler for that exception type.
|
|
54
|
+
User handlers take precedence.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
user_app: User's FastAPI/Starlette application
|
|
58
|
+
core_exception_handlers: Aegra's core exception handlers
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Modified user_app with merged exception handlers
|
|
62
|
+
"""
|
|
63
|
+
for exc_type, handler in core_exception_handlers.items():
|
|
64
|
+
if exc_type not in user_app.exception_handlers:
|
|
65
|
+
user_app.exception_handlers[exc_type] = handler
|
|
66
|
+
else:
|
|
67
|
+
logger.debug(f"User app overrides exception handler for {exc_type}")
|
|
68
|
+
|
|
69
|
+
return user_app
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Serialization layer for LangGraph and general objects"""
|
|
2
|
+
|
|
3
|
+
from aegra_api.core.serializers.base import Serializer
|
|
4
|
+
from aegra_api.core.serializers.general import GeneralSerializer
|
|
5
|
+
from aegra_api.core.serializers.langgraph import LangGraphSerializer
|
|
6
|
+
|
|
7
|
+
__all__ = ["Serializer", "GeneralSerializer", "LangGraphSerializer"]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Base serialization interface"""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Serializer(ABC):
|
|
8
|
+
"""Abstract base class for object serialization"""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def serialize(self, obj: Any) -> Any:
|
|
12
|
+
"""Serialize an object to a JSON-compatible format"""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SerializationError(Exception):
|
|
17
|
+
"""Raised when serialization fails"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, message: str, obj_type: str, original_error: Exception | None = None):
|
|
20
|
+
super().__init__(message)
|
|
21
|
+
self.obj_type = obj_type
|
|
22
|
+
self.original_error = original_error
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""General-purpose object serialization for complex objects"""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from aegra_api.core.serializers.base import SerializationError, Serializer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GeneralSerializer(Serializer):
|
|
9
|
+
"""Simple object serializer for complex Python objects"""
|
|
10
|
+
|
|
11
|
+
def serialize(self, obj: Any) -> Any:
|
|
12
|
+
"""Serialize any object to JSON-compatible format"""
|
|
13
|
+
try:
|
|
14
|
+
return self._serialize_object(obj)
|
|
15
|
+
except Exception as e:
|
|
16
|
+
raise SerializationError(f"Failed to serialize object: {str(e)}", obj.__class__.__name__, e) from e
|
|
17
|
+
|
|
18
|
+
def _serialize_object(self, obj: Any) -> Any:
|
|
19
|
+
"""Core serialization logic for Python objects"""
|
|
20
|
+
# Handle Pydantic v2 models (model_dump method)
|
|
21
|
+
if hasattr(obj, "model_dump") and callable(obj.model_dump):
|
|
22
|
+
return obj.model_dump()
|
|
23
|
+
|
|
24
|
+
# Handle LangChain objects and Pydantic v1 models (dict method)
|
|
25
|
+
elif hasattr(obj, "dict") and callable(obj.dict):
|
|
26
|
+
return obj.dict()
|
|
27
|
+
|
|
28
|
+
# Handle LangGraph Interrupt objects (they don't have .dict() method)
|
|
29
|
+
elif obj.__class__.__name__ == "Interrupt" and hasattr(obj, "value") and hasattr(obj, "id"):
|
|
30
|
+
return {"value": self._serialize_object(obj.value), "id": obj.id}
|
|
31
|
+
|
|
32
|
+
# Handle NamedTuples (like PregelTask) - they have _asdict() method
|
|
33
|
+
elif hasattr(obj, "_asdict") and callable(obj._asdict):
|
|
34
|
+
return {k: self._serialize_object(v) for k, v in obj._asdict().items()}
|
|
35
|
+
|
|
36
|
+
# Handle sets and frozensets
|
|
37
|
+
elif isinstance(obj, (set, frozenset)):
|
|
38
|
+
return list(obj)
|
|
39
|
+
|
|
40
|
+
# Handle tuples and lists recursively
|
|
41
|
+
elif isinstance(obj, (tuple, list)):
|
|
42
|
+
return [self._serialize_object(item) for item in obj]
|
|
43
|
+
|
|
44
|
+
# Handle dictionaries recursively
|
|
45
|
+
elif isinstance(obj, dict):
|
|
46
|
+
return {k: self._serialize_object(v) for k, v in obj.items()}
|
|
47
|
+
|
|
48
|
+
# Handle basic JSON-serializable types
|
|
49
|
+
elif isinstance(obj, (str, int, float, bool, type(None))):
|
|
50
|
+
return obj
|
|
51
|
+
|
|
52
|
+
# Fallback to string representation for unknown types
|
|
53
|
+
else:
|
|
54
|
+
return str(obj)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""LangGraph-specific serialization"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
from aegra_api.core.serializers.base import SerializationError, Serializer
|
|
9
|
+
from aegra_api.core.serializers.general import GeneralSerializer
|
|
10
|
+
|
|
11
|
+
logger = structlog.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LangGraphSerializer(Serializer):
|
|
15
|
+
"""Handles serialization of LangGraph objects (tasks, interrupts, snapshots)"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self.general_serializer = GeneralSerializer()
|
|
19
|
+
|
|
20
|
+
def serialize(self, obj: Any) -> Any:
|
|
21
|
+
"""Main serialization entry point"""
|
|
22
|
+
return json.loads(json.dumps(obj, default=self.general_serializer.serialize))
|
|
23
|
+
|
|
24
|
+
def serialize_task(self, task: Any) -> dict[str, Any]:
|
|
25
|
+
"""Serialize a LangGraph task to ThreadTask format"""
|
|
26
|
+
try:
|
|
27
|
+
if hasattr(task, "id") and hasattr(task, "name"):
|
|
28
|
+
# Proper task object
|
|
29
|
+
task_dict = {
|
|
30
|
+
"id": getattr(task, "id", ""),
|
|
31
|
+
"name": getattr(task, "name", ""),
|
|
32
|
+
"error": getattr(task, "error", None),
|
|
33
|
+
"interrupts": [],
|
|
34
|
+
"checkpoint": None,
|
|
35
|
+
"state": getattr(task, "state", None),
|
|
36
|
+
"result": getattr(task, "result", None),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Handle task interrupts
|
|
40
|
+
if hasattr(task, "interrupts") and task.interrupts:
|
|
41
|
+
task_dict["interrupts"] = self.serialize(task.interrupts)
|
|
42
|
+
|
|
43
|
+
return task_dict
|
|
44
|
+
else:
|
|
45
|
+
# Raw task data - serialize as-is but safely
|
|
46
|
+
serialized_task = self.serialize(task)
|
|
47
|
+
if isinstance(serialized_task, dict):
|
|
48
|
+
return serialized_task
|
|
49
|
+
else:
|
|
50
|
+
raise SerializationError(
|
|
51
|
+
f"Task serialization resulted in non-dict: {type(serialized_task)}",
|
|
52
|
+
task.__class__.__name__,
|
|
53
|
+
)
|
|
54
|
+
except Exception as e:
|
|
55
|
+
if isinstance(e, SerializationError):
|
|
56
|
+
raise
|
|
57
|
+
raise SerializationError(f"Failed to serialize task: {str(e)}", task.__class__.__name__, e) from e
|
|
58
|
+
|
|
59
|
+
def serialize_interrupt(self, interrupt: Any) -> dict[str, Any]:
|
|
60
|
+
"""Serialize a LangGraph interrupt"""
|
|
61
|
+
try:
|
|
62
|
+
return self.serialize(interrupt)
|
|
63
|
+
except Exception as e:
|
|
64
|
+
raise SerializationError(
|
|
65
|
+
f"Failed to serialize interrupt: {str(e)}",
|
|
66
|
+
interrupt.__class__.__name__,
|
|
67
|
+
e,
|
|
68
|
+
) from e
|
|
69
|
+
|
|
70
|
+
def extract_tasks_from_snapshot(self, snapshot: Any) -> list[dict[str, Any]]:
|
|
71
|
+
"""Extract and serialize tasks from a snapshot"""
|
|
72
|
+
tasks = []
|
|
73
|
+
|
|
74
|
+
if not (hasattr(snapshot, "tasks") and snapshot.tasks):
|
|
75
|
+
return tasks
|
|
76
|
+
|
|
77
|
+
for task in snapshot.tasks:
|
|
78
|
+
try:
|
|
79
|
+
serialized_task = self.serialize_task(task)
|
|
80
|
+
tasks.append(serialized_task)
|
|
81
|
+
except SerializationError as e:
|
|
82
|
+
logger.warning(
|
|
83
|
+
f"Task serialization failed, skipping task: {e} "
|
|
84
|
+
f"(task_type={type(task).__name__}, task_id={getattr(task, 'id', 'unknown')})"
|
|
85
|
+
)
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
return tasks
|
|
89
|
+
|
|
90
|
+
def extract_interrupts_from_snapshot(self, snapshot: Any) -> list[dict[str, Any]]:
|
|
91
|
+
"""Extract and serialize interrupts from a snapshot"""
|
|
92
|
+
interrupts = []
|
|
93
|
+
if hasattr(snapshot, "interrupts") and snapshot.interrupts:
|
|
94
|
+
try:
|
|
95
|
+
interrupts = self.serialize(snapshot.interrupts)
|
|
96
|
+
if interrupts:
|
|
97
|
+
return interrupts
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.warning(
|
|
100
|
+
f"Snapshot interrupt serialization failed: {e} (snapshot_type={type(snapshot).__name__})"
|
|
101
|
+
)
|
|
102
|
+
return interrupts if isinstance(interrupts, list) else []
|