speedy-utils 1.0.23__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/async_lm.py CHANGED
@@ -82,6 +82,7 @@ from functools import lru_cache
82
82
  from typing import (
83
83
  Any,
84
84
  Dict,
85
+ Generic,
85
86
  List,
86
87
  Literal,
87
88
  Optional,
@@ -92,9 +93,10 @@ from typing import (
92
93
  cast,
93
94
  overload,
94
95
  )
95
-
96
+ from typing_extensions import TypedDict
96
97
  from httpx import URL
97
98
  from loguru import logger
99
+ from numpy import isin
98
100
  from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
99
101
  from openai.pagination import AsyncPage as AsyncSyncPage
100
102
 
@@ -144,6 +146,12 @@ def _yellow(t):
144
146
  return _color(33, t)
145
147
 
146
148
 
149
+ class ParsedOutput(TypedDict):
150
+ messages: List
151
+ completion: Any
152
+ parsed: BaseModel
153
+
154
+
147
155
  class AsyncLM:
148
156
  """Unified **async** language‑model wrapper with optional JSON parsing."""
149
157
 
@@ -460,10 +468,9 @@ class AsyncLM:
460
468
  add_json_schema_to_instruction: bool = False,
461
469
  temperature: Optional[float] = None,
462
470
  max_tokens: Optional[int] = None,
463
- return_openai_response: bool = False,
464
471
  cache: Optional[bool] = True,
465
472
  **kwargs,
466
- ):
473
+ ) -> ParsedOutput: # -> dict[str, Any]:
467
474
  """Parse response using guided JSON generation."""
468
475
  if messages is None:
469
476
  assert instruction is not None, "Instruction must be provided."
@@ -514,27 +521,27 @@ class AsyncLM:
514
521
  "response_format": response_model.__name__,
515
522
  }
516
523
  cache_key = self._cache_key(cache_data, {}, response_model)
517
- cached_response = self._load_cache(cache_key)
518
- self.last_log = [prompt, messages, cached_response]
519
- if cached_response is not None:
520
- if return_openai_response:
521
- return cached_response
522
- return self._parse_complete_output(cached_response, response_model)
523
-
524
- completion = await self.client.chat.completions.create(
525
- model=self.model, # type: ignore
526
- messages=messages, # type: ignore
527
- extra_body={"guided_json": json_schema},
528
- **model_kwargs,
529
- )
530
-
531
- if cache_key:
532
- self._dump_cache(cache_key, completion)
524
+ completion = self._load_cache(cache_key) # dict
525
+ else:
526
+ completion = await self.client.chat.completions.create(
527
+ model=self.model, # type: ignore
528
+ messages=messages, # type: ignore
529
+ extra_body={"guided_json": json_schema},
530
+ **model_kwargs,
531
+ )
532
+ completion = completion.model_dump()
533
+ if cache_key:
534
+ self._dump_cache(cache_key, completion)
533
535
 
534
536
  self.last_log = [prompt, messages, completion]
535
- if return_openai_response:
536
- return completion
537
- return self._parse_complete_output(completion, response_model)
537
+
538
+ output = self._parse_complete_output(completion, response_model)
539
+ full_messages = messages + [completion]
540
+ return ParsedOutput(
541
+ messages=full_messages,
542
+ completion=completion,
543
+ parsed=output,
544
+ )
538
545
 
539
546
  def _parse_complete_output(
540
547
  self, completion: Any, response_model: Type[BaseModel]
@@ -839,8 +846,11 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
839
846
  # Async LLMTask class
840
847
  # --------------------------------------------------------------------------- #
841
848
 
849
+ InputModelType = TypeVar("InputModelType", bound=BaseModel)
850
+ OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
842
851
 
843
- class AsyncLLMTask(ABC):
852
+
853
+ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
844
854
  """
845
855
  Async callable wrapper around an AsyncLM endpoint.
846
856
 
@@ -878,37 +888,71 @@ class AsyncLLMTask(ABC):
878
888
  """
879
889
 
880
890
  lm: "AsyncLM"
881
- InputModel: Type[BaseModel]
882
- OutputModel: Type[BaseModel]
891
+ InputModel: InputModelType
892
+ OutputModel: OutputModelType
883
893
 
884
894
  temperature: float = 0.6
885
895
  think: bool = False
886
896
  add_json_schema: bool = False
887
897
 
888
- async def __call__(self, data: BaseModel | dict) -> BaseModel:
898
+ async def __call__(
899
+ self,
900
+ data: BaseModel | dict,
901
+ temperature: float = 0.1,
902
+ cache: bool = False,
903
+ ) -> tuple[OutputModelType, List[Dict[str, Any]]]:
904
+ # Get the input and output model types from the generic parameters
905
+ type_args = getattr(self.__class__, "__orig_bases__", None)
889
906
  if (
890
- not hasattr(self, "InputModel")
891
- or not hasattr(self, "OutputModel")
892
- or not hasattr(self, "lm")
907
+ type_args
908
+ and hasattr(type_args[0], "__args__")
909
+ and len(type_args[0].__args__) >= 2
893
910
  ):
894
- raise NotImplementedError(
895
- f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
896
- )
911
+ input_model = type_args[0].__args__[0]
912
+ output_model = type_args[0].__args__[1]
913
+ else:
914
+ # Fallback to the old way if type introspection fails
915
+ if (
916
+ not hasattr(self, "InputModel")
917
+ or not hasattr(self, "OutputModel")
918
+ or not hasattr(self, "lm")
919
+ ):
920
+ raise NotImplementedError(
921
+ f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes or use proper generic typing."
922
+ )
923
+ input_model = self.InputModel
924
+ output_model = self.OutputModel
925
+
926
+ # Ensure input_model is a class before calling
927
+ if isinstance(data, BaseModel):
928
+ item = data
929
+ elif isinstance(input_model, type) and issubclass(input_model, BaseModel):
930
+ item = input_model(**data)
931
+ else:
932
+ raise TypeError("InputModel must be a subclass of BaseModel")
897
933
 
898
- item = data if isinstance(data, BaseModel) else self.InputModel(**data)
934
+ assert isinstance(output_model, type) and issubclass(output_model, BaseModel), (
935
+ "OutputModel must be a subclass of BaseModel"
936
+ )
899
937
 
900
- return await self.lm.parse(
938
+ result = await self.lm.parse(
901
939
  prompt=item.model_dump_json(),
902
940
  instruction=self.__doc__ or "",
903
- response_model=self.OutputModel,
904
- temperature=self.temperature,
941
+ response_model=output_model,
942
+ temperature=temperature or self.temperature,
905
943
  think=self.think,
906
944
  add_json_schema_to_instruction=self.add_json_schema,
945
+ cache=cache,
946
+ )
947
+
948
+ return (
949
+ cast(OutputModelType, result["parsed"]), # type: ignore
950
+ cast(List[dict], result["messages"]), # type: ignore
907
951
  )
908
952
 
909
953
  def generate_training_data(
910
954
  self, input_dict: Dict[str, Any], output: Dict[str, Any]
911
- ):
955
+ ) -> Dict[str, Any]:
912
956
  """Return share gpt like format"""
913
957
  system_prompt = self.__doc__ or ""
914
958
  user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
@@ -917,4 +961,5 @@ class AsyncLLMTask(ABC):
917
961
  system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
918
962
  )
