opik-optimizer 0.9.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. opik_optimizer/__init__.py +7 -3
  2. opik_optimizer/_throttle.py +8 -8
  3. opik_optimizer/base_optimizer.py +98 -45
  4. opik_optimizer/cache_config.py +5 -3
  5. opik_optimizer/datasets/ai2_arc.py +15 -13
  6. opik_optimizer/datasets/cnn_dailymail.py +19 -15
  7. opik_optimizer/datasets/election_questions.py +10 -11
  8. opik_optimizer/datasets/gsm8k.py +16 -11
  9. opik_optimizer/datasets/halu_eval.py +6 -5
  10. opik_optimizer/datasets/hotpot_qa.py +17 -16
  11. opik_optimizer/datasets/medhallu.py +10 -7
  12. opik_optimizer/datasets/rag_hallucinations.py +11 -8
  13. opik_optimizer/datasets/ragbench.py +17 -9
  14. opik_optimizer/datasets/tiny_test.py +33 -37
  15. opik_optimizer/datasets/truthful_qa.py +18 -12
  16. opik_optimizer/demo/cache.py +6 -6
  17. opik_optimizer/demo/datasets.py +3 -7
  18. opik_optimizer/evolutionary_optimizer/__init__.py +3 -1
  19. opik_optimizer/evolutionary_optimizer/evolutionary_optimizer.py +748 -437
  20. opik_optimizer/evolutionary_optimizer/reporting.py +155 -76
  21. opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +291 -181
  22. opik_optimizer/few_shot_bayesian_optimizer/reporting.py +79 -28
  23. opik_optimizer/logging_config.py +19 -15
  24. opik_optimizer/meta_prompt_optimizer/meta_prompt_optimizer.py +234 -138
  25. opik_optimizer/meta_prompt_optimizer/reporting.py +121 -47
  26. opik_optimizer/mipro_optimizer/__init__.py +2 -0
  27. opik_optimizer/mipro_optimizer/_lm.py +41 -9
  28. opik_optimizer/mipro_optimizer/_mipro_optimizer_v2.py +37 -26
  29. opik_optimizer/mipro_optimizer/mipro_optimizer.py +135 -67
  30. opik_optimizer/mipro_optimizer/utils.py +5 -2
  31. opik_optimizer/optimizable_agent.py +179 -0
  32. opik_optimizer/optimization_config/chat_prompt.py +143 -73
  33. opik_optimizer/optimization_config/configs.py +4 -3
  34. opik_optimizer/optimization_config/mappers.py +18 -6
  35. opik_optimizer/optimization_result.py +28 -20
  36. opik_optimizer/py.typed +0 -0
  37. opik_optimizer/reporting_utils.py +96 -46
  38. opik_optimizer/task_evaluator.py +12 -14
  39. opik_optimizer/utils.py +122 -37
  40. {opik_optimizer-0.9.1.dist-info → opik_optimizer-1.0.0.dist-info}/METADATA +8 -8
  41. opik_optimizer-1.0.0.dist-info/RECORD +50 -0
  42. opik_optimizer-0.9.1.dist-info/RECORD +0 -48
  43. {opik_optimizer-0.9.1.dist-info → opik_optimizer-1.0.0.dist-info}/WHEEL +0 -0
  44. {opik_optimizer-0.9.1.dist-info → opik_optimizer-1.0.0.dist-info}/licenses/LICENSE +0 -0
  45. {opik_optimizer-0.9.1.dist-info → opik_optimizer-1.0.0.dist-info}/top_level.txt +0 -0
@@ -3,15 +3,18 @@ import logging
3
3
 
4
4
  from opik.evaluation.models.litellm import warning_filters
5
5
 
6
- from opik_optimizer.evolutionary_optimizer.evolutionary_optimizer import EvolutionaryOptimizer
6
+ from opik_optimizer.evolutionary_optimizer.evolutionary_optimizer import (
7
+ EvolutionaryOptimizer,
8
+ )
7
9
 
8
10
  from . import datasets
