adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,232 @@
1
+ import textwrap
2
+ from typing import NamedTuple, TypedDict
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from adaptive_harmony import StringThread, StringTurn
7
+ from adaptive_harmony.core.structured_output import render_pydantic_model
8
+ from adaptive_harmony.core.utils import stringify_thread
9
+ from adaptive_harmony.graders.range_judge.types import PromptBuildingBlocks, ReasonedScore
10
+ from adaptive_harmony.graders.utils import (
11
+ separate_context_from_last_user_turn,
12
+ validate_thread_last_assistant,
13
+ )
14
+
15
+
16
+ class RangeJudgeShot(BaseModel):
17
+ thread: StringThread
18
+ reasoning: str
19
+ score: int
20
+
21
+ class Config:
22
+ arbitrary_types_allowed = True
23
+
24
+
25
+ class RangeShots(TypedDict):
26
+ reasoning: list[StringTurn]
27
+ scoring: list[StringTurn]
28
+
29
+
30
+ class SubrangeExpectations(NamedTuple):
31
+ subrange: tuple[int, int]
32
+ expectation: str
33
+
34
+
35
+ COMMON_EVALUATION_PROMPT = """You are an expert evaluator of AI-user interactions.
36
+ You will be given:
37
+ - CONTEXT : previous conversation history, might be empty
38
+ - LAST USER INPUT : the latest input from the user
39
+ - LAST ASSISTANT OUTPUT : the latest output/answer from the AI
40
+ - CRITERIA: the evaluation criteria
41
+ - EVALUATION STEPS : logical reasoning steps to take when evaluating the interaction against the CRITERIA
42
+ """
43
+
44
+
45
+ def get_common_user_template(
46
+ context: str, last_user_input: str, assistant_answer: str, criteria: str, evaluation_steps: str
47
+ ):
48
+ return f"""CONTEXT\n{context}\n
49
+ LAST USER INPUT\n{last_user_input}\n
50
+ LAST ASSISTANT OUTPUT\n{assistant_answer}\n
51
+ CRITERIA\n{criteria}\n
52
+ EVALUATION STEPS\n{evaluation_steps}"""
53
+
54
+
55
+ class RangeScorerTemplates:
56
+ @staticmethod
57
+ def get_evaluation_steps(criteria: str) -> StringThread:
58
+ return (
59
+ StringThread()
60
+ .system(
61
+ textwrap.dedent(
62
+ """\
63
+ Given an evaluation criteria which outlines how you should judge an interaction between an AI and a user, generate 3-4 concise evaluation steps based on the criteria below.
64
+
65
+ Return your evaluation steps as a numbered list of evaluation steps, such as:
66
+
67
+ Steps list:
68
+ 1. First step
69
+ 2. Second step
70
+ 3. Third step
71
+ etc.
72
+
73
+ Focus on specific, concise steps that can be objectively followed based on the evaluation criteria provided.
74
+ Don't return any preamble or explanation, only the list.
75
+ """
76
+ )
77
+ )
78
+ .user(f"Evaluation Criteria:\n{criteria}\nSteps list:\n")
79
+ )
80
+
81
+ @staticmethod
82
+ def get_json_reasoned_score_user(
83
+ context: str, last_user_input: str, assistant_answer: str, criteria: str, evaluation_steps: str
84
+ ):
85
+ common = get_common_user_template(context, last_user_input, assistant_answer, criteria, evaluation_steps)
86
+ return f"{common}\n\nJSON OUTPUT:\n"
87
+
88
+ @staticmethod
89
+ def get_json_reasoned_score(
90
+ context: str,
91
+ last_user_input: str,
92
+ assistant_answer: str,
93
+ criteria: str,
94
+ evaluation_steps: str,
95
+ score_range: tuple[int, int],
96
+ json_schema: str,
97
+ subrange_expectations: list[SubrangeExpectations] | None = None,
98
+ shots: list[StringTurn] | None = None,
99
+ ) -> StringThread:
100
+ if subrange_expectations:
101
+ subrange_expectations_str = "which should correspond to:\n" + "\n".join(
102
+ [f"{sub.subrange[0]} - {sub.subrange[1]}: {sub.expectation}" for sub in subrange_expectations]
103
+ )
104
+ else:
105
+ subrange_expectations_str = f"where {score_range[1]} indicates strong alignment with the criteria and {score_range[0]} indicates no alignment."
106
+
107
+ system_prompt = (
108
+ COMMON_EVALUATION_PROMPT
109
+ + f"""
110
+ Your task is to evaluate and score the ASSISTANT ANSWER, strictly following the provided EVALUATION STEPS to evaluate the CRITERIA.
111
+ You must respond with a valid JSON object that matches the following schema:
112
+
113
+ {json_schema}
114
+
115
+ Your reasoning for the score:
116
+ - Be specific and grounded in the EVALUATION STEPS
117
+ - Uphold the evaluation objective and nuances expressed in the CRITERIA as the main target
118
+ - Mention specific details, strenghts or shortcomings of the answer, referencing relevant details from the input
119
+ - Be concise, clear, and focused on the evaluation logic.
120
+ - **Never** quote the score itself in the explanation; focus only on reasoning through the evaluation steps
121
+
122
+ Your final evaluation score must be strictly within the range of [{score_range[0]} - {score_range[1]}], {subrange_expectations_str}
123
+
124
+ Return only the JSON object after the OUTPUT header, no other text, preamble or explanation.
125
+ """
126
+ )
127
+
128
+ user_prompt = RangeScorerTemplates.get_json_reasoned_score_user(
129
+ context, last_user_input, assistant_answer, criteria, evaluation_steps
130
+ )
131
+
132
+ shots = shots or []
133
+ return StringThread([("system", system_prompt)] + shots + [("user", user_prompt)])
134
+
135
+ @staticmethod
136
+ def get_up_to_score_user(
137
+ context: str, last_user_input: str, assistant_answer: str, criteria: str, evaluation_steps: str, reasoning: str
138
+ ) -> str:
139
+ common = get_common_user_template(context, last_user_input, assistant_answer, criteria, evaluation_steps)
140
+ return f"{common}\n\nREASONING\n{reasoning}\n\nSCORE: "
141
+
142
+ @staticmethod
143
+ def get_up_to_score(
144
+ context: str,
145
+ last_user_input: str,
146
+ assistant_answer: str,
147
+ criteria: str,
148
+ evaluation_steps: str,
149
+ score_range: tuple[int, int],
150
+ reasoning: str,
151
+ subrange_expectations: list[SubrangeExpectations] | None = None,
152
+ shots: list[StringTurn] | None = None,
153
+ ) -> StringThread:
154
+ if subrange_expectations:
155
+ subrange_expectations_str = "which should correspond to:\n" + "\n".join(
156
+ [f"{sub.subrange[0]} - {sub.subrange[1]}: {sub.expectation}" for sub in subrange_expectations]
157
+ )
158
+ else:
159
+ subrange_expectations_str = f"where {score_range[1]} indicates strong alignment with the evaluation steps and {score_range[0]} indicates no alignment."
160
+
161
+ system_prompt = COMMON_EVALUATION_PROMPT + textwrap.dedent(
162
+ f"""\
163
+ - REASONING : the reasoning for the score, following the process described by the EVALUATION STEPS to assess the presented interaction against the CRITERIA
164
+
165
+ You must respond only with a score, based on the original CRITERIA and the REASONING for the sample,
166
+ which should justify your score for the sample.
167
+
168
+ Your final evaluation score must be strictly within the range of [{score_range[0]} - {score_range[1]}], {subrange_expectations_str}
169
+
170
+ Return only the integer score, nothing before or after.
171
+ """
172
+ )
173
+ user_prompt = RangeScorerTemplates.get_up_to_score_user(
174
+ context, last_user_input, assistant_answer, criteria, evaluation_steps, reasoning
175
+ )
176
+
177
+ shots = shots or []
178
+ return StringThread([("system", system_prompt)] + shots + [("user", user_prompt)])
179
+
180
+
181
+ def get_prompt_building_blocks(thread: StringThread) -> PromptBuildingBlocks:
182
+ validate_thread_last_assistant(thread)
183
+ context_turns, last_user_turn = separate_context_from_last_user_turn(thread)
184
+ context_str = stringify_thread(StringThread(context_turns))
185
+ last_assistant_turn = thread.last_content()
186
+ assert last_user_turn, "There must be at least one user turn"
187
+ return PromptBuildingBlocks(
188
+ context=context_str, last_user_turn=last_user_turn, last_assistant_turn=last_assistant_turn
189
+ )
190
+
191
+
192
+ def create_shots(criteria: str, evaluation_steps: str, shots: list[RangeJudgeShot]) -> RangeShots:
193
+ reasoning_shots: list[StringTurn] = []
194
+ scoring_shots: list[StringTurn] = []
195
+ for shot in shots:
196
+ prompt_components = get_prompt_building_blocks(shot.thread)
197
+ reasoning_shots.extend(
198
+ [
199
+ StringTurn(
200
+ role="user",
201
+ content=RangeScorerTemplates.get_json_reasoned_score_user(
202
+ prompt_components.context,
203
+ prompt_components.last_user_turn,
204
+ prompt_components.last_assistant_turn,
205
+ criteria,
206
+ evaluation_steps,
207
+ ),
208
+ ),
209
+ StringTurn(
210
+ role="assistant",
211
+ content=render_pydantic_model(ReasonedScore(reasoning=shot.reasoning, score=shot.score)),
212
+ ),
213
+ ]
214
+ )
215
+ scoring_shots.extend(
216
+ [
217
+ StringTurn(
218
+ role="user",
219
+ content=RangeScorerTemplates.get_up_to_score_user(
220
+ prompt_components.context,
221
+ prompt_components.last_user_turn,
222
+ prompt_components.last_assistant_turn,
223
+ criteria,
224
+ evaluation_steps,
225
+ shot.reasoning,
226
+ ),
227
+ ),
228
+ StringTurn(role="assistant", content=str(shot.score)),
229
+ ]
230
+ )
231
+
232
+ return {"reasoning": reasoning_shots, "scoring": scoring_shots}
@@ -0,0 +1,188 @@
1
+ import asyncio
2
+
3
+ import numpy as np
4
+
5
+ from adaptive_harmony import InferenceModel, StringThread
6
+ from adaptive_harmony.core.structured_output import JsonParseError
7
+ from adaptive_harmony.core.utils import stringify_thread
8
+ from adaptive_harmony.graders import BaseGrader, Grade
9
+ from adaptive_harmony.graders.range_judge.prompts import (
10
+ RangeJudgeShot,
11
+ RangeScorerTemplates,
12
+ ReasonedScore,
13
+ SubrangeExpectations,
14
+ create_shots,
15
+ get_prompt_building_blocks,
16
+ )
17
+ from adaptive_harmony.graders.utils import (
18
+ FailedJudgeLog,
19
+ SuccessJudgeLog,
20
+ sample_score_distribution,
21
+ )
22
+ from adaptive_harmony.logging_table import Table
23
+
24
+
25
+ class RangeJudgeGrader(BaseGrader[SuccessJudgeLog | FailedJudgeLog]):
26
+ """
27
+ Scores a thread in a range of integer scores, based on a list of evaluation steps.
28
+ If evaluation steps are not provided, they are generated from the criteria.
29
+ The final score is computed as a weighted average of all possible scores,
30
+ where the weights are the logprobs of each score.
31
+ You can pass subrange_expectations to the scorer, to help the judge
32
+ understand the correspondence between score subranges and expected quality levels.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ grader_key: str,
38
+ model: InferenceModel,
39
+ criteria: str,
40
+ score_range: tuple[int, int] = (1, 5),
41
+ evaluation_steps: list[str] | None = None,
42
+ subrange_expectations: list[SubrangeExpectations] | None = None,
43
+ shots: list[RangeJudgeShot] | None = None,
44
+ normalize_score: bool = True,
45
+ ):
46
+ model_path: str = model.get_builder_args().get("path") # type: ignore[assignment]
47
+ assert model_path.startswith("model_registry://"), "External models cannot be used in RangeJudgeScorer"
48
+
49
+ super().__init__(grader_key)
50
+ self.model = model
51
+ self.criteria = criteria
52
+ self.score_range = score_range
53
+ self.min_score, self.max_score = score_range
54
+ self.subrange_expectations = subrange_expectations
55
+ self._shots = shots
56
+ self.normalize_score = normalize_score
57
+
58
+ if evaluation_steps is None:
59
+ self._str_evaluation_steps = None
60
+ self._list_eval_steps = None
61
+ else:
62
+ self._str_evaluation_steps = "\n".join([f"{i + 1}: {step}" for i, step in enumerate(evaluation_steps)])
63
+ self._list_eval_steps = evaluation_steps
64
+
65
+ @property
66
+ def evaluation_steps(self) -> list[str] | None:
67
+ return self._list_eval_steps
68
+
69
+ @property
70
+ def str_evaluation_steps(self) -> str | None:
71
+ return self._str_evaluation_steps
72
+
73
+ @evaluation_steps.setter
74
+ def evaluation_steps(self, steps: list[str]):
75
+ self._list_eval_steps = steps
76
+ self._str_evaluation_steps = "\n".join([f"{i + 1}: {step}" for i, step in enumerate(steps)])
77
+
78
+ async def generate_evaluation_steps(self) -> list[str]:
79
+ thread = await self.model.temperature(0.0).generate(RangeScorerTemplates.get_evaluation_steps(self.criteria))
80
+ self._str_evaluation_steps = thread.last_content()
81
+ self._list_eval_steps = self._str_evaluation_steps.split("\n")
82
+ assert self.evaluation_steps
83
+ return self.evaluation_steps
84
+
85
+ async def grade(self, sample: StringThread) -> Grade:
86
+ if self.evaluation_steps is None:
87
+ if self._shots is not None:
88
+ raise ValueError(
89
+ "You cannot pass shots without specifying evaluations steps, since your shots' reasoning must match the steps"
90
+ )
91
+ await self.generate_evaluation_steps()
92
+
93
+ # Format shots into both user turn formats (for first reasoned scoring step, and for last logprobs scoring step)
94
+ assert self.str_evaluation_steps
95
+
96
+ shots = create_shots(self.criteria, self.str_evaluation_steps, self._shots) if self._shots is not None else {}
97
+
98
+ # Separate relevant parts of the prompt turns
99
+ prompt_components = get_prompt_building_blocks(sample)
100
+ # Get reasoned scoring thread
101
+ eval_thread = RangeScorerTemplates.get_json_reasoned_score(
102
+ context=prompt_components.context,
103
+ last_user_input=prompt_components.last_user_turn,
104
+ assistant_answer=prompt_components.last_assistant_turn,
105
+ criteria=self.criteria,
106
+ evaluation_steps=self.str_evaluation_steps,
107
+ score_range=self.score_range,
108
+ json_schema=self.model.render_schema(ReasonedScore),
109
+ subrange_expectations=self.subrange_expectations,
110
+ shots=shots.get("reasoning"),
111
+ )
112
+ eval_str_prompt = stringify_thread(eval_thread, sep=f"\n\n{'-' * 10}\n\n")
113
+ try:
114
+ _, reasoned_score = await self.model.temperature(0.0).generate_and_validate(eval_thread, ReasonedScore)
115
+ except JsonParseError as e:
116
+ self.add_log({"prompt": eval_str_prompt, "error": f"{str(e)}\n\nCOMPLETION:\n{e.completion}"})
117
+ raise
118
+ except Exception as e:
119
+ self.add_log({"prompt": eval_str_prompt, "error": str(e)})
120
+ raise
121
+
122
+ # Get a prompt that includes the reasoning for the sample, all the way to form-filling the score
123
+ up_to_score_thread = RangeScorerTemplates.get_up_to_score(
124
+ context=prompt_components.context,
125
+ last_user_input=prompt_components.last_user_turn,
126
+ assistant_answer=prompt_components.last_assistant_turn,
127
+ criteria=self.criteria,
128
+ evaluation_steps=self.str_evaluation_steps,
129
+ score_range=self.score_range,
130
+ reasoning=reasoned_score.reasoning,
131
+ subrange_expectations=self.subrange_expectations,
132
+ shots=shots.get("scoring"),
133
+ )
134
+
135
+ # Get logprobs for each possible final score
136
+ possible_score_ints = [s for s in range(self.min_score, self.max_score + 1)]
137
+ logprobs = await asyncio.gather(
138
+ *[self.model.temperature(0.0).logprobs(up_to_score_thread.assistant(f"{s}")) for s in possible_score_ints]
139
+ )
140
+
141
+ # Convert to probabilities and compute weighted average
142
+ probs = np.exp(logprobs - np.logaddexp.reduce(logprobs))
143
+ weighted_score = np.average(possible_score_ints, weights=probs)
144
+
145
+ final_score: float = weighted_score
146
+ if self.normalize_score: # normalize to 0-1 range
147
+ final_score = (weighted_score - self.min_score) / (self.max_score - self.min_score)
148
+
149
+ str_prompt = stringify_thread(eval_thread, sep=f"\n\n{'-' * 10}\n\n")
150
+ self.add_log({"score": final_score, "prompt": str_prompt, "reasoning": reasoned_score.reasoning})
151
+
152
+ metadata = dict(
153
+ criteria=self.criteria,
154
+ raw_avg_score=float(weighted_score),
155
+ scale_range=(self.min_score, self.max_score),
156
+ score_probabilities={str(score): float(prob) for score, prob in zip(possible_score_ints, probs)},
157
+ evaluation_steps=self.evaluation_steps,
158
+ reasoning=reasoned_score.reasoning,
159
+ )
160
+
161
+ return Grade(value=float(final_score), grader_key=self.grader_key, reasoning=reasoned_score.reasoning)
162
+
163
+ def add_log(self, log_data: SuccessJudgeLog | FailedJudgeLog) -> None:
164
+ self._logs.append(log_data)
165
+
166
+ def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, float | Table]:
167
+ # Only clear logs at the end if clear is True
168
+ logs = super().get_logs(clear=False)
169
+
170
+ successfully_scored_samples = [log for log in self._logs if "score" in log]
171
+
172
+ # stratified sample range of scores to see high and low
173
+ if not log_all_samples:
174
+ subset_successfully_scored_samples = sample_score_distribution(successfully_scored_samples, 15)
175
+ else:
176
+ # if we have fewer than 15 samples or we want to log all samples, take them all
177
+ subset_successfully_scored_samples = successfully_scored_samples
178
+
179
+ failed_scored_samples = [log for log in self._logs if "error" in log]
180
+
181
+ sample_logs = self.get_sample_tables(subset_successfully_scored_samples, failed_scored_samples)
182
+
183
+ logs.update(sample_logs)
184
+
185
+ if clear:
186
+ self.clear_logs()
187
+
188
+ return logs
@@ -0,0 +1,12 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class ReasonedScore(BaseModel):
5
+ reasoning: str = Field(description="String reasoning to support the rationale behind the score")
6
+ score: int = Field(description="Integer score for the interaction, must be within the specified score range")
7
+
8
+
9
+ class PromptBuildingBlocks(BaseModel):
10
+ context: str
11
+ last_user_turn: str
12
+ last_assistant_turn: str
@@ -0,0 +1,36 @@
1
+ import asyncio
2
+
3
+ from adaptive_harmony import StringThread
4
+ from adaptive_harmony.core.reward_client.client import Request, RewardClient, Turn
5
+ from adaptive_harmony.graders.base_grader import BaseGrader, Grade
6
+
7
+
8
+ class RewardServerGrader(BaseGrader):
9
+ def __init__(self, grader_key: str, grader_id: str, reward_server_ip: str):
10
+ super().__init__(grader_key)
11
+ self.reward_client = RewardClient(reward_server_ip)
12
+ self.grader_id_or_key = grader_id or grader_key
13
+ self._setup_task = None
14
+ self._setup_lock = None
15
+
16
+ async def _ensure_setup(self):
17
+ if self._setup_lock is None:
18
+ self._setup_lock = asyncio.Lock()
19
+
20
+ if self._setup_task is None:
21
+ async with self._setup_lock:
22
+ if self._setup_task is None:
23
+ self._setup_task = asyncio.create_task(self.reward_client.setup())
24
+
25
+ await self._setup_task
26
+
27
+ async def grade(self, sample: StringThread) -> Grade:
28
+ await self._ensure_setup()
29
+
30
+ response = await self.reward_client.score(
31
+ Request(
32
+ turns=[Turn(content=turn.content, role=turn.role) for turn in sample.get_turns()],
33
+ metadata=sample.metadata,
34
+ )
35
+ )
36
+ return Grade(value=response.reward, grader_key=self.grader_id_or_key, reasoning=response.metadata.get("reason"))