qmdr 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/.claude-plugin/marketplace.json +29 -0
  2. package/.env.example +85 -0
  3. package/.gitattributes +3 -0
  4. package/.github/workflows/release.yml +77 -0
  5. package/AI-SETUP.md +466 -0
  6. package/LICENSE +22 -0
  7. package/README.md +78 -0
  8. package/bun.lock +637 -0
  9. package/docs/README-zh.md +78 -0
  10. package/docs/refactor-checklist.md +54 -0
  11. package/docs/setup-openclaw.md +139 -0
  12. package/example-index.yml +33 -0
  13. package/finetune/BALANCED_DISTRIBUTION.md +157 -0
  14. package/finetune/DATA_IMPROVEMENTS.md +218 -0
  15. package/finetune/Justfile +43 -0
  16. package/finetune/Modelfile +16 -0
  17. package/finetune/README.md +299 -0
  18. package/finetune/SCORING.md +286 -0
  19. package/finetune/configs/accelerate_multi_gpu.yaml +17 -0
  20. package/finetune/configs/grpo.yaml +49 -0
  21. package/finetune/configs/sft.yaml +42 -0
  22. package/finetune/configs/sft_local.yaml +40 -0
  23. package/finetune/convert_gguf.py +221 -0
  24. package/finetune/data/best_glm_prompt.txt +17 -0
  25. package/finetune/data/gepa_generated.prompts.json +32 -0
  26. package/finetune/data/qmd_expansion_balanced_deduped.jsonl +413 -0
  27. package/finetune/data/qmd_expansion_diverse_addon.jsonl +386 -0
  28. package/finetune/data/qmd_expansion_handcrafted.jsonl +65 -0
  29. package/finetune/data/qmd_expansion_handcrafted_only.jsonl +336 -0
  30. package/finetune/data/qmd_expansion_locations.jsonl +64 -0
  31. package/finetune/data/qmd_expansion_people.jsonl +46 -0
  32. package/finetune/data/qmd_expansion_short_nontech.jsonl +200 -0
  33. package/finetune/data/qmd_expansion_v2.jsonl +1498 -0
  34. package/finetune/data/qmd_only_sampled.jsonl +399 -0
  35. package/finetune/dataset/analyze_data.py +369 -0
  36. package/finetune/dataset/clean_data.py +906 -0
  37. package/finetune/dataset/generate_balanced.py +823 -0
  38. package/finetune/dataset/generate_data.py +714 -0
  39. package/finetune/dataset/generate_data_offline.py +206 -0
  40. package/finetune/dataset/generate_diverse.py +441 -0
  41. package/finetune/dataset/generate_ollama.py +326 -0
  42. package/finetune/dataset/prepare_data.py +197 -0
  43. package/finetune/dataset/schema.py +73 -0
  44. package/finetune/dataset/score_data.py +115 -0
  45. package/finetune/dataset/validate_schema.py +104 -0
  46. package/finetune/eval.py +196 -0
  47. package/finetune/evals/queries.txt +56 -0
  48. package/finetune/gepa/__init__.py +1 -0
  49. package/finetune/gepa/best_prompt.txt +31 -0
  50. package/finetune/gepa/best_prompt_glm.txt +1 -0
  51. package/finetune/gepa/dspy_gepa.py +204 -0
  52. package/finetune/gepa/example.py +117 -0
  53. package/finetune/gepa/generate.py +129 -0
  54. package/finetune/gepa/gepa_outputs.jsonl +10 -0
  55. package/finetune/gepa/gepa_outputs_glm.jsonl +20 -0
  56. package/finetune/gepa/model.json +19 -0
  57. package/finetune/gepa/optimizer.py +70 -0
  58. package/finetune/gepa/score.py +84 -0
  59. package/finetune/jobs/eval.py +490 -0
  60. package/finetune/jobs/eval_common.py +354 -0
  61. package/finetune/jobs/eval_verbose.py +113 -0
  62. package/finetune/jobs/grpo.py +141 -0
  63. package/finetune/jobs/quantize.py +244 -0
  64. package/finetune/jobs/sft.py +121 -0
  65. package/finetune/pyproject.toml +23 -0
  66. package/finetune/reward.py +610 -0
  67. package/finetune/train.py +611 -0
  68. package/finetune/uv.lock +4070 -0
  69. package/flake.lock +61 -0
  70. package/flake.nix +83 -0
  71. package/migrate-schema.ts +162 -0
  72. package/package.json +56 -0
  73. package/skills/qmdr/SKILL.md +172 -0
  74. package/skills/qmdr/references/mcp-setup.md +88 -0
  75. package/src/app/commands/collection.ts +55 -0
  76. package/src/app/commands/context.ts +82 -0
  77. package/src/app/commands/document.ts +46 -0
  78. package/src/app/commands/maintenance.ts +60 -0
  79. package/src/app/commands/search.ts +45 -0
  80. package/src/app/ports/llm.ts +13 -0
  81. package/src/app/services/llm-service.ts +145 -0
  82. package/src/cli.test.ts +963 -0
  83. package/src/collections.ts +390 -0
  84. package/src/eval.test.ts +412 -0
  85. package/src/formatter.ts +427 -0
  86. package/src/llm.test.ts +559 -0
  87. package/src/llm.ts +1990 -0
  88. package/src/mcp.test.ts +889 -0
  89. package/src/mcp.ts +626 -0
  90. package/src/qmd.ts +3330 -0
  91. package/src/store/collections.ts +7 -0
  92. package/src/store/context.ts +10 -0
  93. package/src/store/db.ts +5 -0
  94. package/src/store/documents.ts +26 -0
  95. package/src/store/maintenance.ts +15 -0
  96. package/src/store/path.ts +13 -0
  97. package/src/store/search.ts +10 -0
  98. package/src/store-paths.test.ts +395 -0
  99. package/src/store.test.ts +2483 -0
  100. package/src/store.ts +2813 -0
  101. package/test/eval-harness.ts +223 -0
  102. package/tsconfig.json +29 -0
