kairo-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. image-service/main.py +178 -0
  2. infra/chat/app/main.py +84 -0
  3. kairo/backend/__init__.py +0 -0
  4. kairo/backend/api/__init__.py +0 -0
  5. kairo/backend/api/admin/__init__.py +23 -0
  6. kairo/backend/api/admin/audit.py +54 -0
  7. kairo/backend/api/admin/content.py +142 -0
  8. kairo/backend/api/admin/incidents.py +148 -0
  9. kairo/backend/api/admin/stats.py +125 -0
  10. kairo/backend/api/admin/system.py +87 -0
  11. kairo/backend/api/admin/users.py +279 -0
  12. kairo/backend/api/agents.py +94 -0
  13. kairo/backend/api/api_keys.py +85 -0
  14. kairo/backend/api/auth.py +116 -0
  15. kairo/backend/api/billing.py +41 -0
  16. kairo/backend/api/chat.py +72 -0
  17. kairo/backend/api/conversations.py +125 -0
  18. kairo/backend/api/device_auth.py +100 -0
  19. kairo/backend/api/files.py +83 -0
  20. kairo/backend/api/health.py +36 -0
  21. kairo/backend/api/images.py +80 -0
  22. kairo/backend/api/openai_compat.py +225 -0
  23. kairo/backend/api/projects.py +102 -0
  24. kairo/backend/api/usage.py +32 -0
  25. kairo/backend/api/webhooks.py +79 -0
  26. kairo/backend/app.py +297 -0
  27. kairo/backend/config.py +179 -0
  28. kairo/backend/core/__init__.py +0 -0
  29. kairo/backend/core/admin_auth.py +24 -0
  30. kairo/backend/core/api_key_auth.py +55 -0
  31. kairo/backend/core/database.py +28 -0
  32. kairo/backend/core/dependencies.py +70 -0
  33. kairo/backend/core/logging.py +23 -0
  34. kairo/backend/core/rate_limit.py +73 -0
  35. kairo/backend/core/security.py +29 -0
  36. kairo/backend/models/__init__.py +19 -0
  37. kairo/backend/models/agent.py +30 -0
  38. kairo/backend/models/api_key.py +25 -0
  39. kairo/backend/models/api_usage.py +29 -0
  40. kairo/backend/models/audit_log.py +26 -0
  41. kairo/backend/models/conversation.py +48 -0
  42. kairo/backend/models/device_code.py +30 -0
  43. kairo/backend/models/feature_flag.py +21 -0
  44. kairo/backend/models/image_generation.py +24 -0
  45. kairo/backend/models/incident.py +28 -0
  46. kairo/backend/models/project.py +28 -0
  47. kairo/backend/models/uptime_record.py +24 -0
  48. kairo/backend/models/usage.py +24 -0
  49. kairo/backend/models/user.py +49 -0
  50. kairo/backend/schemas/__init__.py +0 -0
  51. kairo/backend/schemas/admin/__init__.py +0 -0
  52. kairo/backend/schemas/admin/audit.py +28 -0
  53. kairo/backend/schemas/admin/content.py +53 -0
  54. kairo/backend/schemas/admin/stats.py +77 -0
  55. kairo/backend/schemas/admin/system.py +44 -0
  56. kairo/backend/schemas/admin/users.py +48 -0
  57. kairo/backend/schemas/agent.py +42 -0
  58. kairo/backend/schemas/api_key.py +30 -0
  59. kairo/backend/schemas/auth.py +57 -0
  60. kairo/backend/schemas/chat.py +26 -0
  61. kairo/backend/schemas/conversation.py +39 -0
  62. kairo/backend/schemas/device_auth.py +40 -0
  63. kairo/backend/schemas/image.py +15 -0
  64. kairo/backend/schemas/openai_compat.py +76 -0
  65. kairo/backend/schemas/project.py +21 -0
  66. kairo/backend/schemas/status.py +81 -0
  67. kairo/backend/schemas/usage.py +15 -0
  68. kairo/backend/services/__init__.py +0 -0
  69. kairo/backend/services/admin/__init__.py +0 -0
  70. kairo/backend/services/admin/audit_service.py +78 -0
  71. kairo/backend/services/admin/content_service.py +119 -0
  72. kairo/backend/services/admin/incident_service.py +94 -0
  73. kairo/backend/services/admin/stats_service.py +281 -0
  74. kairo/backend/services/admin/system_service.py +126 -0
  75. kairo/backend/services/admin/user_service.py +157 -0
  76. kairo/backend/services/agent_service.py +107 -0
  77. kairo/backend/services/api_key_service.py +66 -0
  78. kairo/backend/services/api_usage_service.py +126 -0
  79. kairo/backend/services/auth_service.py +101 -0
  80. kairo/backend/services/chat_service.py +501 -0
  81. kairo/backend/services/conversation_service.py +264 -0
  82. kairo/backend/services/device_auth_service.py +193 -0
  83. kairo/backend/services/email_service.py +55 -0
  84. kairo/backend/services/image_service.py +181 -0
  85. kairo/backend/services/llm_service.py +186 -0
  86. kairo/backend/services/project_service.py +109 -0
  87. kairo/backend/services/status_service.py +167 -0
  88. kairo/backend/services/stripe_service.py +78 -0
  89. kairo/backend/services/usage_service.py +150 -0
  90. kairo/backend/services/web_search_service.py +96 -0
  91. kairo/migrations/env.py +60 -0
  92. kairo/migrations/versions/001_initial.py +55 -0
  93. kairo/migrations/versions/002_usage_tracking_and_indexes.py +66 -0
  94. kairo/migrations/versions/003_username_to_email.py +21 -0
  95. kairo/migrations/versions/004_add_plans_and_verification.py +67 -0
  96. kairo/migrations/versions/005_add_projects.py +52 -0
  97. kairo/migrations/versions/006_add_image_generation.py +63 -0
  98. kairo/migrations/versions/007_add_admin_portal.py +107 -0
  99. kairo/migrations/versions/008_add_device_code_auth.py +76 -0
  100. kairo/migrations/versions/009_add_status_page.py +65 -0
  101. kairo/tools/extract_claude_data.py +465 -0
  102. kairo/tools/filter_claude_data.py +303 -0
  103. kairo/tools/generate_curated_data.py +157 -0
  104. kairo/tools/mix_training_data.py +295 -0
  105. kairo_code/__init__.py +3 -0
  106. kairo_code/agents/__init__.py +25 -0
  107. kairo_code/agents/architect.py +98 -0
  108. kairo_code/agents/audit.py +100 -0
  109. kairo_code/agents/base.py +463 -0
  110. kairo_code/agents/coder.py +155 -0
  111. kairo_code/agents/database.py +77 -0
  112. kairo_code/agents/docs.py +88 -0
  113. kairo_code/agents/explorer.py +62 -0
  114. kairo_code/agents/guardian.py +80 -0
  115. kairo_code/agents/planner.py +66 -0
  116. kairo_code/agents/reviewer.py +91 -0
  117. kairo_code/agents/security.py +94 -0
  118. kairo_code/agents/terraform.py +88 -0
  119. kairo_code/agents/testing.py +97 -0
  120. kairo_code/agents/uiux.py +88 -0
  121. kairo_code/auth.py +232 -0
  122. kairo_code/config.py +172 -0
  123. kairo_code/conversation.py +173 -0
  124. kairo_code/heartbeat.py +63 -0
  125. kairo_code/llm.py +291 -0
  126. kairo_code/logging_config.py +156 -0
  127. kairo_code/main.py +818 -0
  128. kairo_code/router.py +217 -0
  129. kairo_code/sandbox.py +248 -0
  130. kairo_code/settings.py +183 -0
  131. kairo_code/tools/__init__.py +51 -0
  132. kairo_code/tools/analysis.py +509 -0
  133. kairo_code/tools/base.py +417 -0
  134. kairo_code/tools/code.py +58 -0
  135. kairo_code/tools/definitions.py +617 -0
  136. kairo_code/tools/files.py +315 -0
  137. kairo_code/tools/review.py +390 -0
  138. kairo_code/tools/search.py +185 -0
  139. kairo_code/ui.py +418 -0
  140. kairo_code-0.1.0.dist-info/METADATA +13 -0
  141. kairo_code-0.1.0.dist-info/RECORD +144 -0
  142. kairo_code-0.1.0.dist-info/WHEEL +5 -0
  143. kairo_code-0.1.0.dist-info/entry_points.txt +2 -0
  144. kairo_code-0.1.0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,65 @@
