qmdr 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/marketplace.json +29 -0
- package/.env.example +85 -0
- package/.gitattributes +3 -0
- package/.github/workflows/release.yml +77 -0
- package/AI-SETUP.md +466 -0
- package/LICENSE +22 -0
- package/README.md +78 -0
- package/bun.lock +637 -0
- package/docs/README-zh.md +78 -0
- package/docs/refactor-checklist.md +54 -0
- package/docs/setup-openclaw.md +139 -0
- package/example-index.yml +33 -0
- package/finetune/BALANCED_DISTRIBUTION.md +157 -0
- package/finetune/DATA_IMPROVEMENTS.md +218 -0
- package/finetune/Justfile +43 -0
- package/finetune/Modelfile +16 -0
- package/finetune/README.md +299 -0
- package/finetune/SCORING.md +286 -0
- package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
- package/finetune/configs/grpo.yaml +49 -0
- package/finetune/configs/sft.yaml +42 -0
- package/finetune/configs/sft_local.yaml +40 -0
- package/finetune/convert_gguf.py +221 -0
- package/finetune/data/best_glm_prompt.txt +17 -0
- package/finetune/data/gepa_generated.prompts.json +32 -0
- package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
- package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
- package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
- package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
- package/finetune/data/qmd_expansion_locations.jsonl +64 -0
- package/finetune/data/qmd_expansion_people.jsonl +46 -0
- package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
- package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
- package/finetune/data/qmd_only_sampled.jsonl +399 -0
- package/finetune/dataset/analyze_data.py +369 -0
- package/finetune/dataset/clean_data.py +906 -0
- package/finetune/dataset/generate_balanced.py +823 -0
- package/finetune/dataset/generate_data.py +714 -0
- package/finetune/dataset/generate_data_offline.py +206 -0
- package/finetune/dataset/generate_diverse.py +441 -0
- package/finetune/dataset/generate_ollama.py +326 -0
- package/finetune/dataset/prepare_data.py +197 -0
- package/finetune/dataset/schema.py +73 -0
- package/finetune/dataset/score_data.py +115 -0
- package/finetune/dataset/validate_schema.py +104 -0
- package/finetune/eval.py +196 -0
- package/finetune/evals/queries.txt +56 -0
- package/finetune/gepa/__init__.py +1 -0
- package/finetune/gepa/best_prompt.txt +31 -0
- package/finetune/gepa/best_prompt_glm.txt +1 -0
- package/finetune/gepa/dspy_gepa.py +204 -0
- package/finetune/gepa/example.py +117 -0
- package/finetune/gepa/generate.py +129 -0
- package/finetune/gepa/gepa_outputs.jsonl +10 -0
- package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
- package/finetune/gepa/model.json +19 -0
- package/finetune/gepa/optimizer.py +70 -0
- package/finetune/gepa/score.py +84 -0
- package/finetune/jobs/eval.py +490 -0
- package/finetune/jobs/eval_common.py +354 -0
- package/finetune/jobs/eval_verbose.py +113 -0
- package/finetune/jobs/grpo.py +141 -0
- package/finetune/jobs/quantize.py +244 -0
- package/finetune/jobs/sft.py +121 -0
- package/finetune/pyproject.toml +23 -0
- package/finetune/reward.py +610 -0
- package/finetune/train.py +611 -0
- package/finetune/uv.lock +4070 -0
- package/flake.lock +61 -0
- package/flake.nix +83 -0
- package/migrate-schema.ts +162 -0
- package/package.json +56 -0
- package/skills/qmdr/SKILL.md +172 -0
- package/skills/qmdr/references/mcp-setup.md +88 -0
- package/src/app/commands/collection.ts +55 -0
- package/src/app/commands/context.ts +82 -0
- package/src/app/commands/document.ts +46 -0
- package/src/app/commands/maintenance.ts +60 -0
- package/src/app/commands/search.ts +45 -0
- package/src/app/ports/llm.ts +13 -0
- package/src/app/services/llm-service.ts +145 -0
- package/src/cli.test.ts +963 -0
- package/src/collections.ts +390 -0
- package/src/eval.test.ts +412 -0
- package/src/formatter.ts +427 -0
- package/src/llm.test.ts +559 -0
- package/src/llm.ts +1990 -0
- package/src/mcp.test.ts +889 -0
- package/src/mcp.ts +626 -0
- package/src/qmd.ts +3330 -0
- package/src/store/collections.ts +7 -0
- package/src/store/context.ts +10 -0
- package/src/store/db.ts +5 -0
- package/src/store/documents.ts +26 -0
- package/src/store/maintenance.ts +15 -0
- package/src/store/path.ts +13 -0
- package/src/store/search.ts +10 -0
- package/src/store-paths.test.ts +395 -0
- package/src/store.test.ts +2483 -0
- package/src/store.ts +2813 -0
- package/test/eval-harness.ts +223 -0
- package/tsconfig.json +29 -0
|
@@ -0,0 +1,490 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# requires-python = ">=3.10"
|
|
3
|
+
# dependencies = [
|
|
4
|
+
# "transformers>=4.45.0",
|
|
5
|
+
# "peft>=0.7.0",
|
|
6
|
+
# "torch",
|
|
7
|
+
# "huggingface_hub>=0.20.0",
|
|
8
|
+
# "accelerate",
|
|
9
|
+
# ]
|
|
10
|
+
# ///
|
|
11
|
+
"""
|
|
12
|
+
Evaluate QMD query expansion models on HuggingFace Jobs.
|
|
13
|
+
|
|
14
|
+
Self-contained script — inlines the reward function and test queries.
|
|
15
|
+
|
|
16
|
+
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py
|
|
17
|
+
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import csv
|
|
22
|
+
import io
|
|
23
|
+
import json
|
|
24
|
+
import os
|
|
25
|
+
import re
|
|
26
|
+
import sys
|
|
27
|
+
from collections import Counter
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
from huggingface_hub import HfApi, login
|
|
31
|
+
from peft import PeftModel
|
|
32
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
33
|
+
|
|
34
|
+
# --- Config ---
|
|
35
|
+
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
|
36
|
+
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
|
37
|
+
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
|
38
|
+
|
|
39
|
+
# --- Test queries (inlined from evals/queries.txt) ---
|
|
40
|
+
QUERIES = [
|
|
41
|
+
# Technical documentation
|
|
42
|
+
"how to configure authentication",
|
|
43
|
+
"typescript async await",
|
|
44
|
+
"docker compose networking",
|
|
45
|
+
"git rebase vs merge",
|
|
46
|
+
"react useEffect cleanup",
|
|
47
|
+
# Short/ambiguous
|
|
48
|
+
"auth",
|
|
49
|
+
"config",
|
|
50
|
+
"setup",
|
|
51
|
+
"api",
|
|
52
|
+
# Named entities
|
|
53
|
+
"who is TDS motorsports",
|
|
54
|
+
"React hooks tutorial",
|
|
55
|
+
"Docker container networking",
|
|
56
|
+
"Kubernetes pod deployment",
|
|
57
|
+
"AWS Lambda functions",
|
|
58
|
+
# Personal notes / journals
|
|
59
|
+
"meeting notes project kickoff",
|
|
60
|
+
"ideas for new feature",
|
|
61
|
+
"todo list app architecture",
|
|
62
|
+
# Research / learning
|
|
63
|
+
"what is dependency injection",
|
|
64
|
+
"difference between sql and nosql",
|
|
65
|
+
"kubernetes vs docker swarm",
|
|
66
|
+
# Error/debugging
|
|
67
|
+
"connection timeout error",
|
|
68
|
+
"memory leak debugging",
|
|
69
|
+
"cors error fix",
|
|
70
|
+
# Temporal / recency
|
|
71
|
+
"recent news about Shopify",
|
|
72
|
+
"latest AI developments",
|
|
73
|
+
"best laptops right now",
|
|
74
|
+
"what changed in kubernetes latest version",
|
|
75
|
+
# Complex
|
|
76
|
+
"how to implement caching with redis in nodejs",
|
|
77
|
+
"best practices for api rate limiting",
|
|
78
|
+
"setting up ci cd pipeline with github actions",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
# =============================================================================
|
|
82
|
+
# Reward function (inlined from reward.py)
|
|
83
|
+
# =============================================================================
|
|
84
|
+
|
|
85
|
+
STOPWORDS = frozenset({
|
|
86
|
+
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
|
87
|
+
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
|
88
|
+
})
|
|
89
|
+
|
|
90
|
+
KEY_TERM_STOPWORDS = frozenset({
|
|
91
|
+
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
|
92
|
+
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
|
93
|
+
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
|
94
|
+
})
|
|
95
|
+
|
|
96
|
+
GENERIC_LEX_PHRASES = frozenset({
|
|
97
|
+
'find information about', 'search for', 'look up', 'get information',
|
|
98
|
+
'learn about', 'information on', 'details about', 'find out about',
|
|
99
|
+
'what is', 'how to', 'guide to', 'help with',
|
|
100
|
+
})
|
|
101
|
+
|
|
102
|
+
CHAT_TEMPLATE_TOKENS = frozenset({
|
|
103
|
+
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
|
104
|
+
'\nassistant\n', '\nuser\n',
|
|
105
|
+
})
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def parse_expansion(text):
|
|
109
|
+
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
110
|
+
for line in text.strip().split("\n"):
|
|
111
|
+
line = line.strip()
|
|
112
|
+
if not line:
|
|
113
|
+
continue
|
|
114
|
+
if line.startswith("lex:"):
|
|
115
|
+
result["lex"].append(line[4:].strip())
|
|
116
|
+
elif line.startswith("vec:"):
|
|
117
|
+
result["vec"].append(line[4:].strip())
|
|
118
|
+
elif line.startswith("hyde:"):
|
|
119
|
+
result["hyde"].append(line[5:].strip())
|
|
120
|
+
else:
|
|
121
|
+
result["invalid"].append(line)
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def clean_model_output(text):
|
|
126
|
+
text = text.replace('<|im_end|>', '').strip()
|
|
127
|
+
used_thinking = '<think>' in text and '</think>' in text
|
|
128
|
+
if used_thinking:
|
|
129
|
+
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
|
130
|
+
return text, used_thinking
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def extract_named_entities(query):
|
|
134
|
+
entities = set()
|
|
135
|
+
words = query.split()
|
|
136
|
+
prev_was_entity = False
|
|
137
|
+
for i, word in enumerate(words):
|
|
138
|
+
clean = word.strip('.,!?:;()[]"\'')
|
|
139
|
+
if not clean:
|
|
140
|
+
prev_was_entity = False
|
|
141
|
+
continue
|
|
142
|
+
is_entity = False
|
|
143
|
+
if clean.isupper() and len(clean) >= 2:
|
|
144
|
+
entities.add(clean.lower()); is_entity = True
|
|
145
|
+
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
146
|
+
entities.add(clean.lower()); is_entity = True
|
|
147
|
+
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
|
148
|
+
entities.add(clean.lower()); is_entity = True
|
|
149
|
+
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
|
150
|
+
entities.add(clean.lower()); is_entity = True
|
|
151
|
+
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
152
|
+
entities.add(clean.lower()); is_entity = True
|
|
153
|
+
prev_was_entity = is_entity
|
|
154
|
+
return entities
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_key_terms(query):
|
|
158
|
+
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def lex_preserves_key_terms(lex_line, query):
|
|
162
|
+
key_terms = get_key_terms(query)
|
|
163
|
+
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def lex_preserves_entities(line, entities):
|
|
167
|
+
if not entities: return True
|
|
168
|
+
return any(e in line.lower() for e in entities)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def lex_is_generic(lex_line):
|
|
172
|
+
lower = lex_line.lower().strip()
|
|
173
|
+
for phrase in GENERIC_LEX_PHRASES:
|
|
174
|
+
if phrase in lower or lower.startswith(phrase.split()[0]):
|
|
175
|
+
remaining = lower
|
|
176
|
+
for word in phrase.split():
|
|
177
|
+
remaining = remaining.replace(word, '', 1).strip()
|
|
178
|
+
if len(remaining) < 3:
|
|
179
|
+
return True
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def word_set_distance(a, b):
|
|
184
|
+
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def is_diverse(a, b, min_distance=2):
|
|
188
|
+
a, b = a.lower().strip(), b.lower().strip()
|
|
189
|
+
if a == b or a in b or b in a: return False
|
|
190
|
+
return word_set_distance(a, b) >= min_distance
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def echoes_query(expansion, query):
|
|
194
|
+
exp, q = expansion.lower().strip(), query.lower().strip()
|
|
195
|
+
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def word_repetition_penalty(text):
|
|
199
|
+
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
|
200
|
+
return sum((c - 2) * 2 for w, c in counts.items()
|
|
201
|
+
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def score_expansion_detailed(query, expansion):
|
|
205
|
+
text, used_thinking = clean_model_output(expansion.strip())
|
|
206
|
+
deductions = []
|
|
207
|
+
|
|
208
|
+
def _fail(reason):
|
|
209
|
+
return {
|
|
210
|
+
"format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
|
|
211
|
+
"think_bonus": 0, "total": 0, "max_possible": 100,
|
|
212
|
+
"percentage": 0.0, "rating": "Failed", "deductions": [reason],
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
|
216
|
+
return _fail("CHAT TEMPLATE LEAKAGE")
|
|
217
|
+
for line in text.split("\n"):
|
|
218
|
+
line = line.strip()
|
|
219
|
+
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
|
220
|
+
return _fail(f"INVALID LINE: {line[:50]}")
|
|
221
|
+
|
|
222
|
+
parsed = parse_expansion(text)
|
|
223
|
+
|
|
224
|
+
format_score = 10
|
|
225
|
+
if parsed["lex"]: format_score += 10
|
|
226
|
+
else: deductions.append("missing lex:")
|
|
227
|
+
if parsed["vec"]: format_score += 10
|
|
228
|
+
else: deductions.append("missing vec:")
|
|
229
|
+
|
|
230
|
+
diversity_score = 0
|
|
231
|
+
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
|
|
232
|
+
if types_present >= 2: diversity_score += 10
|
|
233
|
+
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
|
|
234
|
+
lex_div = 5
|
|
235
|
+
for i, a in enumerate(parsed["lex"]):
|
|
236
|
+
for b in parsed["lex"][i+1:]:
|
|
237
|
+
if not is_diverse(a, b, 2): lex_div -= 2
|
|
238
|
+
diversity_score += max(0, lex_div)
|
|
239
|
+
vec_div = 5
|
|
240
|
+
for i, a in enumerate(parsed["vec"]):
|
|
241
|
+
for b in parsed["vec"][i+1:]:
|
|
242
|
+
if not is_diverse(a, b, 3): vec_div -= 2
|
|
243
|
+
diversity_score += max(0, vec_div)
|
|
244
|
+
echo = 5
|
|
245
|
+
for exp in parsed["lex"] + parsed["vec"]:
|
|
246
|
+
if echoes_query(exp, query): echo -= 3
|
|
247
|
+
diversity_score += max(0, echo)
|
|
248
|
+
|
|
249
|
+
hyde_score = 0
|
|
250
|
+
if parsed["hyde"]:
|
|
251
|
+
hyde_text = parsed["hyde"][0]
|
|
252
|
+
hyde_score += 5
|
|
253
|
+
hyde_len = len(hyde_text)
|
|
254
|
+
if 50 <= hyde_len <= 200: hyde_score += 5
|
|
255
|
+
elif hyde_len < 50: hyde_score += 2
|
|
256
|
+
if "\n" not in hyde_text: hyde_score += 5
|
|
257
|
+
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
|
258
|
+
|
|
259
|
+
quality_score = 5
|
|
260
|
+
if parsed["lex"] and parsed["vec"]:
|
|
261
|
+
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
|
262
|
+
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
|
263
|
+
if avg_lex <= avg_vec: quality_score += 5
|
|
264
|
+
if parsed["vec"]:
|
|
265
|
+
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
|
266
|
+
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
|
267
|
+
if parsed["lex"]:
|
|
268
|
+
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
|
269
|
+
if with_terms == len(parsed["lex"]): quality_score += 5
|
|
270
|
+
elif with_terms > 0: quality_score += 2
|
|
271
|
+
|
|
272
|
+
entity_score = 0
|
|
273
|
+
entities = extract_named_entities(query)
|
|
274
|
+
if entities and parsed["lex"]:
|
|
275
|
+
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
|
276
|
+
if with_entities == len(parsed["lex"]): entity_score += 15
|
|
277
|
+
elif with_entities > 0: entity_score += 5
|
|
278
|
+
else: entity_score -= 30
|
|
279
|
+
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
|
280
|
+
if generic_count: entity_score -= generic_count * 15
|
|
281
|
+
if parsed["vec"]:
|
|
282
|
+
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
|
283
|
+
if vec_with > 0: entity_score += 5
|
|
284
|
+
elif not entities:
|
|
285
|
+
entity_score = 10
|
|
286
|
+
|
|
287
|
+
think_bonus = 0 if used_thinking else 20
|
|
288
|
+
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
|
289
|
+
max_possible = 140 if parsed["hyde"] else 120
|
|
290
|
+
percentage = max(0.0, min(100.0, total / max_possible * 100))
|
|
291
|
+
|
|
292
|
+
if percentage >= 80: rating = "Excellent"
|
|
293
|
+
elif percentage >= 60: rating = "Good"
|
|
294
|
+
elif percentage >= 40: rating = "Acceptable"
|
|
295
|
+
elif percentage >= 20: rating = "Poor"
|
|
296
|
+
else: rating = "Failed"
|
|
297
|
+
|
|
298
|
+
return {
|
|
299
|
+
"format": format_score, "diversity": diversity_score, "hyde": hyde_score,
|
|
300
|
+
"quality": quality_score, "entity": max(0, entity_score),
|
|
301
|
+
"think_bonus": think_bonus, "total": max(0, total),
|
|
302
|
+
"max_possible": max_possible, "percentage": round(percentage, 1),
|
|
303
|
+
"rating": rating, "deductions": deductions,
|
|
304
|
+
"entities_detected": list(entities) if entities else [],
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# =============================================================================
|
|
309
|
+
# Model loading and generation
|
|
310
|
+
# =============================================================================
|
|
311
|
+
|
|
312
|
+
def load_model(base, sft=None, grpo=None):
|
|
313
|
+
print(f"Loading tokenizer from {base}...")
|
|
314
|
+
tokenizer = AutoTokenizer.from_pretrained(base)
|
|
315
|
+
if tokenizer.pad_token is None:
|
|
316
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
317
|
+
|
|
318
|
+
print(f"Loading base model {base}...")
|
|
319
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
320
|
+
base, torch_dtype=torch.bfloat16, device_map="auto",
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
if sft:
|
|
324
|
+
print(f"Loading and merging SFT adapter {sft}...")
|
|
325
|
+
model = PeftModel.from_pretrained(model, sft)
|
|
326
|
+
model = model.merge_and_unload()
|
|
327
|
+
|
|
328
|
+
if grpo:
|
|
329
|
+
print(f"Loading GRPO adapter {grpo}...")
|
|
330
|
+
model = PeftModel.from_pretrained(model, grpo)
|
|
331
|
+
|
|
332
|
+
model.eval()
|
|
333
|
+
return model, tokenizer
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
|
|
337
|
+
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
|
338
|
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
339
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
340
|
+
|
|
341
|
+
with torch.no_grad():
|
|
342
|
+
outputs = model.generate(
|
|
343
|
+
**inputs, max_new_tokens=max_new_tokens,
|
|
344
|
+
temperature=0.7, do_sample=True,
|
|
345
|
+
pad_token_id=tokenizer.pad_token_id,
|
|
346
|
+
eos_token_id=tokenizer.eos_token_id,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
350
|
+
if "\nassistant\n" in full_output:
|
|
351
|
+
expansion = full_output.split("\nassistant\n")[-1].strip()
|
|
352
|
+
elif "assistant\n" in full_output:
|
|
353
|
+
expansion = full_output.split("assistant\n")[-1].strip()
|
|
354
|
+
else:
|
|
355
|
+
expansion = full_output[len(prompt):].strip()
|
|
356
|
+
|
|
357
|
+
if "<think>" in expansion:
|
|
358
|
+
expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
|
|
359
|
+
return expansion
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# =============================================================================
|
|
363
|
+
# Main
|
|
364
|
+
# =============================================================================
|
|
365
|
+
|
|
366
|
+
def results_to_csv(results, label):
|
|
367
|
+
"""Convert eval results to CSV string."""
|
|
368
|
+
buf = io.StringIO()
|
|
369
|
+
writer = csv.writer(buf)
|
|
370
|
+
writer.writerow([
|
|
371
|
+
"model", "query", "expansion", "score_pct", "rating",
|
|
372
|
+
"format", "diversity", "hyde", "quality", "entity", "think_bonus",
|
|
373
|
+
"total", "max_possible", "deductions",
|
|
374
|
+
])
|
|
375
|
+
for r in results:
|
|
376
|
+
s = r["scores"]
|
|
377
|
+
writer.writerow([
|
|
378
|
+
label, r["query"], r["expansion"], s["percentage"], s["rating"],
|
|
379
|
+
s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"],
|
|
380
|
+
s["think_bonus"], s["total"], s["max_possible"],
|
|
381
|
+
"; ".join(s.get("deductions", [])),
|
|
382
|
+
])
|
|
383
|
+
return buf.getvalue()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def upload_csv(results, label, repo_id, api):
|
|
387
|
+
"""Upload eval results CSV to HuggingFace Hub."""
|
|
388
|
+
csv_data = results_to_csv(results, label)
|
|
389
|
+
tag = label.split("/")[-1].replace(" ", "_").lower()
|
|
390
|
+
filename = f"eval_{tag}.csv"
|
|
391
|
+
print(f" Uploading {filename} to {repo_id}...")
|
|
392
|
+
api.upload_file(
|
|
393
|
+
path_or_fileobj=csv_data.encode("utf-8"),
|
|
394
|
+
path_in_repo=filename,
|
|
395
|
+
repo_id=repo_id,
|
|
396
|
+
repo_type="model",
|
|
397
|
+
)
|
|
398
|
+
print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}")
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def evaluate_model(model, tokenizer, label):
|
|
402
|
+
print(f"\n{'='*70}")
|
|
403
|
+
print(f" EVALUATING: {label}")
|
|
404
|
+
print(f"{'='*70}")
|
|
405
|
+
|
|
406
|
+
results = []
|
|
407
|
+
for i, query in enumerate(QUERIES, 1):
|
|
408
|
+
expansion = generate_expansion(model, tokenizer, query)
|
|
409
|
+
scores = score_expansion_detailed(query, expansion)
|
|
410
|
+
results.append({"query": query, "expansion": expansion, "scores": scores})
|
|
411
|
+
|
|
412
|
+
marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~"
|
|
413
|
+
print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}")
|
|
414
|
+
|
|
415
|
+
avg = sum(r["scores"]["percentage"] for r in results) / len(results)
|
|
416
|
+
ratings = Counter(r["scores"]["rating"] for r in results)
|
|
417
|
+
|
|
418
|
+
print(f"\n {'─'*50}")
|
|
419
|
+
print(f" Average score: {avg:.1f}%")
|
|
420
|
+
print(f" Ratings:")
|
|
421
|
+
for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
|
|
422
|
+
count = ratings.get(rating, 0)
|
|
423
|
+
if count > 0:
|
|
424
|
+
print(f" {rating:10s}: {count:2d} {'█' * count}")
|
|
425
|
+
|
|
426
|
+
# Show worst queries
|
|
427
|
+
worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5]
|
|
428
|
+
print(f"\n Bottom 5:")
|
|
429
|
+
for r in worst:
|
|
430
|
+
print(f" {r['scores']['percentage']:5.1f}% {r['query']}")
|
|
431
|
+
if r["scores"]["deductions"]:
|
|
432
|
+
print(f" {', '.join(r['scores']['deductions'][:3])}")
|
|
433
|
+
|
|
434
|
+
return results, avg
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def main():
|
|
438
|
+
parser = argparse.ArgumentParser()
|
|
439
|
+
parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model")
|
|
440
|
+
parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals",
|
|
441
|
+
help="HF repo to upload CSV results")
|
|
442
|
+
args = parser.parse_args()
|
|
443
|
+
|
|
444
|
+
hf_token = os.environ.get("HF_TOKEN")
|
|
445
|
+
if hf_token:
|
|
446
|
+
login(token=hf_token)
|
|
447
|
+
|
|
448
|
+
api = HfApi()
|
|
449
|
+
api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True)
|
|
450
|
+
|
|
451
|
+
# Evaluate SFT
|
|
452
|
+
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL)
|
|
453
|
+
sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}")
|
|
454
|
+
upload_csv(sft_results, "sft", args.upload_repo, api)
|
|
455
|
+
|
|
456
|
+
if not args.sft_only:
|
|
457
|
+
# For GRPO: reload base, merge SFT, then load GRPO adapter
|
|
458
|
+
del model
|
|
459
|
+
torch.cuda.empty_cache()
|
|
460
|
+
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
|
|
461
|
+
grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}")
|
|
462
|
+
upload_csv(grpo_results, "grpo", args.upload_repo, api)
|
|
463
|
+
|
|
464
|
+
# Upload combined comparison CSV
|
|
465
|
+
combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1]
|
|
466
|
+
api.upload_file(
|
|
467
|
+
path_or_fileobj=combined.encode("utf-8"),
|
|
468
|
+
path_in_repo="eval_comparison.csv",
|
|
469
|
+
repo_id=args.upload_repo,
|
|
470
|
+
repo_type="model",
|
|
471
|
+
)
|
|
472
|
+
print(f" Uploaded: eval_comparison.csv")
|
|
473
|
+
|
|
474
|
+
# Comparison
|
|
475
|
+
print(f"\n{'='*70}")
|
|
476
|
+
print(f" COMPARISON")
|
|
477
|
+
print(f"{'='*70}")
|
|
478
|
+
print(f" SFT average: {sft_avg:.1f}%")
|
|
479
|
+
print(f" GRPO average: {grpo_avg:.1f}%")
|
|
480
|
+
print(f" Delta: {grpo_avg - sft_avg:+.1f}%")
|
|
481
|
+
|
|
482
|
+
improved = sum(1 for s, g in zip(sft_results, grpo_results)
|
|
483
|
+
if g["scores"]["percentage"] > s["scores"]["percentage"])
|
|
484
|
+
regressed = sum(1 for s, g in zip(sft_results, grpo_results)
|
|
485
|
+
if g["scores"]["percentage"] < s["scores"]["percentage"])
|
|
486
|
+
print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}")
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
if __name__ == "__main__":
|
|
490
|
+
main()
|