speedy-utils 0.1.28__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm.py ADDED
@@ -0,0 +1,742 @@
1
+ import fcntl
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from copy import deepcopy
6
+ import time
7
+ from typing import Any, List, Literal, Optional, TypedDict, Dict, Type, Union, cast
8
+
9
+
10
+ import numpy as np
11
+ from loguru import logger
12
+ from pydantic import BaseModel
13
+ from speedy_utils import dump_json_or_pickle, identify_uuid, load_json_or_pickle
14
+
15
+
16
+ class Message(TypedDict):
17
+ role: Literal["user", "assistant", "system"]
18
+ content: str | BaseModel
19
+
20
+
21
+ class ChatSession:
22
+
23
+ def __init__(
24
+ self,
25
+ lm: "OAI_LM",
26
+ system_prompt: Optional[str] = None,
27
+ history: List[Message] = [], # Default to empty list, deepcopy happens below
28
+ callback=None,
29
+ response_format: Optional[Type[BaseModel]] = None,
30
+ ):
31
+ self.lm = deepcopy(lm)
32
+ self.history = deepcopy(history) # Deepcopy the provided history
33
+ self.callback = callback
34
+ self.response_format = response_format
35
+ if system_prompt:
36
+ system_message: Message = {
37
+ "role": "system",
38
+ "content": system_prompt,
39
+ }
40
+ self.history.insert(0, system_message)
41
+
42
+ def __len__(self):
43
+ return len(self.history)
44
+
45
+ def __call__(
46
+ self,
47
+ text,
48
+ response_format: Optional[Type[BaseModel]] = None,
49
+ display=False,
50
+ max_prev_turns=3,
51
+ **kwargs,
52
+ ) -> str | BaseModel:
53
+ current_response_format = response_format or self.response_format
54
+ self.history.append({"role": "user", "content": text})
55
+ output = self.lm(
56
+ messages=self.parse_history(),
57
+ response_format=current_response_format,
58
+ **kwargs,
59
+ )
60
+ # output could be a string or a pydantic model
61
+ if isinstance(output, BaseModel):
62
+ self.history.append({"role": "assistant", "content": output})
63
+ else:
64
+ assert response_format is None
65
+ self.history.append({"role": "assistant", "content": output})
66
+ if display:
67
+ self.inspect_history(max_prev_turns=max_prev_turns)
68
+
69
+ if self.callback:
70
+ self.callback(self, output)
71
+ return output
72
+
73
+ def send_message(self, text, **kwargs):
74
+ """
75
+ Wrapper around __call__ method for sending messages.
76
+ This maintains compatibility with the test suite.
77
+ """
78
+ return self.__call__(text, **kwargs)
79
+
80
+ def parse_history(self, indent=None):
81
+ parsed_history = []
82
+ for m in self.history:
83
+ if isinstance(m["content"], str):
84
+ parsed_history.append(m)
85
+ elif isinstance(m["content"], BaseModel):
86
+ parsed_history.append(
87
+ {
88
+ "role": m["role"],
89
+ "content": m["content"].model_dump_json(indent=indent),
90
+ }
91
+ )
92
+ else:
93
+ raise ValueError(f"Unexpected content type: {type(m['content'])}")
94
+ return parsed_history
95
+
96
+ def inspect_history(self, max_prev_turns=3):
97
+ from llm_utils import display_chat_messages_as_html
98
+
99
+ h = self.parse_history(indent=2)
100
+ try:
101
+ from IPython.display import clear_output
102
+
103
+ clear_output()
104
+ display_chat_messages_as_html(h[-max_prev_turns * 2 :])
105
+ except:
106
+ pass
107
+
108
+
109
+ def _clear_port_use(ports):
110
+ """
111
+ Clear the usage counters for all ports.
112
+ """
113
+ for port in ports:
114
+ file_counter = f"/tmp/port_use_counter_{port}.npy"
115
+ if os.path.exists(file_counter):
116
+ os.remove(file_counter)
117
+
118
+
119
+ def _atomic_save(array: np.ndarray, filename: str):
120
+ """
121
+ Write `array` to `filename` with an atomic rename to avoid partial writes.
122
+ """
123
+ # The temp file must be on the same filesystem as `filename` to ensure
124
+ # that os.replace() is truly atomic.
125
+ tmp_dir = os.path.dirname(filename) or "."
126
+ with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
127
+ np.save(tmp, array)
128
+ temp_name = tmp.name
129
+
130
+ # Atomically rename the temp file to the final name.
131
+ # On POSIX systems, os.replace is an atomic operation.
132
+ os.replace(temp_name, filename)
133
+
134
+
135
+ def _update_port_use(port: int, increment: int):
136
+ """
137
+ Update the usage counter for a given port, safely with an exclusive lock
138
+ and atomic writes to avoid file corruption.
139
+ """
140
+ file_counter = f"/tmp/port_use_counter_{port}.npy"
141
+ file_counter_lock = f"/tmp/port_use_counter_{port}.lock"
142
+
143
+ with open(file_counter_lock, "w") as lock_file:
144
+ fcntl.flock(lock_file, fcntl.LOCK_EX)
145
+ try:
146
+ # If file exists, load it. Otherwise assume zero usage.
147
+ if os.path.exists(file_counter):
148
+ try:
149
+ counter = np.load(file_counter)
150
+ except Exception as e:
151
+ # If we fail to load (e.g. file corrupted), start from zero
152
+ logger.warning(f"Corrupted usage file {file_counter}: {e}")
153
+ counter = np.array([0])
154
+ else:
155
+ counter = np.array([0])
156
+
157
+ # Increment usage and atomically overwrite the old file
158
+ counter[0] += increment
159
+ _atomic_save(counter, file_counter)
160
+
161
+ finally:
162
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
163
+
164
+
165
+ def _pick_least_used_port(ports: List[int]) -> int:
166
+ """
167
+ Pick the least-used port among the provided list, safely under a global lock
168
+ so that no two processes pick a port at the same time.
169
+ """
170
+ global_lock_file = "/tmp/ports.lock"
171
+
172
+ with open(global_lock_file, "w") as lock_file:
173
+ fcntl.flock(lock_file, fcntl.LOCK_EX)
174
+ try:
175
+ port_use: Dict[int, int] = {}
176
+ # Read usage for each port
177
+ for port in ports:
178
+ file_counter = f"/tmp/port_use_counter_{port}.npy"
179
+ if os.path.exists(file_counter):
180
+ try:
181
+ counter = np.load(file_counter)
182
+ except Exception as e:
183
+ # If the file is corrupted, reset usage to 0
184
+ logger.warning(f"Corrupted usage file {file_counter}: {e}")
185
+ counter = np.array([0])
186
+ else:
187
+ counter = np.array([0])
188
+ port_use[port] = counter[0]
189
+
190
+ logger.debug(f"Port use: {port_use}")
191
+
192
+ if not port_use:
193
+ if ports:
194
+ raise ValueError("Port usage data is empty, cannot pick a port.")
195
+ else:
196
+ raise ValueError("No ports provided to pick from.")
197
+
198
+ # Pick the least-used port
199
+ lsp = min(port_use, key=lambda k: port_use[k])
200
+
201
+ # Increment usage of that port
202
+ _update_port_use(lsp, 1)
203
+
204
+ finally:
205
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
206
+
207
+ return lsp
208
+
209
+
210
+ class OAI_LM:
211
+ """
212
+ A language model supporting chat or text completion requests for use with DSPy modules.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ model: Optional[str] = None,
218
+ model_type: Literal["chat", "text"] = "chat",
219
+ temperature: float = 0.0,
220
+ max_tokens: int = 2000,
221
+ cache: bool = True,
222
+ callbacks: Optional[Any] = None,
223
+ num_retries: int = 3,
224
+ provider=None,
225
+ finetuning_model: Optional[str] = None,
226
+ launch_kwargs: Optional[dict[str, Any]] = None,
227
+ host: str = "localhost",
228
+ port: Optional[int] = None,
229
+ ports: Optional[List[int]] = None,
230
+ api_key: Optional[str] = None,
231
+ **kwargs,
232
+ ):
233
+ # Lazy import dspy
234
+ import dspy
235
+
236
+ self.ports = ports
237
+ self.host = host
238
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
239
+
240
+ # Determine base_url: kwargs["base_url"] > http://host:port > http://host:ports[0]
241
+ resolved_base_url_from_kwarg = kwargs.get("base_url")
242
+ if resolved_base_url_from_kwarg is not None and not isinstance(
243
+ resolved_base_url_from_kwarg, str
244
+ ):
245
+ logger.warning(
246
+ f"base_url in kwargs was not a string ({type(resolved_base_url_from_kwarg)}), ignoring."
247
+ )
248
+ resolved_base_url_from_kwarg = None
249
+
250
+ resolved_base_url: Optional[str] = cast(
251
+ Optional[str], resolved_base_url_from_kwarg
252
+ )
253
+
254
+ if resolved_base_url is None:
255
+ selected_port = port
256
+ if selected_port is None and ports is not None and len(ports) > 0:
257
+ selected_port = ports[0]
258
+
259
+ if selected_port is not None:
260
+ resolved_base_url = f"http://{host}:{selected_port}/v1"
261
+ self.base_url = resolved_base_url
262
+
263
+ if model is None:
264
+ if self.base_url:
265
+ try:
266
+ model_list = (
267
+ self.list_models()
268
+ ) # Uses self.base_url and self.api_key
269
+ if model_list:
270
+ model_name_from_list = model_list[0]
271
+ model = f"openai/{model_name_from_list}"
272
+ logger.info(f"Using default model: {model}")
273
+ else:
274
+ logger.warning(
275
+ f"No models found at {self.base_url}. Please specify a model."
276
+ )
277
+ except Exception as e:
278
+ example_cmd = (
279
+ "LM.start_server('unsloth/gemma-3-1b-it')\n"
280
+ "# Or manually run: svllm serve --model unsloth/gemma-3-1b-it --gpus 0 -hp localhost:9150"
281
+ )
282
+ logger.error(
283
+ f"Failed to list models from {self.base_url}: {e}\n"
284
+ f"Make sure your model server is running and accessible.\n"
285
+ f"Example to start a server:\n{example_cmd}"
286
+ )
287
+ else:
288
+ logger.warning(
289
+ "base_url not configured, cannot fetch default model. Please specify a model."
290
+ )
291
+
292
+ assert (
293
+ model is not None
294
+ ), "Model name must be provided or discoverable via list_models"
295
+
296
+ if not model.startswith("openai/"):
297
+ model = f"openai/{model}"
298
+
299
+ dspy_lm_kwargs = kwargs.copy()
300
+ dspy_lm_kwargs["api_key"] = self.api_key # Ensure dspy.LM gets this
301
+
302
+ if self.base_url and "base_url" not in dspy_lm_kwargs:
303
+ dspy_lm_kwargs["base_url"] = self.base_url
304
+ elif (
305
+ self.base_url
306
+ and "base_url" in dspy_lm_kwargs
307
+ and dspy_lm_kwargs["base_url"] != self.base_url
308
+ ):
309
+ # If kwarg['base_url'] exists and differs from derived self.base_url,
310
+ # dspy.LM will use kwarg['base_url']. Update self.base_url to reflect this.
311
+ self.base_url = dspy_lm_kwargs["base_url"]
312
+
313
+ self._dspy_lm: dspy.LM = dspy.LM(
314
+ model=model,
315
+ model_type=model_type,
316
+ temperature=temperature,
317
+ max_tokens=max_tokens,
318
+ callbacks=callbacks,
319
+ num_retries=num_retries,
320
+ provider=provider,
321
+ finetuning_model=finetuning_model,
322
+ launch_kwargs=launch_kwargs,
323
+ # api_key is passed via dspy_lm_kwargs
324
+ **dspy_lm_kwargs,
325
+ )
326
+ # Store the actual kwargs used by dspy.LM
327
+ self.kwargs = self._dspy_lm.kwargs
328
+ self.model = self._dspy_lm.model # self.model is str
329
+
330
+ # Ensure self.base_url and self.api_key are consistent with what dspy.LM is using
331
+ self.base_url = self.kwargs.get("base_url")
332
+ self.api_key = self.kwargs.get("api_key")
333
+
334
+ self.do_cache = cache
335
+
336
+ @property
337
+ def last_response(self):
338
+ return self._dspy_lm.history[-1]["response"].model_dump()["choices"][0][
339
+ "message"
340
+ ]
341
+
342
+ def __call__(
343
+ self,
344
+ prompt: Optional[str] = None,
345
+ messages: Optional[List[Message]] = None,
346
+ response_format: Optional[Type[BaseModel]] = None,
347
+ cache: Optional[bool] = None,
348
+ retry_count: int = 0,
349
+ port: Optional[int] = None,
350
+ error: Optional[Exception] = None,
351
+ use_loadbalance: Optional[bool] = None,
352
+ must_load_cache: bool = False,
353
+ max_tokens: Optional[int] = None,
354
+ num_retries: int = 10,
355
+ **kwargs,
356
+ ) -> Union[str, BaseModel]:
357
+ if retry_count > num_retries:
358
+ logger.error(f"Retry limit exceeded, error: {error}")
359
+ if error:
360
+ raise error
361
+ raise ValueError("Retry limit exceeded with no specific error.")
362
+
363
+ effective_kwargs = {**self.kwargs, **kwargs}
364
+ id_for_cache: Optional[str] = None
365
+
366
+ effective_cache = cache if cache is not None else self.do_cache
367
+
368
+ if max_tokens is not None:
369
+ effective_kwargs["max_tokens"] = max_tokens
370
+
371
+ if response_format:
372
+ assert issubclass(
373
+ response_format, BaseModel
374
+ ), f"response_format must be a Pydantic model class, {type(response_format)} provided"
375
+
376
+ cached_result: Optional[Union[str, BaseModel, List[Union[str, BaseModel]]]] = (
377
+ None
378
+ )
379
+ if effective_cache:
380
+ cache_key_list = [
381
+ prompt,
382
+ messages,
383
+ (response_format.model_json_schema() if response_format else None),
384
+ effective_kwargs.get("temperature"),
385
+ effective_kwargs.get("max_tokens"),
386
+ self.model,
387
+ ]
388
+ s = str(cache_key_list)
389
+ id_for_cache = identify_uuid(s)
390
+ cached_result = self.load_cache(id_for_cache)
391
+
392
+ if cached_result is not None:
393
+ if response_format:
394
+ if isinstance(cached_result, str):
395
+ try:
396
+ import json_repair
397
+
398
+ parsed = json_repair.loads(cached_result)
399
+ if not isinstance(parsed, dict):
400
+ raise ValueError("Parsed cached_result is not a dict")
401
+ # Ensure keys are strings
402
+ parsed = {str(k): v for k, v in parsed.items()}
403
+ return response_format(**parsed)
404
+ except Exception as e_parse:
405
+ logger.warning(
406
+ f"Failed to parse cached string for {id_for_cache} into {response_format.__name__}: {e_parse}. Retrying LLM call."
407
+ )
408
+ elif isinstance(cached_result, response_format):
409
+ return cached_result
410
+ else:
411
+ logger.warning(
412
+ f"Cached result for {id_for_cache} has unexpected type {type(cached_result)}. Expected {response_format.__name__} or str. Retrying LLM call."
413
+ )
414
+ else: # No response_format, expect string
415
+ if isinstance(cached_result, str):
416
+ return cached_result
417
+ else:
418
+ logger.warning(
419
+ f"Cached result for {id_for_cache} has unexpected type {type(cached_result)}. Expected str. Retrying LLM call."
420
+ )
421
+
422
+ if (
423
+ must_load_cache and cached_result is None
424
+ ): # If we are here, cache load failed or was not suitable
425
+ raise ValueError(
426
+ "must_load_cache is True, but failed to load a valid response from cache."
427
+ )
428
+
429
+ import litellm
430
+
431
+ current_port: int | None = port
432
+ if self.ports and not current_port:
433
+ if use_loadbalance:
434
+ current_port = self.get_least_used_port()
435
+ else:
436
+ current_port = random.choice(self.ports)
437
+
438
+ if current_port:
439
+ effective_kwargs["base_url"] = f"http://{self.host}:{current_port}/v1"
440
+
441
+ llm_output_or_outputs: Union[str, BaseModel, List[Union[str, BaseModel]]]
442
+ try:
443
+ dspy_main_input: Union[str, List[Message]]
444
+ if messages is not None:
445
+ dspy_main_input = messages
446
+ elif prompt is not None:
447
+ dspy_main_input = prompt
448
+ else:
449
+ # Depending on LM capabilities, this might be valid if other means of generation are used (e.g. tool use)
450
+ # For now, assume one is needed for typical completion/chat.
451
+ # Consider if _dspy_lm can handle None/empty input gracefully or if an error is better.
452
+ # If dspy.LM expects a non-null primary argument, this will fail there.
453
+ # For safety, let's raise if both are None, assuming typical usage.
454
+ raise ValueError(
455
+ "Either 'prompt' or 'messages' must be provided for the LLM call."
456
+ )
457
+
458
+ llm_outputs_list = self._dspy_lm(
459
+ dspy_main_input, # Pass as positional argument
460
+ response_format=response_format, # Pass as keyword argument, dspy will handle it in its **kwargs
461
+ **effective_kwargs,
462
+ )
463
+
464
+ if not llm_outputs_list:
465
+ raise ValueError("LLM call returned an empty list.")
466
+
467
+ # Convert dict outputs to string to match expected return type
468
+ def convert_output(o):
469
+ if isinstance(o, dict):
470
+ import json
471
+
472
+ return json.dumps(o)
473
+ return o
474
+
475
+ if effective_kwargs.get("n", 1) == 1:
476
+ llm_output_or_outputs = convert_output(llm_outputs_list[0])
477
+ else:
478
+ llm_output_or_outputs = [convert_output(o) for o in llm_outputs_list]
479
+
480
+ except (litellm.exceptions.APIError, litellm.exceptions.Timeout) as e_llm:
481
+ t = 3
482
+ base_url_info = effective_kwargs.get("base_url", "N/A")
483
+ log_msg = f"[{base_url_info=}] {type(e_llm).__name__}: {str(e_llm)[:100]}, will sleep for {t}s and retry"
484
+ logger.warning(log_msg) # Always warn on retry for these
485
+ time.sleep(t)
486
+ return self.__call__(
487
+ prompt=prompt,
488
+ messages=messages,
489
+ response_format=response_format,
490
+ cache=cache,
491
+ retry_count=retry_count + 1,
492
+ port=current_port,
493
+ error=e_llm,
494
+ use_loadbalance=use_loadbalance,
495
+ must_load_cache=must_load_cache,
496
+ max_tokens=max_tokens,
497
+ num_retries=num_retries,
498
+ **kwargs,
499
+ )
500
+ except litellm.exceptions.ContextWindowExceededError as e_cwe:
501
+ logger.error(f"Context window exceeded: {e_cwe}")
502
+ raise
503
+ except Exception as e_generic:
504
+ logger.error(f"Generic error during LLM call: {e_generic}")
505
+ import traceback
506
+
507
+ traceback.print_exc()
508
+ raise
509
+ finally:
510
+ if (
511
+ current_port and use_loadbalance is True
512
+ ): # Ensure use_loadbalance is explicitly True
513
+ _update_port_use(current_port, -1)
514
+
515
+ if effective_cache and id_for_cache:
516
+ self.dump_cache(id_for_cache, llm_output_or_outputs)
517
+
518
+ # Ensure single return if n=1, which is implied by method signature str | BaseModel
519
+ final_output: Union[str, BaseModel]
520
+ if isinstance(llm_output_or_outputs, list):
521
+ # This should ideally not happen if n=1 was handled correctly above.
522
+ # If it's a list, it means n > 1. The method signature needs to change for that.
523
+ # For now, stick to returning the first element if it's a list.
524
+ logger.warning(
525
+ "LLM returned multiple completions; __call__ expects single. Returning first."
526
+ )
527
+ final_output = llm_output_or_outputs[0]
528
+ else:
529
+ final_output = llm_output_or_outputs # type: ignore # It's already Union[str, BaseModel]
530
+
531
+ if response_format:
532
+ if not isinstance(final_output, response_format):
533
+ if isinstance(final_output, str):
534
+ logger.warning(
535
+ f"LLM call returned string, but expected {response_format.__name__}. Attempting parse."
536
+ )
537
+ try:
538
+ import json_repair
539
+
540
+ parsed_dict = json_repair.loads(final_output)
541
+ if not isinstance(parsed_dict, dict):
542
+ raise ValueError("Parsed output is not a dict")
543
+ parsed_dict = {str(k): v for k, v in parsed_dict.items()}
544
+ parsed_output = response_format(**parsed_dict)
545
+ if effective_cache and id_for_cache:
546
+ self.dump_cache(
547
+ id_for_cache, parsed_output
548
+ ) # Cache the successfully parsed model
549
+ return parsed_output
550
+ except Exception as e_final_parse:
551
+ logger.error(
552
+ f"Final attempt to parse LLM string output into {response_format.__name__} failed: {e_final_parse}"
553
+ )
554
+ # Retry without cache to force regeneration
555
+ return self.__call__(
556
+ prompt=prompt,
557
+ messages=messages,
558
+ response_format=response_format,
559
+ cache=False,
560
+ retry_count=retry_count + 1,
561
+ port=current_port,
562
+ error=e_final_parse,
563
+ use_loadbalance=use_loadbalance,
564
+ must_load_cache=False,
565
+ max_tokens=max_tokens,
566
+ num_retries=num_retries,
567
+ **kwargs,
568
+ )
569
+ else:
570
+ logger.error(
571
+ f"LLM output type mismatch. Expected {response_format.__name__} or str, got {type(final_output)}. Raising error."
572
+ )
573
+ raise TypeError(
574
+ f"LLM output type mismatch: expected {response_format.__name__}, got {type(final_output)}"
575
+ )
576
+ return final_output # Already a response_format instance
577
+ else: # No response_format, expect string
578
+ if not isinstance(final_output, str):
579
+ # This could happen if LLM returns structured data and dspy parses it even without response_format
580
+ logger.warning(
581
+ f"LLM output type mismatch. Expected str, got {type(final_output)}. Attempting to convert to string."
582
+ )
583
+ # Convert to string, or handle as error depending on desired strictness
584
+ return str(final_output) # Or raise TypeError
585
+ return final_output
586
+
587
+ def clear_port_use(self):
588
+ if self.ports:
589
+ _clear_port_use(self.ports)
590
+ else:
591
+ logger.warning("No ports configured to clear usage for.")
592
+
593
+ def get_least_used_port(self) -> int:
594
+ if self.ports is None:
595
+ raise ValueError("Ports must be configured to pick the least used port.")
596
+ if not self.ports:
597
+ raise ValueError("Ports list is empty, cannot pick a port.")
598
+ return _pick_least_used_port(self.ports)
599
+
600
+ def get_session(
601
+ self,
602
+ system_prompt: Optional[str],
603
+ history: Optional[List[Message]] = None,
604
+ callback=None,
605
+ response_format: Optional[Type[BaseModel]] = None,
606
+ **kwargs, # kwargs are not used by ChatSession constructor
607
+ ) -> ChatSession:
608
+ actual_history = deepcopy(history) if history is not None else []
609
+ return ChatSession(
610
+ self,
611
+ system_prompt=system_prompt,
612
+ history=actual_history,
613
+ callback=callback,
614
+ response_format=response_format,
615
+ # **kwargs, # ChatSession constructor does not accept **kwargs
616
+ )
617
+
618
+ def dump_cache(
619
+ self, id: str, result: Union[str, BaseModel, List[Union[str, BaseModel]]]
620
+ ):
621
+ try:
622
+ cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
623
+ cache_file = os.path.expanduser(cache_file)
624
+
625
+ dump_json_or_pickle(result, cache_file)
626
+ except Exception as e:
627
+ logger.warning(f"Cache dump failed: {e}")
628
+
629
+ def load_cache(
630
+ self, id: str
631
+ ) -> Optional[Union[str, BaseModel, List[Union[str, BaseModel]]]]:
632
+ try:
633
+ cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
634
+ cache_file = os.path.expanduser(cache_file)
635
+ if not os.path.exists(cache_file):
636
+ return
637
+ return load_json_or_pickle(cache_file)
638
+ except Exception as e:
639
+ logger.warning(f"Cache load failed for {id}: {e}") # Added id to log
640
+ return None
641
+
642
+ def list_models(self) -> List[str]:
643
+ import openai
644
+
645
+ if not self.base_url:
646
+ raise ValueError("Cannot list models: base_url is not configured.")
647
+ if not self.api_key: # api_key should be set by __init__
648
+ logger.warning(
649
+ "API key not available for listing models. Using default 'abc'."
650
+ )
651
+
652
+ api_key_str = str(self.api_key) if self.api_key is not None else "abc"
653
+ base_url_str = str(self.base_url) if self.base_url is not None else None
654
+ if isinstance(self.base_url, float):
655
+ raise TypeError(f"base_url must be a string or None, got float: {self.base_url}")
656
+ client = openai.OpenAI(base_url=base_url_str, api_key=api_key_str)
657
+ page = client.models.list()
658
+ return [d.id for d in page.data]
659
+
660
+ @property
661
+ def client(self):
662
+ import openai
663
+ if not self.base_url:
664
+ raise ValueError("Cannot create client: base_url is not configured.")
665
+ if not self.api_key:
666
+ logger.warning("API key not available for client. Using default 'abc'.")
667
+
668
+ base_url_str = str(self.base_url) if self.base_url is not None else None
669
+ api_key_str = str(self.api_key) if self.api_key is not None else "abc"
670
+ return openai.OpenAI(base_url=base_url_str, api_key=api_key_str)
671
+
672
+ def __getattr__(self, name):
673
+ """
674
+ Delegate any attributes not found in OAI_LM to the underlying dspy.LM instance.
675
+ This makes sure any dspy.LM methods not explicitly defined in OAI_LM are still accessible.
676
+ """
677
+ # Check __dict__ directly to avoid recursion via hasattr
678
+ if "_dspy_lm" in self.__dict__ and hasattr(self._dspy_lm, name):
679
+ return getattr(self._dspy_lm, name)
680
+ raise AttributeError(
681
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
682
+ )
683
+
684
+ @classmethod
685
+ def get_deepseek_chat(
686
+ cls, api_key: Optional[str] = None, max_tokens: int = 2000, **kwargs
687
+ ):
688
+ api_key_to_pass = cast(
689
+ Optional[str], api_key or os.environ.get("DEEPSEEK_API_KEY")
690
+ )
691
+ return cls( # Use cls instead of OAI_LM
692
+ base_url="https://api.deepseek.com/v1",
693
+ model="deepseek-chat",
694
+ api_key=api_key_to_pass,
695
+ max_tokens=max_tokens,
696
+ **kwargs,
697
+ )
698
+
699
+ @classmethod
700
+ def get_deepseek_reasoner(
701
+ cls, api_key: Optional[str] = None, max_tokens: int = 2000, **kwargs
702
+ ):
703
+ api_key_to_pass = cast(
704
+ Optional[str], api_key or os.environ.get("DEEPSEEK_API_KEY")
705
+ )
706
+ return cls( # Use cls instead of OAI_LM
707
+ base_url="https://api.deepseek.com/v1",
708
+ model="deepseek-reasoner",
709
+ api_key=api_key_to_pass,
710
+ max_tokens=max_tokens,
711
+ **kwargs,
712
+ )
713
+
714
+ @classmethod
715
+ def start_server(
716
+ cls, model_name: str, gpus: str = "4567", port: int = 9150, eager: bool = True
717
+ ):
718
+ cmd = f"svllm serve --model {model_name} --gpus {gpus} -hp localhost:{port}"
719
+ if eager:
720
+ cmd += " --eager"
721
+ session_name = f"vllm_{port}"
722
+ is_session_exists = os.system(f"tmux has-session -t {session_name}")
723
+ logger.info(f"Starting server with command: {cmd}")
724
+ if is_session_exists == 0:
725
+ logger.warning(
726
+ f"Session {session_name} exists, please kill it before running the script"
727
+ )
728
+ # as user if they want to kill the session
729
+ user_input = input(
730
+ f"Session {session_name} exists, do you want to kill it? (y/n): "
731
+ )
732
+ if user_input.lower() == "y":
733
+ os.system(f"tmux kill-session -t {session_name}")
734
+ logger.info(f"Session {session_name} killed")
735
+ os.system(cmd)
736
+ # return subprocess.Popen(shlex.split(cmd))
737
+
738
+ # set get_agent is get_session
739
+ get_agent = get_session
740
+
741
+
742
+ LM = OAI_LM