speedy-utils 1.1.18__py3-none-any.whl → 1.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/__init__.py CHANGED
@@ -4,7 +4,7 @@ from llm_utils.vector_cache import VectorCache
4
4
  from llm_utils.lm.lm_base import get_model_name
5
5
  from llm_utils.lm.base_prompt_builder import BasePromptBuilder
6
6
 
7
-
7
+ LLM = LLMTask
8
8
 
9
9
  from .chat_format import (
10
10
  build_chatml_input,
@@ -34,5 +34,6 @@ __all__ = [
34
34
  "MOpenAI",
35
35
  "get_model_name",
36
36
  "VectorCache",
37
- "BasePromptBuilder"
37
+ "BasePromptBuilder",
38
+ "LLM"
38
39
  ]
@@ -1,3 +1,4 @@
1
+ # type: ignore
1
2
  """
2
3
  Async LLM Task module for handling language model interactions with structured input/output.
3
4
  """
llm_utils/lm/llm_task.py CHANGED
@@ -4,10 +4,12 @@
4
4
  Simplified LLM Task module for handling language model interactions with structured input/output.
5
5
  """
6
6
 
7
+ import os
7
8
  from typing import Any, Dict, List, Optional, Type, Union, cast
8
9
 
10
+ import requests
9
11
  from loguru import logger
10
- from openai import OpenAI
12
+ from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
11
13
  from openai.types.chat import ChatCompletionMessageParam
12
14
  from pydantic import BaseModel
13
15
 
@@ -38,6 +40,90 @@ def get_base_client(
38
40
  )
39
41
 
40
42
 
43
+ def _is_lora_path(path: str) -> bool:
44
+ """Check if the given path is a LoRA adapter directory.
45
+
46
+ Args:
47
+ path: Path to check
48
+
49
+ Returns:
50
+ True if the path contains adapter_config.json, False otherwise
51
+ """
52
+ if not os.path.isdir(path):
53
+ return False
54
+ adapter_config_path = os.path.join(path, 'adapter_config.json')
55
+ return os.path.isfile(adapter_config_path)
56
+
57
+
58
+ def _get_port_from_client(client: OpenAI) -> Optional[int]:
59
+ """Extract port number from OpenAI client base_url.
60
+
61
+ Args:
62
+ client: OpenAI client instance
63
+
64
+ Returns:
65
+ Port number if found, None otherwise
66
+ """
67
+ if hasattr(client, 'base_url') and client.base_url:
68
+ base_url = str(client.base_url)
69
+ if 'localhost:' in base_url:
70
+ try:
71
+ # Extract port from localhost:PORT/v1 format
72
+ port_part = base_url.split('localhost:')[1].split('/')[0]
73
+ return int(port_part)
74
+ except (IndexError, ValueError):
75
+ pass
76
+ return None
77
+
78
+
79
+ def _load_lora_adapter(lora_path: str, port: int) -> str:
80
+ """Load a LoRA adapter from the specified path.
81
+
82
+ Args:
83
+ lora_path: Path to the LoRA adapter directory
84
+ port: Port number for the API endpoint
85
+
86
+ Returns:
87
+ Name of the loaded LoRA adapter
88
+
89
+ Raises:
90
+ requests.RequestException: If the API call fails
91
+ """
92
+ lora_name = os.path.basename(lora_path.rstrip('/\\'))
93
+ if not lora_name: # Handle edge case of empty basename
94
+ lora_name = os.path.basename(os.path.dirname(lora_path))
95
+
96
+ response = requests.post(
97
+ f'http://localhost:{port}/v1/load_lora_adapter',
98
+ headers={'accept': 'application/json', 'Content-Type': 'application/json'},
99
+ json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
100
+ )
101
+ response.raise_for_status()
102
+ return lora_name
103
+
104
+
105
+ def _unload_lora_adapter(lora_path: str, port: int) -> None:
106
+ """Unload the current LoRA adapter.
107
+
108
+ Args:
109
+ lora_path: Path to the LoRA adapter directory
110
+ port: Port number for the API endpoint
111
+ """
112
+ try:
113
+ lora_name = os.path.basename(lora_path.rstrip('/\\'))
114
+ if not lora_name: # Handle edge case of empty basename
115
+ lora_name = os.path.basename(os.path.dirname(lora_path))
116
+
117
+ response = requests.post(
118
+ f'http://localhost:{port}/v1/unload_lora_adapter',
119
+ headers={'accept': 'application/json', 'Content-Type': 'application/json'},
120
+ json={"lora_name": lora_name, "lora_int_id": 0}
121
+ )
122
+ response.raise_for_status()
123
+ except requests.RequestException as e:
124
+ logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
125
+
126
+
41
127
  class LLMTask:
42
128
  """
43
129
  Language model task with structured input/output and optional system instruction.
@@ -106,6 +192,9 @@ class LLMTask:
106
192
  output_model: Type[BaseModel] | Type[str] = None,
107
193
  client: Union[OpenAI, int, str, None] = None,
108
194
  cache=True,
195
+ is_reasoning_model: bool = False,
196
+ force_lora_unload: bool = False,
197
+ lora_path: Optional[str] = None,
109
198
  **model_kwargs,
110
199
  ):
