kairo-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. image-service/main.py +178 -0
  2. infra/chat/app/main.py +84 -0
  3. kairo/backend/__init__.py +0 -0
  4. kairo/backend/api/__init__.py +0 -0
  5. kairo/backend/api/admin/__init__.py +23 -0
  6. kairo/backend/api/admin/audit.py +54 -0
  7. kairo/backend/api/admin/content.py +142 -0
  8. kairo/backend/api/admin/incidents.py +148 -0
  9. kairo/backend/api/admin/stats.py +125 -0
  10. kairo/backend/api/admin/system.py +87 -0
  11. kairo/backend/api/admin/users.py +279 -0
  12. kairo/backend/api/agents.py +94 -0
  13. kairo/backend/api/api_keys.py +85 -0
  14. kairo/backend/api/auth.py +116 -0
  15. kairo/backend/api/billing.py +41 -0
  16. kairo/backend/api/chat.py +72 -0
  17. kairo/backend/api/conversations.py +125 -0
  18. kairo/backend/api/device_auth.py +100 -0
  19. kairo/backend/api/files.py +83 -0
  20. kairo/backend/api/health.py +36 -0
  21. kairo/backend/api/images.py +80 -0
  22. kairo/backend/api/openai_compat.py +225 -0
  23. kairo/backend/api/projects.py +102 -0
  24. kairo/backend/api/usage.py +32 -0
  25. kairo/backend/api/webhooks.py +79 -0
  26. kairo/backend/app.py +297 -0
  27. kairo/backend/config.py +179 -0
  28. kairo/backend/core/__init__.py +0 -0
  29. kairo/backend/core/admin_auth.py +24 -0
  30. kairo/backend/core/api_key_auth.py +55 -0
  31. kairo/backend/core/database.py +28 -0
  32. kairo/backend/core/dependencies.py +70 -0
  33. kairo/backend/core/logging.py +23 -0
  34. kairo/backend/core/rate_limit.py +73 -0
  35. kairo/backend/core/security.py +29 -0
  36. kairo/backend/models/__init__.py +19 -0
  37. kairo/backend/models/agent.py +30 -0
  38. kairo/backend/models/api_key.py +25 -0
  39. kairo/backend/models/api_usage.py +29 -0
  40. kairo/backend/models/audit_log.py +26 -0
  41. kairo/backend/models/conversation.py +48 -0
  42. kairo/backend/models/device_code.py +30 -0
  43. kairo/backend/models/feature_flag.py +21 -0
  44. kairo/backend/models/image_generation.py +24 -0
  45. kairo/backend/models/incident.py +28 -0
  46. kairo/backend/models/project.py +28 -0
  47. kairo/backend/models/uptime_record.py +24 -0
  48. kairo/backend/models/usage.py +24 -0
  49. kairo/backend/models/user.py +49 -0
  50. kairo/backend/schemas/__init__.py +0 -0
  51. kairo/backend/schemas/admin/__init__.py +0 -0
  52. kairo/backend/schemas/admin/audit.py +28 -0
  53. kairo/backend/schemas/admin/content.py +53 -0
  54. kairo/backend/schemas/admin/stats.py +77 -0
  55. kairo/backend/schemas/admin/system.py +44 -0
  56. kairo/backend/schemas/admin/users.py +48 -0
  57. kairo/backend/schemas/agent.py +42 -0
  58. kairo/backend/schemas/api_key.py +30 -0
  59. kairo/backend/schemas/auth.py +57 -0
  60. kairo/backend/schemas/chat.py +26 -0
  61. kairo/backend/schemas/conversation.py +39 -0
  62. kairo/backend/schemas/device_auth.py +40 -0
  63. kairo/backend/schemas/image.py +15 -0
  64. kairo/backend/schemas/openai_compat.py +76 -0
  65. kairo/backend/schemas/project.py +21 -0
  66. kairo/backend/schemas/status.py +81 -0
  67. kairo/backend/schemas/usage.py +15 -0
  68. kairo/backend/services/__init__.py +0 -0
  69. kairo/backend/services/admin/__init__.py +0 -0
  70. kairo/backend/services/admin/audit_service.py +78 -0
  71. kairo/backend/services/admin/content_service.py +119 -0
  72. kairo/backend/services/admin/incident_service.py +94 -0
  73. kairo/backend/services/admin/stats_service.py +281 -0
  74. kairo/backend/services/admin/system_service.py +126 -0
  75. kairo/backend/services/admin/user_service.py +157 -0
  76. kairo/backend/services/agent_service.py +107 -0
  77. kairo/backend/services/api_key_service.py +66 -0
  78. kairo/backend/services/api_usage_service.py +126 -0
  79. kairo/backend/services/auth_service.py +101 -0
  80. kairo/backend/services/chat_service.py +501 -0
  81. kairo/backend/services/conversation_service.py +264 -0
  82. kairo/backend/services/device_auth_service.py +193 -0
  83. kairo/backend/services/email_service.py +55 -0
  84. kairo/backend/services/image_service.py +181 -0
  85. kairo/backend/services/llm_service.py +186 -0
  86. kairo/backend/services/project_service.py +109 -0
  87. kairo/backend/services/status_service.py +167 -0
  88. kairo/backend/services/stripe_service.py +78 -0
  89. kairo/backend/services/usage_service.py +150 -0
  90. kairo/backend/services/web_search_service.py +96 -0
  91. kairo/migrations/env.py +60 -0
  92. kairo/migrations/versions/001_initial.py +55 -0
  93. kairo/migrations/versions/002_usage_tracking_and_indexes.py +66 -0
  94. kairo/migrations/versions/003_username_to_email.py +21 -0
  95. kairo/migrations/versions/004_add_plans_and_verification.py +67 -0
  96. kairo/migrations/versions/005_add_projects.py +52 -0
  97. kairo/migrations/versions/006_add_image_generation.py +63 -0
  98. kairo/migrations/versions/007_add_admin_portal.py +107 -0
  99. kairo/migrations/versions/008_add_device_code_auth.py +76 -0
  100. kairo/migrations/versions/009_add_status_page.py +65 -0
  101. kairo/tools/extract_claude_data.py +465 -0
  102. kairo/tools/filter_claude_data.py +303 -0
  103. kairo/tools/generate_curated_data.py +157 -0
  104. kairo/tools/mix_training_data.py +295 -0
  105. kairo_code/__init__.py +3 -0
  106. kairo_code/agents/__init__.py +25 -0
  107. kairo_code/agents/architect.py +98 -0
  108. kairo_code/agents/audit.py +100 -0
  109. kairo_code/agents/base.py +463 -0
  110. kairo_code/agents/coder.py +155 -0
  111. kairo_code/agents/database.py +77 -0
  112. kairo_code/agents/docs.py +88 -0
  113. kairo_code/agents/explorer.py +62 -0
  114. kairo_code/agents/guardian.py +80 -0
  115. kairo_code/agents/planner.py +66 -0
  116. kairo_code/agents/reviewer.py +91 -0
  117. kairo_code/agents/security.py +94 -0
  118. kairo_code/agents/terraform.py +88 -0
  119. kairo_code/agents/testing.py +97 -0
  120. kairo_code/agents/uiux.py +88 -0
  121. kairo_code/auth.py +232 -0
  122. kairo_code/config.py +172 -0
  123. kairo_code/conversation.py +173 -0
  124. kairo_code/heartbeat.py +63 -0
  125. kairo_code/llm.py +291 -0
  126. kairo_code/logging_config.py +156 -0
  127. kairo_code/main.py +818 -0
  128. kairo_code/router.py +217 -0
  129. kairo_code/sandbox.py +248 -0
  130. kairo_code/settings.py +183 -0
  131. kairo_code/tools/__init__.py +51 -0
  132. kairo_code/tools/analysis.py +509 -0
  133. kairo_code/tools/base.py +417 -0
  134. kairo_code/tools/code.py +58 -0
  135. kairo_code/tools/definitions.py +617 -0
  136. kairo_code/tools/files.py +315 -0
  137. kairo_code/tools/review.py +390 -0
  138. kairo_code/tools/search.py +185 -0
  139. kairo_code/ui.py +418 -0
  140. kairo_code-0.1.0.dist-info/METADATA +13 -0
  141. kairo_code-0.1.0.dist-info/RECORD +144 -0
  142. kairo_code-0.1.0.dist-info/WHEEL +5 -0
  143. kairo_code-0.1.0.dist-info/entry_points.txt +2 -0
  144. kairo_code-0.1.0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,157 @@
