opik-optimizer 0.7.8__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opik_optimizer/__init__.py +2 -0
- opik_optimizer/base_optimizer.py +6 -4
- opik_optimizer/data/hotpot-500.json +501 -1001
- opik_optimizer/datasets/__init__.py +27 -0
- opik_optimizer/datasets/ai2_arc.py +44 -0
- opik_optimizer/datasets/cnn_dailymail.py +40 -0
- opik_optimizer/datasets/election_questions.py +36 -0
- opik_optimizer/datasets/gsm8k.py +40 -0
- opik_optimizer/datasets/halu_eval.py +43 -0
- opik_optimizer/datasets/hotpot_qa.py +68 -0
- opik_optimizer/datasets/medhallu.py +39 -0
- opik_optimizer/datasets/rag_hallucinations.py +41 -0
- opik_optimizer/datasets/ragbench.py +40 -0
- opik_optimizer/datasets/tiny_test.py +57 -0
- opik_optimizer/datasets/truthful_qa.py +107 -0
- opik_optimizer/demo/datasets.py +53 -607
- opik_optimizer/evolutionary_optimizer/evolutionary_optimizer.py +3 -1
- opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +90 -19
- opik_optimizer/logging_config.py +1 -1
- opik_optimizer/meta_prompt_optimizer.py +60 -14
- opik_optimizer/mipro_optimizer/mipro_optimizer.py +151 -13
- opik_optimizer/optimization_result.py +11 -0
- opik_optimizer/task_evaluator.py +6 -1
- opik_optimizer/utils.py +0 -52
- opik_optimizer-0.8.1.dist-info/METADATA +196 -0
- opik_optimizer-0.8.1.dist-info/RECORD +45 -0
- opik_optimizer-0.7.8.dist-info/METADATA +0 -174
- opik_optimizer-0.7.8.dist-info/RECORD +0 -33
- {opik_optimizer-0.7.8.dist-info → opik_optimizer-0.8.1.dist-info}/WHEEL +0 -0
- {opik_optimizer-0.7.8.dist-info → opik_optimizer-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {opik_optimizer-0.7.8.dist-info → opik_optimizer-0.8.1.dist-info}/top_level.txt +0 -0
opik_optimizer/demo/datasets.py
CHANGED
@@ -5,6 +5,21 @@ from datasets import load_dataset
|
|
5
5
|
import traceback
|
6
6
|
from importlib.resources import files
|
7
7
|
import json
|
8
|
+
import warnings
|
9
|
+
from ..datasets import (
|
10
|
+
hotpot_300,
|
11
|
+
hotpot_500,
|
12
|
+
halu_eval_300,
|
13
|
+
tiny_test,
|
14
|
+
gsm8k,
|
15
|
+
ai2_arc,
|
16
|
+
truthful_qa,
|
17
|
+
cnn_dailymail,
|
18
|
+
ragbench_sentence_relevance,
|
19
|
+
election_questions,
|
20
|
+
medhallu,
|
21
|
+
rag_hallucinations,
|
22
|
+
)
|
8
23
|
|
9
24
|
class HaltError(Exception):
|
10
25
|
"""Exception raised when we need to halt the process due to a critical error."""
|
@@ -29,611 +44,42 @@ def get_or_create_dataset(
|
|
29
44
|
"rag_hallucinations",
|
30
45
|
],
|
31
46
|
test_mode: bool = False,
|
47
|
+
seed: int = 42,
|
32
48
|
) -> opik.Dataset:
|
33
|
-
"""Get or create a dataset from HuggingFace."""
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
data = _load_cnn_dailymail(test_mode)
|
71
|
-
elif name == "ragbench_sentence_relevance":
|
72
|
-
data = _load_ragbench_sentence_relevance(test_mode)
|
73
|
-
elif name == "election_questions":
|
74
|
-
data = _load_election_questions(test_mode)
|
75
|
-
elif name == "medhallu":
|
76
|
-
data = _load_medhallu(test_mode)
|
77
|
-
elif name == "rag_hallucinations":
|
78
|
-
data = _load_rag_hallucinations(test_mode)
|
79
|
-
elif name == "math-50":
|
80
|
-
data = _load_math_50()
|
81
|
-
else:
|
82
|
-
raise HaltError(f"Unknown dataset: {name}")
|
83
|
-
|
84
|
-
if not data:
|
85
|
-
raise HaltError(f"No data loaded for dataset: {name}")
|
86
|
-
|
87
|
-
# Create dataset in Opik
|
88
|
-
try:
|
89
|
-
dataset = opik_client.create_dataset(dataset_name) # Use dataset_name with test mode suffix
|
90
|
-
except opik.rest_api.core.api_error.ApiError as e:
|
91
|
-
if e.status_code == 409: # Dataset already exists
|
92
|
-
# Try to get the dataset again
|
93
|
-
dataset = opik_client.get_dataset(dataset_name)
|
94
|
-
if not dataset:
|
95
|
-
raise HaltError(f"Dataset {dataset_name} exists but is empty")
|
96
|
-
return dataset
|
97
|
-
raise HaltError(f"Failed to create dataset {dataset_name}: {e}")
|
98
|
-
|
99
|
-
# Insert data into the dataset
|
100
|
-
try:
|
101
|
-
dataset.insert(data)
|
102
|
-
except Exception as e:
|
103
|
-
raise HaltError(f"Failed to insert data into dataset {dataset_name}: {e}")
|
104
|
-
|
105
|
-
# Verify data was added
|
106
|
-
items = dataset.get_items()
|
107
|
-
if not items or len(items) == 0:
|
108
|
-
raise HaltError(f"Failed to add data to dataset {dataset_name}")
|
109
|
-
|
110
|
-
return dataset
|
111
|
-
except HaltError:
|
112
|
-
raise # Re-raise HaltError to stop the process
|
113
|
-
except Exception as e:
|
114
|
-
print(f"Error loading dataset {name}: {e}")
|
115
|
-
print(traceback.format_exc())
|
116
|
-
raise HaltError(f"Critical error loading dataset {name}: {e}")
|
117
|
-
|
118
|
-
|
119
|
-
def _load_hotpot_500(test_mode: bool = False) -> List[Dict[str, Any]]:
|
120
|
-
size = 500 if not test_mode else 5
|
121
|
-
|
122
|
-
# This is not a random dataset
|
123
|
-
|
124
|
-
json_content = (files('opik_optimizer') / 'data' / 'hotpot-500.json').read_text(encoding='utf-8')
|
125
|
-
all_data = json.loads(json_content)
|
126
|
-
trainset = all_data[:size]
|
127
|
-
|
128
|
-
data = []
|
129
|
-
for row in reversed(trainset):
|
130
|
-
data.append(row)
|
131
|
-
return data
|
132
|
-
|
133
|
-
|
134
|
-
def _load_hotpot_300(test_mode: bool = False) -> List[Dict[str, Any]]:
|
135
|
-
size = 300 if not test_mode else 3
|
136
|
-
|
137
|
-
# This is not a random dataset
|
138
|
-
|
139
|
-
json_content = (files('opik_optimizer') / 'data' / 'hotpot-500.json').read_text(encoding='utf-8')
|
140
|
-
all_data = json.loads(json_content)
|
141
|
-
trainset = all_data[:size]
|
142
|
-
|
143
|
-
data = []
|
144
|
-
for row in reversed(trainset):
|
145
|
-
data.append(row)
|
146
|
-
return data
|
147
|
-
|
148
|
-
|
149
|
-
def _load_halu_eval_300(test_mode: bool = False) -> List[Dict[str, Any]]:
|
150
|
-
import pandas as pd
|
151
|
-
|
152
|
-
try:
|
153
|
-
df = pd.read_parquet(
|
154
|
-
"hf://datasets/pminervini/HaluEval/general/data-00000-of-00001.parquet"
|
155
|
-
)
|
156
|
-
except Exception:
|
157
|
-
raise Exception("Unable to download HaluEval; please try again") from None
|
158
|
-
|
159
|
-
df = df.sample(n=300, random_state=42)
|
160
|
-
|
161
|
-
dataset_records = [
|
162
|
-
{
|
163
|
-
"input": x["user_query"],
|
164
|
-
"llm_output": x["chatgpt_response"],
|
165
|
-
"expected_hallucination_label": x["hallucination"],
|
166
|
-
}
|
167
|
-
for x in df.to_dict(orient="records")
|
168
|
-
]
|
169
|
-
|
170
|
-
return dataset_records
|
171
|
-
|
172
|
-
|
173
|
-
def _load_tiny_test() -> List[Dict[str, Any]]:
|
174
|
-
return [
|
175
|
-
{
|
176
|
-
"text": "What is the capital of France?",
|
177
|
-
"label": "Paris",
|
178
|
-
"metadata": {
|
179
|
-
"context": "France is a country in Europe. Its capital is Paris."
|
180
|
-
},
|
181
|
-
},
|
182
|
-
{
|
183
|
-
"text": "Who wrote Romeo and Juliet?",
|
184
|
-
"label": "William Shakespeare",
|
185
|
-
"metadata": {
|
186
|
-
"context": "Romeo and Juliet is a famous play written by William Shakespeare."
|
187
|
-
},
|
188
|
-
},
|
189
|
-
{
|
190
|
-
"text": "What is 2 + 2?",
|
191
|
-
"label": "4",
|
192
|
-
"metadata": {"context": "Basic arithmetic: 2 + 2 equals 4."},
|
193
|
-
},
|
194
|
-
{
|
195
|
-
"text": "What is the largest planet in our solar system?",
|
196
|
-
"label": "Jupiter",
|
197
|
-
"metadata": {
|
198
|
-
"context": "Jupiter is the largest planet in our solar system."
|
199
|
-
},
|
200
|
-
},
|
201
|
-
{
|
202
|
-
"text": "Who painted the Mona Lisa?",
|
203
|
-
"label": "Leonardo da Vinci",
|
204
|
-
"metadata": {"context": "The Mona Lisa was painted by Leonardo da Vinci."},
|
205
|
-
},
|
206
|
-
]
|
207
|
-
|
208
|
-
|
209
|
-
def _load_gsm8k(test_mode: bool = False) -> List[Dict[str, Any]]:
|
210
|
-
"""Load GSM8K dataset with 300 examples."""
|
211
|
-
try:
|
212
|
-
# Use streaming to avoid downloading the entire dataset
|
213
|
-
dataset = load_dataset("gsm8k", "main", streaming=True)
|
214
|
-
n_samples = 5 if test_mode else 300
|
215
|
-
|
216
|
-
# Convert streaming dataset to list
|
217
|
-
data = []
|
218
|
-
for i, item in enumerate(dataset["train"]):
|
219
|
-
if i >= n_samples:
|
220
|
-
break
|
221
|
-
data.append({
|
222
|
-
"question": item["question"],
|
223
|
-
"answer": item["answer"],
|
224
|
-
})
|
225
|
-
return data
|
226
|
-
except Exception as e:
|
227
|
-
print(f"Error loading GSM8K dataset: {e}")
|
228
|
-
raise Exception("Unable to download gsm8k; please try again") from None
|
229
|
-
|
230
|
-
|
231
|
-
def _load_hotpot_qa(test_mode: bool = False) -> List[Dict[str, Any]]:
|
232
|
-
"""Load HotpotQA dataset with 300 examples."""
|
233
|
-
try:
|
234
|
-
# Use streaming to avoid downloading the entire dataset
|
235
|
-
dataset = load_dataset("hotpot_qa", "distractor", streaming=True)
|
236
|
-
n_samples = 5 if test_mode else 300
|
237
|
-
|
238
|
-
# Convert streaming dataset to list
|
239
|
-
data = []
|
240
|
-
for i, item in enumerate(dataset["train"]):
|
241
|
-
if i >= n_samples:
|
242
|
-
break
|
243
|
-
data.append({
|
244
|
-
"question": item["question"],
|
245
|
-
"answer": item["answer"],
|
246
|
-
"context": item["context"],
|
247
|
-
})
|
248
|
-
return data
|
249
|
-
except Exception as e:
|
250
|
-
print(f"Error loading HotpotQA dataset: {e}")
|
251
|
-
raise Exception("Unable to download HotPotQA; please try again") from None
|
252
|
-
|
253
|
-
|
254
|
-
def _load_ai2_arc(test_mode: bool = False) -> List[Dict[str, Any]]:
|
255
|
-
"""Load AI2 ARC dataset with 300 examples."""
|
256
|
-
try:
|
257
|
-
# Use streaming to avoid downloading the entire dataset
|
258
|
-
dataset = load_dataset("ai2_arc", "ARC-Challenge", streaming=True)
|
259
|
-
n_samples = 5 if test_mode else 300
|
260
|
-
|
261
|
-
# Convert streaming dataset to list
|
262
|
-
data = []
|
263
|
-
for i, item in enumerate(dataset["train"]):
|
264
|
-
if i >= n_samples:
|
265
|
-
break
|
266
|
-
data.append({
|
267
|
-
"question": item["question"],
|
268
|
-
"answer": item["answerKey"],
|
269
|
-
"choices": item["choices"],
|
270
|
-
})
|
271
|
-
return data
|
272
|
-
except Exception as e:
|
273
|
-
print(f"Error loading AI2 ARC dataset: {e}")
|
274
|
-
raise Exception("Unable to download ai2_arc; please try again") from None
|
275
|
-
|
276
|
-
|
277
|
-
def _load_truthful_qa(test_mode: bool = False) -> List[Dict]:
|
278
|
-
"""Load TruthfulQA dataset."""
|
279
|
-
try:
|
280
|
-
# Load both configurations
|
281
|
-
try:
|
282
|
-
gen_dataset = load_dataset("truthful_qa", "generation")
|
283
|
-
mc_dataset = load_dataset("truthful_qa", "multiple_choice")
|
284
|
-
except Exception:
|
285
|
-
raise Exception(
|
286
|
-
"Unable to download truthful_qa; please try again"
|
287
|
-
) from None
|
288
|
-
|
289
|
-
# Combine data from both configurations
|
290
|
-
data = []
|
291
|
-
n_samples = 5 if test_mode else 300
|
292
|
-
for gen_item, mc_item in zip(
|
293
|
-
gen_dataset["validation"], mc_dataset["validation"]
|
294
|
-
):
|
295
|
-
if len(data) >= n_samples:
|
296
|
-
break
|
297
|
-
|
298
|
-
# Get correct answers from both configurations
|
299
|
-
correct_answers = set(gen_item["correct_answers"])
|
300
|
-
if "mc1_targets" in mc_item:
|
301
|
-
correct_answers.update(
|
302
|
-
[
|
303
|
-
choice
|
304
|
-
for choice, label in zip(
|
305
|
-
mc_item["mc1_targets"]["choices"],
|
306
|
-
mc_item["mc1_targets"]["labels"],
|
307
|
-
)
|
308
|
-
if label == 1
|
309
|
-
]
|
310
|
-
)
|
311
|
-
if "mc2_targets" in mc_item:
|
312
|
-
correct_answers.update(
|
313
|
-
[
|
314
|
-
choice
|
315
|
-
for choice, label in zip(
|
316
|
-
mc_item["mc2_targets"]["choices"],
|
317
|
-
mc_item["mc2_targets"]["labels"],
|
318
|
-
)
|
319
|
-
if label == 1
|
320
|
-
]
|
321
|
-
)
|
322
|
-
|
323
|
-
# Get all possible answers
|
324
|
-
all_answers = set(
|
325
|
-
gen_item["correct_answers"] + gen_item["incorrect_answers"]
|
326
|
-
)
|
327
|
-
if "mc1_targets" in mc_item:
|
328
|
-
all_answers.update(mc_item["mc1_targets"]["choices"])
|
329
|
-
if "mc2_targets" in mc_item:
|
330
|
-
all_answers.update(mc_item["mc2_targets"]["choices"])
|
331
|
-
|
332
|
-
# Create a single example with all necessary fields
|
333
|
-
example = {
|
334
|
-
"question": gen_item["question"],
|
335
|
-
"answer": gen_item["best_answer"],
|
336
|
-
"choices": list(all_answers),
|
337
|
-
"correct_answer": gen_item["best_answer"],
|
338
|
-
"input": gen_item["question"], # For AnswerRelevance metric
|
339
|
-
"output": gen_item["best_answer"], # For output_key requirement
|
340
|
-
"context": gen_item.get("source", ""), # Use source as context
|
341
|
-
"type": "TEXT", # Set type to TEXT as required by Opik
|
342
|
-
"category": gen_item["category"],
|
343
|
-
"source": "MANUAL", # Set source to MANUAL as required by Opik
|
344
|
-
"correct_answers": list(
|
345
|
-
correct_answers
|
346
|
-
), # Keep track of all correct answers
|
347
|
-
"incorrect_answers": gen_item[
|
348
|
-
"incorrect_answers"
|
349
|
-
], # Keep track of incorrect answers
|
350
|
-
}
|
351
|
-
|
352
|
-
# Ensure all required fields are present
|
353
|
-
required_fields = [
|
354
|
-
"question",
|
355
|
-
"answer",
|
356
|
-
"choices",
|
357
|
-
"correct_answer",
|
358
|
-
"input",
|
359
|
-
"output",
|
360
|
-
"context",
|
361
|
-
]
|
362
|
-
if all(field in example and example[field] for field in required_fields):
|
363
|
-
data.append(example)
|
364
|
-
|
365
|
-
if not data:
|
366
|
-
raise ValueError("No valid examples found in TruthfulQA dataset")
|
367
|
-
|
368
|
-
return data
|
369
|
-
except Exception as e:
|
370
|
-
print(f"Error loading TruthfulQA dataset: {e}")
|
371
|
-
print(traceback.format_exc())
|
372
|
-
raise
|
373
|
-
|
374
|
-
|
375
|
-
def _load_cnn_dailymail(test_mode: bool = False) -> List[Dict]:
|
376
|
-
"""Load CNN Daily Mail dataset with 100 examples."""
|
377
|
-
try:
|
378
|
-
dataset = load_dataset("cnn_dailymail", "3.0.0", streaming=True)
|
379
|
-
n_samples = 5 if test_mode else 100
|
380
|
-
|
381
|
-
# Convert streaming dataset to list
|
382
|
-
data = []
|
383
|
-
for i, item in enumerate(dataset["validation"]):
|
384
|
-
if i >= n_samples:
|
385
|
-
break
|
386
|
-
data.append({
|
387
|
-
"article": item["article"],
|
388
|
-
"highlights": item["highlights"],
|
389
|
-
})
|
390
|
-
return data
|
391
|
-
except Exception as e:
|
392
|
-
print(f"Error loading CNN Daily Mail dataset: {e}")
|
393
|
-
raise Exception("Unable to download cnn_dailymail; please try again") from None
|
394
|
-
|
395
|
-
|
396
|
-
def _load_math_50():
|
397
|
-
return [
|
398
|
-
{"question": "What is (5 + 3) * 2 - 4?", "expected answer": "12"},
|
399
|
-
{
|
400
|
-
"question": "If you divide 20 by 4 and then add 7, what do you get?",
|
401
|
-
"expected answer": "12",
|
402
|
-
},
|
403
|
-
{
|
404
|
-
"question": "Start with 10, subtract 2, multiply the result by 3, then add 5.",
|
405
|
-
"expected answer": "29",
|
406
|
-
},
|
407
|
-
{
|
408
|
-
"question": "Add 6 and 4, then divide by 2, and finally multiply by 5.",
|
409
|
-
"expected answer": "25",
|
410
|
-
},
|
411
|
-
{
|
412
|
-
"question": "Take 15, subtract 3, add 2, then divide the result by 2.",
|
413
|
-
"expected answer": "7",
|
414
|
-
},
|
415
|
-
{"question": "What is 7 * (6 - 2) + 1?", "expected answer": "29"},
|
416
|
-
{
|
417
|
-
"question": "If you multiply 8 by 3 and subtract 5, what is the result?",
|
418
|
-
"expected answer": "19",
|
419
|
-
},
|
420
|
-
{
|
421
|
-
"question": "Begin with 25, divide by 5, then multiply by 4.",
|
422
|
-
"expected answer": "20",
|
423
|
-
},
|
424
|
-
{
|
425
|
-
"question": "Subtract 9 from 17, then multiply the difference by 3.",
|
426
|
-
"expected answer": "24",
|
427
|
-
},
|
428
|
-
{"question": "What is 10 + 5 * 3 - 8?", "expected answer": "17"},
|
429
|
-
{"question": "Divide 36 by 6, then add 11.", "expected answer": "17"},
|
430
|
-
{
|
431
|
-
"question": "Start with 2, multiply by 9, subtract 7, and add 4.",
|
432
|
-
"expected answer": "15",
|
433
|
-
},
|
434
|
-
{
|
435
|
-
"question": "Add 12 and 8, divide by 4, and then subtract 1.",
|
436
|
-
"expected answer": "4",
|
437
|
-
},
|
438
|
-
{
|
439
|
-
"question": "Take 30, subtract 10, divide by 2, and add 7.",
|
440
|
-
"expected answer": "17",
|
441
|
-
},
|
442
|
-
{"question": "What is (15 - 5) / 2 * 3?", "expected answer": "15"},
|
443
|
-
{
|
444
|
-
"question": "If you add 14 and 6, and then divide by 5, what do you get?",
|
445
|
-
"expected answer": "4",
|
446
|
-
},
|
447
|
-
{
|
448
|
-
"question": "Start with 50, divide by 10, multiply by 2, and subtract 3.",
|
449
|
-
"expected answer": "7",
|
450
|
-
},
|
451
|
-
{
|
452
|
-
"question": "Subtract 4 from 11, multiply by 5, and then add 2.",
|
453
|
-
"expected answer": "37",
|
454
|
-
},
|
455
|
-
{"question": "What is 9 * 4 - 12 / 3?", "expected answer": "32"},
|
456
|
-
{
|
457
|
-
"question": "Divide 42 by 7, and then multiply by 3.",
|
458
|
-
"expected answer": "18",
|
459
|
-
},
|
460
|
-
{
|
461
|
-
"question": "Begin with 1, add 19, divide by 4, and multiply by 6.",
|
462
|
-
"expected answer": "30",
|
463
|
-
},
|
464
|
-
{
|
465
|
-
"question": "Subtract 6 from 21, then divide the result by 5.",
|
466
|
-
"expected answer": "3",
|
467
|
-
},
|
468
|
-
{"question": "What is (8 + 7) * 2 - 9?", "expected answer": "21"},
|
469
|
-
{
|
470
|
-
"question": "If you multiply 7 by 5 and then subtract 11, what is the answer?",
|
471
|
-
"expected answer": "24",
|
472
|
-
},
|
473
|
-
{
|
474
|
-
"question": "Start with 3, multiply by 8, add 6, and then divide by 2.",
|
475
|
-
"expected answer": "15",
|
476
|
-
},
|
477
|
-
{"question": "What is 3 * (10 - 4) + 5?", "expected answer": "23"},
|
478
|
-
{
|
479
|
-
"question": "If you multiply 12 by 2 and subtract 7, what is the result?",
|
480
|
-
"expected answer": "17",
|
481
|
-
},
|
482
|
-
{
|
483
|
-
"question": "Begin with 35, divide by 7, then multiply by 6.",
|
484
|
-
"expected answer": "30",
|
485
|
-
},
|
486
|
-
{
|
487
|
-
"question": "Subtract 11 from 20, then multiply the difference by 4.",
|
488
|
-
"expected answer": "36",
|
489
|
-
},
|
490
|
-
{"question": "What is 15 + 3 * 7 - 9?", "expected answer": "27"},
|
491
|
-
{"question": "Divide 63 by 9, then add 13.", "expected answer": "20"},
|
492
|
-
{
|
493
|
-
"question": "Start with 6, multiply by 5, subtract 8, and add 11.",
|
494
|
-
"expected answer": "33",
|
495
|
-
},
|
496
|
-
{
|
497
|
-
"question": "Add 18 and 6, divide by 3, and then subtract 4.",
|
498
|
-
"expected answer": "4",
|
499
|
-
},
|
500
|
-
{
|
501
|
-
"question": "Take 50, subtract 20, divide by 5, and add 9.",
|
502
|
-
"expected answer": "15",
|
503
|
-
},
|
504
|
-
{"question": "What is (25 - 10) / 3 * 4?", "expected answer": "20"},
|
505
|
-
{
|
506
|
-
"question": "If you add 9 and 15, and then divide by 8, what do you get?",
|
507
|
-
"expected answer": "3",
|
508
|
-
},
|
509
|
-
{
|
510
|
-
"question": "Start with 40, divide by 5, multiply by 3, and subtract 7.",
|
511
|
-
"expected answer": "17",
|
512
|
-
},
|
513
|
-
{
|
514
|
-
"question": "Subtract 5 from 22, multiply by 2, and then divide by 6.",
|
515
|
-
"expected answer": "5.666666666666667",
|
516
|
-
},
|
517
|
-
{"question": "What is 7 * 6 + 8 - 11?", "expected answer": "39"},
|
518
|
-
{
|
519
|
-
"question": "Divide 72 by 8, and then multiply by 5.",
|
520
|
-
"expected answer": "45",
|
521
|
-
},
|
522
|
-
{
|
523
|
-
"question": "Begin with 3, add 17, divide by 5, and multiply by 7.",
|
524
|
-
"expected answer": "28",
|
525
|
-
},
|
526
|
-
{
|
527
|
-
"question": "Subtract 9 from 31, then divide the result by 4.",
|
528
|
-
"expected answer": "5.5",
|
529
|
-
},
|
530
|
-
{"question": "What is (11 + 9) * 3 - 15?", "expected answer": "45"},
|
531
|
-
{
|
532
|
-
"question": "If you multiply 8 by 7 and then subtract 19, what is the answer?",
|
533
|
-
"expected answer": "37",
|
534
|
-
},
|
535
|
-
{
|
536
|
-
"question": "Start with 2, multiply by 12, add 16, and then divide by 4.",
|
537
|
-
"expected answer": "10",
|
538
|
-
},
|
539
|
-
{
|
540
|
-
"question": "Add 13 and 19, then subtract 6, and finally divide by 2.",
|
541
|
-
"expected answer": "13",
|
542
|
-
},
|
543
|
-
{
|
544
|
-
"question": "Take 45, divide by 9, add 11, and then subtract 3.",
|
545
|
-
"expected answer": "13",
|
546
|
-
},
|
547
|
-
{"question": "What is 18 - 4 * 3 + 7?", "expected answer": "13"},
|
548
|
-
{
|
549
|
-
"question": "If you divide 56 by 7 and then add 9, what do you get?",
|
550
|
-
"expected answer": "17",
|
551
|
-
},
|
552
|
-
{
|
553
|
-
"question": "Begin with 4, multiply by 9, subtract 12, and then divide by 6.",
|
554
|
-
"expected answer": "4",
|
555
|
-
},
|
556
|
-
]
|
557
|
-
|
558
|
-
|
559
|
-
def _load_ragbench_sentence_relevance(test_mode: bool = False) -> List[Dict]:
|
560
|
-
"""Load RAGBench sentence relevance dataset."""
|
561
|
-
try:
|
562
|
-
dataset = load_dataset("wandb/ragbench-sentence-relevance-balanced")
|
563
|
-
except Exception:
|
564
|
-
raise Exception("Unable to download ragbench-sentence-relevance; please try again") from None
|
565
|
-
|
566
|
-
n_samples = 5 if test_mode else 300
|
567
|
-
train_data = dataset["train"].select(range(n_samples))
|
568
|
-
|
569
|
-
return [
|
570
|
-
{
|
571
|
-
"question": item["question"],
|
572
|
-
"sentence": item["sentence"],
|
573
|
-
"label": item["label"],
|
574
|
-
}
|
575
|
-
for item in train_data
|
576
|
-
]
|
577
|
-
|
578
|
-
|
579
|
-
def _load_election_questions(test_mode: bool = False) -> List[Dict]:
|
580
|
-
"""Load Anthropic election questions dataset."""
|
581
|
-
try:
|
582
|
-
dataset = load_dataset("Anthropic/election_questions")
|
583
|
-
except Exception:
|
584
|
-
raise Exception("Unable to download election_questions; please try again") from None
|
585
|
-
|
586
|
-
n_samples = 5 if test_mode else 300
|
587
|
-
train_data = dataset["test"].select(range(n_samples))
|
588
|
-
|
589
|
-
return [
|
590
|
-
{
|
591
|
-
"question": item["question"],
|
592
|
-
"label": item["label"], # "Harmless" or "Harmful"
|
593
|
-
}
|
594
|
-
for item in train_data
|
595
|
-
]
|
596
|
-
|
597
|
-
|
598
|
-
def _load_medhallu(test_mode: bool = False) -> List[Dict]:
|
599
|
-
"""Load MedHallu medical hallucinations dataset."""
|
600
|
-
try:
|
601
|
-
dataset = load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled")
|
602
|
-
except Exception:
|
603
|
-
raise Exception("Unable to download medhallu; please try again") from None
|
604
|
-
|
605
|
-
n_samples = 5 if test_mode else 300
|
606
|
-
train_data = dataset["train"].select(range(n_samples))
|
607
|
-
|
608
|
-
return [
|
609
|
-
{
|
610
|
-
"question": item["Question"],
|
611
|
-
"knowledge": item["Knowledge"],
|
612
|
-
"ground_truth": item["Ground Truth"],
|
613
|
-
"hallucinated_answer": item["Hallucinated Answer"],
|
614
|
-
"difficulty_level": item["Difficulty Level"],
|
615
|
-
"hallucination_category": item["Category of Hallucination"],
|
616
|
-
}
|
617
|
-
for item in train_data
|
618
|
-
]
|
619
|
-
|
620
|
-
|
621
|
-
def _load_rag_hallucinations(test_mode: bool = False) -> List[Dict]:
|
622
|
-
"""Load Aporia RAG hallucinations dataset."""
|
623
|
-
try:
|
624
|
-
dataset = load_dataset("aporia-ai/rag_hallucinations")
|
625
|
-
except Exception:
|
626
|
-
raise Exception("Unable to download rag_hallucinations; please try again") from None
|
627
|
-
|
628
|
-
n_samples = 5 if test_mode else 300
|
629
|
-
train_data = dataset["train"].select(range(n_samples))
|
630
|
-
|
631
|
-
return [
|
632
|
-
{
|
633
|
-
"context": item["context"],
|
634
|
-
"question": item["question"],
|
635
|
-
"answer": item["answer"],
|
636
|
-
"is_hallucination": item["is_hallucination"],
|
637
|
-
}
|
638
|
-
for item in train_data
|
639
|
-
]
|
49
|
+
"""Get or create a dataset from HuggingFace, using the provided seed for sampling."""
|
50
|
+
warnings.warn(
|
51
|
+
"This function is deprecated. Please use the datasets directly from opik_optimizer.datasets module instead."
|
52
|
+
" For example: opik_optimizer.datasets.truthful_qa() or opik_optimizer.datasets.rag_hallucination()",
|
53
|
+
DeprecationWarning,
|
54
|
+
stacklevel=2
|
55
|
+
)
|
56
|
+
if name == "hotpot-300":
|
57
|
+
dataset = hotpot_300(test_mode)
|
58
|
+
elif name == "hotpot-500":
|
59
|
+
dataset = hotpot_500(test_mode)
|
60
|
+
elif name == "halu-eval-300":
|
61
|
+
dataset = halu_eval_300(test_mode)
|
62
|
+
elif name == "tiny-test":
|
63
|
+
dataset = tiny_test()
|
64
|
+
elif name == "gsm8k":
|
65
|
+
dataset = gsm8k(test_mode)
|
66
|
+
elif name == "hotpot_qa":
|
67
|
+
raise HaltError("HotpotQA dataset is no longer available in the demo datasets.")
|
68
|
+
elif name == "ai2_arc":
|
69
|
+
dataset = ai2_arc(test_mode)
|
70
|
+
elif name == "truthful_qa":
|
71
|
+
dataset = truthful_qa(test_mode)
|
72
|
+
elif name == "cnn_dailymail":
|
73
|
+
dataset = cnn_dailymail(test_mode)
|
74
|
+
elif name == "ragbench_sentence_relevance":
|
75
|
+
dataset = ragbench_sentence_relevance(test_mode)
|
76
|
+
elif name == "election_questions":
|
77
|
+
dataset = election_questions(test_mode)
|
78
|
+
elif name == "medhallu":
|
79
|
+
dataset = medhallu(test_mode)
|
80
|
+
elif name == "rag_hallucinations":
|
81
|
+
dataset = rag_hallucinations(test_mode)
|
82
|
+
else:
|
83
|
+
raise HaltError(f"Unknown dataset: {name}")
|
84
|
+
|
85
|
+
return dataset
|
@@ -706,7 +706,9 @@ Ensure a good mix of variations, all targeting the specified output style from t
|
|
706
706
|
opik_optimization_run = None
|
707
707
|
try:
|
708
708
|
opik_optimization_run = self._opik_client.create_optimization(
|
709
|
-
dataset_name=opik_dataset_obj.name,
|
709
|
+
dataset_name=opik_dataset_obj.name,
|
710
|
+
objective_name=metric_config.metric.name,
|
711
|
+
metadata={"optimizer": self.__class__.__name__},
|
710
712
|
)
|
711
713
|
self._current_optimization_id = opik_optimization_run.id
|
712
714
|
logger.info(f"Created Opik Optimization run with ID: {self._current_optimization_id}")
|