111
200
  """
@@ -117,6 +206,12 @@ class LLMTask:
117
206
  output_model: Output BaseModel type
118
207
  client: OpenAI client, port number, or base_url string
119
208
  cache: Whether to use cached responses (default True)
209
+ is_reasoning_model: Whether the model is a reasoning model (o1-preview, o1-mini, etc.)
210
+ that outputs reasoning_content separately from content (default False)
211
+ force_lora_unload: If True, forces unloading of any existing LoRA adapter before loading
212
+ a new one when lora_path is provided (default False)
213
+ lora_path: Optional path to LoRA adapter directory. If provided, will load the LoRA
214
+ and use it as the model. Takes precedence over model parameter.
120
215
  **model_kwargs: Additional model parameters including:
121
216
  - temperature: Controls randomness (0.0 to 2.0)
122
217
  - n: Number of responses to generate (when n > 1, returns list)
@@ -127,6 +222,10 @@ class LLMTask:
127
222
  self.input_model = input_model
128
223
  self.output_model = output_model
129
224
  self.model_kwargs = model_kwargs
225
+ self.is_reasoning_model = is_reasoning_model
226
+ self.force_lora_unload = force_lora_unload
227
+ self.lora_path = lora_path
228
+ self.last_ai_response = None # Store raw response from client
130
229
 
131
230
  # if cache:
132
231
  # print("Caching is enabled will use llm_utils.MOpenAI")
@@ -135,11 +234,152 @@ class LLMTask:
135
234
  # else:
136
235
  # self.client = OpenAI(base_url=base_url, api_key=api_key)
137
236
  self.client = get_base_client(client, cache=cache)
237
+ # check connection of client
238
+ try:
239
+ self.client.models.list()
240
+ except Exception as e:
241
+ logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
242
+ raise e
138
243
 
139
244
  if not self.model_kwargs.get("model", ""):
140
245
  self.model_kwargs["model"] = self.client.models.list().data[0].id
246
+
247
+ # Handle LoRA loading if lora_path is provided
248
+ if self.lora_path:
249
+ self._load_lora_adapter()
250
+
141
251
  print(self.model_kwargs)
142
252
 
253
+ def _load_lora_adapter(self) -> None:
254
+ """
255
+ Load LoRA adapter from the specified lora_path.
256
+
257
+ This method:
258
+ 1. Validates that lora_path is a valid LoRA directory
259
+ 2. Checks if LoRA is already loaded (unless force_lora_unload is True)
260
+ 3. Loads the LoRA adapter and updates the model name
261
+ """
262
+ if not self.lora_path:
263
+ return
264
+
265
+ if not _is_lora_path(self.lora_path):
266
+ raise ValueError(
267
+ f"Invalid LoRA path '{self.lora_path}': "
268
+ "Directory must contain 'adapter_config.json'"
269
+ )
270
+
271
+ logger.info(f"Loading LoRA adapter from: {self.lora_path}")
272
+
273
+ # Get the expected LoRA name (basename of the path)
274
+ lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
275
+ if not lora_name: # Handle edge case of empty basename
276
+ lora_name = os.path.basename(os.path.dirname(self.lora_path))
277
+
278
+ # Get list of available models to check if LoRA is already loaded
279
+ try:
280
+ available_models = [m.id for m in self.client.models.list().data]
281
+ except Exception as e:
282
+ logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
283
+ available_models = []
284
+
285
+ # Check if LoRA is already loaded
286
+ if lora_name in available_models and not self.force_lora_unload:
287
+ logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
288
+ self.model_kwargs["model"] = lora_name
289
+ return
290
+
291
+ # Force unload if requested
292
+ if self.force_lora_unload and lora_name in available_models:
293
+ logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
294
+ port = _get_port_from_client(self.client)
295
+ if port is not None:
296
+ try:
297
+ LLMTask.unload_lora(port, lora_name)
298
+ logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
299
+ except Exception as e:
300
+ logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
301
+
302
+ # Get port from client for API calls
303
+ port = _get_port_from_client(self.client)
304
+ if port is None:
305
+ raise ValueError(
306
+ f"Cannot load LoRA adapter '{self.lora_path}': "
307
+ "Unable to determine port from client base_url. "
308
+ "LoRA loading requires a client initialized with port number."
309
+ )
310
+
311
+ try:
312
+ # Load the LoRA adapter
313
+ loaded_lora_name = _load_lora_adapter(self.lora_path, port)
314
+ logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
315
+
316
+ # Update model name to the loaded LoRA name
317
+ self.model_kwargs["model"] = loaded_lora_name
318
+
319
+ except requests.RequestException as e:
320
+ # Check if the error is due to LoRA already being loaded
321
+ error_msg = str(e)
322
+ if "400" in error_msg or "Bad Request" in error_msg:
323
+ logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
324
+ # Refresh the model list to check if it's now available
325
+ try:
326
+ updated_models = [m.id for m in self.client.models.list().data]
327
+ if lora_name in updated_models:
328
+ logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
329
+ self.model_kwargs["model"] = lora_name
330
+ return
331
+ except Exception:
332
+ pass # Fall through to original error
333
+
334
+ raise ValueError(
335
+ f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
336
+ )
337
+
338
+ def unload_lora_adapter(self, lora_path: str) -> None:
339
+ """
340
+ Unload a LoRA adapter.
341
+
342
+ Args:
343
+ lora_path: Path to the LoRA adapter directory to unload
344
+
345
+ Raises:
346
+ ValueError: If unable to determine port from client
347
+ """
348
+ port = _get_port_from_client(self.client)
349
+ if port is None:
350
+ raise ValueError(
351
+ "Cannot unload LoRA adapter: "
352
+ "Unable to determine port from client base_url. "
353
+ "LoRA operations require a client initialized with port number."
354
+ )
355
+
356
+ _unload_lora_adapter(lora_path, port)
357
+ lora_name = os.path.basename(lora_path.rstrip('/\\'))
358
+ logger.info(f"Unloaded LoRA adapter: {lora_name}")
359
+
360
+ @staticmethod
361
+ def unload_lora(port: int, lora_name: str) -> None:
362
+ """Static method to unload a LoRA adapter by name.
363
+
364
+ Args:
365
+ port: Port number for the API endpoint
366
+ lora_name: Name of the LoRA adapter to unload
367
+
368
+ Raises:
369
+ requests.RequestException: If the API call fails
370
+ """
371
+ try:
372
+ response = requests.post(
373
+ f'http://localhost:{port}/v1/unload_lora_adapter',
374
+ headers={'accept': 'application/json', 'Content-Type': 'application/json'},
375
+ json={"lora_name": lora_name, "lora_int_id": 0}
376
+ )
377
+ response.raise_for_status()
378
+ logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
379
+ except requests.RequestException as e:
380
+ logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
381
+ raise
382
+
143
383
  def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
144
384
  """Convert input to messages format."""
145
385
  if isinstance(input_data, list):
@@ -200,9 +440,24 @@ class LLMTask:
200
440
  # Extract model name from kwargs for API call
201
441
  api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
202
442
 
203
- completion = self.client.chat.completions.create(
204
- model=model_name, messages=messages, **api_kwargs
205
- )
443
+ try:
444
+ completion = self.client.chat.completions.create(
445
+ model=model_name, messages=messages, **api_kwargs
446
+ )
447
+ # Store raw response from client
448
+ self.last_ai_response = completion
449
+ except (AuthenticationError, RateLimitError, BadRequestError) as exc:
450
+ error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
451
+ logger.error(error_msg)
452
+ raise
453
+ except Exception as e:
454
+ is_length_error = "Length" in str(e) or "maximum context length" in str(e)
455
+ if is_length_error:
456
+ raise ValueError(
457
+ f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
458
+ )
459
+ # Re-raise all other exceptions
460
+ raise
206
461
  # print(completion)
207
462
 
208
463
  results: List[Dict[str, Any]] = []
@@ -211,9 +466,13 @@ class LLMTask:
211
466
  Messages,
212
467
  messages + [{"role": "assistant", "content": choice.message.content}],
213
468
  )
214
- results.append(
215
- {"parsed": choice.message.content, "messages": choice_messages}
216
- )
469
+ result_dict = {"parsed": choice.message.content, "messages": choice_messages}
470
+
471
+ # Add reasoning content if this is a reasoning model
472
+ if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
473
+ result_dict["reasoning_content"] = choice.message.reasoning_content
474
+
475
+ results.append(result_dict)
217
476
  return results
218
477
 
219
478
  def pydantic_parse(
@@ -239,6 +498,11 @@ class LLMTask:
239
498
  List of dicts [{'parsed': parsed_model, 'messages': messages}, ...]
240
499
  When n=1: List contains one dict
241
500
  When n>1: List contains multiple dicts
501
+
502
+ Note:
503
+ This method ensures consistent Pydantic model output for both fresh and cached responses.
504
+ When responses are cached and loaded back, the parsed content is re-validated to maintain
505
+ type consistency between first-time and subsequent calls.
242
506
  """