1
+ import logging
2
+ from datetime import datetime, timezone
3
+
4
+ from sqlalchemy import func, select
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
+ from backend.models.api_key import ApiKey
8
+ from backend.models.conversation import Conversation
9
+ from backend.models.image_generation import ImageGeneration
10
+ from backend.models.usage import UsageRecord
11
+ from backend.models.user import User
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AdminUserService:
17
+ def __init__(self, db: AsyncSession):
18
+ self.db = db
19
+
20
+ async def list_users(
21
+ self,
22
+ search: str | None = None,
23
+ plan: str | None = None,
24
+ status: str | None = None,
25
+ cursor: str | None = None,
26
+ limit: int = 50,
27
+ ) -> list[User]:
28
+ stmt = select(User).order_by(User.created_at.desc())
29
+
30
+ if search:
31
+ stmt = stmt.where(User.email.ilike(f"%{search}%"))
32
+ if plan:
33
+ stmt = stmt.where(User.plan == plan)
34
+ if status:
35
+ stmt = stmt.where(User.status == status)
36
+
37
+ if cursor:
38
+ cursor_user = await self.db.get(User, cursor)
39
+ if cursor_user:
40
+ stmt = stmt.where(User.created_at < cursor_user.created_at)
41
+
42
+ stmt = stmt.limit(limit)
43
+ result = await self.db.execute(stmt)
44
+ return list(result.scalars().all())
45
+
46
+ async def get_user_detail(self, user_id: str) -> User | None:
47
+ return await self.db.get(User, user_id)
48
+
49
+ async def change_plan(self, user_id: str, new_plan: str) -> User | None:
50
+ user = await self.db.get(User, user_id)
51
+ if not user:
52
+ return None
53
+ user.plan = new_plan
54
+ await self.db.commit()
55
+ await self.db.refresh(user)
56
+ logger.info("User %s plan changed to %s", user_id, new_plan)
57
+ return user
58
+
59
+ async def change_status(self, user_id: str, new_status: str) -> User | None:
60
+ user = await self.db.get(User, user_id)
61
+ if not user:
62
+ return None
63
+ user.status = new_status
64
+ await self.db.commit()
65
+ await self.db.refresh(user)
66
+ logger.info("User %s status changed to %s", user_id, new_status)
67
+ return user
68
+
69
+ async def change_role(self, user_id: str, new_role: str) -> User | None:
70
+ user = await self.db.get(User, user_id)
71
+ if not user:
72
+ return None
73
+ user.role = new_role
74
+ await self.db.commit()
75
+ await self.db.refresh(user)
76
+ logger.info("User %s role changed to %s", user_id, new_role)
77
+ return user
78
+
79
+ async def get_user_usage(self, user_id: str) -> list[dict]:
80
+ now = datetime.now(timezone.utc)
81
+ start = now.replace(hour=0, minute=0, second=0, microsecond=0)
82
+ # Daily usage for the past 30 days
83
+ from datetime import timedelta
84
+
85
+ start_30d = now - timedelta(days=30)
86
+ stmt = (
87
+ select(
88
+ func.date(UsageRecord.created_at).label("date"),
89
+ func.sum(UsageRecord.prompt_tokens + UsageRecord.completion_tokens).label("total_tokens"),
90
+ )
91
+ .where(UsageRecord.user_id == user_id)
92
+ .where(UsageRecord.created_at >= start_30d)
93
+ .group_by(func.date(UsageRecord.created_at))
94
+ .order_by(func.date(UsageRecord.created_at))
95
+ )
96
+ result = await self.db.execute(stmt)
97
+ return [{"date": str(r.date), "total_tokens": r.total_tokens or 0} for r in result.all()]
98
+
99
+ async def get_user_conversations(
100
+ self,
101
+ user_id: str,
102
+ cursor: str | None = None,
103
+ limit: int = 50,
104
+ ) -> list[Conversation]:
105
+ stmt = (
106
+ select(Conversation)
107
+ .where(Conversation.user_id == user_id)
108
+ .order_by(Conversation.updated_at.desc())
109
+ )
110
+ if cursor:
111
+ cursor_conv = await self.db.get(Conversation, cursor)
112
+ if cursor_conv:
113
+ stmt = stmt.where(Conversation.updated_at < cursor_conv.updated_at)
114
+ stmt = stmt.limit(limit)
115
+ result = await self.db.execute(stmt)
116
+ return list(result.scalars().all())
117
+
118
+ async def get_user_images(
119
+ self,
120
+ user_id: str,
121
+ cursor: str | None = None,
122
+ limit: int = 50,
123
+ ) -> list[ImageGeneration]:
124
+ stmt = (
125
+ select(ImageGeneration)
126
+ .where(ImageGeneration.user_id == user_id)
127
+ .order_by(ImageGeneration.created_at.desc())
128
+ )
129
+ if cursor:
130
+ cursor_img = await self.db.get(ImageGeneration, cursor)
131
+ if cursor_img:
132
+ stmt = stmt.where(ImageGeneration.created_at < cursor_img.created_at)
133
+ stmt = stmt.limit(limit)
134
+ result = await self.db.execute(stmt)
135
+ return list(result.scalars().all())
136
+
137
+ async def get_user_api_keys(self, user_id: str) -> list[dict]:
138
+ stmt = (
139
+ select(ApiKey)
140
+ .where(ApiKey.user_id == user_id)
141
+ .order_by(ApiKey.created_at.desc())
142
+ )
143
+ result = await self.db.execute(stmt)
144
+ keys = result.scalars().all()
145
+ # Return masked keys — only show key_prefix
146
+ return [
147
+ {
148
+ "id": k.id,
149
+ "name": k.name,
150
+ "key_prefix": k.key_prefix,
151
+ "is_active": k.is_active,
152
+ "last_used_at": k.last_used_at.isoformat() if k.last_used_at else None,
153
+ "created_at": k.created_at.isoformat(),
154
+ "expires_at": k.expires_at.isoformat() if k.expires_at else None,
155
+ }
156
+ for k in keys
157
+ ]
@@ -0,0 +1,107 @@
1
+ import json
2
+ import logging
3
+ from datetime import datetime, UTC
4
+
5
+ from sqlalchemy import select
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from backend.models.agent import Agent
9
+ from backend.schemas.agent import RegisterAgentRequest, UpdateAgentRequest
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class AgentService:
15
+ def __init__(self, db: AsyncSession):
16
+ self.db = db
17
+
18
+ async def register(self, user_id: str, req: RegisterAgentRequest) -> Agent:
19
+ agent = Agent(
20
+ user_id=user_id,
21
+ name=req.name,
22
+ description=req.description,
23
+ system_prompt=req.system_prompt,
24
+ model_preference=req.model_preference,
25
+ tools_config=json.dumps(req.tools) if req.tools else None,
26
+ )
27
+ self.db.add(agent)
28
+ await self.db.commit()
29
+ await self.db.refresh(agent)
30
+ logger.info("Agent registered: user=%s name=%s id=%s", user_id, req.name, agent.id)
31
+ return agent
32
+
33
+ async def list_agents(self, user_id: str) -> list[Agent]:
34
+ stmt = (
35
+ select(Agent)
36
+ .where(Agent.user_id == user_id)
37
+ .order_by(Agent.created_at.desc())
38
+ )
39
+ result = await self.db.execute(stmt)
40
+ return list(result.scalars().all())
41
+
42
+ async def get_agent(self, user_id: str, agent_id: str) -> Agent | None:
43
+ stmt = select(Agent).where(Agent.id == agent_id, Agent.user_id == user_id)
44
+ result = await self.db.execute(stmt)
45
+ return result.scalar_one_or_none()
46
+
47
+ async def update_agent(self, user_id: str, agent_id: str, req: UpdateAgentRequest) -> Agent | None:
48
+ agent = await self.get_agent(user_id, agent_id)
49
+ if not agent:
50
+ return None
51
+ if req.name is not None:
52
+ agent.name = req.name
53
+ if req.description is not None:
54
+ agent.description = req.description
55
+ if req.system_prompt is not None:
56
+ agent.system_prompt = req.system_prompt
57
+ if req.model_preference is not None:
58
+ agent.model_preference = req.model_preference
59
+ if req.tools is not None:
60
+ agent.tools_config = json.dumps(req.tools)
61
+ agent.updated_at = datetime.now(UTC)
62
+ await self.db.commit()
63
+ await self.db.refresh(agent)
64
+ return agent
65
+
66
+ async def delete_agent(self, user_id: str, agent_id: str) -> bool:
67
+ agent = await self.get_agent(user_id, agent_id)
68
+ if not agent:
69
+ return False
70
+ await self.db.delete(agent)
71
+ await self.db.commit()
72
+ logger.info("Agent deleted: id=%s user=%s", agent_id, user_id)
73
+ return True
74
+
75
+ async def heartbeat(self, agent_id: str, user_id: str, status: str) -> Agent | None:
76
+ """Update agent heartbeat. user_id comes from the API key owner."""
77
+ agent = await self.get_agent(user_id, agent_id)
78
+ if not agent:
79
+ return None
80
+ agent.status = status
81
+ agent.last_heartbeat_at = datetime.now(UTC)
82
+ agent.updated_at = datetime.now(UTC)
83
+ await self.db.commit()
84
+ await self.db.refresh(agent)
85
+ return agent
86
+
87
+ async def mark_stale_agents_offline(self, threshold_seconds: int) -> int:
88
+ """Mark agents as offline if no heartbeat within threshold. Returns count."""
89
+ cutoff = datetime.now(UTC)
90
+ from datetime import timedelta
91
+ cutoff = cutoff - timedelta(seconds=threshold_seconds)
92
+ stmt = (
93
+ select(Agent)
94
+ .where(Agent.status != "offline")
95
+ .where(
96
+ (Agent.last_heartbeat_at < cutoff) | (Agent.last_heartbeat_at.is_(None))
97
+ )
98
+ )
99
+ result = await self.db.execute(stmt)
100
+ agents = list(result.scalars().all())
101
+ for agent in agents:
102
+ agent.status = "offline"
103
+ agent.updated_at = datetime.now(UTC)
104
+ if agents:
105
+ await self.db.commit()
106
+ logger.info("Marked %d agents as offline", len(agents))
107
+ return len(agents)
@@ -0,0 +1,66 @@
1
+ import hashlib
2
+ import logging
3
+ import uuid
4
+ from datetime import datetime, timedelta, UTC
5
+
6
+ from sqlalchemy import select
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+
9
+ from backend.models.api_key import ApiKey
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _generate_raw_key() -> str:
15
+ return f"sk-kairo-{uuid.uuid4().hex}"
16
+
17
+
18
+ class ApiKeyService:
19
+ def __init__(self, db: AsyncSession):
20
+ self.db = db
21
+
22
+ async def create_key(
23
+ self, user_id: str, name: str, expires_in_days: int | None = None
24
+ ) -> tuple[ApiKey, str]:
25
+ """Create an API key. Returns (record, raw_key). Raw key shown only once."""
26
+ raw_key = _generate_raw_key()
27
+ prefix = raw_key[:20]
28
+ key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
29
+ expires_at = (
30
+ datetime.now(UTC) + timedelta(days=expires_in_days)
31
+ if expires_in_days
32
+ else None
33
+ )
34
+
35
+ api_key = ApiKey(
36
+ user_id=user_id,
37
+ name=name,
38
+ key_prefix=prefix,
39
+ key_hash=key_hash,
40
+ expires_at=expires_at,
41
+ )
42
+ self.db.add(api_key)
43
+ await self.db.commit()
44
+ await self.db.refresh(api_key)
45
+ logger.info("API key created: user=%s name=%s prefix=%s", user_id, name, prefix)
46
+ return api_key, raw_key
47
+
48
+ async def list_keys(self, user_id: str) -> list[ApiKey]:
49
+ stmt = (
50
+ select(ApiKey)
51
+ .where(ApiKey.user_id == user_id)
52
+ .order_by(ApiKey.created_at.desc())
53
+ )
54
+ result = await self.db.execute(stmt)
55
+ return list(result.scalars().all())
56
+
57
+ async def revoke_key(self, user_id: str, key_id: str) -> bool:
58
+ stmt = select(ApiKey).where(ApiKey.id == key_id, ApiKey.user_id == user_id)
59
+ result = await self.db.execute(stmt)
60
+ api_key = result.scalar_one_or_none()
61
+ if not api_key:
62
+ return False
63
+ api_key.is_active = False
64
+ await self.db.commit()
65
+ logger.info("API key revoked: key=%s user=%s", key_id, user_id)
66
+ return True
@@ -0,0 +1,126 @@
1
+ import logging
2
+ from datetime import datetime, timezone, timedelta
3
+
4
+ from sqlalchemy import select, func
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
+ from backend.config import settings
8
+ from backend.models.api_usage import ApiUsageRecord
9
+ from backend.models.user import PlanType, User
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ API_PLAN_LIMITS: dict[str, dict[str, int]] = {
14
+ PlanType.FREE.value: {
15
+ "daily": settings.API_DAILY_TOKEN_LIMIT,
16
+ "monthly": settings.API_MONTHLY_TOKEN_LIMIT,
17
+ },
18
+ PlanType.PRO.value: {
19
+ "daily": settings.API_PRO_DAILY_TOKEN_LIMIT,
20
+ "monthly": settings.API_PRO_MONTHLY_TOKEN_LIMIT,
21
+ },
22
+ }
23
+
24
+
25
+ def _get_api_limits(plan: str) -> tuple[int, int]:
26
+ limits = API_PLAN_LIMITS.get(plan, API_PLAN_LIMITS[PlanType.FREE.value])
27
+ return limits["daily"], limits["monthly"]
28
+
29
+
30
+ class ApiUsageService:
31
+ def __init__(self, db: AsyncSession):
32
+ self.db = db
33
+
34
+ async def record(
35
+ self,
36
+ api_key_id: str,
37
+ user_id: str,
38
+ model: str,
39
+ prompt_tokens: int,
40
+ completion_tokens: int,
41
+ endpoint: str,
42
+ agent_id: str | None = None,
43
+ ) -> ApiUsageRecord:
44
+ record = ApiUsageRecord(
45
+ api_key_id=api_key_id,
46
+ user_id=user_id,
47
+ agent_id=agent_id,
48
+ model=model,
49
+ prompt_tokens=prompt_tokens,
50
+ completion_tokens=completion_tokens,
51
+ endpoint=endpoint,
52
+ )
53
+ self.db.add(record)
54
+ await self.db.commit()
55
+ await self.db.refresh(record)
56
+ logger.info(
57
+ "API usage recorded: key=%s user=%s prompt=%d completion=%d",
58
+ api_key_id, user_id, prompt_tokens, completion_tokens,
59
+ )
60
+ return record
61
+
62
+ async def get_daily_usage(self, user_id: str) -> int:
63
+ now = datetime.now(timezone.utc)
64
+ start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
65
+ stmt = (
66
+ select(
67
+ func.coalesce(
68
+ func.sum(ApiUsageRecord.prompt_tokens + ApiUsageRecord.completion_tokens), 0
69
+ )
70
+ )
71
+ .where(ApiUsageRecord.user_id == user_id)
72
+ .where(ApiUsageRecord.created_at >= start_of_day)
73
+ )
74
+ result = await self.db.execute(stmt)
75
+ return result.scalar() or 0
76
+
77
+ async def get_monthly_usage(self, user_id: str) -> int:
78
+ now = datetime.now(timezone.utc)
79
+ start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
80
+ stmt = (
81
+ select(
82
+ func.coalesce(
83
+ func.sum(ApiUsageRecord.prompt_tokens + ApiUsageRecord.completion_tokens), 0
84
+ )
85
+ )
86
+ .where(ApiUsageRecord.user_id == user_id)
87
+ .where(ApiUsageRecord.created_at >= start_of_month)
88
+ )
89
+ result = await self.db.execute(stmt)
90
+ return result.scalar() or 0
91
+
92
+ async def check_limits(self, user_id: str) -> tuple[bool, str]:
93
+ user = await self.db.get(User, user_id)
94
+ if not user:
95
+ return False, "User not found"
96
+
97
+ daily_limit, monthly_limit = _get_api_limits(user.plan)
98
+
99
+ daily = await self.get_daily_usage(user_id)
100
+ if daily >= daily_limit:
101
+ return False, "API daily token limit reached."
102
+
103
+ monthly = await self.get_monthly_usage(user_id)
104
+ if monthly >= monthly_limit:
105
+ return False, "API monthly token limit reached."
106
+
107
+ return True, ""
108
+
109
+ async def get_key_usage_summary(self, api_key_id: str, days: int = 30) -> list[dict]:
110
+ start = datetime.now(timezone.utc) - timedelta(days=days)
111
+ stmt = (
112
+ select(
113
+ func.date(ApiUsageRecord.created_at).label("date"),
114
+ func.sum(ApiUsageRecord.prompt_tokens + ApiUsageRecord.completion_tokens).label("total_tokens"),
115
+ func.count().label("request_count"),
116
+ )
117
+ .where(ApiUsageRecord.api_key_id == api_key_id)
118
+ .where(ApiUsageRecord.created_at >= start)
119
+ .group_by(func.date(ApiUsageRecord.created_at))
120
+ .order_by(func.date(ApiUsageRecord.created_at))
121
+ )
122
+ result = await self.db.execute(stmt)
123
+ return [
124
+ {"date": str(r.date), "total_tokens": r.total_tokens or 0, "requests": r.request_count}
125
+ for r in result.all()
126
+ ]
@@ -0,0 +1,101 @@
1
+ import logging
2
+ import secrets
3
+ from datetime import datetime, timedelta, timezone
4
+
5
+ from sqlalchemy import select
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from backend.core.security import hash_password, verify_password
9
+ from backend.models.user import User
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class AuthService:
15
+ def __init__(self, db: AsyncSession):
16
+ self.db = db
17
+
18
+ async def register(self, email: str, password: str) -> User | None:
19
+ """Create a new user with a verification token. Returns None if email taken."""
20
+ existing = await self._get_by_email(email)
21
+ if existing:
22
+ return None
23
+ token = secrets.token_urlsafe(32)
24
+ user = User(
25
+ email=email,
26
+ hashed_password=hash_password(password),
27
+ email_verification_token=token,
28
+ )
29
+ self.db.add(user)
30
+ await self.db.commit()
31
+ await self.db.refresh(user)
32
+ logger.info("Registered user %s (id=%s)", email, user.id)
33
+ return user
34
+
35
+ async def authenticate(self, email: str, password: str) -> User | None:
36
+ """Verify credentials. Returns User or None."""
37
+ user = await self._get_by_email(email)
38
+ if not user or not verify_password(password, user.hashed_password):
39
+ return None
40
+ return user
41
+
42
+ async def get_user_by_id(self, user_id: str) -> User | None:
43
+ result = await self.db.execute(select(User).where(User.id == user_id))
44
+ return result.scalar_one_or_none()
45
+
46
+ async def verify_email(self, token: str) -> User | None:
47
+ """Mark email as verified if token is valid. Returns user or None."""
48
+ result = await self.db.execute(
49
+ select(User).where(User.email_verification_token == token)
50
+ )
51
+ user = result.scalar_one_or_none()
52
+ if not user:
53
+ return None
54
+ user.email_verified = True
55
+ user.email_verification_token = None
56
+ await self.db.commit()
57
+ await self.db.refresh(user)
58
+ logger.info("Email verified for user %s", user.email)
59
+ return user
60
+
61
+ async def request_password_reset(self, email: str) -> str | None:
62
+ """Generate a reset token with 1hr expiry. Returns token or None if no user."""
63
+ user = await self._get_by_email(email)
64
+ if not user:
65
+ return None
66
+ token = secrets.token_urlsafe(32)
67
+ user.password_reset_token = token
68
+ user.password_reset_expires = datetime.now(timezone.utc) + timedelta(hours=1)
69
+ await self.db.commit()
70
+ return token
71
+
72
+ async def reset_password(self, token: str, new_password: str) -> bool:
73
+ """Reset password using token. Returns True on success."""
74
+ result = await self.db.execute(
75
+ select(User).where(User.password_reset_token == token)
76
+ )
77
+ user = result.scalar_one_or_none()
78
+ if not user:
79
+ return False
80
+ if not user.password_reset_expires or user.password_reset_expires < datetime.now(timezone.utc):
81
+ return False
82
+ user.hashed_password = hash_password(new_password)
83
+ user.password_reset_token = None
84
+ user.password_reset_expires = None
85
+ await self.db.commit()
86
+ logger.info("Password reset for user %s", user.email)
87
+ return True
88
+
89
+ async def regenerate_verification_token(self, user_id: str) -> str | None:
90
+ """Generate a new verification token for resending. Returns token or None."""
91
+ user = await self.get_user_by_id(user_id)
92
+ if not user or user.email_verified:
93
+ return None
94
+ token = secrets.token_urlsafe(32)
95
+ user.email_verification_token = token
96
+ await self.db.commit()
97
+ return token
98
+
99
+ async def _get_by_email(self, email: str) -> User | None:
100
+ result = await self.db.execute(select(User).where(User.email == email))
101
+ return result.scalar_one_or_none()