opik-optimizer 1.0.6__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opik_optimizer/__init__.py +4 -0
- opik_optimizer/_throttle.py +2 -1
- opik_optimizer/base_optimizer.py +402 -28
- opik_optimizer/data/context7_eval.jsonl +3 -0
- opik_optimizer/datasets/context7_eval.py +90 -0
- opik_optimizer/datasets/tiny_test.py +33 -34
- opik_optimizer/datasets/truthful_qa.py +2 -2
- opik_optimizer/evolutionary_optimizer/crossover_ops.py +194 -0
- opik_optimizer/evolutionary_optimizer/evaluation_ops.py +136 -0
- opik_optimizer/evolutionary_optimizer/evolutionary_optimizer.py +289 -966
- opik_optimizer/evolutionary_optimizer/helpers.py +10 -0
- opik_optimizer/evolutionary_optimizer/llm_support.py +136 -0
- opik_optimizer/evolutionary_optimizer/mcp.py +249 -0
- opik_optimizer/evolutionary_optimizer/mutation_ops.py +306 -0
- opik_optimizer/evolutionary_optimizer/population_ops.py +228 -0
- opik_optimizer/evolutionary_optimizer/prompts.py +352 -0
- opik_optimizer/evolutionary_optimizer/reporting.py +28 -4
- opik_optimizer/evolutionary_optimizer/style_ops.py +86 -0
- opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +90 -81
- opik_optimizer/few_shot_bayesian_optimizer/reporting.py +12 -5
- opik_optimizer/gepa_optimizer/__init__.py +3 -0
- opik_optimizer/gepa_optimizer/adapter.py +154 -0
- opik_optimizer/gepa_optimizer/gepa_optimizer.py +653 -0
- opik_optimizer/gepa_optimizer/reporting.py +181 -0
- opik_optimizer/logging_config.py +42 -7
- opik_optimizer/mcp_utils/__init__.py +22 -0
- opik_optimizer/mcp_utils/mcp.py +541 -0
- opik_optimizer/mcp_utils/mcp_second_pass.py +152 -0
- opik_optimizer/mcp_utils/mcp_simulator.py +116 -0
- opik_optimizer/mcp_utils/mcp_workflow.py +547 -0
- opik_optimizer/meta_prompt_optimizer/meta_prompt_optimizer.py +470 -134
- opik_optimizer/meta_prompt_optimizer/reporting.py +16 -2
- opik_optimizer/mipro_optimizer/_lm.py +30 -23
- opik_optimizer/mipro_optimizer/_mipro_optimizer_v2.py +52 -51
- opik_optimizer/mipro_optimizer/mipro_optimizer.py +126 -46
- opik_optimizer/mipro_optimizer/utils.py +2 -4
- opik_optimizer/optimizable_agent.py +21 -16
- opik_optimizer/optimization_config/chat_prompt.py +44 -23
- opik_optimizer/optimization_config/configs.py +3 -3
- opik_optimizer/optimization_config/mappers.py +9 -8
- opik_optimizer/optimization_result.py +22 -14
- opik_optimizer/reporting_utils.py +61 -10
- opik_optimizer/task_evaluator.py +9 -8
- opik_optimizer/utils/__init__.py +15 -0
- opik_optimizer/utils/colbert.py +236 -0
- opik_optimizer/{utils.py → utils/core.py} +160 -33
- opik_optimizer/utils/dataset_utils.py +49 -0
- opik_optimizer/utils/prompt_segments.py +186 -0
- opik_optimizer-2.0.0.dist-info/METADATA +345 -0
- opik_optimizer-2.0.0.dist-info/RECORD +74 -0
- opik_optimizer-2.0.0.dist-info/licenses/LICENSE +203 -0
- opik_optimizer-1.0.6.dist-info/METADATA +0 -181
- opik_optimizer-1.0.6.dist-info/RECORD +0 -50
- opik_optimizer-1.0.6.dist-info/licenses/LICENSE +0 -21
- {opik_optimizer-1.0.6.dist-info → opik_optimizer-2.0.0.dist-info}/WHEEL +0 -0
- {opik_optimizer-1.0.6.dist-info → opik_optimizer-2.0.0.dist-info}/top_level.txt +0 -0
@@ -2,18 +2,14 @@
|
|
2
2
|
|
3
3
|
from typing import (
|
4
4
|
Any,
|
5
|
-
Dict,
|
6
5
|
Final,
|
7
6
|
Literal,
|
8
|
-
Optional,
|
9
|
-
Type,
|
10
7
|
TYPE_CHECKING,
|
11
|
-
List,
|
12
|
-
Callable,
|
13
8
|
)
|
9
|
+
from collections.abc import Callable
|
14
10
|
|
11
|
+
import ast
|
15
12
|
import inspect
|
16
|
-
import typing
|
17
13
|
import base64
|
18
14
|
import json
|
19
15
|
import logging
|
@@ -22,16 +18,20 @@ import string
|
|
22
18
|
import urllib.parse
|
23
19
|
from types import TracebackType
|
24
20
|
|
21
|
+
import requests
|
22
|
+
|
25
23
|
import opik
|
26
24
|
from opik.api_objects.opik_client import Opik
|
27
25
|
from opik.api_objects.optimization import Optimization
|
28
26
|
|
29
|
-
|
30
|
-
logger = logging.getLogger(__name__)
|
27
|
+
from .colbert import ColBERTv2
|
31
28
|
|
32
29
|
if TYPE_CHECKING:
|
33
|
-
from .optimizable_agent import OptimizableAgent
|
34
|
-
from .optimization_config.chat_prompt import ChatPrompt
|
30
|
+
from opik_optimizer.optimizable_agent import OptimizableAgent
|
31
|
+
from opik_optimizer.optimization_config.chat_prompt import ChatPrompt
|
32
|
+
|
33
|
+
ALLOWED_URL_CHARACTERS: Final[str] = ":/&?="
|
34
|
+
logger = logging.getLogger(__name__)
|
35
35
|
|
36
36
|
|
37
37
|
class OptimizationContextManager:
|
@@ -45,8 +45,8 @@ class OptimizationContextManager:
|
|
45
45
|
client: Opik,
|
46
46
|
dataset_name: str,
|
47
47
|
objective_name: str,
|
48
|
-
name:
|
49
|
-
metadata:
|
48
|
+
name: str | None = None,
|
49
|
+
metadata: dict[str, Any] | None = None,
|
50
50
|
):
|
51
51
|
"""
|
52
52
|
Initialize the optimization context.
|
@@ -63,9 +63,9 @@ class OptimizationContextManager:
|
|
63
63
|
self.objective_name = objective_name
|
64
64
|
self.name = name
|
65
65
|
self.metadata = metadata
|
66
|
-
self.optimization:
|
66
|
+
self.optimization: Optimization | None = None
|
67
67
|
|
68
|
-
def __enter__(self) ->
|
68
|
+
def __enter__(self) -> Optimization | None:
|
69
69
|
"""Create and return the optimization."""
|
70
70
|
try:
|
71
71
|
self.optimization = self.client.create_optimization(
|
@@ -88,9 +88,9 @@ class OptimizationContextManager:
|
|
88
88
|
|
89
89
|
def __exit__(
|
90
90
|
self,
|
91
|
-
exc_type:
|
92
|
-
exc_val:
|
93
|
-
exc_tb:
|
91
|
+
exc_type: type[BaseException] | None,
|
92
|
+
exc_val: BaseException | None,
|
93
|
+
exc_tb: TracebackType | None,
|
94
94
|
) -> Literal[False]:
|
95
95
|
"""Update optimization status based on context exit."""
|
96
96
|
if self.optimization is None:
|
@@ -205,7 +205,7 @@ def json_to_dict(json_str: str) -> Any:
|
|
205
205
|
|
206
206
|
try:
|
207
207
|
return json.loads(cleaned_json_string)
|
208
|
-
except json.JSONDecodeError:
|
208
|
+
except json.JSONDecodeError as json_error:
|
209
209
|
if cleaned_json_string.startswith("```json"):
|
210
210
|
cleaned_json_string = cleaned_json_string[7:]
|
211
211
|
if cleaned_json_string.endswith("```"):
|
@@ -217,18 +217,52 @@ def json_to_dict(json_str: str) -> Any:
|
|
217
217
|
|
218
218
|
try:
|
219
219
|
return json.loads(cleaned_json_string)
|
220
|
-
except json.JSONDecodeError
|
221
|
-
|
222
|
-
|
223
|
-
|
220
|
+
except json.JSONDecodeError:
|
221
|
+
try:
|
222
|
+
literal_result = ast.literal_eval(cleaned_json_string)
|
223
|
+
except (ValueError, SyntaxError):
|
224
|
+
logger.debug("Failed to parse JSON string: %s", json_str)
|
225
|
+
raise json_error
|
226
|
+
|
227
|
+
normalized = _convert_literals_to_json_compatible(literal_result)
|
228
|
+
|
229
|
+
try:
|
230
|
+
return json.loads(json.dumps(normalized))
|
231
|
+
except (TypeError, ValueError) as serialization_error:
|
232
|
+
logger.debug(
|
233
|
+
"Failed to serialise literal-evaluated payload %r: %s",
|
234
|
+
literal_result,
|
235
|
+
serialization_error,
|
236
|
+
)
|
237
|
+
raise json_error
|
238
|
+
|
239
|
+
|
240
|
+
def _convert_literals_to_json_compatible(value: Any) -> Any:
|
241
|
+
"""Convert Python literals to JSON-compatible structures."""
|
242
|
+
if isinstance(value, dict):
|
243
|
+
return {
|
244
|
+
key: _convert_literals_to_json_compatible(val) for key, val in value.items()
|
245
|
+
}
|
246
|
+
if isinstance(value, list):
|
247
|
+
return [_convert_literals_to_json_compatible(item) for item in value]
|
248
|
+
if isinstance(value, tuple):
|
249
|
+
return [_convert_literals_to_json_compatible(item) for item in value]
|
250
|
+
if isinstance(value, set):
|
251
|
+
return [
|
252
|
+
_convert_literals_to_json_compatible(item)
|
253
|
+
for item in sorted(value, key=repr)
|
254
|
+
]
|
255
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
256
|
+
return value
|
257
|
+
return str(value)
|
224
258
|
|
225
259
|
|
226
260
|
def optimization_context(
|
227
261
|
client: Opik,
|
228
262
|
dataset_name: str,
|
229
263
|
objective_name: str,
|
230
|
-
name:
|
231
|
-
metadata:
|
264
|
+
name: str | None = None,
|
265
|
+
metadata: dict[str, Any] | None = None,
|
232
266
|
) -> OptimizationContextManager:
|
233
267
|
"""
|
234
268
|
Create a context manager for handling optimization lifecycle.
|
@@ -258,7 +292,7 @@ def ensure_ending_slash(url: str) -> str:
|
|
258
292
|
|
259
293
|
|
260
294
|
def get_optimization_run_url_by_id(
|
261
|
-
dataset_id:
|
295
|
+
dataset_id: str | None, optimization_id: str | None
|
262
296
|
) -> str:
|
263
297
|
if dataset_id is None or optimization_id is None:
|
264
298
|
raise ValueError(
|
@@ -276,11 +310,17 @@ def get_optimization_run_url_by_id(
|
|
276
310
|
return urllib.parse.urljoin(ensure_ending_slash(url_override), run_path)
|
277
311
|
|
278
312
|
|
279
|
-
def create_litellm_agent_class(
|
313
|
+
def create_litellm_agent_class(
|
314
|
+
prompt: "ChatPrompt", optimizer_ref: Any = None
|
315
|
+
) -> type["OptimizableAgent"]:
|
280
316
|
"""
|
281
317
|
Create a LiteLLMAgent from a chat prompt.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
prompt: The chat prompt to use
|
321
|
+
optimizer_ref: Optional optimizer instance to attach to the agent
|
282
322
|
"""
|
283
|
-
from .optimizable_agent import OptimizableAgent
|
323
|
+
from opik_optimizer.optimizable_agent import OptimizableAgent
|
284
324
|
|
285
325
|
if prompt.invoke is not None:
|
286
326
|
|
@@ -288,9 +328,10 @@ def create_litellm_agent_class(prompt: "ChatPrompt") -> Type["OptimizableAgent"]
|
|
288
328
|
model = prompt.model
|
289
329
|
model_kwargs = prompt.model_kwargs
|
290
330
|
project_name = prompt.project_name
|
331
|
+
optimizer = optimizer_ref
|
291
332
|
|
292
333
|
def invoke(
|
293
|
-
self, messages:
|
334
|
+
self, messages: list[dict[str, str]], seed: int | None = None
|
294
335
|
) -> str:
|
295
336
|
return prompt.invoke(
|
296
337
|
self.model, messages, prompt.tools, **self.model_kwargs
|
@@ -302,18 +343,19 @@ def create_litellm_agent_class(prompt: "ChatPrompt") -> Type["OptimizableAgent"]
|
|
302
343
|
model = prompt.model
|
303
344
|
model_kwargs = prompt.model_kwargs
|
304
345
|
project_name = prompt.project_name
|
346
|
+
optimizer = optimizer_ref
|
305
347
|
|
306
348
|
return LiteLLMAgent
|
307
349
|
|
308
350
|
|
309
351
|
def function_to_tool_definition(
|
310
|
-
func: Callable, description:
|
311
|
-
) ->
|
352
|
+
func: Callable, description: str | None = None
|
353
|
+
) -> dict[str, Any]:
|
312
354
|
sig = inspect.signature(func)
|
313
355
|
doc = description or func.__doc__ or ""
|
314
356
|
|
315
|
-
properties:
|
316
|
-
required:
|
357
|
+
properties: dict[str, dict[str, str]] = {}
|
358
|
+
required: list[str] = []
|
317
359
|
|
318
360
|
for name, param in sig.parameters.items():
|
319
361
|
param_type = (
|
@@ -350,7 +392,92 @@ def python_type_to_json_type(python_type: type) -> str:
|
|
350
392
|
return "boolean"
|
351
393
|
elif python_type in [dict]:
|
352
394
|
return "object"
|
353
|
-
elif python_type in [list,
|
395
|
+
elif python_type in [list, list]:
|
354
396
|
return "array"
|
355
397
|
else:
|
356
398
|
return "string" # default fallback
|
399
|
+
|
400
|
+
|
401
|
+
def search_wikipedia(query: str, use_api: bool | None = False) -> list[str]:
|
402
|
+
"""
|
403
|
+
This agent is used to search wikipedia. It can retrieve additional details
|
404
|
+
about a topic.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
query: The search query string
|
408
|
+
use_api: (Optional) If True, directly use Wikipedia API instead of ColBERTv2.
|
409
|
+
If False (default), try ColBERTv2 first with API fallback.
|
410
|
+
"""
|
411
|
+
if use_api:
|
412
|
+
# Directly use Wikipedia API when requested
|
413
|
+
try:
|
414
|
+
return _search_wikipedia_api(query)
|
415
|
+
except Exception as api_error:
|
416
|
+
print(f"Wikipedia API failed: {api_error}")
|
417
|
+
return [f"Wikipedia search unavailable. Query was: {query}"]
|
418
|
+
|
419
|
+
# Default behavior: Try ColBERTv2 first with API fallback
|
420
|
+
# Try ColBERTv2 first with a short timeout
|
421
|
+
try:
|
422
|
+
colbert = ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")
|
423
|
+
# Use a shorter timeout by modifying the max_retries parameter
|
424
|
+
results = colbert(query, k=3, max_retries=1)
|
425
|
+
return [str(item.text) for item in results if hasattr(item, "text")]
|
426
|
+
except Exception:
|
427
|
+
# Fallback to Wikipedia API
|
428
|
+
try:
|
429
|
+
return _search_wikipedia_api(query)
|
430
|
+
except Exception as api_error:
|
431
|
+
print(f"Wikipedia API fallback also failed: {api_error}")
|
432
|
+
return [f"Wikipedia search unavailable. Query was: {query}"]
|
433
|
+
|
434
|
+
|
435
|
+
def _search_wikipedia_api(query: str, max_results: int = 3) -> list[str]:
|
436
|
+
"""
|
437
|
+
Fallback Wikipedia search using the Wikipedia API.
|
438
|
+
"""
|
439
|
+
try:
|
440
|
+
# First, search for pages using the search API
|
441
|
+
search_params: dict[str, str | int] = {
|
442
|
+
"action": "query",
|
443
|
+
"format": "json",
|
444
|
+
"list": "search",
|
445
|
+
"srsearch": query,
|
446
|
+
"srlimit": max_results,
|
447
|
+
"srprop": "snippet",
|
448
|
+
}
|
449
|
+
|
450
|
+
headers = {
|
451
|
+
"User-Agent": "OpikOptimizer/1.0 (https://github.com/opik-ai/opik-optimizer)"
|
452
|
+
}
|
453
|
+
search_response = requests.get(
|
454
|
+
"https://en.wikipedia.org/w/api.php",
|
455
|
+
params=search_params,
|
456
|
+
headers=headers,
|
457
|
+
timeout=5,
|
458
|
+
)
|
459
|
+
|
460
|
+
if search_response.status_code != 200:
|
461
|
+
raise Exception(f"Search API returned status {search_response.status_code}")
|
462
|
+
|
463
|
+
search_data = search_response.json()
|
464
|
+
|
465
|
+
results = []
|
466
|
+
if "query" in search_data and "search" in search_data["query"]:
|
467
|
+
for item in search_data["query"]["search"][:max_results]:
|
468
|
+
page_title = item["title"]
|
469
|
+
snippet = item.get("snippet", "")
|
470
|
+
|
471
|
+
# Clean up the snippet (remove HTML tags)
|
472
|
+
import re
|
473
|
+
|
474
|
+
clean_snippet = re.sub(r"<[^>]+>", "", snippet)
|
475
|
+
clean_snippet = re.sub(r"&[^;]+;", " ", clean_snippet)
|
476
|
+
|
477
|
+
if clean_snippet.strip():
|
478
|
+
results.append(f"{page_title}: {clean_snippet.strip()}")
|
479
|
+
|
480
|
+
return results if results else [f"No Wikipedia results found for: {query}"]
|
481
|
+
|
482
|
+
except Exception as e:
|
483
|
+
raise Exception(f"Wikipedia API request failed: {e}") from e
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import hashlib
|
4
|
+
import secrets
|
5
|
+
import time
|
6
|
+
from functools import lru_cache
|
7
|
+
from importlib import resources
|
8
|
+
from typing import Any
|
9
|
+
from collections.abc import Iterable
|
10
|
+
|
11
|
+
|
12
|
+
@lru_cache(maxsize=None)
|
13
|
+
def dataset_suffix(package: str, filename: str) -> str:
|
14
|
+
"""Return a stable checksum-based suffix for a JSONL dataset file."""
|
15
|
+
text = resources.files(package).joinpath(filename).read_text(encoding="utf-8")
|
16
|
+
return hashlib.md5(text.encode("utf-8")).hexdigest()[:8]
|
17
|
+
|
18
|
+
|
19
|
+
def generate_uuid7_str() -> str:
|
20
|
+
"""Generate a UUIDv7-compatible string, emulating the layout if unavailable."""
|
21
|
+
import uuid
|
22
|
+
|
23
|
+
if hasattr(uuid, "uuid7"):
|
24
|
+
return str(uuid.uuid7()) # type: ignore[attr-defined]
|
25
|
+
|
26
|
+
unix_ts_ms = int(time.time() * 1000) & ((1 << 48) - 1)
|
27
|
+
rand_a = secrets.randbits(12)
|
28
|
+
rand_b = secrets.randbits(62)
|
29
|
+
|
30
|
+
uuid_int = unix_ts_ms << 80
|
31
|
+
uuid_int |= 0x7 << 76 # version 7
|
32
|
+
uuid_int |= rand_a << 64
|
33
|
+
uuid_int |= 0b10 << 62 # RFC4122 variant
|
34
|
+
uuid_int |= rand_b
|
35
|
+
|
36
|
+
return str(uuid.UUID(int=uuid_int))
|
37
|
+
|
38
|
+
|
39
|
+
def attach_uuids(records: Iterable[dict[str, Any]]) -> list[dict[str, Any]]:
|
40
|
+
"""Copy records and assign a fresh UUIDv7 `id` to each."""
|
41
|
+
payload: list[dict[str, Any]] = []
|
42
|
+
for record in records:
|
43
|
+
rec = dict(record)
|
44
|
+
rec["id"] = generate_uuid7_str()
|
45
|
+
payload.append(rec)
|
46
|
+
return payload
|
47
|
+
|
48
|
+
|
49
|
+
__all__ = ["dataset_suffix", "generate_uuid7_str", "attach_uuids"]
|
@@ -0,0 +1,186 @@
|
|
1
|
+
"""Prompt segmentation helpers for targeted prompt updates.
|
2
|
+
|
3
|
+
These utilities operate on existing ``ChatPrompt`` instances without
|
4
|
+
changing their constructor, allowing callers to identify and update
|
5
|
+
specific sections (system message, individual chat messages, or tool
|
6
|
+
descriptions) while preserving backwards compatibility for the rest of
|
7
|
+
the optimizer stack.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from __future__ import annotations
|
11
|
+
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from typing import Any
|
14
|
+
from collections.abc import Iterable
|
15
|
+
|
16
|
+
import copy
|
17
|
+
|
18
|
+
from ..optimization_config.chat_prompt import ChatPrompt
|
19
|
+
|
20
|
+
|
21
|
+
PROMPT_SEGMENT_PREFIX_TOOL = "tool:"
|
22
|
+
PROMPT_SEGMENT_PREFIX_MESSAGE = "message:"
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class PromptSegment:
|
27
|
+
"""Lightweight view over a prompt component that may be edited."""
|
28
|
+
|
29
|
+
segment_id: str
|
30
|
+
kind: str
|
31
|
+
role: str | None
|
32
|
+
content: str
|
33
|
+
metadata: dict[str, Any]
|
34
|
+
|
35
|
+
def is_tool(self) -> bool:
|
36
|
+
return self.segment_id.startswith(PROMPT_SEGMENT_PREFIX_TOOL)
|
37
|
+
|
38
|
+
|
39
|
+
def _normalise_tool(tool: dict[str, Any]) -> dict[str, Any]:
|
40
|
+
"""Return tools in the ``{"function": {...}}`` structure for consistency."""
|
41
|
+
|
42
|
+
if "function" in tool:
|
43
|
+
return copy.deepcopy(tool)
|
44
|
+
|
45
|
+
normalised = copy.deepcopy(tool)
|
46
|
+
function_block = {
|
47
|
+
"name": normalised.pop("name", None),
|
48
|
+
"description": normalised.pop("description", ""),
|
49
|
+
"parameters": normalised.pop("parameters", None),
|
50
|
+
}
|
51
|
+
normalised = {"function": function_block, **normalised}
|
52
|
+
return normalised
|
53
|
+
|
54
|
+
|
55
|
+
def extract_prompt_segments(prompt: ChatPrompt) -> list[PromptSegment]:
|
56
|
+
"""Extract individual editable segments from ``prompt``.
|
57
|
+
|
58
|
+
The extraction preserves order for chat messages while assigning
|
59
|
+
stable segment identifiers:
|
60
|
+
|
61
|
+
* ``system`` for the system field (if present)
|
62
|
+
* ``user`` for the top-level user field (if present)
|
63
|
+
* ``message:<index>`` for entries in ``messages``
|
64
|
+
* ``tool:<name>`` for tool descriptions
|
65
|
+
"""
|
66
|
+
|
67
|
+
segments: list[PromptSegment] = []
|
68
|
+
|
69
|
+
if prompt.system is not None:
|
70
|
+
segments.append(
|
71
|
+
PromptSegment(
|
72
|
+
segment_id="system",
|
73
|
+
kind="system",
|
74
|
+
role="system",
|
75
|
+
content=prompt.system,
|
76
|
+
metadata={},
|
77
|
+
)
|
78
|
+
)
|
79
|
+
|
80
|
+
if prompt.messages is not None:
|
81
|
+
for idx, message in enumerate(prompt.messages):
|
82
|
+
segments.append(
|
83
|
+
PromptSegment(
|
84
|
+
segment_id=f"{PROMPT_SEGMENT_PREFIX_MESSAGE}{idx}",
|
85
|
+
kind="message",
|
86
|
+
role=message.get("role"),
|
87
|
+
content=message.get("content", ""),
|
88
|
+
metadata={
|
89
|
+
key: value for key, value in message.items() if key != "content"
|
90
|
+
},
|
91
|
+
)
|
92
|
+
)
|
93
|
+
|
94
|
+
if prompt.user is not None:
|
95
|
+
segments.append(
|
96
|
+
PromptSegment(
|
97
|
+
segment_id="user",
|
98
|
+
kind="user",
|
99
|
+
role="user",
|
100
|
+
content=prompt.user,
|
101
|
+
metadata={},
|
102
|
+
)
|
103
|
+
)
|
104
|
+
|
105
|
+
if prompt.tools:
|
106
|
+
for tool in prompt.tools:
|
107
|
+
normalised = _normalise_tool(tool)
|
108
|
+
function_block = normalised.get("function", {})
|
109
|
+
tool_name = function_block.get("name")
|
110
|
+
if not tool_name:
|
111
|
+
continue
|
112
|
+
segments.append(
|
113
|
+
PromptSegment(
|
114
|
+
segment_id=f"{PROMPT_SEGMENT_PREFIX_TOOL}{tool_name}",
|
115
|
+
kind="tool",
|
116
|
+
role="tool",
|
117
|
+
content=function_block.get("description", ""),
|
118
|
+
metadata={
|
119
|
+
"parameters": function_block.get("parameters"),
|
120
|
+
"raw_tool": normalised,
|
121
|
+
},
|
122
|
+
)
|
123
|
+
)
|
124
|
+
|
125
|
+
return segments
|
126
|
+
|
127
|
+
|
128
|
+
def apply_segment_updates(
|
129
|
+
prompt: ChatPrompt,
|
130
|
+
updates: dict[str, str],
|
131
|
+
) -> ChatPrompt:
|
132
|
+
"""Return a new ``ChatPrompt`` with selected segments replaced.
|
133
|
+
|
134
|
+
``updates`` maps segment identifiers (as produced by
|
135
|
+
``extract_prompt_segments``) to replacement strings.
|
136
|
+
"""
|
137
|
+
|
138
|
+
system = updates.get("system", prompt.system)
|
139
|
+
user = updates.get("user", prompt.user)
|
140
|
+
|
141
|
+
messages: list[dict[str, Any]] | None = None
|
142
|
+
if prompt.messages is not None:
|
143
|
+
new_messages: list[dict[str, Any]] = []
|
144
|
+
for idx, message in enumerate(prompt.messages):
|
145
|
+
segment_id = f"{PROMPT_SEGMENT_PREFIX_MESSAGE}{idx}"
|
146
|
+
replacement = updates.get(segment_id)
|
147
|
+
if replacement is not None:
|
148
|
+
updated_message = copy.deepcopy(message)
|
149
|
+
updated_message["content"] = replacement
|
150
|
+
new_messages.append(updated_message)
|
151
|
+
else:
|
152
|
+
new_messages.append(copy.deepcopy(message))
|
153
|
+
messages = new_messages
|
154
|
+
|
155
|
+
tools = copy.deepcopy(prompt.tools) if prompt.tools else None
|
156
|
+
if tools:
|
157
|
+
for tool in tools:
|
158
|
+
normalised = _normalise_tool(tool)
|
159
|
+
function_block = normalised.get("function", {})
|
160
|
+
tool_name = function_block.get("name")
|
161
|
+
if not tool_name:
|
162
|
+
continue
|
163
|
+
segment_id = f"{PROMPT_SEGMENT_PREFIX_TOOL}{tool_name}"
|
164
|
+
replacement = updates.get(segment_id)
|
165
|
+
if replacement is not None:
|
166
|
+
function_block["description"] = replacement
|
167
|
+
tool.update(normalised)
|
168
|
+
|
169
|
+
return ChatPrompt(
|
170
|
+
name=prompt.name,
|
171
|
+
system=system,
|
172
|
+
user=user,
|
173
|
+
messages=messages,
|
174
|
+
tools=tools,
|
175
|
+
function_map=prompt.function_map,
|
176
|
+
model=prompt.model,
|
177
|
+
invoke=prompt.invoke,
|
178
|
+
project_name=prompt.project_name,
|
179
|
+
**prompt.model_kwargs,
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
def segment_ids_for_tools(segments: Iterable[PromptSegment]) -> list[str]:
|
184
|
+
"""Convenience helper returning IDs of tool segments."""
|
185
|
+
|
186
|
+
return [segment.segment_id for segment in segments if segment.is_tool()]
|