opik-optimizer 0.7.8__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. opik_optimizer/__init__.py +2 -0
  2. opik_optimizer/base_optimizer.py +6 -4
  3. opik_optimizer/datasets/__init__.py +27 -0
  4. opik_optimizer/datasets/ai2_arc.py +44 -0
  5. opik_optimizer/datasets/cnn_dailymail.py +40 -0
  6. opik_optimizer/datasets/election_questions.py +36 -0
  7. opik_optimizer/datasets/gsm8k.py +40 -0
  8. opik_optimizer/datasets/halu_eval.py +43 -0
  9. opik_optimizer/datasets/hotpot_qa.py +67 -0
  10. opik_optimizer/datasets/medhallu.py +39 -0
  11. opik_optimizer/datasets/rag_hallucinations.py +41 -0
  12. opik_optimizer/datasets/ragbench.py +40 -0
  13. opik_optimizer/datasets/tiny_test.py +57 -0
  14. opik_optimizer/datasets/truthful_qa.py +107 -0
  15. opik_optimizer/demo/datasets.py +53 -607
  16. opik_optimizer/evolutionary_optimizer/evolutionary_optimizer.py +3 -1
  17. opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +88 -17
  18. opik_optimizer/logging_config.py +1 -1
  19. opik_optimizer/meta_prompt_optimizer.py +57 -11
  20. opik_optimizer/mipro_optimizer/mipro_optimizer.py +151 -13
  21. opik_optimizer/optimization_result.py +11 -0
  22. opik_optimizer/task_evaluator.py +6 -1
  23. opik_optimizer/utils.py +0 -52
  24. opik_optimizer-0.8.0.dist-info/METADATA +196 -0
  25. opik_optimizer-0.8.0.dist-info/RECORD +45 -0
  26. opik_optimizer-0.7.8.dist-info/METADATA +0 -174
  27. opik_optimizer-0.7.8.dist-info/RECORD +0 -33
  28. {opik_optimizer-0.7.8.dist-info → opik_optimizer-0.8.0.dist-info}/WHEEL +0 -0
  29. {opik_optimizer-0.7.8.dist-info → opik_optimizer-0.8.0.dist-info}/licenses/LICENSE +0 -0
  30. {opik_optimizer-0.7.8.dist-info → opik_optimizer-0.8.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from .optimization_config.mappers import (
23
23
  )
24
24
 
25
25
  from opik.evaluation.models.litellm import warning_filters
26
+ from . import datasets
26
27
 
27
28
  warning_filters.add_warning_filters()
28
29
 
@@ -42,4 +43,5 @@ __all__ = [
42
43
  "from_llm_response_text",
43
44
  "OptimizationResult",
44
45
  "setup_logging",
46
+ "datasets",
45
47
  ]
@@ -4,15 +4,15 @@ import logging
4
4
  import time
5
5
 
6
6
  import litellm
7
+ from . import _throttle
7
8
  from opik.rest_api.core import ApiError
8
9
 
9
10
  from pydantic import BaseModel
10
- from ._throttle import RateLimiter, rate_limited
11
11
  from .cache_config import initialize_cache
12
12
  from opik.evaluation.models.litellm import opik_monitor as opik_litellm_monitor
13
13
  from .optimization_config.configs import TaskConfig, MetricConfig
14
14
 
15
- limiter = RateLimiter(max_calls_per_second=8)
15
+ _limiter = _throttle.get_rate_limiter_for_current_opik_installation()
16
16
 
17
17
  # Don't use unsupported params:
18
18
  litellm.drop_params = True
@@ -32,19 +32,21 @@ class OptimizationRound(BaseModel):
32
32
 
33
33
 
34
34
  class BaseOptimizer:
35
- def __init__(self, model: str, project_name: Optional[str] = None, **model_kwargs):
35
+ def __init__(self, model: str, project_name: Optional[str] = None, verbose: int = 1, **model_kwargs):
36
36
  """
37
37
  Base class for optimizers.
38
38
 
39
39
  Args:
40
40
  model: LiteLLM model name
41
41
  project_name: Opik project name
42
+ verbose: Controls internal logging/progress bars (0=off, 1=on).
42
43
  model_kwargs: additional args for model (eg, temperature)
43
44
  """
44
45
  self.model = model
45
46
  self.reasoning_model = model
46
47
  self.model_kwargs = model_kwargs
47
48
  self.project_name = project_name
49
+ self.verbose = verbose
48
50
  self._history = []
49
51
  self.experiment_config = None
50
52
  self.llm_call_counter = 0
@@ -141,7 +143,7 @@ class BaseOptimizer:
141
143
  """
142
144
  self._history.append(round_data)
143
145
 
144
-
146
+
145
147
  def update_optimization(self, optimization, status: str) -> None:
146
148
  """
147
149
  Update the optimization status
@@ -0,0 +1,27 @@
1
+ from .hotpot_qa import hotpot_300, hotpot_500
2
+ from .halu_eval import halu_eval_300
3
+ from .tiny_test import tiny_test
4
+ from .gsm8k import gsm8k
5
+ from .ai2_arc import ai2_arc
6
+ from .truthful_qa import truthful_qa
7
+ from .cnn_dailymail import cnn_dailymail
8
+ from .ragbench import ragbench_sentence_relevance
9
+ from .election_questions import election_questions
10
+ from .medhallu import medhallu
11
+ from .rag_hallucinations import rag_hallucinations
12
+
13
+
14
+ __all__ = [
15
+ "hotpot_300",
16
+ "hotpot_500",
17
+ "halu_eval_300",
18
+ "tiny_test",
19
+ "gsm8k",
20
+ "ai2_arc",
21
+ "truthful_qa",
22
+ "cnn_dailymail",
23
+ "ragbench_sentence_relevance",
24
+ "election_questions",
25
+ "medhallu",
26
+ "rag_hallucinations",
27
+ ]
@@ -0,0 +1,44 @@
1
+ import opik
2
+
3
+ def ai2_arc(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 300 samples of the AI2 ARC dataset.
8
+ """
9
+ dataset_name = "ai2_arc" if not test_mode else "ai2_arc_test"
10
+ nb_items = 300 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import datasets as ds
22
+
23
+ # Load data from file and insert into the dataset
24
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
+ ds.disable_progress_bar()
26
+ hf_dataset = ds.load_dataset(
27
+ "ai2_arc", "ARC-Challenge",
28
+ streaming=True, download_config=download_config
29
+ )
30
+
31
+ data = []
32
+ for i, item in enumerate(hf_dataset["train"]):
33
+ if i >= nb_items:
34
+ break
35
+ data.append({
36
+ "question": item["question"],
37
+ "answer": item["answerKey"],
38
+ "choices": item["choices"],
39
+ })
40
+ ds.enable_progress_bar()
41
+
42
+ dataset.insert(data)
43
+
44
+ return dataset
@@ -0,0 +1,40 @@
1
+ import opik
2
+
3
+ def cnn_dailymail(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 100 samples of the CNN Daily Mail dataset.
8
+ """
9
+ dataset_name = "cnn_dailymail" if not test_mode else "cnn_dailymail_test"
10
+ nb_items = 100 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import datasets as ds
22
+
23
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
24
+ ds.disable_progress_bar()
25
+ hf_dataset = ds.load_dataset("cnn_dailymail", "3.0.0", streaming=True, download_config=download_config)
26
+
27
+ data = []
28
+ for i, item in enumerate(hf_dataset["validation"]):
29
+ if i >= nb_items:
30
+ break
31
+ data.append({
32
+ "article": item["article"],
33
+ "highlights": item["highlights"],
34
+ })
35
+ ds.enable_progress_bar()
36
+
37
+ dataset.insert(data)
38
+
39
+ return dataset
40
+
@@ -0,0 +1,36 @@
1
+ import opik
2
+
3
+
4
+ def election_questions(
5
+ test_mode: bool = False
6
+ ) -> opik.Dataset:
7
+ dataset_name = "election_questions" if not test_mode else "election_questions_test"
8
+ nb_items = 300 if not test_mode else 5
9
+
10
+ client = opik.Opik()
11
+ dataset = client.get_or_create_dataset(dataset_name)
12
+
13
+ items = dataset.get_items()
14
+ if len(items) == nb_items:
15
+ return dataset
16
+ elif len(items) != 0:
17
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
18
+ elif len(items) == 0:
19
+ import datasets as ds
20
+
21
+ # Load data from file and insert into the dataset
22
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
23
+ ds.disable_progress_bar()
24
+ hf_dataset = ds.load_dataset("Anthropic/election_questions", download_config=download_config)
25
+
26
+ data = [
27
+ {
28
+ "question": item["question"],
29
+ "label": item["label"]
30
+ }
31
+ for item in hf_dataset["test"].select(range(nb_items))
32
+ ]
33
+ ds.enable_progress_bar()
34
+ dataset.insert(data)
35
+
36
+ return dataset
@@ -0,0 +1,40 @@
1
+ import opik
2
+
3
+ def gsm8k(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 300 samples of the GSM8K dataset.
8
+ """
9
+ dataset_name = "gsm8k" if not test_mode else "gsm8k_test"
10
+ nb_items = 300 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import datasets as ds
22
+
23
+ # Load data from file and insert into the dataset
24
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
+ ds.disable_progress_bar()
26
+ hf_dataset = ds.load_dataset("gsm8k", "main", streaming=True, download_config=download_config)
27
+
28
+ data = []
29
+ for i, item in enumerate(hf_dataset["train"]):
30
+ if i >= nb_items:
31
+ break
32
+ data.append({
33
+ "question": item["question"],
34
+ "answer": item["answer"],
35
+ })
36
+ ds.enable_progress_bar()
37
+
38
+ dataset.insert(data)
39
+
40
+ return dataset
@@ -0,0 +1,43 @@
1
+ import opik
2
+
3
+ def halu_eval_300(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 300 samples of the HaluEval dataset.
8
+ """
9
+ dataset_name = "halu_eval_300" if not test_mode else "halu_eval_300_test"
10
+ nb_items = 300 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import pandas as pd
22
+
23
+ try:
24
+ df = pd.read_parquet(
25
+ "hf://datasets/pminervini/HaluEval/general/data-00000-of-00001.parquet"
26
+ )
27
+ except Exception:
28
+ raise Exception("Unable to download HaluEval; please try again") from None
29
+
30
+ sample_size = min(nb_items, len(df))
31
+ df_sampled = df.sample(n=sample_size, random_state=42)
32
+
33
+ dataset_records = [
34
+ {
35
+ "input": x["user_query"],
36
+ "llm_output": x["chatgpt_response"],
37
+ "expected_hallucination_label": x["hallucination"],
38
+ }
39
+ for x in df_sampled.to_dict(orient="records")
40
+ ]
41
+
42
+ dataset.insert(dataset_records)
43
+ return dataset
@@ -0,0 +1,67 @@
1
+ import opik
2
+ from importlib.resources import files
3
+ import json
4
+
5
+ def hotpot_300(
6
+ test_mode: bool = False
7
+ ) -> opik.Dataset:
8
+ """
9
+ Dataset containing the first 300 samples of the HotpotQA dataset.
10
+ """
11
+ dataset_name = "hotpot_300" if not test_mode else "hotpot_300_test"
12
+ nb_items = 300 if not test_mode else 3
13
+
14
+ client = opik.Opik()
15
+ dataset = client.get_or_create_dataset(dataset_name)
16
+
17
+ items = dataset.get_items()
18
+ if len(items) == nb_items:
19
+ return dataset
20
+ elif len(items) != 0:
21
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
22
+ elif len(items) == 0:
23
+ # Load data from file and insert into the dataset
24
+ json_content = (files('opik_optimizer') / 'data' / 'hotpot-500.json').read_text(encoding='utf-8')
25
+ all_data = json.loads(json_content)
26
+ trainset = all_data[:nb_items]
27
+
28
+ data = []
29
+ for row in reversed(trainset):
30
+ data.append(row)
31
+
32
+ dataset.insert(data)
33
+ return dataset
34
+
35
+ def hotpot_500(
36
+ test_mode: bool = False
37
+ ) -> opik.Dataset:
38
+ """
39
+ Dataset containing the first 500 samples of the HotpotQA dataset.
40
+ """
41
+ dataset_name = "hotpot_500" if not test_mode else "hotpot_500_test"
42
+ nb_items = 500 if not test_mode else 5
43
+
44
+ client = opik.Opik()
45
+ dataset = client.get_or_create_dataset(dataset_name)
46
+
47
+ items = dataset.get_items()
48
+ if len(items) == nb_items:
49
+ return dataset
50
+ elif len(items) != 0:
51
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
52
+ elif len(items) == 0:
53
+ # Load data from file and insert into the dataset
54
+ json_content = (files('opik_optimizer') / 'data' / 'hotpot-500.json').read_text(encoding='utf-8')
55
+ all_data = json.loads(json_content)
56
+ trainset = all_data[:nb_items]
57
+
58
+ data = []
59
+ for row in reversed(trainset):
60
+ data.append(row)
61
+
62
+ dataset.insert(data)
63
+ return dataset
64
+
65
+
66
+
67
+
@@ -0,0 +1,39 @@
1
+ import opik
2
+
3
+ def medhallu(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ dataset_name = "medhallu" if not test_mode else "medhallu_test"
7
+ nb_items = 300 if not test_mode else 5
8
+
9
+ client = opik.Opik()
10
+ dataset = client.get_or_create_dataset(dataset_name)
11
+
12
+ items = dataset.get_items()
13
+ if len(items) == nb_items:
14
+ return dataset
15
+ elif len(items) != 0:
16
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
17
+ elif len(items) == 0:
18
+ import datasets as ds
19
+
20
+ # Load data from file and insert into the dataset
21
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
22
+ ds.disable_progress_bar()
23
+ hf_dataset = ds.load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled", download_config=download_config)
24
+
25
+ data = [
26
+ {
27
+ "question": item["Question"],
28
+ "knowledge": item["Knowledge"],
29
+ "ground_truth": item["Ground Truth"],
30
+ "hallucinated_answer": item["Hallucinated Answer"],
31
+ "difficulty_level": item["Difficulty Level"],
32
+ "hallucination_category": item["Category of Hallucination"],
33
+ }
34
+ for item in hf_dataset["train"].select(range(nb_items))
35
+ ]
36
+ ds.enable_progress_bar()
37
+ dataset.insert(data)
38
+
39
+ return dataset
@@ -0,0 +1,41 @@
1
+ import opik
2
+
3
+ def rag_hallucinations(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 300 samples of the RAG Hallucinations dataset.
8
+ """
9
+ dataset_name = "rag_hallucination" if not test_mode else "rag_hallucination_test"
10
+ nb_items = 300 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import datasets as ds
22
+
23
+ # Load data from file and insert into the dataset
24
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
+ ds.disable_progress_bar()
26
+ hf_dataset = ds.load_dataset("aporia-ai/rag_hallucinations", download_config=download_config)
27
+
28
+ data = [
29
+ {
30
+ "context": item["context"],
31
+ "question": item["question"],
32
+ "answer": item["answer"],
33
+ "is_hallucination": item["is_hallucination"],
34
+ }
35
+ for item in hf_dataset["train"].select(range(nb_items))
36
+ ]
37
+ ds.enable_progress_bar()
38
+
39
+ dataset.insert(data)
40
+
41
+ return dataset
@@ -0,0 +1,40 @@
1
+ import opik
2
+
3
+ def ragbench_sentence_relevance(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 300 samples of the RAGBench sentence relevance dataset.
8
+ """
9
+ dataset_name = "ragbench_sentence_relevance" if not test_mode else "ragbench_sentence_relevance_test"
10
+ nb_items = 300 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import datasets as ds
22
+
23
+ # Load data from file and insert into the dataset
24
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
+ ds.disable_progress_bar()
26
+ hf_dataset = ds.load_dataset("wandb/ragbench-sentence-relevance-balanced", download_config=download_config)
27
+
28
+ data = [
29
+ {
30
+ "question": item["question"],
31
+ "sentence": item["sentence"],
32
+ "label": item["label"],
33
+ }
34
+ for item in hf_dataset["train"].select(range(nb_items))
35
+ ]
36
+ ds.enable_progress_bar()
37
+
38
+ dataset.insert(data)
39
+
40
+ return dataset
@@ -0,0 +1,57 @@
1
+ import opik
2
+
3
+ TINY_TEST_ITEMS = [
4
+ {
5
+ "text": "What is the capital of France?",
6
+ "label": "Paris",
7
+ "metadata": {
8
+ "context": "France is a country in Europe. Its capital is Paris."
9
+ },
10
+ },
11
+ {
12
+ "text": "Who wrote Romeo and Juliet?",
13
+ "label": "William Shakespeare",
14
+ "metadata": {
15
+ "context": "Romeo and Juliet is a famous play written by William Shakespeare."
16
+ },
17
+ },
18
+ {
19
+ "text": "What is 2 + 2?",
20
+ "label": "4",
21
+ "metadata": {"context": "Basic arithmetic: 2 + 2 equals 4."},
22
+ },
23
+ {
24
+ "text": "What is the largest planet in our solar system?",
25
+ "label": "Jupiter",
26
+ "metadata": {
27
+ "context": "Jupiter is the largest planet in our solar system."
28
+ },
29
+ },
30
+ {
31
+ "text": "Who painted the Mona Lisa?",
32
+ "label": "Leonardo da Vinci",
33
+ "metadata": {"context": "The Mona Lisa was painted by Leonardo da Vinci."},
34
+ },
35
+ ]
36
+
37
+ def tiny_test(
38
+ test_mode: bool = False
39
+ ) -> opik.Dataset:
40
+ """
41
+ Dataset containing the first 5 samples of the HotpotQA dataset.
42
+ """
43
+ dataset_name = "tiny_test" if not test_mode else "tiny_test_test"
44
+ nb_items = len(TINY_TEST_ITEMS)
45
+
46
+ client = opik.Opik()
47
+ dataset = client.get_or_create_dataset(dataset_name)
48
+
49
+ items = dataset.get_items()
50
+ if len(items) == nb_items:
51
+ return dataset
52
+ elif len(items) != 0:
53
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
54
+ elif len(items) == 0:
55
+ dataset.insert(TINY_TEST_ITEMS)
56
+ return dataset
57
+
@@ -0,0 +1,107 @@
1
+ import opik
2
+
3
+ def truthful_qa(
4
+ test_mode: bool = False
5
+ ) -> opik.Dataset:
6
+ """
7
+ Dataset containing the first 300 samples of the TruthfulQA dataset.
8
+ """
9
+ dataset_name = "truthful_qa" if not test_mode else "truthful_qa_test"
10
+ nb_items = 300 if not test_mode else 5
11
+
12
+ client = opik.Opik()
13
+ dataset = client.get_or_create_dataset(dataset_name)
14
+
15
+ items = dataset.get_items()
16
+ if len(items) == nb_items:
17
+ return dataset
18
+ elif len(items) != 0:
19
+ raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ elif len(items) == 0:
21
+ import datasets as ds
22
+
23
+ # Load data from file and insert into the dataset
24
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
+ ds.disable_progress_bar()
26
+
27
+ gen_dataset = ds.load_dataset("truthful_qa", "generation", download_config=download_config)
28
+ mc_dataset = ds.load_dataset("truthful_qa", "multiple_choice", download_config=download_config)
29
+
30
+ data = []
31
+ for gen_item, mc_item in zip(
32
+ gen_dataset["validation"], mc_dataset["validation"]
33
+ ):
34
+ if len(data) >= nb_items:
35
+ break
36
+
37
+ # Get correct answers from both configurations
38
+ correct_answers = set(gen_item["correct_answers"])
39
+ if "mc1_targets" in mc_item:
40
+ correct_answers.update(
41
+ [
42
+ choice
43
+ for choice, label in zip(
44
+ mc_item["mc1_targets"]["choices"],
45
+ mc_item["mc1_targets"]["labels"],
46
+ )
47
+ if label == 1
48
+ ]
49
+ )
50
+ if "mc2_targets" in mc_item:
51
+ correct_answers.update(
52
+ [
53
+ choice
54
+ for choice, label in zip(
55
+ mc_item["mc2_targets"]["choices"],
56
+ mc_item["mc2_targets"]["labels"],
57
+ )
58
+ if label == 1
59
+ ]
60
+ )
61
+
62
+ # Get all possible answers
63
+ all_answers = set(
64
+ gen_item["correct_answers"] + gen_item["incorrect_answers"]
65
+ )
66
+ if "mc1_targets" in mc_item:
67
+ all_answers.update(mc_item["mc1_targets"]["choices"])
68
+ if "mc2_targets" in mc_item:
69
+ all_answers.update(mc_item["mc2_targets"]["choices"])
70
+
71
+ # Create a single example with all necessary fields
72
+ example = {
73
+ "question": gen_item["question"],
74
+ "answer": gen_item["best_answer"],
75
+ "choices": list(all_answers),
76
+ "correct_answer": gen_item["best_answer"],
77
+ "input": gen_item["question"], # For AnswerRelevance metric
78
+ "output": gen_item["best_answer"], # For output_key requirement
79
+ "context": gen_item.get("source", ""), # Use source as context
80
+ "type": "TEXT", # Set type to TEXT as required by Opik
81
+ "category": gen_item["category"],
82
+ "source": "MANUAL", # Set source to MANUAL as required by Opik
83
+ "correct_answers": list(
84
+ correct_answers
85
+ ), # Keep track of all correct answers
86
+ "incorrect_answers": gen_item[
87
+ "incorrect_answers"
88
+ ], # Keep track of incorrect answers
89
+ }
90
+
91
+ # Ensure all required fields are present
92
+ required_fields = [
93
+ "question",
94
+ "answer",
95
+ "choices",
96
+ "correct_answer",
97
+ "input",
98
+ "output",
99
+ "context",
100
+ ]
101
+ if all(field in example and example[field] for field in required_fields):
102
+ data.append(example)
103
+ ds.enable_progress_bar()
104
+
105
+ dataset.insert(data)
106
+
107
+ return dataset