opik-optimizer 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opik_optimizer/__init__.py +65 -0
- opik_optimizer/_throttle.py +43 -0
- opik_optimizer/base_optimizer.py +240 -0
- opik_optimizer/cache_config.py +24 -0
- opik_optimizer/demo/__init__.py +7 -0
- opik_optimizer/demo/cache.py +112 -0
- opik_optimizer/demo/datasets.py +656 -0
- opik_optimizer/few_shot_bayesian_optimizer/__init__.py +5 -0
- opik_optimizer/few_shot_bayesian_optimizer/few_shot_bayesian_optimizer.py +408 -0
- opik_optimizer/few_shot_bayesian_optimizer/prompt_parameter.py +91 -0
- opik_optimizer/few_shot_bayesian_optimizer/prompt_templates.py +80 -0
- opik_optimizer/integrations/__init__.py +0 -0
- opik_optimizer/logging_config.py +69 -0
- opik_optimizer/meta_prompt_optimizer.py +1100 -0
- opik_optimizer/mipro_optimizer/__init__.py +1 -0
- opik_optimizer/mipro_optimizer/_lm.py +394 -0
- opik_optimizer/mipro_optimizer/_mipro_optimizer_v2.py +1058 -0
- opik_optimizer/mipro_optimizer/mipro_optimizer.py +395 -0
- opik_optimizer/mipro_optimizer/utils.py +107 -0
- opik_optimizer/optimization_config/__init__.py +0 -0
- opik_optimizer/optimization_config/configs.py +35 -0
- opik_optimizer/optimization_config/mappers.py +49 -0
- opik_optimizer/optimization_result.py +211 -0
- opik_optimizer/task_evaluator.py +102 -0
- opik_optimizer/utils.py +132 -0
- opik_optimizer-0.7.0.dist-info/METADATA +35 -0
- opik_optimizer-0.7.0.dist-info/RECORD +30 -0
- opik_optimizer-0.7.0.dist-info/WHEEL +5 -0
- opik_optimizer-0.7.0.dist-info/licenses/LICENSE +21 -0
- opik_optimizer-0.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1 @@
|
|
1
|
+
from .mipro_optimizer import MiproOptimizer
|
@@ -0,0 +1,394 @@
|
|
1
|
+
import functools
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
import threading
|
6
|
+
from hashlib import sha256
|
7
|
+
from typing import Any, Dict, List, Literal, Optional, cast
|
8
|
+
|
9
|
+
import litellm
|
10
|
+
import pydantic
|
11
|
+
import ujson
|
12
|
+
from anyio.streams.memory import MemoryObjectSendStream
|
13
|
+
from asyncer import syncify
|
14
|
+
from cachetools import LRUCache, cached
|
15
|
+
from litellm import RetryPolicy
|
16
|
+
|
17
|
+
import dspy
|
18
|
+
from dspy.clients.openai import OpenAIProvider
|
19
|
+
from dspy.clients.provider import Provider, TrainingJob
|
20
|
+
from dspy.clients.utils_finetune import TrainDataFormat
|
21
|
+
from dspy.dsp.utils.settings import settings
|
22
|
+
from dspy.utils.callback import BaseCallback, with_callbacks
|
23
|
+
from dspy.clients.base_lm import BaseLM
|
24
|
+
|
25
|
+
from .._throttle import RateLimiter, rate_limited
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
# Limit how fast an LLM can be called:
|
29
|
+
limiter = RateLimiter(max_calls_per_second=15)
|
30
|
+
|
31
|
+
class LM(BaseLM):
|
32
|
+
"""
|
33
|
+
A language model supporting chat or text completion requests for use with DSPy modules.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
model: str,
|
39
|
+
model_type: Literal["chat", "text"] = "chat",
|
40
|
+
temperature: float = 0.0,
|
41
|
+
max_tokens: int = 1000,
|
42
|
+
cache: bool = True,
|
43
|
+
cache_in_memory: bool = True,
|
44
|
+
callbacks: Optional[List[BaseCallback]] = None,
|
45
|
+
num_retries: int = 8,
|
46
|
+
provider=None,
|
47
|
+
finetuning_model: Optional[str] = None,
|
48
|
+
launch_kwargs: Optional[dict[str, Any]] = None,
|
49
|
+
train_kwargs: Optional[dict[str, Any]] = None,
|
50
|
+
**kwargs,
|
51
|
+
):
|
52
|
+
"""
|
53
|
+
Create a new language model instance for use with DSPy modules and programs.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
|
57
|
+
supported by LiteLLM. For example, ``"openai/gpt-4o"``.
|
58
|
+
model_type: The type of the model, either ``"chat"`` or ``"text"``.
|
59
|
+
temperature: The sampling temperature to use when generating responses.
|
60
|
+
max_tokens: The maximum number of tokens to generate per response.
|
61
|
+
cache: Whether to cache the model responses for reuse to improve performance
|
62
|
+
and reduce costs.
|
63
|
+
cache_in_memory: To enable additional caching with LRU in memory.
|
64
|
+
callbacks: A list of callback functions to run before and after each request.
|
65
|
+
num_retries: The number of times to retry a request if it fails transiently due to
|
66
|
+
network error, rate limiting, etc. Requests are retried with exponential
|
67
|
+
backoff.
|
68
|
+
provider: The provider to use. If not specified, the provider will be inferred from the model.
|
69
|
+
finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
|
70
|
+
from the models available for inference.
|
71
|
+
"""
|
72
|
+
# Remember to update LM.copy() if you modify the constructor!
|
73
|
+
self.model = model
|
74
|
+
self.model_type = model_type
|
75
|
+
self.cache = cache
|
76
|
+
self.cache_in_memory = cache_in_memory
|
77
|
+
self.provider = provider or self.infer_provider()
|
78
|
+
self.callbacks = callbacks or []
|
79
|
+
self.history = []
|
80
|
+
self.callbacks = callbacks or []
|
81
|
+
self.num_retries = num_retries
|
82
|
+
self.finetuning_model = finetuning_model
|
83
|
+
self.launch_kwargs = launch_kwargs or {}
|
84
|
+
self.train_kwargs = train_kwargs or {}
|
85
|
+
|
86
|
+
# Handle model-specific configuration for different model families
|
87
|
+
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
|
88
|
+
|
89
|
+
# Match pattern: o[1,3] at the start, optionally followed by -mini and anything else
|
90
|
+
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family)
|
91
|
+
|
92
|
+
if model_pattern:
|
93
|
+
# Handle OpenAI reasoning models (o1, o3)
|
94
|
+
assert (
|
95
|
+
max_tokens >= 20_000 and temperature == 1.0
|
96
|
+
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
|
97
|
+
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
|
98
|
+
else:
|
99
|
+
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
|
100
|
+
|
101
|
+
@rate_limited(limiter)
|
102
|
+
@with_callbacks
|
103
|
+
def forward(self, prompt=None, messages=None, **kwargs):
|
104
|
+
# Build the request.
|
105
|
+
cache = kwargs.pop("cache", self.cache)
|
106
|
+
# disable cache will also disable in memory cache
|
107
|
+
cache_in_memory = cache and kwargs.pop("cache_in_memory", self.cache_in_memory)
|
108
|
+
messages = messages or [{"role": "user", "content": prompt}]
|
109
|
+
kwargs = {**self.kwargs, **kwargs}
|
110
|
+
|
111
|
+
# Make the request and handle LRU & disk caching.
|
112
|
+
if cache_in_memory:
|
113
|
+
completion = cached_litellm_completion if self.model_type == "chat" else cached_litellm_text_completion
|
114
|
+
|
115
|
+
results = completion(
|
116
|
+
request=dict(model=self.model, messages=messages, **kwargs),
|
117
|
+
num_retries=self.num_retries,
|
118
|
+
)
|
119
|
+
else:
|
120
|
+
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
|
121
|
+
|
122
|
+
results = completion(
|
123
|
+
request=dict(model=self.model, messages=messages, **kwargs),
|
124
|
+
num_retries=self.num_retries,
|
125
|
+
# only leverage LiteLLM cache in this case
|
126
|
+
cache={"no-cache": not cache, "no-store": not cache},
|
127
|
+
)
|
128
|
+
|
129
|
+
if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
|
130
|
+
settings.usage_tracker.add_usage(self.model, dict(results.usage))
|
131
|
+
|
132
|
+
return results
|
133
|
+
|
134
|
+
def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
|
135
|
+
self.provider.launch(self, launch_kwargs)
|
136
|
+
|
137
|
+
def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None):
|
138
|
+
self.provider.kill(self, launch_kwargs)
|
139
|
+
|
140
|
+
def finetune(
|
141
|
+
self,
|
142
|
+
train_data: List[Dict[str, Any]],
|
143
|
+
train_data_format: Optional[TrainDataFormat],
|
144
|
+
train_kwargs: Optional[Dict[str, Any]] = None,
|
145
|
+
) -> TrainingJob:
|
146
|
+
from dspy import settings as settings
|
147
|
+
|
148
|
+
err = "Fine-tuning is an experimental feature."
|
149
|
+
err += " Set `dspy.settings.experimental` to `True` to use it."
|
150
|
+
assert settings.experimental, err
|
151
|
+
|
152
|
+
err = f"Provider {self.provider} does not support fine-tuning."
|
153
|
+
assert self.provider.finetunable, err
|
154
|
+
|
155
|
+
def thread_function_wrapper():
|
156
|
+
return self._run_finetune_job(job)
|
157
|
+
|
158
|
+
thread = threading.Thread(target=thread_function_wrapper)
|
159
|
+
train_kwargs = train_kwargs or self.train_kwargs
|
160
|
+
model_to_finetune = self.finetuning_model or self.model
|
161
|
+
job = self.provider.TrainingJob(
|
162
|
+
thread=thread,
|
163
|
+
model=model_to_finetune,
|
164
|
+
train_data=train_data,
|
165
|
+
train_data_format=train_data_format,
|
166
|
+
train_kwargs=train_kwargs,
|
167
|
+
)
|
168
|
+
thread.start()
|
169
|
+
|
170
|
+
return job
|
171
|
+
|
172
|
+
def _run_finetune_job(self, job: TrainingJob):
|
173
|
+
# TODO(enhance): We should listen for keyboard interrupts somewhere.
|
174
|
+
# Requires TrainingJob.cancel() to be implemented for each provider.
|
175
|
+
try:
|
176
|
+
model = self.provider.finetune(
|
177
|
+
job=job,
|
178
|
+
model=job.model,
|
179
|
+
train_data=job.train_data,
|
180
|
+
train_data_format=job.train_data_format,
|
181
|
+
train_kwargs=job.train_kwargs,
|
182
|
+
)
|
183
|
+
lm = self.copy(model=model)
|
184
|
+
job.set_result(lm)
|
185
|
+
except Exception as err:
|
186
|
+
logger.error(err)
|
187
|
+
job.set_result(err)
|
188
|
+
|
189
|
+
def infer_provider(self) -> Provider:
|
190
|
+
if OpenAIProvider.is_provider_model(self.model):
|
191
|
+
return OpenAIProvider()
|
192
|
+
return Provider()
|
193
|
+
|
194
|
+
def dump_state(self):
|
195
|
+
state_keys = [
|
196
|
+
"model",
|
197
|
+
"model_type",
|
198
|
+
"cache",
|
199
|
+
"cache_in_memory",
|
200
|
+
"num_retries",
|
201
|
+
"finetuning_model",
|
202
|
+
"launch_kwargs",
|
203
|
+
"train_kwargs",
|
204
|
+
]
|
205
|
+
return {key: getattr(self, key) for key in state_keys} | self.kwargs
|
206
|
+
|
207
|
+
|
208
|
+
def request_cache(maxsize: Optional[int] = None):
|
209
|
+
"""
|
210
|
+
A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
|
211
|
+
a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
|
212
|
+
good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).
|
213
|
+
|
214
|
+
Args:
|
215
|
+
maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded).
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
A decorator that wraps the target function with caching.
|
219
|
+
"""
|
220
|
+
|
221
|
+
def cache_key(request: Dict[str, Any]) -> str:
|
222
|
+
"""
|
223
|
+
Obtain a unique cache key for the given request dictionary by hashing its JSON
|
224
|
+
representation. For request fields having types that are known to be JSON-incompatible,
|
225
|
+
convert them to a JSON-serializable format before hashing.
|
226
|
+
|
227
|
+
Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since
|
228
|
+
that would potentially lead to cache collisions. For example, consider request A
|
229
|
+
containing only JSON-convertible values and request B containing the same JSON-convertible
|
230
|
+
values in addition to one unconvertible value. Discarding the unconvertible value would
|
231
|
+
lead to a cache collision between requests A and B, even though they are semantically
|
232
|
+
different.
|
233
|
+
"""
|
234
|
+
|
235
|
+
def transform_value(value):
|
236
|
+
if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
|
237
|
+
return value.model_json_schema()
|
238
|
+
elif isinstance(value, pydantic.BaseModel):
|
239
|
+
return value.model_dump()
|
240
|
+
elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"):
|
241
|
+
return value.__code__.co_code.decode("utf-8")
|
242
|
+
else:
|
243
|
+
# Note: We don't attempt to compute a hash of the value, since the default
|
244
|
+
# implementation of hash() is id(), which may collide if the same memory address
|
245
|
+
# is reused for different objects at different times
|
246
|
+
return value
|
247
|
+
|
248
|
+
params = {k: transform_value(v) for k, v in request.items()}
|
249
|
+
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
|
250
|
+
|
251
|
+
def decorator(func):
|
252
|
+
@cached(
|
253
|
+
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
|
254
|
+
cache=LRUCache(maxsize=maxsize or float("inf")),
|
255
|
+
key=lambda key, request, *args, **kwargs: key,
|
256
|
+
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
|
257
|
+
# concurrently, e.g. during optimization and evaluation
|
258
|
+
lock=threading.RLock(),
|
259
|
+
)
|
260
|
+
def func_cached(key: str, request: Dict[str, Any], *args, **kwargs):
|
261
|
+
return func(request, *args, **kwargs)
|
262
|
+
|
263
|
+
@functools.wraps(func)
|
264
|
+
def wrapper(request: dict, *args, **kwargs):
|
265
|
+
try:
|
266
|
+
key = cache_key(request)
|
267
|
+
except Exception:
|
268
|
+
# If the cache key cannot be computed (e.g. because it contains a value that cannot
|
269
|
+
# be converted to JSON), bypass the cache and call the target function directly
|
270
|
+
return func(request, *args, **kwargs)
|
271
|
+
cache_hit = key in func_cached.cache
|
272
|
+
output = func_cached(key, request, *args, **kwargs)
|
273
|
+
if cache_hit and hasattr(output, "usage"):
|
274
|
+
# Clear the usage data when cache is hit, because no LM call is made
|
275
|
+
output.usage = {}
|
276
|
+
|
277
|
+
return func_cached(key, request, *args, **kwargs)
|
278
|
+
|
279
|
+
return wrapper
|
280
|
+
|
281
|
+
return decorator
|
282
|
+
|
283
|
+
|
284
|
+
@request_cache(maxsize=None)
|
285
|
+
def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
|
286
|
+
return litellm_completion(
|
287
|
+
request,
|
288
|
+
cache={"no-cache": False, "no-store": False},
|
289
|
+
num_retries=num_retries,
|
290
|
+
)
|
291
|
+
|
292
|
+
|
293
|
+
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
|
294
|
+
retry_kwargs = dict(
|
295
|
+
retry_policy=_get_litellm_retry_policy(num_retries),
|
296
|
+
retry_strategy="exponential_backoff_retry",
|
297
|
+
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
|
298
|
+
# to completion()), the default value of max_retries is non-zero for certain providers, and
|
299
|
+
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
|
300
|
+
max_retries=0,
|
301
|
+
)
|
302
|
+
|
303
|
+
stream = dspy.settings.send_stream
|
304
|
+
caller_predict = dspy.settings.caller_predict
|
305
|
+
if stream is None:
|
306
|
+
# If `streamify` is not used, or if the exact predict doesn't need to be streamed,
|
307
|
+
# we can just return the completion without streaming.
|
308
|
+
return litellm.completion(
|
309
|
+
cache=cache,
|
310
|
+
**retry_kwargs,
|
311
|
+
**request,
|
312
|
+
)
|
313
|
+
|
314
|
+
# The stream is already opened, and will be closed by the caller.
|
315
|
+
stream = cast(MemoryObjectSendStream, stream)
|
316
|
+
caller_predict_id = id(caller_predict) if caller_predict else None
|
317
|
+
|
318
|
+
@syncify
|
319
|
+
async def stream_completion():
|
320
|
+
response = await litellm.acompletion(
|
321
|
+
cache=cache,
|
322
|
+
stream=True,
|
323
|
+
**retry_kwargs,
|
324
|
+
**request,
|
325
|
+
)
|
326
|
+
chunks = []
|
327
|
+
async for chunk in response:
|
328
|
+
if caller_predict_id:
|
329
|
+
# Add the predict id to the chunk so that the stream listener can identify which predict produces it.
|
330
|
+
chunk.predict_id = caller_predict_id
|
331
|
+
chunks.append(chunk)
|
332
|
+
await stream.send(chunk)
|
333
|
+
return litellm.stream_chunk_builder(chunks)
|
334
|
+
|
335
|
+
return stream_completion()
|
336
|
+
|
337
|
+
|
338
|
+
@request_cache(maxsize=None)
|
339
|
+
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
|
340
|
+
return litellm_text_completion(
|
341
|
+
request,
|
342
|
+
num_retries=num_retries,
|
343
|
+
cache={"no-cache": False, "no-store": False},
|
344
|
+
)
|
345
|
+
|
346
|
+
|
347
|
+
def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
|
348
|
+
# Extract the provider and model from the model string.
|
349
|
+
# TODO: Not all the models are in the format of "provider/model"
|
350
|
+
model = request.pop("model").split("/", 1)
|
351
|
+
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
|
352
|
+
|
353
|
+
# Use the API key and base from the request, or from the environment.
|
354
|
+
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
|
355
|
+
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
|
356
|
+
|
357
|
+
# Build the prompt from the messages.
|
358
|
+
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
|
359
|
+
|
360
|
+
return litellm.text_completion(
|
361
|
+
cache=cache,
|
362
|
+
model=f"text-completion-openai/{model}",
|
363
|
+
api_key=api_key,
|
364
|
+
api_base=api_base,
|
365
|
+
prompt=prompt,
|
366
|
+
retry_policy=_get_litellm_retry_policy(num_retries),
|
367
|
+
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
|
368
|
+
# to completion()), the default value of max_retries is non-zero for certain providers, and
|
369
|
+
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
|
370
|
+
max_retries=0,
|
371
|
+
**request,
|
372
|
+
)
|
373
|
+
|
374
|
+
|
375
|
+
def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
|
376
|
+
"""
|
377
|
+
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
|
378
|
+
Args:
|
379
|
+
num_retries: The number of times to retry a request if it fails transiently due to
|
380
|
+
network error, rate limiting, etc. Requests are retried with exponential
|
381
|
+
backoff.
|
382
|
+
Returns:
|
383
|
+
A LiteLLM RetryPolicy instance.
|
384
|
+
"""
|
385
|
+
return RetryPolicy(
|
386
|
+
TimeoutErrorRetries=num_retries,
|
387
|
+
RateLimitErrorRetries=num_retries,
|
388
|
+
InternalServerErrorRetries=num_retries,
|
389
|
+
ContentPolicyViolationErrorRetries=num_retries,
|
390
|
+
# We don't retry on errors that are unlikely to be transient
|
391
|
+
# (e.g. bad request, invalid auth credentials)
|
392
|
+
BadRequestErrorRetries=0,
|
393
|
+
AuthenticationErrorRetries=0,
|
394
|
+
)
|