synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/cli/__init__.py +66 -0
- synth_ai/cli/balance.py +205 -0
- synth_ai/cli/calc.py +70 -0
- synth_ai/cli/demo.py +74 -0
- synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
- synth_ai/cli/man.py +103 -0
- synth_ai/cli/recent.py +126 -0
- synth_ai/cli/root.py +184 -0
- synth_ai/cli/status.py +126 -0
- synth_ai/cli/traces.py +136 -0
- synth_ai/cli/watch.py +508 -0
- synth_ai/config/base_url.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +760 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/crafter_custom/environment.py +1 -1
- synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
- synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
- synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
- synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
- synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
- synth_ai/environments/service/core_routes.py +1 -1
- synth_ai/experimental/synth_oss.py +446 -0
- synth_ai/learning/core.py +21 -0
- synth_ai/learning/gateway.py +4 -0
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/mipro.py +8 -0
- synth_ai/lm/__init__.py +3 -0
- synth_ai/lm/core/main.py +4 -0
- synth_ai/lm/core/main_v3.py +238 -122
- synth_ai/lm/core/vendor_clients.py +4 -0
- synth_ai/lm/provider_support/openai.py +11 -2
- synth_ai/lm/vendors/base.py +7 -0
- synth_ai/lm/vendors/openai_standard.py +339 -4
- synth_ai/lm/vendors/openai_standard_responses.py +243 -0
- synth_ai/lm/vendors/synth_client.py +155 -5
- synth_ai/lm/warmup.py +54 -17
- synth_ai/tracing/__init__.py +18 -0
- synth_ai/tracing_v1/__init__.py +29 -14
- synth_ai/tracing_v3/__init__.py +2 -2
- synth_ai/tracing_v3/abstractions.py +62 -17
- synth_ai/tracing_v3/config.py +13 -7
- synth_ai/tracing_v3/db_config.py +6 -6
- synth_ai/tracing_v3/hooks.py +1 -1
- synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
- synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
- synth_ai/tracing_v3/session_tracer.py +5 -5
- synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
- synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
- synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
- synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
- synth_ai/tracing_v3/turso/manager.py +18 -11
- synth_ai/tracing_v3/turso/models.py +1 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/dashboard.py +329 -0
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
- synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
- synth_ai/{tracing → v0/tracing}/context.py +1 -1
- synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
- synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
- synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
- synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
- synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
- synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
- synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
- synth_ai/{tracing → v0/tracing}/upload.py +4 -4
- synth_ai/v0/tracing_v1/__init__.py +16 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +100 -5
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +115 -75
- /synth_ai/{tracing/events/__init__.py → compound/cais.py} +0 -0
- /synth_ai/{tracing_v1/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
- /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
- /synth_ai/{tracing → v0/tracing}/config.py +0 -0
- /synth_ai/{tracing → v0/tracing}/local.py +0 -0
- /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,252 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Run Crafter agent and analyze semantic map words - output as markdown tables only.
|
4
|
+
|
5
|
+
This script:
|
6
|
+
1. Runs a Crafter agent for multiple episodes
|
7
|
+
2. Extracts all unique words from the semantic map observations
|
8
|
+
3. Outputs analysis as markdown tables (no plotting dependencies)
|
9
|
+
|
10
|
+
Usage:
|
11
|
+
python analyze_semantic_words_markdown.py --model gemini-1.5-flash --episodes 3
|
12
|
+
"""
|
13
|
+
|
14
|
+
import asyncio
|
15
|
+
import argparse
|
16
|
+
import json
|
17
|
+
import re
|
18
|
+
from collections import Counter
|
19
|
+
from pathlib import Path
|
20
|
+
from typing import Dict, List, Set
|
21
|
+
from datetime import datetime
|
22
|
+
|
23
|
+
# Import the Crafter agent
|
24
|
+
import sys
|
25
|
+
sys.path.append(str(Path(__file__).parent))
|
26
|
+
from test_crafter_react_agent import run_crafter_episodes
|
27
|
+
|
28
|
+
def extract_words_from_semantic_map(observation: str) -> Set[str]:
|
29
|
+
"""Extract meaningful words from a semantic map observation string."""
|
30
|
+
if not observation or "semantic_map" not in observation.lower():
|
31
|
+
return set()
|
32
|
+
|
33
|
+
# Look for patterns like object names in the semantic map
|
34
|
+
# Common Crafter objects/entities
|
35
|
+
crafter_words = {
|
36
|
+
# Resources
|
37
|
+
'wood', 'stone', 'coal', 'iron', 'diamond', 'water',
|
38
|
+
# Animals
|
39
|
+
'cow', 'pig', 'skeleton', 'zombie',
|
40
|
+
# Structures/Objects
|
41
|
+
'tree', 'grass', 'furnace', 'table', 'bed', 'chest',
|
42
|
+
'house', 'fence', 'door', 'wall',
|
43
|
+
# Tools
|
44
|
+
'axe', 'pickaxe', 'sword', 'shovel',
|
45
|
+
# Food
|
46
|
+
'bread', 'meat', 'apple',
|
47
|
+
# Environment
|
48
|
+
'mountain', 'river', 'forest', 'desert', 'cave',
|
49
|
+
'lava', 'sand', 'dirt', 'path',
|
50
|
+
# Actions/States
|
51
|
+
'crafting', 'mining', 'building', 'farming',
|
52
|
+
'health', 'hunger', 'energy'
|
53
|
+
}
|
54
|
+
|
55
|
+
# Extract words using regex - look for alphabetic words
|
56
|
+
words = re.findall(r'\b[a-zA-Z]{3,}\b', observation.lower())
|
57
|
+
|
58
|
+
# Filter to keep only meaningful Crafter-related words
|
59
|
+
found_words = set()
|
60
|
+
for word in words:
|
61
|
+
if word in crafter_words:
|
62
|
+
found_words.add(word)
|
63
|
+
# Also check for partial matches for compound words
|
64
|
+
elif any(cw in word for cw in crafter_words):
|
65
|
+
found_words.add(word)
|
66
|
+
|
67
|
+
return found_words
|
68
|
+
|
69
|
+
def analyze_episode_traces(traces_data: List[Dict]) -> Dict[str, int]:
|
70
|
+
"""Analyze traces to extract semantic map words."""
|
71
|
+
word_counter = Counter()
|
72
|
+
|
73
|
+
for episode_data in traces_data:
|
74
|
+
if 'observations' in episode_data:
|
75
|
+
for obs in episode_data['observations']:
|
76
|
+
if isinstance(obs, dict):
|
77
|
+
# Look for semantic map in observation
|
78
|
+
obs_str = str(obs)
|
79
|
+
words = extract_words_from_semantic_map(obs_str)
|
80
|
+
word_counter.update(words)
|
81
|
+
elif isinstance(obs, str):
|
82
|
+
words = extract_words_from_semantic_map(obs)
|
83
|
+
word_counter.update(words)
|
84
|
+
|
85
|
+
return dict(word_counter)
|
86
|
+
|
87
|
+
def generate_markdown_report(word_counts: Dict[str, int], model: str, episodes: int) -> str:
|
88
|
+
"""Generate a markdown report of the semantic map analysis."""
|
89
|
+
if not word_counts:
|
90
|
+
return "# Semantic Map Analysis\n\n**No words found in semantic maps!**\n"
|
91
|
+
|
92
|
+
total_words = sum(word_counts.values())
|
93
|
+
unique_words = len(word_counts)
|
94
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
95
|
+
|
96
|
+
# Sort words by frequency
|
97
|
+
sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
|
98
|
+
|
99
|
+
# Generate markdown
|
100
|
+
md = f"""# Semantic Map Word Analysis
|
101
|
+
|
102
|
+
**Model:** {model}
|
103
|
+
**Episodes:** {episodes}
|
104
|
+
**Generated:** {timestamp}
|
105
|
+
|
106
|
+
## Summary
|
107
|
+
|
108
|
+
- **Total word occurrences:** {total_words}
|
109
|
+
- **Unique words discovered:** {unique_words}
|
110
|
+
- **Average occurrences per word:** {total_words/unique_words:.1f}
|
111
|
+
|
112
|
+
## Top Words by Frequency
|
113
|
+
|
114
|
+
| Rank | Word | Count | Percentage |
|
115
|
+
|------|------|-------|------------|
|
116
|
+
"""
|
117
|
+
|
118
|
+
# Top 15 words table
|
119
|
+
for i, (word, count) in enumerate(sorted_words[:15], 1):
|
120
|
+
percentage = (count / total_words) * 100
|
121
|
+
md += f"| {i:2d} | {word} | {count} | {percentage:.1f}% |\n"
|
122
|
+
|
123
|
+
# Word categories
|
124
|
+
categories = {
|
125
|
+
"Resources": ['wood', 'stone', 'coal', 'iron', 'diamond', 'water'],
|
126
|
+
"Animals": ['cow', 'pig', 'skeleton', 'zombie'],
|
127
|
+
"Structures": ['tree', 'furnace', 'table', 'house', 'chest', 'fence', 'door'],
|
128
|
+
"Tools": ['axe', 'pickaxe', 'sword', 'shovel'],
|
129
|
+
"Environment": ['mountain', 'river', 'forest', 'desert', 'cave', 'lava', 'grass'],
|
130
|
+
"Food": ['bread', 'meat', 'apple']
|
131
|
+
}
|
132
|
+
|
133
|
+
md += "\n## Words by Category\n\n"
|
134
|
+
|
135
|
+
for category, words in categories.items():
|
136
|
+
found_words = [(w, word_counts[w]) for w in words if w in word_counts]
|
137
|
+
if found_words:
|
138
|
+
md += f"### {category}\n\n"
|
139
|
+
md += "| Word | Count |\n|------|-------|\n"
|
140
|
+
for word, count in sorted(found_words, key=lambda x: x[1], reverse=True):
|
141
|
+
md += f"| {word} | {count} |\n"
|
142
|
+
md += "\n"
|
143
|
+
|
144
|
+
# Frequency distribution
|
145
|
+
freq_counts = Counter(word_counts.values())
|
146
|
+
md += "## Frequency Distribution\n\n"
|
147
|
+
md += "| Frequency | Number of Words |\n|-----------|----------------|\n"
|
148
|
+
for freq in sorted(freq_counts.keys(), reverse=True):
|
149
|
+
md += f"| {freq} | {freq_counts[freq]} |\n"
|
150
|
+
|
151
|
+
# All words alphabetically
|
152
|
+
md += "\n## All Words (Alphabetical)\n\n"
|
153
|
+
md += "| Word | Count |\n|------|-------|\n"
|
154
|
+
for word in sorted(word_counts.keys()):
|
155
|
+
md += f"| {word} | {word_counts[word]} |\n"
|
156
|
+
|
157
|
+
return md
|
158
|
+
|
159
|
+
async def main():
|
160
|
+
parser = argparse.ArgumentParser(description="Analyze semantic map words - markdown output only")
|
161
|
+
parser.add_argument("--model", default="gemini-1.5-flash",
|
162
|
+
help="Model to use for agent (default: gemini-1.5-flash)")
|
163
|
+
parser.add_argument("--episodes", type=int, default=3,
|
164
|
+
help="Number of episodes to run (default: 3)")
|
165
|
+
parser.add_argument("--max-turns", type=int, default=50,
|
166
|
+
help="Maximum turns per episode (default: 50)")
|
167
|
+
parser.add_argument("--output-dir", default="semantic_analysis",
|
168
|
+
help="Directory to save analysis results")
|
169
|
+
|
170
|
+
args = parser.parse_args()
|
171
|
+
|
172
|
+
print(f"🚀 Running {args.episodes} episodes with {args.model}")
|
173
|
+
print(f"📊 Will analyze semantic map words and generate markdown report")
|
174
|
+
|
175
|
+
# Create output directory
|
176
|
+
output_dir = Path(args.output_dir)
|
177
|
+
output_dir.mkdir(exist_ok=True)
|
178
|
+
|
179
|
+
# Run the agent episodes
|
180
|
+
try:
|
181
|
+
print("\n🎮 Starting Crafter episodes...")
|
182
|
+
traces_result = await run_crafter_episodes(
|
183
|
+
model_name=args.model,
|
184
|
+
num_episodes=args.episodes,
|
185
|
+
max_turns=args.max_turns,
|
186
|
+
difficulty="easy",
|
187
|
+
base_seed=1000
|
188
|
+
)
|
189
|
+
|
190
|
+
print(f"✅ Completed {args.episodes} episodes")
|
191
|
+
|
192
|
+
# Analyze semantic map words
|
193
|
+
print("\n🔍 Analyzing semantic map words...")
|
194
|
+
word_counts = analyze_episode_traces(traces_result)
|
195
|
+
|
196
|
+
# Generate markdown report
|
197
|
+
print("\n📝 Generating markdown report...")
|
198
|
+
markdown_report = generate_markdown_report(word_counts, args.model, args.episodes)
|
199
|
+
|
200
|
+
# Save markdown report
|
201
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
202
|
+
report_file = output_dir / f"semantic_analysis_{args.model}_{timestamp}.md"
|
203
|
+
|
204
|
+
with open(report_file, 'w') as f:
|
205
|
+
f.write(markdown_report)
|
206
|
+
|
207
|
+
print(f"💾 Markdown report saved to: {report_file}")
|
208
|
+
|
209
|
+
# Also save raw data as JSON
|
210
|
+
analysis_data = {
|
211
|
+
"model": args.model,
|
212
|
+
"episodes": args.episodes,
|
213
|
+
"timestamp": timestamp,
|
214
|
+
"word_counts": word_counts,
|
215
|
+
"total_unique_words": len(word_counts),
|
216
|
+
"total_word_occurrences": sum(word_counts.values())
|
217
|
+
}
|
218
|
+
|
219
|
+
json_file = output_dir / f"word_data_{args.model}_{timestamp}.json"
|
220
|
+
with open(json_file, 'w') as f:
|
221
|
+
json.dump(analysis_data, f, indent=2)
|
222
|
+
|
223
|
+
print(f"💾 Raw data saved to: {json_file}")
|
224
|
+
|
225
|
+
# Print summary to console
|
226
|
+
print("\n" + "="*60)
|
227
|
+
print("SEMANTIC MAP WORD ANALYSIS SUMMARY")
|
228
|
+
print("="*60)
|
229
|
+
|
230
|
+
if word_counts:
|
231
|
+
total_words = sum(word_counts.values())
|
232
|
+
unique_words = len(word_counts)
|
233
|
+
print(f"Total word occurrences: {total_words}")
|
234
|
+
print(f"Unique words discovered: {unique_words}")
|
235
|
+
|
236
|
+
# Top 10 most common words
|
237
|
+
sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
|
238
|
+
print(f"\nTop 10 most frequent words:")
|
239
|
+
for i, (word, count) in enumerate(sorted_words[:10], 1):
|
240
|
+
print(f"{i:2d}. {word:<12} ({count} times)")
|
241
|
+
else:
|
242
|
+
print("No semantic map words found!")
|
243
|
+
|
244
|
+
print(f"\n📄 Full analysis available in: {report_file}")
|
245
|
+
print("\n🎉 Analysis complete!")
|
246
|
+
|
247
|
+
except Exception as e:
|
248
|
+
print(f"❌ Error during analysis: {e}")
|
249
|
+
raise
|
250
|
+
|
251
|
+
if __name__ == "__main__":
|
252
|
+
asyncio.run(main())
|
@@ -0,0 +1,413 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Filter traces from DuckDB to create Modal/Synth SFT-ready .jsonl files
|
4
|
+
Supports two modes:
|
5
|
+
1. Trajectory-level filtering: Include entire trajectories above a score threshold
|
6
|
+
2. Window-based filtering: Extract high-scoring windows of actions
|
7
|
+
|
8
|
+
This is identical to the OpenAI version but configured for Modal/Synth fine-tuning.
|
9
|
+
"""
|
10
|
+
|
11
|
+
import json
|
12
|
+
import argparse
|
13
|
+
from pathlib import Path
|
14
|
+
from typing import List, Dict, Any, Tuple, Optional
|
15
|
+
from collections import defaultdict
|
16
|
+
import numpy as np
|
17
|
+
import os
|
18
|
+
import sys
|
19
|
+
import toml
|
20
|
+
|
21
|
+
# Add synth_ai to path
|
22
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
23
|
+
|
24
|
+
from synth_ai.tracing_v2.duckdb.ft_utils import FinetuningDataExtractor
|
25
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
26
|
+
|
27
|
+
|
28
|
+
def create_histogram(data: List[float], bins: int = 20, width: int = 60, height: int = 15,
|
29
|
+
title: str = "", x_label: str = "", y_label: str = "") -> str:
|
30
|
+
"""Create a beautiful ASCII histogram."""
|
31
|
+
if not data:
|
32
|
+
return "No data to display"
|
33
|
+
|
34
|
+
# Create histogram
|
35
|
+
counts, edges = np.histogram(data, bins=bins)
|
36
|
+
max_count = max(counts) if len(counts) > 0 else 1
|
37
|
+
|
38
|
+
# Normalize heights
|
39
|
+
if max_count > 0:
|
40
|
+
heights = [int(c * height / max_count) for c in counts]
|
41
|
+
else:
|
42
|
+
heights = [0] * len(counts)
|
43
|
+
|
44
|
+
# Build the plot
|
45
|
+
lines = []
|
46
|
+
|
47
|
+
# Title
|
48
|
+
if title:
|
49
|
+
lines.append(f"\n{title.center(width + 10)}")
|
50
|
+
lines.append("=" * (width + 10))
|
51
|
+
|
52
|
+
# Y-axis label
|
53
|
+
if y_label:
|
54
|
+
lines.append(f"{y_label}")
|
55
|
+
|
56
|
+
# Plot area with y-axis
|
57
|
+
for y in range(height, 0, -1):
|
58
|
+
# Y-axis value
|
59
|
+
y_val = int(max_count * y / height)
|
60
|
+
line = f"{y_val:>6} │"
|
61
|
+
|
62
|
+
# Bars
|
63
|
+
for h in heights:
|
64
|
+
if h >= y:
|
65
|
+
line += "█"
|
66
|
+
else:
|
67
|
+
line += " "
|
68
|
+
|
69
|
+
lines.append(line)
|
70
|
+
|
71
|
+
# X-axis
|
72
|
+
lines.append(f"{'':>6} └" + "─" * len(heights))
|
73
|
+
|
74
|
+
# X-axis labels
|
75
|
+
x_labels_line = " " * 8
|
76
|
+
min_val, max_val = min(data), max(data)
|
77
|
+
|
78
|
+
# Add labels at key positions
|
79
|
+
label_positions = [0, len(heights)//4, len(heights)//2, 3*len(heights)//4, len(heights)-1]
|
80
|
+
for i, pos in enumerate(label_positions):
|
81
|
+
if pos < len(edges) - 1:
|
82
|
+
val = edges[pos]
|
83
|
+
label = f"{val:.1f}"
|
84
|
+
# Calculate position
|
85
|
+
target_pos = 8 + pos
|
86
|
+
if i == 0:
|
87
|
+
x_labels_line = label + x_labels_line[len(label):]
|
88
|
+
elif i == len(label_positions) - 1:
|
89
|
+
start = max(0, target_pos - len(label))
|
90
|
+
x_labels_line = x_labels_line[:start] + label
|
91
|
+
else:
|
92
|
+
start = max(0, target_pos - len(label)//2)
|
93
|
+
end = min(len(x_labels_line), start + len(label))
|
94
|
+
if start < len(x_labels_line):
|
95
|
+
x_labels_line = x_labels_line[:start] + label[:end-start] + x_labels_line[end:]
|
96
|
+
|
97
|
+
lines.append(x_labels_line)
|
98
|
+
|
99
|
+
# X-axis label
|
100
|
+
if x_label:
|
101
|
+
lines.append(f"\n{x_label.center(width + 10)}")
|
102
|
+
|
103
|
+
return "\n".join(lines)
|
104
|
+
|
105
|
+
|
106
|
+
def create_bar_chart(categories: List[str], values: List[int], width: int = 60,
|
107
|
+
title: str = "", show_values: bool = True) -> str:
|
108
|
+
"""Create a horizontal bar chart."""
|
109
|
+
if not categories or not values:
|
110
|
+
return "No data to display"
|
111
|
+
|
112
|
+
max_val = max(values) if values else 1
|
113
|
+
lines = []
|
114
|
+
|
115
|
+
# Title
|
116
|
+
if title:
|
117
|
+
lines.append(f"\n{title}")
|
118
|
+
lines.append("=" * (width + 20))
|
119
|
+
|
120
|
+
# Find longest category name for alignment
|
121
|
+
max_cat_len = max(len(cat) for cat in categories)
|
122
|
+
|
123
|
+
# Create bars
|
124
|
+
for cat, val in zip(categories, values):
|
125
|
+
# Normalize bar length
|
126
|
+
bar_len = int(val * width / max_val) if max_val > 0 else 0
|
127
|
+
bar = "█" * bar_len
|
128
|
+
|
129
|
+
# Format line
|
130
|
+
if show_values:
|
131
|
+
line = f"{cat:<{max_cat_len}} │ {bar} {val}"
|
132
|
+
else:
|
133
|
+
line = f"{cat:<{max_cat_len}} │ {bar}"
|
134
|
+
|
135
|
+
lines.append(line)
|
136
|
+
|
137
|
+
return "\n".join(lines)
|
138
|
+
|
139
|
+
|
140
|
+
def filter_traces_from_duckdb(
|
141
|
+
db_path: str,
|
142
|
+
output_path: str,
|
143
|
+
config: Dict[str, Any]
|
144
|
+
) -> Tuple[int, Dict[str, Any]]:
|
145
|
+
"""
|
146
|
+
Filter traces from DuckDB based on configuration.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
Tuple of (num_examples, statistics_dict)
|
150
|
+
"""
|
151
|
+
mode = config.get("mode", "trajectory")
|
152
|
+
filters = config.get("filters", {})
|
153
|
+
|
154
|
+
# Extract filtering parameters
|
155
|
+
min_reward = filters.get("min_total_reward", 0.0)
|
156
|
+
min_achievements = filters.get("min_achievements", 0)
|
157
|
+
max_cost = filters.get("max_cost", float('inf'))
|
158
|
+
max_tokens = filters.get("max_tokens", float('inf'))
|
159
|
+
|
160
|
+
# Modal/Synth specific: filter by model if specified
|
161
|
+
target_models = filters.get("models", [])
|
162
|
+
|
163
|
+
statistics = {
|
164
|
+
"total_sessions": 0,
|
165
|
+
"filtered_sessions": 0,
|
166
|
+
"total_examples": 0,
|
167
|
+
"reward_distribution": [],
|
168
|
+
"token_distribution": [],
|
169
|
+
"cost_distribution": [],
|
170
|
+
"model_distribution": defaultdict(int)
|
171
|
+
}
|
172
|
+
|
173
|
+
with FinetuningDataExtractor(db_path) as extractor:
|
174
|
+
# Get all sessions
|
175
|
+
all_sessions_query = "SELECT session_id FROM session_traces"
|
176
|
+
all_sessions = extractor.db_manager.query_traces(all_sessions_query)
|
177
|
+
statistics["total_sessions"] = len(all_sessions)
|
178
|
+
|
179
|
+
# Filter sessions based on criteria
|
180
|
+
filtered_sessions = []
|
181
|
+
|
182
|
+
for session_id in all_sessions['session_id']:
|
183
|
+
metrics = extractor.get_session_metrics(session_id)
|
184
|
+
|
185
|
+
# Apply filters
|
186
|
+
if metrics['total_reward'] < min_reward:
|
187
|
+
continue
|
188
|
+
if metrics['total_cost'] > max_cost:
|
189
|
+
continue
|
190
|
+
if metrics['total_tokens'] > max_tokens:
|
191
|
+
continue
|
192
|
+
|
193
|
+
# Check achievements if required
|
194
|
+
if min_achievements > 0:
|
195
|
+
achievement_sessions = extractor.filter_by_achievements(min_achievements)
|
196
|
+
if session_id not in achievement_sessions:
|
197
|
+
continue
|
198
|
+
|
199
|
+
# Check model filter if specified
|
200
|
+
if target_models:
|
201
|
+
model_query = f"""
|
202
|
+
SELECT DISTINCT model_name
|
203
|
+
FROM events
|
204
|
+
WHERE session_id = '{session_id}'
|
205
|
+
AND event_type = 'cais'
|
206
|
+
AND model_name IS NOT NULL
|
207
|
+
"""
|
208
|
+
session_models = extractor.db_manager.query_traces(model_query)
|
209
|
+
if not any(model in target_models for model in session_models['model_name']):
|
210
|
+
continue
|
211
|
+
|
212
|
+
filtered_sessions.append(session_id)
|
213
|
+
|
214
|
+
# Collect statistics
|
215
|
+
statistics["reward_distribution"].append(metrics['total_reward'])
|
216
|
+
statistics["token_distribution"].append(metrics['total_tokens'])
|
217
|
+
statistics["cost_distribution"].append(metrics['total_cost'])
|
218
|
+
|
219
|
+
statistics["filtered_sessions"] = len(filtered_sessions)
|
220
|
+
|
221
|
+
# Extract training data
|
222
|
+
if mode == "trajectory":
|
223
|
+
training_data = extractor.extract_openai_format(
|
224
|
+
session_ids=filtered_sessions,
|
225
|
+
min_reward=min_reward
|
226
|
+
)
|
227
|
+
else: # window mode
|
228
|
+
# For window mode, we need to implement window extraction
|
229
|
+
# For now, use trajectory mode
|
230
|
+
training_data = extractor.extract_openai_format(
|
231
|
+
session_ids=filtered_sessions,
|
232
|
+
min_reward=min_reward
|
233
|
+
)
|
234
|
+
|
235
|
+
statistics["total_examples"] = len(training_data)
|
236
|
+
|
237
|
+
# Write to output file
|
238
|
+
output_file = Path(output_path)
|
239
|
+
output_file.parent.mkdir(exist_ok=True)
|
240
|
+
|
241
|
+
with open(output_file, 'w') as f:
|
242
|
+
for example in training_data:
|
243
|
+
f.write(json.dumps(example) + '\n')
|
244
|
+
|
245
|
+
# Get model distribution
|
246
|
+
model_query = """
|
247
|
+
SELECT model_name, COUNT(*) as count
|
248
|
+
FROM events
|
249
|
+
WHERE event_type = 'cais' AND model_name IS NOT NULL
|
250
|
+
GROUP BY model_name
|
251
|
+
"""
|
252
|
+
model_stats = extractor.db_manager.query_traces(model_query)
|
253
|
+
for _, row in model_stats.iterrows():
|
254
|
+
statistics["model_distribution"][row['model_name']] = row['count']
|
255
|
+
|
256
|
+
return len(training_data), statistics
|
257
|
+
|
258
|
+
|
259
|
+
def print_statistics(stats: Dict[str, Any]):
|
260
|
+
"""Print filtering statistics with visualizations."""
|
261
|
+
print("\n" + "="*80)
|
262
|
+
print("FILTERING STATISTICS (Modal/Synth)")
|
263
|
+
print("="*80)
|
264
|
+
|
265
|
+
# Basic counts
|
266
|
+
print(f"\nTotal sessions in database: {stats['total_sessions']}")
|
267
|
+
print(f"Sessions after filtering: {stats['filtered_sessions']}")
|
268
|
+
print(f"Training examples generated: {stats['total_examples']}")
|
269
|
+
|
270
|
+
filter_rate = (stats['filtered_sessions'] / stats['total_sessions'] * 100) if stats['total_sessions'] > 0 else 0
|
271
|
+
print(f"Filter pass rate: {filter_rate:.1f}%")
|
272
|
+
|
273
|
+
# Reward distribution
|
274
|
+
if stats['reward_distribution'] and any(not np.isnan(x) for x in stats['reward_distribution']):
|
275
|
+
valid_rewards = [x for x in stats['reward_distribution'] if not np.isnan(x)]
|
276
|
+
if valid_rewards:
|
277
|
+
print(create_histogram(
|
278
|
+
valid_rewards,
|
279
|
+
bins=20,
|
280
|
+
title="Reward Distribution",
|
281
|
+
x_label="Total Reward",
|
282
|
+
y_label="Count"
|
283
|
+
))
|
284
|
+
|
285
|
+
print(f"\nReward statistics:")
|
286
|
+
print(f" Min: {min(valid_rewards):.2f}")
|
287
|
+
print(f" Max: {max(valid_rewards):.2f}")
|
288
|
+
print(f" Mean: {np.mean(valid_rewards):.2f}")
|
289
|
+
print(f" Median: {np.median(valid_rewards):.2f}")
|
290
|
+
else:
|
291
|
+
print("\nNo valid reward data to display.")
|
292
|
+
|
293
|
+
# Token distribution
|
294
|
+
if stats['token_distribution'] and any(not np.isnan(x) for x in stats['token_distribution']):
|
295
|
+
valid_tokens = [x for x in stats['token_distribution'] if not np.isnan(x)]
|
296
|
+
if valid_tokens:
|
297
|
+
print(create_histogram(
|
298
|
+
valid_tokens,
|
299
|
+
bins=20,
|
300
|
+
title="Token Usage Distribution",
|
301
|
+
x_label="Total Tokens",
|
302
|
+
y_label="Count"
|
303
|
+
))
|
304
|
+
|
305
|
+
# Model distribution
|
306
|
+
if stats['model_distribution']:
|
307
|
+
models = list(stats['model_distribution'].keys())
|
308
|
+
counts = list(stats['model_distribution'].values())
|
309
|
+
print(create_bar_chart(
|
310
|
+
models,
|
311
|
+
counts,
|
312
|
+
title="Model Usage",
|
313
|
+
show_values=True
|
314
|
+
))
|
315
|
+
|
316
|
+
print("\n" + "="*80)
|
317
|
+
|
318
|
+
|
319
|
+
def main():
|
320
|
+
parser = argparse.ArgumentParser(
|
321
|
+
description="Filter traces from DuckDB for Modal/Synth fine-tuning",
|
322
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
323
|
+
epilog="""
|
324
|
+
Example usage:
|
325
|
+
# Use default config
|
326
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb -o ft_data/training.jsonl
|
327
|
+
|
328
|
+
# Use custom config file
|
329
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb -c filter_config.toml
|
330
|
+
|
331
|
+
# Override config parameters
|
332
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb --min-reward 5.0 --max-cost 0.1
|
333
|
+
|
334
|
+
# Filter by model
|
335
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb --models "Qwen/Qwen2.5-7B-Instruct"
|
336
|
+
"""
|
337
|
+
)
|
338
|
+
|
339
|
+
parser.add_argument('-d', '--database', required=True, help='Path to DuckDB database')
|
340
|
+
parser.add_argument('-o', '--output', default='ft_data/training_modal.jsonl', help='Output JSONL file')
|
341
|
+
parser.add_argument('-c', '--config', help='Configuration TOML file')
|
342
|
+
|
343
|
+
# Filter overrides
|
344
|
+
parser.add_argument('--mode', choices=['trajectory', 'window'], help='Filtering mode')
|
345
|
+
parser.add_argument('--min-reward', type=float, help='Minimum total reward')
|
346
|
+
parser.add_argument('--min-achievements', type=int, help='Minimum achievements')
|
347
|
+
parser.add_argument('--max-cost', type=float, help='Maximum cost')
|
348
|
+
parser.add_argument('--max-tokens', type=int, help='Maximum tokens')
|
349
|
+
parser.add_argument('--models', nargs='+', help='Filter by model names (e.g., Qwen/Qwen2.5-7B-Instruct)')
|
350
|
+
|
351
|
+
parser.add_argument('--dry-run', action='store_true', help='Show statistics without writing output')
|
352
|
+
|
353
|
+
args = parser.parse_args()
|
354
|
+
|
355
|
+
# Load config
|
356
|
+
config = {
|
357
|
+
"mode": "trajectory",
|
358
|
+
"filters": {
|
359
|
+
"min_total_reward": 1.0,
|
360
|
+
"min_achievements": 0,
|
361
|
+
"max_cost": 10.0,
|
362
|
+
"max_tokens": 100000,
|
363
|
+
"models": [] # Empty means all models
|
364
|
+
}
|
365
|
+
}
|
366
|
+
|
367
|
+
if args.config:
|
368
|
+
with open(args.config, 'r') as f:
|
369
|
+
loaded_config = toml.load(f)
|
370
|
+
config.update(loaded_config)
|
371
|
+
|
372
|
+
# Apply command-line overrides
|
373
|
+
if args.mode:
|
374
|
+
config["mode"] = args.mode
|
375
|
+
if args.min_reward is not None:
|
376
|
+
config["filters"]["min_total_reward"] = args.min_reward
|
377
|
+
if args.min_achievements is not None:
|
378
|
+
config["filters"]["min_achievements"] = args.min_achievements
|
379
|
+
if args.max_cost is not None:
|
380
|
+
config["filters"]["max_cost"] = args.max_cost
|
381
|
+
if args.max_tokens is not None:
|
382
|
+
config["filters"]["max_tokens"] = args.max_tokens
|
383
|
+
if args.models:
|
384
|
+
config["filters"]["models"] = args.models
|
385
|
+
|
386
|
+
print(f"🤖 Modal/Synth Fine-Tuning Data Filter")
|
387
|
+
print(f"Using database: {args.database}")
|
388
|
+
print(f"Output file: {args.output}")
|
389
|
+
print(f"Mode: {config['mode']}")
|
390
|
+
print(f"Filters: {json.dumps(config['filters'], indent=2)}")
|
391
|
+
|
392
|
+
if args.dry_run:
|
393
|
+
print("\n🔍 DRY RUN - No output will be written")
|
394
|
+
|
395
|
+
# Run filtering
|
396
|
+
num_examples, stats = filter_traces_from_duckdb(
|
397
|
+
args.database,
|
398
|
+
args.output if not args.dry_run else "/dev/null",
|
399
|
+
config
|
400
|
+
)
|
401
|
+
|
402
|
+
# Print statistics
|
403
|
+
print_statistics(stats)
|
404
|
+
|
405
|
+
if not args.dry_run:
|
406
|
+
print(f"\n✅ Successfully wrote {num_examples} training examples to {args.output}")
|
407
|
+
print(f" Ready for Modal/Synth fine-tuning!")
|
408
|
+
else:
|
409
|
+
print(f"\n✅ Would write {num_examples} training examples (dry run)")
|
410
|
+
|
411
|
+
|
412
|
+
if __name__ == "__main__":
|
413
|
+
main()
|