speedy-utils 1.0.22__py3-none-any.whl → 1.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/async_lm.py CHANGED
@@ -82,6 +82,7 @@ from functools import lru_cache
82
82
  from typing import (
83
83
  Any,
84
84
  Dict,
85
+ Generic,
85
86
  List,
86
87
  Literal,
87
88
  Optional,
@@ -95,6 +96,7 @@ from typing import (
95
96
 
96
97
  from httpx import URL
97
98
  from loguru import logger
99
+ from numpy import isin
98
100
  from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
99
101
  from openai.pagination import AsyncPage as AsyncSyncPage
100
102
 
@@ -462,6 +464,7 @@ class AsyncLM:
462
464
  max_tokens: Optional[int] = None,
463
465
  return_openai_response: bool = False,
464
466
  cache: Optional[bool] = True,
467
+ return_messages: bool = False,
465
468
  **kwargs,
466
469
  ):
467
470
  """Parse response using guided JSON generation."""
@@ -532,9 +535,19 @@ class AsyncLM:
532
535
  self._dump_cache(cache_key, completion)
533
536
 
534
537
  self.last_log = [prompt, messages, completion]
538
+
539
+ output = self._parse_complete_output(completion, response_model)
535
540
  if return_openai_response:
536
- return completion
537
- return self._parse_complete_output(completion, response_model)
541
+ return {"completion": completion, "parsed": output}
542
+ if return_messages:
543
+ # content = completion.model_dump()
544
+ full_messages = messages + [completion.model_dump()]
545
+ return {
546
+ "messages": full_messages,
547
+ "completion": completion,
548
+ "parsed": output,
549
+ }
550
+ return output
538
551
 
539
552
  def _parse_complete_output(
540
553
  self, completion: Any, response_model: Type[BaseModel]
@@ -839,8 +852,11 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
839
852
  # Async LLMTask class
840
853
  # --------------------------------------------------------------------------- #
841
854
 
855
+ InputModelType = TypeVar("InputModelType", bound=BaseModel)
856
+ OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
842
857
 
843
- class AsyncLLMTask(ABC):
858
+
859
+ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
844
860
  """
845
861
  Async callable wrapper around an AsyncLM endpoint.
846
862
 
@@ -885,30 +901,58 @@ class AsyncLLMTask(ABC):
885
901
  think: bool = False
886
902
  add_json_schema: bool = False
887
903
 
888
- async def __call__(self, data: BaseModel | dict) -> BaseModel:
904
+ async def __call__(
905
+ self,
906
+ data: BaseModel | dict,
907
+ temperature: float = 0.1,
908
+ cache: bool = False,
909
+ collect_messages: bool = False,
910
+ ) -> OutputModelType | tuple[OutputModelType, List[Dict[str, Any]]]:
911
+ # Get the input and output model types from the generic parameters
912
+ type_args = getattr(self.__class__, "__orig_bases__", None)
889
913
  if (
890
- not hasattr(self, "InputModel")
891
- or not hasattr(self, "OutputModel")
892
- or not hasattr(self, "lm")
914
+ type_args
915
+ and hasattr(type_args[0], "__args__")
916
+ and len(type_args[0].__args__) >= 2
893
917
  ):
894
- raise NotImplementedError(
895
- f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes."
896
- )
918
+ input_model = type_args[0].__args__[0]
919
+ output_model = type_args[0].__args__[1]
920
+ else:
921
+ # Fallback to the old way if type introspection fails
922
+ if (
923
+ not hasattr(self, "InputModel")
924
+ or not hasattr(self, "OutputModel")
925
+ or not hasattr(self, "lm")
926
+ ):
927
+ raise NotImplementedError(
928
+ f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes or use proper generic typing."
929
+ )
930
+ input_model = self.InputModel
931
+ output_model = self.OutputModel
897
932
 
898
- item = data if isinstance(data, BaseModel) else self.InputModel(**data)
933
+ item = data if isinstance(data, BaseModel) else input_model(**data)
899
934
 
900
- return await self.lm.parse(
935
+ result = await self.lm.parse(
901
936
  prompt=item.model_dump_json(),
902
937
  instruction=self.__doc__ or "",
903
- response_model=self.OutputModel,
904
- temperature=self.temperature,
938
+ response_model=output_model,
939
+ temperature=temperature or self.temperature,
905
940
  think=self.think,
906
941
  add_json_schema_to_instruction=self.add_json_schema,
942
+ cache=cache,
943
+ return_messages=True,
907
944
  )
908
945
 
946
+ if collect_messages:
947
+ return (
948
+ cast(OutputModelType, result),
949
+ result["messages"],
950
+ )
951
+ return cast(OutputModelType, result)
952
+
909
953
  def generate_training_data(
910
954
  self, input_dict: Dict[str, Any], output: Dict[str, Any]
911
- ):
955
+ ) -> Dict[str, Any]:
912
956
  """Return share gpt like format"""
913
957
  system_prompt = self.__doc__ or ""
914
958
  user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
@@ -917,4 +961,5 @@ class AsyncLLMTask(ABC):
917
961
  system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
918
962
  )
919
963
  return {"messages": messages}
920
- arun = __call__ # alias for compatibility with other LLMTask implementations
964
+
965
+ # arun = __call__ # alias for compatibility with other LLMTask implementations
llm_utils/lm/utils.py CHANGED
@@ -7,11 +7,6 @@ import numpy as np
7
7
  from loguru import logger
8
8
 
9
9
 
10
- def _clear_port_use(ports):
11
- for port in ports:
12
- file_counter: str = f"/tmp/port_use_counter_{port}.npy"
13
- if os.path.exists(file_counter):
14
- os.remove(file_counter)
15
10
 
16
11
 
17
12
  def _atomic_save(array: np.ndarray, filename: str):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 1.0.22
3
+ Version: 1.0.24
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -5,11 +5,11 @@ llm_utils/chat_format/transform.py,sha256=8TZhvUS5DrjUeMNtDIuWY54B_QZ7jjpXEL9c8F
5
5
  llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
6
6
  llm_utils/group_messages.py,sha256=8CU9nKOja3xeuhdrX5CvYVveSqSKb2zQ0eeNzA88aTQ,3621
7
7
  llm_utils/lm/__init__.py,sha256=rX36_MsnekM5GHwWS56XELbm4W5x2TDwnPERDTfo0eU,194
8
- llm_utils/lm/async_lm.py,sha256=_NmWEp_jCbD6soexXo489L40KS8xJPgtY5QxXLDYsis,34174
8
+ llm_utils/lm/async_lm.py,sha256=xUOlAOivxrI-KRaUv1V0l5y7ajYTKTA6QZkqpK6Uue8,35847
9
9
  llm_utils/lm/chat_html.py,sha256=FkGo0Dv_nAHYBMZzXfMu_bGQKaCx302goh3XaT-_ETc,8674
10
10
  llm_utils/lm/lm_json.py,sha256=fMt42phzFV2f6ulrtWcDXsWHi8WcG7gGkCzpIq8VSSM,1975
11
11
  llm_utils/lm/sync_lm.py,sha256=ANw_m5KiWcRwwoeQ5no6dzPFLc6j9o2oEcJtkMKqrn8,34640
12
- llm_utils/lm/utils.py,sha256=GMvs64DRzVnXAki4SZ-A6mx2Fi9IgeF11BA-5FB-CYg,4777
12
+ llm_utils/lm/utils.py,sha256=gUejbVZPYg97g4ftYEptYN52WhH3TAKOFW81sjLvi08,4585
13
13
  llm_utils/scripts/README.md,sha256=yuOLnLa2od2jp4wVy3rV0rESeiV3o8zol5MNMsZx0DY,999
14
14
  llm_utils/scripts/vllm_load_balancer.py,sha256=GjMdoZrdT9cSLos0qSdkLg2dwZgW1enAMsD3aTZAfNs,20845
15
15
  llm_utils/scripts/vllm_serve.py,sha256=4NaqpVs7LBvxtvTCMPsNCAOfqiWkKRttxWMmWY7SitA,14729
@@ -31,7 +31,7 @@ speedy_utils/multi_worker/thread.py,sha256=u_hTwXh7_FciMa5EukdEA1fDCY_vUC4moDceB
31
31
  speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  speedy_utils/scripts/mpython.py,sha256=73PHm1jqbCt2APN4xuNjD0VDKwzOj4EZsViEMQiZU2g,3853
33
33
  speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
34
- speedy_utils-1.0.22.dist-info/METADATA,sha256=Ll1EUWmXjsgvn2nK2NZ-uSrf6SbkTY1mLaHRKWfgR2Q,7442
35
- speedy_utils-1.0.22.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
36
- speedy_utils-1.0.22.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
37
- speedy_utils-1.0.22.dist-info/RECORD,,
34
+ speedy_utils-1.0.24.dist-info/METADATA,sha256=p8QjQuz3u1B50Sj5qoPR0p-FsrYGvhd19dwfYM_MlwA,7442
35
+ speedy_utils-1.0.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
36
+ speedy_utils-1.0.24.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
37
+ speedy_utils-1.0.24.dist-info/RECORD,,