kairo-code 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- image-service/main.py +178 -0
- infra/chat/app/main.py +84 -0
- kairo/backend/__init__.py +0 -0
- kairo/backend/api/__init__.py +0 -0
- kairo/backend/api/admin/__init__.py +23 -0
- kairo/backend/api/admin/audit.py +54 -0
- kairo/backend/api/admin/content.py +142 -0
- kairo/backend/api/admin/incidents.py +148 -0
- kairo/backend/api/admin/stats.py +125 -0
- kairo/backend/api/admin/system.py +87 -0
- kairo/backend/api/admin/users.py +279 -0
- kairo/backend/api/agents.py +94 -0
- kairo/backend/api/api_keys.py +85 -0
- kairo/backend/api/auth.py +116 -0
- kairo/backend/api/billing.py +41 -0
- kairo/backend/api/chat.py +72 -0
- kairo/backend/api/conversations.py +125 -0
- kairo/backend/api/device_auth.py +100 -0
- kairo/backend/api/files.py +83 -0
- kairo/backend/api/health.py +36 -0
- kairo/backend/api/images.py +80 -0
- kairo/backend/api/openai_compat.py +225 -0
- kairo/backend/api/projects.py +102 -0
- kairo/backend/api/usage.py +32 -0
- kairo/backend/api/webhooks.py +79 -0
- kairo/backend/app.py +297 -0
- kairo/backend/config.py +179 -0
- kairo/backend/core/__init__.py +0 -0
- kairo/backend/core/admin_auth.py +24 -0
- kairo/backend/core/api_key_auth.py +55 -0
- kairo/backend/core/database.py +28 -0
- kairo/backend/core/dependencies.py +70 -0
- kairo/backend/core/logging.py +23 -0
- kairo/backend/core/rate_limit.py +73 -0
- kairo/backend/core/security.py +29 -0
- kairo/backend/models/__init__.py +19 -0
- kairo/backend/models/agent.py +30 -0
- kairo/backend/models/api_key.py +25 -0
- kairo/backend/models/api_usage.py +29 -0
- kairo/backend/models/audit_log.py +26 -0
- kairo/backend/models/conversation.py +48 -0
- kairo/backend/models/device_code.py +30 -0
- kairo/backend/models/feature_flag.py +21 -0
- kairo/backend/models/image_generation.py +24 -0
- kairo/backend/models/incident.py +28 -0
- kairo/backend/models/project.py +28 -0
- kairo/backend/models/uptime_record.py +24 -0
- kairo/backend/models/usage.py +24 -0
- kairo/backend/models/user.py +49 -0
- kairo/backend/schemas/__init__.py +0 -0
- kairo/backend/schemas/admin/__init__.py +0 -0
- kairo/backend/schemas/admin/audit.py +28 -0
- kairo/backend/schemas/admin/content.py +53 -0
- kairo/backend/schemas/admin/stats.py +77 -0
- kairo/backend/schemas/admin/system.py +44 -0
- kairo/backend/schemas/admin/users.py +48 -0
- kairo/backend/schemas/agent.py +42 -0
- kairo/backend/schemas/api_key.py +30 -0
- kairo/backend/schemas/auth.py +57 -0
- kairo/backend/schemas/chat.py +26 -0
- kairo/backend/schemas/conversation.py +39 -0
- kairo/backend/schemas/device_auth.py +40 -0
- kairo/backend/schemas/image.py +15 -0
- kairo/backend/schemas/openai_compat.py +76 -0
- kairo/backend/schemas/project.py +21 -0
- kairo/backend/schemas/status.py +81 -0
- kairo/backend/schemas/usage.py +15 -0
- kairo/backend/services/__init__.py +0 -0
- kairo/backend/services/admin/__init__.py +0 -0
- kairo/backend/services/admin/audit_service.py +78 -0
- kairo/backend/services/admin/content_service.py +119 -0
- kairo/backend/services/admin/incident_service.py +94 -0
- kairo/backend/services/admin/stats_service.py +281 -0
- kairo/backend/services/admin/system_service.py +126 -0
- kairo/backend/services/admin/user_service.py +157 -0
- kairo/backend/services/agent_service.py +107 -0
- kairo/backend/services/api_key_service.py +66 -0
- kairo/backend/services/api_usage_service.py +126 -0
- kairo/backend/services/auth_service.py +101 -0
- kairo/backend/services/chat_service.py +501 -0
- kairo/backend/services/conversation_service.py +264 -0
- kairo/backend/services/device_auth_service.py +193 -0
- kairo/backend/services/email_service.py +55 -0
- kairo/backend/services/image_service.py +181 -0
- kairo/backend/services/llm_service.py +186 -0
- kairo/backend/services/project_service.py +109 -0
- kairo/backend/services/status_service.py +167 -0
- kairo/backend/services/stripe_service.py +78 -0
- kairo/backend/services/usage_service.py +150 -0
- kairo/backend/services/web_search_service.py +96 -0
- kairo/migrations/env.py +60 -0
- kairo/migrations/versions/001_initial.py +55 -0
- kairo/migrations/versions/002_usage_tracking_and_indexes.py +66 -0
- kairo/migrations/versions/003_username_to_email.py +21 -0
- kairo/migrations/versions/004_add_plans_and_verification.py +67 -0
- kairo/migrations/versions/005_add_projects.py +52 -0
- kairo/migrations/versions/006_add_image_generation.py +63 -0
- kairo/migrations/versions/007_add_admin_portal.py +107 -0
- kairo/migrations/versions/008_add_device_code_auth.py +76 -0
- kairo/migrations/versions/009_add_status_page.py +65 -0
- kairo/tools/extract_claude_data.py +465 -0
- kairo/tools/filter_claude_data.py +303 -0
- kairo/tools/generate_curated_data.py +157 -0
- kairo/tools/mix_training_data.py +295 -0
- kairo_code/__init__.py +3 -0
- kairo_code/agents/__init__.py +25 -0
- kairo_code/agents/architect.py +98 -0
- kairo_code/agents/audit.py +100 -0
- kairo_code/agents/base.py +463 -0
- kairo_code/agents/coder.py +155 -0
- kairo_code/agents/database.py +77 -0
- kairo_code/agents/docs.py +88 -0
- kairo_code/agents/explorer.py +62 -0
- kairo_code/agents/guardian.py +80 -0
- kairo_code/agents/planner.py +66 -0
- kairo_code/agents/reviewer.py +91 -0
- kairo_code/agents/security.py +94 -0
- kairo_code/agents/terraform.py +88 -0
- kairo_code/agents/testing.py +97 -0
- kairo_code/agents/uiux.py +88 -0
- kairo_code/auth.py +232 -0
- kairo_code/config.py +172 -0
- kairo_code/conversation.py +173 -0
- kairo_code/heartbeat.py +63 -0
- kairo_code/llm.py +291 -0
- kairo_code/logging_config.py +156 -0
- kairo_code/main.py +818 -0
- kairo_code/router.py +217 -0
- kairo_code/sandbox.py +248 -0
- kairo_code/settings.py +183 -0
- kairo_code/tools/__init__.py +51 -0
- kairo_code/tools/analysis.py +509 -0
- kairo_code/tools/base.py +417 -0
- kairo_code/tools/code.py +58 -0
- kairo_code/tools/definitions.py +617 -0
- kairo_code/tools/files.py +315 -0
- kairo_code/tools/review.py +390 -0
- kairo_code/tools/search.py +185 -0
- kairo_code/ui.py +418 -0
- kairo_code-0.1.0.dist-info/METADATA +13 -0
- kairo_code-0.1.0.dist-info/RECORD +144 -0
- kairo_code-0.1.0.dist-info/WHEEL +5 -0
- kairo_code-0.1.0.dist-info/entry_points.txt +2 -0
- kairo_code-0.1.0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, func, delete
|
|
4
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
5
|
+
from sqlalchemy.orm import selectinload
|
|
6
|
+
|
|
7
|
+
from backend.models.conversation import Conversation, Message
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConversationService:
|
|
13
|
+
def __init__(self, db: AsyncSession, user_id: str | None = None):
|
|
14
|
+
self.db = db
|
|
15
|
+
self.user_id = user_id
|
|
16
|
+
|
|
17
|
+
async def create(self, model: str = "nyx", title: str = "New Conversation", project_id: str | None = None) -> Conversation:
|
|
18
|
+
conv = Conversation(model=model, title=title, user_id=self.user_id, project_id=project_id)
|
|
19
|
+
self.db.add(conv)
|
|
20
|
+
await self.db.commit()
|
|
21
|
+
await self.db.refresh(conv)
|
|
22
|
+
return conv
|
|
23
|
+
|
|
24
|
+
async def list_all(self, limit: int = 50, offset: int = 0, project_id: str | None = None) -> tuple[list[dict], int]:
|
|
25
|
+
base = select(Conversation).where(Conversation.user_id == self.user_id) if self.user_id else select(Conversation)
|
|
26
|
+
|
|
27
|
+
# Count
|
|
28
|
+
count_stmt = select(func.count()).select_from(base.subquery())
|
|
29
|
+
total = (await self.db.execute(count_stmt)).scalar() or 0
|
|
30
|
+
|
|
31
|
+
# Data
|
|
32
|
+
stmt = (
|
|
33
|
+
select(
|
|
34
|
+
Conversation.id,
|
|
35
|
+
Conversation.title,
|
|
36
|
+
Conversation.model,
|
|
37
|
+
Conversation.updated_at,
|
|
38
|
+
Conversation.project_id,
|
|
39
|
+
func.count(Message.id).label("message_count"),
|
|
40
|
+
)
|
|
41
|
+
.outerjoin(Message)
|
|
42
|
+
.group_by(Conversation.id)
|
|
43
|
+
.order_by(Conversation.updated_at.desc())
|
|
44
|
+
.limit(limit)
|
|
45
|
+
.offset(offset)
|
|
46
|
+
)
|
|
47
|
+
if self.user_id:
|
|
48
|
+
stmt = stmt.where(Conversation.user_id == self.user_id)
|
|
49
|
+
if project_id is not None:
|
|
50
|
+
stmt = stmt.where(Conversation.project_id == project_id)
|
|
51
|
+
|
|
52
|
+
result = await self.db.execute(stmt)
|
|
53
|
+
rows = result.all()
|
|
54
|
+
items = [
|
|
55
|
+
{
|
|
56
|
+
"id": r.id,
|
|
57
|
+
"title": r.title,
|
|
58
|
+
"model": r.model,
|
|
59
|
+
"updated_at": r.updated_at,
|
|
60
|
+
"message_count": r.message_count,
|
|
61
|
+
"project_id": r.project_id,
|
|
62
|
+
}
|
|
63
|
+
for r in rows
|
|
64
|
+
]
|
|
65
|
+
return items, total
|
|
66
|
+
|
|
67
|
+
async def search(self, query: str, limit: int = 50, offset: int = 0) -> tuple[list[dict], int]:
|
|
68
|
+
"""Search conversations by title or message content."""
|
|
69
|
+
pattern = f"%{query}%"
|
|
70
|
+
|
|
71
|
+
# Find matching conversation IDs
|
|
72
|
+
matching_ids_stmt = (
|
|
73
|
+
select(Conversation.id)
|
|
74
|
+
.outerjoin(Message)
|
|
75
|
+
.where(
|
|
76
|
+
(Conversation.title.ilike(pattern)) | (Message.content.ilike(pattern))
|
|
77
|
+
)
|
|
78
|
+
.group_by(Conversation.id)
|
|
79
|
+
)
|
|
80
|
+
if self.user_id:
|
|
81
|
+
matching_ids_stmt = matching_ids_stmt.where(Conversation.user_id == self.user_id)
|
|
82
|
+
|
|
83
|
+
# Count
|
|
84
|
+
count_stmt = select(func.count()).select_from(matching_ids_stmt.subquery())
|
|
85
|
+
total = (await self.db.execute(count_stmt)).scalar() or 0
|
|
86
|
+
|
|
87
|
+
# Data
|
|
88
|
+
stmt = (
|
|
89
|
+
select(
|
|
90
|
+
Conversation.id,
|
|
91
|
+
Conversation.title,
|
|
92
|
+
Conversation.model,
|
|
93
|
+
Conversation.updated_at,
|
|
94
|
+
Conversation.project_id,
|
|
95
|
+
func.count(Message.id).label("message_count"),
|
|
96
|
+
)
|
|
97
|
+
.outerjoin(Message)
|
|
98
|
+
.where(Conversation.id.in_(matching_ids_stmt))
|
|
99
|
+
.group_by(Conversation.id)
|
|
100
|
+
.order_by(Conversation.updated_at.desc())
|
|
101
|
+
.limit(limit)
|
|
102
|
+
.offset(offset)
|
|
103
|
+
)
|
|
104
|
+
if self.user_id:
|
|
105
|
+
stmt = stmt.where(Conversation.user_id == self.user_id)
|
|
106
|
+
|
|
107
|
+
result = await self.db.execute(stmt)
|
|
108
|
+
rows = result.all()
|
|
109
|
+
items = [
|
|
110
|
+
{
|
|
111
|
+
"id": r.id,
|
|
112
|
+
"title": r.title,
|
|
113
|
+
"model": r.model,
|
|
114
|
+
"updated_at": r.updated_at,
|
|
115
|
+
"message_count": r.message_count,
|
|
116
|
+
"project_id": r.project_id,
|
|
117
|
+
}
|
|
118
|
+
for r in rows
|
|
119
|
+
]
|
|
120
|
+
return items, total
|
|
121
|
+
|
|
122
|
+
async def get(self, conversation_id: str) -> Conversation | None:
|
|
123
|
+
stmt = (
|
|
124
|
+
select(Conversation)
|
|
125
|
+
.options(selectinload(Conversation.messages), selectinload(Conversation.project))
|
|
126
|
+
.where(Conversation.id == conversation_id)
|
|
127
|
+
)
|
|
128
|
+
if self.user_id:
|
|
129
|
+
stmt = stmt.where(Conversation.user_id == self.user_id)
|
|
130
|
+
result = await self.db.execute(stmt)
|
|
131
|
+
return result.scalar_one_or_none()
|
|
132
|
+
|
|
133
|
+
async def get_fresh(self, conversation_id: str) -> Conversation | None:
|
|
134
|
+
"""Get conversation with fresh data, bypassing the identity map cache."""
|
|
135
|
+
stmt = (
|
|
136
|
+
select(Conversation)
|
|
137
|
+
.options(selectinload(Conversation.messages), selectinload(Conversation.project))
|
|
138
|
+
.where(Conversation.id == conversation_id)
|
|
139
|
+
.execution_options(populate_existing=True)
|
|
140
|
+
)
|
|
141
|
+
if self.user_id:
|
|
142
|
+
stmt = stmt.where(Conversation.user_id == self.user_id)
|
|
143
|
+
result = await self.db.execute(stmt)
|
|
144
|
+
return result.scalar_one_or_none()
|
|
145
|
+
|
|
146
|
+
async def delete(self, conversation_id: str) -> bool:
|
|
147
|
+
conv = await self.get(conversation_id)
|
|
148
|
+
if not conv:
|
|
149
|
+
return False
|
|
150
|
+
await self.db.delete(conv)
|
|
151
|
+
await self.db.commit()
|
|
152
|
+
return True
|
|
153
|
+
|
|
154
|
+
async def rename(self, conversation_id: str, title: str) -> Conversation | None:
|
|
155
|
+
conv = await self.get(conversation_id)
|
|
156
|
+
if not conv:
|
|
157
|
+
return None
|
|
158
|
+
conv.title = title
|
|
159
|
+
await self.db.commit()
|
|
160
|
+
await self.db.refresh(conv)
|
|
161
|
+
return conv
|
|
162
|
+
|
|
163
|
+
async def add_message(self, conversation_id: str, role: str, content: str, image_url: str | None = None) -> Message:
|
|
164
|
+
msg = Message(conversation_id=conversation_id, role=role, content=content, image_url=image_url)
|
|
165
|
+
self.db.add(msg)
|
|
166
|
+
await self.db.commit()
|
|
167
|
+
await self.db.refresh(msg)
|
|
168
|
+
return msg
|
|
169
|
+
|
|
170
|
+
async def get_message(self, message_id: str) -> Message | None:
|
|
171
|
+
stmt = select(Message).where(Message.id == message_id)
|
|
172
|
+
result = await self.db.execute(stmt)
|
|
173
|
+
return result.scalar_one_or_none()
|
|
174
|
+
|
|
175
|
+
async def edit_message(self, message_id: str, new_content: str) -> Message | None:
|
|
176
|
+
"""Edit a message and delete all subsequent messages in the conversation."""
|
|
177
|
+
msg = await self.get_message(message_id)
|
|
178
|
+
if not msg:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
# Verify user owns this conversation
|
|
182
|
+
conv = await self.get(msg.conversation_id)
|
|
183
|
+
if not conv:
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
# Delete all messages after this one
|
|
187
|
+
await self.db.execute(
|
|
188
|
+
delete(Message).where(
|
|
189
|
+
Message.conversation_id == msg.conversation_id,
|
|
190
|
+
Message.created_at > msg.created_at,
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Update the message content
|
|
195
|
+
msg.content = new_content
|
|
196
|
+
await self.db.commit()
|
|
197
|
+
await self.db.refresh(msg)
|
|
198
|
+
return msg
|
|
199
|
+
|
|
200
|
+
async def search_messages(self, query: str, limit: int = 20, offset: int = 0) -> list[dict]:
|
|
201
|
+
"""Search message content and return snippets with conversation context."""
|
|
202
|
+
pattern = f"%{query}%"
|
|
203
|
+
stmt = (
|
|
204
|
+
select(
|
|
205
|
+
Message.id,
|
|
206
|
+
Message.role,
|
|
207
|
+
Message.content,
|
|
208
|
+
Message.created_at,
|
|
209
|
+
Conversation.id.label("conversation_id"),
|
|
210
|
+
Conversation.title.label("conversation_title"),
|
|
211
|
+
)
|
|
212
|
+
.join(Conversation, Message.conversation_id == Conversation.id)
|
|
213
|
+
.where(Message.content.ilike(pattern))
|
|
214
|
+
.order_by(Message.created_at.desc())
|
|
215
|
+
.limit(limit)
|
|
216
|
+
.offset(offset)
|
|
217
|
+
)
|
|
218
|
+
if self.user_id:
|
|
219
|
+
stmt = stmt.where(Conversation.user_id == self.user_id)
|
|
220
|
+
|
|
221
|
+
result = await self.db.execute(stmt)
|
|
222
|
+
rows = result.all()
|
|
223
|
+
items = []
|
|
224
|
+
for r in rows:
|
|
225
|
+
# Build snippet: ~100 chars around first match
|
|
226
|
+
content = r.content
|
|
227
|
+
lower_content = content.lower()
|
|
228
|
+
idx = lower_content.find(query.lower())
|
|
229
|
+
if idx >= 0:
|
|
230
|
+
start = max(0, idx - 50)
|
|
231
|
+
end = min(len(content), idx + len(query) + 50)
|
|
232
|
+
snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
|
|
233
|
+
else:
|
|
234
|
+
snippet = content[:100] + ("..." if len(content) > 100 else "")
|
|
235
|
+
items.append({
|
|
236
|
+
"message_id": r.id,
|
|
237
|
+
"role": r.role,
|
|
238
|
+
"snippet": snippet,
|
|
239
|
+
"created_at": r.created_at,
|
|
240
|
+
"conversation_id": r.conversation_id,
|
|
241
|
+
"conversation_title": r.conversation_title,
|
|
242
|
+
})
|
|
243
|
+
return items
|
|
244
|
+
|
|
245
|
+
async def delete_last_assistant_message(self, conversation_id: str) -> bool:
|
|
246
|
+
"""Delete the last assistant message in a conversation for regeneration."""
|
|
247
|
+
conv = await self.get(conversation_id)
|
|
248
|
+
if not conv or not conv.messages:
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
# Find last assistant message
|
|
252
|
+
messages = sorted(conv.messages, key=lambda m: m.created_at)
|
|
253
|
+
last_assistant = None
|
|
254
|
+
for m in reversed(messages):
|
|
255
|
+
if m.role == "assistant":
|
|
256
|
+
last_assistant = m
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
if not last_assistant:
|
|
260
|
+
return False
|
|
261
|
+
|
|
262
|
+
await self.db.delete(last_assistant)
|
|
263
|
+
await self.db.commit()
|
|
264
|
+
return True
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import random
|
|
4
|
+
import string
|
|
5
|
+
import uuid
|
|
6
|
+
from datetime import UTC, datetime, timedelta
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import delete, select
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
|
+
|
|
11
|
+
from backend.models.api_key import ApiKey
|
|
12
|
+
from backend.models.device_code import DeviceCode
|
|
13
|
+
from backend.models.user import User
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _generate_user_code() -> str:
|
|
19
|
+
"""Generate a human-readable code like ABCD-1234."""
|
|
20
|
+
letters = "".join(random.choices(string.ascii_uppercase, k=4))
|
|
21
|
+
digits = "".join(random.choices(string.digits, k=4))
|
|
22
|
+
return f"{letters}-{digits}"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _generate_device_code() -> str:
|
|
26
|
+
"""Generate a 64-character hex device code."""
|
|
27
|
+
return uuid.uuid4().hex + uuid.uuid4().hex
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DeviceAuthService:
|
|
31
|
+
EXPIRY_MINUTES = 10
|
|
32
|
+
POLL_INTERVAL = 5
|
|
33
|
+
|
|
34
|
+
def __init__(self, db: AsyncSession):
|
|
35
|
+
self.db = db
|
|
36
|
+
|
|
37
|
+
async def create_device_code(self, client_name: str) -> DeviceCode:
|
|
38
|
+
"""Create a new device code for the CLI auth flow."""
|
|
39
|
+
user_code = await self._unique_user_code()
|
|
40
|
+
|
|
41
|
+
device = DeviceCode(
|
|
42
|
+
device_code=_generate_device_code(),
|
|
43
|
+
user_code=user_code,
|
|
44
|
+
client_name=client_name,
|
|
45
|
+
expires_at=datetime.now(UTC) + timedelta(minutes=self.EXPIRY_MINUTES),
|
|
46
|
+
interval=self.POLL_INTERVAL,
|
|
47
|
+
)
|
|
48
|
+
self.db.add(device)
|
|
49
|
+
await self.db.commit()
|
|
50
|
+
await self.db.refresh(device)
|
|
51
|
+
logger.info("Device code created: user_code=%s client=%s", user_code, client_name)
|
|
52
|
+
return device
|
|
53
|
+
|
|
54
|
+
async def approve(self, user_code: str, user: User) -> tuple[str, DeviceCode] | None:
|
|
55
|
+
"""Approve a pending device code.
|
|
56
|
+
|
|
57
|
+
Returns (raw_api_key, device_code) or None if not found / expired.
|
|
58
|
+
"""
|
|
59
|
+
device = await self._find_pending_device(user_code)
|
|
60
|
+
if not device:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
raw_key = self._create_raw_cli_key()
|
|
64
|
+
api_key = self._build_api_key(raw_key, user.id, device.client_name)
|
|
65
|
+
self.db.add(api_key)
|
|
66
|
+
await self.db.flush()
|
|
67
|
+
|
|
68
|
+
self._mark_approved(device, user.id, api_key.id, raw_key)
|
|
69
|
+
await self.db.commit()
|
|
70
|
+
|
|
71
|
+
logger.info("Device code approved: user_code=%s user=%s", user_code, user.id)
|
|
72
|
+
return raw_key, device
|
|
73
|
+
|
|
74
|
+
async def poll_token(
|
|
75
|
+
self, device_code: str
|
|
76
|
+
) -> tuple[str, dict | None]:
|
|
77
|
+
"""Poll for token status.
|
|
78
|
+
|
|
79
|
+
Returns (status, user_info_or_none).
|
|
80
|
+
status: authorization_pending | expired_token | access_denied | approved
|
|
81
|
+
user_info: {"raw_key": ..., "plan": ..., "email": ...} when approved.
|
|
82
|
+
"""
|
|
83
|
+
device = await self._find_device_by_code(device_code)
|
|
84
|
+
if not device:
|
|
85
|
+
return "expired_token", None
|
|
86
|
+
|
|
87
|
+
if self._is_expired_pending(device):
|
|
88
|
+
device.status = "expired"
|
|
89
|
+
await self.db.commit()
|
|
90
|
+
return "expired_token", None
|
|
91
|
+
|
|
92
|
+
if device.status in ("denied", "expired"):
|
|
93
|
+
status = "access_denied" if device.status == "denied" else "expired_token"
|
|
94
|
+
return status, None
|
|
95
|
+
|
|
96
|
+
if device.status == "pending":
|
|
97
|
+
return "authorization_pending", None
|
|
98
|
+
|
|
99
|
+
if device.status == "approved" and device.cli_token:
|
|
100
|
+
return await self._consume_approved_token(device)
|
|
101
|
+
|
|
102
|
+
return "authorization_pending", None
|
|
103
|
+
|
|
104
|
+
async def cleanup_expired(self) -> int:
|
|
105
|
+
"""Delete device codes older than 1 hour."""
|
|
106
|
+
cutoff = datetime.now(UTC) - timedelta(hours=1)
|
|
107
|
+
stmt = delete(DeviceCode).where(DeviceCode.created_at < cutoff)
|
|
108
|
+
result = await self.db.execute(stmt)
|
|
109
|
+
await self.db.commit()
|
|
110
|
+
deleted = result.rowcount or 0
|
|
111
|
+
if deleted:
|
|
112
|
+
logger.info("Cleaned up %d expired device codes", deleted)
|
|
113
|
+
return deleted
|
|
114
|
+
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Private helpers
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
async def _unique_user_code(self) -> str:
|
|
120
|
+
"""Generate a user code that does not collide with pending codes."""
|
|
121
|
+
for _ in range(5):
|
|
122
|
+
user_code = _generate_user_code()
|
|
123
|
+
stmt = select(DeviceCode).where(
|
|
124
|
+
DeviceCode.user_code == user_code,
|
|
125
|
+
DeviceCode.status == "pending",
|
|
126
|
+
)
|
|
127
|
+
result = await self.db.execute(stmt)
|
|
128
|
+
if not result.scalar_one_or_none():
|
|
129
|
+
return user_code
|
|
130
|
+
# Extremely unlikely fallback
|
|
131
|
+
return _generate_user_code()
|
|
132
|
+
|
|
133
|
+
async def _find_pending_device(self, user_code: str) -> DeviceCode | None:
|
|
134
|
+
stmt = select(DeviceCode).where(
|
|
135
|
+
DeviceCode.user_code == user_code,
|
|
136
|
+
DeviceCode.status == "pending",
|
|
137
|
+
DeviceCode.expires_at > datetime.now(UTC),
|
|
138
|
+
)
|
|
139
|
+
result = await self.db.execute(stmt)
|
|
140
|
+
return result.scalar_one_or_none()
|
|
141
|
+
|
|
142
|
+
async def _find_device_by_code(self, device_code: str) -> DeviceCode | None:
|
|
143
|
+
stmt = select(DeviceCode).where(DeviceCode.device_code == device_code)
|
|
144
|
+
result = await self.db.execute(stmt)
|
|
145
|
+
return result.scalar_one_or_none()
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _create_raw_cli_key() -> str:
|
|
149
|
+
return f"sk-kairo-cli-{uuid.uuid4().hex}"
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _build_api_key(raw_key: str, user_id: str, client_name: str | None) -> ApiKey:
|
|
153
|
+
prefix = raw_key[:24]
|
|
154
|
+
key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
155
|
+
return ApiKey(
|
|
156
|
+
user_id=user_id,
|
|
157
|
+
name=f"Kairo Code CLI ({client_name or 'cli'})",
|
|
158
|
+
key_prefix=prefix,
|
|
159
|
+
key_hash=key_hash,
|
|
160
|
+
key_type="cli",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _mark_approved(
|
|
165
|
+
device: DeviceCode, user_id: str, api_key_id: str, raw_key: str
|
|
166
|
+
) -> None:
|
|
167
|
+
device.status = "approved"
|
|
168
|
+
device.user_id = user_id
|
|
169
|
+
device.cli_api_key_id = api_key_id
|
|
170
|
+
device.cli_token = raw_key
|
|
171
|
+
device.approved_at = datetime.now(UTC)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _is_expired_pending(device: DeviceCode) -> bool:
|
|
175
|
+
return device.expires_at < datetime.now(UTC) and device.status == "pending"
|
|
176
|
+
|
|
177
|
+
async def _consume_approved_token(
|
|
178
|
+
self, device: DeviceCode
|
|
179
|
+
) -> tuple[str, dict | None]:
|
|
180
|
+
"""Return the token and clear it (single-use retrieval)."""
|
|
181
|
+
raw_key = device.cli_token
|
|
182
|
+
device.cli_token = None # single-use: clear after first poll
|
|
183
|
+
await self.db.commit()
|
|
184
|
+
|
|
185
|
+
user = await self.db.get(User, device.user_id)
|
|
186
|
+
if not user:
|
|
187
|
+
return "access_denied", None
|
|
188
|
+
|
|
189
|
+
return "approved", {
|
|
190
|
+
"raw_key": raw_key,
|
|
191
|
+
"plan": user.plan,
|
|
192
|
+
"email": user.email,
|
|
193
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import smtplib
|
|
3
|
+
from email.mime.multipart import MIMEMultipart
|
|
4
|
+
from email.mime.text import MIMEText
|
|
5
|
+
|
|
6
|
+
from backend.config import settings
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EmailService:
|
|
12
|
+
def send_verification_email(self, to: str, token: str) -> None:
|
|
13
|
+
link = f"{settings.APP_BASE_URL}/verify-email?token={token}"
|
|
14
|
+
subject = "Verify your Kairo email"
|
|
15
|
+
html = (
|
|
16
|
+
f"<h2>Welcome to Kairo</h2>"
|
|
17
|
+
f"<p>Click the link below to verify your email address:</p>"
|
|
18
|
+
f'<p><a href="{link}">Verify Email</a></p>'
|
|
19
|
+
f"<p>Or copy this URL: {link}</p>"
|
|
20
|
+
f"<p>If you didn't create a Kairo account, you can ignore this email.</p>"
|
|
21
|
+
)
|
|
22
|
+
self._send(to, subject, html)
|
|
23
|
+
|
|
24
|
+
def send_password_reset_email(self, to: str, token: str) -> None:
|
|
25
|
+
link = f"{settings.APP_BASE_URL}/reset-password?token={token}"
|
|
26
|
+
subject = "Reset your Kairo password"
|
|
27
|
+
html = (
|
|
28
|
+
f"<h2>Password Reset</h2>"
|
|
29
|
+
f"<p>Click the link below to reset your password. This link expires in 1 hour.</p>"
|
|
30
|
+
f'<p><a href="{link}">Reset Password</a></p>'
|
|
31
|
+
f"<p>Or copy this URL: {link}</p>"
|
|
32
|
+
f"<p>If you didn't request this, you can ignore this email.</p>"
|
|
33
|
+
)
|
|
34
|
+
self._send(to, subject, html)
|
|
35
|
+
|
|
36
|
+
def _send(self, to: str, subject: str, html_body: str) -> None:
|
|
37
|
+
if not settings.SMTP_USERNAME:
|
|
38
|
+
logger.warning("SMTP not configured, skipping email to %s", to)
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
msg = MIMEMultipart("alternative")
|
|
42
|
+
msg["Subject"] = subject
|
|
43
|
+
msg["From"] = f"{settings.SMTP_FROM_NAME} <{settings.SMTP_FROM_EMAIL}>"
|
|
44
|
+
msg["To"] = to
|
|
45
|
+
msg.attach(MIMEText(html_body, "html"))
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
with smtplib.SMTP(settings.SMTP_HOST, settings.SMTP_PORT) as server:
|
|
49
|
+
server.starttls()
|
|
50
|
+
server.login(settings.SMTP_USERNAME, settings.SMTP_PASSWORD)
|
|
51
|
+
server.sendmail(settings.SMTP_FROM_EMAIL, to, msg.as_string())
|
|
52
|
+
logger.info("Email sent to %s: %s", to, subject)
|
|
53
|
+
except Exception:
|
|
54
|
+
logger.exception("Failed to send email to %s", to)
|
|
55
|
+
raise
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import uuid
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
|
|
5
|
+
import boto3
|
|
6
|
+
import httpx
|
|
7
|
+
from sqlalchemy import select, func
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
from backend.config import settings
|
|
11
|
+
from backend.models.image_generation import ImageGeneration
|
|
12
|
+
from backend.models.user import PlanType, User
|
|
13
|
+
from backend.services.conversation_service import ConversationService
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
IMAGE_LIMITS: dict[str, dict[str, int]] = {
|
|
18
|
+
PlanType.PRO.value: {
|
|
19
|
+
"daily": settings.PRO_DAILY_IMAGE_LIMIT,
|
|
20
|
+
"monthly": settings.PRO_MONTHLY_IMAGE_LIMIT,
|
|
21
|
+
},
|
|
22
|
+
PlanType.MAX.value: {
|
|
23
|
+
"daily": settings.MAX_DAILY_IMAGE_LIMIT,
|
|
24
|
+
"monthly": settings.MAX_MONTHLY_IMAGE_LIMIT,
|
|
25
|
+
},
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ImageService:
|
|
30
|
+
def __init__(self, db: AsyncSession, user_id: str):
|
|
31
|
+
self.db = db
|
|
32
|
+
self.user_id = user_id
|
|
33
|
+
self.conv_service = ConversationService(db, user_id=user_id)
|
|
34
|
+
|
|
35
|
+
async def check_access(self) -> tuple[bool, str]:
|
|
36
|
+
"""Check plan tier and image usage limits."""
|
|
37
|
+
user = await self.db.get(User, self.user_id)
|
|
38
|
+
if not user:
|
|
39
|
+
return False, "User not found."
|
|
40
|
+
|
|
41
|
+
if user.plan not in (PlanType.PRO.value, PlanType.MAX.value):
|
|
42
|
+
return False, "Image generation requires a Pro or Max plan."
|
|
43
|
+
|
|
44
|
+
limits = IMAGE_LIMITS.get(user.plan)
|
|
45
|
+
if not limits:
|
|
46
|
+
return False, "No image limits configured for your plan."
|
|
47
|
+
|
|
48
|
+
now = datetime.now(timezone.utc)
|
|
49
|
+
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
50
|
+
start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
51
|
+
|
|
52
|
+
daily_count = (await self.db.execute(
|
|
53
|
+
select(func.count())
|
|
54
|
+
.where(ImageGeneration.user_id == self.user_id)
|
|
55
|
+
.where(ImageGeneration.created_at >= start_of_day)
|
|
56
|
+
)).scalar() or 0
|
|
57
|
+
|
|
58
|
+
if daily_count >= limits["daily"]:
|
|
59
|
+
return False, "Daily image generation limit reached."
|
|
60
|
+
|
|
61
|
+
monthly_count = (await self.db.execute(
|
|
62
|
+
select(func.count())
|
|
63
|
+
.where(ImageGeneration.user_id == self.user_id)
|
|
64
|
+
.where(ImageGeneration.created_at >= start_of_month)
|
|
65
|
+
)).scalar() or 0
|
|
66
|
+
|
|
67
|
+
if monthly_count >= limits["monthly"]:
|
|
68
|
+
return False, "Monthly image generation limit reached."
|
|
69
|
+
|
|
70
|
+
return True, ""
|
|
71
|
+
|
|
72
|
+
async def _save_result(
|
|
73
|
+
self,
|
|
74
|
+
image_bytes: bytes,
|
|
75
|
+
prompt: str,
|
|
76
|
+
conversation_id: str | None,
|
|
77
|
+
width: int = 1024,
|
|
78
|
+
height: int = 1024,
|
|
79
|
+
) -> dict:
|
|
80
|
+
"""Upload image to S3, save to conversation, record usage."""
|
|
81
|
+
# Upload to S3
|
|
82
|
+
image_id = str(uuid.uuid4())
|
|
83
|
+
s3_key = f"generated/{image_id}.png"
|
|
84
|
+
s3 = boto3.client("s3", region_name=settings.S3_IMAGES_REGION)
|
|
85
|
+
s3.put_object(
|
|
86
|
+
Bucket=settings.S3_IMAGES_BUCKET,
|
|
87
|
+
Key=s3_key,
|
|
88
|
+
Body=image_bytes,
|
|
89
|
+
ContentType="image/png",
|
|
90
|
+
)
|
|
91
|
+
image_url = f"https://{settings.S3_IMAGES_BUCKET}.s3.{settings.S3_IMAGES_REGION}.amazonaws.com/{s3_key}"
|
|
92
|
+
|
|
93
|
+
# Get or create conversation
|
|
94
|
+
conv = None
|
|
95
|
+
if conversation_id:
|
|
96
|
+
conv = await self.conv_service.get(conversation_id)
|
|
97
|
+
if not conv:
|
|
98
|
+
conv = await self.conv_service.create(model="flux-dev")
|
|
99
|
+
|
|
100
|
+
# Save user message (the prompt)
|
|
101
|
+
await self.conv_service.add_message(conv.id, "user", prompt)
|
|
102
|
+
|
|
103
|
+
# Save assistant message with image_url
|
|
104
|
+
msg = await self.conv_service.add_message(
|
|
105
|
+
conv.id, "assistant", f"",
|
|
106
|
+
image_url=image_url,
|
|
107
|
+
)
|
|
108
|
+
msg_id = msg.id
|
|
109
|
+
|
|
110
|
+
# Record in image_generations table
|
|
111
|
+
record = ImageGeneration(
|
|
112
|
+
user_id=self.user_id,
|
|
113
|
+
conversation_id=conv.id,
|
|
114
|
+
prompt=prompt,
|
|
115
|
+
image_url=image_url,
|
|
116
|
+
width=width,
|
|
117
|
+
height=height,
|
|
118
|
+
)
|
|
119
|
+
self.db.add(record)
|
|
120
|
+
await self.db.commit()
|
|
121
|
+
|
|
122
|
+
# Auto-title if new conversation
|
|
123
|
+
if conv.title == "New Conversation":
|
|
124
|
+
title = prompt[:47] + "..." if len(prompt) > 50 else prompt
|
|
125
|
+
await self.conv_service.rename(conv.id, title)
|
|
126
|
+
|
|
127
|
+
return {
|
|
128
|
+
"image_url": image_url,
|
|
129
|
+
"message_id": msg_id,
|
|
130
|
+
"conversation_id": conv.id,
|
|
131
|
+
"prompt": prompt,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
async def generate(
|
|
135
|
+
self,
|
|
136
|
+
prompt: str,
|
|
137
|
+
conversation_id: str | None,
|
|
138
|
+
width: int = 1024,
|
|
139
|
+
height: int = 1024,
|
|
140
|
+
) -> dict:
|
|
141
|
+
"""Call FLUX text2img service, upload to S3, save to conversation."""
|
|
142
|
+
base_url = settings.FLUX_BASE_URL.rstrip("/")
|
|
143
|
+
headers = {}
|
|
144
|
+
if settings.FLUX_API_KEY:
|
|
145
|
+
headers["Authorization"] = f"Bearer {settings.FLUX_API_KEY}"
|
|
146
|
+
|
|
147
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) as client:
|
|
148
|
+
resp = await client.post(
|
|
149
|
+
f"{base_url}/generate",
|
|
150
|
+
json={"prompt": prompt, "width": width, "height": height},
|
|
151
|
+
headers=headers,
|
|
152
|
+
)
|
|
153
|
+
resp.raise_for_status()
|
|
154
|
+
image_bytes = resp.content
|
|
155
|
+
|
|
156
|
+
return await self._save_result(image_bytes, prompt, conversation_id, width, height)
|
|
157
|
+
|
|
158
|
+
async def generate_img2img(
|
|
159
|
+
self,
|
|
160
|
+
image_bytes: bytes,
|
|
161
|
+
prompt: str,
|
|
162
|
+
conversation_id: str | None,
|
|
163
|
+
strength: float = 0.75,
|
|
164
|
+
) -> dict:
|
|
165
|
+
"""Call FLUX img2img service, upload to S3, save to conversation."""
|
|
166
|
+
base_url = settings.FLUX_BASE_URL.rstrip("/")
|
|
167
|
+
headers = {}
|
|
168
|
+
if settings.FLUX_API_KEY:
|
|
169
|
+
headers["Authorization"] = f"Bearer {settings.FLUX_API_KEY}"
|
|
170
|
+
|
|
171
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) as client:
|
|
172
|
+
resp = await client.post(
|
|
173
|
+
f"{base_url}/img2img",
|
|
174
|
+
files={"image": ("input.png", image_bytes, "image/png")},
|
|
175
|
+
data={"prompt": prompt, "strength": str(strength)},
|
|
176
|
+
headers=headers,
|
|
177
|
+
)
|
|
178
|
+
resp.raise_for_status()
|
|
179
|
+
result_bytes = resp.content
|
|
180
|
+
|
|
181
|
+
return await self._save_result(result_bytes, prompt, conversation_id)
|