919
963
  return {"messages": messages}
920
- arun = __call__ # alias for compatibility with other LLMTask implementations
964
+
965
+ arun = __call__ # alias for compatibility with other LLMTask implementations
llm_utils/lm/utils.py CHANGED
@@ -7,11 +7,6 @@ import numpy as np
7
7
  from loguru import logger
8
8
 
9
9
 
10
- def _clear_port_use(ports):
11
- for port in ports:
12
- file_counter: str = f"/tmp/port_use_counter_{port}.npy"
13
- if os.path.exists(file_counter):
14
- os.remove(file_counter)
15
10
 
16
11
 
17
12
  def _atomic_save(array: np.ndarray, filename: str):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 1.0.23
3
+ Version: 1.1.0
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -5,11 +5,11 @@ llm_utils/chat_format/transform.py,sha256=8TZhvUS5DrjUeMNtDIuWY54B_QZ7jjpXEL9c8F
5
5
  llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
6
6
  llm_utils/group_messages.py,sha256=8CU9nKOja3xeuhdrX5CvYVveSqSKb2zQ0eeNzA88aTQ,3621
7
7
  llm_utils/lm/__init__.py,sha256=rX36_MsnekM5GHwWS56XELbm4W5x2TDwnPERDTfo0eU,194
8
- llm_utils/lm/async_lm.py,sha256=_NmWEp_jCbD6soexXo489L40KS8xJPgtY5QxXLDYsis,34174
8
+ llm_utils/lm/async_lm.py,sha256=kiWEecrkCTTQFlQj5JiHNziFeLOF1-7G_2xC2Dra1bw,35806
9
9
  llm_utils/lm/chat_html.py,sha256=FkGo0Dv_nAHYBMZzXfMu_bGQKaCx302goh3XaT-_ETc,8674
10
10
  llm_utils/lm/lm_json.py,sha256=fMt42phzFV2f6ulrtWcDXsWHi8WcG7gGkCzpIq8VSSM,1975
11
11
  llm_utils/lm/sync_lm.py,sha256=ANw_m5KiWcRwwoeQ5no6dzPFLc6j9o2oEcJtkMKqrn8,34640
12
- llm_utils/lm/utils.py,sha256=GMvs64DRzVnXAki4SZ-A6mx2Fi9IgeF11BA-5FB-CYg,4777
12
+ llm_utils/lm/utils.py,sha256=gUejbVZPYg97g4ftYEptYN52WhH3TAKOFW81sjLvi08,4585
13
13
  llm_utils/scripts/README.md,sha256=yuOLnLa2od2jp4wVy3rV0rESeiV3o8zol5MNMsZx0DY,999
14
14
  llm_utils/scripts/vllm_load_balancer.py,sha256=GjMdoZrdT9cSLos0qSdkLg2dwZgW1enAMsD3aTZAfNs,20845
15
15
  llm_utils/scripts/vllm_serve.py,sha256=4NaqpVs7LBvxtvTCMPsNCAOfqiWkKRttxWMmWY7SitA,14729
@@ -31,7 +31,7 @@ speedy_utils/multi_worker/thread.py,sha256=u_hTwXh7_FciMa5EukdEA1fDCY_vUC4moDceB
31
31
  speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  speedy_utils/scripts/mpython.py,sha256=73PHm1jqbCt2APN4xuNjD0VDKwzOj4EZsViEMQiZU2g,3853
33
33
  speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
34
- speedy_utils-1.0.23.dist-info/METADATA,sha256=E2NtrXhJt45XHFz5cv9BuxmdlPHqSZwHdq2vHJG7xqk,7442
35
- speedy_utils-1.0.23.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
36
- speedy_utils-1.0.23.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
37
- speedy_utils-1.0.23.dist-info/RECORD,,
34
+ speedy_utils-1.1.0.dist-info/METADATA,sha256=h1Alzm4q92GSiw5GNZWn6d8sHaSJS4X8RTMXStjkqHY,7441
35
+ speedy_utils-1.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
36
+ speedy_utils-1.1.0.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
37
+ speedy_utils-1.1.0.dist-info/RECORD,,