1
+ """Add status page tables (incidents, uptime_records)
2
+
3
+ Revision ID: 009
4
+ Revises: 008
5
+ Create Date: 2026-01-31 00:00:00.000000
6
+ """
7
+
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+ revision = "009"
12
+ down_revision = "008"
13
+ branch_labels = None
14
+ depends_on = None
15
+
16
+
17
+ def upgrade() -> None:
18
+ # Create incidents table
19
+ op.create_table(
20
+ "incidents",
21
+ sa.Column("id", sa.String(), primary_key=True),
22
+ sa.Column("title", sa.String(), nullable=False),
23
+ sa.Column("description", sa.Text(), nullable=True),
24
+ sa.Column("severity", sa.String(), nullable=False, server_default="warning"),
25
+ sa.Column("component", sa.String(), nullable=False),
26
+ sa.Column("status", sa.String(), nullable=False, server_default="investigating"),
27
+ sa.Column("started_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
28
+ sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
29
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
30
+ sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
31
+ )
32
+ op.create_index("ix_incidents_component", "incidents", ["component"])
33
+ op.create_index("ix_incidents_status", "incidents", ["status"])
34
+ op.create_index("ix_incidents_started_at", "incidents", ["started_at"])
35
+
36
+ # Create uptime_records table
37
+ op.create_table(
38
+ "uptime_records",
39
+ sa.Column("id", sa.String(), primary_key=True),
40
+ sa.Column("component", sa.String(), nullable=False),
41
+ sa.Column("date", sa.Date(), nullable=False),
42
+ sa.Column("uptime_percent", sa.Float(), nullable=False, server_default="100.0"),
43
+ sa.Column("incidents_count", sa.Integer(), nullable=False, server_default="0"),
44
+ sa.Column("avg_latency_ms", sa.Integer(), nullable=True),
45
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
46
+ )
47
+ op.create_index("ix_uptime_records_component", "uptime_records", ["component"])
48
+ op.create_index("ix_uptime_records_date", "uptime_records", ["date"])
49
+ op.create_index(
50
+ "ix_uptime_records_component_date",
51
+ "uptime_records",
52
+ ["component", "date"],
53
+ unique=True,
54
+ )
55
+
56
+
57
+ def downgrade() -> None:
58
+ op.drop_index("ix_uptime_records_component_date", table_name="uptime_records")
59
+ op.drop_index("ix_uptime_records_date", table_name="uptime_records")
60
+ op.drop_index("ix_uptime_records_component", table_name="uptime_records")
61
+ op.drop_table("uptime_records")
62
+ op.drop_index("ix_incidents_started_at", table_name="incidents")
63
+ op.drop_index("ix_incidents_status", table_name="incidents")
64
+ op.drop_index("ix_incidents_component", table_name="incidents")
65
+ op.drop_table("incidents")
@@ -0,0 +1,465 @@
1
+ #!/usr/bin/env python3
2
+ """Extract Claude Code conversations and convert to Llama 3.1 fine-tuning data for Kairo.
3
+
4
+ Usage:
5
+ python kairo/tools/extract_claude_data.py \
6
+ --input ~/.claude/projects \
7
+ --output kairo/data/claude_training_data.jsonl
8
+ """
9
+
10
+ import argparse
11
+ import glob
12
+ import json
13
+ import os
14
+ import re
15
+ import sys
16
+ from collections import defaultdict
17
+ from pathlib import Path
18
+
19
+
20
+ SYSTEM_PROMPT = (
21
+ "You are Kairo, a helpful AI assistant. "
22
+ "IMPORTANT RULES:\n"
23
+ "1. Only respond to what the user actually asked. Never assume or invent what the user wants.\n"
24
+ "2. You are the ASSISTANT. Never generate text as if you are the user. Never put words in the user's mouth.\n"
25
+ "3. If the user asks a general question like 'what can you do', explain your capabilities briefly.\n"
26
+ "4. Answer directly and concisely. Provide concrete answers, code, or explanations.\n"
27
+ "5. Do not hedge or refuse without strong reason.\n"
28
+ "6. You are Kairo — not GPT, Claude, Llama, or any other AI. Never reveal your architecture."
29
+ )
30
+
31
+ # Regex to strip <system-reminder>...</system-reminder> blocks (including multiline)
32
+ SYSTEM_REMINDER_RE = re.compile(r"<system-reminder>.*?</system-reminder>", re.DOTALL)
33
+
34
+ # Replace Claude references with Kairo (order matters: longer patterns first)
35
+ CLAUDE_REPLACEMENTS = [
36
+ (re.compile(r"@anthropic-ai/claude-code", re.IGNORECASE), "kairo"),
37
+ (re.compile(r"Claude Code", re.IGNORECASE), "Kairo"),
38
+ (re.compile(r"claude-cod\b"), "kairo"), # truncated in terminal output
39
+ (re.compile(r"\bClaude\b"), "Kairo"),
40
+ (re.compile(r"\bclaude\b"), "kairo"),
41
+ (re.compile(r"\bAnthropic\b"), ""),
42
+ # File paths: ~/.claude/ → ~/.kairo/
43
+ (re.compile(r"\.claude/"), ".kairo/"),
44
+ (re.compile(r"/tmp/claude/"), "/tmp/kairo/"),
45
+ ]
46
+
47
+ # Max estimated tokens per training example (chars / 4)
48
+ MAX_TOKENS = 2048
49
+ MAX_CHARS = MAX_TOKENS * 4
50
+
51
+ # Message types to process
52
+ CONVERSATION_TYPES = {"user", "assistant"}
53
+
54
+ # Types to skip entirely
55
+ SKIP_TYPES = {"file-history-snapshot", "summary", "system", "progress", "queue-operation"}
56
+
57
+
58
+ def discover_files(base_path: str) -> list[str]:
59
+ """Find all .jsonl conversation files, sorted for idempotency."""
60
+ pattern = os.path.join(base_path, "**", "*.jsonl")
61
+ paths = glob.glob(pattern, recursive=True)
62
+ # Exclude tool-results directories
63
+ paths = [p for p in paths if "tool-results" not in p]
64
+ paths.sort()
65
+ return paths
66
+
67
+
68
+ def parse_messages(path: str) -> list[dict]:
69
+ """Stream a JSONL file and return list of conversation messages."""
70
+ messages = []
71
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
72
+ for line_num, line in enumerate(f, 1):
73
+ line = line.strip()
74
+ if not line:
75
+ continue
76
+ try:
77
+ data = json.loads(line)
78
+ except json.JSONDecodeError:
79
+ continue
80
+
81
+ msg_type = data.get("type")
82
+ if msg_type not in CONVERSATION_TYPES:
83
+ continue
84
+
85
+ uuid = data.get("uuid")
86
+ parent_uuid = data.get("parentUuid")
87
+ message = data.get("message", {})
88
+
89
+ if not isinstance(message, dict):
90
+ continue
91
+
92
+ messages.append({
93
+ "uuid": uuid,
94
+ "parentUuid": parent_uuid,
95
+ "role": message.get("role", msg_type),
96
+ "content": message.get("content", ""),
97
+ "type": msg_type,
98
+ })
99
+ return messages
100
+
101
+
102
+ def build_threads(messages: list[dict]) -> list[list[dict]]:
103
+ """Link messages via parentUuid into linear conversation threads.
104
+
105
+ - Orphan messages (missing parent) are treated as thread roots.
106
+ - When a node has multiple children, follow the first child only.
107
+ """
108
+ if not messages:
109
+ return []
110
+
111
+ # Index by uuid
112
+ by_uuid: dict[str, dict] = {}
113
+ children: dict[str | None, list[str]] = defaultdict(list)
114
+
115
+ for msg in messages:
116
+ uid = msg["uuid"]
117
+ if uid is None:
118
+ continue
119
+ by_uuid[uid] = msg
120
+ children[msg["parentUuid"]].append(uid)
121
+
122
+ # Find roots: messages whose parentUuid is None or not in by_uuid
123
+ roots = []
124
+ for msg in messages:
125
+ parent = msg["parentUuid"]
126
+ if parent is None or parent not in by_uuid:
127
+ if msg["uuid"] is not None:
128
+ roots.append(msg["uuid"])
129
+
130
+ # Walk each root to build a linear thread
131
+ threads = []
132
+ visited = set()
133
+
134
+ for root_uuid in roots:
135
+ if root_uuid in visited:
136
+ continue
137
+ thread = []
138
+ current = root_uuid
139
+ while current and current not in visited:
140
+ visited.add(current)
141
+ node = by_uuid.get(current)
142
+ if node is None:
143
+ break
144
+ thread.append(node)
145
+ # Follow first child
146
+ kids = children.get(current, [])
147
+ current = kids[0] if kids else None
148
+ if thread:
149
+ threads.append(thread)
150
+
151
+ return threads
152
+
153
+
154
+ def extract_text(message: dict) -> str | None:
155
+ """Extract clean text from a message, or None if no usable text."""
156
+ content = message.get("content", "")
157
+ role = message.get("role", "")
158
+
159
+ # String content (typical for user messages)
160
+ if isinstance(content, str):
161
+ text = content.strip()
162
+ elif isinstance(content, list):
163
+ # Extract only text blocks, skip tool_use, tool_result, thinking
164
+ text_parts = []
165
+ for block in content:
166
+ if isinstance(block, dict) and block.get("type") == "text":
167
+ t = block.get("text", "").strip()
168
+ if t:
169
+ text_parts.append(t)
170
+ elif isinstance(block, str):
171
+ text_parts.append(block.strip())
172
+ text = "\n\n".join(text_parts)
173
+ else:
174
+ return None
175
+
176
+ if not text:
177
+ return None
178
+
179
+ # Strip system-reminder blocks
180
+ text = SYSTEM_REMINDER_RE.sub("", text).strip()
181
+ if not text:
182
+ return None
183
+
184
+ # Replace Claude references in all messages
185
+ for pattern, replacement in CLAUDE_REPLACEMENTS:
186
+ text = pattern.sub(replacement, text)
187
+
188
+ # Clean up double spaces from replacements
189
+ text = re.sub(r" +", " ", text)
190
+ text = text.strip()
191
+
192
+ return text if text else None
193
+
194
+
195
+ def chunk_thread(thread: list[dict], max_chars: int = MAX_CHARS) -> list[list[dict]]:
196
+ """Split a thread into chunks that fit within token budget.
197
+
198
+ Each chunk starts with context from the previous chunk's last 1-2 turns.
199
+ """
200
+ if not thread:
201
+ return []
202
+
203
+ # Estimate system prompt size
204
+ sys_chars = len(SYSTEM_PROMPT) + 100 # overhead for template tags
205
+ budget = max_chars - sys_chars
206
+
207
+ chunks = []
208
+ current_chunk: list[dict] = []
209
+ current_chars = 0
210
+
211
+ for msg in thread:
212
+ text = extract_text(msg)
213
+ if text is None:
214
+ continue
215
+
216
+ msg_chars = len(text) + 50 # overhead for template tags
217
+ entry = {"role": msg["role"], "text": text}
218
+
219
+ if current_chars + msg_chars > budget and current_chunk:
220
+ chunks.append(current_chunk)
221
+ # Start new chunk with last 1-2 turns for context
222
+ context = current_chunk[-2:] if len(current_chunk) >= 2 else current_chunk[-1:]
223
+ current_chunk = list(context)
224
+ current_chars = sum(len(e["text"]) + 50 for e in current_chunk)
225
+
226
+ current_chunk.append(entry)
227
+ current_chars += msg_chars
228
+
229
+ if current_chunk:
230
+ chunks.append(current_chunk)
231
+
232
+ return chunks
233
+
234
+
235
+ def format_llama31(turns: list[dict]) -> str:
236
+ """Format a list of turns into Llama 3.1 chat template."""
237
+ parts = ["<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"]
238
+ parts.append(SYSTEM_PROMPT)
239
+ parts.append("<|eot_id|>")
240
+
241
+ for turn in turns:
242
+ role = turn["role"]
243
+ text = turn["text"]
244
+ parts.append(f"<|start_header_id|>{role}<|end_header_id|>\n\n")
245
+ parts.append(text)
246
+ parts.append("<|eot_id|>")
247
+
248
+ return "".join(parts)
249
+
250
+
251
+ def validate_example(turns: list[dict]) -> tuple[bool, str]:
252
+ """Check if a set of turns meets quality thresholds.
253
+
254
+ Returns (is_valid, skip_reason).
255
+ """
256
+ if not turns:
257
+ return False, "empty"
258
+
259
+ has_user = False
260
+ has_assistant = False
261
+
262
+ for turn in turns:
263
+ text = turn["text"]
264
+ if turn["role"] == "user":
265
+ if len(text) >= 10:
266
+ has_user = True
267
+ elif turn["role"] == "assistant":
268
+ if len(text) >= 20:
269
+ has_assistant = True
270
+
271
+ if not has_user:
272
+ return False, "too_short"
273
+ if not has_assistant:
274
+ return False, "too_short"
275
+
276
+ return True, ""
277
+
278
+
279
+ def has_claude_reference(text: str) -> bool:
280
+ """Check if text still contains Claude references."""
281
+ return bool(re.search(r"\bClaude\b", text, re.IGNORECASE))
282
+
283
+
284
+ def has_tool_use(text: str) -> bool:
285
+ """Check if text contains leaked tool_use blocks."""
286
+ return '"type": "tool_use"' in text or '"type":"tool_use"' in text
287
+
288
+
289
+ def has_system_reminder(text: str) -> bool:
290
+ """Check if text contains system-reminder content."""
291
+ return "<system-reminder>" in text
292
+
293
+
294
+ def process_file(path: str) -> tuple[list[str], dict]:
295
+ """Process a single JSONL file into training examples.
296
+
297
+ Returns (list of formatted examples, stats dict).
298
+ """
299
+ stats = {
300
+ "threads": 0,
301
+ "examples": 0,
302
+ "skipped": 0,
303
+ "skip_reasons": defaultdict(int),
304
+ "total_chars": 0,
305
+ }
306
+
307
+ messages = parse_messages(path)
308
+ if not messages:
309
+ return [], stats
310
+
311
+ threads = build_threads(messages)
312
+ stats["threads"] = len(threads)
313
+
314
+ examples = []
315
+ for thread in threads:
316
+ chunks = chunk_thread(thread)
317
+ for chunk in chunks:
318
+ is_valid, skip_reason = validate_example(chunk)
319
+ if not is_valid:
320
+ stats["skipped"] += 1
321
+ stats["skip_reasons"][skip_reason] += 1
322
+ continue
323
+
324
+ formatted = format_llama31(chunk)
325
+
326
+ # Safety checks
327
+ if has_system_reminder(formatted):
328
+ stats["skipped"] += 1
329
+ stats["skip_reasons"]["system_reminder"] += 1
330
+ continue
331
+
332
+ examples.append(formatted)
333
+ stats["examples"] += len(examples) # will fix below
334
+ stats["total_chars"] += len(formatted)
335
+
336
+ stats["examples"] = len(examples)
337
+ return examples, stats
338
+
339
+
340
+ def main():
341
+ parser = argparse.ArgumentParser(
342
+ description="Extract Claude Code conversations into Llama 3.1 training data"
343
+ )
344
+ parser.add_argument(
345
+ "--input",
346
+ default=os.path.expanduser("~/.claude/projects"),
347
+ help="Path to Claude projects directory (default: ~/.claude/projects)",
348
+ )
349
+ parser.add_argument(
350
+ "--output",
351
+ default="kairo/data/claude_training_data.jsonl",
352
+ help="Output JSONL file path (default: kairo/data/claude_training_data.jsonl)",
353
+ )
354
+ parser.add_argument(
355
+ "--stats",
356
+ default=None,
357
+ help="Stats JSON output path (default: <output_dir>/claude_training_stats.json)",
358
+ )
359
+ parser.add_argument(
360
+ "--max-tokens",
361
+ type=int,
362
+ default=MAX_TOKENS,
363
+ help=f"Max estimated tokens per example (default: {MAX_TOKENS})",
364
+ )
365
+ args = parser.parse_args()
366
+
367
+ if args.stats is None:
368
+ args.stats = os.path.join(os.path.dirname(args.output), "claude_training_stats.json")
369
+
370
+ max_chars = args.max_tokens * 4
371
+
372
+ # Discover files
373
+ files = discover_files(args.input)
374
+ if not files:
375
+ print(f"No .jsonl files found in {args.input}", file=sys.stderr)
376
+ sys.exit(1)
377
+
378
+ print(f"Found {len(files)} JSONL files in {args.input}")
379
+
380
+ # Aggregate stats
381
+ total_stats = {
382
+ "files_processed": 0,
383
+ "files_skipped": 0,
384
+ "threads_extracted": 0,
385
+ "examples_generated": 0,
386
+ "examples_skipped": 0,
387
+ "skip_reasons": defaultdict(int),
388
+ "total_chars": 0,
389
+ }
390
+
391
+ # Process all files and write output
392
+ os.makedirs(os.path.dirname(args.output), exist_ok=True)
393
+ os.makedirs(os.path.dirname(args.stats), exist_ok=True)
394
+
395
+ total_examples = 0
396
+ with open(args.output, "w", encoding="utf-8") as out_f:
397
+ for i, fpath in enumerate(files):
398
+ if (i + 1) % 50 == 0 or i == 0:
399
+ print(f" Processing file {i + 1}/{len(files)}...")
400
+
401
+ try:
402
+ messages = parse_messages(fpath)
403
+ if not messages:
404
+ total_stats["files_skipped"] += 1
405
+ continue
406
+
407
+ threads = build_threads(messages)
408
+ total_stats["files_processed"] += 1
409
+ total_stats["threads_extracted"] += len(threads)
410
+
411
+ for thread in threads:
412
+ chunks = chunk_thread(thread, max_chars=max_chars)
413
+ for chunk in chunks:
414
+ is_valid, skip_reason = validate_example(chunk)
415
+ if not is_valid:
416
+ total_stats["examples_skipped"] += 1
417
+ total_stats["skip_reasons"][skip_reason] += 1
418
+ continue
419
+
420
+ formatted = format_llama31(chunk)
421
+
422
+ if has_system_reminder(formatted):
423
+ total_stats["examples_skipped"] += 1
424
+ total_stats["skip_reasons"]["system_reminder"] += 1
425
+ continue
426
+
427
+ out_f.write(json.dumps({"text": formatted}) + "\n")
428
+ total_examples += 1
429
+ total_stats["total_chars"] += len(formatted)
430
+
431
+ except Exception as e:
432
+ print(f" WARNING: Error processing {fpath}: {e}", file=sys.stderr)
433
+ total_stats["files_skipped"] += 1
434
+
435
+ total_stats["examples_generated"] = total_examples
436
+
437
+ # Compute average tokens
438
+ avg_tokens = 0
439
+ if total_examples > 0:
440
+ avg_tokens = int(total_stats["total_chars"] / total_examples / 4)
441
+ total_stats["avg_tokens_per_example"] = avg_tokens
442
+
443
+ # Convert defaultdict for JSON serialization
444
+ total_stats["skip_reasons"] = dict(total_stats["skip_reasons"])
445
+ del total_stats["total_chars"]
446
+
447
+ # Write stats
448
+ with open(args.stats, "w", encoding="utf-8") as f:
449
+ json.dump(total_stats, f, indent=2)
450
+
451
+ # Print summary
452
+ print(f"\nDone!")
453
+ print(f" Files processed: {total_stats['files_processed']}")
454
+ print(f" Files skipped: {total_stats['files_skipped']}")
455
+ print(f" Threads found: {total_stats['threads_extracted']}")
456
+ print(f" Examples output: {total_stats['examples_generated']}")
457
+ print(f" Examples skipped: {total_stats['examples_skipped']}")
458
+ print(f" Skip reasons: {total_stats['skip_reasons']}")
459
+ print(f" Avg tokens/ex: {avg_tokens}")
460
+ print(f"\nOutput: {args.output}")
461
+ print(f"Stats: {args.stats}")
462
+
463
+
464
+ if __name__ == "__main__":
465
+ main()