speedy-utils 1.0.22__py3-none-any.whl → 1.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/async_lm.py +61 -16
- llm_utils/lm/utils.py +0 -5
- {speedy_utils-1.0.22.dist-info → speedy_utils-1.0.24.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.22.dist-info → speedy_utils-1.0.24.dist-info}/RECORD +6 -6
- {speedy_utils-1.0.22.dist-info → speedy_utils-1.0.24.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.22.dist-info → speedy_utils-1.0.24.dist-info}/entry_points.txt +0 -0
llm_utils/lm/async_lm.py
CHANGED
|
@@ -82,6 +82,7 @@ from functools import lru_cache
|
|
|
82
82
|
from typing import (
|
|
83
83
|
Any,
|
|
84
84
|
Dict,
|
|
85
|
+
Generic,
|
|
85
86
|
List,
|
|
86
87
|
Literal,
|
|
87
88
|
Optional,
|
|
@@ -95,6 +96,7 @@ from typing import (
|
|
|
95
96
|
|
|
96
97
|
from httpx import URL
|
|
97
98
|
from loguru import logger
|
|
99
|
+
from numpy import isin
|
|
98
100
|
from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
99
101
|
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
100
102
|
|
|
@@ -462,6 +464,7 @@ class AsyncLM:
|
|
|
462
464
|
max_tokens: Optional[int] = None,
|
|
463
465
|
return_openai_response: bool = False,
|
|
464
466
|
cache: Optional[bool] = True,
|
|
467
|
+
return_messages: bool = False,
|
|
465
468
|
**kwargs,
|
|
466
469
|
):
|
|
467
470
|
"""Parse response using guided JSON generation."""
|
|
@@ -532,9 +535,19 @@ class AsyncLM:
|
|
|
532
535
|
self._dump_cache(cache_key, completion)
|
|
533
536
|
|
|
534
537
|
self.last_log = [prompt, messages, completion]
|
|
538
|
+
|
|
539
|
+
output = self._parse_complete_output(completion, response_model)
|
|
535
540
|
if return_openai_response:
|
|
536
|
-
return completion
|
|
537
|
-
|
|
541
|
+
return {"completion": completion, "parsed": output}
|
|
542
|
+
if return_messages:
|
|
543
|
+
# content = completion.model_dump()
|
|
544
|
+
full_messages = messages + [completion.model_dump()]
|
|
545
|
+
return {
|
|
546
|
+
"messages": full_messages,
|
|
547
|
+
"completion": completion,
|
|
548
|
+
"parsed": output,
|
|
549
|
+
}
|
|
550
|
+
return output
|
|
538
551
|
|
|
539
552
|
def _parse_complete_output(
|
|
540
553
|
self, completion: Any, response_model: Type[BaseModel]
|
|
@@ -839,8 +852,11 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
839
852
|
# Async LLMTask class
|
|
840
853
|
# --------------------------------------------------------------------------- #
|
|
841
854
|
|
|
855
|
+
InputModelType = TypeVar("InputModelType", bound=BaseModel)
|
|
856
|
+
OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
|
|
842
857
|
|
|
843
|
-
|
|
858
|
+
|
|
859
|
+
class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
844
860
|
"""
|
|
845
861
|
Async callable wrapper around an AsyncLM endpoint.
|
|
846
862
|
|
|
@@ -885,30 +901,58 @@ class AsyncLLMTask(ABC):
|
|
|
885
901
|
think: bool = False
|
|
886
902
|
add_json_schema: bool = False
|
|
887
903
|
|
|
888
|
-
async def __call__(
|
|
904
|
+
async def __call__(
|
|
905
|
+
self,
|
|
906
|
+
data: BaseModel | dict,
|
|
907
|
+
temperature: float = 0.1,
|
|
908
|
+
cache: bool = False,
|
|
909
|
+
collect_messages: bool = False,
|
|
910
|
+
) -> OutputModelType | tuple[OutputModelType, List[Dict[str, Any]]]:
|
|
911
|
+
# Get the input and output model types from the generic parameters
|
|
912
|
+
type_args = getattr(self.__class__, "__orig_bases__", None)
|
|
889
913
|
if (
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
914
|
+
type_args
|
|
915
|
+
and hasattr(type_args[0], "__args__")
|
|
916
|
+
and len(type_args[0].__args__) >= 2
|
|
893
917
|
):
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
918
|
+
input_model = type_args[0].__args__[0]
|
|
919
|
+
output_model = type_args[0].__args__[1]
|
|
920
|
+
else:
|
|
921
|
+
# Fallback to the old way if type introspection fails
|
|
922
|
+
if (
|
|
923
|
+
not hasattr(self, "InputModel")
|
|
924
|
+
or not hasattr(self, "OutputModel")
|
|
925
|
+
or not hasattr(self, "lm")
|
|
926
|
+
):
|
|
927
|
+
raise NotImplementedError(
|
|
928
|
+
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes or use proper generic typing."
|
|
929
|
+
)
|
|
930
|
+
input_model = self.InputModel
|
|
931
|
+
output_model = self.OutputModel
|
|
897
932
|
|
|
898
|
-
item = data if isinstance(data, BaseModel) else
|
|
933
|
+
item = data if isinstance(data, BaseModel) else input_model(**data)
|
|
899
934
|
|
|
900
|
-
|
|
935
|
+
result = await self.lm.parse(
|
|
901
936
|
prompt=item.model_dump_json(),
|
|
902
937
|
instruction=self.__doc__ or "",
|
|
903
|
-
response_model=
|
|
904
|
-
temperature=self.temperature,
|
|
938
|
+
response_model=output_model,
|
|
939
|
+
temperature=temperature or self.temperature,
|
|
905
940
|
think=self.think,
|
|
906
941
|
add_json_schema_to_instruction=self.add_json_schema,
|
|
942
|
+
cache=cache,
|
|
943
|
+
return_messages=True,
|
|
907
944
|
)
|
|
908
945
|
|
|
946
|
+
if collect_messages:
|
|
947
|
+
return (
|
|
948
|
+
cast(OutputModelType, result),
|
|
949
|
+
result["messages"],
|
|
950
|
+
)
|
|
951
|
+
return cast(OutputModelType, result)
|
|
952
|
+
|
|
909
953
|
def generate_training_data(
|
|
910
954
|
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
911
|
-
):
|
|
955
|
+
) -> Dict[str, Any]:
|
|
912
956
|
"""Return share gpt like format"""
|
|
913
957
|
system_prompt = self.__doc__ or ""
|
|
914
958
|
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
@@ -917,4 +961,5 @@ class AsyncLLMTask(ABC):
|
|
|
917
961
|
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
918
962
|
)
|
|
919
963
|
return {"messages": messages}
|
|
920
|
-
|
|
964
|
+
|
|
965
|
+
# arun = __call__ # alias for compatibility with other LLMTask implementations
|
llm_utils/lm/utils.py
CHANGED
|
@@ -7,11 +7,6 @@ import numpy as np
|
|
|
7
7
|
from loguru import logger
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def _clear_port_use(ports):
|
|
11
|
-
for port in ports:
|
|
12
|
-
file_counter: str = f"/tmp/port_use_counter_{port}.npy"
|
|
13
|
-
if os.path.exists(file_counter):
|
|
14
|
-
os.remove(file_counter)
|
|
15
10
|
|
|
16
11
|
|
|
17
12
|
def _atomic_save(array: np.ndarray, filename: str):
|
|
@@ -5,11 +5,11 @@ llm_utils/chat_format/transform.py,sha256=8TZhvUS5DrjUeMNtDIuWY54B_QZ7jjpXEL9c8F
|
|
|
5
5
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
6
6
|
llm_utils/group_messages.py,sha256=8CU9nKOja3xeuhdrX5CvYVveSqSKb2zQ0eeNzA88aTQ,3621
|
|
7
7
|
llm_utils/lm/__init__.py,sha256=rX36_MsnekM5GHwWS56XELbm4W5x2TDwnPERDTfo0eU,194
|
|
8
|
-
llm_utils/lm/async_lm.py,sha256=
|
|
8
|
+
llm_utils/lm/async_lm.py,sha256=xUOlAOivxrI-KRaUv1V0l5y7ajYTKTA6QZkqpK6Uue8,35847
|
|
9
9
|
llm_utils/lm/chat_html.py,sha256=FkGo0Dv_nAHYBMZzXfMu_bGQKaCx302goh3XaT-_ETc,8674
|
|
10
10
|
llm_utils/lm/lm_json.py,sha256=fMt42phzFV2f6ulrtWcDXsWHi8WcG7gGkCzpIq8VSSM,1975
|
|
11
11
|
llm_utils/lm/sync_lm.py,sha256=ANw_m5KiWcRwwoeQ5no6dzPFLc6j9o2oEcJtkMKqrn8,34640
|
|
12
|
-
llm_utils/lm/utils.py,sha256=
|
|
12
|
+
llm_utils/lm/utils.py,sha256=gUejbVZPYg97g4ftYEptYN52WhH3TAKOFW81sjLvi08,4585
|
|
13
13
|
llm_utils/scripts/README.md,sha256=yuOLnLa2od2jp4wVy3rV0rESeiV3o8zol5MNMsZx0DY,999
|
|
14
14
|
llm_utils/scripts/vllm_load_balancer.py,sha256=GjMdoZrdT9cSLos0qSdkLg2dwZgW1enAMsD3aTZAfNs,20845
|
|
15
15
|
llm_utils/scripts/vllm_serve.py,sha256=4NaqpVs7LBvxtvTCMPsNCAOfqiWkKRttxWMmWY7SitA,14729
|
|
@@ -31,7 +31,7 @@ speedy_utils/multi_worker/thread.py,sha256=u_hTwXh7_FciMa5EukdEA1fDCY_vUC4moDceB
|
|
|
31
31
|
speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
speedy_utils/scripts/mpython.py,sha256=73PHm1jqbCt2APN4xuNjD0VDKwzOj4EZsViEMQiZU2g,3853
|
|
33
33
|
speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
|
|
34
|
-
speedy_utils-1.0.
|
|
35
|
-
speedy_utils-1.0.
|
|
36
|
-
speedy_utils-1.0.
|
|
37
|
-
speedy_utils-1.0.
|
|
34
|
+
speedy_utils-1.0.24.dist-info/METADATA,sha256=p8QjQuz3u1B50Sj5qoPR0p-FsrYGvhd19dwfYM_MlwA,7442
|
|
35
|
+
speedy_utils-1.0.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
36
|
+
speedy_utils-1.0.24.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
|
|
37
|
+
speedy_utils-1.0.24.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|