11
+ from .optimizable_agent import OptimizableAgent
12
+ from .optimization_config.chat_prompt import ChatPrompt
9
13
  from .base_optimizer import BaseOptimizer
10
14
  from .few_shot_bayesian_optimizer import FewShotBayesianOptimizer
11
15
  from .logging_config import setup_logging
12
16
  from .meta_prompt_optimizer import MetaPromptOptimizer
13
17
  from .mipro_optimizer import MiproOptimizer
14
- from .optimization_config.chat_prompt import ChatPrompt
15
18
  from .optimization_config.configs import TaskConfig
16
19
  from .optimization_result import OptimizationResult
17
20
 
@@ -30,7 +33,8 @@ __all__ = [
30
33
  "MiproOptimizer",
31
34
  "EvolutionaryOptimizer",
32
35
  "OptimizationResult",
36
+ "OptimizableAgent",
33
37
  "setup_logging",
34
38
  "datasets",
35
- "TaskConfig"
39
+ "TaskConfig",
36
40
  ]
@@ -10,34 +10,34 @@ class RateLimiter:
10
10
  """
11
11
  Rate limiter that enforces a maximum number of calls across all threads using pyrate_limiter.
12
12
  """
13
+
13
14
  def __init__(self, max_calls_per_second: int):
14
15
  self.max_calls_per_second = max_calls_per_second
15
16
  rate = pyrate_limiter.Rate(max_calls_per_second, pyrate_limiter.Duration.SECOND)
16
17
 
17
18
  self.limiter = pyrate_limiter.Limiter(rate, raise_when_fail=False)
18
19
  self.bucket_key = "global_rate_limit"
19
-
20
+
20
21
  def acquire(self) -> None:
21
22
  while not self.limiter.try_acquire(self.bucket_key):
22
23
  time.sleep(0.01)
23
24
 
25
+
24
26
  def rate_limited(limiter: RateLimiter) -> Callable[[Callable], Callable]:
25
27
  """Decorator to rate limit a function using the provided limiter"""
26
28
 
27
29
  def decorator(func: Callable) -> Callable:
28
30
  @functools.wraps(func)
29
- def wrapper(*args, **kwargs) -> Any:
31
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
30
32
  limiter.acquire()
31
33
  return func(*args, **kwargs)
34
+
32
35
  return wrapper
36
+
33
37
  return decorator
34
38
 
35
39
 
36
40
  def get_rate_limiter_for_current_opik_installation() -> RateLimiter:
37
41
  opik_config = opik.config.OpikConfig()
38
- max_calls_per_second = (
39
- 10
40
- if opik_config.is_cloud_installation
41
- else 50
42
- )
43
- return RateLimiter(max_calls_per_second=max_calls_per_second)
42
+ max_calls_per_second = 10 if opik_config.is_cloud_installation else 50
43
+ return RateLimiter(max_calls_per_second=max_calls_per_second)
@@ -1,16 +1,23 @@
1
+ from typing import Any, Callable, Dict, List, Optional, Type
2
+
1
3
  import logging
2
4
  import time
3
5
  from abc import abstractmethod
4
- from typing import Any, Callable, Dict, List, Optional
6
+ import random
7
+
5
8
 
6
9
  import litellm
7
- import opik
8
10
  from opik.rest_api.core import ApiError
11
+ from opik.api_objects import optimization
12
+ from opik import Dataset
9
13
  from pydantic import BaseModel
10
14
 
11
15
  from . import _throttle, optimization_result
12
16
  from .cache_config import initialize_cache
13
- from .optimization_config import chat_prompt
17
+ from .optimization_config import chat_prompt, mappers
18
+ from .optimizable_agent import OptimizableAgent
19
+ from .utils import create_litellm_agent_class
20
+ from . import task_evaluator
14
21
 
15
22
  _limiter = _throttle.get_rate_limiter_for_current_opik_installation()
16
23
 
@@ -34,22 +41,25 @@ class OptimizationRound(BaseModel):
34
41
 
35
42
 
