kairo-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. image-service/main.py +178 -0
  2. infra/chat/app/main.py +84 -0
  3. kairo/backend/__init__.py +0 -0
  4. kairo/backend/api/__init__.py +0 -0
  5. kairo/backend/api/admin/__init__.py +23 -0
  6. kairo/backend/api/admin/audit.py +54 -0
  7. kairo/backend/api/admin/content.py +142 -0
  8. kairo/backend/api/admin/incidents.py +148 -0
  9. kairo/backend/api/admin/stats.py +125 -0
  10. kairo/backend/api/admin/system.py +87 -0
  11. kairo/backend/api/admin/users.py +279 -0
  12. kairo/backend/api/agents.py +94 -0
  13. kairo/backend/api/api_keys.py +85 -0
  14. kairo/backend/api/auth.py +116 -0
  15. kairo/backend/api/billing.py +41 -0
  16. kairo/backend/api/chat.py +72 -0
  17. kairo/backend/api/conversations.py +125 -0
  18. kairo/backend/api/device_auth.py +100 -0
  19. kairo/backend/api/files.py +83 -0
  20. kairo/backend/api/health.py +36 -0
  21. kairo/backend/api/images.py +80 -0
  22. kairo/backend/api/openai_compat.py +225 -0
  23. kairo/backend/api/projects.py +102 -0
  24. kairo/backend/api/usage.py +32 -0
  25. kairo/backend/api/webhooks.py +79 -0
  26. kairo/backend/app.py +297 -0
  27. kairo/backend/config.py +179 -0
  28. kairo/backend/core/__init__.py +0 -0
  29. kairo/backend/core/admin_auth.py +24 -0
  30. kairo/backend/core/api_key_auth.py +55 -0
  31. kairo/backend/core/database.py +28 -0
  32. kairo/backend/core/dependencies.py +70 -0
  33. kairo/backend/core/logging.py +23 -0
  34. kairo/backend/core/rate_limit.py +73 -0
  35. kairo/backend/core/security.py +29 -0
  36. kairo/backend/models/__init__.py +19 -0
  37. kairo/backend/models/agent.py +30 -0
  38. kairo/backend/models/api_key.py +25 -0
  39. kairo/backend/models/api_usage.py +29 -0
  40. kairo/backend/models/audit_log.py +26 -0
  41. kairo/backend/models/conversation.py +48 -0
  42. kairo/backend/models/device_code.py +30 -0
  43. kairo/backend/models/feature_flag.py +21 -0
  44. kairo/backend/models/image_generation.py +24 -0
  45. kairo/backend/models/incident.py +28 -0
  46. kairo/backend/models/project.py +28 -0
  47. kairo/backend/models/uptime_record.py +24 -0
  48. kairo/backend/models/usage.py +24 -0
  49. kairo/backend/models/user.py +49 -0
  50. kairo/backend/schemas/__init__.py +0 -0
  51. kairo/backend/schemas/admin/__init__.py +0 -0
  52. kairo/backend/schemas/admin/audit.py +28 -0
  53. kairo/backend/schemas/admin/content.py +53 -0
  54. kairo/backend/schemas/admin/stats.py +77 -0
  55. kairo/backend/schemas/admin/system.py +44 -0
  56. kairo/backend/schemas/admin/users.py +48 -0
  57. kairo/backend/schemas/agent.py +42 -0
  58. kairo/backend/schemas/api_key.py +30 -0
  59. kairo/backend/schemas/auth.py +57 -0
  60. kairo/backend/schemas/chat.py +26 -0
  61. kairo/backend/schemas/conversation.py +39 -0
  62. kairo/backend/schemas/device_auth.py +40 -0
  63. kairo/backend/schemas/image.py +15 -0
  64. kairo/backend/schemas/openai_compat.py +76 -0
  65. kairo/backend/schemas/project.py +21 -0
  66. kairo/backend/schemas/status.py +81 -0
  67. kairo/backend/schemas/usage.py +15 -0
  68. kairo/backend/services/__init__.py +0 -0
  69. kairo/backend/services/admin/__init__.py +0 -0
  70. kairo/backend/services/admin/audit_service.py +78 -0
  71. kairo/backend/services/admin/content_service.py +119 -0
  72. kairo/backend/services/admin/incident_service.py +94 -0
  73. kairo/backend/services/admin/stats_service.py +281 -0
  74. kairo/backend/services/admin/system_service.py +126 -0
  75. kairo/backend/services/admin/user_service.py +157 -0
  76. kairo/backend/services/agent_service.py +107 -0
  77. kairo/backend/services/api_key_service.py +66 -0
  78. kairo/backend/services/api_usage_service.py +126 -0
  79. kairo/backend/services/auth_service.py +101 -0
  80. kairo/backend/services/chat_service.py +501 -0
  81. kairo/backend/services/conversation_service.py +264 -0
  82. kairo/backend/services/device_auth_service.py +193 -0
  83. kairo/backend/services/email_service.py +55 -0
  84. kairo/backend/services/image_service.py +181 -0
  85. kairo/backend/services/llm_service.py +186 -0
  86. kairo/backend/services/project_service.py +109 -0
  87. kairo/backend/services/status_service.py +167 -0
  88. kairo/backend/services/stripe_service.py +78 -0
  89. kairo/backend/services/usage_service.py +150 -0
  90. kairo/backend/services/web_search_service.py +96 -0
  91. kairo/migrations/env.py +60 -0
  92. kairo/migrations/versions/001_initial.py +55 -0
  93. kairo/migrations/versions/002_usage_tracking_and_indexes.py +66 -0
  94. kairo/migrations/versions/003_username_to_email.py +21 -0
  95. kairo/migrations/versions/004_add_plans_and_verification.py +67 -0
  96. kairo/migrations/versions/005_add_projects.py +52 -0
  97. kairo/migrations/versions/006_add_image_generation.py +63 -0
  98. kairo/migrations/versions/007_add_admin_portal.py +107 -0
  99. kairo/migrations/versions/008_add_device_code_auth.py +76 -0
  100. kairo/migrations/versions/009_add_status_page.py +65 -0
  101. kairo/tools/extract_claude_data.py +465 -0
  102. kairo/tools/filter_claude_data.py +303 -0
  103. kairo/tools/generate_curated_data.py +157 -0
  104. kairo/tools/mix_training_data.py +295 -0
  105. kairo_code/__init__.py +3 -0
  106. kairo_code/agents/__init__.py +25 -0
  107. kairo_code/agents/architect.py +98 -0
  108. kairo_code/agents/audit.py +100 -0
  109. kairo_code/agents/base.py +463 -0
  110. kairo_code/agents/coder.py +155 -0
  111. kairo_code/agents/database.py +77 -0
  112. kairo_code/agents/docs.py +88 -0
  113. kairo_code/agents/explorer.py +62 -0
  114. kairo_code/agents/guardian.py +80 -0
  115. kairo_code/agents/planner.py +66 -0
  116. kairo_code/agents/reviewer.py +91 -0
  117. kairo_code/agents/security.py +94 -0
  118. kairo_code/agents/terraform.py +88 -0
  119. kairo_code/agents/testing.py +97 -0
  120. kairo_code/agents/uiux.py +88 -0
  121. kairo_code/auth.py +232 -0
  122. kairo_code/config.py +172 -0
  123. kairo_code/conversation.py +173 -0
  124. kairo_code/heartbeat.py +63 -0
  125. kairo_code/llm.py +291 -0
  126. kairo_code/logging_config.py +156 -0
  127. kairo_code/main.py +818 -0
  128. kairo_code/router.py +217 -0
  129. kairo_code/sandbox.py +248 -0
  130. kairo_code/settings.py +183 -0
  131. kairo_code/tools/__init__.py +51 -0
  132. kairo_code/tools/analysis.py +509 -0
  133. kairo_code/tools/base.py +417 -0
  134. kairo_code/tools/code.py +58 -0
  135. kairo_code/tools/definitions.py +617 -0
  136. kairo_code/tools/files.py +315 -0
  137. kairo_code/tools/review.py +390 -0
  138. kairo_code/tools/search.py +185 -0
  139. kairo_code/ui.py +418 -0
  140. kairo_code-0.1.0.dist-info/METADATA +13 -0
  141. kairo_code-0.1.0.dist-info/RECORD +144 -0
  142. kairo_code-0.1.0.dist-info/WHEEL +5 -0
  143. kairo_code-0.1.0.dist-info/entry_points.txt +2 -0
  144. kairo_code-0.1.0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,186 @@
