asft 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asft/__init__.py +30 -0
- asft/accuracy/__init__.py +16 -0
- asft/accuracy/confidence_scorer.py +167 -0
- asft/accuracy/multi_pass_reasoner.py +185 -0
- asft/accuracy/self_critique.py +187 -0
- asft/accuracy/verification_layer.py +277 -0
- asft/accuracy/verifier.py +333 -0
- asft/api/__init__.py +1 -0
- asft/api/middleware.py +157 -0
- asft/api/schemas.py +340 -0
- asft/api/server.py +414 -0
- asft/api/websockets.py +82 -0
- asft/benchmark/__init__.py +1 -0
- asft/benchmark/reporter.py +148 -0
- asft/benchmark/runner.py +196 -0
- asft/cli/__init__.py +1 -0
- asft/cli/main.py +462 -0
- asft/compute/__init__.py +1 -0
- asft/compute/adaptive_compute.py +241 -0
- asft/continual/__init__.py +1 -0
- asft/continual/ewc_trainer.py +261 -0
- asft/core/__init__.py +7 -0
- asft/core/config.py +199 -0
- asft/core/events.py +100 -0
- asft/core/exceptions.py +234 -0
- asft/core/hardware_profiler.py +344 -0
- asft/core/interfaces.py +346 -0
- asft/core/registry.py +162 -0
- asft/core/settings.py +250 -0
- asft/dataset/__init__.py +8 -0
- asft/dataset/clusterer.py +100 -0
- asft/dataset/compressor.py +147 -0
- asft/dataset/deduplicator.py +96 -0
- asft/dataset/quality_scorer.py +131 -0
- asft/dataset/representative_selector.py +110 -0
- asft/dataset/streaming_compressor.py +258 -0
- asft/db/database.py +25 -0
- asft/db/maintenance.py +138 -0
- asft/db/models.py +109 -0
- asft/distillation/__init__.py +1 -0
- asft/distillation/knowledge_distiller.py +424 -0
- asft/evaluation/benchmark_manager.py +117 -0
- asft/evaluation/harness.py +51 -0
- asft/hardware/__init__.py +1 -0
- asft/hardware/optimizer.py +187 -0
- asft/improvement/__init__.py +1 -0
- asft/layers/__init__.py +1 -0
- asft/memory/__init__.py +9 -0
- asft/memory/backends/faiss_adapter.py +122 -0
- asft/memory/backends/qdrant.py +111 -0
- asft/memory/backends/secure_qdrant.py +139 -0
- asft/memory/consolidator.py +55 -0
- asft/memory/episodic_memory.py +337 -0
- asft/memory/long_term_memory.py +159 -0
- asft/memory/memory_manager.py +226 -0
- asft/memory/semantic_memory.py +187 -0
- asft/memory/vector_memory.py +245 -0
- asft/memory/working_memory.py +95 -0
- asft/observability/logging.py +105 -0
- asft/observability/metrics.py +84 -0
- asft/optimizer/__init__.py +1 -0
- asft/optimizer/auto_optimizer.py +448 -0
- asft/optimizer/cost_estimator.py +397 -0
- asft/optimizer/decision_engine.py +242 -0
- asft/plugins/loader.py +68 -0
- asft/security/auth.py +186 -0
- asft/security/input_validator.py +187 -0
- asft/security/rbac.py +78 -0
- asft/security/sandbox.py +134 -0
- asft/selection/__init__.py +1 -0
- asft/selection/parameter_selector.py +340 -0
- asft/selection/sample_selector.py +336 -0
- asft/skills/__init__.py +7 -0
- asft/skills/packs/__init__.py +1 -0
- asft/skills/packs/automation.py +84 -0
- asft/skills/packs/coding.py +75 -0
- asft/skills/packs/mathematics.py +94 -0
- asft/skills/packs/planning.py +71 -0
- asft/skills/packs/research.py +61 -0
- asft/skills/packs/trading.py +78 -0
- asft/skills/skill_pack.py +95 -0
- asft/skills/skill_router.py +305 -0
- asft/sparse/__init__.py +6 -0
- asft/sparse/activation_analyzer.py +191 -0
- asft/sparse/dynamic_sparse.py +305 -0
- asft/sparse/lora_adapter.py +145 -0
- asft/sparse/neuron_selector.py +222 -0
- asft/sparse/sparse_trainer.py +160 -0
- asft/training/checkpoint_manager.py +124 -0
- asft/training/job_store.py +297 -0
- asft/training/peft_trainer.py +296 -0
- asft/workers/__init__.py +1 -0
- asft/workers/celery_app.py +30 -0
- asft/workers/process_pool.py +123 -0
- asft/workers/tasks.py +224 -0
- asft-0.1.0.dist-info/METADATA +138 -0
- asft-0.1.0.dist-info/RECORD +100 -0
- asft-0.1.0.dist-info/WHEEL +4 -0
- asft-0.1.0.dist-info/entry_points.txt +2 -0
- asft-0.1.0.dist-info/licenses/LICENSE +201 -0
asft/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ASFT — Adaptive Sparse Fine-Tuning Framework
|
|
3
|
+
=============================================
|
|
4
|
+
|
|
5
|
+
A next-generation, hardware-adaptive AI learning framework that dramatically
|
|
6
|
+
reduces training cost and time while improving accuracy, reliability, and
|
|
7
|
+
reasoning quality.
|
|
8
|
+
|
|
9
|
+
Learning Priority Hierarchy:
|
|
10
|
+
Memory → Workflow Optimization → Tool Learning →
|
|
11
|
+
Skill Packs → Sparse Fine-Tuning → Full Fine-Tuning
|
|
12
|
+
|
|
13
|
+
Full retraining is always the last resort.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
__version__ = "0.1.0"
|
|
17
|
+
__author__ = "ASFT Contributors"
|
|
18
|
+
__license__ = "MIT"
|
|
19
|
+
|
|
20
|
+
from asft.core.config import ASFTConfig
|
|
21
|
+
from asft.core.hardware_profiler import HardwareProfile, detect_hardware
|
|
22
|
+
from asft.core.registry import Registry
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"__version__",
|
|
26
|
+
"ASFTConfig",
|
|
27
|
+
"detect_hardware",
|
|
28
|
+
"HardwareProfile",
|
|
29
|
+
"Registry",
|
|
30
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""ASFT Accuracy Package."""
|
|
2
|
+
|
|
3
|
+
from asft.accuracy.confidence_scorer import ConfidenceScore, ConfidenceScorer
|
|
4
|
+
from asft.accuracy.multi_pass_reasoner import MultiPassReasoner, ReasoningResult
|
|
5
|
+
from asft.accuracy.self_critique import CritiqueResult, SelfCritiqueEngine
|
|
6
|
+
from asft.accuracy.verification_layer import VerificationLayer
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"ConfidenceScorer",
|
|
10
|
+
"ConfidenceScore",
|
|
11
|
+
"MultiPassReasoner",
|
|
12
|
+
"ReasoningResult",
|
|
13
|
+
"SelfCritiqueEngine",
|
|
14
|
+
"CritiqueResult",
|
|
15
|
+
"VerificationLayer",
|
|
16
|
+
]
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Confidence Scorer — Assigns confidence, reliability, and verification scores
|
|
3
|
+
to every model output. Low scores trigger additional reasoning passes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ConfidenceScore:
|
|
17
|
+
confidence: float # 0–1: overall output confidence
|
|
18
|
+
reliability: float # 0–1: factual reliability estimate
|
|
19
|
+
verification: float # 0–1: how well the output can be verified
|
|
20
|
+
composite: float # weighted combination
|
|
21
|
+
flags: list[str] # detected quality issues
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def needs_extra_pass(self) -> bool:
|
|
25
|
+
return self.composite < 0.7
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def label(self) -> str:
|
|
29
|
+
if self.composite >= 0.85:
|
|
30
|
+
return "HIGH"
|
|
31
|
+
elif self.composite >= 0.65:
|
|
32
|
+
return "MEDIUM"
|
|
33
|
+
else:
|
|
34
|
+
return "LOW"
|
|
35
|
+
|
|
36
|
+
def __str__(self) -> str:
|
|
37
|
+
return (
|
|
38
|
+
f"ConfidenceScore [{self.label}] "
|
|
39
|
+
f"conf={self.confidence:.2f} rel={self.reliability:.2f} "
|
|
40
|
+
f"ver={self.verification:.2f} composite={self.composite:.2f}"
|
|
41
|
+
+ (f" flags={self.flags}" if self.flags else "")
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Patterns that indicate uncertainty or low confidence
|
|
46
|
+
_UNCERTAINTY_PATTERNS = [
|
|
47
|
+
r"\bi (think|believe|guess|assume)\b",
|
|
48
|
+
r"\b(probably|possibly|maybe|perhaps|might|could be)\b",
|
|
49
|
+
r"\b(i('m| am) not sure|i don't know|uncertain|unclear)\b",
|
|
50
|
+
r"\b(approximately|roughly|around|about)\b",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
# Patterns that indicate hallucination risk
|
|
54
|
+
_HALLUCINATION_PATTERNS = [
|
|
55
|
+
r"\bspecific(ally)? (in|on|at) \d{4}\b", # specific years without context
|
|
56
|
+
r"\baccording to (a|the) (recent|new) study\b",
|
|
57
|
+
r"\bexperts (say|agree|claim)\b",
|
|
58
|
+
r"\bit('s| is) (widely|commonly|generally) known\b",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
# Patterns that indicate verifiable content
|
|
62
|
+
_VERIFIABLE_PATTERNS = [
|
|
63
|
+
r"\d+(\.\d+)?", # numbers
|
|
64
|
+
r"```[\s\S]+?```", # code blocks
|
|
65
|
+
r"https?://\S+", # URLs
|
|
66
|
+
r"\b\d{4}\b", # years
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ConfidenceScorer:
|
|
71
|
+
"""
|
|
72
|
+
Multi-dimensional output quality scorer.
|
|
73
|
+
|
|
74
|
+
Analyzes:
|
|
75
|
+
- Linguistic uncertainty markers
|
|
76
|
+
- Hallucination risk patterns
|
|
77
|
+
- Verifiability signals
|
|
78
|
+
- Output length and structure
|
|
79
|
+
- Task-specific quality indicators
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, threshold: float = 0.7):
|
|
83
|
+
self._threshold = threshold
|
|
84
|
+
self._uncertainty_re = [re.compile(p, re.IGNORECASE) for p in _UNCERTAINTY_PATTERNS]
|
|
85
|
+
self._hallucination_re = [re.compile(p, re.IGNORECASE) for p in _HALLUCINATION_PATTERNS]
|
|
86
|
+
self._verifiable_re = [re.compile(p) for p in _VERIFIABLE_PATTERNS]
|
|
87
|
+
|
|
88
|
+
def score(
|
|
89
|
+
self, output: str, task_type: str = "general", model_logprobs: list[float] | None = None
|
|
90
|
+
) -> ConfidenceScore:
|
|
91
|
+
"""Score a model output."""
|
|
92
|
+
flags: list[str] = []
|
|
93
|
+
|
|
94
|
+
if not output or len(output.strip()) < 5:
|
|
95
|
+
return ConfidenceScore(0.1, 0.1, 0.1, 0.1, ["empty_output"])
|
|
96
|
+
|
|
97
|
+
# 1. Confidence: based on uncertainty language
|
|
98
|
+
confidence = self._score_confidence(output, flags)
|
|
99
|
+
|
|
100
|
+
# 2. Reliability: based on hallucination risk patterns
|
|
101
|
+
reliability = self._score_reliability(output, flags)
|
|
102
|
+
|
|
103
|
+
# 3. Verification: how verifiable is the output
|
|
104
|
+
verification = self._score_verification(output)
|
|
105
|
+
|
|
106
|
+
# 4. Length/structure bonus
|
|
107
|
+
length_bonus = min(0.1, len(output) / 2000)
|
|
108
|
+
has_structure = any(c in output for c in ["\n", "##", "- ", "1."])
|
|
109
|
+
structure_bonus = 0.05 if has_structure else 0.0
|
|
110
|
+
|
|
111
|
+
# 5. Log-prob bonus if available
|
|
112
|
+
logprob_bonus = 0.0
|
|
113
|
+
if model_logprobs:
|
|
114
|
+
avg_logprob = sum(model_logprobs) / len(model_logprobs)
|
|
115
|
+
logprob_bonus = min(0.1, max(0.0, (avg_logprob + 5) / 50))
|
|
116
|
+
|
|
117
|
+
# Composite (weighted)
|
|
118
|
+
composite = (
|
|
119
|
+
0.40 * confidence
|
|
120
|
+
+ 0.30 * reliability
|
|
121
|
+
+ 0.20 * verification
|
|
122
|
+
+ length_bonus
|
|
123
|
+
+ structure_bonus
|
|
124
|
+
+ logprob_bonus
|
|
125
|
+
)
|
|
126
|
+
composite = min(1.0, max(0.0, composite))
|
|
127
|
+
|
|
128
|
+
return ConfidenceScore(
|
|
129
|
+
confidence=round(confidence, 3),
|
|
130
|
+
reliability=round(reliability, 3),
|
|
131
|
+
verification=round(verification, 3),
|
|
132
|
+
composite=round(composite, 3),
|
|
133
|
+
flags=flags,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _score_confidence(self, text: str, flags: list[str]) -> float:
|
|
137
|
+
uncertainty_hits = sum(1 for p in self._uncertainty_re if p.search(text))
|
|
138
|
+
if uncertainty_hits >= 3:
|
|
139
|
+
flags.append("high_uncertainty")
|
|
140
|
+
return 0.4
|
|
141
|
+
elif uncertainty_hits >= 1:
|
|
142
|
+
flags.append("some_uncertainty")
|
|
143
|
+
return 0.7
|
|
144
|
+
return 0.9
|
|
145
|
+
|
|
146
|
+
def _score_reliability(self, text: str, flags: list[str]) -> float:
|
|
147
|
+
hallucination_hits = sum(1 for p in self._hallucination_re if p.search(text))
|
|
148
|
+
if hallucination_hits >= 2:
|
|
149
|
+
flags.append("hallucination_risk")
|
|
150
|
+
return 0.35
|
|
151
|
+
elif hallucination_hits >= 1:
|
|
152
|
+
flags.append("minor_hallucination_risk")
|
|
153
|
+
return 0.65
|
|
154
|
+
return 0.85
|
|
155
|
+
|
|
156
|
+
def _score_verification(self, text: str) -> float:
|
|
157
|
+
verifiable_hits = sum(1 for p in self._verifiable_re if p.search(text))
|
|
158
|
+
return min(1.0, 0.3 + verifiable_hits * 0.15)
|
|
159
|
+
|
|
160
|
+
def batch_score(self, outputs: list[str], task_type: str = "general") -> list[ConfidenceScore]:
|
|
161
|
+
return [self.score(o, task_type) for o in outputs]
|
|
162
|
+
|
|
163
|
+
def best_output(self, outputs: list[str], task_type: str = "general") -> tuple:
|
|
164
|
+
"""Select the highest-confidence output from a list."""
|
|
165
|
+
scores = self.batch_score(outputs, task_type)
|
|
166
|
+
best_idx = max(range(len(scores)), key=lambda i: scores[i].composite)
|
|
167
|
+
return outputs[best_idx], scores[best_idx]
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-Pass Reasoner — Generates K candidate solutions, scores each,
|
|
3
|
+
and returns the highest-confidence answer. Implements self-consistency reasoning.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
|
|
12
|
+
from asft.accuracy.confidence_scorer import ConfidenceScore, ConfidenceScorer
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ReasoningResult:
|
|
19
|
+
best_output: str
|
|
20
|
+
best_score: ConfidenceScore
|
|
21
|
+
all_outputs: list[str] = field(default_factory=list)
|
|
22
|
+
all_scores: list[ConfidenceScore] = field(default_factory=list)
|
|
23
|
+
consensus_output: str | None = None
|
|
24
|
+
passes_used: int = 1
|
|
25
|
+
task_type: str = "general"
|
|
26
|
+
|
|
27
|
+
def summary(self) -> str:
|
|
28
|
+
return (
|
|
29
|
+
f"MultiPassReasoning: {self.passes_used} passes | "
|
|
30
|
+
f"best={self.best_score.composite:.3f} | "
|
|
31
|
+
f"task={self.task_type}"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MultiPassReasoner:
|
|
36
|
+
"""
|
|
37
|
+
Generates multiple candidate solutions and selects the best one.
|
|
38
|
+
|
|
39
|
+
Strategies:
|
|
40
|
+
- best_of_k: Generate K outputs, pick highest-confidence
|
|
41
|
+
- self_consistency: Pick output with highest agreement across candidates
|
|
42
|
+
- escalating: Start with 1 pass, escalate if confidence is low
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
k: int = 3,
|
|
48
|
+
min_confidence: float = 0.7,
|
|
49
|
+
strategy: str = "best_of_k",
|
|
50
|
+
):
|
|
51
|
+
self._k = k
|
|
52
|
+
self._min_confidence = min_confidence
|
|
53
|
+
self._strategy = strategy
|
|
54
|
+
self._scorer = ConfidenceScorer(threshold=min_confidence)
|
|
55
|
+
|
|
56
|
+
def reason(
|
|
57
|
+
self,
|
|
58
|
+
generate_fn: Callable[[int], list[str]],
|
|
59
|
+
task_type: str = "general",
|
|
60
|
+
) -> ReasoningResult:
|
|
61
|
+
"""
|
|
62
|
+
Run multi-pass reasoning.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
generate_fn: callable(n_samples) → List[str] — model generation function
|
|
66
|
+
task_type: domain hint for scoring
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
ReasoningResult with best output and all candidates
|
|
70
|
+
"""
|
|
71
|
+
if self._strategy == "escalating":
|
|
72
|
+
return self._escalating_pass(generate_fn, task_type)
|
|
73
|
+
elif self._strategy == "self_consistency":
|
|
74
|
+
return self._self_consistency(generate_fn, task_type)
|
|
75
|
+
else:
|
|
76
|
+
return self._best_of_k(generate_fn, task_type)
|
|
77
|
+
|
|
78
|
+
def _best_of_k(self, generate_fn, task_type: str) -> ReasoningResult:
|
|
79
|
+
"""Generate K outputs, score all, return highest-confidence."""
|
|
80
|
+
outputs = generate_fn(self._k)
|
|
81
|
+
if not outputs:
|
|
82
|
+
return ReasoningResult(
|
|
83
|
+
best_output="", best_score=ConfidenceScore(0, 0, 0, 0, ["no_output"])
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
scores = self._scorer.batch_score(outputs, task_type)
|
|
87
|
+
best_idx = max(range(len(scores)), key=lambda i: scores[i].composite)
|
|
88
|
+
|
|
89
|
+
logger.debug(
|
|
90
|
+
"BestOfK[%d]: scores=%s best=%.3f",
|
|
91
|
+
self._k,
|
|
92
|
+
[round(s.composite, 3) for s in scores],
|
|
93
|
+
scores[best_idx].composite,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return ReasoningResult(
|
|
97
|
+
best_output=outputs[best_idx],
|
|
98
|
+
best_score=scores[best_idx],
|
|
99
|
+
all_outputs=outputs,
|
|
100
|
+
all_scores=scores,
|
|
101
|
+
passes_used=len(outputs),
|
|
102
|
+
task_type=task_type,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def _self_consistency(self, generate_fn, task_type: str) -> ReasoningResult:
|
|
106
|
+
"""Generate K outputs, find the one with most agreement."""
|
|
107
|
+
outputs = generate_fn(self._k)
|
|
108
|
+
if not outputs:
|
|
109
|
+
return ReasoningResult(
|
|
110
|
+
best_output="", best_score=ConfidenceScore(0, 0, 0, 0, ["no_output"])
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
scores = self._scorer.batch_score(outputs, task_type)
|
|
114
|
+
|
|
115
|
+
# Find consensus: output most similar (by token overlap) to others
|
|
116
|
+
consensus_idx = self._find_consensus(outputs)
|
|
117
|
+
consensus = outputs[consensus_idx]
|
|
118
|
+
|
|
119
|
+
best_idx = max(range(len(scores)), key=lambda i: scores[i].composite)
|
|
120
|
+
|
|
121
|
+
return ReasoningResult(
|
|
122
|
+
best_output=outputs[best_idx],
|
|
123
|
+
best_score=scores[best_idx],
|
|
124
|
+
all_outputs=outputs,
|
|
125
|
+
all_scores=scores,
|
|
126
|
+
consensus_output=consensus,
|
|
127
|
+
passes_used=len(outputs),
|
|
128
|
+
task_type=task_type,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _escalating_pass(self, generate_fn, task_type: str) -> ReasoningResult:
|
|
132
|
+
"""
|
|
133
|
+
Start with 1 pass. If confidence is below threshold, escalate to K.
|
|
134
|
+
Minimizes compute for easy tasks.
|
|
135
|
+
"""
|
|
136
|
+
# First pass
|
|
137
|
+
outputs = generate_fn(1)
|
|
138
|
+
if not outputs:
|
|
139
|
+
return ReasoningResult(
|
|
140
|
+
best_output="", best_score=ConfidenceScore(0, 0, 0, 0, ["no_output"])
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
score = self._scorer.score(outputs[0], task_type)
|
|
144
|
+
if score.composite >= self._min_confidence:
|
|
145
|
+
logger.debug("Escalating: single pass sufficient (score=%.3f)", score.composite)
|
|
146
|
+
return ReasoningResult(
|
|
147
|
+
best_output=outputs[0],
|
|
148
|
+
best_score=score,
|
|
149
|
+
all_outputs=outputs,
|
|
150
|
+
all_scores=[score],
|
|
151
|
+
passes_used=1,
|
|
152
|
+
task_type=task_type,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Escalate: generate K-1 more
|
|
156
|
+
logger.debug(
|
|
157
|
+
"Escalating: low confidence (%.3f) — generating %d more", score.composite, self._k - 1
|
|
158
|
+
)
|
|
159
|
+
extra = generate_fn(self._k - 1)
|
|
160
|
+
all_outputs = outputs + extra
|
|
161
|
+
all_scores = self._scorer.batch_score(all_outputs, task_type)
|
|
162
|
+
best_idx = max(range(len(all_scores)), key=lambda i: all_scores[i].composite)
|
|
163
|
+
|
|
164
|
+
return ReasoningResult(
|
|
165
|
+
best_output=all_outputs[best_idx],
|
|
166
|
+
best_score=all_scores[best_idx],
|
|
167
|
+
all_outputs=all_outputs,
|
|
168
|
+
all_scores=all_scores,
|
|
169
|
+
passes_used=len(all_outputs),
|
|
170
|
+
task_type=task_type,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def _find_consensus(self, outputs: list[str]) -> int:
|
|
174
|
+
"""Find the output with highest token overlap to all others."""
|
|
175
|
+
|
|
176
|
+
def token_overlap(a: str, b: str) -> float:
|
|
177
|
+
ta, tb = set(a.lower().split()), set(b.lower().split())
|
|
178
|
+
return len(ta & tb) / max(1, len(ta | tb))
|
|
179
|
+
|
|
180
|
+
scores = []
|
|
181
|
+
for i, out in enumerate(outputs):
|
|
182
|
+
total = sum(token_overlap(out, outputs[j]) for j in range(len(outputs)) if j != i)
|
|
183
|
+
scores.append(total)
|
|
184
|
+
|
|
185
|
+
return int(max(range(len(scores)), key=lambda i: scores[i]))
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Self-Critique Engine — Reviews model outputs before finalizing.
|
|
3
|
+
Detects logical errors, contradictions, hallucinations, and weak reasoning.
|
|
4
|
+
Automatically revises answers when issues are detected.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import re
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class CritiqueResult:
|
|
19
|
+
original_output: str
|
|
20
|
+
revised_output: str
|
|
21
|
+
issues_found: list[str]
|
|
22
|
+
was_revised: bool
|
|
23
|
+
critique_rounds: int
|
|
24
|
+
quality_improvement: float # estimated improvement 0–1
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def is_clean(self) -> bool:
|
|
28
|
+
return len(self.issues_found) == 0
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SelfCritiqueEngine:
|
|
32
|
+
"""
|
|
33
|
+
Reviews and revises model outputs to detect and fix:
|
|
34
|
+
- Logical errors and contradictions
|
|
35
|
+
- Hallucination indicators
|
|
36
|
+
- Missing information
|
|
37
|
+
- Weak or unsupported reasoning
|
|
38
|
+
- Inconsistent statements
|
|
39
|
+
|
|
40
|
+
Uses a generate_fn (model inference callable) to produce revisions.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, max_rounds: int = 2):
|
|
44
|
+
self._max_rounds = max_rounds
|
|
45
|
+
self._issue_detectors = [
|
|
46
|
+
self._detect_contradictions,
|
|
47
|
+
self._detect_hallucination_markers,
|
|
48
|
+
self._detect_logical_gaps,
|
|
49
|
+
self._detect_incomplete_response,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
def critique(
|
|
53
|
+
self,
|
|
54
|
+
output: str,
|
|
55
|
+
original_task: str,
|
|
56
|
+
generate_fn: Callable[[str], str] | None = None,
|
|
57
|
+
) -> CritiqueResult:
|
|
58
|
+
"""
|
|
59
|
+
Critique an output and optionally revise it.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
output: the original model output to critique
|
|
63
|
+
original_task: the original task/question
|
|
64
|
+
generate_fn: optional callable(prompt) → str for generating revisions
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
CritiqueResult with original, revised, and issue list
|
|
68
|
+
"""
|
|
69
|
+
current = output
|
|
70
|
+
all_issues: list[str] = []
|
|
71
|
+
rounds = 0
|
|
72
|
+
was_revised = False
|
|
73
|
+
|
|
74
|
+
for round_num in range(self._max_rounds):
|
|
75
|
+
issues = self._detect_issues(current)
|
|
76
|
+
if not issues:
|
|
77
|
+
logger.debug("SelfCritique round %d: no issues found", round_num)
|
|
78
|
+
break
|
|
79
|
+
|
|
80
|
+
all_issues.extend(issues)
|
|
81
|
+
logger.info(
|
|
82
|
+
"SelfCritique round %d: %d issues found: %s", round_num, len(issues), issues
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if generate_fn is None:
|
|
86
|
+
# No model available — mark issues but cannot revise
|
|
87
|
+
break
|
|
88
|
+
|
|
89
|
+
# Generate a revised version
|
|
90
|
+
revision_prompt = self._build_revision_prompt(
|
|
91
|
+
original_task=original_task,
|
|
92
|
+
current_output=current,
|
|
93
|
+
issues=issues,
|
|
94
|
+
)
|
|
95
|
+
revised = generate_fn(revision_prompt)
|
|
96
|
+
if revised and revised.strip() and revised != current:
|
|
97
|
+
current = revised
|
|
98
|
+
was_revised = True
|
|
99
|
+
rounds += 1
|
|
100
|
+
else:
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
quality_improvement = min(0.3, 0.1 * rounds) if was_revised else 0.0
|
|
104
|
+
|
|
105
|
+
return CritiqueResult(
|
|
106
|
+
original_output=output,
|
|
107
|
+
revised_output=current,
|
|
108
|
+
issues_found=list(set(all_issues)),
|
|
109
|
+
was_revised=was_revised,
|
|
110
|
+
critique_rounds=rounds,
|
|
111
|
+
quality_improvement=quality_improvement,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _detect_issues(self, text: str) -> list[str]:
|
|
115
|
+
issues = []
|
|
116
|
+
for detector in self._issue_detectors:
|
|
117
|
+
found = detector(text)
|
|
118
|
+
if found:
|
|
119
|
+
issues.extend(found)
|
|
120
|
+
return issues
|
|
121
|
+
|
|
122
|
+
def _detect_contradictions(self, text: str) -> list[str]:
|
|
123
|
+
"""Detect obvious contradictions using negation patterns."""
|
|
124
|
+
sentences = [s.strip() for s in re.split(r"[.!?]", text) if len(s.strip()) > 10]
|
|
125
|
+
issues = []
|
|
126
|
+
for i in range(len(sentences)):
|
|
127
|
+
for j in range(i + 1, min(i + 5, len(sentences))):
|
|
128
|
+
a, b = sentences[i].lower(), sentences[j].lower()
|
|
129
|
+
# Simple: check if one sentence negates a claim from another
|
|
130
|
+
a_words = set(a.split())
|
|
131
|
+
b_words = set(b.split())
|
|
132
|
+
shared = a_words & b_words
|
|
133
|
+
# Check if "not" appears near shared words
|
|
134
|
+
if len(shared) > 3 and ("not" in b_words) != ("not" in a_words):
|
|
135
|
+
issues.append(f"potential_contradiction: '{sentences[i][:40]}...'")
|
|
136
|
+
break
|
|
137
|
+
return issues[:2] # Cap at 2 to avoid noise
|
|
138
|
+
|
|
139
|
+
def _detect_hallucination_markers(self, text: str) -> list[str]:
|
|
140
|
+
"""Detect patterns commonly associated with hallucinations."""
|
|
141
|
+
patterns = [
|
|
142
|
+
(r"\brecently published\b.*\bstudy\b", "unsourced_study_claim"),
|
|
143
|
+
(r"\bexperts agree\b", "vague_expert_claim"),
|
|
144
|
+
(r"\b(all|every|always|never)\b.*\b(are|is|do|does)\b", "absolute_claim"),
|
|
145
|
+
(r"\bscientifically proven\b", "unverified_scientific_claim"),
|
|
146
|
+
]
|
|
147
|
+
issues = []
|
|
148
|
+
for pattern, label in patterns:
|
|
149
|
+
if re.search(pattern, text, re.IGNORECASE):
|
|
150
|
+
issues.append(label)
|
|
151
|
+
return issues
|
|
152
|
+
|
|
153
|
+
def _detect_logical_gaps(self, text: str) -> list[str]:
|
|
154
|
+
"""Detect reasoning gaps — conclusions without supporting reasoning."""
|
|
155
|
+
issues = []
|
|
156
|
+
has_conclusion = any(
|
|
157
|
+
w in text.lower() for w in ["therefore", "thus", "hence", "so", "conclusion"]
|
|
158
|
+
)
|
|
159
|
+
has_reasoning = any(
|
|
160
|
+
w in text.lower() for w in ["because", "since", "given that", "due to", "as a result"]
|
|
161
|
+
)
|
|
162
|
+
if has_conclusion and not has_reasoning and len(text) > 100:
|
|
163
|
+
issues.append("conclusion_without_reasoning")
|
|
164
|
+
return issues
|
|
165
|
+
|
|
166
|
+
def _detect_incomplete_response(self, text: str) -> list[str]:
|
|
167
|
+
"""Detect suspiciously short or abruptly cut-off responses."""
|
|
168
|
+
issues = []
|
|
169
|
+
if len(text.strip()) < 20:
|
|
170
|
+
issues.append("response_too_short")
|
|
171
|
+
if text.strip() and text.strip()[-1] not in ".!?\"'`":
|
|
172
|
+
if len(text) > 50: # Not just a very short answer
|
|
173
|
+
issues.append("response_appears_truncated")
|
|
174
|
+
return issues
|
|
175
|
+
|
|
176
|
+
def _build_revision_prompt(
|
|
177
|
+
self, original_task: str, current_output: str, issues: list[str]
|
|
178
|
+
) -> str:
|
|
179
|
+
issues_str = "\n".join(f" - {issue}" for issue in issues)
|
|
180
|
+
return (
|
|
181
|
+
f"Review and improve this response. The following issues were detected:\n"
|
|
182
|
+
f"{issues_str}\n\n"
|
|
183
|
+
f"Original task: {original_task}\n\n"
|
|
184
|
+
f"Current response:\n{current_output}\n\n"
|
|
185
|
+
f"Provide an improved response that fixes the identified issues. "
|
|
186
|
+
f"Be accurate, clear, and complete."
|
|
187
|
+
)
|