36
43
  class BaseOptimizer:
37
- def __init__(self, model: str, project_name: Optional[str] = None, verbose: int = 1, **model_kwargs):
44
+ def __init__(
45
+ self,
46
+ model: str,
47
+ verbose: int = 1,
48
+ **model_kwargs: Any,
49
+ ) -> None:
38
50
  """
39
51
  Base class for optimizers.
40
52
 
41
53
  Args:
42
54
  model: LiteLLM model name
43
- project_name: Opik project name
44
55
  verbose: Controls internal logging/progress bars (0=off, 1=on).
45
56
  model_kwargs: additional args for model (eg, temperature)
46
57
  """
47
58
  self.model = model
48
59
  self.reasoning_model = model
49
60
  self.model_kwargs = model_kwargs
50
- self.project_name = project_name
51
61
  self.verbose = verbose
52
- self._history = []
62
+ self._history: List[OptimizationRound] = []
53
63
  self.experiment_config = None
54
64
  self.llm_call_counter = 0
55
65
 
@@ -59,18 +69,18 @@ class BaseOptimizer:
59
69
  @abstractmethod
60
70
  def optimize_prompt(
61
71
  self,
62
- prompt: chat_prompt.ChatPrompt,
63
- dataset: opik.Dataset,
64
- metrics: List[Callable],
72
+ prompt: "chat_prompt.ChatPrompt",
73
+ dataset: Dataset,
74
+ metric: Callable,
65
75
  experiment_config: Optional[Dict] = None,
66
- **kwargs,
76
+ **kwargs: Any,
67
77
  ) -> optimization_result.OptimizationResult:
68
78
  """
69
79
  Optimize a prompt.
70
80
 
71
81
  Args:
72
82
  dataset: Opik dataset name, or Opik dataset
73
- metrics: A list of metric functions, these functions should have two arguments:
83
+ metric: A metric function, this function should have two arguments:
74
84
  dataset_item and llm_output
75
85
  prompt: the prompt to optimize
76
86
  input_key: input field of dataset
@@ -80,36 +90,7 @@ class BaseOptimizer:
80
90
  """
81
91
  pass
82
92
 
83
- @abstractmethod
84
- def evaluate_prompt(
85
- self,
86
- prompt: chat_prompt.ChatPrompt,
87
- dataset: opik.Dataset,
88
- metrics: List[Callable],
89
- n_samples: Optional[int] = None,
90
- dataset_item_ids: Optional[List[str]] = None,
91
- experiment_config: Optional[Dict] = None,
92
- **kwargs,
93
- ) -> float:
94
- """
95
- Evaluate a prompt.
96
-
97
- Args:
98
- prompt: the prompt to evaluate
99
- dataset: Opik dataset name, or Opik dataset
100
- metrics: A list of metric functions, these functions should have two arguments:
101
- dataset_item and llm_output
102
- n_samples: number of items to test in the dataset
103
- dataset_item_ids: Optional list of dataset item IDs to evaluate
104
- experiment_config: Optional configuration for the experiment
105
- **kwargs: Additional arguments for evaluation
106
-
107
- Returns:
108
- float: The evaluation score
109
- """
110
- pass
111
-
112
- def get_history(self) -> List[Dict[str, Any]]:
93
+ def get_history(self) -> List[OptimizationRound]:
113
94
  """
114
95
  Get the optimization history.
115
96
 
@@ -118,7 +99,7 @@ class BaseOptimizer:
118
99
  """
119
100
  return self._history
120
101
 
121
- def _add_to_history(self, round_data: Dict[str, Any]):
102
+ def _add_to_history(self, round_data: OptimizationRound) -> None:
122
103
  """
123
104
  Add a round to the optimization history.
124
105
 
@@ -127,8 +108,9 @@ class BaseOptimizer:
127
108
  """
128
109
  self._history.append(round_data)
129
110
 
130
-
131
- def update_optimization(self, optimization, status: str) -> None:
111
+ def update_optimization(
112
+ self, optimization: optimization.Optimization, status: str
113
+ ) -> None:
132
114
  """
133
115
  Update the optimization status
134
116
  """
