opik-optimizer 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ from .mipro_optimizer import MiproOptimizer
@@ -0,0 +1,394 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ import re
5
+ import threading
6
+ from hashlib import sha256
7
+ from typing import Any, Dict, List, Literal, Optional, cast
8
+
9
+ import litellm
10
+ import pydantic
11
+ import ujson
12
+ from anyio.streams.memory import MemoryObjectSendStream
13
+ from asyncer import syncify
14
+ from cachetools import LRUCache, cached
15
+ from litellm import RetryPolicy
16
+
17
+ import dspy
18
+ from dspy.clients.openai import OpenAIProvider
19
+ from dspy.clients.provider import Provider, TrainingJob
20
+ from dspy.clients.utils_finetune import TrainDataFormat
21
+ from dspy.dsp.utils.settings import settings
22
+ from dspy.utils.callback import BaseCallback, with_callbacks
23
+ from dspy.clients.base_lm import BaseLM
24
+
25
+ from .._throttle import RateLimiter, rate_limited
26
+
27
+ logger = logging.getLogger(__name__)
28
+ # Limit how fast an LLM can be called:
29
+ limiter = RateLimiter(max_calls_per_second=15)
30
+
31
+ class LM(BaseLM):
32
+ """
33
+ A language model supporting chat or text completion requests for use with DSPy modules.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: str,
39
+ model_type: Literal["chat", "text"] = "chat",
40
+ temperature: float = 0.0,
41
+ max_tokens: int = 1000,
42
+ cache: bool = True,
43
+ cache_in_memory: bool = True,
44
+ callbacks: Optional[List[BaseCallback]] = None,
45
+ num_retries: int = 8,
46
+ provider=None,
47
+ finetuning_model: Optional[str] = None,
48
+ launch_kwargs: Optional[dict[str, Any]] = None,
49
+ train_kwargs: Optional[dict[str, Any]] = None,
50
+ **kwargs,
51
+ ):
52
+ """
53
+ Create a new language model instance for use with DSPy modules and programs.
54
+
55
+ Args:
56
+ model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
57
+ supported by LiteLLM. For example, ``"openai/gpt-4o"``.
58
+ model_type: The type of the model, either ``"chat"`` or ``"text"``.
59
+ temperature: The sampling temperature to use when generating responses.
60
+ max_tokens: The maximum number of tokens to generate per response.
61
+ cache: Whether to cache the model responses for reuse to improve performance
62
+ and reduce costs.
63
+ cache_in_memory: To enable additional caching with LRU in memory.
64
+ callbacks: A list of callback functions to run before and after each request.
65
+ num_retries: The number of times to retry a request if it fails transiently due to
66
+ network error, rate limiting, etc. Requests are retried with exponential
67
+ backoff.
68
+ provider: The provider to use. If not specified, the provider will be inferred from the model.
69
+ finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
70
+ from the models available for inference.
71
+ """
72
+ # Remember to update LM.copy() if you modify the constructor!
73
+ self.model = model
74
+ self.model_type = model_type
75
+ self.cache = cache
76
+ self.cache_in_memory = cache_in_memory
77
+ self.provider = provider or self.infer_provider()
78
+ self.callbacks = callbacks or []
79
+ self.history = []
80
+ self.callbacks = callbacks or []
81
+ self.num_retries = num_retries
82
+ self.finetuning_model = finetuning_model
83
+ self.launch_kwargs = launch_kwargs or {}
84
+ self.train_kwargs = train_kwargs or {}
85
+
86
+ # Handle model-specific configuration for different model families
87
+ model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
88
+
89
+ # Match pattern: o[1,3] at the start, optionally followed by -mini and anything else
90
+ model_pattern = re.match(r"^o([13])(?:-mini)?", model_family)
91
+
92
+ if model_pattern:
93
+ # Handle OpenAI reasoning models (o1, o3)
94
+ assert (
95
+ max_tokens >= 20_000 and temperature == 1.0
96
+ ), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
97
+ self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
98
+ else:
99
+ self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
100
+
101
+ @rate_limited(limiter)
102
+ @with_callbacks
103
+ def forward(self, prompt=None, messages=None, **kwargs):
104
+ # Build the request.
105
+ cache = kwargs.pop("cache", self.cache)
106
+ # disable cache will also disable in memory cache
107
+ cache_in_memory = cache and kwargs.pop("cache_in_memory", self.cache_in_memory)
108
+ messages = messages or [{"role": "user", "content": prompt}]
109
+ kwargs = {**self.kwargs, **kwargs}
110
+
111
+ # Make the request and handle LRU & disk caching.
112
+ if cache_in_memory:
113
+ completion = cached_litellm_completion if self.model_type == "chat" else cached_litellm_text_completion
114
+
115
+ results = completion(
116
+ request=dict(model=self.model, messages=messages, **kwargs),
117
+ num_retries=self.num_retries,
118
+ )
119
+ else:
120
+ completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
121
+
122
+ results = completion(
123
+ request=dict(model=self.model, messages=messages, **kwargs),
124
+ num_retries=self.num_retries,
125
+ # only leverage LiteLLM cache in this case
126
+ cache={"no-cache": not cache, "no-store": not cache},
127
+ )
128
+
129
+ if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
130
+ settings.usage_tracker.add_usage(self.model, dict(results.usage))
131
+
132
+ return results
133
+
134
+ def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
135
+ self.provider.launch(self, launch_kwargs)
136
+
137
+ def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None):
138
+ self.provider.kill(self, launch_kwargs)
139
+
140
+ def finetune(
141
+ self,
142
+ train_data: List[Dict[str, Any]],
143
+ train_data_format: Optional[TrainDataFormat],
144
+ train_kwargs: Optional[Dict[str, Any]] = None,
145
+ ) -> TrainingJob:
146
+ from dspy import settings as settings
147
+
148
+ err = "Fine-tuning is an experimental feature."
149
+ err += " Set `dspy.settings.experimental` to `True` to use it."
150
+ assert settings.experimental, err
151
+
152
+ err = f"Provider {self.provider} does not support fine-tuning."
153
+ assert self.provider.finetunable, err
154
+
155
+ def thread_function_wrapper():
156
+ return self._run_finetune_job(job)
157
+
158
+ thread = threading.Thread(target=thread_function_wrapper)
159
+ train_kwargs = train_kwargs or self.train_kwargs
160
+ model_to_finetune = self.finetuning_model or self.model
161
+ job = self.provider.TrainingJob(
162
+ thread=thread,
163
+ model=model_to_finetune,
164
+ train_data=train_data,
165
+ train_data_format=train_data_format,
166
+ train_kwargs=train_kwargs,
167
+ )
168
+ thread.start()
169
+
170
+ return job
171
+
172
+ def _run_finetune_job(self, job: TrainingJob):
173
+ # TODO(enhance): We should listen for keyboard interrupts somewhere.
174
+ # Requires TrainingJob.cancel() to be implemented for each provider.
175
+ try:
176
+ model = self.provider.finetune(
177
+ job=job,
178
+ model=job.model,
179
+ train_data=job.train_data,
180
+ train_data_format=job.train_data_format,
181
+ train_kwargs=job.train_kwargs,
182
+ )
183
+ lm = self.copy(model=model)
184
+ job.set_result(lm)
185
+ except Exception as err:
186
+ logger.error(err)
187
+ job.set_result(err)
188
+
189
+ def infer_provider(self) -> Provider:
190
+ if OpenAIProvider.is_provider_model(self.model):
191
+ return OpenAIProvider()
192
+ return Provider()
193
+
194
+ def dump_state(self):
195
+ state_keys = [
196
+ "model",
197
+ "model_type",
198
+ "cache",
199
+ "cache_in_memory",
200
+ "num_retries",
201
+ "finetuning_model",
202
+ "launch_kwargs",
203
+ "train_kwargs",
204
+ ]
205
+ return {key: getattr(self, key) for key in state_keys} | self.kwargs
206
+
207
+
208
+ def request_cache(maxsize: Optional[int] = None):
209
+ """
210
+ A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
211
+ a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
212
+ good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).
213
+
214
+ Args:
215
+ maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded).
216
+
217
+ Returns:
218
+ A decorator that wraps the target function with caching.
219
+ """
220
+
221
+ def cache_key(request: Dict[str, Any]) -> str:
222
+ """
223
+ Obtain a unique cache key for the given request dictionary by hashing its JSON
224
+ representation. For request fields having types that are known to be JSON-incompatible,
225
+ convert them to a JSON-serializable format before hashing.
226
+
227
+ Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since
228
+ that would potentially lead to cache collisions. For example, consider request A
229
+ containing only JSON-convertible values and request B containing the same JSON-convertible
230
+ values in addition to one unconvertible value. Discarding the unconvertible value would
231
+ lead to a cache collision between requests A and B, even though they are semantically
232
+ different.
233
+ """
234
+
235
+ def transform_value(value):
236
+ if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
237
+ return value.model_json_schema()
238
+ elif isinstance(value, pydantic.BaseModel):
239
+ return value.model_dump()
240
+ elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"):
241
+ return value.__code__.co_code.decode("utf-8")
242
+ else:
243
+ # Note: We don't attempt to compute a hash of the value, since the default
244
+ # implementation of hash() is id(), which may collide if the same memory address
245
+ # is reused for different objects at different times
246
+ return value
247
+
248
+ params = {k: transform_value(v) for k, v in request.items()}
249
+ return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
250
+
251
+ def decorator(func):
252
+ @cached(
253
+ # NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
254
+ cache=LRUCache(maxsize=maxsize or float("inf")),
255
+ key=lambda key, request, *args, **kwargs: key,
256
+ # Use a lock to ensure thread safety for the cache when DSPy LMs are queried
257
+ # concurrently, e.g. during optimization and evaluation
258
+ lock=threading.RLock(),
259
+ )
260
+ def func_cached(key: str, request: Dict[str, Any], *args, **kwargs):
261
+ return func(request, *args, **kwargs)
262
+
263
+ @functools.wraps(func)
264
+ def wrapper(request: dict, *args, **kwargs):
265
+ try:
266
+ key = cache_key(request)
267
+ except Exception:
268
+ # If the cache key cannot be computed (e.g. because it contains a value that cannot
269
+ # be converted to JSON), bypass the cache and call the target function directly
270
+ return func(request, *args, **kwargs)
271
+ cache_hit = key in func_cached.cache
272
+ output = func_cached(key, request, *args, **kwargs)
273
+ if cache_hit and hasattr(output, "usage"):
274
+ # Clear the usage data when cache is hit, because no LM call is made
275
+ output.usage = {}
276
+
277
+ return func_cached(key, request, *args, **kwargs)
278
+
279
+ return wrapper
280
+
281
+ return decorator
282
+
283
+
284
+ @request_cache(maxsize=None)
285
+ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
286
+ return litellm_completion(
287
+ request,
288
+ cache={"no-cache": False, "no-store": False},
289
+ num_retries=num_retries,
290
+ )
291
+
292
+
293
+ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
294
+ retry_kwargs = dict(
295
+ retry_policy=_get_litellm_retry_policy(num_retries),
296
+ retry_strategy="exponential_backoff_retry",
297
+ # In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
298
+ # to completion()), the default value of max_retries is non-zero for certain providers, and
299
+ # max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
300
+ max_retries=0,
301
+ )
302
+
303
+ stream = dspy.settings.send_stream
304
+ caller_predict = dspy.settings.caller_predict
305
+ if stream is None:
306
+ # If `streamify` is not used, or if the exact predict doesn't need to be streamed,
307
+ # we can just return the completion without streaming.
308
+ return litellm.completion(
309
+ cache=cache,
310
+ **retry_kwargs,
311
+ **request,
312
+ )
313
+
314
+ # The stream is already opened, and will be closed by the caller.
315
+ stream = cast(MemoryObjectSendStream, stream)
316
+ caller_predict_id = id(caller_predict) if caller_predict else None
317
+
318
+ @syncify
319
+ async def stream_completion():
320
+ response = await litellm.acompletion(
321
+ cache=cache,
322
+ stream=True,
323
+ **retry_kwargs,
324
+ **request,
325
+ )
326
+ chunks = []
327
+ async for chunk in response:
328
+ if caller_predict_id:
329
+ # Add the predict id to the chunk so that the stream listener can identify which predict produces it.
330
+ chunk.predict_id = caller_predict_id
331
+ chunks.append(chunk)
332
+ await stream.send(chunk)
333
+ return litellm.stream_chunk_builder(chunks)
334
+
335
+ return stream_completion()
336
+
337
+
338
+ @request_cache(maxsize=None)
339
+ def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
340
+ return litellm_text_completion(
341
+ request,
342
+ num_retries=num_retries,
343
+ cache={"no-cache": False, "no-store": False},
344
+ )
345
+
346
+
347
+ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
348
+ # Extract the provider and model from the model string.
349
+ # TODO: Not all the models are in the format of "provider/model"
350
+ model = request.pop("model").split("/", 1)
351
+ provider, model = model[0] if len(model) > 1 else "openai", model[-1]
352
+
353
+ # Use the API key and base from the request, or from the environment.
354
+ api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
355
+ api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
356
+
357
+ # Build the prompt from the messages.
358
+ prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
359
+
360
+ return litellm.text_completion(
361
+ cache=cache,
362
+ model=f"text-completion-openai/{model}",
363
+ api_key=api_key,
364
+ api_base=api_base,
365
+ prompt=prompt,
366
+ retry_policy=_get_litellm_retry_policy(num_retries),
367
+ # In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
368
+ # to completion()), the default value of max_retries is non-zero for certain providers, and
369
+ # max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
370
+ max_retries=0,
371
+ **request,
372
+ )
373
+
374
+
375
+ def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
376
+ """
377
+ Get a LiteLLM retry policy for retrying requests when transient API errors occur.
378
+ Args:
379
+ num_retries: The number of times to retry a request if it fails transiently due to
380
+ network error, rate limiting, etc. Requests are retried with exponential
381
+ backoff.
382
+ Returns:
383
+ A LiteLLM RetryPolicy instance.
384
+ """
385
+ return RetryPolicy(
386
+ TimeoutErrorRetries=num_retries,
387
+ RateLimitErrorRetries=num_retries,
388
+ InternalServerErrorRetries=num_retries,
389
+ ContentPolicyViolationErrorRetries=num_retries,
390
+ # We don't retry on errors that are unlikely to be transient
391
+ # (e.g. bad request, invalid auth credentials)
392
+ BadRequestErrorRetries=0,
393
+ AuthenticationErrorRetries=0,
394
+ )