243
507
  # Prepare messages
244
508
  messages = self._prepare_input(input_data)
@@ -265,12 +529,20 @@ class LLMTask:
265
529
  response_format=pydantic_model_to_use,
266
530
  **api_kwargs,
267
531
  )
532
+ # Store raw response from client
533
+ self.last_ai_response = completion
534
+ except (AuthenticationError, RateLimitError, BadRequestError) as exc:
535
+ error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
536
+ logger.error(error_msg)
537
+ raise
268
538
  except Exception as e:
269
539
  is_length_error = "Length" in str(e) or "maximum context length" in str(e)
270
540
  if is_length_error:
271
541
  raise ValueError(
272
542
  f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
273
543
  )
544
+ # Re-raise all other exceptions
545
+ raise
274
546
 
275
547
  results: List[Dict[str, Any]] = []
276
548
  for choice in completion.choices: # type: ignore[attr-defined]
@@ -278,9 +550,23 @@ class LLMTask:
278
550
  Messages,
279
551
  messages + [{"role": "assistant", "content": choice.message.content}],
280
552
  )
281
- results.append(
282
- {"parsed": choice.message.parsed, "messages": choice_messages}
283
- ) # type: ignore[attr-defined]
553
+
554
+ # Ensure consistent Pydantic model output for both fresh and cached responses
555
+ parsed_content = choice.message.parsed # type: ignore[attr-defined]
556
+ if isinstance(parsed_content, dict):
557
+ # Cached response: validate dict back to Pydantic model
558
+ parsed_content = pydantic_model_to_use.model_validate(parsed_content)
559
+ elif not isinstance(parsed_content, pydantic_model_to_use):
560
+ # Fallback: ensure it's the correct type
561
+ parsed_content = pydantic_model_to_use.model_validate(parsed_content)
562
+
563
+ result_dict = {"parsed": parsed_content, "messages": choice_messages}
564
+
565
+ # Add reasoning content if this is a reasoning model
566
+ if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
567
+ result_dict["reasoning_content"] = choice.message.reasoning_content
568
+
569
+ results.append(result_dict)
284
570
  return results
