kairo-code 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- image-service/main.py +178 -0
- infra/chat/app/main.py +84 -0
- kairo/backend/__init__.py +0 -0
- kairo/backend/api/__init__.py +0 -0
- kairo/backend/api/admin/__init__.py +23 -0
- kairo/backend/api/admin/audit.py +54 -0
- kairo/backend/api/admin/content.py +142 -0
- kairo/backend/api/admin/incidents.py +148 -0
- kairo/backend/api/admin/stats.py +125 -0
- kairo/backend/api/admin/system.py +87 -0
- kairo/backend/api/admin/users.py +279 -0
- kairo/backend/api/agents.py +94 -0
- kairo/backend/api/api_keys.py +85 -0
- kairo/backend/api/auth.py +116 -0
- kairo/backend/api/billing.py +41 -0
- kairo/backend/api/chat.py +72 -0
- kairo/backend/api/conversations.py +125 -0
- kairo/backend/api/device_auth.py +100 -0
- kairo/backend/api/files.py +83 -0
- kairo/backend/api/health.py +36 -0
- kairo/backend/api/images.py +80 -0
- kairo/backend/api/openai_compat.py +225 -0
- kairo/backend/api/projects.py +102 -0
- kairo/backend/api/usage.py +32 -0
- kairo/backend/api/webhooks.py +79 -0
- kairo/backend/app.py +297 -0
- kairo/backend/config.py +179 -0
- kairo/backend/core/__init__.py +0 -0
- kairo/backend/core/admin_auth.py +24 -0
- kairo/backend/core/api_key_auth.py +55 -0
- kairo/backend/core/database.py +28 -0
- kairo/backend/core/dependencies.py +70 -0
- kairo/backend/core/logging.py +23 -0
- kairo/backend/core/rate_limit.py +73 -0
- kairo/backend/core/security.py +29 -0
- kairo/backend/models/__init__.py +19 -0
- kairo/backend/models/agent.py +30 -0
- kairo/backend/models/api_key.py +25 -0
- kairo/backend/models/api_usage.py +29 -0
- kairo/backend/models/audit_log.py +26 -0
- kairo/backend/models/conversation.py +48 -0
- kairo/backend/models/device_code.py +30 -0
- kairo/backend/models/feature_flag.py +21 -0
- kairo/backend/models/image_generation.py +24 -0
- kairo/backend/models/incident.py +28 -0
- kairo/backend/models/project.py +28 -0
- kairo/backend/models/uptime_record.py +24 -0
- kairo/backend/models/usage.py +24 -0
- kairo/backend/models/user.py +49 -0
- kairo/backend/schemas/__init__.py +0 -0
- kairo/backend/schemas/admin/__init__.py +0 -0
- kairo/backend/schemas/admin/audit.py +28 -0
- kairo/backend/schemas/admin/content.py +53 -0
- kairo/backend/schemas/admin/stats.py +77 -0
- kairo/backend/schemas/admin/system.py +44 -0
- kairo/backend/schemas/admin/users.py +48 -0
- kairo/backend/schemas/agent.py +42 -0
- kairo/backend/schemas/api_key.py +30 -0
- kairo/backend/schemas/auth.py +57 -0
- kairo/backend/schemas/chat.py +26 -0
- kairo/backend/schemas/conversation.py +39 -0
- kairo/backend/schemas/device_auth.py +40 -0
- kairo/backend/schemas/image.py +15 -0
- kairo/backend/schemas/openai_compat.py +76 -0
- kairo/backend/schemas/project.py +21 -0
- kairo/backend/schemas/status.py +81 -0
- kairo/backend/schemas/usage.py +15 -0
- kairo/backend/services/__init__.py +0 -0
- kairo/backend/services/admin/__init__.py +0 -0
- kairo/backend/services/admin/audit_service.py +78 -0
- kairo/backend/services/admin/content_service.py +119 -0
- kairo/backend/services/admin/incident_service.py +94 -0
- kairo/backend/services/admin/stats_service.py +281 -0
- kairo/backend/services/admin/system_service.py +126 -0
- kairo/backend/services/admin/user_service.py +157 -0
- kairo/backend/services/agent_service.py +107 -0
- kairo/backend/services/api_key_service.py +66 -0
- kairo/backend/services/api_usage_service.py +126 -0
- kairo/backend/services/auth_service.py +101 -0
- kairo/backend/services/chat_service.py +501 -0
- kairo/backend/services/conversation_service.py +264 -0
- kairo/backend/services/device_auth_service.py +193 -0
- kairo/backend/services/email_service.py +55 -0
- kairo/backend/services/image_service.py +181 -0
- kairo/backend/services/llm_service.py +186 -0
- kairo/backend/services/project_service.py +109 -0
- kairo/backend/services/status_service.py +167 -0
- kairo/backend/services/stripe_service.py +78 -0
- kairo/backend/services/usage_service.py +150 -0
- kairo/backend/services/web_search_service.py +96 -0
- kairo/migrations/env.py +60 -0
- kairo/migrations/versions/001_initial.py +55 -0
- kairo/migrations/versions/002_usage_tracking_and_indexes.py +66 -0
- kairo/migrations/versions/003_username_to_email.py +21 -0
- kairo/migrations/versions/004_add_plans_and_verification.py +67 -0
- kairo/migrations/versions/005_add_projects.py +52 -0
- kairo/migrations/versions/006_add_image_generation.py +63 -0
- kairo/migrations/versions/007_add_admin_portal.py +107 -0
- kairo/migrations/versions/008_add_device_code_auth.py +76 -0
- kairo/migrations/versions/009_add_status_page.py +65 -0
- kairo/tools/extract_claude_data.py +465 -0
- kairo/tools/filter_claude_data.py +303 -0
- kairo/tools/generate_curated_data.py +157 -0
- kairo/tools/mix_training_data.py +295 -0
- kairo_code/__init__.py +3 -0
- kairo_code/agents/__init__.py +25 -0
- kairo_code/agents/architect.py +98 -0
- kairo_code/agents/audit.py +100 -0
- kairo_code/agents/base.py +463 -0
- kairo_code/agents/coder.py +155 -0
- kairo_code/agents/database.py +77 -0
- kairo_code/agents/docs.py +88 -0
- kairo_code/agents/explorer.py +62 -0
- kairo_code/agents/guardian.py +80 -0
- kairo_code/agents/planner.py +66 -0
- kairo_code/agents/reviewer.py +91 -0
- kairo_code/agents/security.py +94 -0
- kairo_code/agents/terraform.py +88 -0
- kairo_code/agents/testing.py +97 -0
- kairo_code/agents/uiux.py +88 -0
- kairo_code/auth.py +232 -0
- kairo_code/config.py +172 -0
- kairo_code/conversation.py +173 -0
- kairo_code/heartbeat.py +63 -0
- kairo_code/llm.py +291 -0
- kairo_code/logging_config.py +156 -0
- kairo_code/main.py +818 -0
- kairo_code/router.py +217 -0
- kairo_code/sandbox.py +248 -0
- kairo_code/settings.py +183 -0
- kairo_code/tools/__init__.py +51 -0
- kairo_code/tools/analysis.py +509 -0
- kairo_code/tools/base.py +417 -0
- kairo_code/tools/code.py +58 -0
- kairo_code/tools/definitions.py +617 -0
- kairo_code/tools/files.py +315 -0
- kairo_code/tools/review.py +390 -0
- kairo_code/tools/search.py +185 -0
- kairo_code/ui.py +418 -0
- kairo_code-0.1.0.dist-info/METADATA +13 -0
- kairo_code-0.1.0.dist-info/RECORD +144 -0
- kairo_code-0.1.0.dist-info/WHEEL +5 -0
- kairo_code-0.1.0.dist-info/entry_points.txt +2 -0
- kairo_code-0.1.0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func, select
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
|
+
|
|
7
|
+
from backend.models.api_key import ApiKey
|
|
8
|
+
from backend.models.conversation import Conversation
|
|
9
|
+
from backend.models.image_generation import ImageGeneration
|
|
10
|
+
from backend.models.usage import UsageRecord
|
|
11
|
+
from backend.models.user import User
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AdminUserService:
|
|
17
|
+
def __init__(self, db: AsyncSession):
|
|
18
|
+
self.db = db
|
|
19
|
+
|
|
20
|
+
async def list_users(
|
|
21
|
+
self,
|
|
22
|
+
search: str | None = None,
|
|
23
|
+
plan: str | None = None,
|
|
24
|
+
status: str | None = None,
|
|
25
|
+
cursor: str | None = None,
|
|
26
|
+
limit: int = 50,
|
|
27
|
+
) -> list[User]:
|
|
28
|
+
stmt = select(User).order_by(User.created_at.desc())
|
|
29
|
+
|
|
30
|
+
if search:
|
|
31
|
+
stmt = stmt.where(User.email.ilike(f"%{search}%"))
|
|
32
|
+
if plan:
|
|
33
|
+
stmt = stmt.where(User.plan == plan)
|
|
34
|
+
if status:
|
|
35
|
+
stmt = stmt.where(User.status == status)
|
|
36
|
+
|
|
37
|
+
if cursor:
|
|
38
|
+
cursor_user = await self.db.get(User, cursor)
|
|
39
|
+
if cursor_user:
|
|
40
|
+
stmt = stmt.where(User.created_at < cursor_user.created_at)
|
|
41
|
+
|
|
42
|
+
stmt = stmt.limit(limit)
|
|
43
|
+
result = await self.db.execute(stmt)
|
|
44
|
+
return list(result.scalars().all())
|
|
45
|
+
|
|
46
|
+
async def get_user_detail(self, user_id: str) -> User | None:
|
|
47
|
+
return await self.db.get(User, user_id)
|
|
48
|
+
|
|
49
|
+
async def change_plan(self, user_id: str, new_plan: str) -> User | None:
|
|
50
|
+
user = await self.db.get(User, user_id)
|
|
51
|
+
if not user:
|
|
52
|
+
return None
|
|
53
|
+
user.plan = new_plan
|
|
54
|
+
await self.db.commit()
|
|
55
|
+
await self.db.refresh(user)
|
|
56
|
+
logger.info("User %s plan changed to %s", user_id, new_plan)
|
|
57
|
+
return user
|
|
58
|
+
|
|
59
|
+
async def change_status(self, user_id: str, new_status: str) -> User | None:
|
|
60
|
+
user = await self.db.get(User, user_id)
|
|
61
|
+
if not user:
|
|
62
|
+
return None
|
|
63
|
+
user.status = new_status
|
|
64
|
+
await self.db.commit()
|
|
65
|
+
await self.db.refresh(user)
|
|
66
|
+
logger.info("User %s status changed to %s", user_id, new_status)
|
|
67
|
+
return user
|
|
68
|
+
|
|
69
|
+
async def change_role(self, user_id: str, new_role: str) -> User | None:
|
|
70
|
+
user = await self.db.get(User, user_id)
|
|
71
|
+
if not user:
|
|
72
|
+
return None
|
|
73
|
+
user.role = new_role
|
|
74
|
+
await self.db.commit()
|
|
75
|
+
await self.db.refresh(user)
|
|
76
|
+
logger.info("User %s role changed to %s", user_id, new_role)
|
|
77
|
+
return user
|
|
78
|
+
|
|
79
|
+
async def get_user_usage(self, user_id: str) -> list[dict]:
|
|
80
|
+
now = datetime.now(timezone.utc)
|
|
81
|
+
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
82
|
+
# Daily usage for the past 30 days
|
|
83
|
+
from datetime import timedelta
|
|
84
|
+
|
|
85
|
+
start_30d = now - timedelta(days=30)
|
|
86
|
+
stmt = (
|
|
87
|
+
select(
|
|
88
|
+
func.date(UsageRecord.created_at).label("date"),
|
|
89
|
+
func.sum(UsageRecord.prompt_tokens + UsageRecord.completion_tokens).label("total_tokens"),
|
|
90
|
+
)
|
|
91
|
+
.where(UsageRecord.user_id == user_id)
|
|
92
|
+
.where(UsageRecord.created_at >= start_30d)
|
|
93
|
+
.group_by(func.date(UsageRecord.created_at))
|
|
94
|
+
.order_by(func.date(UsageRecord.created_at))
|
|
95
|
+
)
|
|
96
|
+
result = await self.db.execute(stmt)
|
|
97
|
+
return [{"date": str(r.date), "total_tokens": r.total_tokens or 0} for r in result.all()]
|
|
98
|
+
|
|
99
|
+
async def get_user_conversations(
|
|
100
|
+
self,
|
|
101
|
+
user_id: str,
|
|
102
|
+
cursor: str | None = None,
|
|
103
|
+
limit: int = 50,
|
|
104
|
+
) -> list[Conversation]:
|
|
105
|
+
stmt = (
|
|
106
|
+
select(Conversation)
|
|
107
|
+
.where(Conversation.user_id == user_id)
|
|
108
|
+
.order_by(Conversation.updated_at.desc())
|
|
109
|
+
)
|
|
110
|
+
if cursor:
|
|
111
|
+
cursor_conv = await self.db.get(Conversation, cursor)
|
|
112
|
+
if cursor_conv:
|
|
113
|
+
stmt = stmt.where(Conversation.updated_at < cursor_conv.updated_at)
|
|
114
|
+
stmt = stmt.limit(limit)
|
|
115
|
+
result = await self.db.execute(stmt)
|
|
116
|
+
return list(result.scalars().all())
|
|
117
|
+
|
|
118
|
+
async def get_user_images(
|
|
119
|
+
self,
|
|
120
|
+
user_id: str,
|
|
121
|
+
cursor: str | None = None,
|
|
122
|
+
limit: int = 50,
|
|
123
|
+
) -> list[ImageGeneration]:
|
|
124
|
+
stmt = (
|
|
125
|
+
select(ImageGeneration)
|
|
126
|
+
.where(ImageGeneration.user_id == user_id)
|
|
127
|
+
.order_by(ImageGeneration.created_at.desc())
|
|
128
|
+
)
|
|
129
|
+
if cursor:
|
|
130
|
+
cursor_img = await self.db.get(ImageGeneration, cursor)
|
|
131
|
+
if cursor_img:
|
|
132
|
+
stmt = stmt.where(ImageGeneration.created_at < cursor_img.created_at)
|
|
133
|
+
stmt = stmt.limit(limit)
|
|
134
|
+
result = await self.db.execute(stmt)
|
|
135
|
+
return list(result.scalars().all())
|
|
136
|
+
|
|
137
|
+
async def get_user_api_keys(self, user_id: str) -> list[dict]:
|
|
138
|
+
stmt = (
|
|
139
|
+
select(ApiKey)
|
|
140
|
+
.where(ApiKey.user_id == user_id)
|
|
141
|
+
.order_by(ApiKey.created_at.desc())
|
|
142
|
+
)
|
|
143
|
+
result = await self.db.execute(stmt)
|
|
144
|
+
keys = result.scalars().all()
|
|
145
|
+
# Return masked keys — only show key_prefix
|
|
146
|
+
return [
|
|
147
|
+
{
|
|
148
|
+
"id": k.id,
|
|
149
|
+
"name": k.name,
|
|
150
|
+
"key_prefix": k.key_prefix,
|
|
151
|
+
"is_active": k.is_active,
|
|
152
|
+
"last_used_at": k.last_used_at.isoformat() if k.last_used_at else None,
|
|
153
|
+
"created_at": k.created_at.isoformat(),
|
|
154
|
+
"expires_at": k.expires_at.isoformat() if k.expires_at else None,
|
|
155
|
+
}
|
|
156
|
+
for k in keys
|
|
157
|
+
]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime, UTC
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from backend.models.agent import Agent
|
|
9
|
+
from backend.schemas.agent import RegisterAgentRequest, UpdateAgentRequest
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AgentService:
|
|
15
|
+
def __init__(self, db: AsyncSession):
|
|
16
|
+
self.db = db
|
|
17
|
+
|
|
18
|
+
async def register(self, user_id: str, req: RegisterAgentRequest) -> Agent:
|
|
19
|
+
agent = Agent(
|
|
20
|
+
user_id=user_id,
|
|
21
|
+
name=req.name,
|
|
22
|
+
description=req.description,
|
|
23
|
+
system_prompt=req.system_prompt,
|
|
24
|
+
model_preference=req.model_preference,
|
|
25
|
+
tools_config=json.dumps(req.tools) if req.tools else None,
|
|
26
|
+
)
|
|
27
|
+
self.db.add(agent)
|
|
28
|
+
await self.db.commit()
|
|
29
|
+
await self.db.refresh(agent)
|
|
30
|
+
logger.info("Agent registered: user=%s name=%s id=%s", user_id, req.name, agent.id)
|
|
31
|
+
return agent
|
|
32
|
+
|
|
33
|
+
async def list_agents(self, user_id: str) -> list[Agent]:
|
|
34
|
+
stmt = (
|
|
35
|
+
select(Agent)
|
|
36
|
+
.where(Agent.user_id == user_id)
|
|
37
|
+
.order_by(Agent.created_at.desc())
|
|
38
|
+
)
|
|
39
|
+
result = await self.db.execute(stmt)
|
|
40
|
+
return list(result.scalars().all())
|
|
41
|
+
|
|
42
|
+
async def get_agent(self, user_id: str, agent_id: str) -> Agent | None:
|
|
43
|
+
stmt = select(Agent).where(Agent.id == agent_id, Agent.user_id == user_id)
|
|
44
|
+
result = await self.db.execute(stmt)
|
|
45
|
+
return result.scalar_one_or_none()
|
|
46
|
+
|
|
47
|
+
async def update_agent(self, user_id: str, agent_id: str, req: UpdateAgentRequest) -> Agent | None:
|
|
48
|
+
agent = await self.get_agent(user_id, agent_id)
|
|
49
|
+
if not agent:
|
|
50
|
+
return None
|
|
51
|
+
if req.name is not None:
|
|
52
|
+
agent.name = req.name
|
|
53
|
+
if req.description is not None:
|
|
54
|
+
agent.description = req.description
|
|
55
|
+
if req.system_prompt is not None:
|
|
56
|
+
agent.system_prompt = req.system_prompt
|
|
57
|
+
if req.model_preference is not None:
|
|
58
|
+
agent.model_preference = req.model_preference
|
|
59
|
+
if req.tools is not None:
|
|
60
|
+
agent.tools_config = json.dumps(req.tools)
|
|
61
|
+
agent.updated_at = datetime.now(UTC)
|
|
62
|
+
await self.db.commit()
|
|
63
|
+
await self.db.refresh(agent)
|
|
64
|
+
return agent
|
|
65
|
+
|
|
66
|
+
async def delete_agent(self, user_id: str, agent_id: str) -> bool:
|
|
67
|
+
agent = await self.get_agent(user_id, agent_id)
|
|
68
|
+
if not agent:
|
|
69
|
+
return False
|
|
70
|
+
await self.db.delete(agent)
|
|
71
|
+
await self.db.commit()
|
|
72
|
+
logger.info("Agent deleted: id=%s user=%s", agent_id, user_id)
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
async def heartbeat(self, agent_id: str, user_id: str, status: str) -> Agent | None:
|
|
76
|
+
"""Update agent heartbeat. user_id comes from the API key owner."""
|
|
77
|
+
agent = await self.get_agent(user_id, agent_id)
|
|
78
|
+
if not agent:
|
|
79
|
+
return None
|
|
80
|
+
agent.status = status
|
|
81
|
+
agent.last_heartbeat_at = datetime.now(UTC)
|
|
82
|
+
agent.updated_at = datetime.now(UTC)
|
|
83
|
+
await self.db.commit()
|
|
84
|
+
await self.db.refresh(agent)
|
|
85
|
+
return agent
|
|
86
|
+
|
|
87
|
+
async def mark_stale_agents_offline(self, threshold_seconds: int) -> int:
|
|
88
|
+
"""Mark agents as offline if no heartbeat within threshold. Returns count."""
|
|
89
|
+
cutoff = datetime.now(UTC)
|
|
90
|
+
from datetime import timedelta
|
|
91
|
+
cutoff = cutoff - timedelta(seconds=threshold_seconds)
|
|
92
|
+
stmt = (
|
|
93
|
+
select(Agent)
|
|
94
|
+
.where(Agent.status != "offline")
|
|
95
|
+
.where(
|
|
96
|
+
(Agent.last_heartbeat_at < cutoff) | (Agent.last_heartbeat_at.is_(None))
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
result = await self.db.execute(stmt)
|
|
100
|
+
agents = list(result.scalars().all())
|
|
101
|
+
for agent in agents:
|
|
102
|
+
agent.status = "offline"
|
|
103
|
+
agent.updated_at = datetime.now(UTC)
|
|
104
|
+
if agents:
|
|
105
|
+
await self.db.commit()
|
|
106
|
+
logger.info("Marked %d agents as offline", len(agents))
|
|
107
|
+
return len(agents)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import datetime, timedelta, UTC
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import select
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
8
|
+
|
|
9
|
+
from backend.models.api_key import ApiKey
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _generate_raw_key() -> str:
|
|
15
|
+
return f"sk-kairo-{uuid.uuid4().hex}"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ApiKeyService:
|
|
19
|
+
def __init__(self, db: AsyncSession):
|
|
20
|
+
self.db = db
|
|
21
|
+
|
|
22
|
+
async def create_key(
|
|
23
|
+
self, user_id: str, name: str, expires_in_days: int | None = None
|
|
24
|
+
) -> tuple[ApiKey, str]:
|
|
25
|
+
"""Create an API key. Returns (record, raw_key). Raw key shown only once."""
|
|
26
|
+
raw_key = _generate_raw_key()
|
|
27
|
+
prefix = raw_key[:20]
|
|
28
|
+
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
29
|
+
expires_at = (
|
|
30
|
+
datetime.now(UTC) + timedelta(days=expires_in_days)
|
|
31
|
+
if expires_in_days
|
|
32
|
+
else None
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
api_key = ApiKey(
|
|
36
|
+
user_id=user_id,
|
|
37
|
+
name=name,
|
|
38
|
+
key_prefix=prefix,
|
|
39
|
+
key_hash=key_hash,
|
|
40
|
+
expires_at=expires_at,
|
|
41
|
+
)
|
|
42
|
+
self.db.add(api_key)
|
|
43
|
+
await self.db.commit()
|
|
44
|
+
await self.db.refresh(api_key)
|
|
45
|
+
logger.info("API key created: user=%s name=%s prefix=%s", user_id, name, prefix)
|
|
46
|
+
return api_key, raw_key
|
|
47
|
+
|
|
48
|
+
async def list_keys(self, user_id: str) -> list[ApiKey]:
|
|
49
|
+
stmt = (
|
|
50
|
+
select(ApiKey)
|
|
51
|
+
.where(ApiKey.user_id == user_id)
|
|
52
|
+
.order_by(ApiKey.created_at.desc())
|
|
53
|
+
)
|
|
54
|
+
result = await self.db.execute(stmt)
|
|
55
|
+
return list(result.scalars().all())
|
|
56
|
+
|
|
57
|
+
async def revoke_key(self, user_id: str, key_id: str) -> bool:
|
|
58
|
+
stmt = select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
|
|
59
|
+
result = await self.db.execute(stmt)
|
|
60
|
+
api_key = result.scalar_one_or_none()
|
|
61
|
+
if not api_key:
|
|
62
|
+
return False
|
|
63
|
+
api_key.is_active = False
|
|
64
|
+
await self.db.commit()
|
|
65
|
+
logger.info("API key revoked: key=%s user=%s", key_id, user_id)
|
|
66
|
+
return True
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from datetime import datetime, timezone, timedelta
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import select, func
|
|
5
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
6
|
+
|
|
7
|
+
from backend.config import settings
|
|
8
|
+
from backend.models.api_usage import ApiUsageRecord
|
|
9
|
+
from backend.models.user import PlanType, User
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
API_PLAN_LIMITS: dict[str, dict[str, int]] = {
|
|
14
|
+
PlanType.FREE.value: {
|
|
15
|
+
"daily": settings.API_DAILY_TOKEN_LIMIT,
|
|
16
|
+
"monthly": settings.API_MONTHLY_TOKEN_LIMIT,
|
|
17
|
+
},
|
|
18
|
+
PlanType.PRO.value: {
|
|
19
|
+
"daily": settings.API_PRO_DAILY_TOKEN_LIMIT,
|
|
20
|
+
"monthly": settings.API_PRO_MONTHLY_TOKEN_LIMIT,
|
|
21
|
+
},
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_api_limits(plan: str) -> tuple[int, int]:
|
|
26
|
+
limits = API_PLAN_LIMITS.get(plan, API_PLAN_LIMITS[PlanType.FREE.value])
|
|
27
|
+
return limits["daily"], limits["monthly"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ApiUsageService:
|
|
31
|
+
def __init__(self, db: AsyncSession):
|
|
32
|
+
self.db = db
|
|
33
|
+
|
|
34
|
+
async def record(
|
|
35
|
+
self,
|
|
36
|
+
api_key_id: str,
|
|
37
|
+
user_id: str,
|
|
38
|
+
model: str,
|
|
39
|
+
prompt_tokens: int,
|
|
40
|
+
completion_tokens: int,
|
|
41
|
+
endpoint: str,
|
|
42
|
+
agent_id: str | None = None,
|
|
43
|
+
) -> ApiUsageRecord:
|
|
44
|
+
record = ApiUsageRecord(
|
|
45
|
+
api_key_id=api_key_id,
|
|
46
|
+
user_id=user_id,
|
|
47
|
+
agent_id=agent_id,
|
|
48
|
+
model=model,
|
|
49
|
+
prompt_tokens=prompt_tokens,
|
|
50
|
+
completion_tokens=completion_tokens,
|
|
51
|
+
endpoint=endpoint,
|
|
52
|
+
)
|
|
53
|
+
self.db.add(record)
|
|
54
|
+
await self.db.commit()
|
|
55
|
+
await self.db.refresh(record)
|
|
56
|
+
logger.info(
|
|
57
|
+
"API usage recorded: key=%s user=%s prompt=%d completion=%d",
|
|
58
|
+
api_key_id, user_id, prompt_tokens, completion_tokens,
|
|
59
|
+
)
|
|
60
|
+
return record
|
|
61
|
+
|
|
62
|
+
async def get_daily_usage(self, user_id: str) -> int:
|
|
63
|
+
now = datetime.now(timezone.utc)
|
|
64
|
+
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
65
|
+
stmt = (
|
|
66
|
+
select(
|
|
67
|
+
func.coalesce(
|
|
68
|
+
func.sum(ApiUsageRecord.prompt_tokens + ApiUsageRecord.completion_tokens), 0
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
.where(ApiUsageRecord.user_id == user_id)
|
|
72
|
+
.where(ApiUsageRecord.created_at >= start_of_day)
|
|
73
|
+
)
|
|
74
|
+
result = await self.db.execute(stmt)
|
|
75
|
+
return result.scalar() or 0
|
|
76
|
+
|
|
77
|
+
async def get_monthly_usage(self, user_id: str) -> int:
|
|
78
|
+
now = datetime.now(timezone.utc)
|
|
79
|
+
start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
80
|
+
stmt = (
|
|
81
|
+
select(
|
|
82
|
+
func.coalesce(
|
|
83
|
+
func.sum(ApiUsageRecord.prompt_tokens + ApiUsageRecord.completion_tokens), 0
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
.where(ApiUsageRecord.user_id == user_id)
|
|
87
|
+
.where(ApiUsageRecord.created_at >= start_of_month)
|
|
88
|
+
)
|
|
89
|
+
result = await self.db.execute(stmt)
|
|
90
|
+
return result.scalar() or 0
|
|
91
|
+
|
|
92
|
+
async def check_limits(self, user_id: str) -> tuple[bool, str]:
|
|
93
|
+
user = await self.db.get(User, user_id)
|
|
94
|
+
if not user:
|
|
95
|
+
return False, "User not found"
|
|
96
|
+
|
|
97
|
+
daily_limit, monthly_limit = _get_api_limits(user.plan)
|
|
98
|
+
|
|
99
|
+
daily = await self.get_daily_usage(user_id)
|
|
100
|
+
if daily >= daily_limit:
|
|
101
|
+
return False, "API daily token limit reached."
|
|
102
|
+
|
|
103
|
+
monthly = await self.get_monthly_usage(user_id)
|
|
104
|
+
if monthly >= monthly_limit:
|
|
105
|
+
return False, "API monthly token limit reached."
|
|
106
|
+
|
|
107
|
+
return True, ""
|
|
108
|
+
|
|
109
|
+
async def get_key_usage_summary(self, api_key_id: str, days: int = 30) -> list[dict]:
|
|
110
|
+
start = datetime.now(timezone.utc) - timedelta(days=days)
|
|
111
|
+
stmt = (
|
|
112
|
+
select(
|
|
113
|
+
func.date(ApiUsageRecord.created_at).label("date"),
|
|
114
|
+
func.sum(ApiUsageRecord.prompt_tokens + ApiUsageRecord.completion_tokens).label("total_tokens"),
|
|
115
|
+
func.count().label("request_count"),
|
|
116
|
+
)
|
|
117
|
+
.where(ApiUsageRecord.api_key_id == api_key_id)
|
|
118
|
+
.where(ApiUsageRecord.created_at >= start)
|
|
119
|
+
.group_by(func.date(ApiUsageRecord.created_at))
|
|
120
|
+
.order_by(func.date(ApiUsageRecord.created_at))
|
|
121
|
+
)
|
|
122
|
+
result = await self.db.execute(stmt)
|
|
123
|
+
return [
|
|
124
|
+
{"date": str(r.date), "total_tokens": r.total_tokens or 0, "requests": r.request_count}
|
|
125
|
+
for r in result.all()
|
|
126
|
+
]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import secrets
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
|
|
8
|
+
from backend.core.security import hash_password, verify_password
|
|
9
|
+
from backend.models.user import User
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AuthService:
|
|
15
|
+
def __init__(self, db: AsyncSession):
|
|
16
|
+
self.db = db
|
|
17
|
+
|
|
18
|
+
async def register(self, email: str, password: str) -> User | None:
|
|
19
|
+
"""Create a new user with a verification token. Returns None if email taken."""
|
|
20
|
+
existing = await self._get_by_email(email)
|
|
21
|
+
if existing:
|
|
22
|
+
return None
|
|
23
|
+
token = secrets.token_urlsafe(32)
|
|
24
|
+
user = User(
|
|
25
|
+
email=email,
|
|
26
|
+
hashed_password=hash_password(password),
|
|
27
|
+
email_verification_token=token,
|
|
28
|
+
)
|
|
29
|
+
self.db.add(user)
|
|
30
|
+
await self.db.commit()
|
|
31
|
+
await self.db.refresh(user)
|
|
32
|
+
logger.info("Registered user %s (id=%s)", email, user.id)
|
|
33
|
+
return user
|
|
34
|
+
|
|
35
|
+
async def authenticate(self, email: str, password: str) -> User | None:
|
|
36
|
+
"""Verify credentials. Returns User or None."""
|
|
37
|
+
user = await self._get_by_email(email)
|
|
38
|
+
if not user or not verify_password(password, user.hashed_password):
|
|
39
|
+
return None
|
|
40
|
+
return user
|
|
41
|
+
|
|
42
|
+
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
43
|
+
result = await self.db.execute(select(User).where(User.id == user_id))
|
|
44
|
+
return result.scalar_one_or_none()
|
|
45
|
+
|
|
46
|
+
async def verify_email(self, token: str) -> User | None:
|
|
47
|
+
"""Mark email as verified if token is valid. Returns user or None."""
|
|
48
|
+
result = await self.db.execute(
|
|
49
|
+
select(User).where(User.email_verification_token == token)
|
|
50
|
+
)
|
|
51
|
+
user = result.scalar_one_or_none()
|
|
52
|
+
if not user:
|
|
53
|
+
return None
|
|
54
|
+
user.email_verified = True
|
|
55
|
+
user.email_verification_token = None
|
|
56
|
+
await self.db.commit()
|
|
57
|
+
await self.db.refresh(user)
|
|
58
|
+
logger.info("Email verified for user %s", user.email)
|
|
59
|
+
return user
|
|
60
|
+
|
|
61
|
+
async def request_password_reset(self, email: str) -> str | None:
|
|
62
|
+
"""Generate a reset token with 1hr expiry. Returns token or None if no user."""
|
|
63
|
+
user = await self._get_by_email(email)
|
|
64
|
+
if not user:
|
|
65
|
+
return None
|
|
66
|
+
token = secrets.token_urlsafe(32)
|
|
67
|
+
user.password_reset_token = token
|
|
68
|
+
user.password_reset_expires = datetime.now(timezone.utc) + timedelta(hours=1)
|
|
69
|
+
await self.db.commit()
|
|
70
|
+
return token
|
|
71
|
+
|
|
72
|
+
async def reset_password(self, token: str, new_password: str) -> bool:
|
|
73
|
+
"""Reset password using token. Returns True on success."""
|
|
74
|
+
result = await self.db.execute(
|
|
75
|
+
select(User).where(User.password_reset_token == token)
|
|
76
|
+
)
|
|
77
|
+
user = result.scalar_one_or_none()
|
|
78
|
+
if not user:
|
|
79
|
+
return False
|
|
80
|
+
if not user.password_reset_expires or user.password_reset_expires < datetime.now(timezone.utc):
|
|
81
|
+
return False
|
|
82
|
+
user.hashed_password = hash_password(new_password)
|
|
83
|
+
user.password_reset_token = None
|
|
84
|
+
user.password_reset_expires = None
|
|
85
|
+
await self.db.commit()
|
|
86
|
+
logger.info("Password reset for user %s", user.email)
|
|
87
|
+
return True
|
|
88
|
+
|
|
89
|
+
async def regenerate_verification_token(self, user_id: str) -> str | None:
|
|
90
|
+
"""Generate a new verification token for resending. Returns token or None."""
|
|
91
|
+
user = await self.get_user_by_id(user_id)
|
|
92
|
+
if not user or user.email_verified:
|
|
93
|
+
return None
|
|
94
|
+
token = secrets.token_urlsafe(32)
|
|
95
|
+
user.email_verification_token = token
|
|
96
|
+
await self.db.commit()
|
|
97
|
+
return token
|
|
98
|
+
|
|
99
|
+
async def _get_by_email(self, email: str) -> User | None:
|
|
100
|
+
result = await self.db.execute(select(User).where(User.email == email))
|
|
101
|
+
return result.scalar_one_or_none()
|