themis-eval 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/__init__.py +12 -1
- themis/_version.py +2 -2
- themis/api.py +343 -0
- themis/backends/__init__.py +17 -0
- themis/backends/execution.py +197 -0
- themis/backends/storage.py +260 -0
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/results.py +252 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +463 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/comparison/__init__.py +25 -0
- themis/comparison/engine.py +348 -0
- themis/comparison/reports.py +283 -0
- themis/comparison/statistics.py +402 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +184 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/code/__init__.py +19 -0
- themis/evaluation/metrics/code/codebleu.py +144 -0
- themis/evaluation/metrics/code/execution.py +280 -0
- themis/evaluation/metrics/code/pass_at_k.py +181 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/nlp/__init__.py +21 -0
- themis/evaluation/metrics/nlp/bertscore.py +138 -0
- themis/evaluation/metrics/nlp/bleu.py +129 -0
- themis/evaluation/metrics/nlp/meteor.py +153 -0
- themis/evaluation/metrics/nlp/rouge.py +136 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +348 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +134 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +798 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +415 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +1458 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +72 -0
- themis/integrations/wandb.py +77 -0
- themis/interfaces/__init__.py +169 -0
- themis/presets/__init__.py +10 -0
- themis/presets/benchmarks.py +354 -0
- themis/presets/models.py +190 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/server/__init__.py +28 -0
- themis/server/app.py +337 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- themis_eval-0.2.0.dist-info/METADATA +596 -0
- themis_eval-0.2.0.dist-info/RECORD +157 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/WHEEL +1 -1
- themis_eval-0.1.0.dist-info/METADATA +0 -758
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Turn strategies for multi-turn conversations.
|
|
2
|
+
|
|
3
|
+
This module provides strategies for determining the next turn in a conversation.
|
|
4
|
+
Strategies can be fixed (predefined sequences), dynamic (generated based on context),
|
|
5
|
+
or interactive.
|
|
6
|
+
|
|
7
|
+
Examples:
|
|
8
|
+
# Fixed sequence
|
|
9
|
+
strategy = FixedSequenceTurnStrategy([
|
|
10
|
+
"What is 2+2?",
|
|
11
|
+
"What about 3+3?",
|
|
12
|
+
"And 5+5?"
|
|
13
|
+
])
|
|
14
|
+
|
|
15
|
+
# Dynamic strategy
|
|
16
|
+
def planner(context):
|
|
17
|
+
if len(context) < 2:
|
|
18
|
+
return "Can you explain more?"
|
|
19
|
+
return None # Stop
|
|
20
|
+
|
|
21
|
+
strategy = DynamicTurnStrategy(planner)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from dataclasses import dataclass
|
|
27
|
+
from typing import Callable, Protocol
|
|
28
|
+
|
|
29
|
+
from themis.core import conversation as conv
|
|
30
|
+
from themis.core import entities as core_entities
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TurnStrategy(Protocol):
|
|
34
|
+
"""Strategy for determining the next turn in a conversation.
|
|
35
|
+
|
|
36
|
+
A turn strategy decides what the user's next message should be
|
|
37
|
+
based on the current conversation state and the last model response.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def next_turn(
|
|
41
|
+
self,
|
|
42
|
+
context: conv.ConversationContext,
|
|
43
|
+
last_record: core_entities.GenerationRecord,
|
|
44
|
+
) -> str | None:
|
|
45
|
+
"""Determine the next user message.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
context: Current conversation context
|
|
49
|
+
last_record: Last generation record
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Next user message, or None to end conversation
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class FixedSequenceTurnStrategy:
|
|
59
|
+
"""Pre-determined sequence of user messages.
|
|
60
|
+
|
|
61
|
+
This strategy iterates through a fixed list of user messages,
|
|
62
|
+
useful for scripted conversations or testing.
|
|
63
|
+
|
|
64
|
+
Examples:
|
|
65
|
+
strategy = FixedSequenceTurnStrategy([
|
|
66
|
+
"Hello!",
|
|
67
|
+
"How are you?",
|
|
68
|
+
"Goodbye!"
|
|
69
|
+
])
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
messages: list[str]
|
|
73
|
+
_index: int = 0
|
|
74
|
+
|
|
75
|
+
def next_turn(
|
|
76
|
+
self,
|
|
77
|
+
context: conv.ConversationContext,
|
|
78
|
+
last_record: core_entities.GenerationRecord,
|
|
79
|
+
) -> str | None:
|
|
80
|
+
"""Return next message from sequence.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
context: Current conversation context
|
|
84
|
+
last_record: Last generation record
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Next message or None if sequence exhausted
|
|
88
|
+
"""
|
|
89
|
+
if self._index >= len(self.messages):
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
message = self.messages[self._index]
|
|
93
|
+
self._index += 1
|
|
94
|
+
return message
|
|
95
|
+
|
|
96
|
+
def reset(self) -> None:
|
|
97
|
+
"""Reset strategy to beginning of sequence."""
|
|
98
|
+
self._index = 0
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class DynamicTurnStrategy:
|
|
103
|
+
"""Generate next message based on conversation state.
|
|
104
|
+
|
|
105
|
+
This strategy uses a function to dynamically determine the next
|
|
106
|
+
user message based on the conversation context.
|
|
107
|
+
|
|
108
|
+
Examples:
|
|
109
|
+
def planner(context, record):
|
|
110
|
+
outputs = [msg.content for msg in context.get_messages_by_role("assistant")]
|
|
111
|
+
if "error" in outputs[-1].lower():
|
|
112
|
+
return "Can you try again?"
|
|
113
|
+
elif len(context) >= 10:
|
|
114
|
+
return None # Stop after 10 messages
|
|
115
|
+
else:
|
|
116
|
+
return "Please continue."
|
|
117
|
+
|
|
118
|
+
strategy = DynamicTurnStrategy(planner)
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
planner: Callable[
|
|
122
|
+
[conv.ConversationContext, core_entities.GenerationRecord], str | None
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
def next_turn(
|
|
126
|
+
self,
|
|
127
|
+
context: conv.ConversationContext,
|
|
128
|
+
last_record: core_entities.GenerationRecord,
|
|
129
|
+
) -> str | None:
|
|
130
|
+
"""Generate next message using planner function.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
context: Current conversation context
|
|
134
|
+
last_record: Last generation record
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Next message or None to stop
|
|
138
|
+
"""
|
|
139
|
+
return self.planner(context, last_record)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class RepeatUntilSuccessTurnStrategy:
|
|
144
|
+
"""Repeat the same question until getting a successful response.
|
|
145
|
+
|
|
146
|
+
This strategy is useful for testing robustness or debugging.
|
|
147
|
+
|
|
148
|
+
Examples:
|
|
149
|
+
strategy = RepeatUntilSuccessTurnStrategy(
|
|
150
|
+
question="What is 2+2?",
|
|
151
|
+
success_checker=lambda output: "4" in output,
|
|
152
|
+
max_attempts=5
|
|
153
|
+
)
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
question: str
|
|
157
|
+
success_checker: Callable[[str], bool]
|
|
158
|
+
max_attempts: int = 5
|
|
159
|
+
_attempts: int = 0
|
|
160
|
+
|
|
161
|
+
def next_turn(
|
|
162
|
+
self,
|
|
163
|
+
context: conv.ConversationContext,
|
|
164
|
+
last_record: core_entities.GenerationRecord,
|
|
165
|
+
) -> str | None:
|
|
166
|
+
"""Repeat question until success or max attempts.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
context: Current conversation context
|
|
170
|
+
last_record: Last generation record
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Question or None if success/max attempts reached
|
|
174
|
+
"""
|
|
175
|
+
# Check if this is first turn
|
|
176
|
+
if self._attempts == 0:
|
|
177
|
+
self._attempts += 1
|
|
178
|
+
return self.question
|
|
179
|
+
|
|
180
|
+
# Check if last response was successful
|
|
181
|
+
if last_record.output:
|
|
182
|
+
if self.success_checker(last_record.output.text):
|
|
183
|
+
return None # Success, stop
|
|
184
|
+
|
|
185
|
+
# Check if we've exhausted attempts
|
|
186
|
+
if self._attempts >= self.max_attempts:
|
|
187
|
+
return None # Give up
|
|
188
|
+
|
|
189
|
+
self._attempts += 1
|
|
190
|
+
return self.question
|
|
191
|
+
|
|
192
|
+
def reset(self) -> None:
|
|
193
|
+
"""Reset attempt counter."""
|
|
194
|
+
self._attempts = 0
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclass
|
|
198
|
+
class ConditionalTurnStrategy:
|
|
199
|
+
"""Choose next message based on conditions.
|
|
200
|
+
|
|
201
|
+
This strategy evaluates conditions and returns different messages
|
|
202
|
+
based on which condition matches.
|
|
203
|
+
|
|
204
|
+
Examples:
|
|
205
|
+
strategy = ConditionalTurnStrategy(
|
|
206
|
+
conditions=[
|
|
207
|
+
(lambda ctx, rec: "error" in rec.output.text.lower(), "Please try again."),
|
|
208
|
+
(lambda ctx, rec: len(ctx) >= 5, None), # Stop after 5 turns
|
|
209
|
+
],
|
|
210
|
+
default="Continue."
|
|
211
|
+
)
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
conditions: list[
|
|
215
|
+
tuple[
|
|
216
|
+
Callable[[conv.ConversationContext, core_entities.GenerationRecord], bool],
|
|
217
|
+
str | None,
|
|
218
|
+
]
|
|
219
|
+
]
|
|
220
|
+
default: str | None = None
|
|
221
|
+
|
|
222
|
+
def next_turn(
|
|
223
|
+
self,
|
|
224
|
+
context: conv.ConversationContext,
|
|
225
|
+
last_record: core_entities.GenerationRecord,
|
|
226
|
+
) -> str | None:
|
|
227
|
+
"""Evaluate conditions and return matching message.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
context: Current conversation context
|
|
231
|
+
last_record: Last generation record
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Message from first matching condition, or default
|
|
235
|
+
"""
|
|
236
|
+
for condition, message in self.conditions:
|
|
237
|
+
try:
|
|
238
|
+
if condition(context, last_record):
|
|
239
|
+
return message
|
|
240
|
+
except Exception:
|
|
241
|
+
# Skip conditions that fail
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
return self.default
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@dataclass
|
|
248
|
+
class ChainedTurnStrategy:
|
|
249
|
+
"""Chain multiple strategies together.
|
|
250
|
+
|
|
251
|
+
This strategy tries strategies in sequence until one returns
|
|
252
|
+
a non-None message.
|
|
253
|
+
|
|
254
|
+
Examples:
|
|
255
|
+
strategy = ChainedTurnStrategy([
|
|
256
|
+
FixedSequenceTurnStrategy(["Hello", "How are you?"]),
|
|
257
|
+
DynamicTurnStrategy(lambda ctx, rec: "Goodbye" if len(ctx) > 5 else None)
|
|
258
|
+
])
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
strategies: list[TurnStrategy]
|
|
262
|
+
|
|
263
|
+
def next_turn(
|
|
264
|
+
self,
|
|
265
|
+
context: conv.ConversationContext,
|
|
266
|
+
last_record: core_entities.GenerationRecord,
|
|
267
|
+
) -> str | None:
|
|
268
|
+
"""Try each strategy until one returns a message.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
context: Current conversation context
|
|
272
|
+
last_record: Last generation record
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
First non-None message, or None if all return None
|
|
276
|
+
"""
|
|
277
|
+
for strategy in self.strategies:
|
|
278
|
+
message = strategy.next_turn(context, last_record)
|
|
279
|
+
if message is not None:
|
|
280
|
+
return message
|
|
281
|
+
|
|
282
|
+
return None
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# Helper functions for creating common strategies
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def create_qa_strategy(questions: list[str]) -> FixedSequenceTurnStrategy:
|
|
289
|
+
"""Create a simple Q&A strategy from a list of questions.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
questions: List of questions to ask
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
FixedSequenceTurnStrategy with questions
|
|
296
|
+
"""
|
|
297
|
+
return FixedSequenceTurnStrategy(messages=questions)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def create_max_turns_strategy(
|
|
301
|
+
max_turns: int, message: str = "Continue."
|
|
302
|
+
) -> DynamicTurnStrategy:
|
|
303
|
+
"""Create strategy that stops after max turns.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
max_turns: Maximum number of turns
|
|
307
|
+
message: Message to send each turn
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
DynamicTurnStrategy that stops after max_turns
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
def planner(
|
|
314
|
+
context: conv.ConversationContext, record: core_entities.GenerationRecord
|
|
315
|
+
) -> str | None:
|
|
316
|
+
if len(context) >= max_turns:
|
|
317
|
+
return None
|
|
318
|
+
return message
|
|
319
|
+
|
|
320
|
+
return DynamicTurnStrategy(planner=planner)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def create_keyword_stop_strategy(
|
|
324
|
+
keywords: list[str], message: str = "Continue."
|
|
325
|
+
) -> DynamicTurnStrategy:
|
|
326
|
+
"""Create strategy that stops when any keyword appears in response.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
keywords: List of keywords to trigger stop
|
|
330
|
+
message: Message to send each turn
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
DynamicTurnStrategy that stops on keywords
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
def planner(
|
|
337
|
+
context: conv.ConversationContext, record: core_entities.GenerationRecord
|
|
338
|
+
) -> str | None:
|
|
339
|
+
if record.output:
|
|
340
|
+
text_lower = record.output.text.lower()
|
|
341
|
+
if any(kw.lower() in text_lower for kw in keywords):
|
|
342
|
+
return None
|
|
343
|
+
return message
|
|
344
|
+
|
|
345
|
+
return DynamicTurnStrategy(planner=planner)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
# Prompt perturbation and seed helpers for robustness sweeps
|
|
349
|
+
|
|
350
|
+
import random
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def set_sampling_seed(task_metadata: dict[str, object], seed: int) -> dict[str, object]:
|
|
354
|
+
"""Attach a deterministic seed to task metadata for providers that support it.
|
|
355
|
+
|
|
356
|
+
This does not enforce provider behavior but offers a convention: 'sampling_seed'.
|
|
357
|
+
"""
|
|
358
|
+
md = dict(task_metadata)
|
|
359
|
+
md["sampling_seed"] = int(seed)
|
|
360
|
+
return md
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def perturb_prompt(text: str, *, seed: int | None = None, max_changes: int = 2) -> str:
|
|
364
|
+
"""Apply small, semantics-preserving perturbations to a prompt.
|
|
365
|
+
|
|
366
|
+
Changes include optional punctuation tweaks and inserting polite filler words.
|
|
367
|
+
"""
|
|
368
|
+
rng = random.Random(seed)
|
|
369
|
+
t = text
|
|
370
|
+
changes = 0
|
|
371
|
+
# Optional punctuation swap
|
|
372
|
+
if "?" in t and changes < max_changes and rng.random() < 0.5:
|
|
373
|
+
t = t.replace("?", "??", 1)
|
|
374
|
+
changes += 1
|
|
375
|
+
# Optional polite filler insertion
|
|
376
|
+
fillers = ["please", "kindly", "if possible"]
|
|
377
|
+
if changes < max_changes and rng.random() < 0.5:
|
|
378
|
+
words = t.split()
|
|
379
|
+
if words:
|
|
380
|
+
idx = rng.randint(0, len(words) - 1)
|
|
381
|
+
words.insert(idx, rng.choice(fillers))
|
|
382
|
+
t = " ".join(words)
|
|
383
|
+
changes += 1
|
|
384
|
+
return t
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def create_prompt_variants(base_text: str, *, count: int, seed: int) -> list[str]:
|
|
388
|
+
"""Create multiple perturbed variants of a base prompt with deterministic seeding."""
|
|
389
|
+
rng = random.Random(seed)
|
|
390
|
+
return [
|
|
391
|
+
perturb_prompt(base_text, seed=rng.randint(0, 1_000_000))
|
|
392
|
+
for _ in range(max(1, count))
|
|
393
|
+
]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Backwards-compatible aliases for core entities."""
|
|
2
|
+
|
|
3
|
+
from themis.core import entities as core_entities
|
|
4
|
+
|
|
5
|
+
SamplingParameters = core_entities.SamplingConfig
|
|
6
|
+
ModelOutput = core_entities.ModelOutput
|
|
7
|
+
GenerationError = core_entities.ModelError
|
|
8
|
+
GenerationRequest = core_entities.GenerationTask
|
|
9
|
+
GenerationResult = core_entities.GenerationRecord
|
|
File without changes
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import asdict, is_dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from huggingface_hub import HfApi
|
|
10
|
+
else:
|
|
11
|
+
try:
|
|
12
|
+
from huggingface_hub import HfApi
|
|
13
|
+
except ImportError:
|
|
14
|
+
HfApi = None # type: ignore
|
|
15
|
+
|
|
16
|
+
from themis.config.schema import HuggingFaceHubConfig
|
|
17
|
+
from themis.core.entities import ExperimentReport
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def to_dict(obj):
|
|
21
|
+
if is_dataclass(obj):
|
|
22
|
+
return asdict(obj)
|
|
23
|
+
if hasattr(obj, "to_dict"):
|
|
24
|
+
return obj.to_dict()
|
|
25
|
+
if isinstance(obj, (list, tuple)):
|
|
26
|
+
return [to_dict(item) for item in obj]
|
|
27
|
+
if isinstance(obj, dict):
|
|
28
|
+
return {key: to_dict(value) for key, value in obj.items()}
|
|
29
|
+
return obj
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class HuggingFaceHubUploader:
|
|
33
|
+
def __init__(self, config: HuggingFaceHubConfig):
|
|
34
|
+
if HfApi is None:
|
|
35
|
+
raise ImportError(
|
|
36
|
+
"huggingface_hub is not installed. Install with: pip install huggingface_hub"
|
|
37
|
+
)
|
|
38
|
+
self.config = config
|
|
39
|
+
self.api = HfApi()
|
|
40
|
+
|
|
41
|
+
def upload_results(self, report: ExperimentReport, storage_path: Path) -> None:
|
|
42
|
+
if not self.config.enable or not self.config.repository:
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
report_dict = to_dict(report)
|
|
46
|
+
|
|
47
|
+
# Upload the full report as a JSON file
|
|
48
|
+
report_path = storage_path / "report.json"
|
|
49
|
+
with open(report_path, "w") as f:
|
|
50
|
+
json.dump(report_dict, f, indent=4)
|
|
51
|
+
|
|
52
|
+
self.api.upload_file(
|
|
53
|
+
path_or_fileobj=str(report_path),
|
|
54
|
+
path_in_repo=f"{report.metadata.get('run_id')}/report.json",
|
|
55
|
+
repo_id=self.config.repository,
|
|
56
|
+
repo_type="dataset",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Upload individual generation results
|
|
60
|
+
for record in report.generation_results:
|
|
61
|
+
record_dict = to_dict(record)
|
|
62
|
+
record_path = (
|
|
63
|
+
storage_path / f"{record.task.metadata.get('dataset_id')}.json"
|
|
64
|
+
)
|
|
65
|
+
with open(record_path, "w") as f:
|
|
66
|
+
json.dump(record_dict, f, indent=4)
|
|
67
|
+
self.api.upload_file(
|
|
68
|
+
path_or_fileobj=str(record_path),
|
|
69
|
+
path_in_repo=f"{report.metadata.get('run_id')}/generations/{record.task.metadata.get('dataset_id')}.json",
|
|
70
|
+
repo_id=self.config.repository,
|
|
71
|
+
repo_type="dataset",
|
|
72
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import wandb
|
|
7
|
+
else:
|
|
8
|
+
try:
|
|
9
|
+
import wandb
|
|
10
|
+
except ImportError:
|
|
11
|
+
wandb = None # type: ignore
|
|
12
|
+
|
|
13
|
+
from themis.config.schema import WandbConfig
|
|
14
|
+
from themis.core.entities import ExperimentReport
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WandbTracker:
|
|
18
|
+
def __init__(self, config: WandbConfig):
|
|
19
|
+
if wandb is None:
|
|
20
|
+
raise ImportError(
|
|
21
|
+
"wandb is not installed. Install with: pip install wandb"
|
|
22
|
+
)
|
|
23
|
+
self.config = config
|
|
24
|
+
|
|
25
|
+
def init(self, experiment_config: dict) -> None:
|
|
26
|
+
if not self.config.enable:
|
|
27
|
+
return
|
|
28
|
+
wandb.init(
|
|
29
|
+
project=self.config.project,
|
|
30
|
+
entity=self.config.entity,
|
|
31
|
+
tags=self.config.tags,
|
|
32
|
+
config=experiment_config,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def log_results(self, report: ExperimentReport) -> None:
|
|
36
|
+
if not self.config.enable:
|
|
37
|
+
return
|
|
38
|
+
summary = {
|
|
39
|
+
"total_samples": report.metadata.get("total_samples"),
|
|
40
|
+
"successful_generations": report.metadata.get("successful_generations"),
|
|
41
|
+
"failed_generations": report.metadata.get("failed_generations"),
|
|
42
|
+
"evaluation_failures": report.metadata.get("evaluation_failures"),
|
|
43
|
+
}
|
|
44
|
+
for name, aggregate in report.evaluation_report.metrics.items():
|
|
45
|
+
summary[f"{name}_mean"] = aggregate.mean
|
|
46
|
+
wandb.summary.update(summary)
|
|
47
|
+
|
|
48
|
+
records_table = wandb.Table(
|
|
49
|
+
columns=[
|
|
50
|
+
"sample_id",
|
|
51
|
+
"prompt",
|
|
52
|
+
"raw_response",
|
|
53
|
+
"parsed_response",
|
|
54
|
+
"error",
|
|
55
|
+
"metric_scores",
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
for record in report.generation_results:
|
|
59
|
+
eval_record = next(
|
|
60
|
+
(
|
|
61
|
+
r
|
|
62
|
+
for r in report.evaluation_report.records
|
|
63
|
+
if r.sample_id == record.task.metadata.get("dataset_id")
|
|
64
|
+
),
|
|
65
|
+
None,
|
|
66
|
+
)
|
|
67
|
+
records_table.add_data(
|
|
68
|
+
record.task.metadata.get("dataset_id"),
|
|
69
|
+
record.task.prompt,
|
|
70
|
+
[resp.text for resp in record.responses],
|
|
71
|
+
eval_record.parsed_response if eval_record else None,
|
|
72
|
+
record.error.message if record.error else None,
|
|
73
|
+
{s.metric_name: s.value for s in eval_record.scores}
|
|
74
|
+
if eval_record
|
|
75
|
+
else None,
|
|
76
|
+
)
|
|
77
|
+
wandb.log({"generation_results": records_table})
|