opik-optimizer 1.0.5__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. opik_optimizer/__init__.py +2 -0
  2. opik_optimizer/_throttle.py +2 -1
  3. opik_optimizer/base_optimizer.py +28 -11
  4. opik_optimizer/colbert.py +236 -0
  5. opik_optimizer/data/context7_eval.jsonl +3 -0
  6. opik_optimizer/datasets/context7_eval.py +90 -0
  7. opik_optimizer/datasets/tiny_test.py +33 -34
  8. opik_optimizer/datasets/truthful_qa.py +2 -2
  9. opik_optimizer/evolutionary_optimizer/crossover_ops.py +194 -0
  10. opik_optimizer/evolutionary_optimizer/evaluation_ops.py +73 -0
  11. opik_optimizer/evolutionary_optimizer/evolutionary_optimizer.py +124 -941
  12. opik_optimizer/evolutionary_optimizer/helpers.py +10 -0
  13. opik_optimizer/evolutionary_optimizer/llm_support.py +134 -0
  14. opik_optimizer/evolutionary_optimizer/mutation_ops.py +292 -0
  15. opik_optimizer/evolutionary_optimizer/population_ops.py +223 -0
  16. opik_optimizer/evolutionary_optimizer/prompts.py +305 -0
  17. opik_optimizer/evolutionary_optimizer/reporting.py +16 -4
  18. opik_optimizer/evolutionary_optimizer/style_ops.py +86 -0
  19. opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +26 -23
  20. opik_optimizer/few_shot_bayesian_optimizer/reporting.py +12 -5
  21. opik_optimizer/gepa_optimizer/__init__.py +3 -0
  22. opik_optimizer/gepa_optimizer/adapter.py +152 -0
  23. opik_optimizer/gepa_optimizer/gepa_optimizer.py +556 -0
  24. opik_optimizer/gepa_optimizer/reporting.py +181 -0
  25. opik_optimizer/logging_config.py +42 -7
  26. opik_optimizer/mcp_utils/__init__.py +22 -0
  27. opik_optimizer/mcp_utils/mcp.py +541 -0
  28. opik_optimizer/mcp_utils/mcp_second_pass.py +152 -0
  29. opik_optimizer/mcp_utils/mcp_simulator.py +116 -0
  30. opik_optimizer/mcp_utils/mcp_workflow.py +493 -0
  31. opik_optimizer/meta_prompt_optimizer/meta_prompt_optimizer.py +399 -69
  32. opik_optimizer/meta_prompt_optimizer/reporting.py +16 -2
  33. opik_optimizer/mipro_optimizer/_lm.py +20 -20
  34. opik_optimizer/mipro_optimizer/_mipro_optimizer_v2.py +51 -50
  35. opik_optimizer/mipro_optimizer/mipro_optimizer.py +33 -28
  36. opik_optimizer/mipro_optimizer/utils.py +2 -4
  37. opik_optimizer/optimizable_agent.py +18 -17
  38. opik_optimizer/optimization_config/chat_prompt.py +44 -23
  39. opik_optimizer/optimization_config/configs.py +3 -3
  40. opik_optimizer/optimization_config/mappers.py +9 -8
  41. opik_optimizer/optimization_result.py +21 -14
  42. opik_optimizer/reporting_utils.py +61 -10
  43. opik_optimizer/task_evaluator.py +9 -8
  44. opik_optimizer/utils/__init__.py +15 -0
  45. opik_optimizer/{utils.py → utils/core.py} +111 -26
  46. opik_optimizer/utils/dataset_utils.py +49 -0
  47. opik_optimizer/utils/prompt_segments.py +186 -0
  48. {opik_optimizer-1.0.5.dist-info → opik_optimizer-1.1.0.dist-info}/METADATA +93 -16
  49. opik_optimizer-1.1.0.dist-info/RECORD +73 -0
  50. opik_optimizer-1.1.0.dist-info/licenses/LICENSE +203 -0
  51. opik_optimizer-1.0.5.dist-info/RECORD +0 -50
  52. opik_optimizer-1.0.5.dist-info/licenses/LICENSE +0 -21
  53. {opik_optimizer-1.0.5.dist-info → opik_optimizer-1.1.0.dist-info}/WHEEL +0 -0
  54. {opik_optimizer-1.0.5.dist-info → opik_optimizer-1.1.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from .optimizable_agent import OptimizableAgent
12
12
  from .optimization_config.chat_prompt import ChatPrompt
13
13
  from .base_optimizer import BaseOptimizer
14
14
  from .few_shot_bayesian_optimizer import FewShotBayesianOptimizer
15
+ from .gepa_optimizer import GepaOptimizer
15
16
  from .logging_config import setup_logging
16
17
  from .meta_prompt_optimizer import MetaPromptOptimizer
17
18
  from .optimization_config.configs import TaskConfig
@@ -28,6 +29,7 @@ __all__ = [
28
29
  "BaseOptimizer",
29
30
  "ChatPrompt",
30
31
  "FewShotBayesianOptimizer",
32
+ "GepaOptimizer",
31
33
  "MetaPromptOptimizer",
32
34
  "EvolutionaryOptimizer",
33
35
  "OptimizationResult",
@@ -3,7 +3,8 @@ import pyrate_limiter
3
3
  import time
4
4
  import opik.config
5
5
 
6
- from typing import Callable, Any
6
+ from typing import Any
7
+ from collections.abc import Callable
7
8
 
8
9
 
9
10
  class RateLimiter:
@@ -1,4 +1,5 @@
1
- from typing import Any, Callable, Dict, List, Optional, Type
1
+ from typing import Any
2
+ from collections.abc import Callable
2
3
 
3
4
  import logging
4
5
  import time
@@ -59,7 +60,7 @@ class BaseOptimizer:
59
60
  self.reasoning_model = model
60
61
  self.model_kwargs = model_kwargs
61
62
  self.verbose = verbose
62
- self._history: List[OptimizationRound] = []
63
+ self._history: list[OptimizationRound] = []
63
64
  self.experiment_config = None
64
65
  self.llm_call_counter = 0
65
66
 
@@ -72,7 +73,7 @@ class BaseOptimizer:
72
73
  prompt: "chat_prompt.ChatPrompt",
73
74
  dataset: Dataset,
74
75
  metric: Callable,
75
- experiment_config: Optional[Dict] = None,
76
+ experiment_config: dict | None = None,
76
77
  **kwargs: Any,
77
78
  ) -> optimization_result.OptimizationResult:
78
79
  """
@@ -90,7 +91,23 @@ class BaseOptimizer:
90
91
  """
91
92
  pass
92
93
 
93
- def get_history(self) -> List[OptimizationRound]:
94
+ def optimize_mcp(
95
+ self,
96
+ prompt: "chat_prompt.ChatPrompt",
97
+ dataset: Dataset,
98
+ metric: Callable,
99
+ *,
100
+ tool_name: str,
101
+ second_pass: Any,
102
+ experiment_config: dict | None = None,
103
+ **kwargs: Any,
104
+ ) -> optimization_result.OptimizationResult:
105
+ """Optimize prompts that rely on MCP tooling."""
106
+ raise NotImplementedError(
107
+ f"{self.__class__.__name__} does not implement optimize_mcp yet."
108
+ )
109
+
110
+ def get_history(self) -> list[OptimizationRound]:
94
111
  """
95
112
  Get the optimization history.
96
113
 
@@ -133,11 +150,11 @@ class BaseOptimizer:
133
150
  metric: Callable,
134
151
  n_threads: int,
135
152
  verbose: int = 1,
136
- dataset_item_ids: Optional[List[str]] = None,
137
- experiment_config: Optional[Dict] = None,
138
- n_samples: Optional[int] = None,
139
- seed: Optional[int] = None,
140
- agent_class: Optional[Type[OptimizableAgent]] = None,
153
+ dataset_item_ids: list[str] | None = None,
154
+ experiment_config: dict | None = None,
155
+ n_samples: int | None = None,
156
+ seed: int | None = None,
157
+ agent_class: type[OptimizableAgent] | None = None,
141
158
  ) -> float:
142
159
  random.seed(seed)
143
160
 
@@ -146,7 +163,7 @@ class BaseOptimizer:
146
163
  if prompt.model_kwargs is None:
147
164
  prompt.model_kwargs = self.model_kwargs
148
165
 
149
- self.agent_class: Type[OptimizableAgent]
166
+ self.agent_class: type[OptimizableAgent]
150
167
 
151
168
  if agent_class is None:
152
169
  self.agent_class = create_litellm_agent_class(prompt)
@@ -155,7 +172,7 @@ class BaseOptimizer:
155
172
 
156
173
  agent = self.agent_class(prompt)
157
174
 
158
- def llm_task(dataset_item: Dict[str, Any]) -> Dict[str, str]:
175
+ def llm_task(dataset_item: dict[str, Any]) -> dict[str, str]:
159
176
  messages = prompt.get_messages(dataset_item)
160
177
  raw_model_output = agent.invoke(messages)
161
178
  cleaned_model_output = raw_model_output.strip()
@@ -0,0 +1,236 @@
1
+ """
2
+ Minimal ColBERTv2 implementation extracted from dspy (MIT license).
3
+
4
+ This module provides a lightweight implementation of ColBERTv2 search functionality
5
+ without requiring the full dspy dependency.
6
+ """
7
+
8
+ import copy
9
+ import time
10
+ from typing import Any
11
+ import requests # type: ignore[import-untyped]
12
+ from requests.adapters import HTTPAdapter # type: ignore[import-untyped]
13
+ from urllib3.util.retry import Retry
14
+
15
+
16
+ def _create_session_with_retries(max_retries: int = 4) -> requests.Session:
17
+ """
18
+ Create a requests session with retry configuration.
19
+
20
+ Args:
21
+ max_retries: Maximum number of retry attempts
22
+
23
+ Returns:
24
+ Configured requests session
25
+ """
26
+ session = requests.Session()
27
+
28
+ retry_strategy = Retry(
29
+ total=max_retries,
30
+ backoff_factor=1, # Wait 1, 2, 4, 8 seconds between retries
31
+ status_forcelist=[429, 500, 502, 503, 504], # HTTP status codes to retry on
32
+ allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"],
33
+ )
34
+
35
+ adapter = HTTPAdapter(max_retries=retry_strategy)
36
+ session.mount("http://", adapter)
37
+ session.mount("https://", adapter)
38
+
39
+ return session
40
+
41
+
42
+ class dotdict(dict):
43
+ """Dictionary with attribute access (extracted from dspy)."""
44
+
45
+ def __getattr__(self, key: str) -> Any:
46
+ if key.startswith("__") and key.endswith("__"):
47
+ return super().__getattribute__(key)
48
+ try:
49
+ return self[key]
50
+ except KeyError:
51
+ raise AttributeError(
52
+ f"'{type(self).__name__}' object has no attribute '{key}'"
53
+ )
54
+
55
+ def __setattr__(self, key: str, value: Any) -> None:
56
+ if key.startswith("__") and key.endswith("__"):
57
+ super().__setattr__(key, value)
58
+ else:
59
+ self[key] = value
60
+
61
+ def __delattr__(self, key: str) -> None:
62
+ if key.startswith("__") and key.endswith("__"):
63
+ super().__delattr__(key)
64
+ else:
65
+ del self[key]
66
+
67
+ def __deepcopy__(self, memo: dict[Any, Any]) -> "dotdict":
68
+ # Use the default dict copying method to avoid infinite recursion.
69
+ return dotdict(copy.deepcopy(dict(self), memo))
70
+
71
+
72
+ def colbertv2_get_request(
73
+ url: str, query: str, k: int, max_retries: int = 4
74
+ ) -> list[dict[str, Any]]:
75
+ """
76
+ Make a GET request to ColBERTv2 server with retry logic.
77
+
78
+ Args:
79
+ url: The ColBERTv2 server URL
80
+ query: The search query
81
+ k: Number of results to return
82
+ max_retries: Maximum number of retry attempts
83
+
84
+ Returns:
85
+ List of search results
86
+ """
87
+ assert k <= 100, (
88
+ "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."
89
+ )
90
+
91
+ session = _create_session_with_retries(max_retries)
92
+ payload: dict[str, str | int] = {"query": query, "k": k}
93
+
94
+ # Application-level retry for server connection errors
95
+ for attempt in range(max_retries):
96
+ try:
97
+ res = session.get(url, params=payload, timeout=5)
98
+ response_data = res.json()
99
+
100
+ # Check for application-level errors (server connection issues, etc.)
101
+ if "error" in response_data and response_data["error"]:
102
+ error_msg = response_data.get("message", "Unknown error")
103
+ # If it's a connection error, retry; otherwise, fail immediately
104
+ if (
105
+ "Cannot connect to host" in error_msg
106
+ or "Connection refused" in error_msg
107
+ ):
108
+ if attempt == max_retries - 1:
109
+ raise Exception(f"ColBERTv2 server error: {error_msg}")
110
+ time.sleep(1) # Wait 1 second before retrying
111
+ continue
112
+ else:
113
+ raise Exception(f"ColBERTv2 server error: {error_msg}")
114
+
115
+ if "topk" not in response_data:
116
+ raise Exception(
117
+ f"Unexpected response format from ColBERTv2 server: {list(response_data.keys())}"
118
+ )
119
+
120
+ topk = response_data["topk"][:k]
121
+ topk = [{**d, "long_text": d["text"]} for d in topk]
122
+ return topk[:k]
123
+
124
+ except requests.RequestException as e:
125
+ if attempt == max_retries - 1:
126
+ raise Exception(f"ColBERTv2 request failed: {str(e)}")
127
+ time.sleep(1) # Wait 1 second before retrying
128
+
129
+ # This should never be reached, but mypy requires a return statement
130
+ raise Exception("Unexpected end of retry loop")
131
+
132
+
133
+ def colbertv2_post_request(
134
+ url: str, query: str, k: int, max_retries: int = 4
135
+ ) -> list[dict[str, Any]]:
136
+ """
137
+ Make a POST request to ColBERTv2 server with retry logic.
138
+
139
+ Args:
140
+ url: The ColBERTv2 server URL
141
+ query: The search query
142
+ k: Number of results to return
143
+ max_retries: Maximum number of retry attempts
144
+
145
+ Returns:
146
+ List of search results
147
+ """
148
+ session = _create_session_with_retries(max_retries)
149
+ headers = {"Content-Type": "application/json; charset=utf-8"}
150
+ payload = {"query": query, "k": k}
151
+
152
+ # Application-level retry for server connection errors
153
+ for attempt in range(max_retries):
154
+ try:
155
+ res = session.post(url, json=payload, headers=headers, timeout=5)
156
+ response_data = res.json()
157
+
158
+ # Check for application-level errors (server connection issues, etc.)
159
+ if "error" in response_data and response_data["error"]:
160
+ error_msg = response_data.get("message", "Unknown error")
161
+ # If it's a connection error, retry; otherwise, fail immediately
162
+ if (
163
+ "Cannot connect to host" in error_msg
164
+ or "Connection refused" in error_msg
165
+ ):
166
+ if attempt == max_retries - 1:
167
+ raise Exception(f"ColBERTv2 server error: {error_msg}")
168
+ time.sleep(1) # Wait 1 second before retrying
169
+ continue
170
+ else:
171
+ raise Exception(f"ColBERTv2 server error: {error_msg}")
172
+
173
+ if "topk" not in response_data:
174
+ raise Exception(
175
+ f"Unexpected response format from ColBERTv2 server: {list(response_data.keys())}"
176
+ )
177
+
178
+ return response_data["topk"][:k]
179
+
180
+ except requests.RequestException as e:
181
+ if attempt == max_retries - 1:
182
+ raise Exception(f"ColBERTv2 request failed: {str(e)}")
183
+ time.sleep(1) # Wait 1 second before retrying
184
+
185
+ # This should never be reached, but mypy requires a return statement
186
+ raise Exception("Unexpected end of retry loop")
187
+
188
+
189
+ class ColBERTv2:
190
+ """Wrapper for the ColBERTv2 Retrieval (extracted from dspy)."""
191
+
192
+ def __init__(
193
+ self,
194
+ url: str = "http://0.0.0.0",
195
+ port: str | int | None = None,
196
+ post_requests: bool = False,
197
+ ):
198
+ """
199
+ Initialize ColBERTv2 client.
200
+
201
+ Args:
202
+ url: Base URL for the ColBERTv2 server
203
+ port: Optional port number
204
+ post_requests: Whether to use POST requests instead of GET
205
+ """
206
+ self.post_requests = post_requests
207
+ self.url = f"{url}:{port}" if port else url
208
+
209
+ def __call__(
210
+ self,
211
+ query: str,
212
+ k: int = 10,
213
+ simplify: bool = False,
214
+ max_retries: int = 4,
215
+ ) -> list[str] | list[dotdict]:
216
+ """
217
+ Search using ColBERTv2.
218
+
219
+ Args:
220
+ query: The search query
221
+ k: Number of results to return
222
+ simplify: If True, return only text strings; if False, return dotdict objects
223
+ max_retries: Maximum number of retry attempts
224
+
225
+ Returns:
226
+ List of search results (either strings or dotdict objects)
227
+ """
228
+ if self.post_requests:
229
+ topk_results = colbertv2_post_request(self.url, query, k, max_retries)
230
+ else:
231
+ topk_results = colbertv2_get_request(self.url, query, k, max_retries)
232
+
233
+ if simplify:
234
+ return [psg["long_text"] for psg in topk_results]
235
+
236
+ return [dotdict(psg) for psg in topk_results]
@@ -0,0 +1,3 @@
1
+ {"id": "ctx-001", "user_query": "Using the Context7 library ID /vercel/next.js, how can I route users down different UI flows with the App Router?", "expected_tool": "get-library-docs", "arguments": {"context7CompatibleLibraryID": "/vercel/next.js", "topic": "routing", "tokens": 1500}, "reference_answer": "The App Router handles conditional experiences with parallel routes. Create directories that start with @ to declare each slot, provide a default.tsx so the route still renders when a branch is missing, and decide which slot to render inside your layout based on the user's state. This lets you show different UI branches without blocking navigation."}
2
+ {"id": "ctx-002", "user_query": "With library ID /supabase/supabase, what do the docs recommend for keeping edge functions secure?", "expected_tool": "get-library-docs", "arguments": {"context7CompatibleLibraryID": "/supabase/supabase", "topic": "security", "tokens": 1200}, "reference_answer": "Supabase recommends enabling Row Level Security (RLS) on your Postgres tables so edge functions can only access data allowed by fine-grained policies. Run `alter table ... enable row level security;` (for example on the `todos` table) to enforce those policies and prevent unauthorized access."}
3
+ {"id": "ctx-003", "user_query": "Given /mongodb/docs, remind me what makes up the basic aggregation pipeline.", "expected_tool": "get-library-docs", "arguments": {"context7CompatibleLibraryID": "/mongodb/docs", "topic": "aggregation", "tokens": 1000}, "reference_answer": "An aggregation pipeline runs ordered stages such as $match, $group, $project, $sort, and $limit. Each stage accepts the stream of documents from the previous stage so you can filter, reshape, and summarize the data step by step."}
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from importlib import resources
6
+ from typing import Any, Union
7
+
8
+ try: # pragma: no cover - optional dependency
9
+ import opik # type: ignore
10
+ except ImportError: # pragma: no cover - fallback for tests
11
+ opik = None
12
+
13
+ from opik_optimizer.utils.dataset_utils import attach_uuids, dataset_suffix
14
+
15
+ OpikDataset = Any
16
+
17
+ DATA_PACKAGE = "opik_optimizer.data"
18
+ DATA_FILENAME = "context7_eval.jsonl"
19
+ DATASET_NAME = "context7_eval"
20
+
21
+
22
+ def _load_examples() -> list[dict[str, Any]]:
23
+ text = (
24
+ resources.files(DATA_PACKAGE)
25
+ .joinpath(DATA_FILENAME)
26
+ .read_text(encoding="utf-8")
27
+ )
28
+ return [json.loads(line) for line in text.splitlines() if line.strip()]
29
+
30
+
31
+ def _dataset_name(test_mode: bool) -> str:
32
+ suffix = dataset_suffix(DATA_PACKAGE, DATA_FILENAME)
33
+ return f"{DATASET_NAME}_{suffix}{'_test' if test_mode else ''}"
34
+
35
+
36
+ @dataclass
37
+ class _ListDataset:
38
+ name: str
39
+ _items: list[dict[str, Any]]
40
+
41
+ def __post_init__(self) -> None:
42
+ for idx, item in enumerate(self._items):
43
+ item.setdefault("id", f"{self.name}-{idx}")
44
+ self.id = self.name
45
+
46
+ def copy(self) -> _ListDataset:
47
+ return _ListDataset(self.name, [dict(item) for item in self._items])
48
+
49
+ def get_items(self, nb_samples: int | None = None) -> list[dict[str, Any]]:
50
+ if nb_samples is None:
51
+ return [dict(item) for item in self._items]
52
+ return [dict(item) for item in self._items[:nb_samples]]
53
+
54
+
55
+ DatasetResult = Union["_ListDataset", OpikDataset]
56
+
57
+
58
+ def load_context7_dataset(test_mode: bool = False) -> DatasetResult:
59
+ """Return the context7 synthetic dataset as an Opik dataset when available."""
60
+
61
+ examples = _load_examples()
62
+ dataset_name = _dataset_name(test_mode)
63
+
64
+ if opik is None:
65
+ return _ListDataset(dataset_name, examples)
66
+
67
+ try:
68
+ client = opik.Opik()
69
+ dataset: OpikDataset = client.get_or_create_dataset(dataset_name)
70
+ items = dataset.get_items()
71
+ expected_len = len(examples) if not test_mode else min(len(examples), 2)
72
+
73
+ if len(items) == expected_len:
74
+ return dataset
75
+ if len(items) != 0: # pragma: no cover - defensive path
76
+ raise ValueError(
77
+ f"Dataset {dataset_name} already exists with {len(items)} items. Delete it to regenerate."
78
+ )
79
+
80
+ if test_mode:
81
+ dataset.insert(attach_uuids(examples[:expected_len]))
82
+ else:
83
+ dataset.insert(attach_uuids(examples))
84
+ return dataset
85
+ except Exception:
86
+ # If Opik client fails (e.g., no API key configured), fall back to local dataset
87
+ return _ListDataset(dataset_name, examples)
88
+
89
+
90
+ __all__ = ["load_context7_dataset"]
@@ -1,42 +1,12 @@
1
1
  import opik
2
2
 
3
- TINY_TEST_ITEMS = [
4
- {
5
- "text": "What is the capital of France?",
6
- "label": "Paris",
7
- "metadata": {"context": "France is a country in Europe. Its capital is Paris."},
8
- },
9
- {
10
- "text": "Who wrote Romeo and Juliet?",
11
- "label": "William Shakespeare",
12
- "metadata": {
13
- "context": "Romeo and Juliet is a famous play written by William Shakespeare."
14
- },
15
- },
16
- {
17
- "text": "What is 2 + 2?",
18
- "label": "4",
19
- "metadata": {"context": "Basic arithmetic: 2 + 2 equals 4."},
20
- },
21
- {
22
- "text": "What is the largest planet in our solar system?",
23
- "label": "Jupiter",
24
- "metadata": {"context": "Jupiter is the largest planet in our solar system."},
25
- },
26
- {
27
- "text": "Who painted the Mona Lisa?",
28
- "label": "Leonardo da Vinci",
29
- "metadata": {"context": "The Mona Lisa was painted by Leonardo da Vinci."},
30
- },
31
- ]
32
-
33
3
 
34
4
  def tiny_test(test_mode: bool = False) -> opik.Dataset:
35
5
  """
36
- Dataset containing the first 5 samples of the HotpotQA dataset.
6
+ Tiny QA benchmark (core_en subset from vincentkoc/tiny_qa_benchmark_pp).
37
7
  """
38
8
  dataset_name = "tiny_test" if not test_mode else "tiny_test_test"
39
- nb_items = len(TINY_TEST_ITEMS)
9
+ nb_items = 5 # keep tiny dataset size consistent with tests/docs
40
10
 
41
11
  client = opik.Opik()
42
12
  dataset = client.get_or_create_dataset(dataset_name)
@@ -49,5 +19,34 @@ def tiny_test(test_mode: bool = False) -> opik.Dataset:
49
19
  f"Dataset {dataset_name} contains {len(items)} items, expected {nb_items}. We recommend deleting the dataset and re-creating it."
50
20
  )
51
21
  elif len(items) == 0:
52
- dataset.insert(TINY_TEST_ITEMS)
53
- return dataset
22
+ import datasets as ds
23
+
24
+ download_config = ds.DownloadConfig(download_desc=False, disable_tqdm=True)
25
+ ds.disable_progress_bar()
26
+ try:
27
+ # Load only the core_en subset JSONL from the repo
28
+ # Use the generic JSON loader with streaming for efficiency
29
+ hf_dataset = ds.load_dataset(
30
+ "json",
31
+ data_files="hf://datasets/vincentkoc/tiny_qa_benchmark_pp/data/core_en/core_en.jsonl",
32
+ streaming=True,
33
+ download_config=download_config,
34
+ )["train"]
35
+
36
+ data = []
37
+ for i, item in enumerate(hf_dataset):
38
+ if i >= nb_items:
39
+ break
40
+ data.append(
41
+ {
42
+ "text": item.get("text", ""),
43
+ "label": item.get("label", ""),
44
+ # Preserve original tiny_test shape with metadata.context
45
+ "metadata": {"context": item.get("context", "")},
46
+ }
47
+ )
48
+
49
+ dataset.insert(data)
50
+ return dataset
51
+ finally:
52
+ ds.enable_progress_bar()
@@ -1,5 +1,5 @@
1
1
  import opik
2
- from typing import Any, Dict, List
2
+ from typing import Any
3
3
 
4
4
 
5
5
  def truthful_qa(test_mode: bool = False) -> opik.Dataset:
@@ -33,7 +33,7 @@ def truthful_qa(test_mode: bool = False) -> opik.Dataset:
33
33
  "truthful_qa", "multiple_choice", download_config=download_config
34
34
  )
35
35
 
36
- data: List[Dict[str, Any]] = []
36
+ data: list[dict[str, Any]] = []
37
37
  for gen_item, mc_item in zip(
38
38
  gen_dataset["validation"], mc_dataset["validation"]
39
39
  ):