@@ -0,0 +1,610 @@
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = []
4
+ # ///
5
+ """
6
+ QMD Query Expansion Reward Function
7
+
8
+ Single source of truth for scoring query expansions. Used by:
9
+ - GRPO training (as the RL reward signal)
10
+ - Evaluation scripts (for scoring model outputs)
11
+
12
+ Scores expansions on five dimensions:
13
+ Format (30) - Has lex/vec lines, no invalid lines
14
+ Diversity (30) - Multiple types, diverse content, no echoes
15
+ HyDE (20) - Optional bonus for hypothetical document passage
16
+ Quality (20) - Lex shorter than vec, natural language, key terms
17
+ Entity (20) - Named entity preservation in lex/vec lines
18
+
19
+ Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
20
+ """
21
+
22
+ import re
23
+ from collections import Counter
24
+
25
+ # =============================================================================
26
+ # Constants
27
+ # =============================================================================
28
+
29
+ # "only:" mode patterns - when query ends with these, expect only that type
30
+ # Format: "query /only:lex" (slash prefix, no space after colon)
31
+ ONLY_MODE_PATTERN = re.compile(r'\s+/only:(lex|vec|hyde)\s*$', re.IGNORECASE)
32
+
33
+ STOPWORDS = frozenset({
34
+ 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
35
+ 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
36
+ })
37
+
38
+ KEY_TERM_STOPWORDS = frozenset({
39
+ 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
40
+ 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
41
+ 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
42
+ })
43
+
44
+ GENERIC_LEX_PHRASES = frozenset({
45
+ 'find information about', 'search for', 'look up', 'get information',
46
+ 'learn about', 'information on', 'details about', 'find out about',
47
+ 'what is', 'how to', 'guide to', 'help with',
48
+ })
49
+
50
+ # Chat template tokens that indicate a broken output
51
+ CHAT_TEMPLATE_TOKENS = frozenset({
52
+ '<|im_start|>', '<|im_end|>', '<|endoftext|>',
53
+ '\nassistant\n', '\nuser\n',
54
+ })
55
+
56
+
57
+ # =============================================================================
58
+ # Parsing
59
+ # =============================================================================
60
+
61
+ def parse_expansion(text: str) -> dict:
62
+ """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
63
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
64
+ for line in text.strip().split("\n"):
65
+ line = line.strip()
66
+ if not line:
67
+ continue
68
+ if line.startswith("lex:"):
69
+ result["lex"].append(line[4:].strip())
70
+ elif line.startswith("vec:"):
71
+ result["vec"].append(line[4:].strip())
72
+ elif line.startswith("hyde:"):
73
+ result["hyde"].append(line[5:].strip())
74
+ else:
75
+ result["invalid"].append(line)
76
+ return result
77
+
78
+
79
+ def detect_only_mode(query: str) -> tuple[str | None, str]:
80
+ """Detect if query ends with 'only: lex/vec/hyde'.
81
+
82
+ Returns (only_type, base_query) where only_type is None for normal queries.
83
+ """
84
+ match = ONLY_MODE_PATTERN.search(query)
85
+ if match:
86
+ only_type = match.group(1).lower()
87
+ base_query = query[:match.start()].strip()
88
+ return only_type, base_query
89
+ return None, query
90
+
91
+
92
+ def clean_model_output(text: str) -> tuple[str, bool]:
93
+ """Strip chat template artifacts from model output.
94
+
95
+ Returns (cleaned_text, used_thinking) where used_thinking is True
96
+ if the model emitted <think>...</think> blocks.
97
+ """
98
+ text = text.replace('<|im_end|>', '').strip()
99
+
100
+ used_thinking = '<think>' in text and '</think>' in text
101
+ if used_thinking:
102
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
103
+
104
+ return text, used_thinking
105
+
106
+
107
+ # =============================================================================
108
+ # Helpers
109
+ # =============================================================================
110
+
111
+ def extract_named_entities(query: str) -> set:
112
+ """Extract named entities using heuristics.
113
+
114
+ Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React),
115
+ technical terms with special chars (node.js, C++), CamelCase (JavaScript),
116
+ and compound names (TDS motorsports -> both words).
117
+ """
118
+ entities = set()
119
+ words = query.split()
120
+ prev_was_entity = False
121
+
122
+ for i, word in enumerate(words):
123
+ clean = word.strip('.,!?:;()[]"\'')
124
+ if not clean:
125
+ prev_was_entity = False
126
+ continue
127
+
128
+ is_entity = False
129
+
130
+ if clean.isupper() and len(clean) >= 2:
131
+ entities.add(clean.lower())
132
+ is_entity = True
133
+ elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
134
+ entities.add(clean.lower())
135
+ is_entity = True
136
+ elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
137
+ entities.add(clean.lower())
138
+ is_entity = True
139
+ elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
140
+ entities.add(clean.lower())
141
+ is_entity = True
142
+ elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
143
+ entities.add(clean.lower())
144
+ is_entity = True
145
+
146
+ prev_was_entity = is_entity
147
+
148
+ return entities
149
+
150
+
151
+ def get_key_terms(query: str) -> set:
152
+ """Get non-stopword terms from a query."""
153
+ return set(query.lower().split()) - KEY_TERM_STOPWORDS
154
+
155
+
156
+ def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
157
+ """Does the lex line contain at least one key term from the query?"""
158
+ key_terms = get_key_terms(query)
159
+ if not key_terms:
160
+ return True
161
+ return bool(key_terms & set(lex_line.lower().split()))
162
+
163
+
164
+ def lex_preserves_entities(line: str, entities: set) -> bool:
165
+ """Does the line contain at least one named entity?"""
166
+ if not entities:
167
+ return True
168
+ lower = line.lower()
169
+ return any(e in lower for e in entities)
170
+
171
+
172
+ def lex_is_generic(lex_line: str) -> bool:
173
+ """Is this lex line a useless generic filler phrase?"""
174
+ lower = lex_line.lower().strip()
175
+ for phrase in GENERIC_LEX_PHRASES:
176
+ if phrase in lower or lower.startswith(phrase.split()[0]):
177
+ remaining = lower
178
+ for word in phrase.split():
179
+ remaining = remaining.replace(word, '', 1).strip()
180
+ if len(remaining) < 3:
181
+ return True
182
+ return False
183
+
184
+
185
+ def word_set_distance(a: str, b: str) -> int:
186
+ """Symmetric difference of word sets (how many words are unique to one)."""
187
+ return len(set(a.lower().split()) ^ set(b.lower().split()))
188
+
189
+
190
+ def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
191
+ """Are two strings sufficiently different?"""
192
+ a, b = a.lower().strip(), b.lower().strip()
193
+ if a == b or a in b or b in a:
194
+ return False
195
+ return word_set_distance(a, b) >= min_distance
196
+
197
+
198
+ def echoes_query(expansion: str, query: str) -> bool:
199
+ """Is this expansion just echoing the original query?"""
200
+ exp, q = expansion.lower().strip(), query.lower().strip()
201
+ return exp == q or (q in exp and len(exp) < len(q) + 10)
202
+
203
+
204
+ def word_repetition_penalty(text: str) -> int:
205
+ """Penalty for words repeated 3+ times (excluding stopwords)."""
206
+ counts = Counter(re.findall(r'\b\w+\b', text.lower()))
207
+ return sum((c - 2) * 2 for w, c in counts.items()
208
+ if c >= 3 and w not in STOPWORDS and len(w) > 2)
209
+
210
+
211
+ # =============================================================================
212
+ # Scoring
213
+ # =============================================================================
214
+
215
+ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict:
216
+ """Score an 'only:' mode expansion. Expects ONLY the requested type."""
217
+ parsed = parse_expansion(text)
218
+ deductions = []
219
+
220
+ # Expected type must be present
221
+ expected_items = parsed.get(only_type, [])
222
+ if not expected_items:
223
+ return {
224
+ "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
225
+ "think_bonus": 0, "total": 0, "max_possible": 100,
226
+ "percentage": 0.0, "rating": "Failed",
227
+ "deductions": [f"missing expected {only_type}: output"],
228
+ "parsed": parsed,
229
+ "entities_detected": [],
230
+ "only_mode": only_type,
231
+ }
232
+
233
+ # Penalize presence of OTHER types
234
+ other_types = {"lex", "vec", "hyde"} - {only_type}
235
+ unwanted_count = sum(len(parsed.get(t, [])) for t in other_types)
236
+ if unwanted_count > 0:
237
+ deductions.append(f"contains unwanted types (expected only {only_type})")
238
+
239
+ # --- Format (0-30) ---
240
+ format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10)
241
+
242
+ # --- Diversity (0-30) ---
243
+ diversity_score = 0
244
+ if len(expected_items) >= 2:
245
+ diversity_score += 15
246
+ # Check for diversity among items
247
+ div_score = 15
248
+ for i, a in enumerate(expected_items):
249
+ for b in expected_items[i+1:]:
250
+ if not is_diverse(a, b, 2):
251
+ div_score -= 5
252
+ deductions.append(f"{only_type} duplicate: {a[:20]}...")
253
+ diversity_score += max(0, div_score)
254
+ elif len(expected_items) == 1:
255
+ diversity_score = 15 # One item is fine for single-type output
256
+
257
+ # Check for echoes
258
+ for exp in expected_items:
259
+ if echoes_query(exp, base_query):
260
+ diversity_score -= 5
261
+ deductions.append(f"echoes query: {exp[:20]}...")
262
+ diversity_score = max(0, diversity_score)
263
+
264
+ # --- Type-specific quality (0-20) ---
265
+ quality_score = 10 # base
266
+ entities = extract_named_entities(base_query)
267
+
268
+ if only_type == "lex":
269
+ # Lex should be short keyword phrases with key terms
270
+ with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query))
271
+ if with_terms == len(expected_items):
272
+ quality_score += 5
273
+ # Check for generic phrases
274
+ generic = sum(1 for l in expected_items if lex_is_generic(l))
275
+ if generic == 0:
276
+ quality_score += 5
277
+ else:
278
+ deductions.append(f"{generic} generic lex phrases")
279
+
280
+ elif only_type == "vec":
281
+ # Vec should be natural language sentences
282
+ natural = sum(1 for v in expected_items if " " in v and len(v) > 15)
283
+ if natural == len(expected_items):
284
+ quality_score += 10
285
+ else:
286
+ quality_score += 5
287
+ deductions.append("vec not all natural language")
288
+
289
+ elif only_type == "hyde":
290
+ # Hyde should be a document snippet (50-200 chars)
291
+ hyde_text = expected_items[0]
292
+ hyde_len = len(hyde_text)
293
+ if 50 <= hyde_len <= 200:
294
+ quality_score += 10
295
+ elif 30 <= hyde_len <= 300:
296
+ quality_score += 5
297
+ deductions.append(f"hyde length {hyde_len} (ideal: 50-200)")
298
+ else:
299
+ deductions.append(f"hyde length {hyde_len} out of range")
300
+
301
+ # --- Entity preservation (0-20) ---
302
+ entity_score = 10 # base
303
+ if entities:
304
+ with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities))
305
+ if with_entities == len(expected_items):
306
+ entity_score += 10
307
+ elif with_entities > 0:
308
+ entity_score += 5
309
+ else:
310
+ entity_score = 0
311
+ deductions.append(f"missing entities: {entities}")
312
+
313
+ # --- Think bonus (0-20) ---
314
+ think_bonus = 0 if used_thinking else 20
315
+
316
+ # --- Total ---
317
+ total = format_score + diversity_score + quality_score + entity_score + think_bonus
318
+ max_possible = 120
319
+ percentage = max(0.0, min(100.0, total / max_possible * 100))
320
+
321
+ if percentage >= 80:
322
+ rating = "Excellent"
323
+ elif percentage >= 60:
324
+ rating = "Good"
325
+ elif percentage >= 40:
326
+ rating = "Acceptable"
327
+ elif percentage >= 20:
328
+ rating = "Poor"
329
+ else:
330
+ rating = "Failed"
331
+
332
+ return {
333
+ "format": format_score,
334
+ "diversity": diversity_score,
335
+ "hyde": 0, # not used in only mode (quality covers it)
336
+ "quality": quality_score,
337
+ "entity": entity_score,
338
+ "think_bonus": think_bonus,
339
+ "total": total,
340
+ "max_possible": max_possible,
341
+ "percentage": round(percentage, 1),
342
+ "rating": rating,
343
+ "deductions": deductions,
344
+ "parsed": parsed,
345
+ "entities_detected": list(entities) if entities else [],
346
+ "only_mode": only_type,
347
+ }
348
+
349
+
350
+ def score_expansion_detailed(query: str, expansion: str) -> dict:
351
+ """Score an expansion with full breakdown. Returns dict with all dimensions."""
352
+ text, used_thinking = clean_model_output(expansion.strip())
353
+ deductions = []
354
+
355
+ # Detect "only:" mode
356
+ only_type, base_query = detect_only_mode(query)
357
+
358
+ def _fail(reason):
359
+ return {
360
+ "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
361
+ "think_bonus": 0, "total": 0, "max_possible": 100,
362
+ "percentage": 0.0, "rating": "Failed",
363
+ "deductions": [reason],
364
+ "parsed": parse_expansion(expansion),
365
+ "entities_detected": [],
366
+ "only_mode": only_type,
367
+ }
368
+
369
+ # Hard fail: remaining chat template tokens
370
+ if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
371
+ return _fail("CHAT TEMPLATE LEAKAGE")
372
+
373
+ # Hard fail: every non-empty line must have a valid prefix
374
+ for line in text.split("\n"):
375
+ line = line.strip()
376
+ if line and not line.startswith(("lex:", "vec:", "hyde:")):
377
+ return _fail(f"INVALID LINE: {line[:50]}")
378
+
379
+ # --- Handle "only:" mode separately ---
380
+ if only_type:
381
+ return _score_only_mode(query, base_query, text, used_thinking, only_type)
382
+
383
+ parsed = parse_expansion(text)
384
+
385
+ # --- Format (0-30) ---
386
+ format_score = 10 # no invalid lines (guaranteed by hard fail)
387
+ if parsed["lex"]:
388
+ format_score += 10
389
+ else:
390
+ deductions.append("missing lex:")
391
+ if parsed["vec"]:
392
+ format_score += 10
393
+ else:
394
+ deductions.append("missing vec:")
395
+
396
+ # --- Diversity (0-30) ---
397
+ diversity_score = 0
398
+
399
+ types_present = sum(1 for t in ("lex", "vec") if parsed[t])
400
+ if types_present >= 2:
401
+ diversity_score += 10
402
+ else:
403
+ deductions.append("only one type")
404
+
405
+ if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
406
+ diversity_score += 5
407
+
408
+ lex_div = 5
409
+ for i, a in enumerate(parsed["lex"]):
410
+ for b in parsed["lex"][i+1:]:
411
+ if not is_diverse(a, b, 2):
412
+ lex_div -= 2
413
+ deductions.append(f"lex duplicate: {a[:20]}...")
414
+ diversity_score += max(0, lex_div)
415
+
416
+ vec_div = 5
417
+ for i, a in enumerate(parsed["vec"]):
418
+ for b in parsed["vec"][i+1:]:
419
+ if not is_diverse(a, b, 3):
420
+ vec_div -= 2
421
+ deductions.append(f"vec duplicate: {a[:20]}...")
422
+ diversity_score += max(0, vec_div)
423
+
424
+ echo = 5
425
+ lex_echo_count = 0
426
+ for exp in parsed["lex"]:
427
+ if echoes_query(exp, query):
428
+ lex_echo_count += 1
429
+ deductions.append(f"lex echoes query: {exp[:20]}...")
430
+ # Harsh penalty for lex echoes - they're useless
431
+ if lex_echo_count > 0:
432
+ echo -= lex_echo_count * 10 # -10 per echo
433
+
434
+ for exp in parsed["vec"]:
435
+ if echoes_query(exp, query):
436
+ echo -= 3 # vec echoes less severe (natural language overlap ok)
437
+ deductions.append(f"vec echoes query: {exp[:20]}...")
438
+ diversity_score += max(-10, echo) # can go negative
439
+
440
+ # --- HyDE (0-20, optional bonus) ---
441
+ hyde_score = 0
442
+ if parsed["hyde"]:
443
+ hyde_text = parsed["hyde"][0]
444
+ hyde_score += 5
445
+ hyde_len = len(hyde_text)
446
+ if 50 <= hyde_len <= 200:
447
+ hyde_score += 5
448
+ elif hyde_len < 50:
449
+ hyde_score += 2
450
+ deductions.append(f"hyde too short ({hyde_len})")
451
+ else:
452
+ deductions.append(f"hyde too long ({hyde_len})")
453
+ if "\n" not in hyde_text:
454
+ hyde_score += 5
455
+ hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
456
+
457
+ # --- Quality (0-20) ---
458
+ quality_score = 5 # base relevance
459
+ if parsed["lex"] and parsed["vec"]:
460
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
461
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
462
+ if avg_lex <= avg_vec:
463
+ quality_score += 5
464
+ else:
465
+ deductions.append("lex longer than vec")
466
+ if parsed["vec"]:
467
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
468
+ quality_score += 5 if natural == len(parsed["vec"]) else 2
469
+ if parsed["lex"]:
470
+ with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
471
+ if with_terms == len(parsed["lex"]):
472
+ quality_score += 5
473
+ elif with_terms > 0:
474
+ quality_score += 2
475
+ else:
476
+ deductions.append("lex missing key terms")
477
+
478
+ # --- Entity Preservation (-45 to +20) ---
479
+ entity_score = 0
480
+ entities = extract_named_entities(query)
481
+ if entities and parsed["lex"]:
482
+ with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
483
+ if with_entities == len(parsed["lex"]):
484
+ entity_score += 15
485
+ elif with_entities > 0:
486
+ entity_score += 5
487
+ else:
488
+ entity_score -= 30
489
+ deductions.append(f"lex missing entities: {entities}")
490
+
491
+ generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
492
+ if generic_count:
493
+ entity_score -= generic_count * 15
494
+ deductions.append(f"{generic_count} generic lex phrases")
495
+
496
+ if parsed["vec"]:
497
+ vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
498
+ if vec_with > 0:
499
+ entity_score += 5
500
+ elif not entities:
501
+ entity_score = 10
502
+
503
+ # --- Think bonus (0-20): reward NOT using thinking mode ---
504
+ think_bonus = 0 if used_thinking else 20
505
+
506
+ # --- Total ---
507
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
508
+ max_possible = 140 if parsed["hyde"] else 120
509
+ percentage = max(0.0, min(100.0, total / max_possible * 100))
510
+
511
+ # Hard cap: lex echoes are unacceptable - cap at 50%
512
+ if lex_echo_count > 0:
513
+ percentage = min(percentage, 50.0)
514
+ deductions.insert(0, f"CAPPED: {lex_echo_count} lex echo(es)")
515
+
516
+ if percentage >= 80:
517
+ rating = "Excellent"
518
+ elif percentage >= 60:
519
+ rating = "Good"
520
+ elif percentage >= 40:
521
+ rating = "Acceptable"
522
+ elif percentage >= 20:
523
+ rating = "Poor"
524
+ else:
525
+ rating = "Failed"
526
+
527
+ return {
528
+ "format": format_score,
529
+ "diversity": diversity_score,
530
+ "hyde": hyde_score,
531
+ "quality": quality_score,
532
+ "entity": max(0, entity_score),
533
+ "think_bonus": think_bonus,
534
+ "total": max(0, total),
535
+ "max_possible": max_possible,
536
+ "percentage": round(percentage, 1),
537
+ "rating": rating,
538
+ "deductions": deductions,
539
+ "parsed": parsed,
540
+ "entities_detected": list(entities) if entities else [],
541
+ "only_mode": None,
542
+ }
543
+
544
+
545
+ def score_expansion(query: str, expansion: str) -> float:
546
+ """Score expansion as a float in [0.0, 1.0] for use as RL reward."""
547
+ result = score_expansion_detailed(query, expansion)
548
+ return max(0.0, min(1.0, result["total"] / result["max_possible"]))
549
+
550
+
551
+ def extract_query_from_prompt(prompt: str) -> str:
552
+ """Extract the query string from a chat-formatted prompt."""
553
+ if "Expand this search query:" in prompt:
554
+ query = prompt.split("Expand this search query:")[-1].strip()
555
+ if "<|im_end|>" in query:
556
+ query = query.split("<|im_end|>")[0].strip()
557
+ return query
558
+ return prompt.strip()
559
+
560
+
561
+ # =============================================================================
562
+ # TRL-compatible reward class
563
+ # =============================================================================
564
+
565
+ class QMDRewardFunction:
566
+ """Reward function compatible with TRL's GRPOTrainer."""
567
+ __name__ = "qmd_scoring_reward"
568
+
569
+ def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
570
+ rewards = []
571
+ for i, completion in enumerate(completions):
572
+ query = ""
573
+ if prompts and i < len(prompts):
574
+ query = extract_query_from_prompt(prompts[i])
575
+ rewards.append(score_expansion(query, completion))
576
+ return rewards
577
+
578
+
579
+ # =============================================================================
580
+ # CLI: run standalone to test the reward function
581
+ # =============================================================================
582
+
583
+ if __name__ == "__main__":
584
+ print("QMD Reward Function Self-Test")
585
+ print("=" * 60)
586
+
587
+ tests = [
588
+ ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
589
+ ("auth", "auth is important for security"),
590
+ ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
591
+ ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
592
+ ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
593
+ ("auth", "<think>Let me think...</think>\nlex: auth"),
594
+ ("auth", "lex: auth\nThis is some explanation\nvec: more"),
595
+ # "/only:" mode tests (slash prefix)
596
+ ("auth /only:lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"),
597
+ ("auth /only:lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec
598
+ ("React hooks /only:vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"),
599
+ ("PostgreSQL indexing /only:hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."),
600
+ ]
601
+
602
+ for query, expansion in tests:
603
+ score = score_expansion(query, expansion)
604
+ detail = score_expansion_detailed(query, expansion)
605
+ only_mode = detail.get("only_mode")
606
+ mode_str = f" [only:{only_mode}]" if only_mode else ""
607
+ print(f"\n Query: '{query}'{mode_str}")
608
+ print(f" Score: {score:.2f} ({detail['rating']})")
609
+ if detail["deductions"]:
610
+ print(f" Issues: {', '.join(detail['deductions'][:3])}")