285
571
 
286
572
  def __call__(
@@ -364,6 +650,8 @@ class LLMTask:
364
650
  builder: BasePromptBuilder,
365
651
  client: Union[OpenAI, int, str, None] = None,
366
652
  cache=True,
653
+ is_reasoning_model: bool = False,
654
+ lora_path: Optional[str] = None,
367
655
  **model_kwargs,
368
656
  ) -> "LLMTask":
369
657
  """
@@ -382,6 +670,10 @@ class LLMTask:
382
670
  input_model=input_model,
383
671
  output_model=output_model,
384
672
  client=client,
673
+ cache=cache,
674
+ is_reasoning_model=is_reasoning_model,
675
+ lora_path=lora_path,
676
+ **model_kwargs,
385
677
  )
386
678
 
387
679
  @staticmethod
@@ -398,3 +690,4 @@ class LLMTask:
398
690
  client = get_base_client(client, cache=False)
399
691
  models = client.models.list().data
400
692
  return [m.id for m in models]
693
+
@@ -1,4 +1,5 @@
1
1
  from openai import OpenAI, AsyncOpenAI
2
+ from typing import Any, Callable
2
3
 
3
4
  from speedy_utils.common.utils_cache import memoize
4
5
 
@@ -30,6 +31,8 @@ class MOpenAI(OpenAI):
30
31
  - If you need a shared cache across instances, or more advanced cache controls,
31
32
  modify `memoize` or wrap at a class/static level instead of assigning to the
32
33
  bound method.
34
+ - Type information is now fully preserved by the memoize decorator, eliminating
35
+ the need for type casting.
33
36
 
34
37
  Example
35
38
  m = MOpenAI(api_key="...", model="gpt-4")
@@ -40,7 +43,12 @@ class MOpenAI(OpenAI):
40
43
  def __init__(self, *args, cache=True, **kwargs):
41
44
  super().__init__(*args, **kwargs)
42
45
  if cache:
43
- self.post = memoize(self.post)
46
+ # Create a memoized wrapper for the instance's post method.
47
+ # The memoize decorator now preserves exact type information,
48
+ # so no casting is needed.
49
+ orig_post = self.post
50
+ memoized = memoize(orig_post)
51
+ self.post = memoized
44
52
 
45
53
 
46
54
  class MAsyncOpenAI(AsyncOpenAI):
@@ -69,4 +77,4 @@ class MAsyncOpenAI(AsyncOpenAI):
69
77
  def __init__(self, *args, cache=True, **kwargs):
70
78
  super().__init__(*args, **kwargs)
71
79
  if cache:
72
- self.post = memoize(self.post)
80
+ self.post = memoize(self.post) # type: ignore