opik-optimizer 1.0.6__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opik_optimizer/__init__.py +4 -0
- opik_optimizer/_throttle.py +2 -1
- opik_optimizer/base_optimizer.py +402 -28
- opik_optimizer/data/context7_eval.jsonl +3 -0
- opik_optimizer/datasets/context7_eval.py +90 -0
- opik_optimizer/datasets/tiny_test.py +33 -34
- opik_optimizer/datasets/truthful_qa.py +2 -2
- opik_optimizer/evolutionary_optimizer/crossover_ops.py +194 -0
- opik_optimizer/evolutionary_optimizer/evaluation_ops.py +136 -0
- opik_optimizer/evolutionary_optimizer/evolutionary_optimizer.py +289 -966
- opik_optimizer/evolutionary_optimizer/helpers.py +10 -0
- opik_optimizer/evolutionary_optimizer/llm_support.py +136 -0
- opik_optimizer/evolutionary_optimizer/mcp.py +249 -0
- opik_optimizer/evolutionary_optimizer/mutation_ops.py +306 -0
- opik_optimizer/evolutionary_optimizer/population_ops.py +228 -0
- opik_optimizer/evolutionary_optimizer/prompts.py +352 -0
- opik_optimizer/evolutionary_optimizer/reporting.py +28 -4
- opik_optimizer/evolutionary_optimizer/style_ops.py +86 -0
- opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +90 -81
- opik_optimizer/few_shot_bayesian_optimizer/reporting.py +12 -5
- opik_optimizer/gepa_optimizer/__init__.py +3 -0
- opik_optimizer/gepa_optimizer/adapter.py +154 -0
- opik_optimizer/gepa_optimizer/gepa_optimizer.py +653 -0
- opik_optimizer/gepa_optimizer/reporting.py +181 -0
- opik_optimizer/logging_config.py +42 -7
- opik_optimizer/mcp_utils/__init__.py +22 -0
- opik_optimizer/mcp_utils/mcp.py +541 -0
- opik_optimizer/mcp_utils/mcp_second_pass.py +152 -0
- opik_optimizer/mcp_utils/mcp_simulator.py +116 -0
- opik_optimizer/mcp_utils/mcp_workflow.py +547 -0
- opik_optimizer/meta_prompt_optimizer/meta_prompt_optimizer.py +470 -134
- opik_optimizer/meta_prompt_optimizer/reporting.py +16 -2
- opik_optimizer/mipro_optimizer/_lm.py +30 -23
- opik_optimizer/mipro_optimizer/_mipro_optimizer_v2.py +52 -51
- opik_optimizer/mipro_optimizer/mipro_optimizer.py +126 -46
- opik_optimizer/mipro_optimizer/utils.py +2 -4
- opik_optimizer/optimizable_agent.py +21 -16
- opik_optimizer/optimization_config/chat_prompt.py +44 -23
- opik_optimizer/optimization_config/configs.py +3 -3
- opik_optimizer/optimization_config/mappers.py +9 -8
- opik_optimizer/optimization_result.py +22 -14
- opik_optimizer/reporting_utils.py +61 -10
- opik_optimizer/task_evaluator.py +9 -8
- opik_optimizer/utils/__init__.py +15 -0
- opik_optimizer/utils/colbert.py +236 -0
- opik_optimizer/{utils.py → utils/core.py} +160 -33
- opik_optimizer/utils/dataset_utils.py +49 -0
- opik_optimizer/utils/prompt_segments.py +186 -0
- opik_optimizer-2.0.0.dist-info/METADATA +345 -0
- opik_optimizer-2.0.0.dist-info/RECORD +74 -0
- opik_optimizer-2.0.0.dist-info/licenses/LICENSE +203 -0
- opik_optimizer-1.0.6.dist-info/METADATA +0 -181
- opik_optimizer-1.0.6.dist-info/RECORD +0 -50
- opik_optimizer-1.0.6.dist-info/licenses/LICENSE +0 -21
- {opik_optimizer-1.0.6.dist-info → opik_optimizer-2.0.0.dist-info}/WHEEL +0 -0
- {opik_optimizer-1.0.6.dist-info → opik_optimizer-2.0.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Module containing the OptimizationResult class."""
|
2
2
|
|
3
|
-
from typing import Any
|
3
|
+
from typing import Any
|
4
4
|
|
5
5
|
import pydantic
|
6
6
|
import rich
|
@@ -13,25 +13,26 @@ class OptimizationResult(pydantic.BaseModel):
|
|
13
13
|
|
14
14
|
optimizer: str = "Optimizer"
|
15
15
|
|
16
|
-
prompt:
|
16
|
+
prompt: list[dict[str, str]]
|
17
17
|
score: float
|
18
18
|
metric_name: str
|
19
19
|
|
20
|
-
optimization_id:
|
21
|
-
dataset_id:
|
20
|
+
optimization_id: str | None = None
|
21
|
+
dataset_id: str | None = None
|
22
22
|
|
23
23
|
# Initial score
|
24
|
-
initial_prompt:
|
25
|
-
initial_score:
|
24
|
+
initial_prompt: list[dict[str, str]] | None = None
|
25
|
+
initial_score: float | None = None
|
26
26
|
|
27
|
-
details:
|
28
|
-
history:
|
29
|
-
llm_calls:
|
27
|
+
details: dict[str, Any] = pydantic.Field(default_factory=dict)
|
28
|
+
history: list[dict[str, Any]] = []
|
29
|
+
llm_calls: int | None = None
|
30
|
+
tool_calls: int | None = None
|
30
31
|
|
31
32
|
# MIPRO specific
|
32
|
-
demonstrations:
|
33
|
-
mipro_prompt:
|
34
|
-
tool_prompts:
|
33
|
+
demonstrations: list[dict[str, Any]] | None = None
|
34
|
+
mipro_prompt: str | None = None
|
35
|
+
tool_prompts: dict[str, str] | None = None
|
35
36
|
|
36
37
|
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
37
38
|
|
@@ -40,7 +41,7 @@ class OptimizationResult(pydantic.BaseModel):
|
|
40
41
|
optimization_id=self.optimization_id, dataset_id=self.dataset_id
|
41
42
|
)
|
42
43
|
|
43
|
-
def model_dump(self, *kargs: Any, **kwargs: Any) ->
|
44
|
+
def model_dump(self, *kargs: Any, **kwargs: Any) -> dict[str, Any]:
|
44
45
|
return super().model_dump(*kargs, **kwargs)
|
45
46
|
|
46
47
|
def _calculate_improvement_str(self) -> str:
|
@@ -205,4 +206,11 @@ class OptimizationResult(pydantic.BaseModel):
|
|
205
206
|
"""
|
206
207
|
console = get_console()
|
207
208
|
console.print(self)
|
208
|
-
|
209
|
+
# Gracefully handle cases where optimization tracking isn't available
|
210
|
+
if self.dataset_id and self.optimization_id:
|
211
|
+
try:
|
212
|
+
print("Optimization run link:", self.get_run_link())
|
213
|
+
except Exception:
|
214
|
+
print("Optimization run link: No optimization run link available")
|
215
|
+
else:
|
216
|
+
print("Optimization run link: No optimization run link available")
|
@@ -1,6 +1,7 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
2
3
|
from contextlib import contextmanager
|
3
|
-
from typing import Any
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
from rich import box
|
6
7
|
from rich.console import Console, Group
|
@@ -20,7 +21,7 @@ def get_console(*args: Any, **kwargs: Any) -> Console:
|
|
20
21
|
|
21
22
|
|
22
23
|
@contextmanager
|
23
|
-
def convert_tqdm_to_rich(description:
|
24
|
+
def convert_tqdm_to_rich(description: str | None = None, verbose: int = 1) -> Any:
|
24
25
|
"""Context manager to convert tqdm to rich."""
|
25
26
|
import opik.evaluation.engine.evaluation_tasks_executor
|
26
27
|
|
@@ -66,7 +67,7 @@ def suppress_opik_logs() -> Any:
|
|
66
67
|
opik_logger.setLevel(original_level)
|
67
68
|
|
68
69
|
|
69
|
-
def display_messages(messages:
|
70
|
+
def display_messages(messages: list[dict[str, str]], prefix: str = "") -> None:
|
70
71
|
for i, msg in enumerate(messages):
|
71
72
|
panel = Panel(
|
72
73
|
Text(msg.get("content", ""), overflow="fold"),
|
@@ -90,11 +91,53 @@ def display_messages(messages: List[Dict[str, str]], prefix: str = "") -> None:
|
|
90
91
|
console.print(Text(prefix) + Text.from_ansi(line))
|
91
92
|
|
92
93
|
|
94
|
+
def _format_tool_panel(tool: dict[str, Any]) -> Panel:
|
95
|
+
function_block = tool.get("function", {})
|
96
|
+
name = function_block.get("name") or tool.get("name", "unknown_tool")
|
97
|
+
description = function_block.get("description", "")
|
98
|
+
parameters = function_block.get("parameters", {})
|
99
|
+
|
100
|
+
body_lines: list[str] = []
|
101
|
+
if description:
|
102
|
+
body_lines.append(description)
|
103
|
+
if parameters:
|
104
|
+
formatted_schema = json.dumps(parameters, indent=2, sort_keys=True)
|
105
|
+
body_lines.append("\nSchema:\n" + formatted_schema)
|
106
|
+
|
107
|
+
content = Text(
|
108
|
+
"\n".join(body_lines) if body_lines else "(no metadata)", overflow="fold"
|
109
|
+
)
|
110
|
+
return Panel(
|
111
|
+
content,
|
112
|
+
title=f"tool: {name}",
|
113
|
+
title_align="left",
|
114
|
+
border_style="cyan",
|
115
|
+
width=PANEL_WIDTH,
|
116
|
+
padding=(1, 2),
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
def _display_tools(tools: list[dict[str, Any]] | None) -> None:
|
121
|
+
if not tools:
|
122
|
+
return
|
123
|
+
|
124
|
+
console = get_console()
|
125
|
+
console.print(Text("\nTools registered:\n", style="bold"))
|
126
|
+
for tool in tools:
|
127
|
+
panel = _format_tool_panel(tool)
|
128
|
+
with console.capture() as capture:
|
129
|
+
console.print(panel)
|
130
|
+
rendered_panel = capture.get()
|
131
|
+
for line in rendered_panel.splitlines():
|
132
|
+
console.print(Text.from_ansi(line))
|
133
|
+
console.print("")
|
134
|
+
|
135
|
+
|
93
136
|
def get_link_text(
|
94
137
|
pre_text: str,
|
95
138
|
link_text: str,
|
96
|
-
optimization_id:
|
97
|
-
dataset_id:
|
139
|
+
optimization_id: str | None = None,
|
140
|
+
dataset_id: str | None = None,
|
98
141
|
) -> Text:
|
99
142
|
if optimization_id is not None and dataset_id is not None:
|
100
143
|
optimization_url = get_optimization_run_url_by_id(
|
@@ -112,8 +155,8 @@ def get_link_text(
|
|
112
155
|
|
113
156
|
def display_header(
|
114
157
|
algorithm: str,
|
115
|
-
optimization_id:
|
116
|
-
dataset_id:
|
158
|
+
optimization_id: str | None = None,
|
159
|
+
dataset_id: str | None = None,
|
117
160
|
verbose: int = 1,
|
118
161
|
) -> None:
|
119
162
|
if verbose < 1:
|
@@ -140,8 +183,9 @@ def display_header(
|
|
140
183
|
def display_result(
|
141
184
|
initial_score: float,
|
142
185
|
best_score: float,
|
143
|
-
best_prompt:
|
186
|
+
best_prompt: list[dict[str, str]],
|
144
187
|
verbose: int = 1,
|
188
|
+
tools: list[dict[str, Any]] | None = None,
|
145
189
|
) -> None:
|
146
190
|
if verbose < 1:
|
147
191
|
return
|
@@ -149,7 +193,7 @@ def display_result(
|
|
149
193
|
console = get_console()
|
150
194
|
console.print(Text("\n> Optimization complete\n"))
|
151
195
|
|
152
|
-
content:
|
196
|
+
content: Text | Panel = []
|
153
197
|
|
154
198
|
if best_score > initial_score:
|
155
199
|
if initial_score == 0:
|
@@ -199,9 +243,15 @@ def display_result(
|
|
199
243
|
)
|
200
244
|
)
|
201
245
|
|
246
|
+
if tools:
|
247
|
+
_display_tools(tools)
|
248
|
+
|
202
249
|
|
203
250
|
def display_configuration(
|
204
|
-
messages:
|
251
|
+
messages: list[dict[str, str]],
|
252
|
+
optimizer_config: dict[str, Any],
|
253
|
+
verbose: int = 1,
|
254
|
+
tools: list[dict[str, Any]] | None = None,
|
205
255
|
) -> None:
|
206
256
|
"""Displays the LLM messages and optimizer configuration using Rich panels."""
|
207
257
|
|
@@ -213,6 +263,7 @@ def display_configuration(
|
|
213
263
|
console.print(Text("> Let's optimize the prompt:\n"))
|
214
264
|
|
215
265
|
display_messages(messages)
|
266
|
+
_display_tools(tools)
|
216
267
|
|
217
268
|
# Panel for configuration
|
218
269
|
console.print(
|
opik_optimizer/task_evaluator.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any
|
3
|
+
from collections.abc import Callable
|
3
4
|
|
4
5
|
import opik
|
5
6
|
from opik.evaluation import evaluator as opik_evaluator
|
@@ -38,14 +39,14 @@ def _create_metric_class(metric: Callable) -> base_metric.BaseMetric:
|
|
38
39
|
|
39
40
|
def evaluate(
|
40
41
|
dataset: opik.Dataset,
|
41
|
-
evaluated_task: Callable[[
|
42
|
+
evaluated_task: Callable[[dict[str, Any]], dict[str, Any]],
|
42
43
|
metric: Callable,
|
43
44
|
num_threads: int,
|
44
|
-
optimization_id:
|
45
|
-
dataset_item_ids:
|
46
|
-
project_name:
|
47
|
-
n_samples:
|
48
|
-
experiment_config:
|
45
|
+
optimization_id: str | None = None,
|
46
|
+
dataset_item_ids: list[str] | None = None,
|
47
|
+
project_name: str | None = None,
|
48
|
+
n_samples: int | None = None,
|
49
|
+
experiment_config: dict[str, Any] | None = None,
|
49
50
|
verbose: int = 1,
|
50
51
|
) -> float:
|
51
52
|
"""
|
@@ -107,7 +108,7 @@ def evaluate(
|
|
107
108
|
return 0.0
|
108
109
|
|
109
110
|
# We may allow score aggregation customization.
|
110
|
-
score_results:
|
111
|
+
score_results: list[score_result.ScoreResult] = [
|
111
112
|
test_result.score_results[0] for test_result in result.test_results
|
112
113
|
]
|
113
114
|
if not score_results:
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Utility helpers exposed as part of the opik_optimizer package."""
|
2
|
+
|
3
|
+
from .core import * # noqa: F401,F403
|
4
|
+
from .dataset_utils import * # noqa: F401,F403
|
5
|
+
from .prompt_segments import * # noqa: F401,F403
|
6
|
+
|
7
|
+
from . import core as _core
|
8
|
+
from . import dataset_utils as _dataset_utils
|
9
|
+
from . import prompt_segments as _prompt_segments
|
10
|
+
|
11
|
+
__all__: list[str] = [
|
12
|
+
*getattr(_core, "__all__", []),
|
13
|
+
*getattr(_dataset_utils, "__all__", []),
|
14
|
+
*getattr(_prompt_segments, "__all__", []),
|
15
|
+
]
|
@@ -0,0 +1,236 @@
|
|
1
|
+
"""
|
2
|
+
Minimal ColBERTv2 implementation extracted from dspy (MIT license).
|
3
|
+
|
4
|
+
This module provides a lightweight implementation of ColBERTv2 search functionality
|
5
|
+
without requiring the full dspy dependency.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import copy
|
9
|
+
import time
|
10
|
+
from typing import Any
|
11
|
+
import requests # type: ignore[import-untyped]
|
12
|
+
from requests.adapters import HTTPAdapter # type: ignore[import-untyped]
|
13
|
+
from urllib3.util.retry import Retry
|
14
|
+
|
15
|
+
|
16
|
+
def _create_session_with_retries(max_retries: int = 4) -> requests.Session:
|
17
|
+
"""
|
18
|
+
Create a requests session with retry configuration.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
max_retries: Maximum number of retry attempts
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Configured requests session
|
25
|
+
"""
|
26
|
+
session = requests.Session()
|
27
|
+
|
28
|
+
retry_strategy = Retry(
|
29
|
+
total=max_retries,
|
30
|
+
backoff_factor=1, # Wait 1, 2, 4, 8 seconds between retries
|
31
|
+
status_forcelist=[429, 500, 502, 503, 504], # HTTP status codes to retry on
|
32
|
+
allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"],
|
33
|
+
)
|
34
|
+
|
35
|
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
36
|
+
session.mount("http://", adapter)
|
37
|
+
session.mount("https://", adapter)
|
38
|
+
|
39
|
+
return session
|
40
|
+
|
41
|
+
|
42
|
+
class dotdict(dict):
|
43
|
+
"""Dictionary with attribute access (extracted from dspy)."""
|
44
|
+
|
45
|
+
def __getattr__(self, key: str) -> Any:
|
46
|
+
if key.startswith("__") and key.endswith("__"):
|
47
|
+
return super().__getattribute__(key)
|
48
|
+
try:
|
49
|
+
return self[key]
|
50
|
+
except KeyError:
|
51
|
+
raise AttributeError(
|
52
|
+
f"'{type(self).__name__}' object has no attribute '{key}'"
|
53
|
+
)
|
54
|
+
|
55
|
+
def __setattr__(self, key: str, value: Any) -> None:
|
56
|
+
if key.startswith("__") and key.endswith("__"):
|
57
|
+
super().__setattr__(key, value)
|
58
|
+
else:
|
59
|
+
self[key] = value
|
60
|
+
|
61
|
+
def __delattr__(self, key: str) -> None:
|
62
|
+
if key.startswith("__") and key.endswith("__"):
|
63
|
+
super().__delattr__(key)
|
64
|
+
else:
|
65
|
+
del self[key]
|
66
|
+
|
67
|
+
def __deepcopy__(self, memo: dict[Any, Any]) -> "dotdict":
|
68
|
+
# Use the default dict copying method to avoid infinite recursion.
|
69
|
+
return dotdict(copy.deepcopy(dict(self), memo))
|
70
|
+
|
71
|
+
|
72
|
+
def colbertv2_get_request(
|
73
|
+
url: str, query: str, k: int, max_retries: int = 4
|
74
|
+
) -> list[dict[str, Any]]:
|
75
|
+
"""
|
76
|
+
Make a GET request to ColBERTv2 server with retry logic.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
url: The ColBERTv2 server URL
|
80
|
+
query: The search query
|
81
|
+
k: Number of results to return
|
82
|
+
max_retries: Maximum number of retry attempts
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
List of search results
|
86
|
+
"""
|
87
|
+
assert k <= 100, (
|
88
|
+
"Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."
|
89
|
+
)
|
90
|
+
|
91
|
+
session = _create_session_with_retries(max_retries)
|
92
|
+
payload: dict[str, str | int] = {"query": query, "k": k}
|
93
|
+
|
94
|
+
# Application-level retry for server connection errors
|
95
|
+
for attempt in range(max_retries):
|
96
|
+
try:
|
97
|
+
res = session.get(url, params=payload, timeout=5)
|
98
|
+
response_data = res.json()
|
99
|
+
|
100
|
+
# Check for application-level errors (server connection issues, etc.)
|
101
|
+
if "error" in response_data and response_data["error"]:
|
102
|
+
error_msg = response_data.get("message", "Unknown error")
|
103
|
+
# If it's a connection error, retry; otherwise, fail immediately
|
104
|
+
if (
|
105
|
+
"Cannot connect to host" in error_msg
|
106
|
+
or "Connection refused" in error_msg
|
107
|
+
):
|
108
|
+
if attempt == max_retries - 1:
|
109
|
+
raise Exception(f"ColBERTv2 server error: {error_msg}")
|
110
|
+
time.sleep(1) # Wait 1 second before retrying
|
111
|
+
continue
|
112
|
+
else:
|
113
|
+
raise Exception(f"ColBERTv2 server error: {error_msg}")
|
114
|
+
|
115
|
+
if "topk" not in response_data:
|
116
|
+
raise Exception(
|
117
|
+
f"Unexpected response format from ColBERTv2 server: {list(response_data.keys())}"
|
118
|
+
)
|
119
|
+
|
120
|
+
topk = response_data["topk"][:k]
|
121
|
+
topk = [{**d, "long_text": d["text"]} for d in topk]
|
122
|
+
return topk[:k]
|
123
|
+
|
124
|
+
except requests.RequestException as e:
|
125
|
+
if attempt == max_retries - 1:
|
126
|
+
raise Exception(f"ColBERTv2 request failed: {str(e)}")
|
127
|
+
time.sleep(1) # Wait 1 second before retrying
|
128
|
+
|
129
|
+
# This should never be reached, but mypy requires a return statement
|
130
|
+
raise Exception("Unexpected end of retry loop")
|
131
|
+
|
132
|
+
|
133
|
+
def colbertv2_post_request(
|
134
|
+
url: str, query: str, k: int, max_retries: int = 4
|
135
|
+
) -> list[dict[str, Any]]:
|
136
|
+
"""
|
137
|
+
Make a POST request to ColBERTv2 server with retry logic.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
url: The ColBERTv2 server URL
|
141
|
+
query: The search query
|
142
|
+
k: Number of results to return
|
143
|
+
max_retries: Maximum number of retry attempts
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
List of search results
|
147
|
+
"""
|
148
|
+
session = _create_session_with_retries(max_retries)
|
149
|
+
headers = {"Content-Type": "application/json; charset=utf-8"}
|
150
|
+
payload = {"query": query, "k": k}
|
151
|
+
|
152
|
+
# Application-level retry for server connection errors
|
153
|
+
for attempt in range(max_retries):
|
154
|
+
try:
|
155
|
+
res = session.post(url, json=payload, headers=headers, timeout=5)
|
156
|
+
response_data = res.json()
|
157
|
+
|
158
|
+
# Check for application-level errors (server connection issues, etc.)
|
159
|
+
if "error" in response_data and response_data["error"]:
|
160
|
+
error_msg = response_data.get("message", "Unknown error")
|
161
|
+
# If it's a connection error, retry; otherwise, fail immediately
|
162
|
+
if (
|
163
|
+
"Cannot connect to host" in error_msg
|
164
|
+
or "Connection refused" in error_msg
|
165
|
+
):
|
166
|
+
if attempt == max_retries - 1:
|
167
|
+
raise Exception(f"ColBERTv2 server error: {error_msg}")
|
168
|
+
time.sleep(1) # Wait 1 second before retrying
|
169
|
+
continue
|
170
|
+
else:
|
171
|
+
raise Exception(f"ColBERTv2 server error: {error_msg}")
|
172
|
+
|
173
|
+
if "topk" not in response_data:
|
174
|
+
raise Exception(
|
175
|
+
f"Unexpected response format from ColBERTv2 server: {list(response_data.keys())}"
|
176
|
+
)
|
177
|
+
|
178
|
+
return response_data["topk"][:k]
|
179
|
+
|
180
|
+
except requests.RequestException as e:
|
181
|
+
if attempt == max_retries - 1:
|
182
|
+
raise Exception(f"ColBERTv2 request failed: {str(e)}")
|
183
|
+
time.sleep(1) # Wait 1 second before retrying
|
184
|
+
|
185
|
+
# This should never be reached, but mypy requires a return statement
|
186
|
+
raise Exception("Unexpected end of retry loop")
|
187
|
+
|
188
|
+
|
189
|
+
class ColBERTv2:
|
190
|
+
"""Wrapper for the ColBERTv2 Retrieval (extracted from dspy)."""
|
191
|
+
|
192
|
+
def __init__(
|
193
|
+
self,
|
194
|
+
url: str = "http://0.0.0.0",
|
195
|
+
port: str | int | None = None,
|
196
|
+
post_requests: bool = False,
|
197
|
+
):
|
198
|
+
"""
|
199
|
+
Initialize ColBERTv2 client.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
url: Base URL for the ColBERTv2 server
|
203
|
+
port: Optional port number
|
204
|
+
post_requests: Whether to use POST requests instead of GET
|
205
|
+
"""
|
206
|
+
self.post_requests = post_requests
|
207
|
+
self.url = f"{url}:{port}" if port else url
|
208
|
+
|
209
|
+
def __call__(
|
210
|
+
self,
|
211
|
+
query: str,
|
212
|
+
k: int = 10,
|
213
|
+
simplify: bool = False,
|
214
|
+
max_retries: int = 4,
|
215
|
+
) -> list[str] | list[dotdict]:
|
216
|
+
"""
|
217
|
+
Search using ColBERTv2.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
query: The search query
|
221
|
+
k: Number of results to return
|
222
|
+
simplify: If True, return only text strings; if False, return dotdict objects
|
223
|
+
max_retries: Maximum number of retry attempts
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
List of search results (either strings or dotdict objects)
|
227
|
+
"""
|
228
|
+
if self.post_requests:
|
229
|
+
topk_results = colbertv2_post_request(self.url, query, k, max_retries)
|
230
|
+
else:
|
231
|
+
topk_results = colbertv2_get_request(self.url, query, k, max_retries)
|
232
|
+
|
233
|
+
if simplify:
|
234
|
+
return [psg["long_text"] for psg in topk_results]
|
235
|
+
|
236
|
+
return [dotdict(psg) for psg in topk_results]
|