kairo-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. image-service/main.py +178 -0
  2. infra/chat/app/main.py +84 -0
  3. kairo/backend/__init__.py +0 -0
  4. kairo/backend/api/__init__.py +0 -0
  5. kairo/backend/api/admin/__init__.py +23 -0
  6. kairo/backend/api/admin/audit.py +54 -0
  7. kairo/backend/api/admin/content.py +142 -0
  8. kairo/backend/api/admin/incidents.py +148 -0
  9. kairo/backend/api/admin/stats.py +125 -0
  10. kairo/backend/api/admin/system.py +87 -0
  11. kairo/backend/api/admin/users.py +279 -0
  12. kairo/backend/api/agents.py +94 -0
  13. kairo/backend/api/api_keys.py +85 -0
  14. kairo/backend/api/auth.py +116 -0
  15. kairo/backend/api/billing.py +41 -0
  16. kairo/backend/api/chat.py +72 -0
  17. kairo/backend/api/conversations.py +125 -0
  18. kairo/backend/api/device_auth.py +100 -0
  19. kairo/backend/api/files.py +83 -0
  20. kairo/backend/api/health.py +36 -0
  21. kairo/backend/api/images.py +80 -0
  22. kairo/backend/api/openai_compat.py +225 -0
  23. kairo/backend/api/projects.py +102 -0
  24. kairo/backend/api/usage.py +32 -0
  25. kairo/backend/api/webhooks.py +79 -0
  26. kairo/backend/app.py +297 -0
  27. kairo/backend/config.py +179 -0
  28. kairo/backend/core/__init__.py +0 -0
  29. kairo/backend/core/admin_auth.py +24 -0
  30. kairo/backend/core/api_key_auth.py +55 -0
  31. kairo/backend/core/database.py +28 -0
  32. kairo/backend/core/dependencies.py +70 -0
  33. kairo/backend/core/logging.py +23 -0
  34. kairo/backend/core/rate_limit.py +73 -0
  35. kairo/backend/core/security.py +29 -0
  36. kairo/backend/models/__init__.py +19 -0
  37. kairo/backend/models/agent.py +30 -0
  38. kairo/backend/models/api_key.py +25 -0
  39. kairo/backend/models/api_usage.py +29 -0
  40. kairo/backend/models/audit_log.py +26 -0
  41. kairo/backend/models/conversation.py +48 -0
  42. kairo/backend/models/device_code.py +30 -0
  43. kairo/backend/models/feature_flag.py +21 -0
  44. kairo/backend/models/image_generation.py +24 -0
  45. kairo/backend/models/incident.py +28 -0
  46. kairo/backend/models/project.py +28 -0
  47. kairo/backend/models/uptime_record.py +24 -0
  48. kairo/backend/models/usage.py +24 -0
  49. kairo/backend/models/user.py +49 -0
  50. kairo/backend/schemas/__init__.py +0 -0
  51. kairo/backend/schemas/admin/__init__.py +0 -0
  52. kairo/backend/schemas/admin/audit.py +28 -0
  53. kairo/backend/schemas/admin/content.py +53 -0
  54. kairo/backend/schemas/admin/stats.py +77 -0
  55. kairo/backend/schemas/admin/system.py +44 -0
  56. kairo/backend/schemas/admin/users.py +48 -0
  57. kairo/backend/schemas/agent.py +42 -0
  58. kairo/backend/schemas/api_key.py +30 -0
  59. kairo/backend/schemas/auth.py +57 -0
  60. kairo/backend/schemas/chat.py +26 -0
  61. kairo/backend/schemas/conversation.py +39 -0
  62. kairo/backend/schemas/device_auth.py +40 -0
  63. kairo/backend/schemas/image.py +15 -0
  64. kairo/backend/schemas/openai_compat.py +76 -0
  65. kairo/backend/schemas/project.py +21 -0
  66. kairo/backend/schemas/status.py +81 -0
  67. kairo/backend/schemas/usage.py +15 -0
  68. kairo/backend/services/__init__.py +0 -0
  69. kairo/backend/services/admin/__init__.py +0 -0
  70. kairo/backend/services/admin/audit_service.py +78 -0
  71. kairo/backend/services/admin/content_service.py +119 -0
  72. kairo/backend/services/admin/incident_service.py +94 -0
  73. kairo/backend/services/admin/stats_service.py +281 -0
  74. kairo/backend/services/admin/system_service.py +126 -0
  75. kairo/backend/services/admin/user_service.py +157 -0
  76. kairo/backend/services/agent_service.py +107 -0
  77. kairo/backend/services/api_key_service.py +66 -0
  78. kairo/backend/services/api_usage_service.py +126 -0
  79. kairo/backend/services/auth_service.py +101 -0
  80. kairo/backend/services/chat_service.py +501 -0
  81. kairo/backend/services/conversation_service.py +264 -0
  82. kairo/backend/services/device_auth_service.py +193 -0
  83. kairo/backend/services/email_service.py +55 -0
  84. kairo/backend/services/image_service.py +181 -0
  85. kairo/backend/services/llm_service.py +186 -0
  86. kairo/backend/services/project_service.py +109 -0
  87. kairo/backend/services/status_service.py +167 -0
  88. kairo/backend/services/stripe_service.py +78 -0
  89. kairo/backend/services/usage_service.py +150 -0
  90. kairo/backend/services/web_search_service.py +96 -0
  91. kairo/migrations/env.py +60 -0
  92. kairo/migrations/versions/001_initial.py +55 -0
  93. kairo/migrations/versions/002_usage_tracking_and_indexes.py +66 -0
  94. kairo/migrations/versions/003_username_to_email.py +21 -0
  95. kairo/migrations/versions/004_add_plans_and_verification.py +67 -0
  96. kairo/migrations/versions/005_add_projects.py +52 -0
  97. kairo/migrations/versions/006_add_image_generation.py +63 -0
  98. kairo/migrations/versions/007_add_admin_portal.py +107 -0
  99. kairo/migrations/versions/008_add_device_code_auth.py +76 -0
  100. kairo/migrations/versions/009_add_status_page.py +65 -0
  101. kairo/tools/extract_claude_data.py +465 -0
  102. kairo/tools/filter_claude_data.py +303 -0
  103. kairo/tools/generate_curated_data.py +157 -0
  104. kairo/tools/mix_training_data.py +295 -0
  105. kairo_code/__init__.py +3 -0
  106. kairo_code/agents/__init__.py +25 -0
  107. kairo_code/agents/architect.py +98 -0
  108. kairo_code/agents/audit.py +100 -0
  109. kairo_code/agents/base.py +463 -0
  110. kairo_code/agents/coder.py +155 -0
  111. kairo_code/agents/database.py +77 -0
  112. kairo_code/agents/docs.py +88 -0
  113. kairo_code/agents/explorer.py +62 -0
  114. kairo_code/agents/guardian.py +80 -0
  115. kairo_code/agents/planner.py +66 -0
  116. kairo_code/agents/reviewer.py +91 -0
  117. kairo_code/agents/security.py +94 -0
  118. kairo_code/agents/terraform.py +88 -0
  119. kairo_code/agents/testing.py +97 -0
  120. kairo_code/agents/uiux.py +88 -0
  121. kairo_code/auth.py +232 -0
  122. kairo_code/config.py +172 -0
  123. kairo_code/conversation.py +173 -0
  124. kairo_code/heartbeat.py +63 -0
  125. kairo_code/llm.py +291 -0
  126. kairo_code/logging_config.py +156 -0
  127. kairo_code/main.py +818 -0
  128. kairo_code/router.py +217 -0
  129. kairo_code/sandbox.py +248 -0
  130. kairo_code/settings.py +183 -0
  131. kairo_code/tools/__init__.py +51 -0
  132. kairo_code/tools/analysis.py +509 -0
  133. kairo_code/tools/base.py +417 -0
  134. kairo_code/tools/code.py +58 -0
  135. kairo_code/tools/definitions.py +617 -0
  136. kairo_code/tools/files.py +315 -0
  137. kairo_code/tools/review.py +390 -0
  138. kairo_code/tools/search.py +185 -0
  139. kairo_code/ui.py +418 -0
  140. kairo_code-0.1.0.dist-info/METADATA +13 -0
  141. kairo_code-0.1.0.dist-info/RECORD +144 -0
  142. kairo_code-0.1.0.dist-info/WHEEL +5 -0
  143. kairo_code-0.1.0.dist-info/entry_points.txt +2 -0
  144. kairo_code-0.1.0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,264 @@
