adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from asyncio import gather
|
|
2
|
+
from typing import Any, Literal, Sequence
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from adaptive_harmony import StringThread
|
|
7
|
+
from adaptive_harmony.core.structured_output import JsonParseError
|
|
8
|
+
from adaptive_harmony.graders import BaseGrader, Grade
|
|
9
|
+
from adaptive_harmony.graders.exceptions import IgnoreScoreException
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CombinedGrader(BaseGrader):
|
|
13
|
+
"""
|
|
14
|
+
Combines grades from multiple graders.
|
|
15
|
+
Aggregates their results using weighted sum or average.
|
|
16
|
+
Ignores failing graders and proceeds calculating the aggregate grade with the rest.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
grader_key: str,
|
|
22
|
+
graders: Sequence[BaseGrader],
|
|
23
|
+
weights: list[float] | None = None,
|
|
24
|
+
aggregation_method: Literal["sum", "mean"] = "mean",
|
|
25
|
+
failure_rate_warning_threshold: float = 0.2,
|
|
26
|
+
):
|
|
27
|
+
super().__init__(grader_key)
|
|
28
|
+
self.graders = graders
|
|
29
|
+
if weights:
|
|
30
|
+
assert len(weights) == len(graders), "Number of weights must match number of graders"
|
|
31
|
+
self.weights = weights or [1.0] * len(graders)
|
|
32
|
+
self.agg_method = aggregation_method
|
|
33
|
+
self.failure_rate_warning_threshold = failure_rate_warning_threshold
|
|
34
|
+
|
|
35
|
+
async def grade(self, sample: StringThread) -> Grade:
|
|
36
|
+
async def separate_success_from_fail_graders(grader: BaseGrader) -> Grade | None:
|
|
37
|
+
try:
|
|
38
|
+
return await grader.grade(sample)
|
|
39
|
+
except (IgnoreScoreException, JsonParseError):
|
|
40
|
+
# return None if score is supposed to be ignored, or judge output format failure
|
|
41
|
+
return None
|
|
42
|
+
except Exception as e:
|
|
43
|
+
# fail for any other exception
|
|
44
|
+
raise e
|
|
45
|
+
|
|
46
|
+
tasks = [separate_success_from_fail_graders(grader) for grader in self.graders]
|
|
47
|
+
results: list[Grade | None] = await gather(*tasks)
|
|
48
|
+
|
|
49
|
+
weighted_scores = []
|
|
50
|
+
failed_graders = []
|
|
51
|
+
|
|
52
|
+
# Separate successful and failed results
|
|
53
|
+
successful_graders: list[str] = []
|
|
54
|
+
successful_grades: list[Grade] = []
|
|
55
|
+
successful_weights = []
|
|
56
|
+
|
|
57
|
+
for result, weight, grader in zip(results, self.weights, self.graders):
|
|
58
|
+
if result is not None:
|
|
59
|
+
# Successful grader
|
|
60
|
+
weighted_score = result.value * weight
|
|
61
|
+
successful_graders.append(grader.grader_key)
|
|
62
|
+
weighted_scores.append(weighted_score)
|
|
63
|
+
successful_grades.append(result)
|
|
64
|
+
successful_weights.append(weight)
|
|
65
|
+
else:
|
|
66
|
+
# Failed grader
|
|
67
|
+
failed_graders.append(grader.grader_key)
|
|
68
|
+
|
|
69
|
+
# Fail if no successfull graders
|
|
70
|
+
if not successful_grades:
|
|
71
|
+
raise RuntimeError("All graders failed - cannot compute aggregate grade")
|
|
72
|
+
|
|
73
|
+
# Warn if more than a set % of scorers failed
|
|
74
|
+
total_graders = len(self.graders)
|
|
75
|
+
failure_rate = len(failed_graders) / total_graders
|
|
76
|
+
if failure_rate > self.failure_rate_warning_threshold:
|
|
77
|
+
logger.warning(f"{len(failed_graders)}/{total_graders}% of graders failed for sample: {failed_graders}")
|
|
78
|
+
|
|
79
|
+
# Aggregate scores
|
|
80
|
+
if self.agg_method == "sum":
|
|
81
|
+
final_score = sum(weighted_scores)
|
|
82
|
+
elif self.agg_method == "mean":
|
|
83
|
+
# For average, we normalize by the sum of successful weights (renormalize)
|
|
84
|
+
final_score = sum(weighted_scores) / sum(successful_weights)
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Unknown aggregation method: {self.agg_method}")
|
|
87
|
+
|
|
88
|
+
# Log the combined score.
|
|
89
|
+
self.add_log({"score": final_score})
|
|
90
|
+
|
|
91
|
+
reason = "\n".join(
|
|
92
|
+
[
|
|
93
|
+
f"{grader_key}: {grade.value} - {grade.reasoning}"
|
|
94
|
+
for grader_key, grade in zip(successful_graders, successful_grades)
|
|
95
|
+
]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return Grade(
|
|
99
|
+
value=final_score,
|
|
100
|
+
grader_key=self.grader_key,
|
|
101
|
+
reasoning=reason,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, Any]:
|
|
105
|
+
combined_logs = super().get_logs(clear=False, log_all_samples=log_all_samples)
|
|
106
|
+
all_logs = combined_logs
|
|
107
|
+
for grader in self.graders:
|
|
108
|
+
scorer_logs = grader.get_logs(clear=clear, log_all_samples=log_all_samples)
|
|
109
|
+
all_logs = all_logs | {f"{grader.grader_key}/{key}": value for key, value in scorer_logs.items()}
|
|
110
|
+
|
|
111
|
+
if clear:
|
|
112
|
+
self.clear_logs()
|
|
113
|
+
return all_logs
|
|
114
|
+
|
|
115
|
+
def clear_logs(self) -> None:
|
|
116
|
+
super().clear_logs()
|
|
117
|
+
for grader in self.graders:
|
|
118
|
+
grader.clear_logs()
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pysbd
|
|
8
|
+
from harmony_client import Grade, InferenceModel, StringThread
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from adaptive_harmony.core.utils import stringify_thread
|
|
12
|
+
from adaptive_harmony.graders.base_grader import BaseGrader
|
|
13
|
+
from adaptive_harmony.graders.context_relevancy_judge.prompts import DEFAULT_SHOTS, SYSTEM, USER
|
|
14
|
+
from adaptive_harmony.graders.faithfulness_judge.faithfulness_judge import SupportedLanguages
|
|
15
|
+
from adaptive_harmony.graders.utils import sample_score_distribution
|
|
16
|
+
from adaptive_harmony.logging_table import Table
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DocumentRelevancyResult(BaseModel):
|
|
20
|
+
reason: str = Field(description="The justification for the score given to a document. Keep it short and concise.")
|
|
21
|
+
score: Literal[0, 1] = Field(
|
|
22
|
+
description="The score for the document. A score of 1 if the document contains information relevant to answering the user input, and 0 if the document does not contain information relevant to answering the user input"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ContextRelevancyGrader(BaseGrader):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model: InferenceModel,
|
|
30
|
+
language: SupportedLanguages = "en",
|
|
31
|
+
grader_key: str = "context_relevancy_judge",
|
|
32
|
+
grader_id: str | None = None,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(grader_key)
|
|
35
|
+
self.model = model
|
|
36
|
+
self.language = language
|
|
37
|
+
self.grader_id_or_key = grader_id or grader_key
|
|
38
|
+
self.sentence_splitter = pysbd.Segmenter(language=language)
|
|
39
|
+
self.shots = DEFAULT_SHOTS
|
|
40
|
+
|
|
41
|
+
async def grade(self, sample: StringThread) -> Grade:
|
|
42
|
+
documents = sample.metadata.get("documents", []) if sample.metadata else []
|
|
43
|
+
if not documents:
|
|
44
|
+
self.add_log(
|
|
45
|
+
{
|
|
46
|
+
"prompt": stringify_thread(sample, sep=f"\n\n{'-' * 10}\n\n"),
|
|
47
|
+
"error": "No document turns found in thread",
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
raise ValueError("No document turns found in thread")
|
|
51
|
+
|
|
52
|
+
user_question = next((turn[1] for turn in reversed(sample.get_turns()) if turn[0] == "user"), None)
|
|
53
|
+
if not user_question:
|
|
54
|
+
self.add_log(
|
|
55
|
+
{"prompt": stringify_thread(sample, sep=f"\n\n{'-' * 10}\n\n"), "error": "No user turn found in thread"}
|
|
56
|
+
)
|
|
57
|
+
raise ValueError("No user turn found in thread")
|
|
58
|
+
|
|
59
|
+
judging_threads = [
|
|
60
|
+
(
|
|
61
|
+
StringThread()
|
|
62
|
+
.system(SYSTEM.format(json_schema=self.model.render_schema(DocumentRelevancyResult), shots=self.shots))
|
|
63
|
+
.user(USER.format(user_question=user_question, document=document))
|
|
64
|
+
)
|
|
65
|
+
for document in documents
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
judge_tasks = [
|
|
70
|
+
self.model.temperature(0.0).generate_and_validate(thread, DocumentRelevancyResult)
|
|
71
|
+
for thread in judging_threads
|
|
72
|
+
]
|
|
73
|
+
results = await asyncio.gather(*judge_tasks)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
self.add_log(
|
|
76
|
+
{
|
|
77
|
+
"error": str(e),
|
|
78
|
+
"number_of_documents": len(documents),
|
|
79
|
+
"documents": documents,
|
|
80
|
+
"prompt": stringify_thread(judging_threads[0]),
|
|
81
|
+
}
|
|
82
|
+
)
|
|
83
|
+
raise
|
|
84
|
+
|
|
85
|
+
doc_relevancy_results = [result[1] for result in results]
|
|
86
|
+
|
|
87
|
+
reason = ""
|
|
88
|
+
for i, (document, doc_result) in enumerate(zip(documents, doc_relevancy_results)):
|
|
89
|
+
emoji = "✅" if doc_result.score == 1 else "❌"
|
|
90
|
+
result = "PASS" if doc_result.score == 1 else "FAIL"
|
|
91
|
+
doc_display = document[:150] + ("..." if len(document) > 150 else "")
|
|
92
|
+
reason += f"{emoji} Document {i}: {result}\n Content: {doc_display}:\nReason: {doc_result.reason}\n\n"
|
|
93
|
+
|
|
94
|
+
score = np.mean([float(verdict.score) for verdict in doc_relevancy_results]) if doc_relevancy_results else 0.0
|
|
95
|
+
self.add_log(
|
|
96
|
+
{
|
|
97
|
+
"score": score,
|
|
98
|
+
"reasoning": reason,
|
|
99
|
+
"number_of_documents": len(documents),
|
|
100
|
+
"documents": documents,
|
|
101
|
+
"prompt": stringify_thread(judging_threads[0]),
|
|
102
|
+
}
|
|
103
|
+
)
|
|
104
|
+
return Grade(value=float(score), grader_key=self.grader_id_or_key, reasoning=reason)
|
|
105
|
+
|
|
106
|
+
def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, float | Table]:
|
|
107
|
+
# Only clear logs at the end if clear is True
|
|
108
|
+
logs = super().get_logs(clear=False)
|
|
109
|
+
|
|
110
|
+
successfully_scored_samples = [log for log in self._logs if "score" in log]
|
|
111
|
+
|
|
112
|
+
# stratified sample range of scores to see high and low
|
|
113
|
+
if not log_all_samples:
|
|
114
|
+
subset_successfully_scored_samples = sample_score_distribution(successfully_scored_samples, 15)
|
|
115
|
+
else:
|
|
116
|
+
# if we have fewer than 15 samples or we want to log all samples, take them all
|
|
117
|
+
subset_successfully_scored_samples = successfully_scored_samples
|
|
118
|
+
|
|
119
|
+
failed_scored_samples = [log for log in self._logs if "error" in log]
|
|
120
|
+
|
|
121
|
+
sample_logs = self.get_sample_tables(subset_successfully_scored_samples, failed_scored_samples)
|
|
122
|
+
|
|
123
|
+
logs.update(sample_logs)
|
|
124
|
+
|
|
125
|
+
if clear:
|
|
126
|
+
self.clear_logs()
|
|
127
|
+
|
|
128
|
+
return logs
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
DEFAULT_SCORING_SHOTS = """Example:
|
|
2
|
+
INPUT
|
|
3
|
+
What percentage is considered a good rental yield?
|
|
4
|
+
|
|
5
|
+
STATEMENTS
|
|
6
|
+
0: How are you doing today?
|
|
7
|
+
1: Rental yield is how much you could expect to receive in rent each year from your buy to let investment.
|
|
8
|
+
2: Rental yield is expressed as a percentage - reflecting your rental income against the property's market value.
|
|
9
|
+
3: Anything around the 5-6% mark could be considered a good rental yield.
|
|
10
|
+
4: Anything above 6% could be considered a very good rental yield.
|
|
11
|
+
|
|
12
|
+
```json
|
|
13
|
+
{
|
|
14
|
+
"verdicts": [
|
|
15
|
+
{
|
|
16
|
+
"reason": "The statement is unrelated to the input.",
|
|
17
|
+
"score": 0
|
|
18
|
+
},
|
|
19
|
+
{
|
|
20
|
+
"reason": "While the statement discusses rental yields, it does not indicate what constitutes a good rental yield.",
|
|
21
|
+
"score": 0
|
|
22
|
+
},
|
|
23
|
+
{
|
|
24
|
+
"reason": "While the statement mentions that yield is expressed as a percentage, it does not address the user question.",
|
|
25
|
+
"score": 0
|
|
26
|
+
},
|
|
27
|
+
{
|
|
28
|
+
"reason": "The statement addresses the user input, specifying what a good rental yield is.",
|
|
29
|
+
"score": 1
|
|
30
|
+
},
|
|
31
|
+
{
|
|
32
|
+
"reason": "The statement addresses the user input, specifying what a very good rental yield is.",
|
|
33
|
+
"score": 1
|
|
34
|
+
},
|
|
35
|
+
]
|
|
36
|
+
}```"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
SYSTEM = """You are an expert data reviewer.
|
|
40
|
+
You will be given a document, and a user input/question.
|
|
41
|
+
Your task is to classify whether the document contains information relevant to answering the user input.
|
|
42
|
+
|
|
43
|
+
IMPORTANT: Please make sure to only return in JSON format and with no further preamble explanation, with the `score` key providing the score, and the `reason` key providing the reason. The score can only be 0 or 1. Keep the reason short and concise.
|
|
44
|
+
|
|
45
|
+
You always output a JSON object with the following schema, and nothing else before or after:
|
|
46
|
+
{json_schema}
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
{shots}"""
|
|
50
|
+
|
|
51
|
+
USER = """Your real task:
|
|
52
|
+
INPUT
|
|
53
|
+
{user_question}
|
|
54
|
+
|
|
55
|
+
DOCUMENT
|
|
56
|
+
{document}
|
|
57
|
+
|
|
58
|
+
```json"""
|
|
59
|
+
|
|
60
|
+
DEFAULT_SHOTS = """Example 1:
|
|
61
|
+
INPUT
|
|
62
|
+
What percentage is considered a good rental yield?
|
|
63
|
+
|
|
64
|
+
DOCUMENT
|
|
65
|
+
Rental yield is how much you could expect to receive in rent each year from your buy to let investment. Rental yield is expressed as a percentage - reflecting your rental income against the property's market value.
|
|
66
|
+
While calculating rental yield can give you an indication of whether investing in a buy-to-let property is worth it, there's other factors to consider. You'll also need to think about whether there might be any problem finding tenants, collecting rent or void periods, for example. Then, there's capital growth - the value by which your property is set to increase over time.
|
|
67
|
+
All of these can impact your decision on whether a property is worth your investment.
|
|
68
|
+
|
|
69
|
+
```json
|
|
70
|
+
{"score":0,"reason":"While the document explains the concept of rental yield, there is no indication as to what a good percentage yield is."}
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
Example 2:
|
|
74
|
+
INPUT
|
|
75
|
+
What percentage is considered a good rental yield?
|
|
76
|
+
|
|
77
|
+
DOCUMENT
|
|
78
|
+
As of 2024, the average rental yield in the UK is between 5% and 8%. Anything around the 5-6% mark could be considered a 'good' rental yield, while anything above 6% could be considered very good.
|
|
79
|
+
Some parts of the country can deliver significantly higher or lower returns to others. It's worth bearing in mind that you may get a lower yield in areas where the house prices are highest, such as in London and the South East.
|
|
80
|
+
This is because the potential for capital gains in the region pushes sale prices up, while rent levels are less affected.
|
|
81
|
+
|
|
82
|
+
```json
|
|
83
|
+
{"score":1,"reason":"The document indicates what can be considered a good rental yield percentage."}
|
|
84
|
+
```"""
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
class IgnoreScoreException(Exception):
|
|
2
|
+
"""
|
|
3
|
+
Exception to indicate that a score should be ignored/is not applicable.
|
|
4
|
+
Accepts a custom message explaining why the score was ignored.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
def __init__(self, message: str = "Score should be ignored"):
|
|
8
|
+
super().__init__(message)
|
|
9
|
+
self.message = message
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import pysbd
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from adaptive_harmony import Grade, InferenceModel, StringThread
|
|
7
|
+
from adaptive_harmony.core.structured_output import JsonParseError
|
|
8
|
+
from adaptive_harmony.core.utils import stringify_thread
|
|
9
|
+
from adaptive_harmony.graders.base_grader import BaseGrader
|
|
10
|
+
from adaptive_harmony.graders.exceptions import IgnoreScoreException
|
|
11
|
+
from adaptive_harmony.graders.faithfulness_judge.prompts import SYSTEM, USER
|
|
12
|
+
from adaptive_harmony.graders.utils import (
|
|
13
|
+
FailedJudgeLog,
|
|
14
|
+
SuccessJudgeLog,
|
|
15
|
+
sample_score_distribution,
|
|
16
|
+
separate_context_from_last_user_turn,
|
|
17
|
+
validate_thread_last_assistant,
|
|
18
|
+
)
|
|
19
|
+
from adaptive_harmony.logging_table import Table
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SingleStatementFaithfulnessJudgeOutput(BaseModel):
|
|
23
|
+
statement_idx: int = Field(description="The original index of the sentence being scored")
|
|
24
|
+
reasoning: str = Field(description="Reasoning to support the rationale behind the score")
|
|
25
|
+
score: Literal["1", "0"] = Field(
|
|
26
|
+
description="The score of the sample, 1 if the statement is fully supported by the context, 0 if it is not"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FaithfulnessGraderOutput(BaseModel):
|
|
31
|
+
all_statements_scoring: list[SingleStatementFaithfulnessJudgeOutput] = Field(
|
|
32
|
+
description="An array of objects, each analyzing a single statement from the original list of statements"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
SupportedLanguages = Literal[
|
|
37
|
+
"en",
|
|
38
|
+
"hi",
|
|
39
|
+
"mr",
|
|
40
|
+
"zh",
|
|
41
|
+
"es",
|
|
42
|
+
"am",
|
|
43
|
+
"ar",
|
|
44
|
+
"hy",
|
|
45
|
+
"bg",
|
|
46
|
+
"ur",
|
|
47
|
+
"ru",
|
|
48
|
+
"pl",
|
|
49
|
+
"fa",
|
|
50
|
+
"nl",
|
|
51
|
+
"da",
|
|
52
|
+
"fr",
|
|
53
|
+
"my",
|
|
54
|
+
"el",
|
|
55
|
+
"it",
|
|
56
|
+
"ja",
|
|
57
|
+
"de",
|
|
58
|
+
"kk",
|
|
59
|
+
"sk",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class FaithfulnessGrader(BaseGrader):
|
|
64
|
+
"""
|
|
65
|
+
Scores each sentence in the last assistant turn as fully supported by the context or not (1 or 0).
|
|
66
|
+
The context is the rest of the thread, excluding the system prompt.
|
|
67
|
+
The final score is the average of each sentence.
|
|
68
|
+
Requires an input language code to split the sentences.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
model: InferenceModel,
|
|
74
|
+
language: SupportedLanguages = "en",
|
|
75
|
+
grader_key: str = "faithfulness_judge",
|
|
76
|
+
grader_id: str | None = None,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(grader_key)
|
|
79
|
+
self._logs: list[SuccessJudgeLog | FailedJudgeLog] = [] # already created in super, this is for typing
|
|
80
|
+
self.model = model
|
|
81
|
+
self.language = language
|
|
82
|
+
self.sentence_splitter = pysbd.Segmenter(language=language)
|
|
83
|
+
self.grader_id_or_key = grader_id or grader_key
|
|
84
|
+
|
|
85
|
+
async def grade(self, sample: StringThread) -> Grade:
|
|
86
|
+
# Split response into sentences
|
|
87
|
+
validate_thread_last_assistant(sample)
|
|
88
|
+
# Separate conversation context from last user turn
|
|
89
|
+
context_turns, user_question = separate_context_from_last_user_turn(sample)
|
|
90
|
+
completion = sample.last_content()
|
|
91
|
+
split_sentences = self.sentence_splitter.segment(completion)
|
|
92
|
+
sentences = [f"{i}: {sentence.strip()}" for i, sentence in enumerate(split_sentences) if sentence.strip()]
|
|
93
|
+
sentences_judge_str = "\n".join(sentences)
|
|
94
|
+
|
|
95
|
+
# Build prompt
|
|
96
|
+
context_str = stringify_thread(StringThread(context_turns))
|
|
97
|
+
judge_thread = (
|
|
98
|
+
StringThread()
|
|
99
|
+
.system(SYSTEM.format(json_schema=self.model.render_schema(FaithfulnessGraderOutput)))
|
|
100
|
+
.user(USER.format(context=context_str, user_question=user_question, sentences=sentences_judge_str))
|
|
101
|
+
)
|
|
102
|
+
judge_str_prompt = stringify_thread(judge_thread, sep=f"\n\n{'-' * 10}\n\n")
|
|
103
|
+
# Generate response
|
|
104
|
+
try:
|
|
105
|
+
_, parsed_response = await self.model.temperature(0.0).generate_and_validate(
|
|
106
|
+
judge_thread, FaithfulnessGraderOutput
|
|
107
|
+
)
|
|
108
|
+
except JsonParseError as e:
|
|
109
|
+
self.add_log({"prompt": judge_str_prompt, "error": f"{str(e)}\n\nCOMPLETION:\n{e.completion}"})
|
|
110
|
+
raise
|
|
111
|
+
except Exception as e:
|
|
112
|
+
self.add_log({"prompt": judge_str_prompt, "error": str(e)})
|
|
113
|
+
raise
|
|
114
|
+
|
|
115
|
+
# Raise error if judge failed to judge any sentence
|
|
116
|
+
n_judged_sentences = len(parsed_response.all_statements_scoring)
|
|
117
|
+
if n_judged_sentences != len(sentences):
|
|
118
|
+
raise IgnoreScoreException(
|
|
119
|
+
f"Number of sentences in the response ({n_judged_sentences})"
|
|
120
|
+
f"does not match the number of sentences in the input ({len(sentences)})"
|
|
121
|
+
)
|
|
122
|
+
# Calculate avg score
|
|
123
|
+
score = round(
|
|
124
|
+
sum([float(judgement.score) for judgement in parsed_response.all_statements_scoring]) / n_judged_sentences,
|
|
125
|
+
3,
|
|
126
|
+
)
|
|
127
|
+
merged_reasoning_traces = "\n-".join(
|
|
128
|
+
[judgement.reasoning for judgement in parsed_response.all_statements_scoring]
|
|
129
|
+
)
|
|
130
|
+
self.add_log({"score": score, "prompt": judge_str_prompt, "reasoning": merged_reasoning_traces})
|
|
131
|
+
|
|
132
|
+
return Grade(value=score, grader_key=self.grader_id_or_key, reasoning=merged_reasoning_traces)
|
|
133
|
+
|
|
134
|
+
def add_log(self, log: SuccessJudgeLog | FailedJudgeLog) -> None:
|
|
135
|
+
self._logs.append(log)
|
|
136
|
+
|
|
137
|
+
def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, float | Table]:
|
|
138
|
+
# Only clear logs at the end if clear is True
|
|
139
|
+
logs = super().get_logs(clear=False)
|
|
140
|
+
|
|
141
|
+
successfully_scored_samples = [log for log in self._logs if "score" in log]
|
|
142
|
+
|
|
143
|
+
# stratified sample range of scores to see high and low
|
|
144
|
+
if not log_all_samples:
|
|
145
|
+
subset_successfully_scored_samples = sample_score_distribution(successfully_scored_samples, 15)
|
|
146
|
+
else:
|
|
147
|
+
# if we have fewer than 15 samples or we want to log all samples, take them all
|
|
148
|
+
subset_successfully_scored_samples = successfully_scored_samples
|
|
149
|
+
|
|
150
|
+
failed_scored_samples = [log for log in self._logs if "error" in log]
|
|
151
|
+
|
|
152
|
+
sample_logs = self.get_sample_tables(subset_successfully_scored_samples, failed_scored_samples)
|
|
153
|
+
|
|
154
|
+
logs.update(sample_logs)
|
|
155
|
+
|
|
156
|
+
if clear:
|
|
157
|
+
self.clear_logs()
|
|
158
|
+
|
|
159
|
+
return logs
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
SYSTEM = """Your task is to judge the faithfulness of a series of statements to the provided context.
|
|
2
|
+
For each statement_idx you must return your reasoning to support a score you attribute to the corresponding statement.
|
|
3
|
+
The score should be 1 in case the statement is fully supported and can be directly inferred based on the context, or 0 in case it cannot.
|
|
4
|
+
If there is no relevant context to be faithful to, the score should be 1.
|
|
5
|
+
You must score every single sentence without skipping any.
|
|
6
|
+
|
|
7
|
+
If it exists, you will be given the whole CONVERSATION so far leading up to the LAST USER TURN.
|
|
8
|
+
You must evaluate the statements with a focus on what is being asked in the LAST USER TURN, and never on an intermediary questions that might have been asked in course of the conversation.
|
|
9
|
+
|
|
10
|
+
You always output a JSON object with the following schema, and nothing else before or after:
|
|
11
|
+
{json_schema}"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
USER = """CONVERSATION
|
|
15
|
+
{context}
|
|
16
|
+
|
|
17
|
+
LAST USER TURN
|
|
18
|
+
{user_question}
|
|
19
|
+
|
|
20
|
+
STATEMENTS
|
|
21
|
+
{sentences}
|
|
22
|
+
"""
|