@@ -143,3 +125,74 @@ class BaseOptimizer:
143
125
  time.sleep(5)
144
126
  if count == 3:
145
127
  logger.warning("Unable to update optimization status; continuing...")
128
+
129
+ def evaluate_prompt(
130
+ self,
131
+ prompt: chat_prompt.ChatPrompt,
132
+ dataset: Dataset,
133
+ metric: Callable,
134
+ n_threads: int,
135
+ verbose: int = 1,
136
+ dataset_item_ids: Optional[List[str]] = None,
137
+ experiment_config: Optional[Dict] = None,
138
+ n_samples: Optional[int] = None,
139
+ seed: Optional[int] = None,
140
+ agent_class: Optional[Type[OptimizableAgent]] = None,
141
+ ) -> float:
142
+ random.seed(seed)
143
+
144
+ if prompt.model is None:
145
+ prompt.model = self.model
146
+ if prompt.model_kwargs is None:
147
+ prompt.model_kwargs = self.model_kwargs
148
+
149
+ self.agent_class: Type[OptimizableAgent]
150
+
151
+ if agent_class is None:
152
+ self.agent_class = create_litellm_agent_class(prompt)
153
+ else:
154
+ self.agent_class = agent_class
155
+
156
+ agent = self.agent_class(prompt)
157
+
158
+ def llm_task(dataset_item: Dict[str, Any]) -> Dict[str, str]:
159
+ messages = prompt.get_messages(dataset_item)
160
+ raw_model_output = agent.invoke(messages)
161
+ cleaned_model_output = raw_model_output.strip()
162
+ result = {
163
+ mappers.EVALUATED_LLM_TASK_OUTPUT: cleaned_model_output,
164
+ }
165
+ return result
166
+
167
+ experiment_config = experiment_config or {}
168
+ experiment_config["project_name"] = self.__class__.__name__
169
+ experiment_config = {
170
+ **experiment_config,
171
+ **{
172
+ "agent_class": self.agent_class.__name__,
173
+ "agent_config": prompt.to_dict(),
174
+ "metric": metric.__name__,
175
+ "dataset": dataset.name,
176
+ "configuration": {"prompt": (prompt.get_messages() if prompt else [])},
177
+ },
178
+ }
179
+
180
+ if n_samples is not None:
181
+ if dataset_item_ids is not None:
182
+ raise Exception("Can't use n_samples and dataset_item_ids")
183
+
184
+ all_ids = [dataset_item["id"] for dataset_item in dataset.get_items()]
185
+ dataset_item_ids = random.sample(all_ids, n_samples)
186
+
187
+ score = task_evaluator.evaluate(
188
+ dataset=dataset,
189
+ dataset_item_ids=dataset_item_ids,
190
+ metric=metric,
191
+ evaluated_task=llm_task,
192
+ num_threads=n_threads,
193
+ project_name=self.agent_class.project_name,
194
+ experiment_config=experiment_config,
195
+ optimization_id=None,
196
+ verbose=verbose,
197
+ )
198
+ return score
@@ -13,12 +13,14 @@ CACHE_CONFIG = {
13
13
  "disk_cache_dir": CACHE_DIR,
14
14
  }
15
15
 
16
- def initialize_cache():
16
+
17
+ def initialize_cache() -> Cache:
17
18
  """Initialize the LiteLLM cache with custom configuration."""
18
19
  litellm.cache = Cache(**CACHE_CONFIG)
19
20
  return litellm.cache
20
21
 
21
- def clear_cache():
22
+
23
+ def clear_cache() -> None:
22
24
  """Clear the LiteLLM cache."""
23
25
  if litellm.cache:
24
- litellm.cache.clear()
26
+ litellm.cache.clear()
@@ -1,8 +1,7 @@
1
1
  import opik
2
2
 