1
+ import logging
2
+
3
+ from sqlalchemy import select, func, delete
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.orm import selectinload
6
+
7
+ from backend.models.conversation import Conversation, Message
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ConversationService:
13
+ def __init__(self, db: AsyncSession, user_id: str | None = None):
14
+ self.db = db
15
+ self.user_id = user_id
16
+
17
+ async def create(self, model: str = "nyx", title: str = "New Conversation", project_id: str | None = None) -> Conversation:
18
+ conv = Conversation(model=model, title=title, user_id=self.user_id, project_id=project_id)
19
+ self.db.add(conv)
20
+ await self.db.commit()
21
+ await self.db.refresh(conv)
22
+ return conv
23
+
24
+ async def list_all(self, limit: int = 50, offset: int = 0, project_id: str | None = None) -> tuple[list[dict], int]:
25
+ base = select(Conversation).where(Conversation.user_id == self.user_id) if self.user_id else select(Conversation)
26
+
27
+ # Count
28
+ count_stmt = select(func.count()).select_from(base.subquery())
29
+ total = (await self.db.execute(count_stmt)).scalar() or 0
30
+
31
+ # Data
32
+ stmt = (
33
+ select(
34
+ Conversation.id,
35
+ Conversation.title,
36
+ Conversation.model,
37
+ Conversation.updated_at,
38
+ Conversation.project_id,
39
+ func.count(Message.id).label("message_count"),
40
+ )
41
+ .outerjoin(Message)
42
+ .group_by(Conversation.id)
43
+ .order_by(Conversation.updated_at.desc())
44
+ .limit(limit)
45
+ .offset(offset)
46
+ )
47
+ if self.user_id:
48
+ stmt = stmt.where(Conversation.user_id == self.user_id)
49
+ if project_id is not None:
50
+ stmt = stmt.where(Conversation.project_id == project_id)
51
+
52
+ result = await self.db.execute(stmt)
53
+ rows = result.all()
54
+ items = [
55
+ {
56
+ "id": r.id,
57
+ "title": r.title,
58
+ "model": r.model,
59
+ "updated_at": r.updated_at,
60
+ "message_count": r.message_count,
61
+ "project_id": r.project_id,
62
+ }
63
+ for r in rows
64
+ ]
65
+ return items, total
66
+
67
+ async def search(self, query: str, limit: int = 50, offset: int = 0) -> tuple[list[dict], int]:
68
+ """Search conversations by title or message content."""
69
+ pattern = f"%{query}%"
70
+
71
+ # Find matching conversation IDs
72
+ matching_ids_stmt = (
73
+ select(Conversation.id)
74
+ .outerjoin(Message)
75
+ .where(
76
+ (Conversation.title.ilike(pattern)) | (Message.content.ilike(pattern))
77
+ )
78
+ .group_by(Conversation.id)
79
+ )
80
+ if self.user_id:
81
+ matching_ids_stmt = matching_ids_stmt.where(Conversation.user_id == self.user_id)
82
+
83
+ # Count
84
+ count_stmt = select(func.count()).select_from(matching_ids_stmt.subquery())
85
+ total = (await self.db.execute(count_stmt)).scalar() or 0
86
+
87
+ # Data
88
+ stmt = (
89
+ select(
90
+ Conversation.id,
91
+ Conversation.title,
92
+ Conversation.model,
93
+ Conversation.updated_at,
94
+ Conversation.project_id,
95
+ func.count(Message.id).label("message_count"),
96
+ )
97
+ .outerjoin(Message)
98
+ .where(Conversation.id.in_(matching_ids_stmt))
99
+ .group_by(Conversation.id)
100
+ .order_by(Conversation.updated_at.desc())
101
+ .limit(limit)
102
+ .offset(offset)
103
+ )
104
+ if self.user_id:
105
+ stmt = stmt.where(Conversation.user_id == self.user_id)
106
+
107
+ result = await self.db.execute(stmt)
108
+ rows = result.all()
109
+ items = [
110
+ {
111
+ "id": r.id,
112
+ "title": r.title,
113
+ "model": r.model,
114
+ "updated_at": r.updated_at,
115
+ "message_count": r.message_count,
116
+ "project_id": r.project_id,
117
+ }
118
+ for r in rows
119
+ ]
120
+ return items, total
121
+
122
+ async def get(self, conversation_id: str) -> Conversation | None:
123
+ stmt = (
124
+ select(Conversation)
125
+ .options(selectinload(Conversation.messages), selectinload(Conversation.project))
126
+ .where(Conversation.id == conversation_id)
127
+ )
128
+ if self.user_id:
129
+ stmt = stmt.where(Conversation.user_id == self.user_id)
130
+ result = await self.db.execute(stmt)
131
+ return result.scalar_one_or_none()
132
+
133
+ async def get_fresh(self, conversation_id: str) -> Conversation | None:
134
+ """Get conversation with fresh data, bypassing the identity map cache."""
135
+ stmt = (
136
+ select(Conversation)
137
+ .options(selectinload(Conversation.messages), selectinload(Conversation.project))
138
+ .where(Conversation.id == conversation_id)
139
+ .execution_options(populate_existing=True)
140
+ )
141
+ if self.user_id:
142
+ stmt = stmt.where(Conversation.user_id == self.user_id)
143
+ result = await self.db.execute(stmt)
144
+ return result.scalar_one_or_none()
145
+
146
+ async def delete(self, conversation_id: str) -> bool:
147
+ conv = await self.get(conversation_id)
148
+ if not conv:
149
+ return False
150
+ await self.db.delete(conv)
151
+ await self.db.commit()
152
+ return True
153
+
154
+ async def rename(self, conversation_id: str, title: str) -> Conversation | None:
155
+ conv = await self.get(conversation_id)
156
+ if not conv:
157
+ return None
158
+ conv.title = title
159
+ await self.db.commit()
160
+ await self.db.refresh(conv)
161
+ return conv
162
+
163
+ async def add_message(self, conversation_id: str, role: str, content: str, image_url: str | None = None) -> Message:
164
+ msg = Message(conversation_id=conversation_id, role=role, content=content, image_url=image_url)
165
+ self.db.add(msg)
166
+ await self.db.commit()
167
+ await self.db.refresh(msg)
168
+ return msg
169
+
170
+ async def get_message(self, message_id: str) -> Message | None:
171
+ stmt = select(Message).where(Message.id == message_id)
172
+ result = await self.db.execute(stmt)
173
+ return result.scalar_one_or_none()
174
+
175
+ async def edit_message(self, message_id: str, new_content: str) -> Message | None:
176
+ """Edit a message and delete all subsequent messages in the conversation."""
177
+ msg = await self.get_message(message_id)
178
+ if not msg:
179
+ return None
180
+
181
+ # Verify user owns this conversation
182
+ conv = await self.get(msg.conversation_id)
183
+ if not conv:
184
+ return None
185
+
186
+ # Delete all messages after this one
187
+ await self.db.execute(
188
+ delete(Message).where(
189
+ Message.conversation_id == msg.conversation_id,
190
+ Message.created_at > msg.created_at,
191
+ )
192
+ )
193
+
194
+ # Update the message content
195
+ msg.content = new_content
196
+ await self.db.commit()
197
+ await self.db.refresh(msg)
198
+ return msg
199
+
200
+ async def search_messages(self, query: str, limit: int = 20, offset: int = 0) -> list[dict]:
201
+ """Search message content and return snippets with conversation context."""
202
+ pattern = f"%{query}%"
203
+ stmt = (
204
+ select(
205
+ Message.id,
206
+ Message.role,
207
+ Message.content,
208
+ Message.created_at,
209
+ Conversation.id.label("conversation_id"),
210
+ Conversation.title.label("conversation_title"),
211
+ )
212
+ .join(Conversation, Message.conversation_id == Conversation.id)
213
+ .where(Message.content.ilike(pattern))
214
+ .order_by(Message.created_at.desc())
215
+ .limit(limit)
216
+ .offset(offset)
217
+ )
218
+ if self.user_id:
219
+ stmt = stmt.where(Conversation.user_id == self.user_id)
220
+
221
+ result = await self.db.execute(stmt)
222
+ rows = result.all()
223
+ items = []
224
+ for r in rows:
225
+ # Build snippet: ~100 chars around first match
226
+ content = r.content
227
+ lower_content = content.lower()
228
+ idx = lower_content.find(query.lower())
229
+ if idx >= 0:
230
+ start = max(0, idx - 50)
231
+ end = min(len(content), idx + len(query) + 50)
232
+ snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
233
+ else:
234
+ snippet = content[:100] + ("..." if len(content) > 100 else "")
235
+ items.append({
236
+ "message_id": r.id,
237
+ "role": r.role,
238
+ "snippet": snippet,
239
+ "created_at": r.created_at,
240
+ "conversation_id": r.conversation_id,
241
+ "conversation_title": r.conversation_title,
242
+ })
243
+ return items
244
+
245
+ async def delete_last_assistant_message(self, conversation_id: str) -> bool:
246
+ """Delete the last assistant message in a conversation for regeneration."""
247
+ conv = await self.get(conversation_id)
248
+ if not conv or not conv.messages:
249
+ return False
250
+
251
+ # Find last assistant message
252
+ messages = sorted(conv.messages, key=lambda m: m.created_at)
253
+ last_assistant = None
254
+ for m in reversed(messages):
255
+ if m.role == "assistant":
256
+ last_assistant = m
257
+ break
258
+
259
+ if not last_assistant:
260
+ return False
261
+
262
+ await self.db.delete(last_assistant)
263
+ await self.db.commit()
264
+ return True
@@ -0,0 +1,193 @@
1
+ import hashlib
2
+ import logging
3
+ import random
4
+ import string
5
+ import uuid
6
+ from datetime import UTC, datetime, timedelta
7
+
8
+ from sqlalchemy import delete, select
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+
11
+ from backend.models.api_key import ApiKey
12
+ from backend.models.device_code import DeviceCode
13
+ from backend.models.user import User
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _generate_user_code() -> str:
19
+ """Generate a human-readable code like ABCD-1234."""
20
+ letters = "".join(random.choices(string.ascii_uppercase, k=4))
21
+ digits = "".join(random.choices(string.digits, k=4))
22
+ return f"{letters}-{digits}"
23
+
24
+
25
+ def _generate_device_code() -> str:
26
+ """Generate a 64-character hex device code."""
27
+ return uuid.uuid4().hex + uuid.uuid4().hex
28
+
29
+
30
+ class DeviceAuthService:
31
+ EXPIRY_MINUTES = 10
32
+ POLL_INTERVAL = 5
33
+
34
+ def __init__(self, db: AsyncSession):
35
+ self.db = db
36
+
37
+ async def create_device_code(self, client_name: str) -> DeviceCode:
38
+ """Create a new device code for the CLI auth flow."""
39
+ user_code = await self._unique_user_code()
40
+
41
+ device = DeviceCode(
42
+ device_code=_generate_device_code(),
43
+ user_code=user_code,
44
+ client_name=client_name,
45
+ expires_at=datetime.now(UTC) + timedelta(minutes=self.EXPIRY_MINUTES),
46
+ interval=self.POLL_INTERVAL,
47
+ )
48
+ self.db.add(device)
49
+ await self.db.commit()
50
+ await self.db.refresh(device)
51
+ logger.info("Device code created: user_code=%s client=%s", user_code, client_name)
52
+ return device
53
+
54
+ async def approve(self, user_code: str, user: User) -> tuple[str, DeviceCode] | None:
55
+ """Approve a pending device code.
56
+
57
+ Returns (raw_api_key, device_code) or None if not found / expired.
58
+ """
59
+ device = await self._find_pending_device(user_code)
60
+ if not device:
61
+ return None
62
+
63
+ raw_key = self._create_raw_cli_key()
64
+ api_key = self._build_api_key(raw_key, user.id, device.client_name)
65
+ self.db.add(api_key)
66
+ await self.db.flush()
67
+
68
+ self._mark_approved(device, user.id, api_key.id, raw_key)
69
+ await self.db.commit()
70
+
71
+ logger.info("Device code approved: user_code=%s user=%s", user_code, user.id)
72
+ return raw_key, device
73
+
74
+ async def poll_token(
75
+ self, device_code: str
76
+ ) -> tuple[str, dict | None]:
77
+ """Poll for token status.
78
+
79
+ Returns (status, user_info_or_none).
80
+ status: authorization_pending | expired_token | access_denied | approved
81
+ user_info: {"raw_key": ..., "plan": ..., "email": ...} when approved.
82
+ """
83
+ device = await self._find_device_by_code(device_code)
84
+ if not device:
85
+ return "expired_token", None
86
+
87
+ if self._is_expired_pending(device):
88
+ device.status = "expired"
89
+ await self.db.commit()
90
+ return "expired_token", None
91
+
92
+ if device.status in ("denied", "expired"):
93
+ status = "access_denied" if device.status == "denied" else "expired_token"
94
+ return status, None
95
+
96
+ if device.status == "pending":
97
+ return "authorization_pending", None
98
+
99
+ if device.status == "approved" and device.cli_token:
100
+ return await self._consume_approved_token(device)
101
+
102
+ return "authorization_pending", None
103
+
104
+ async def cleanup_expired(self) -> int:
105
+ """Delete device codes older than 1 hour."""
106
+ cutoff = datetime.now(UTC) - timedelta(hours=1)
107
+ stmt = delete(DeviceCode).where(DeviceCode.created_at < cutoff)
108
+ result = await self.db.execute(stmt)
109
+ await self.db.commit()
110
+ deleted = result.rowcount or 0
111
+ if deleted:
112
+ logger.info("Cleaned up %d expired device codes", deleted)
113
+ return deleted
114
+
115
+ # ------------------------------------------------------------------
116
+ # Private helpers
117
+ # ------------------------------------------------------------------
118
+
119
+ async def _unique_user_code(self) -> str:
120
+ """Generate a user code that does not collide with pending codes."""
121
+ for _ in range(5):
122
+ user_code = _generate_user_code()
123
+ stmt = select(DeviceCode).where(
124
+ DeviceCode.user_code == user_code,
125
+ DeviceCode.status == "pending",
126
+ )
127
+ result = await self.db.execute(stmt)
128
+ if not result.scalar_one_or_none():
129
+ return user_code
130
+ # Extremely unlikely fallback
131
+ return _generate_user_code()
132
+
133
+ async def _find_pending_device(self, user_code: str) -> DeviceCode | None:
134
+ stmt = select(DeviceCode).where(
135
+ DeviceCode.user_code == user_code,
136
+ DeviceCode.status == "pending",
137
+ DeviceCode.expires_at > datetime.now(UTC),
138
+ )
139
+ result = await self.db.execute(stmt)
140
+ return result.scalar_one_or_none()
141
+
142
+ async def _find_device_by_code(self, device_code: str) -> DeviceCode | None:
143
+ stmt = select(DeviceCode).where(DeviceCode.device_code == device_code)
144
+ result = await self.db.execute(stmt)
145
+ return result.scalar_one_or_none()
146
+
147
+ @staticmethod
148
+ def _create_raw_cli_key() -> str:
149
+ return f"sk-kairo-cli-{uuid.uuid4().hex}"
150
+
151
+ @staticmethod
152
+ def _build_api_key(raw_key: str, user_id: str, client_name: str | None) -> ApiKey:
153
+ prefix = raw_key[:24]
154
+ key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
155
+ return ApiKey(
156
+ user_id=user_id,
157
+ name=f"Kairo Code CLI ({client_name or 'cli'})",
158
+ key_prefix=prefix,
159
+ key_hash=key_hash,
160
+ key_type="cli",
161
+ )
162
+
163
+ @staticmethod
164
+ def _mark_approved(
165
+ device: DeviceCode, user_id: str, api_key_id: str, raw_key: str
166
+ ) -> None:
167
+ device.status = "approved"
168
+ device.user_id = user_id
169
+ device.cli_api_key_id = api_key_id
170
+ device.cli_token = raw_key
171
+ device.approved_at = datetime.now(UTC)
172
+
173
+ @staticmethod
174
+ def _is_expired_pending(device: DeviceCode) -> bool:
175
+ return device.expires_at < datetime.now(UTC) and device.status == "pending"
176
+
177
+ async def _consume_approved_token(
178
+ self, device: DeviceCode
179
+ ) -> tuple[str, dict | None]:
180
+ """Return the token and clear it (single-use retrieval)."""
181
+ raw_key = device.cli_token
182
+ device.cli_token = None # single-use: clear after first poll
183
+ await self.db.commit()
184
+
185
+ user = await self.db.get(User, device.user_id)
186
+ if not user:
187
+ return "access_denied", None
188
+
189
+ return "approved", {
190
+ "raw_key": raw_key,
191
+ "plan": user.plan,
192
+ "email": user.email,
193
+ }
@@ -0,0 +1,55 @@
1
+ import logging
2
+ import smtplib
3
+ from email.mime.multipart import MIMEMultipart
4
+ from email.mime.text import MIMEText
5
+
6
+ from backend.config import settings
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class EmailService:
12
+ def send_verification_email(self, to: str, token: str) -> None:
13
+ link = f"{settings.APP_BASE_URL}/verify-email?token={token}"
14
+ subject = "Verify your Kairo email"
15
+ html = (
16
+ f"<h2>Welcome to Kairo</h2>"
17
+ f"<p>Click the link below to verify your email address:</p>"
18
+ f'<p><a href="{link}">Verify Email</a></p>'
19
+ f"<p>Or copy this URL: {link}</p>"
20
+ f"<p>If you didn't create a Kairo account, you can ignore this email.</p>"
21
+ )
22
+ self._send(to, subject, html)
23
+
24
+ def send_password_reset_email(self, to: str, token: str) -> None:
25
+ link = f"{settings.APP_BASE_URL}/reset-password?token={token}"
26
+ subject = "Reset your Kairo password"
27
+ html = (
28
+ f"<h2>Password Reset</h2>"
29
+ f"<p>Click the link below to reset your password. This link expires in 1 hour.</p>"
30
+ f'<p><a href="{link}">Reset Password</a></p>'
31
+ f"<p>Or copy this URL: {link}</p>"
32
+ f"<p>If you didn't request this, you can ignore this email.</p>"
33
+ )
34
+ self._send(to, subject, html)
35
+
36
+ def _send(self, to: str, subject: str, html_body: str) -> None:
37
+ if not settings.SMTP_USERNAME:
38
+ logger.warning("SMTP not configured, skipping email to %s", to)
39
+ return
40
+
41
+ msg = MIMEMultipart("alternative")
42
+ msg["Subject"] = subject
43
+ msg["From"] = f"{settings.SMTP_FROM_NAME} <{settings.SMTP_FROM_EMAIL}>"
44
+ msg["To"] = to
45
+ msg.attach(MIMEText(html_body, "html"))
46
+
47
+ try:
48
+ with smtplib.SMTP(settings.SMTP_HOST, settings.SMTP_PORT) as server:
49
+ server.starttls()
50
+ server.login(settings.SMTP_USERNAME, settings.SMTP_PASSWORD)
51
+ server.sendmail(settings.SMTP_FROM_EMAIL, to, msg.as_string())
52
+ logger.info("Email sent to %s: %s", to, subject)
53
+ except Exception:
54
+ logger.exception("Failed to send email to %s", to)
55
+ raise
@@ -0,0 +1,181 @@
1
+ import logging
2
+ import uuid
3
+ from datetime import datetime, timezone
4
+
5
+ import boto3
6
+ import httpx
7
+ from sqlalchemy import select, func
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+
10
+ from backend.config import settings
11
+ from backend.models.image_generation import ImageGeneration
12
+ from backend.models.user import PlanType, User
13
+ from backend.services.conversation_service import ConversationService
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ IMAGE_LIMITS: dict[str, dict[str, int]] = {
18
+ PlanType.PRO.value: {
19
+ "daily": settings.PRO_DAILY_IMAGE_LIMIT,
20
+ "monthly": settings.PRO_MONTHLY_IMAGE_LIMIT,
21
+ },
22
+ PlanType.MAX.value: {
23
+ "daily": settings.MAX_DAILY_IMAGE_LIMIT,
24
+ "monthly": settings.MAX_MONTHLY_IMAGE_LIMIT,
25
+ },
26
+ }
27
+
28
+
29
+ class ImageService:
30
+ def __init__(self, db: AsyncSession, user_id: str):
31
+ self.db = db
32
+ self.user_id = user_id
33
+ self.conv_service = ConversationService(db, user_id=user_id)
34
+
35
+ async def check_access(self) -> tuple[bool, str]:
36
+ """Check plan tier and image usage limits."""
37
+ user = await self.db.get(User, self.user_id)
38
+ if not user:
39
+ return False, "User not found."
40
+
41
+ if user.plan not in (PlanType.PRO.value, PlanType.MAX.value):
42
+ return False, "Image generation requires a Pro or Max plan."
43
+
44
+ limits = IMAGE_LIMITS.get(user.plan)
45
+ if not limits:
46
+ return False, "No image limits configured for your plan."
47
+
48
+ now = datetime.now(timezone.utc)
49
+ start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
50
+ start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
51
+
52
+ daily_count = (await self.db.execute(
53
+ select(func.count())
54
+ .where(ImageGeneration.user_id == self.user_id)
55
+ .where(ImageGeneration.created_at >= start_of_day)
56
+ )).scalar() or 0
57
+
58
+ if daily_count >= limits["daily"]:
59
+ return False, "Daily image generation limit reached."
60
+
61
+ monthly_count = (await self.db.execute(
62
+ select(func.count())
63
+ .where(ImageGeneration.user_id == self.user_id)
64
+ .where(ImageGeneration.created_at >= start_of_month)
65
+ )).scalar() or 0
66
+
67
+ if monthly_count >= limits["monthly"]:
68
+ return False, "Monthly image generation limit reached."
69
+
70
+ return True, ""
71
+
72
+ async def _save_result(
73
+ self,
74
+ image_bytes: bytes,
75
+ prompt: str,
76
+ conversation_id: str | None,
77
+ width: int = 1024,
78
+ height: int = 1024,
79
+ ) -> dict:
80
+ """Upload image to S3, save to conversation, record usage."""
81
+ # Upload to S3
82
+ image_id = str(uuid.uuid4())
83
+ s3_key = f"generated/{image_id}.png"
84
+ s3 = boto3.client("s3", region_name=settings.S3_IMAGES_REGION)
85
+ s3.put_object(
86
+ Bucket=settings.S3_IMAGES_BUCKET,
87
+ Key=s3_key,
88
+ Body=image_bytes,
89
+ ContentType="image/png",
90
+ )
91
+ image_url = f"https://{settings.S3_IMAGES_BUCKET}.s3.{settings.S3_IMAGES_REGION}.amazonaws.com/{s3_key}"
92
+
93
+ # Get or create conversation
94
+ conv = None
95
+ if conversation_id:
96
+ conv = await self.conv_service.get(conversation_id)
97
+ if not conv:
98
+ conv = await self.conv_service.create(model="flux-dev")
99
+
100
+ # Save user message (the prompt)
101
+ await self.conv_service.add_message(conv.id, "user", prompt)
102
+
103
+ # Save assistant message with image_url
104
+ msg = await self.conv_service.add_message(
105
+ conv.id, "assistant", f"![Generated image]({image_url})",
106
+ image_url=image_url,
107
+ )
108
+ msg_id = msg.id
109
+
110
+ # Record in image_generations table
111
+ record = ImageGeneration(
112
+ user_id=self.user_id,
113
+ conversation_id=conv.id,
114
+ prompt=prompt,
115
+ image_url=image_url,
116
+ width=width,
117
+ height=height,
118
+ )
119
+ self.db.add(record)
120
+ await self.db.commit()
121
+
122
+ # Auto-title if new conversation
123
+ if conv.title == "New Conversation":
124
+ title = prompt[:47] + "..." if len(prompt) > 50 else prompt
125
+ await self.conv_service.rename(conv.id, title)
126
+
127
+ return {
128
+ "image_url": image_url,
129
+ "message_id": msg_id,
130
+ "conversation_id": conv.id,
131
+ "prompt": prompt,
132
+ }
133
+
134
+ async def generate(
135
+ self,
136
+ prompt: str,
137
+ conversation_id: str | None,
138
+ width: int = 1024,
139
+ height: int = 1024,
140
+ ) -> dict:
141
+ """Call FLUX text2img service, upload to S3, save to conversation."""
142
+ base_url = settings.FLUX_BASE_URL.rstrip("/")
143
+ headers = {}
144
+ if settings.FLUX_API_KEY:
145
+ headers["Authorization"] = f"Bearer {settings.FLUX_API_KEY}"
146
+
147
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) as client:
148
+ resp = await client.post(
149
+ f"{base_url}/generate",
150
+ json={"prompt": prompt, "width": width, "height": height},
151
+ headers=headers,
152
+ )
153
+ resp.raise_for_status()
154
+ image_bytes = resp.content
155
+
156
+ return await self._save_result(image_bytes, prompt, conversation_id, width, height)
157
+
158
+ async def generate_img2img(
159
+ self,
160
+ image_bytes: bytes,
161
+ prompt: str,
162
+ conversation_id: str | None,
163
+ strength: float = 0.75,
164
+ ) -> dict:
165
+ """Call FLUX img2img service, upload to S3, save to conversation."""
166
+ base_url = settings.FLUX_BASE_URL.rstrip("/")
167
+ headers = {}
168
+ if settings.FLUX_API_KEY:
169
+ headers["Authorization"] = f"Bearer {settings.FLUX_API_KEY}"
170
+
171
+ async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) as client:
172
+ resp = await client.post(
173
+ f"{base_url}/img2img",
174
+ files={"image": ("input.png", image_bytes, "image/png")},
175
+ data={"prompt": prompt, "strength": str(strength)},
176
+ headers=headers,
177
+ )
178
+ resp.raise_for_status()
179
+ result_bytes = resp.content
180
+
181
+ return await self._save_result(result_bytes, prompt, conversation_id)