speedy-utils 1.0.23__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/async_lm.py +83 -38
- llm_utils/lm/utils.py +0 -5
- {speedy_utils-1.0.23.dist-info → speedy_utils-1.1.0.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.23.dist-info → speedy_utils-1.1.0.dist-info}/RECORD +6 -6
- {speedy_utils-1.0.23.dist-info → speedy_utils-1.1.0.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.23.dist-info → speedy_utils-1.1.0.dist-info}/entry_points.txt +0 -0
llm_utils/lm/async_lm.py
CHANGED
|
@@ -82,6 +82,7 @@ from functools import lru_cache
|
|
|
82
82
|
from typing import (
|
|
83
83
|
Any,
|
|
84
84
|
Dict,
|
|
85
|
+
Generic,
|
|
85
86
|
List,
|
|
86
87
|
Literal,
|
|
87
88
|
Optional,
|
|
@@ -92,9 +93,10 @@ from typing import (
|
|
|
92
93
|
cast,
|
|
93
94
|
overload,
|
|
94
95
|
)
|
|
95
|
-
|
|
96
|
+
from typing_extensions import TypedDict
|
|
96
97
|
from httpx import URL
|
|
97
98
|
from loguru import logger
|
|
99
|
+
from numpy import isin
|
|
98
100
|
from openai import AsyncOpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
99
101
|
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
100
102
|
|
|
@@ -144,6 +146,12 @@ def _yellow(t):
|
|
|
144
146
|
return _color(33, t)
|
|
145
147
|
|
|
146
148
|
|
|
149
|
+
class ParsedOutput(TypedDict):
|
|
150
|
+
messages: List
|
|
151
|
+
completion: Any
|
|
152
|
+
parsed: BaseModel
|
|
153
|
+
|
|
154
|
+
|
|
147
155
|
class AsyncLM:
|
|
148
156
|
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
149
157
|
|
|
@@ -460,10 +468,9 @@ class AsyncLM:
|
|
|
460
468
|
add_json_schema_to_instruction: bool = False,
|
|
461
469
|
temperature: Optional[float] = None,
|
|
462
470
|
max_tokens: Optional[int] = None,
|
|
463
|
-
return_openai_response: bool = False,
|
|
464
471
|
cache: Optional[bool] = True,
|
|
465
472
|
**kwargs,
|
|
466
|
-
):
|
|
473
|
+
) -> ParsedOutput: # -> dict[str, Any]:
|
|
467
474
|
"""Parse response using guided JSON generation."""
|
|
468
475
|
if messages is None:
|
|
469
476
|
assert instruction is not None, "Instruction must be provided."
|
|
@@ -514,27 +521,27 @@ class AsyncLM:
|
|
|
514
521
|
"response_format": response_model.__name__,
|
|
515
522
|
}
|
|
516
523
|
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
**model_kwargs,
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
if cache_key:
|
|
532
|
-
self._dump_cache(cache_key, completion)
|
|
524
|
+
completion = self._load_cache(cache_key) # dict
|
|
525
|
+
else:
|
|
526
|
+
completion = await self.client.chat.completions.create(
|
|
527
|
+
model=self.model, # type: ignore
|
|
528
|
+
messages=messages, # type: ignore
|
|
529
|
+
extra_body={"guided_json": json_schema},
|
|
530
|
+
**model_kwargs,
|
|
531
|
+
)
|
|
532
|
+
completion = completion.model_dump()
|
|
533
|
+
if cache_key:
|
|
534
|
+
self._dump_cache(cache_key, completion)
|
|
533
535
|
|
|
534
536
|
self.last_log = [prompt, messages, completion]
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
537
|
+
|
|
538
|
+
output = self._parse_complete_output(completion, response_model)
|
|
539
|
+
full_messages = messages + [completion]
|
|
540
|
+
return ParsedOutput(
|
|
541
|
+
messages=full_messages,
|
|
542
|
+
completion=completion,
|
|
543
|
+
parsed=output,
|
|
544
|
+
)
|
|
538
545
|
|
|
539
546
|
def _parse_complete_output(
|
|
540
547
|
self, completion: Any, response_model: Type[BaseModel]
|
|
@@ -839,8 +846,11 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
839
846
|
# Async LLMTask class
|
|
840
847
|
# --------------------------------------------------------------------------- #
|
|
841
848
|
|
|
849
|
+
InputModelType = TypeVar("InputModelType", bound=BaseModel)
|
|
850
|
+
OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
|
|
842
851
|
|
|
843
|
-
|
|
852
|
+
|
|
853
|
+
class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
844
854
|
"""
|
|
845
855
|
Async callable wrapper around an AsyncLM endpoint.
|
|
846
856
|
|
|
@@ -878,37 +888,71 @@ class AsyncLLMTask(ABC):
|
|
|
878
888
|
"""
|
|
879
889
|
|
|
880
890
|
lm: "AsyncLM"
|
|
881
|
-
InputModel:
|
|
882
|
-
OutputModel:
|
|
891
|
+
InputModel: InputModelType
|
|
892
|
+
OutputModel: OutputModelType
|
|
883
893
|
|
|
884
894
|
temperature: float = 0.6
|
|
885
895
|
think: bool = False
|
|
886
896
|
add_json_schema: bool = False
|
|
887
897
|
|
|
888
|
-
async def __call__(
|
|
898
|
+
async def __call__(
|
|
899
|
+
self,
|
|
900
|
+
data: BaseModel | dict,
|
|
901
|
+
temperature: float = 0.1,
|
|
902
|
+
cache: bool = False,
|
|
903
|
+
) -> tuple[OutputModelType, List[Dict[str, Any]]]:
|
|
904
|
+
# Get the input and output model types from the generic parameters
|
|
905
|
+
type_args = getattr(self.__class__, "__orig_bases__", None)
|
|
889
906
|
if (
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
907
|
+
type_args
|
|
908
|
+
and hasattr(type_args[0], "__args__")
|
|
909
|
+
and len(type_args[0].__args__) >= 2
|
|
893
910
|
):
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
911
|
+
input_model = type_args[0].__args__[0]
|
|
912
|
+
output_model = type_args[0].__args__[1]
|
|
913
|
+
else:
|
|
914
|
+
# Fallback to the old way if type introspection fails
|
|
915
|
+
if (
|
|
916
|
+
not hasattr(self, "InputModel")
|
|
917
|
+
or not hasattr(self, "OutputModel")
|
|
918
|
+
or not hasattr(self, "lm")
|
|
919
|
+
):
|
|
920
|
+
raise NotImplementedError(
|
|
921
|
+
f"{self.__class__.__name__} must define lm, InputModel, and OutputModel as class attributes or use proper generic typing."
|
|
922
|
+
)
|
|
923
|
+
input_model = self.InputModel
|
|
924
|
+
output_model = self.OutputModel
|
|
925
|
+
|
|
926
|
+
# Ensure input_model is a class before calling
|
|
927
|
+
if isinstance(data, BaseModel):
|
|
928
|
+
item = data
|
|
929
|
+
elif isinstance(input_model, type) and issubclass(input_model, BaseModel):
|
|
930
|
+
item = input_model(**data)
|
|
931
|
+
else:
|
|
932
|
+
raise TypeError("InputModel must be a subclass of BaseModel")
|
|
897
933
|
|
|
898
|
-
|
|
934
|
+
assert isinstance(output_model, type) and issubclass(output_model, BaseModel), (
|
|
935
|
+
"OutputModel must be a subclass of BaseModel"
|
|
936
|
+
)
|
|
899
937
|
|
|
900
|
-
|
|
938
|
+
result = await self.lm.parse(
|
|
901
939
|
prompt=item.model_dump_json(),
|
|
902
940
|
instruction=self.__doc__ or "",
|
|
903
|
-
response_model=
|
|
904
|
-
temperature=self.temperature,
|
|
941
|
+
response_model=output_model,
|
|
942
|
+
temperature=temperature or self.temperature,
|
|
905
943
|
think=self.think,
|
|
906
944
|
add_json_schema_to_instruction=self.add_json_schema,
|
|
945
|
+
cache=cache,
|
|
946
|
+
)
|
|
947
|
+
|
|
948
|
+
return (
|
|
949
|
+
cast(OutputModelType, result["parsed"]), # type: ignore
|
|
950
|
+
cast(List[dict], result["messages"]), # type: ignore
|
|
907
951
|
)
|
|
908
952
|
|
|
909
953
|
def generate_training_data(
|
|
910
954
|
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
911
|
-
):
|
|
955
|
+
) -> Dict[str, Any]:
|
|
912
956
|
"""Return share gpt like format"""
|
|
913
957
|
system_prompt = self.__doc__ or ""
|
|
914
958
|
user_msg = self.InputModel(**input_dict).model_dump_json() # type: ignore[attr-defined]
|
|
@@ -917,4 +961,5 @@ class AsyncLLMTask(ABC):
|
|
|
917
961
|
system_msg=system_prompt, user_msg=user_msg, assistant_msg=assistant_msg
|
|
918
962
|
)
|
|
919
963
|
return {"messages": messages}
|
|
920
|
-
|
|
964
|
+
|
|
965
|
+
arun = __call__ # alias for compatibility with other LLMTask implementations
|
llm_utils/lm/utils.py
CHANGED
|
@@ -7,11 +7,6 @@ import numpy as np
|
|
|
7
7
|
from loguru import logger
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def _clear_port_use(ports):
|
|
11
|
-
for port in ports:
|
|
12
|
-
file_counter: str = f"/tmp/port_use_counter_{port}.npy"
|
|
13
|
-
if os.path.exists(file_counter):
|
|
14
|
-
os.remove(file_counter)
|
|
15
10
|
|
|
16
11
|
|
|
17
12
|
def _atomic_save(array: np.ndarray, filename: str):
|
|
@@ -5,11 +5,11 @@ llm_utils/chat_format/transform.py,sha256=8TZhvUS5DrjUeMNtDIuWY54B_QZ7jjpXEL9c8F
|
|
|
5
5
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
6
6
|
llm_utils/group_messages.py,sha256=8CU9nKOja3xeuhdrX5CvYVveSqSKb2zQ0eeNzA88aTQ,3621
|
|
7
7
|
llm_utils/lm/__init__.py,sha256=rX36_MsnekM5GHwWS56XELbm4W5x2TDwnPERDTfo0eU,194
|
|
8
|
-
llm_utils/lm/async_lm.py,sha256=
|
|
8
|
+
llm_utils/lm/async_lm.py,sha256=kiWEecrkCTTQFlQj5JiHNziFeLOF1-7G_2xC2Dra1bw,35806
|
|
9
9
|
llm_utils/lm/chat_html.py,sha256=FkGo0Dv_nAHYBMZzXfMu_bGQKaCx302goh3XaT-_ETc,8674
|
|
10
10
|
llm_utils/lm/lm_json.py,sha256=fMt42phzFV2f6ulrtWcDXsWHi8WcG7gGkCzpIq8VSSM,1975
|
|
11
11
|
llm_utils/lm/sync_lm.py,sha256=ANw_m5KiWcRwwoeQ5no6dzPFLc6j9o2oEcJtkMKqrn8,34640
|
|
12
|
-
llm_utils/lm/utils.py,sha256=
|
|
12
|
+
llm_utils/lm/utils.py,sha256=gUejbVZPYg97g4ftYEptYN52WhH3TAKOFW81sjLvi08,4585
|
|
13
13
|
llm_utils/scripts/README.md,sha256=yuOLnLa2od2jp4wVy3rV0rESeiV3o8zol5MNMsZx0DY,999
|
|
14
14
|
llm_utils/scripts/vllm_load_balancer.py,sha256=GjMdoZrdT9cSLos0qSdkLg2dwZgW1enAMsD3aTZAfNs,20845
|
|
15
15
|
llm_utils/scripts/vllm_serve.py,sha256=4NaqpVs7LBvxtvTCMPsNCAOfqiWkKRttxWMmWY7SitA,14729
|
|
@@ -31,7 +31,7 @@ speedy_utils/multi_worker/thread.py,sha256=u_hTwXh7_FciMa5EukdEA1fDCY_vUC4moDceB
|
|
|
31
31
|
speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
speedy_utils/scripts/mpython.py,sha256=73PHm1jqbCt2APN4xuNjD0VDKwzOj4EZsViEMQiZU2g,3853
|
|
33
33
|
speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
|
|
34
|
-
speedy_utils-1.0.
|
|
35
|
-
speedy_utils-1.0.
|
|
36
|
-
speedy_utils-1.0.
|
|
37
|
-
speedy_utils-1.0.
|
|
34
|
+
speedy_utils-1.1.0.dist-info/METADATA,sha256=h1Alzm4q92GSiw5GNZWn6d8sHaSJS4X8RTMXStjkqHY,7441
|
|
35
|
+
speedy_utils-1.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
36
|
+
speedy_utils-1.1.0.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
|
|
37
|
+
speedy_utils-1.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|