1
+ import json
2
+ import logging
3
+ from collections.abc import AsyncGenerator
4
+
5
+ import httpx
6
+
7
+ from backend.config import settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LLMService:
13
+ def __init__(self):
14
+ self.base_url = settings.VLLM_BASE_URL.rstrip("/")
15
+ self.lite_base_url = settings.VLLM_LITE_BASE_URL.rstrip("/") if settings.VLLM_LITE_BASE_URL else ""
16
+ self.model_map = settings.MODEL_MAP
17
+ headers = {}
18
+ if settings.VLLM_API_KEY:
19
+ headers["Authorization"] = f"Bearer {settings.VLLM_API_KEY}"
20
+ self._client = httpx.AsyncClient(
21
+ timeout=httpx.Timeout(120.0, connect=10.0),
22
+ headers=headers,
23
+ )
24
+ self._lite_client: httpx.AsyncClient | None = None
25
+ if self.lite_base_url:
26
+ self._lite_client = httpx.AsyncClient(
27
+ timeout=httpx.Timeout(120.0, connect=10.0),
28
+ headers=headers,
29
+ )
30
+ self.primary_healthy = True
31
+ self.lite_healthy = False
32
+
33
+ async def close(self):
34
+ await self._client.aclose()
35
+ if self._lite_client:
36
+ await self._lite_client.aclose()
37
+
38
+ def resolve_model(self, name: str) -> str:
39
+ return self.model_map.get(name, name)
40
+
41
+ async def check_primary_health(self) -> bool:
42
+ try:
43
+ resp = await self._client.get(f"{self.base_url}/health")
44
+ healthy = resp.status_code == 200
45
+ if self.primary_healthy != healthy:
46
+ logger.info("Primary vLLM health changed: %s → %s", self.primary_healthy, healthy)
47
+ self.primary_healthy = healthy
48
+ return healthy
49
+ except httpx.HTTPError:
50
+ if self.primary_healthy:
51
+ logger.warning("Primary vLLM became unhealthy")
52
+ self.primary_healthy = False
53
+ return False
54
+
55
+ async def check_lite_health(self) -> bool:
56
+ if not self._lite_client or not self.lite_base_url:
57
+ self.lite_healthy = False
58
+ return False
59
+ try:
60
+ resp = await self._lite_client.get(f"{self.lite_base_url}/health")
61
+ healthy = resp.status_code == 200
62
+ if self.lite_healthy != healthy:
63
+ logger.info("Lite vLLM health changed: %s → %s", self.lite_healthy, healthy)
64
+ self.lite_healthy = healthy
65
+ return healthy
66
+ except httpx.HTTPError:
67
+ if self.lite_healthy:
68
+ logger.warning("Lite vLLM became unhealthy")
69
+ self.lite_healthy = False
70
+ return False
71
+
72
+ def _select_backend(self, model: str) -> tuple[httpx.AsyncClient, str, str, bool]:
73
+ """Select the best available backend for the given model.
74
+ Returns (client, base_url, model_id, is_fallback).
75
+ """
76
+ # User explicitly selected nyx-lite
77
+ if model == "nyx-lite" and self._lite_client and self.lite_base_url:
78
+ return self._lite_client, self.lite_base_url, "nyx-lite", False
79
+ # Auto-fallback: nyx requested but primary is down
80
+ if model == "nyx" and not self.primary_healthy and self._lite_client and self.lite_base_url:
81
+ return self._lite_client, self.lite_base_url, "nyx-lite", True
82
+ return self._client, self.base_url, self.resolve_model(model), False
83
+
84
+ async def stream_chat(
85
+ self,
86
+ messages: list[dict],
87
+ model: str,
88
+ temperature: float = 0.7,
89
+ max_tokens: int = 2048,
90
+ tools: list[dict] | None = None,
91
+ ) -> AsyncGenerator[str | dict, None]:
92
+ client, base_url, model_id, is_fallback = self._select_backend(model)
93
+
94
+ if is_fallback:
95
+ yield {"type": "fallback", "model": "nyx-lite"}
96
+
97
+ payload = {
98
+ "model": model_id,
99
+ "messages": messages,
100
+ "temperature": temperature,
101
+ "max_tokens": max_tokens,
102
+ "stream": True,
103
+ "stream_options": {"include_usage": True},
104
+ }
105
+ if tools:
106
+ payload["tools"] = tools
107
+ payload["tool_choice"] = "auto"
108
+
109
+ logger.info("LLM request: model=%s (resolved=%s fallback=%s) msgs=%d temp=%.1f tools=%s",
110
+ model, model_id, is_fallback, len(messages), temperature, bool(tools))
111
+
112
+ async with client.stream(
113
+ "POST",
114
+ f"{base_url}/v1/chat/completions",
115
+ json=payload,
116
+ ) as response:
117
+ response.raise_for_status()
118
+ usage_data = None
119
+ tool_calls_acc: dict[int, dict] = {}
120
+ async for line in response.aiter_lines():
121
+ if not line.startswith("data: "):
122
+ continue
123
+ data = line[6:]
124
+ if data.strip() == "[DONE]":
125
+ break
126
+ try:
127
+ chunk = json.loads(data)
128
+ # Capture usage from final chunk
129
+ if "usage" in chunk and chunk["usage"]:
130
+ usage_data = chunk["usage"]
131
+ if not chunk.get("choices"):
132
+ continue
133
+ choice = chunk["choices"][0]
134
+ delta = choice.get("delta", {})
135
+
136
+ # Accumulate tool calls across streaming deltas
137
+ if "tool_calls" in delta:
138
+ for tc in delta["tool_calls"]:
139
+ idx = tc["index"]
140
+ if idx not in tool_calls_acc:
141
+ tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
142
+ if tc.get("id"):
143
+ tool_calls_acc[idx]["id"] = tc["id"]
144
+ func = tc.get("function", {})
145
+ if func.get("name"):
146
+ tool_calls_acc[idx]["name"] = func["name"]
147
+ if func.get("arguments"):
148
+ tool_calls_acc[idx]["arguments"] += func["arguments"]
149
+
150
+ content = delta.get("content")
151
+ if content:
152
+ yield content
153
+
154
+ # Check finish reason for tool calls
155
+ finish_reason = choice.get("finish_reason")
156
+ if finish_reason == "tool_calls" and tool_calls_acc:
157
+ yield {"type": "tool_calls", "calls": list(tool_calls_acc.values())}
158
+ tool_calls_acc = {}
159
+ except (json.JSONDecodeError, KeyError, IndexError):
160
+ continue
161
+ # Yield usage as a special sentinel dict (not a string)
162
+ if usage_data:
163
+ yield usage_data # type: ignore[misc]
164
+
165
+ async def check_health(self) -> bool:
166
+ try:
167
+ resp = await self._client.get(f"{self.base_url}/health")
168
+ ok = resp.status_code == 200
169
+ if not ok:
170
+ logger.warning("vLLM health check failed: status %d", resp.status_code)
171
+ return ok
172
+ except httpx.HTTPError as e:
173
+ logger.warning("vLLM health check error: %s", e)
174
+ return False
175
+
176
+ async def list_available_models(self) -> list[dict]:
177
+ primary_ok = await self.check_health()
178
+ lite_ok = await self.check_lite_health()
179
+ result = []
180
+ for info in settings.MODEL_INFO:
181
+ if info["id"] == "nyx-lite":
182
+ if self._lite_client:
183
+ result.append({**info, "available": lite_ok})
184
+ else:
185
+ result.append({**info, "available": primary_ok})
186
+ return result
@@ -0,0 +1,109 @@
1
+ import logging
2
+
3
+ from sqlalchemy import select, func
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.orm import selectinload
6
+
7
+ from backend.models.conversation import Conversation
8
+ from backend.models.project import Project
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ProjectService:
14
+ def __init__(self, db: AsyncSession, user_id: str):
15
+ self.db = db
16
+ self.user_id = user_id
17
+
18
+ async def create(self, name: str) -> Project:
19
+ project = Project(name=name, user_id=self.user_id)
20
+ self.db.add(project)
21
+ await self.db.commit()
22
+ await self.db.refresh(project)
23
+ return project
24
+
25
+ async def get(self, project_id: str) -> Project | None:
26
+ stmt = (
27
+ select(Project)
28
+ .options(selectinload(Project.conversations))
29
+ .where(Project.id == project_id, Project.user_id == self.user_id)
30
+ )
31
+ result = await self.db.execute(stmt)
32
+ return result.scalar_one_or_none()
33
+
34
+ async def list_all(self) -> list[dict]:
35
+ stmt = (
36
+ select(
37
+ Project.id,
38
+ Project.name,
39
+ Project.instructions,
40
+ Project.created_at,
41
+ Project.updated_at,
42
+ func.count(Conversation.id).label("conversation_count"),
43
+ )
44
+ .outerjoin(Conversation, Conversation.project_id == Project.id)
45
+ .where(Project.user_id == self.user_id)
46
+ .group_by(Project.id)
47
+ .order_by(Project.updated_at.desc())
48
+ )
49
+ result = await self.db.execute(stmt)
50
+ rows = result.all()
51
+ return [
52
+ {
53
+ "id": r.id,
54
+ "name": r.name,
55
+ "instructions": r.instructions,
56
+ "created_at": r.created_at,
57
+ "updated_at": r.updated_at,
58
+ "conversation_count": r.conversation_count,
59
+ }
60
+ for r in rows
61
+ ]
62
+
63
+ async def update(self, project_id: str, **kwargs) -> Project | None:
64
+ project = await self.get(project_id)
65
+ if not project:
66
+ return None
67
+ for key, value in kwargs.items():
68
+ if hasattr(project, key):
69
+ setattr(project, key, value)
70
+ await self.db.commit()
71
+ await self.db.refresh(project)
72
+ return project
73
+
74
+ async def delete(self, project_id: str) -> bool:
75
+ project = await self.get(project_id)
76
+ if not project:
77
+ return False
78
+ await self.db.delete(project)
79
+ await self.db.commit()
80
+ return True
81
+
82
+ async def add_conversation(self, project_id: str, conversation_id: str) -> bool:
83
+ project = await self.get(project_id)
84
+ if not project:
85
+ return False
86
+ stmt = (
87
+ select(Conversation)
88
+ .where(Conversation.id == conversation_id, Conversation.user_id == self.user_id)
89
+ )
90
+ result = await self.db.execute(stmt)
91
+ conv = result.scalar_one_or_none()
92
+ if not conv:
93
+ return False
94
+ conv.project_id = project_id
95
+ await self.db.commit()
96
+ return True
97
+
98
+ async def remove_conversation(self, conversation_id: str) -> bool:
99
+ stmt = (
100
+ select(Conversation)
101
+ .where(Conversation.id == conversation_id, Conversation.user_id == self.user_id)
102
+ )
103
+ result = await self.db.execute(stmt)
104
+ conv = result.scalar_one_or_none()
105
+ if not conv:
106
+ return False
107
+ conv.project_id = None
108
+ await self.db.commit()
109
+ return True
@@ -0,0 +1,167 @@
1
+ import logging
2
+ import time
3
+ from collections import defaultdict
4
+ from datetime import datetime, timedelta, UTC
5
+
6
+ import httpx
7
+ from sqlalchemy import select
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+
10
+ from backend.config import settings
11
+ from backend.models.incident import Incident
12
+ from backend.models.uptime_record import UptimeRecord
13
+ from backend.services.llm_service import LLMService
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class StatusService:
19
+ """Public-facing service that aggregates health data for the status page."""
20
+
21
+ def __init__(self, db: AsyncSession, llm_service: LLMService):
22
+ self.db = db
23
+ self.llm = llm_service
24
+
25
+ async def get_status_page(self) -> dict:
26
+ """Aggregate all status data for the public status page."""
27
+ components = await self._get_component_statuses()
28
+ incidents = await self._get_recent_incidents(limit=10)
29
+ uptime = await self._get_uptime_history(days=90)
30
+ overall = self._determine_overall_status(components)
31
+ overall_uptime = self._calculate_overall_uptime(uptime)
32
+
33
+ return {
34
+ "status": overall,
35
+ "overall_uptime": overall_uptime,
36
+ "checked_at": datetime.now(UTC).isoformat(),
37
+ "components": components,
38
+ "recent_incidents": incidents,
39
+ "uptime_history": uptime,
40
+ }
41
+
42
+ async def _get_component_statuses(self) -> list[dict]:
43
+ """Check each component's current health status."""
44
+ components = []
45
+
46
+ # API/Chat -- if this code is executing, the API is up
47
+ components.append({"name": "API", "status": "operational", "latency_ms": None})
48
+
49
+ # Models (vLLM primary + lite)
50
+ primary_ok = getattr(self.llm, "primary_healthy", False)
51
+ lite_ok = getattr(self.llm, "lite_healthy", False)
52
+ if primary_ok:
53
+ model_status = "operational"
54
+ elif lite_ok:
55
+ model_status = "degraded"
56
+ else:
57
+ model_status = "down"
58
+ components.append({"name": "Models", "status": model_status, "latency_ms": None})
59
+
60
+ # Image Generation
61
+ if settings.FEATURE_IMAGE_GEN_ENABLED and settings.FLUX_BASE_URL:
62
+ try:
63
+ async with httpx.AsyncClient(timeout=5.0) as client:
64
+ t0 = time.monotonic()
65
+ resp = await client.get(f"{settings.FLUX_BASE_URL}/health")
66
+ latency = int((time.monotonic() - t0) * 1000)
67
+ if resp.status_code == 200:
68
+ components.append({"name": "Image Generation", "status": "operational", "latency_ms": latency})
69
+ else:
70
+ components.append({"name": "Image Generation", "status": "degraded", "latency_ms": latency})
71
+ except Exception:
72
+ components.append({"name": "Image Generation", "status": "down", "latency_ms": None})
73
+ else:
74
+ components.append({"name": "Image Generation", "status": "down", "latency_ms": None})
75
+
76
+ # Database -- if we got this far, DB is operational
77
+ components.append({"name": "Database", "status": "operational", "latency_ms": None})
78
+
79
+ # Check for any active (unresolved) incidents that override component status
80
+ active_incidents = await self._get_active_incidents()
81
+ component_map = {
82
+ "api": "API",
83
+ "chat": "API",
84
+ "models": "Models",
85
+ "images": "Image Generation",
86
+ "database": "Database",
87
+ }
88
+ for incident in active_incidents:
89
+ comp_name = component_map.get(incident.component)
90
+ if not comp_name:
91
+ continue
92
+ for comp in components:
93
+ if comp["name"] == comp_name:
94
+ # Escalate status based on incident severity
95
+ if incident.severity == "critical":
96
+ comp["status"] = "down"
97
+ elif incident.severity == "warning" and comp["status"] == "operational":
98
+ comp["status"] = "degraded"
99
+
100
+ return components
101
+
102
+ async def _get_active_incidents(self) -> list[Incident]:
103
+ """Get all unresolved incidents."""
104
+ stmt = (
105
+ select(Incident)
106
+ .where(Incident.status != "resolved")
107
+ .order_by(Incident.started_at.desc())
108
+ )
109
+ result = await self.db.execute(stmt)
110
+ return list(result.scalars().all())
111
+
112
+ async def _get_recent_incidents(self, limit: int = 10) -> list[dict]:
113
+ """Get recent incidents ordered by started_at descending."""
114
+ stmt = (
115
+ select(Incident)
116
+ .order_by(Incident.started_at.desc())
117
+ .limit(limit)
118
+ )
119
+ result = await self.db.execute(stmt)
120
+ return [
121
+ {
122
+ "id": i.id,
123
+ "title": i.title,
124
+ "description": i.description,
125
+ "severity": i.severity,
126
+ "component": i.component,
127
+ "status": i.status,
128
+ "started_at": i.started_at.isoformat(),
129
+ "resolved_at": i.resolved_at.isoformat() if i.resolved_at else None,
130
+ }
131
+ for i in result.scalars().all()
132
+ ]
133
+
134
+ async def _get_uptime_history(self, days: int = 90) -> list[dict]:
135
+ """Get daily uptime percentages for the last N days, averaged across components."""
136
+ start = datetime.now(UTC) - timedelta(days=days)
137
+ stmt = (
138
+ select(UptimeRecord)
139
+ .where(UptimeRecord.date >= start.date())
140
+ .order_by(UptimeRecord.date.desc())
141
+ )
142
+ result = await self.db.execute(stmt)
143
+
144
+ # Group by date, average across all components
145
+ date_uptimes: dict[str, list[float]] = defaultdict(list)
146
+ for r in result.scalars().all():
147
+ date_uptimes[str(r.date)].append(r.uptime_percent)
148
+
149
+ return [
150
+ {"date": d, "uptime_percent": round(sum(vals) / len(vals), 2)}
151
+ for d, vals in sorted(date_uptimes.items(), reverse=True)
152
+ ]
153
+
154
+ def _determine_overall_status(self, components: list[dict]) -> str:
155
+ """Determine overall system status from component statuses."""
156
+ statuses = [c["status"] for c in components]
157
+ if any(s == "down" for s in statuses):
158
+ return "outage"
159
+ if any(s == "degraded" for s in statuses):
160
+ return "degraded"
161
+ return "operational"
162
+
163
+ def _calculate_overall_uptime(self, history: list[dict]) -> float:
164
+ """Calculate overall uptime percentage from history entries."""
165
+ if not history:
166
+ return 100.0
167
+ return round(sum(h["uptime_percent"] for h in history) / len(history), 2)
@@ -0,0 +1,78 @@
1
+ import logging
2
+
3
+ import stripe
4
+
5
+ from backend.config import settings
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class StripeService:
11
+ def __init__(self) -> None:
12
+ stripe.api_key = settings.STRIPE_SECRET_KEY
13
+
14
+ def create_checkout_session(self, user_id: str, email: str, stripe_customer_id: str | None) -> tuple[str, str]:
15
+ """Create a Stripe Checkout session. Returns (checkout_url, customer_id)."""
16
+ if stripe_customer_id:
17
+ customer_id = stripe_customer_id
18
+ else:
19
+ customer = stripe.Customer.create(email=email, metadata={"kairo_user_id": user_id})
20
+ customer_id = customer.id
21
+
22
+ session = stripe.checkout.Session.create(
23
+ customer=customer_id,
24
+ mode="subscription",
25
+ line_items=[{"price": settings.STRIPE_PRO_PRICE_ID, "quantity": 1}],
26
+ success_url=f"{settings.APP_BASE_URL}/account?checkout=success",
27
+ cancel_url=f"{settings.APP_BASE_URL}/pricing?checkout=cancelled",
28
+ metadata={"kairo_user_id": user_id},
29
+ )
30
+ if not session.url:
31
+ raise RuntimeError("Stripe did not return a checkout URL")
32
+ return session.url, customer_id
33
+
34
+ def create_billing_portal_session(self, stripe_customer_id: str) -> str:
35
+ """Create a Stripe Billing Portal session. Returns the portal URL."""
36
+ session = stripe.billing_portal.Session.create(
37
+ customer=stripe_customer_id,
38
+ return_url=f"{settings.APP_BASE_URL}/account",
39
+ )
40
+ return session.url
41
+
42
+ def handle_webhook_event(self, payload: bytes, sig: str) -> dict:
43
+ """Verify and parse a Stripe webhook event. Returns action dict."""
44
+ event = stripe.Webhook.construct_event(payload, sig, settings.STRIPE_WEBHOOK_SECRET)
45
+ event_type = event["type"]
46
+ data = event["data"]["object"]
47
+
48
+ if event_type == "checkout.session.completed":
49
+ customer_id = data["customer"]
50
+ subscription_id = data["subscription"]
51
+ user_id = data.get("metadata", {}).get("kairo_user_id")
52
+ logger.info("Checkout completed: customer=%s user=%s", customer_id, user_id)
53
+ return {
54
+ "action": "upgrade",
55
+ "customer_id": customer_id,
56
+ "subscription_id": subscription_id,
57
+ "user_id": user_id,
58
+ }
59
+
60
+ if event_type == "customer.subscription.updated":
61
+ customer_id = data["customer"]
62
+ status = data["status"]
63
+ logger.info("Subscription updated: customer=%s status=%s", customer_id, status)
64
+ if status == "active":
65
+ return {"action": "sync_active", "customer_id": customer_id, "subscription_id": data["id"]}
66
+ return {"action": "noop"}
67
+
68
+ if event_type == "customer.subscription.deleted":
69
+ customer_id = data["customer"]
70
+ logger.info("Subscription deleted: customer=%s", customer_id)
71
+ return {"action": "downgrade", "customer_id": customer_id}
72
+
73
+ if event_type == "invoice.payment_failed":
74
+ customer_id = data["customer"]
75
+ logger.warning("Payment failed: customer=%s", customer_id)
76
+ return {"action": "payment_failed", "customer_id": customer_id}
77
+
78
+ return {"action": "noop"}
@@ -0,0 +1,150 @@
1
+ import logging
2
+ from datetime import datetime, timezone, timedelta
3
+
4
+ from sqlalchemy import select, func
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
+ from backend.config import settings
8
+ from backend.models.usage import UsageRecord
9
+ from backend.models.user import PlanType, User
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ PLAN_LIMITS: dict[str, dict[str, int]] = {
14
+ PlanType.FREE.value: {
15
+ "daily": settings.DEFAULT_DAILY_TOKEN_LIMIT,
16
+ "monthly": settings.DEFAULT_MONTHLY_TOKEN_LIMIT,
17
+ },
18
+ PlanType.PRO.value: {
19
+ "daily": settings.PRO_DAILY_TOKEN_LIMIT,
20
+ "monthly": settings.PRO_MONTHLY_TOKEN_LIMIT,
21
+ },
22
+ PlanType.MAX.value: {
23
+ "daily": settings.MAX_DAILY_TOKEN_LIMIT,
24
+ "monthly": settings.MAX_MONTHLY_TOKEN_LIMIT,
25
+ },
26
+ }
27
+
28
+
29
+ def _get_plan_limits(plan: str) -> tuple[int, int]:
30
+ limits = PLAN_LIMITS.get(plan, PLAN_LIMITS[PlanType.FREE.value])
31
+ return limits["daily"], limits["monthly"]
32
+
33
+
34
+ class UsageService:
35
+ def __init__(self, db: AsyncSession):
36
+ self.db = db
37
+
38
+ async def record_usage(
39
+ self,
40
+ user_id: str,
41
+ conversation_id: str | None,
42
+ model: str,
43
+ prompt_tokens: int,
44
+ completion_tokens: int,
45
+ ) -> UsageRecord:
46
+ record = UsageRecord(
47
+ user_id=user_id,
48
+ conversation_id=conversation_id,
49
+ model=model,
50
+ prompt_tokens=prompt_tokens,
51
+ completion_tokens=completion_tokens,
52
+ )
53
+ self.db.add(record)
54
+ await self.db.commit()
55
+ await self.db.refresh(record)
56
+ logger.info(
57
+ "Usage recorded: user=%s prompt=%d completion=%d",
58
+ user_id, prompt_tokens, completion_tokens,
59
+ )
60
+ return record
61
+
62
+ async def get_daily_usage(self, user_id: str) -> int:
63
+ now = datetime.now(timezone.utc)
64
+ start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
65
+ stmt = (
66
+ select(
67
+ func.coalesce(
68
+ func.sum(UsageRecord.prompt_tokens + UsageRecord.completion_tokens), 0
69
+ )
70
+ )
71
+ .where(UsageRecord.user_id == user_id)
72
+ .where(UsageRecord.created_at >= start_of_day)
73
+ )
74
+ result = await self.db.execute(stmt)
75
+ return result.scalar() or 0
76
+
77
+ async def get_monthly_usage(self, user_id: str) -> int:
78
+ now = datetime.now(timezone.utc)
79
+ start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
80
+ stmt = (
81
+ select(
82
+ func.coalesce(
83
+ func.sum(UsageRecord.prompt_tokens + UsageRecord.completion_tokens), 0
84
+ )
85
+ )
86
+ .where(UsageRecord.user_id == user_id)
87
+ .where(UsageRecord.created_at >= start_of_month)
88
+ )
89
+ result = await self.db.execute(stmt)
90
+ return result.scalar() or 0
91
+
92
+ async def check_limits(self, user_id: str) -> tuple[bool, str]:
93
+ """Returns (allowed, reason). allowed=True if under limits."""
94
+ user = await self.db.get(User, user_id)
95
+ if not user:
96
+ return False, "User not found"
97
+
98
+ daily_limit, monthly_limit = _get_plan_limits(user.plan)
99
+
100
+ daily = await self.get_daily_usage(user_id)
101
+ if daily >= daily_limit:
102
+ return False, "Daily token limit reached. Try again tomorrow."
103
+
104
+ monthly = await self.get_monthly_usage(user_id)
105
+ if monthly >= monthly_limit:
106
+ return False, "Monthly token limit reached."
107
+
108
+ return True, ""
109
+
110
+ async def get_usage_summary(self, user_id: str) -> dict:
111
+ user = await self.db.get(User, user_id)
112
+ daily_used = await self.get_daily_usage(user_id)
113
+ monthly_used = await self.get_monthly_usage(user_id)
114
+
115
+ daily_limit, monthly_limit = _get_plan_limits(user.plan if user else PlanType.FREE.value)
116
+
117
+ daily_pct = round((daily_used / daily_limit) * 100, 1) if daily_limit else 0
118
+ monthly_pct = round((monthly_used / monthly_limit) * 100, 1) if monthly_limit else 0
119
+
120
+ return {
121
+ "daily_percent": min(100, daily_pct),
122
+ "monthly_percent": min(100, monthly_pct),
123
+ }
124
+
125
+ async def get_usage_history(self, user_id: str, days: int = 30) -> list[dict]:
126
+ user = await self.db.get(User, user_id)
127
+ daily_limit, _ = _get_plan_limits(user.plan if user else PlanType.FREE.value)
128
+
129
+ start = datetime.now(timezone.utc) - timedelta(days=days)
130
+ stmt = (
131
+ select(
132
+ func.date(UsageRecord.created_at).label("date"),
133
+ func.sum(UsageRecord.prompt_tokens + UsageRecord.completion_tokens).label("total_tokens"),
134
+ )
135
+ .where(UsageRecord.user_id == user_id)
136
+ .where(UsageRecord.created_at >= start)
137
+ .group_by(func.date(UsageRecord.created_at))
138
+ .order_by(func.date(UsageRecord.created_at))
139
+ )
140
+ result = await self.db.execute(stmt)
141
+ rows = result.all()
142
+ return [
143
+ {
144
+ "date": str(r.date),
145
+ "usage_percent": round(
146
+ min(100, ((r.total_tokens or 0) / daily_limit) * 100), 1
147
+ ) if daily_limit else 0,
148
+ }
149
+ for r in rows
150
+ ]