speedy-utils 1.0.24__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/async_lm.py +44 -44
- {speedy_utils-1.0.24.dist-info → speedy_utils-1.1.0.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.24.dist-info → speedy_utils-1.1.0.dist-info}/RECORD +5 -5
- {speedy_utils-1.0.24.dist-info → speedy_utils-1.1.0.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.24.dist-info → speedy_utils-1.1.0.dist-info}/entry_points.txt +0 -0
llm_utils/lm/async_lm.py
CHANGED
|
@@ -93,7 +93,7 @@ from typing import (
|
|
|
93
93
|
cast,
|
|
94
94
|
overload,
|
|
95
95
|
)
|
|
96
|
-
|
|
96
|
+
from typing_extensions import TypedDict
|
|
97
97
|
from httpx import URL
|
|
98
98
|
from loguru import logger
|
|
99
99
|
from numpy import isin
|
|
@@ -146,6 +146,12 @@ def _yellow(t):
|
|
|
146
146
|
return _color(33, t)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
+
class ParsedOutput(TypedDict):
|
|
150
|
+
messages: List
|
|
151
|
+
completion: Any
|
|
152
|
+
parsed: BaseModel
|
|
153
|
+
|
|
154
|
+
|
|
149
155
|
class AsyncLM:
|
|
150
156
|
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
151
157
|
|
|
@@ -462,11 +468,9 @@ class AsyncLM:
|
|
|
462
468
|
add_json_schema_to_instruction: bool = False,
|
|
463
469
|
temperature: Optional[float] = None,
|
|
464
470
|
max_tokens: Optional[int] = None,
|
|
465
|
-
return_openai_response: bool = False,
|
|
466
471
|
cache: Optional[bool] = True,
|
|
467
|
-
return_messages: bool = False,
|
|
468
472
|
**kwargs,
|
|
469
|
-
):
|
|
473
|
+
) -> ParsedOutput: # -> dict[str, Any]:
|
|
470
474
|
"""Parse response using guided JSON generation."""
|
|
471
475
|
if messages is None:
|
|
472
476
|
assert instruction is not None, "Instruction must be provided."
|
|
@@ -517,37 +521,27 @@ class AsyncLM:
|
|
|
517
521
|
"response_format": response_model.__name__,
|
|
518
522
|
}
|
|
519
523
|
cache_key = self._cache_key(cache_data, {}, response_model)
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
**model_kwargs,
|
|
532
|
-
)
|
|
533
|
-
|
|
534
|
-
if cache_key:
|
|
535
|
-
self._dump_cache(cache_key, completion)
|
|
524
|
+
completion = self._load_cache(cache_key) # dict
|
|
525
|
+
else:
|
|
526
|
+
completion = await self.client.chat.completions.create(
|
|
527
|
+
model=self.model, # type: ignore
|
|
528
|
+
messages=messages, # type: ignore
|
|
529
|
+
extra_body={"guided_json": json_schema},
|
|
530
|
+
**model_kwargs,
|
|
531
|
+
)
|
|
532
|
+
completion = completion.model_dump()
|
|
533
|
+
if cache_key:
|
|
534
|
+
self._dump_cache(cache_key, completion)
|
|
536
535
|
|
|
537
536
|
self.last_log = [prompt, messages, completion]
|
|
538
537
|
|
|
539
538
|
output = self._parse_complete_output(completion, response_model)
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
"messages": full_messages,
|
|
547
|
-
"completion": completion,
|
|
548
|
-
"parsed": output,
|
|
549
|
-
}
|
|
550
|
-
return output
|
|
539
|
+
full_messages = messages + [completion]
|
|
540
|
+
return ParsedOutput(
|
|
541
|
+
messages=full_messages,
|
|
542
|
+
completion=completion,
|
|
543
|
+
parsed=output,
|
|
544
|
+
)
|
|
551
545
|
|
|
552
546
|
def _parse_complete_output(
|
|
553
547
|
self, completion: Any, response_model: Type[BaseModel]
|
|
@@ -894,8 +888,8 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
894
888
|
"""
|
|
895
889
|
|
|
896
890
|
lm: "AsyncLM"
|
|
897
|
-
InputModel:
|
|
898
|
-
OutputModel:
|
|
891
|
+
InputModel: InputModelType
|
|
892
|
+
OutputModel: OutputModelType
|
|
899
893
|
|
|
900
894
|
temperature: float = 0.6
|
|
901
895
|
think: bool = False
|
|
@@ -906,8 +900,7 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
906
900
|
data: BaseModel | dict,
|
|
907
901
|
temperature: float = 0.1,
|
|
908
902
|
cache: bool = False,
|
|
909
|
-
|
|
910
|
-
) -> OutputModelType | tuple[OutputModelType, List[Dict[str, Any]]]:
|
|
903
|
+
) -> tuple[OutputModelType, List[Dict[str, Any]]]:
|
|
911
904
|
# Get the input and output model types from the generic parameters
|
|
912
905
|
type_args = getattr(self.__class__, "__orig_bases__", None)
|
|
913
906
|
if (
|
|
@@ -930,7 +923,17 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
930
923
|
input_model = self.InputModel
|
|
931
924
|
output_model = self.OutputModel
|
|
932
925
|
|
|
933
|
-
|
|
926
|
+
# Ensure input_model is a class before calling
|
|
927
|
+
if isinstance(data, BaseModel):
|
|
928
|
+
item = data
|
|
929
|
+
elif isinstance(input_model, type) and issubclass(input_model, BaseModel):
|
|
930
|
+
item = input_model(**data)
|
|
931
|
+
else:
|
|
932
|
+
raise TypeError("InputModel must be a subclass of BaseModel")
|
|
933
|
+
|
|
934
|
+
assert isinstance(output_model, type) and issubclass(output_model, BaseModel), (
|
|
935
|
+
"OutputModel must be a subclass of BaseModel"
|
|
936
|
+
)
|
|
934
937
|
|
|
935
938
|
result = await self.lm.parse(
|
|
936
939
|
prompt=item.model_dump_json(),
|
|
@@ -940,15 +943,12 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
940
943
|
think=self.think,
|
|
941
944
|
add_json_schema_to_instruction=self.add_json_schema,
|
|
942
945
|
cache=cache,
|
|
943
|
-
return_messages=True,
|
|
944
946
|
)
|
|
945
947
|
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
)
|
|
951
|
-
return cast(OutputModelType, result)
|
|
948
|
+
return (
|
|
949
|
+
cast(OutputModelType, result["parsed"]), # type: ignore
|
|
950
|
+
cast(List[dict], result["messages"]), # type: ignore
|
|
951
|
+
)
|
|
952
952
|
|
|
953
953
|
def generate_training_data(
|
|
954
954
|
self, input_dict: Dict[str, Any], output: Dict[str, Any]
|
|
@@ -962,4 +962,4 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
962
962
|
)
|
|
963
963
|
return {"messages": messages}
|
|
964
964
|
|
|
965
|
-
|
|
965
|
+
arun = __call__ # alias for compatibility with other LLMTask implementations
|
|
@@ -5,7 +5,7 @@ llm_utils/chat_format/transform.py,sha256=8TZhvUS5DrjUeMNtDIuWY54B_QZ7jjpXEL9c8F
|
|
|
5
5
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
6
6
|
llm_utils/group_messages.py,sha256=8CU9nKOja3xeuhdrX5CvYVveSqSKb2zQ0eeNzA88aTQ,3621
|
|
7
7
|
llm_utils/lm/__init__.py,sha256=rX36_MsnekM5GHwWS56XELbm4W5x2TDwnPERDTfo0eU,194
|
|
8
|
-
llm_utils/lm/async_lm.py,sha256=
|
|
8
|
+
llm_utils/lm/async_lm.py,sha256=kiWEecrkCTTQFlQj5JiHNziFeLOF1-7G_2xC2Dra1bw,35806
|
|
9
9
|
llm_utils/lm/chat_html.py,sha256=FkGo0Dv_nAHYBMZzXfMu_bGQKaCx302goh3XaT-_ETc,8674
|
|
10
10
|
llm_utils/lm/lm_json.py,sha256=fMt42phzFV2f6ulrtWcDXsWHi8WcG7gGkCzpIq8VSSM,1975
|
|
11
11
|
llm_utils/lm/sync_lm.py,sha256=ANw_m5KiWcRwwoeQ5no6dzPFLc6j9o2oEcJtkMKqrn8,34640
|
|
@@ -31,7 +31,7 @@ speedy_utils/multi_worker/thread.py,sha256=u_hTwXh7_FciMa5EukdEA1fDCY_vUC4moDceB
|
|
|
31
31
|
speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
speedy_utils/scripts/mpython.py,sha256=73PHm1jqbCt2APN4xuNjD0VDKwzOj4EZsViEMQiZU2g,3853
|
|
33
33
|
speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
|
|
34
|
-
speedy_utils-1.0.
|
|
35
|
-
speedy_utils-1.0.
|
|
36
|
-
speedy_utils-1.0.
|
|
37
|
-
speedy_utils-1.0.
|
|
34
|
+
speedy_utils-1.1.0.dist-info/METADATA,sha256=h1Alzm4q92GSiw5GNZWn6d8sHaSJS4X8RTMXStjkqHY,7441
|
|
35
|
+
speedy_utils-1.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
36
|
+
speedy_utils-1.1.0.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
|
|
37
|
+
speedy_utils-1.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|