qmdr 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/marketplace.json +29 -0
- package/.env.example +85 -0
- package/.gitattributes +3 -0
- package/.github/workflows/release.yml +77 -0
- package/AI-SETUP.md +466 -0
- package/LICENSE +22 -0
- package/README.md +78 -0
- package/bun.lock +637 -0
- package/docs/README-zh.md +78 -0
- package/docs/refactor-checklist.md +54 -0
- package/docs/setup-openclaw.md +139 -0
- package/example-index.yml +33 -0
- package/finetune/BALANCED_DISTRIBUTION.md +157 -0
- package/finetune/DATA_IMPROVEMENTS.md +218 -0
- package/finetune/Justfile +43 -0
- package/finetune/Modelfile +16 -0
- package/finetune/README.md +299 -0
- package/finetune/SCORING.md +286 -0
- package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
- package/finetune/configs/grpo.yaml +49 -0
- package/finetune/configs/sft.yaml +42 -0
- package/finetune/configs/sft_local.yaml +40 -0
- package/finetune/convert_gguf.py +221 -0
- package/finetune/data/best_glm_prompt.txt +17 -0
- package/finetune/data/gepa_generated.prompts.json +32 -0
- package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
- package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
- package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
- package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
- package/finetune/data/qmd_expansion_locations.jsonl +64 -0
- package/finetune/data/qmd_expansion_people.jsonl +46 -0
- package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
- package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
- package/finetune/data/qmd_only_sampled.jsonl +399 -0
- package/finetune/dataset/analyze_data.py +369 -0
- package/finetune/dataset/clean_data.py +906 -0
- package/finetune/dataset/generate_balanced.py +823 -0
- package/finetune/dataset/generate_data.py +714 -0
- package/finetune/dataset/generate_data_offline.py +206 -0
- package/finetune/dataset/generate_diverse.py +441 -0
- package/finetune/dataset/generate_ollama.py +326 -0
- package/finetune/dataset/prepare_data.py +197 -0
- package/finetune/dataset/schema.py +73 -0
- package/finetune/dataset/score_data.py +115 -0
- package/finetune/dataset/validate_schema.py +104 -0
- package/finetune/eval.py +196 -0
- package/finetune/evals/queries.txt +56 -0
- package/finetune/gepa/__init__.py +1 -0
- package/finetune/gepa/best_prompt.txt +31 -0
- package/finetune/gepa/best_prompt_glm.txt +1 -0
- package/finetune/gepa/dspy_gepa.py +204 -0
- package/finetune/gepa/example.py +117 -0
- package/finetune/gepa/generate.py +129 -0
- package/finetune/gepa/gepa_outputs.jsonl +10 -0
- package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
- package/finetune/gepa/model.json +19 -0
- package/finetune/gepa/optimizer.py +70 -0
- package/finetune/gepa/score.py +84 -0
- package/finetune/jobs/eval.py +490 -0
- package/finetune/jobs/eval_common.py +354 -0
- package/finetune/jobs/eval_verbose.py +113 -0
- package/finetune/jobs/grpo.py +141 -0
- package/finetune/jobs/quantize.py +244 -0
- package/finetune/jobs/sft.py +121 -0
- package/finetune/pyproject.toml +23 -0
- package/finetune/reward.py +610 -0
- package/finetune/train.py +611 -0
- package/finetune/uv.lock +4070 -0
- package/flake.lock +61 -0
- package/flake.nix +83 -0
- package/migrate-schema.ts +162 -0
- package/package.json +56 -0
- package/skills/qmdr/SKILL.md +172 -0
- package/skills/qmdr/references/mcp-setup.md +88 -0
- package/src/app/commands/collection.ts +55 -0
- package/src/app/commands/context.ts +82 -0
- package/src/app/commands/document.ts +46 -0
- package/src/app/commands/maintenance.ts +60 -0
- package/src/app/commands/search.ts +45 -0
- package/src/app/ports/llm.ts +13 -0
- package/src/app/services/llm-service.ts +145 -0
- package/src/cli.test.ts +963 -0
- package/src/collections.ts +390 -0
- package/src/eval.test.ts +412 -0
- package/src/formatter.ts +427 -0
- package/src/llm.test.ts +559 -0
- package/src/llm.ts +1990 -0
- package/src/mcp.test.ts +889 -0
- package/src/mcp.ts +626 -0
- package/src/qmd.ts +3330 -0
- package/src/store/collections.ts +7 -0
- package/src/store/context.ts +10 -0
- package/src/store/db.ts +5 -0
- package/src/store/documents.ts +26 -0
- package/src/store/maintenance.ts +15 -0
- package/src/store/path.ts +13 -0
- package/src/store/search.ts +10 -0
- package/src/store-paths.test.ts +395 -0
- package/src/store.test.ts +2483 -0
- package/src/store.ts +2813 -0
- package/test/eval-harness.ts +223 -0
- package/tsconfig.json +29 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Generate synthetic training data for QMD query expansion using local Ollama."""
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import json
|
|
6
|
+
import random
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
from dataset.schema import normalize_output_items, parse_output_text
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import requests
|
|
15
|
+
except ImportError:
|
|
16
|
+
print("Install requests: pip install requests")
|
|
17
|
+
exit(1)
|
|
18
|
+
|
|
19
|
+
# Diverse query seeds across many domains
|
|
20
|
+
QUERY_SEEDS = [
|
|
21
|
+
# Programming & Tech
|
|
22
|
+
"async await javascript",
|
|
23
|
+
"rust ownership borrow checker",
|
|
24
|
+
"kubernetes pod networking",
|
|
25
|
+
"docker compose volumes",
|
|
26
|
+
"nginx reverse proxy",
|
|
27
|
+
"postgresql index optimization",
|
|
28
|
+
"redis caching strategies",
|
|
29
|
+
"graphql mutations",
|
|
30
|
+
"websocket authentication",
|
|
31
|
+
"terraform state management",
|
|
32
|
+
"ansible playbook variables",
|
|
33
|
+
"prometheus alerting rules",
|
|
34
|
+
"elasticsearch aggregations",
|
|
35
|
+
"kafka consumer groups",
|
|
36
|
+
"grpc streaming",
|
|
37
|
+
"oauth2 refresh tokens",
|
|
38
|
+
"jwt token expiration",
|
|
39
|
+
"cors preflight requests",
|
|
40
|
+
"css grid layout",
|
|
41
|
+
"react hooks useEffect",
|
|
42
|
+
"vue composition api",
|
|
43
|
+
"svelte stores",
|
|
44
|
+
"nextjs middleware",
|
|
45
|
+
"webpack code splitting",
|
|
46
|
+
"typescript generics constraints",
|
|
47
|
+
"python asyncio gather",
|
|
48
|
+
"go goroutines channels",
|
|
49
|
+
"java streams filter map",
|
|
50
|
+
"c++ smart pointers",
|
|
51
|
+
"swift optionals unwrapping",
|
|
52
|
+
# DevOps & Infrastructure
|
|
53
|
+
"ci cd pipeline best practices",
|
|
54
|
+
"blue green deployment",
|
|
55
|
+
"canary release strategy",
|
|
56
|
+
"infrastructure as code",
|
|
57
|
+
"secrets management vault",
|
|
58
|
+
"load balancer health checks",
|
|
59
|
+
"ssl certificate renewal",
|
|
60
|
+
"dns propagation time",
|
|
61
|
+
"cdn cache invalidation",
|
|
62
|
+
"container orchestration",
|
|
63
|
+
"service mesh istio",
|
|
64
|
+
"observability tracing",
|
|
65
|
+
"log aggregation elk",
|
|
66
|
+
"metrics dashboards grafana",
|
|
67
|
+
"incident response runbook",
|
|
68
|
+
# Data & ML
|
|
69
|
+
"pandas dataframe groupby",
|
|
70
|
+
"numpy array broadcasting",
|
|
71
|
+
"scikit learn pipeline",
|
|
72
|
+
"pytorch autograd",
|
|
73
|
+
"tensorflow keras layers",
|
|
74
|
+
"huggingface transformers",
|
|
75
|
+
"feature engineering techniques",
|
|
76
|
+
"hyperparameter tuning",
|
|
77
|
+
"model evaluation metrics",
|
|
78
|
+
"data preprocessing normalization",
|
|
79
|
+
"time series forecasting",
|
|
80
|
+
"anomaly detection",
|
|
81
|
+
"recommendation systems",
|
|
82
|
+
"natural language processing",
|
|
83
|
+
"computer vision cnn",
|
|
84
|
+
"reinforcement learning",
|
|
85
|
+
"transfer learning",
|
|
86
|
+
"model deployment mlops",
|
|
87
|
+
# Databases
|
|
88
|
+
"sql join types explained",
|
|
89
|
+
"database normalization forms",
|
|
90
|
+
"acid transactions",
|
|
91
|
+
"database sharding strategies",
|
|
92
|
+
"read replicas setup",
|
|
93
|
+
"connection pooling",
|
|
94
|
+
"query optimization explain",
|
|
95
|
+
"stored procedures triggers",
|
|
96
|
+
"database migrations",
|
|
97
|
+
"nosql document model",
|
|
98
|
+
"graph database queries",
|
|
99
|
+
"vector database similarity",
|
|
100
|
+
# Security
|
|
101
|
+
"xss prevention sanitization",
|
|
102
|
+
"sql injection prepared statements",
|
|
103
|
+
"csrf tokens",
|
|
104
|
+
"content security policy",
|
|
105
|
+
"rate limiting api",
|
|
106
|
+
"input validation patterns",
|
|
107
|
+
"password hashing bcrypt",
|
|
108
|
+
"two factor authentication",
|
|
109
|
+
"penetration testing",
|
|
110
|
+
"security headers http",
|
|
111
|
+
"vulnerability scanning",
|
|
112
|
+
"audit logging",
|
|
113
|
+
# System Administration
|
|
114
|
+
"linux file permissions",
|
|
115
|
+
"systemd service unit",
|
|
116
|
+
"cron job scheduling",
|
|
117
|
+
"ssh key management",
|
|
118
|
+
"firewall rules iptables",
|
|
119
|
+
"process monitoring",
|
|
120
|
+
"disk space management",
|
|
121
|
+
"memory leak debugging",
|
|
122
|
+
"network troubleshooting",
|
|
123
|
+
"backup restore strategies",
|
|
124
|
+
"log rotation configuration",
|
|
125
|
+
"performance profiling",
|
|
126
|
+
# General Knowledge
|
|
127
|
+
"climate change effects",
|
|
128
|
+
"renewable energy sources",
|
|
129
|
+
"electric vehicles",
|
|
130
|
+
"artificial intelligence ethics",
|
|
131
|
+
"blockchain technology",
|
|
132
|
+
"quantum computing basics",
|
|
133
|
+
"space exploration mars",
|
|
134
|
+
"gene editing crispr",
|
|
135
|
+
"vaccine development",
|
|
136
|
+
"economic indicators gdp",
|
|
137
|
+
"stock market investing",
|
|
138
|
+
"cryptocurrency trading",
|
|
139
|
+
"mental health awareness",
|
|
140
|
+
"nutrition diet tips",
|
|
141
|
+
"exercise fitness routine",
|
|
142
|
+
"meditation mindfulness",
|
|
143
|
+
"sleep hygiene habits",
|
|
144
|
+
"stress management",
|
|
145
|
+
"time management productivity",
|
|
146
|
+
"remote work tips",
|
|
147
|
+
"team collaboration",
|
|
148
|
+
"project management agile",
|
|
149
|
+
"design thinking process",
|
|
150
|
+
"user experience research",
|
|
151
|
+
# Short/Ambiguous Queries (important for training)
|
|
152
|
+
"cache",
|
|
153
|
+
"proxy",
|
|
154
|
+
"queue",
|
|
155
|
+
"mutex",
|
|
156
|
+
"semaphore",
|
|
157
|
+
"deadlock",
|
|
158
|
+
"heap",
|
|
159
|
+
"stack",
|
|
160
|
+
"tree",
|
|
161
|
+
"graph",
|
|
162
|
+
"hash",
|
|
163
|
+
"sort",
|
|
164
|
+
"api",
|
|
165
|
+
"sdk",
|
|
166
|
+
"cli",
|
|
167
|
+
"gui",
|
|
168
|
+
"orm",
|
|
169
|
+
"cdn",
|
|
170
|
+
"auth",
|
|
171
|
+
"cors",
|
|
172
|
+
"csrf",
|
|
173
|
+
"xss",
|
|
174
|
+
"jwt",
|
|
175
|
+
"ssh",
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
PROMPT_TEMPLATE = """Generate search query expansions for: {query}
|
|
179
|
+
|
|
180
|
+
Output EXACTLY this format (3 lex, 2 vec, 1 hyde):
|
|
181
|
+
lex: keyword phrase 1
|
|
182
|
+
lex: keyword phrase 2
|
|
183
|
+
lex: keyword phrase 3
|
|
184
|
+
vec: natural language search query
|
|
185
|
+
vec: alternative semantic query
|
|
186
|
+
hyde: A specific 2-sentence document passage answering this query.
|
|
187
|
+
|
|
188
|
+
Output:"""
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def generate_with_ollama(
|
|
192
|
+
query: str, model: str = "gemma3:4b", base_url: str = "http://localhost:11434"
|
|
193
|
+
) -> str | None:
|
|
194
|
+
"""Generate query expansion using Ollama API."""
|
|
195
|
+
try:
|
|
196
|
+
response = requests.post(
|
|
197
|
+
f"{base_url}/api/generate",
|
|
198
|
+
json={
|
|
199
|
+
"model": model,
|
|
200
|
+
"prompt": PROMPT_TEMPLATE.format(query=query),
|
|
201
|
+
"stream": False,
|
|
202
|
+
"options": {
|
|
203
|
+
"temperature": 0.7,
|
|
204
|
+
"top_p": 0.9,
|
|
205
|
+
"num_predict": 800, # More tokens for thinking models
|
|
206
|
+
},
|
|
207
|
+
},
|
|
208
|
+
timeout=120,
|
|
209
|
+
)
|
|
210
|
+
response.raise_for_status()
|
|
211
|
+
return response.json().get("response", "").strip()
|
|
212
|
+
except Exception as e:
|
|
213
|
+
print(f"Error generating for '{query}': {e}", file=sys.stderr)
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def parse_expansion(output: str) -> list[list[str]] | None:
|
|
218
|
+
"""Parse the model output into structured format."""
|
|
219
|
+
items = normalize_output_items(parse_output_text(output))
|
|
220
|
+
lex_count = sum(1 for kind, _ in items if kind == "lex")
|
|
221
|
+
vec_count = sum(1 for kind, _ in items if kind == "vec")
|
|
222
|
+
hyde_count = sum(1 for kind, _ in items if kind == "hyde")
|
|
223
|
+
if lex_count >= 2 and vec_count >= 1 and hyde_count >= 1:
|
|
224
|
+
return items
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def generate_query_variations(seed: str) -> list[str]:
|
|
229
|
+
"""Generate variations of a seed query."""
|
|
230
|
+
variations = [seed]
|
|
231
|
+
|
|
232
|
+
# Add question forms
|
|
233
|
+
if not seed.startswith(("how", "what", "why", "when", "where")):
|
|
234
|
+
variations.append(f"how to {seed}")
|
|
235
|
+
variations.append(f"what is {seed}")
|
|
236
|
+
|
|
237
|
+
# Add context
|
|
238
|
+
variations.append(f"{seed} tutorial")
|
|
239
|
+
variations.append(f"{seed} best practices")
|
|
240
|
+
variations.append(f"{seed} examples")
|
|
241
|
+
|
|
242
|
+
return variations
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def main():
|
|
246
|
+
parser = argparse.ArgumentParser(description="Generate training data using Ollama")
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"--output", "-o", default="data/qmd_expansion_ollama.jsonl", help="Output file"
|
|
249
|
+
)
|
|
250
|
+
parser.add_argument(
|
|
251
|
+
"--count", "-n", type=int, default=1000, help="Number of examples to generate"
|
|
252
|
+
)
|
|
253
|
+
parser.add_argument("--model", "-m", default="gemma3:4b", help="Ollama model name")
|
|
254
|
+
parser.add_argument(
|
|
255
|
+
"--base-url", default="http://localhost:11434", help="Ollama base URL"
|
|
256
|
+
)
|
|
257
|
+
parser.add_argument(
|
|
258
|
+
"--resume", action="store_true", help="Resume from existing file"
|
|
259
|
+
)
|
|
260
|
+
args = parser.parse_args()
|
|
261
|
+
|
|
262
|
+
output_path = Path(args.output)
|
|
263
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
264
|
+
|
|
265
|
+
# Load existing if resuming
|
|
266
|
+
existing_queries = set()
|
|
267
|
+
if args.resume and output_path.exists():
|
|
268
|
+
with open(output_path) as f:
|
|
269
|
+
for line in f:
|
|
270
|
+
obj = json.loads(line)
|
|
271
|
+
existing_queries.add(obj.get("query", obj.get("input", "")).lower())
|
|
272
|
+
print(
|
|
273
|
+
f"Resuming with {len(existing_queries)} existing examples", file=sys.stderr
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Generate query pool
|
|
277
|
+
all_queries = []
|
|
278
|
+
for seed in QUERY_SEEDS:
|
|
279
|
+
all_queries.extend(generate_query_variations(seed))
|
|
280
|
+
|
|
281
|
+
# Shuffle and filter
|
|
282
|
+
random.shuffle(all_queries)
|
|
283
|
+
queries_to_process = [q for q in all_queries if q.lower() not in existing_queries]
|
|
284
|
+
|
|
285
|
+
print(
|
|
286
|
+
f"Processing {min(args.count, len(queries_to_process))} queries with {args.model}...",
|
|
287
|
+
file=sys.stderr,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
generated = 0
|
|
291
|
+
errors = 0
|
|
292
|
+
|
|
293
|
+
mode = "a" if args.resume else "w"
|
|
294
|
+
with open(output_path, mode) as f:
|
|
295
|
+
for i, query in enumerate(queries_to_process):
|
|
296
|
+
if generated >= args.count:
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
output = generate_with_ollama(query, args.model, args.base_url)
|
|
300
|
+
if output:
|
|
301
|
+
parsed = parse_expansion(output)
|
|
302
|
+
if parsed:
|
|
303
|
+
example = {"query": query, "output": parsed}
|
|
304
|
+
f.write(json.dumps(example) + "\n")
|
|
305
|
+
f.flush()
|
|
306
|
+
generated += 1
|
|
307
|
+
|
|
308
|
+
if generated % 10 == 0:
|
|
309
|
+
print(
|
|
310
|
+
f"Generated {generated}/{args.count} ({errors} errors)",
|
|
311
|
+
file=sys.stderr,
|
|
312
|
+
)
|
|
313
|
+
else:
|
|
314
|
+
errors += 1
|
|
315
|
+
else:
|
|
316
|
+
errors += 1
|
|
317
|
+
|
|
318
|
+
# Small delay to avoid overwhelming the API
|
|
319
|
+
time.sleep(0.1)
|
|
320
|
+
|
|
321
|
+
print(f"\nDone! Generated {generated} examples, {errors} errors", file=sys.stderr)
|
|
322
|
+
print(f"Output: {output_path}", file=sys.stderr)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
if __name__ == "__main__":
|
|
326
|
+
main()
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# /// script
|
|
3
|
+
# requires-python = ">=3.10"
|
|
4
|
+
# dependencies = [
|
|
5
|
+
# "transformers>=4.45.0",
|
|
6
|
+
# "jinja2",
|
|
7
|
+
# ]
|
|
8
|
+
# ///
|
|
9
|
+
"""Prepare QMD query expansion data for training.
|
|
10
|
+
|
|
11
|
+
See PROMPT_FORMAT.md for format specification.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import random
|
|
17
|
+
import sys
|
|
18
|
+
import os
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
22
|
+
from dataset.schema import (
|
|
23
|
+
normalize_output_items,
|
|
24
|
+
output_items_to_text,
|
|
25
|
+
parse_output_text,
|
|
26
|
+
has_type,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from transformers import AutoTokenizer
|
|
30
|
+
|
|
31
|
+
_tokenizer = None
|
|
32
|
+
_tokenizer_model = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_tokenizer():
|
|
36
|
+
global _tokenizer, _tokenizer_model
|
|
37
|
+
model_name = os.environ.get("QMD_BASE_MODEL", "Qwen/Qwen3-1.7B")
|
|
38
|
+
if _tokenizer is None or _tokenizer_model != model_name:
|
|
39
|
+
_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
40
|
+
_tokenizer_model = model_name
|
|
41
|
+
return _tokenizer
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def format_for_training(query_text: str, output_items: list[list[str]]) -> dict:
|
|
46
|
+
"""Format a single example for SFT training using Qwen chat format."""
|
|
47
|
+
tokenizer = get_tokenizer()
|
|
48
|
+
output_text = output_items_to_text(output_items)
|
|
49
|
+
|
|
50
|
+
# Use /no_think to disable thinking mode - we want direct output
|
|
51
|
+
messages = [
|
|
52
|
+
{
|
|
53
|
+
"role": "user",
|
|
54
|
+
"content": f"/no_think Expand this search query: {query_text}",
|
|
55
|
+
},
|
|
56
|
+
{"role": "assistant", "content": output_text},
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
# Use tokenizer to generate proper chat format with special tokens
|
|
60
|
+
text = tokenizer.apply_chat_template(
|
|
61
|
+
messages,
|
|
62
|
+
tokenize=False,
|
|
63
|
+
add_generation_prompt=False,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Strip empty <think> tags - we don't want thinking mode
|
|
67
|
+
# The template adds "<think>\n\n</think>\n\n" which we remove
|
|
68
|
+
text = text.replace("<think>\n\n</think>\n\n", "")
|
|
69
|
+
|
|
70
|
+
return {
|
|
71
|
+
"text": text,
|
|
72
|
+
"messages": messages,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def main():
|
|
77
|
+
parser = argparse.ArgumentParser(description="Prepare data for training")
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--input",
|
|
80
|
+
type=str,
|
|
81
|
+
default="data/*.jsonl",
|
|
82
|
+
help="Input JSONL file(s) - supports glob patterns",
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--output", type=str, default="data/train", help="Output directory"
|
|
86
|
+
)
|
|
87
|
+
parser.add_argument(
|
|
88
|
+
"--split", type=float, default=0.1, help="Validation split ratio"
|
|
89
|
+
)
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--seed",
|
|
92
|
+
type=int,
|
|
93
|
+
default=42,
|
|
94
|
+
help="Shuffle seed (default: 42)",
|
|
95
|
+
)
|
|
96
|
+
args = parser.parse_args()
|
|
97
|
+
|
|
98
|
+
output_dir = Path(args.output)
|
|
99
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
|
|
101
|
+
# Support glob patterns for input
|
|
102
|
+
import glob
|
|
103
|
+
|
|
104
|
+
if "*" in args.input:
|
|
105
|
+
input_files = sorted(glob.glob(args.input))
|
|
106
|
+
if not input_files:
|
|
107
|
+
print(f"Error: No files found matching: {args.input}")
|
|
108
|
+
exit(1)
|
|
109
|
+
print(
|
|
110
|
+
f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}"
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
input_path = Path(args.input)
|
|
114
|
+
if not input_path.exists():
|
|
115
|
+
print(f"Error: Input file not found: {input_path}")
|
|
116
|
+
exit(1)
|
|
117
|
+
input_files = [str(input_path)]
|
|
118
|
+
|
|
119
|
+
# Load all examples from all input files
|
|
120
|
+
examples = []
|
|
121
|
+
|
|
122
|
+
for input_file in input_files:
|
|
123
|
+
file_count = 0
|
|
124
|
+
with open(input_file) as f:
|
|
125
|
+
for line in f:
|
|
126
|
+
if line.strip():
|
|
127
|
+
ex = json.loads(line)
|
|
128
|
+
|
|
129
|
+
# Normalize legacy format
|
|
130
|
+
if "query" not in ex and "input" in ex:
|
|
131
|
+
ex["query"] = ex.pop("input")
|
|
132
|
+
if isinstance(ex.get("output"), str):
|
|
133
|
+
ex["output"] = parse_output_text(ex["output"])
|
|
134
|
+
ex["output"] = normalize_output_items(ex.get("output", []))
|
|
135
|
+
|
|
136
|
+
examples.append(ex)
|
|
137
|
+
file_count += 1
|
|
138
|
+
print(f" {Path(input_file).name}: {file_count} examples")
|
|
139
|
+
|
|
140
|
+
print(f"Loaded {len(examples)} examples total")
|
|
141
|
+
|
|
142
|
+
# Combine and shuffle
|
|
143
|
+
all_examples = examples
|
|
144
|
+
random.seed(args.seed)
|
|
145
|
+
random.shuffle(all_examples)
|
|
146
|
+
|
|
147
|
+
# Format for training
|
|
148
|
+
formatted = [format_for_training(ex["query"], ex["output"]) for ex in all_examples]
|
|
149
|
+
|
|
150
|
+
# Split into train/val
|
|
151
|
+
split_idx = int(len(formatted) * (1 - args.split))
|
|
152
|
+
train_data = formatted[:split_idx]
|
|
153
|
+
val_data = formatted[split_idx:]
|
|
154
|
+
|
|
155
|
+
# Write train set
|
|
156
|
+
train_path = output_dir / "train.jsonl"
|
|
157
|
+
with open(train_path, "w") as f:
|
|
158
|
+
for item in train_data:
|
|
159
|
+
f.write(json.dumps(item) + "\n")
|
|
160
|
+
|
|
161
|
+
# Write validation set
|
|
162
|
+
val_path = output_dir / "val.jsonl"
|
|
163
|
+
with open(val_path, "w") as f:
|
|
164
|
+
for item in val_data:
|
|
165
|
+
f.write(json.dumps(item) + "\n")
|
|
166
|
+
|
|
167
|
+
# Write chat format (for TRL)
|
|
168
|
+
chat_path = output_dir / "train_chat.jsonl"
|
|
169
|
+
with open(chat_path, "w") as f:
|
|
170
|
+
for item in train_data:
|
|
171
|
+
f.write(json.dumps({"messages": item["messages"]}) + "\n")
|
|
172
|
+
|
|
173
|
+
# Stats
|
|
174
|
+
short_final = sum(1 for ex in all_examples if len(ex["query"].split()) <= 2)
|
|
175
|
+
|
|
176
|
+
print(f"\n=== Summary ===")
|
|
177
|
+
print(f"Total examples: {len(all_examples)}")
|
|
178
|
+
print(
|
|
179
|
+
f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)"
|
|
180
|
+
)
|
|
181
|
+
print(f"Train: {len(train_data)}, Val: {len(val_data)}")
|
|
182
|
+
print(f"Output: {output_dir}")
|
|
183
|
+
|
|
184
|
+
# Dataset info
|
|
185
|
+
dataset_info = {
|
|
186
|
+
"dataset_name": "qmd-query-expansion",
|
|
187
|
+
"train_samples": len(train_data),
|
|
188
|
+
"val_samples": len(val_data),
|
|
189
|
+
"short_query_pct": round(100 * short_final / len(all_examples), 1),
|
|
190
|
+
"columns": ["prompt", "completion", "text", "messages"],
|
|
191
|
+
}
|
|
192
|
+
with open(output_dir / "dataset_info.json", "w") as f:
|
|
193
|
+
json.dump(dataset_info, f, indent=2)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
if __name__ == "__main__":
|
|
197
|
+
main()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Schema helpers for QMD training JSONL data."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Iterable
|
|
7
|
+
|
|
8
|
+
VALID_OUTPUT_TYPES = {"hyde", "lex", "vec"}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def parse_output_text(text: str) -> list[list[str]]:
|
|
12
|
+
"""Parse prefixed output text into list pairs.
|
|
13
|
+
|
|
14
|
+
Returns: [["hyde", "..."], ["lex", "..."], ...]
|
|
15
|
+
"""
|
|
16
|
+
items: list[list[str]] = []
|
|
17
|
+
for raw_line in text.strip().split("\n"):
|
|
18
|
+
line = raw_line.strip()
|
|
19
|
+
if not line:
|
|
20
|
+
continue
|
|
21
|
+
if line.startswith("lex:"):
|
|
22
|
+
items.append(["lex", line[4:].strip()])
|
|
23
|
+
elif line.startswith("vec:"):
|
|
24
|
+
items.append(["vec", line[4:].strip()])
|
|
25
|
+
elif line.startswith("hyde:"):
|
|
26
|
+
items.append(["hyde", line[5:].strip()])
|
|
27
|
+
return items
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def output_items_to_text(items: Iterable[Iterable[str]]) -> str:
|
|
31
|
+
"""Render output list pairs to prefixed text lines."""
|
|
32
|
+
lines = []
|
|
33
|
+
for item in items:
|
|
34
|
+
if not item:
|
|
35
|
+
continue
|
|
36
|
+
try:
|
|
37
|
+
kind, text = item[0], item[1]
|
|
38
|
+
except Exception:
|
|
39
|
+
continue
|
|
40
|
+
if kind not in VALID_OUTPUT_TYPES:
|
|
41
|
+
continue
|
|
42
|
+
if text is None:
|
|
43
|
+
continue
|
|
44
|
+
text = str(text).strip()
|
|
45
|
+
if not text:
|
|
46
|
+
continue
|
|
47
|
+
lines.append(f"{kind}: {text}")
|
|
48
|
+
return "\n".join(lines)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def normalize_output_items(items: Iterable[Iterable[str]]) -> list[list[str]]:
|
|
52
|
+
"""Normalize output list pairs (filter invalid, trim whitespace)."""
|
|
53
|
+
normalized: list[list[str]] = []
|
|
54
|
+
for item in items:
|
|
55
|
+
if not item:
|
|
56
|
+
continue
|
|
57
|
+
try:
|
|
58
|
+
kind, text = item[0], item[1]
|
|
59
|
+
except Exception:
|
|
60
|
+
continue
|
|
61
|
+
if kind not in VALID_OUTPUT_TYPES:
|
|
62
|
+
continue
|
|
63
|
+
if text is None:
|
|
64
|
+
continue
|
|
65
|
+
text = str(text).strip()
|
|
66
|
+
if not text:
|
|
67
|
+
continue
|
|
68
|
+
normalized.append([kind, text])
|
|
69
|
+
return normalized
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def has_type(items: Iterable[Iterable[str]], kind: str) -> bool:
|
|
73
|
+
return any(item and item[0] == kind for item in items)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Score JSONL datasets with the reward function."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import statistics
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
13
|
+
from dataset.schema import (
|
|
14
|
+
normalize_output_items,
|
|
15
|
+
output_items_to_text,
|
|
16
|
+
parse_output_text,
|
|
17
|
+
)
|
|
18
|
+
from reward import score_expansion_detailed
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def score_file(path: Path) -> tuple[int, int, list[float], dict]:
|
|
22
|
+
total = 0
|
|
23
|
+
errors = 0
|
|
24
|
+
scores: list[float] = []
|
|
25
|
+
ratings: dict[str, int] = {}
|
|
26
|
+
|
|
27
|
+
with path.open("r", encoding="utf-8") as f:
|
|
28
|
+
for line_num, line in enumerate(f, 1):
|
|
29
|
+
line = line.strip()
|
|
30
|
+
if not line:
|
|
31
|
+
continue
|
|
32
|
+
total += 1
|
|
33
|
+
try:
|
|
34
|
+
obj = json.loads(line)
|
|
35
|
+
except json.JSONDecodeError:
|
|
36
|
+
errors += 1
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
query = obj.get("query") or obj.get("input")
|
|
40
|
+
output = obj.get("output")
|
|
41
|
+
if not isinstance(query, str) or not query.strip():
|
|
42
|
+
errors += 1
|
|
43
|
+
continue
|
|
44
|
+
if output is None:
|
|
45
|
+
errors += 1
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
if isinstance(output, str):
|
|
49
|
+
output_items = normalize_output_items(parse_output_text(output))
|
|
50
|
+
else:
|
|
51
|
+
output_items = normalize_output_items(output)
|
|
52
|
+
|
|
53
|
+
output_text = output_items_to_text(output_items)
|
|
54
|
+
if not output_text:
|
|
55
|
+
errors += 1
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
detail = score_expansion_detailed(query, output_text)
|
|
59
|
+
score = detail["percentage"]
|
|
60
|
+
scores.append(score)
|
|
61
|
+
rating = detail["rating"]
|
|
62
|
+
ratings[rating] = ratings.get(rating, 0) + 1
|
|
63
|
+
|
|
64
|
+
return total, errors, scores, ratings
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def main() -> int:
|
|
68
|
+
parser = argparse.ArgumentParser(description="Score QMD datasets")
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"paths",
|
|
71
|
+
nargs="*",
|
|
72
|
+
default=["finetune/data/*.jsonl"],
|
|
73
|
+
help="JSONL files or glob patterns (default: finetune/data/*.jsonl)",
|
|
74
|
+
)
|
|
75
|
+
args = parser.parse_args()
|
|
76
|
+
|
|
77
|
+
repo_root = Path(__file__).parent.parent.parent
|
|
78
|
+
files: list[Path] = []
|
|
79
|
+
for pattern in args.paths:
|
|
80
|
+
if "*" in pattern:
|
|
81
|
+
files.extend(repo_root.glob(pattern))
|
|
82
|
+
else:
|
|
83
|
+
files.append(repo_root / pattern)
|
|
84
|
+
|
|
85
|
+
files = [p for p in files if p.exists()]
|
|
86
|
+
if not files:
|
|
87
|
+
print("No files found to score.")
|
|
88
|
+
return 1
|
|
89
|
+
|
|
90
|
+
for path in sorted(files):
|
|
91
|
+
total, errors, scores, ratings = score_file(path)
|
|
92
|
+
if scores:
|
|
93
|
+
avg = statistics.mean(scores)
|
|
94
|
+
median = statistics.median(scores)
|
|
95
|
+
min_score = min(scores)
|
|
96
|
+
max_score = max(scores)
|
|
97
|
+
above_70 = sum(1 for s in scores if s >= 70.0)
|
|
98
|
+
pct_70 = above_70 / len(scores) * 100
|
|
99
|
+
print(
|
|
100
|
+
f"{path}: {len(scores)} scored, {errors} errors, "
|
|
101
|
+
f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
|
|
102
|
+
f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
print(f"{path}: 0 scored, {errors} errors")
|
|
106
|
+
|
|
107
|
+
if ratings:
|
|
108
|
+
rating_parts = [f"{k}:{v}" for k, v in sorted(ratings.items())]
|
|
109
|
+
print(f" ratings: {', '.join(rating_parts)}")
|
|
110
|
+
|
|
111
|
+
return 0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if __name__ == "__main__":
|
|
115
|
+
raise SystemExit(main())
|