3
- def ai2_arc(
4
- test_mode: bool = False
5
- ) -> opik.Dataset:
3
+
4
+ def ai2_arc(test_mode: bool = False) -> opik.Dataset:
6
5
  """
7
6
  Dataset containing the first 300 samples of the AI2 ARC dataset.
8
7
  """
@@ -11,12 +10,14 @@ def ai2_arc(
11
10
 
12
11
  client = opik.Opik()
13
12
  dataset = client.get_or_create_dataset(dataset_name)
14
-
13
+
15
14
  items = dataset.get_items()
16
15
  if len(items) == nb_items:
17
16
  return dataset
18
17
  elif len(items) != 0:
19
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
18
+ raise ValueError(
19
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
20
+ )
20
21
  elif len(items) == 0:
21
22
  import datasets as ds
22
23
 
@@ -24,19 +25,20 @@ def ai2_arc(
24
25
  download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
26
  ds.disable_progress_bar()
26
27
  hf_dataset = ds.load_dataset(
27
- "ai2_arc", "ARC-Challenge",
28
- streaming=True, download_config=download_config
28
+ "ai2_arc", "ARC-Challenge", streaming=True, download_config=download_config
29
29
  )
30
-
30
+
31
31
  data = []
32
32
  for i, item in enumerate(hf_dataset["train"]):
33
33
  if i >= nb_items:
34
34
  break
35
- data.append({
36
- "question": item["question"],
37
- "answer": item["answerKey"],
38
- "choices": item["choices"],
39
- })
35
+ data.append(
36
+ {
37
+ "question": item["question"],
38
+ "answer": item["answerKey"],
39
+ "choices": item["choices"],
40
+ }
41
+ )
40
42
  ds.enable_progress_bar()
41
43
 
42
44
  dataset.insert(data)
@@ -1,8 +1,7 @@
1
1
  import opik
2
2
 
3
- def cnn_dailymail(
4
- test_mode: bool = False
5
- ) -> opik.Dataset:
3
+
4
+ def cnn_dailymail(test_mode: bool = False) -> opik.Dataset:
6
5
  """
7
6
  Dataset containing the first 100 samples of the CNN Daily Mail dataset.
8
7
  """
@@ -11,30 +10,35 @@ def cnn_dailymail(
11
10
 
12
11
  client = opik.Opik()
13
12
  dataset = client.get_or_create_dataset(dataset_name)
14
-
13
+
15
14
  items = dataset.get_items()
16
15
  if len(items) == nb_items:
17
16
  return dataset
18
17
  elif len(items) != 0:
19
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
18
+ raise ValueError(
19
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
20
+ )
20
21
  elif len(items) == 0:
21
22
  import datasets as ds
22
-
23
+
23
24
  download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
24
25
  ds.disable_progress_bar()
25
- hf_dataset = ds.load_dataset("cnn_dailymail", "3.0.0", streaming=True, download_config=download_config)
26
-
26
+ hf_dataset = ds.load_dataset(
27
+ "cnn_dailymail", "3.0.0", streaming=True, download_config=download_config
28
+ )
29
+
27
30
  data = []
28
31
  for i, item in enumerate(hf_dataset["validation"]):
29
32
  if i >= nb_items:
30
33
  break
31
- data.append({
32
- "article": item["article"],
33
- "highlights": item["highlights"],
34
- })
34
+ data.append(
35
+ {
36
+ "article": item["article"],
37
+ "highlights": item["highlights"],
38
+ }
39
+ )
35
40
  ds.enable_progress_bar()
36
-
41
+
37
42
  dataset.insert(data)
38
-
43
+
39
44
  return dataset
40
-
@@ -1,33 +1,32 @@
1
1
  import opik
2
2
 
3
3
 
4
- def election_questions(
5
- test_mode: bool = False
6
- ) -> opik.Dataset:
4
+ def election_questions(test_mode: bool = False) -> opik.Dataset:
7
5
  dataset_name = "election_questions" if not test_mode else "election_questions_test"
8
6
  nb_items = 300 if not test_mode else 5
9
7
 
10
8
  client = opik.Opik()
11
9
  dataset = client.get_or_create_dataset(dataset_name)
12
-
10
+
13
11
  items = dataset.get_items()
14
12
  if len(items) == nb_items:
15
13
  return dataset
16
14
  elif len(items) != 0:
17
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
15
+ raise ValueError(
16
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
17
+ )
18
18
  elif len(items) == 0:
19
19
  import datasets as ds
20
20
 
21
21
  # Load data from file and insert into the dataset
22
22
  download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
23
23
  ds.disable_progress_bar()
24
- hf_dataset = ds.load_dataset("Anthropic/election_questions", download_config=download_config)
25
-
24
+ hf_dataset = ds.load_dataset(
25
+ "Anthropic/election_questions", download_config=download_config
26
+ )
27
+
26
28
  data = [
27
- {
28
- "question": item["question"],
29
- "label": item["label"]
30
- }
29
+ {"question": item["question"], "label": item["label"]}
31
30
  for item in hf_dataset["test"].select(range(nb_items))
32
31
  ]
33
32
  ds.enable_progress_bar()
@@ -1,8 +1,7 @@
1
1
  import opik
2
2
 
3
- def gsm8k(
4
- test_mode: bool = False
5
- ) -> opik.Dataset:
3
+
4
+ def gsm8k(test_mode: bool = False) -> opik.Dataset:
6
5
  """
7
6
  Dataset containing the first 300 samples of the GSM8K dataset.
8
7
  """
@@ -11,28 +10,34 @@ def gsm8k(
11
10
 
12
11
  client = opik.Opik()
13
12
  dataset = client.get_or_create_dataset(dataset_name)
14
-
13
+
15
14
  items = dataset.get_items()
16
15
  if len(items) == nb_items:
17
16
  return dataset
18
17
  elif len(items) != 0:
19
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
18
+ raise ValueError(
19
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
20
+ )
20
21
  elif len(items) == 0:
21
22
  import datasets as ds
22
23
 
23
24
  # Load data from file and insert into the dataset
24
25
  download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
26
  ds.disable_progress_bar()
26
- hf_dataset = ds.load_dataset("gsm8k", "main", streaming=True, download_config=download_config)
27
-
27
+ hf_dataset = ds.load_dataset(
28
+ "gsm8k", "main", streaming=True, download_config=download_config
29
+ )
30
+
28
31
  data = []
29
32
  for i, item in enumerate(hf_dataset["train"]):
30
33
  if i >= nb_items:
31
34
  break
32
- data.append({
33
- "question": item["question"],
34
- "answer": item["answer"],
35
- })
35
+ data.append(
36
+ {
37
+ "question": item["question"],
38
+ "answer": item["answer"],
39
+ }
40
+ )
36
41
  ds.enable_progress_bar()
37
42
 
38
43
  dataset.insert(data)
@@ -1,8 +1,7 @@
1
1
  import opik
2
2
 
3
- def halu_eval_300(
4
- test_mode: bool = False
5
- ) -> opik.Dataset:
3
+
4
+ def halu_eval_300(test_mode: bool = False) -> opik.Dataset:
6
5
  """
7
6
  Dataset containing the first 300 samples of the HaluEval dataset.
8
7
  """
@@ -11,12 +10,14 @@ def halu_eval_300(
11
10
 
12
11
  client = opik.Opik()
13
12
  dataset = client.get_or_create_dataset(dataset_name)
14
-
13
+
15
14
  items = dataset.get_items()
16
15
  if len(items) == nb_items:
17
16
  return dataset
18
17
  elif len(items) != 0:
19
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
18
+ raise ValueError(
19
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
20
+ )
20
21
  elif len(items) == 0:
21
22
  import pandas as pd
22
23
 
@@ -3,9 +3,7 @@ from importlib.resources import files
3
3
  import json
4
4
 
5
5
 
6
- def hotpot_300(
7
- test_mode: bool = False
8
- ) -> opik.Dataset:
6
+ def hotpot_300(test_mode: bool = False) -> opik.Dataset:
9
7
  """
10
8
  Dataset containing the first 300 samples of the HotpotQA dataset.
11
9
  """
@@ -14,15 +12,19 @@ def hotpot_300(
14
12
 
15
13
  client = opik.Opik()
16
14
  dataset = client.get_or_create_dataset(dataset_name)
17
-
15
+
18
16
  items = dataset.get_items()
19
17
  if len(items) == nb_items:
20
18
  return dataset
21
19
  elif len(items) != 0:
22
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
20
+ raise ValueError(
21
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
22
+ )
23
23
  elif len(items) == 0:
24
24
  # Load data from file and insert into the dataset
25
- json_content = (files('opik_optimizer') / 'data' / 'hotpot-500.json').read_text(encoding='utf-8')
25
+ json_content = (files("opik_optimizer") / "data" / "hotpot-500.json").read_text(
26
+ encoding="utf-8"
27
+ )
26
28
  all_data = json.loads(json_content)
27
29
  trainset = all_data[:nb_items]
28
30
 
@@ -33,9 +35,8 @@ def hotpot_300(
33
35
  dataset.insert(data)
34
36
  return dataset
35
37
 
36
- def hotpot_500(
37
- test_mode: bool = False
38
- ) -> opik.Dataset:
38
+
39
+ def hotpot_500(test_mode: bool = False) -> opik.Dataset:
39
40
  """
40
41
  Dataset containing the first 500 samples of the HotpotQA dataset.
41
42
  """
@@ -44,15 +45,19 @@ def hotpot_500(
44
45
 
45
46
  client = opik.Opik()
46
47
  dataset = client.get_or_create_dataset(dataset_name)
47
-
48
+
48
49
  items = dataset.get_items()
49
50
  if len(items) == nb_items:
50
51
  return dataset
51
52
  elif len(items) != 0:
52
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
53
+ raise ValueError(
54
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
55
+ )
53
56
  elif len(items) == 0:
54
57
  # Load data from file and insert into the dataset
55
- json_content = (files('opik_optimizer') / 'data' / 'hotpot-500.json').read_text(encoding='utf-8')
58
+ json_content = (files("opik_optimizer") / "data" / "hotpot-500.json").read_text(
59
+ encoding="utf-8"
60
+ )
56
61
  all_data = json.loads(json_content)
57
62
  trainset = all_data[:nb_items]
58
63
 
@@ -62,7 +67,3 @@ def hotpot_500(
62
67
 
63
68
  dataset.insert(data)
64
69
  return dataset
65
-
66
-
67
-
68
-
@@ -1,27 +1,30 @@
1
1
  import opik
2
2
 
3
- def medhallu(
4
- test_mode: bool = False
5
- ) -> opik.Dataset:
3
+
4
+ def medhallu(test_mode: bool = False) -> opik.Dataset:
6
5
  dataset_name = "medhallu" if not test_mode else "medhallu_test"
7
6
  nb_items = 300 if not test_mode else 5
8
7
 
9
8
  client = opik.Opik()
10
9
  dataset = client.get_or_create_dataset(dataset_name)
11
-
10
+
12
11
  items = dataset.get_items()
13
12
  if len(items) == nb_items:
14
13
  return dataset
15
14
  elif len(items) != 0:
16
- raise ValueError(f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it.")
15
+ raise ValueError(
16
+ f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
17
+ )
17
18
  elif len(items) == 0:
18
19
  import datasets as ds
19
20
 
20
21
  # Load data from file and insert into the dataset
21
22
  download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
22
23
  ds.disable_progress_bar()
23
- hf_dataset = ds.load_dataset("UTAustin-AIHealth/MedHallu", "pqa_labeled", download_config=download_config)
24
-
24
+ hf_dataset = ds.load_dataset(
25
+ "UTAustin-AIHealth/MedHallu", "pqa_labeled", download_config=download_config
26
+ )
27
+
25
28
  data = [
26
29
  {
27
30
  "question": item["Question"],