qmdr 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/marketplace.json +29 -0
- package/.env.example +85 -0
- package/.gitattributes +3 -0
- package/.github/workflows/release.yml +77 -0
- package/AI-SETUP.md +466 -0
- package/LICENSE +22 -0
- package/README.md +78 -0
- package/bun.lock +637 -0
- package/docs/README-zh.md +78 -0
- package/docs/refactor-checklist.md +54 -0
- package/docs/setup-openclaw.md +139 -0
- package/example-index.yml +33 -0
- package/finetune/BALANCED_DISTRIBUTION.md +157 -0
- package/finetune/DATA_IMPROVEMENTS.md +218 -0
- package/finetune/Justfile +43 -0
- package/finetune/Modelfile +16 -0
- package/finetune/README.md +299 -0
- package/finetune/SCORING.md +286 -0
- package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
- package/finetune/configs/grpo.yaml +49 -0
- package/finetune/configs/sft.yaml +42 -0
- package/finetune/configs/sft_local.yaml +40 -0
- package/finetune/convert_gguf.py +221 -0
- package/finetune/data/best_glm_prompt.txt +17 -0
- package/finetune/data/gepa_generated.prompts.json +32 -0
- package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
- package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
- package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
- package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
- package/finetune/data/qmd_expansion_locations.jsonl +64 -0
- package/finetune/data/qmd_expansion_people.jsonl +46 -0
- package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
- package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
- package/finetune/data/qmd_only_sampled.jsonl +399 -0
- package/finetune/dataset/analyze_data.py +369 -0
- package/finetune/dataset/clean_data.py +906 -0
- package/finetune/dataset/generate_balanced.py +823 -0
- package/finetune/dataset/generate_data.py +714 -0
- package/finetune/dataset/generate_data_offline.py +206 -0
- package/finetune/dataset/generate_diverse.py +441 -0
- package/finetune/dataset/generate_ollama.py +326 -0
- package/finetune/dataset/prepare_data.py +197 -0
- package/finetune/dataset/schema.py +73 -0
- package/finetune/dataset/score_data.py +115 -0
- package/finetune/dataset/validate_schema.py +104 -0
- package/finetune/eval.py +196 -0
- package/finetune/evals/queries.txt +56 -0
- package/finetune/gepa/__init__.py +1 -0
- package/finetune/gepa/best_prompt.txt +31 -0
- package/finetune/gepa/best_prompt_glm.txt +1 -0
- package/finetune/gepa/dspy_gepa.py +204 -0
- package/finetune/gepa/example.py +117 -0
- package/finetune/gepa/generate.py +129 -0
- package/finetune/gepa/gepa_outputs.jsonl +10 -0
- package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
- package/finetune/gepa/model.json +19 -0
- package/finetune/gepa/optimizer.py +70 -0
- package/finetune/gepa/score.py +84 -0
- package/finetune/jobs/eval.py +490 -0
- package/finetune/jobs/eval_common.py +354 -0
- package/finetune/jobs/eval_verbose.py +113 -0
- package/finetune/jobs/grpo.py +141 -0
- package/finetune/jobs/quantize.py +244 -0
- package/finetune/jobs/sft.py +121 -0
- package/finetune/pyproject.toml +23 -0
- package/finetune/reward.py +610 -0
- package/finetune/train.py +611 -0
- package/finetune/uv.lock +4070 -0
- package/flake.lock +61 -0
- package/flake.nix +83 -0
- package/migrate-schema.ts +162 -0
- package/package.json +56 -0
- package/skills/qmdr/SKILL.md +172 -0
- package/skills/qmdr/references/mcp-setup.md +88 -0
- package/src/app/commands/collection.ts +55 -0
- package/src/app/commands/context.ts +82 -0
- package/src/app/commands/document.ts +46 -0
- package/src/app/commands/maintenance.ts +60 -0
- package/src/app/commands/search.ts +45 -0
- package/src/app/ports/llm.ts +13 -0
- package/src/app/services/llm-service.ts +145 -0
- package/src/cli.test.ts +963 -0
- package/src/collections.ts +390 -0
- package/src/eval.test.ts +412 -0
- package/src/formatter.ts +427 -0
- package/src/llm.test.ts +559 -0
- package/src/llm.ts +1990 -0
- package/src/mcp.test.ts +889 -0
- package/src/mcp.ts +626 -0
- package/src/qmd.ts +3330 -0
- package/src/store/collections.ts +7 -0
- package/src/store/context.ts +10 -0
- package/src/store/db.ts +5 -0
- package/src/store/documents.ts +26 -0
- package/src/store/maintenance.ts +15 -0
- package/src/store/path.ts +13 -0
- package/src/store/search.ts +10 -0
- package/src/store-paths.test.ts +395 -0
- package/src/store.test.ts +2483 -0
- package/src/store.ts +2813 -0
- package/test/eval-harness.ts +223 -0
- package/tsconfig.json +29 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# requires-python = ">=3.10"
|
|
3
|
+
# dependencies = []
|
|
4
|
+
# ///
|
|
5
|
+
"""
|
|
6
|
+
QMD Query Expansion Reward Function
|
|
7
|
+
|
|
8
|
+
Single source of truth for scoring query expansions. Used by:
|
|
9
|
+
- GRPO training (as the RL reward signal)
|
|
10
|
+
- Evaluation scripts (for scoring model outputs)
|
|
11
|
+
|
|
12
|
+
Scores expansions on five dimensions:
|
|
13
|
+
Format (30) - Has lex/vec lines, no invalid lines
|
|
14
|
+
Diversity (30) - Multiple types, diverse content, no echoes
|
|
15
|
+
HyDE (20) - Optional bonus for hypothetical document passage
|
|
16
|
+
Quality (20) - Lex shorter than vec, natural language, key terms
|
|
17
|
+
Entity (20) - Named entity preservation in lex/vec lines
|
|
18
|
+
|
|
19
|
+
Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import re
|
|
23
|
+
from collections import Counter
|
|
24
|
+
|
|
25
|
+
# =============================================================================
|
|
26
|
+
# Constants
|
|
27
|
+
# =============================================================================
|
|
28
|
+
|
|
29
|
+
# "only:" mode patterns - when query ends with these, expect only that type
|
|
30
|
+
# Format: "query /only:lex" (slash prefix, no space after colon)
|
|
31
|
+
ONLY_MODE_PATTERN = re.compile(r'\s+/only:(lex|vec|hyde)\s*$', re.IGNORECASE)
|
|
32
|
+
|
|
33
|
+
STOPWORDS = frozenset({
|
|
34
|
+
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
|
35
|
+
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
|
36
|
+
})
|
|
37
|
+
|
|
38
|
+
KEY_TERM_STOPWORDS = frozenset({
|
|
39
|
+
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
|
40
|
+
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
|
41
|
+
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
|
42
|
+
})
|
|
43
|
+
|
|
44
|
+
GENERIC_LEX_PHRASES = frozenset({
|
|
45
|
+
'find information about', 'search for', 'look up', 'get information',
|
|
46
|
+
'learn about', 'information on', 'details about', 'find out about',
|
|
47
|
+
'what is', 'how to', 'guide to', 'help with',
|
|
48
|
+
})
|
|
49
|
+
|
|
50
|
+
# Chat template tokens that indicate a broken output
|
|
51
|
+
CHAT_TEMPLATE_TOKENS = frozenset({
|
|
52
|
+
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
|
53
|
+
'\nassistant\n', '\nuser\n',
|
|
54
|
+
})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# =============================================================================
|
|
58
|
+
# Parsing
|
|
59
|
+
# =============================================================================
|
|
60
|
+
|
|
61
|
+
def parse_expansion(text: str) -> dict:
|
|
62
|
+
"""Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
|
|
63
|
+
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
64
|
+
for line in text.strip().split("\n"):
|
|
65
|
+
line = line.strip()
|
|
66
|
+
if not line:
|
|
67
|
+
continue
|
|
68
|
+
if line.startswith("lex:"):
|
|
69
|
+
result["lex"].append(line[4:].strip())
|
|
70
|
+
elif line.startswith("vec:"):
|
|
71
|
+
result["vec"].append(line[4:].strip())
|
|
72
|
+
elif line.startswith("hyde:"):
|
|
73
|
+
result["hyde"].append(line[5:].strip())
|
|
74
|
+
else:
|
|
75
|
+
result["invalid"].append(line)
|
|
76
|
+
return result
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def detect_only_mode(query: str) -> tuple[str | None, str]:
|
|
80
|
+
"""Detect if query ends with 'only: lex/vec/hyde'.
|
|
81
|
+
|
|
82
|
+
Returns (only_type, base_query) where only_type is None for normal queries.
|
|
83
|
+
"""
|
|
84
|
+
match = ONLY_MODE_PATTERN.search(query)
|
|
85
|
+
if match:
|
|
86
|
+
only_type = match.group(1).lower()
|
|
87
|
+
base_query = query[:match.start()].strip()
|
|
88
|
+
return only_type, base_query
|
|
89
|
+
return None, query
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def clean_model_output(text: str) -> tuple[str, bool]:
|
|
93
|
+
"""Strip chat template artifacts from model output.
|
|
94
|
+
|
|
95
|
+
Returns (cleaned_text, used_thinking) where used_thinking is True
|
|
96
|
+
if the model emitted <think>...</think> blocks.
|
|
97
|
+
"""
|
|
98
|
+
text = text.replace('<|im_end|>', '').strip()
|
|
99
|
+
|
|
100
|
+
used_thinking = '<think>' in text and '</think>' in text
|
|
101
|
+
if used_thinking:
|
|
102
|
+
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
|
103
|
+
|
|
104
|
+
return text, used_thinking
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# =============================================================================
|
|
108
|
+
# Helpers
|
|
109
|
+
# =============================================================================
|
|
110
|
+
|
|
111
|
+
def extract_named_entities(query: str) -> set:
|
|
112
|
+
"""Extract named entities using heuristics.
|
|
113
|
+
|
|
114
|
+
Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React),
|
|
115
|
+
technical terms with special chars (node.js, C++), CamelCase (JavaScript),
|
|
116
|
+
and compound names (TDS motorsports -> both words).
|
|
117
|
+
"""
|
|
118
|
+
entities = set()
|
|
119
|
+
words = query.split()
|
|
120
|
+
prev_was_entity = False
|
|
121
|
+
|
|
122
|
+
for i, word in enumerate(words):
|
|
123
|
+
clean = word.strip('.,!?:;()[]"\'')
|
|
124
|
+
if not clean:
|
|
125
|
+
prev_was_entity = False
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
is_entity = False
|
|
129
|
+
|
|
130
|
+
if clean.isupper() and len(clean) >= 2:
|
|
131
|
+
entities.add(clean.lower())
|
|
132
|
+
is_entity = True
|
|
133
|
+
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
134
|
+
entities.add(clean.lower())
|
|
135
|
+
is_entity = True
|
|
136
|
+
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
|
137
|
+
entities.add(clean.lower())
|
|
138
|
+
is_entity = True
|
|
139
|
+
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
|
140
|
+
entities.add(clean.lower())
|
|
141
|
+
is_entity = True
|
|
142
|
+
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
143
|
+
entities.add(clean.lower())
|
|
144
|
+
is_entity = True
|
|
145
|
+
|
|
146
|
+
prev_was_entity = is_entity
|
|
147
|
+
|
|
148
|
+
return entities
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_key_terms(query: str) -> set:
|
|
152
|
+
"""Get non-stopword terms from a query."""
|
|
153
|
+
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
|
|
157
|
+
"""Does the lex line contain at least one key term from the query?"""
|
|
158
|
+
key_terms = get_key_terms(query)
|
|
159
|
+
if not key_terms:
|
|
160
|
+
return True
|
|
161
|
+
return bool(key_terms & set(lex_line.lower().split()))
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def lex_preserves_entities(line: str, entities: set) -> bool:
|
|
165
|
+
"""Does the line contain at least one named entity?"""
|
|
166
|
+
if not entities:
|
|
167
|
+
return True
|
|
168
|
+
lower = line.lower()
|
|
169
|
+
return any(e in lower for e in entities)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def lex_is_generic(lex_line: str) -> bool:
|
|
173
|
+
"""Is this lex line a useless generic filler phrase?"""
|
|
174
|
+
lower = lex_line.lower().strip()
|
|
175
|
+
for phrase in GENERIC_LEX_PHRASES:
|
|
176
|
+
if phrase in lower or lower.startswith(phrase.split()[0]):
|
|
177
|
+
remaining = lower
|
|
178
|
+
for word in phrase.split():
|
|
179
|
+
remaining = remaining.replace(word, '', 1).strip()
|
|
180
|
+
if len(remaining) < 3:
|
|
181
|
+
return True
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def word_set_distance(a: str, b: str) -> int:
|
|
186
|
+
"""Symmetric difference of word sets (how many words are unique to one)."""
|
|
187
|
+
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
|
|
191
|
+
"""Are two strings sufficiently different?"""
|
|
192
|
+
a, b = a.lower().strip(), b.lower().strip()
|
|
193
|
+
if a == b or a in b or b in a:
|
|
194
|
+
return False
|
|
195
|
+
return word_set_distance(a, b) >= min_distance
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def echoes_query(expansion: str, query: str) -> bool:
|
|
199
|
+
"""Is this expansion just echoing the original query?"""
|
|
200
|
+
exp, q = expansion.lower().strip(), query.lower().strip()
|
|
201
|
+
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def word_repetition_penalty(text: str) -> int:
|
|
205
|
+
"""Penalty for words repeated 3+ times (excluding stopwords)."""
|
|
206
|
+
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
|
207
|
+
return sum((c - 2) * 2 for w, c in counts.items()
|
|
208
|
+
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# =============================================================================
|
|
212
|
+
# Scoring
|
|
213
|
+
# =============================================================================
|
|
214
|
+
|
|
215
|
+
def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict:
|
|
216
|
+
"""Score an 'only:' mode expansion. Expects ONLY the requested type."""
|
|
217
|
+
parsed = parse_expansion(text)
|
|
218
|
+
deductions = []
|
|
219
|
+
|
|
220
|
+
# Expected type must be present
|
|
221
|
+
expected_items = parsed.get(only_type, [])
|
|
222
|
+
if not expected_items:
|
|
223
|
+
return {
|
|
224
|
+
"format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
|
|
225
|
+
"think_bonus": 0, "total": 0, "max_possible": 100,
|
|
226
|
+
"percentage": 0.0, "rating": "Failed",
|
|
227
|
+
"deductions": [f"missing expected {only_type}: output"],
|
|
228
|
+
"parsed": parsed,
|
|
229
|
+
"entities_detected": [],
|
|
230
|
+
"only_mode": only_type,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
# Penalize presence of OTHER types
|
|
234
|
+
other_types = {"lex", "vec", "hyde"} - {only_type}
|
|
235
|
+
unwanted_count = sum(len(parsed.get(t, [])) for t in other_types)
|
|
236
|
+
if unwanted_count > 0:
|
|
237
|
+
deductions.append(f"contains unwanted types (expected only {only_type})")
|
|
238
|
+
|
|
239
|
+
# --- Format (0-30) ---
|
|
240
|
+
format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10)
|
|
241
|
+
|
|
242
|
+
# --- Diversity (0-30) ---
|
|
243
|
+
diversity_score = 0
|
|
244
|
+
if len(expected_items) >= 2:
|
|
245
|
+
diversity_score += 15
|
|
246
|
+
# Check for diversity among items
|
|
247
|
+
div_score = 15
|
|
248
|
+
for i, a in enumerate(expected_items):
|
|
249
|
+
for b in expected_items[i+1:]:
|
|
250
|
+
if not is_diverse(a, b, 2):
|
|
251
|
+
div_score -= 5
|
|
252
|
+
deductions.append(f"{only_type} duplicate: {a[:20]}...")
|
|
253
|
+
diversity_score += max(0, div_score)
|
|
254
|
+
elif len(expected_items) == 1:
|
|
255
|
+
diversity_score = 15 # One item is fine for single-type output
|
|
256
|
+
|
|
257
|
+
# Check for echoes
|
|
258
|
+
for exp in expected_items:
|
|
259
|
+
if echoes_query(exp, base_query):
|
|
260
|
+
diversity_score -= 5
|
|
261
|
+
deductions.append(f"echoes query: {exp[:20]}...")
|
|
262
|
+
diversity_score = max(0, diversity_score)
|
|
263
|
+
|
|
264
|
+
# --- Type-specific quality (0-20) ---
|
|
265
|
+
quality_score = 10 # base
|
|
266
|
+
entities = extract_named_entities(base_query)
|
|
267
|
+
|
|
268
|
+
if only_type == "lex":
|
|
269
|
+
# Lex should be short keyword phrases with key terms
|
|
270
|
+
with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query))
|
|
271
|
+
if with_terms == len(expected_items):
|
|
272
|
+
quality_score += 5
|
|
273
|
+
# Check for generic phrases
|
|
274
|
+
generic = sum(1 for l in expected_items if lex_is_generic(l))
|
|
275
|
+
if generic == 0:
|
|
276
|
+
quality_score += 5
|
|
277
|
+
else:
|
|
278
|
+
deductions.append(f"{generic} generic lex phrases")
|
|
279
|
+
|
|
280
|
+
elif only_type == "vec":
|
|
281
|
+
# Vec should be natural language sentences
|
|
282
|
+
natural = sum(1 for v in expected_items if " " in v and len(v) > 15)
|
|
283
|
+
if natural == len(expected_items):
|
|
284
|
+
quality_score += 10
|
|
285
|
+
else:
|
|
286
|
+
quality_score += 5
|
|
287
|
+
deductions.append("vec not all natural language")
|
|
288
|
+
|
|
289
|
+
elif only_type == "hyde":
|
|
290
|
+
# Hyde should be a document snippet (50-200 chars)
|
|
291
|
+
hyde_text = expected_items[0]
|
|
292
|
+
hyde_len = len(hyde_text)
|
|
293
|
+
if 50 <= hyde_len <= 200:
|
|
294
|
+
quality_score += 10
|
|
295
|
+
elif 30 <= hyde_len <= 300:
|
|
296
|
+
quality_score += 5
|
|
297
|
+
deductions.append(f"hyde length {hyde_len} (ideal: 50-200)")
|
|
298
|
+
else:
|
|
299
|
+
deductions.append(f"hyde length {hyde_len} out of range")
|
|
300
|
+
|
|
301
|
+
# --- Entity preservation (0-20) ---
|
|
302
|
+
entity_score = 10 # base
|
|
303
|
+
if entities:
|
|
304
|
+
with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities))
|
|
305
|
+
if with_entities == len(expected_items):
|
|
306
|
+
entity_score += 10
|
|
307
|
+
elif with_entities > 0:
|
|
308
|
+
entity_score += 5
|
|
309
|
+
else:
|
|
310
|
+
entity_score = 0
|
|
311
|
+
deductions.append(f"missing entities: {entities}")
|
|
312
|
+
|
|
313
|
+
# --- Think bonus (0-20) ---
|
|
314
|
+
think_bonus = 0 if used_thinking else 20
|
|
315
|
+
|
|
316
|
+
# --- Total ---
|
|
317
|
+
total = format_score + diversity_score + quality_score + entity_score + think_bonus
|
|
318
|
+
max_possible = 120
|
|
319
|
+
percentage = max(0.0, min(100.0, total / max_possible * 100))
|
|
320
|
+
|
|
321
|
+
if percentage >= 80:
|
|
322
|
+
rating = "Excellent"
|
|
323
|
+
elif percentage >= 60:
|
|
324
|
+
rating = "Good"
|
|
325
|
+
elif percentage >= 40:
|
|
326
|
+
rating = "Acceptable"
|
|
327
|
+
elif percentage >= 20:
|
|
328
|
+
rating = "Poor"
|
|
329
|
+
else:
|
|
330
|
+
rating = "Failed"
|
|
331
|
+
|
|
332
|
+
return {
|
|
333
|
+
"format": format_score,
|
|
334
|
+
"diversity": diversity_score,
|
|
335
|
+
"hyde": 0, # not used in only mode (quality covers it)
|
|
336
|
+
"quality": quality_score,
|
|
337
|
+
"entity": entity_score,
|
|
338
|
+
"think_bonus": think_bonus,
|
|
339
|
+
"total": total,
|
|
340
|
+
"max_possible": max_possible,
|
|
341
|
+
"percentage": round(percentage, 1),
|
|
342
|
+
"rating": rating,
|
|
343
|
+
"deductions": deductions,
|
|
344
|
+
"parsed": parsed,
|
|
345
|
+
"entities_detected": list(entities) if entities else [],
|
|
346
|
+
"only_mode": only_type,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def score_expansion_detailed(query: str, expansion: str) -> dict:
|
|
351
|
+
"""Score an expansion with full breakdown. Returns dict with all dimensions."""
|
|
352
|
+
text, used_thinking = clean_model_output(expansion.strip())
|
|
353
|
+
deductions = []
|
|
354
|
+
|
|
355
|
+
# Detect "only:" mode
|
|
356
|
+
only_type, base_query = detect_only_mode(query)
|
|
357
|
+
|
|
358
|
+
def _fail(reason):
|
|
359
|
+
return {
|
|
360
|
+
"format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
|
|
361
|
+
"think_bonus": 0, "total": 0, "max_possible": 100,
|
|
362
|
+
"percentage": 0.0, "rating": "Failed",
|
|
363
|
+
"deductions": [reason],
|
|
364
|
+
"parsed": parse_expansion(expansion),
|
|
365
|
+
"entities_detected": [],
|
|
366
|
+
"only_mode": only_type,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
# Hard fail: remaining chat template tokens
|
|
370
|
+
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
|
371
|
+
return _fail("CHAT TEMPLATE LEAKAGE")
|
|
372
|
+
|
|
373
|
+
# Hard fail: every non-empty line must have a valid prefix
|
|
374
|
+
for line in text.split("\n"):
|
|
375
|
+
line = line.strip()
|
|
376
|
+
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
|
377
|
+
return _fail(f"INVALID LINE: {line[:50]}")
|
|
378
|
+
|
|
379
|
+
# --- Handle "only:" mode separately ---
|
|
380
|
+
if only_type:
|
|
381
|
+
return _score_only_mode(query, base_query, text, used_thinking, only_type)
|
|
382
|
+
|
|
383
|
+
parsed = parse_expansion(text)
|
|
384
|
+
|
|
385
|
+
# --- Format (0-30) ---
|
|
386
|
+
format_score = 10 # no invalid lines (guaranteed by hard fail)
|
|
387
|
+
if parsed["lex"]:
|
|
388
|
+
format_score += 10
|
|
389
|
+
else:
|
|
390
|
+
deductions.append("missing lex:")
|
|
391
|
+
if parsed["vec"]:
|
|
392
|
+
format_score += 10
|
|
393
|
+
else:
|
|
394
|
+
deductions.append("missing vec:")
|
|
395
|
+
|
|
396
|
+
# --- Diversity (0-30) ---
|
|
397
|
+
diversity_score = 0
|
|
398
|
+
|
|
399
|
+
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
|
|
400
|
+
if types_present >= 2:
|
|
401
|
+
diversity_score += 10
|
|
402
|
+
else:
|
|
403
|
+
deductions.append("only one type")
|
|
404
|
+
|
|
405
|
+
if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
|
|
406
|
+
diversity_score += 5
|
|
407
|
+
|
|
408
|
+
lex_div = 5
|
|
409
|
+
for i, a in enumerate(parsed["lex"]):
|
|
410
|
+
for b in parsed["lex"][i+1:]:
|
|
411
|
+
if not is_diverse(a, b, 2):
|
|
412
|
+
lex_div -= 2
|
|
413
|
+
deductions.append(f"lex duplicate: {a[:20]}...")
|
|
414
|
+
diversity_score += max(0, lex_div)
|
|
415
|
+
|
|
416
|
+
vec_div = 5
|
|
417
|
+
for i, a in enumerate(parsed["vec"]):
|
|
418
|
+
for b in parsed["vec"][i+1:]:
|
|
419
|
+
if not is_diverse(a, b, 3):
|
|
420
|
+
vec_div -= 2
|
|
421
|
+
deductions.append(f"vec duplicate: {a[:20]}...")
|
|
422
|
+
diversity_score += max(0, vec_div)
|
|
423
|
+
|
|
424
|
+
echo = 5
|
|
425
|
+
lex_echo_count = 0
|
|
426
|
+
for exp in parsed["lex"]:
|
|
427
|
+
if echoes_query(exp, query):
|
|
428
|
+
lex_echo_count += 1
|
|
429
|
+
deductions.append(f"lex echoes query: {exp[:20]}...")
|
|
430
|
+
# Harsh penalty for lex echoes - they're useless
|
|
431
|
+
if lex_echo_count > 0:
|
|
432
|
+
echo -= lex_echo_count * 10 # -10 per echo
|
|
433
|
+
|
|
434
|
+
for exp in parsed["vec"]:
|
|
435
|
+
if echoes_query(exp, query):
|
|
436
|
+
echo -= 3 # vec echoes less severe (natural language overlap ok)
|
|
437
|
+
deductions.append(f"vec echoes query: {exp[:20]}...")
|
|
438
|
+
diversity_score += max(-10, echo) # can go negative
|
|
439
|
+
|
|
440
|
+
# --- HyDE (0-20, optional bonus) ---
|
|
441
|
+
hyde_score = 0
|
|
442
|
+
if parsed["hyde"]:
|
|
443
|
+
hyde_text = parsed["hyde"][0]
|
|
444
|
+
hyde_score += 5
|
|
445
|
+
hyde_len = len(hyde_text)
|
|
446
|
+
if 50 <= hyde_len <= 200:
|
|
447
|
+
hyde_score += 5
|
|
448
|
+
elif hyde_len < 50:
|
|
449
|
+
hyde_score += 2
|
|
450
|
+
deductions.append(f"hyde too short ({hyde_len})")
|
|
451
|
+
else:
|
|
452
|
+
deductions.append(f"hyde too long ({hyde_len})")
|
|
453
|
+
if "\n" not in hyde_text:
|
|
454
|
+
hyde_score += 5
|
|
455
|
+
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
|
456
|
+
|
|
457
|
+
# --- Quality (0-20) ---
|
|
458
|
+
quality_score = 5 # base relevance
|
|
459
|
+
if parsed["lex"] and parsed["vec"]:
|
|
460
|
+
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
|
461
|
+
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
|
462
|
+
if avg_lex <= avg_vec:
|
|
463
|
+
quality_score += 5
|
|
464
|
+
else:
|
|
465
|
+
deductions.append("lex longer than vec")
|
|
466
|
+
if parsed["vec"]:
|
|
467
|
+
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
|
468
|
+
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
|
469
|
+
if parsed["lex"]:
|
|
470
|
+
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
|
471
|
+
if with_terms == len(parsed["lex"]):
|
|
472
|
+
quality_score += 5
|
|
473
|
+
elif with_terms > 0:
|
|
474
|
+
quality_score += 2
|
|
475
|
+
else:
|
|
476
|
+
deductions.append("lex missing key terms")
|
|
477
|
+
|
|
478
|
+
# --- Entity Preservation (-45 to +20) ---
|
|
479
|
+
entity_score = 0
|
|
480
|
+
entities = extract_named_entities(query)
|
|
481
|
+
if entities and parsed["lex"]:
|
|
482
|
+
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
|
483
|
+
if with_entities == len(parsed["lex"]):
|
|
484
|
+
entity_score += 15
|
|
485
|
+
elif with_entities > 0:
|
|
486
|
+
entity_score += 5
|
|
487
|
+
else:
|
|
488
|
+
entity_score -= 30
|
|
489
|
+
deductions.append(f"lex missing entities: {entities}")
|
|
490
|
+
|
|
491
|
+
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
|
492
|
+
if generic_count:
|
|
493
|
+
entity_score -= generic_count * 15
|
|
494
|
+
deductions.append(f"{generic_count} generic lex phrases")
|
|
495
|
+
|
|
496
|
+
if parsed["vec"]:
|
|
497
|
+
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
|
498
|
+
if vec_with > 0:
|
|
499
|
+
entity_score += 5
|
|
500
|
+
elif not entities:
|
|
501
|
+
entity_score = 10
|
|
502
|
+
|
|
503
|
+
# --- Think bonus (0-20): reward NOT using thinking mode ---
|
|
504
|
+
think_bonus = 0 if used_thinking else 20
|
|
505
|
+
|
|
506
|
+
# --- Total ---
|
|
507
|
+
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
|
508
|
+
max_possible = 140 if parsed["hyde"] else 120
|
|
509
|
+
percentage = max(0.0, min(100.0, total / max_possible * 100))
|
|
510
|
+
|
|
511
|
+
# Hard cap: lex echoes are unacceptable - cap at 50%
|
|
512
|
+
if lex_echo_count > 0:
|
|
513
|
+
percentage = min(percentage, 50.0)
|
|
514
|
+
deductions.insert(0, f"CAPPED: {lex_echo_count} lex echo(es)")
|
|
515
|
+
|
|
516
|
+
if percentage >= 80:
|
|
517
|
+
rating = "Excellent"
|
|
518
|
+
elif percentage >= 60:
|
|
519
|
+
rating = "Good"
|
|
520
|
+
elif percentage >= 40:
|
|
521
|
+
rating = "Acceptable"
|
|
522
|
+
elif percentage >= 20:
|
|
523
|
+
rating = "Poor"
|
|
524
|
+
else:
|
|
525
|
+
rating = "Failed"
|
|
526
|
+
|
|
527
|
+
return {
|
|
528
|
+
"format": format_score,
|
|
529
|
+
"diversity": diversity_score,
|
|
530
|
+
"hyde": hyde_score,
|
|
531
|
+
"quality": quality_score,
|
|
532
|
+
"entity": max(0, entity_score),
|
|
533
|
+
"think_bonus": think_bonus,
|
|
534
|
+
"total": max(0, total),
|
|
535
|
+
"max_possible": max_possible,
|
|
536
|
+
"percentage": round(percentage, 1),
|
|
537
|
+
"rating": rating,
|
|
538
|
+
"deductions": deductions,
|
|
539
|
+
"parsed": parsed,
|
|
540
|
+
"entities_detected": list(entities) if entities else [],
|
|
541
|
+
"only_mode": None,
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def score_expansion(query: str, expansion: str) -> float:
|
|
546
|
+
"""Score expansion as a float in [0.0, 1.0] for use as RL reward."""
|
|
547
|
+
result = score_expansion_detailed(query, expansion)
|
|
548
|
+
return max(0.0, min(1.0, result["total"] / result["max_possible"]))
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def extract_query_from_prompt(prompt: str) -> str:
|
|
552
|
+
"""Extract the query string from a chat-formatted prompt."""
|
|
553
|
+
if "Expand this search query:" in prompt:
|
|
554
|
+
query = prompt.split("Expand this search query:")[-1].strip()
|
|
555
|
+
if "<|im_end|>" in query:
|
|
556
|
+
query = query.split("<|im_end|>")[0].strip()
|
|
557
|
+
return query
|
|
558
|
+
return prompt.strip()
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
# =============================================================================
|
|
562
|
+
# TRL-compatible reward class
|
|
563
|
+
# =============================================================================
|
|
564
|
+
|
|
565
|
+
class QMDRewardFunction:
|
|
566
|
+
"""Reward function compatible with TRL's GRPOTrainer."""
|
|
567
|
+
__name__ = "qmd_scoring_reward"
|
|
568
|
+
|
|
569
|
+
def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
|
|
570
|
+
rewards = []
|
|
571
|
+
for i, completion in enumerate(completions):
|
|
572
|
+
query = ""
|
|
573
|
+
if prompts and i < len(prompts):
|
|
574
|
+
query = extract_query_from_prompt(prompts[i])
|
|
575
|
+
rewards.append(score_expansion(query, completion))
|
|
576
|
+
return rewards
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
# =============================================================================
|
|
580
|
+
# CLI: run standalone to test the reward function
|
|
581
|
+
# =============================================================================
|
|
582
|
+
|
|
583
|
+
if __name__ == "__main__":
|
|
584
|
+
print("QMD Reward Function Self-Test")
|
|
585
|
+
print("=" * 60)
|
|
586
|
+
|
|
587
|
+
tests = [
|
|
588
|
+
("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
|
|
589
|
+
("auth", "auth is important for security"),
|
|
590
|
+
("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
|
|
591
|
+
("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
|
|
592
|
+
("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
|
|
593
|
+
("auth", "<think>Let me think...</think>\nlex: auth"),
|
|
594
|
+
("auth", "lex: auth\nThis is some explanation\nvec: more"),
|
|
595
|
+
# "/only:" mode tests (slash prefix)
|
|
596
|
+
("auth /only:lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"),
|
|
597
|
+
("auth /only:lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec
|
|
598
|
+
("React hooks /only:vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"),
|
|
599
|
+
("PostgreSQL indexing /only:hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."),
|
|
600
|
+
]
|
|
601
|
+
|
|
602
|
+
for query, expansion in tests:
|
|
603
|
+
score = score_expansion(query, expansion)
|
|
604
|
+
detail = score_expansion_detailed(query, expansion)
|
|
605
|
+
only_mode = detail.get("only_mode")
|
|
606
|
+
mode_str = f" [only:{only_mode}]" if only_mode else ""
|
|
607
|
+
print(f"\n Query: '{query}'{mode_str}")
|
|
608
|
+
print(f" Score: {score:.2f} ({detail['rating']})")
|
|
609
|
+
if detail["deductions"]:
|
|
610
|
+
print(f" Issues: {', '.join(detail['deductions'][:3])}")
|