speedy-utils 1.0.24__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/async_lm.py CHANGED
@@ -93,7 +93,7 @@ from typing import (
93
93
  cast,
94
94
  overload,
95
95
  )
96
-
96
+ from typing_extensions import TypedDict
97
97
  from httpx import URL
98
98
  from loguru import logger
99
99
  from numpy import isin
@@ -146,6 +146,12 @@ def _yellow(t):
146
146
  return _color(33, t)
147
147
 
148
148
 
149
+ class ParsedOutput(TypedDict):
150
+ messages: List
151
+ completion: Any
152
+ parsed: BaseModel
153
+
154
+
149
155
  class AsyncLM:
150
156
  """Unified **async** language‑model wrapper with optional JSON parsing."""
151
157
 
@@ -462,11 +468,9 @@ class AsyncLM:
462
468
  add_json_schema_to_instruction: bool = False,
463
469
  temperature: Optional[float] = None,
464
470
  max_tokens: Optional[int] = None,
465
- return_openai_response: bool = False,
466
471
  cache: Optional[bool] = True,
467
- return_messages: bool = False,
468
472
  **kwargs,
469
- ):
473
+ ) -> ParsedOutput: # -> dict[str, Any]:
470
474
  """Parse response using guided JSON generation."""
471
475
  if messages is None:
472
476
  assert instruction is not None, "Instruction must be provided."
@@ -517,37 +521,27 @@ class AsyncLM:
517
521
  "response_format": response_model.__name__,
518
522
  }
519
523
  cache_key = self._cache_key(cache_data, {}, response_model)
520
- cached_response = self._load_cache(cache_key)
521
- self.last_log = [prompt, messages, cached_response]
522
- if cached_response is not None:
523
- if return_openai_response:
524
- return cached_response
525
- return self._parse_complete_output(cached_response, response_model)
526
-
527
- completion = await self.client.chat.completions.create(
528
- model=self.model, # type: ignore
529
- messages=messages, # type: ignore
530
- extra_body={"guided_json": json_schema},
531
- **model_kwargs,
532
- )
533
-
534
- if cache_key:
535
- self._dump_cache(cache_key, completion)
524
+ completion = self._load_cache(cache_key) # dict
525
+ else:
526
+ completion = await self.client.chat.completions.create(
527
+ model=self.model, # type: ignore
528
+ messages=messages, # type: ignore
529
+ extra_body={"guided_json": json_schema},
530
+ **model_kwargs,
531
+ )
532
+ completion = completion.model_dump()
533
+ if cache_key:
534
+ self._dump_cache(cache_key, completion)
536
535
 
537
536
  self.last_log = [prompt, messages, completion]
538
537
 
539
538
  output = self._parse_complete_output(completion, response_model)
540
- if return_openai_response:
541
- return {"completion": completion, "parsed": output}
542
- if return_messages:
543
- # content = completion.model_dump()
544
- full_messages = messages + [completion.model_dump()]
545
- return {
546
- "messages": full_messages,
547
- "completion": completion,
548
- "parsed": output,
549
- }
550
- return output
539
+ full_messages = messages + [completion]
540
+ return ParsedOutput(
541
+ messages=full_messages,
542
+ completion=completion,
543
+ parsed=output,
544
+ )
551
545
 
552
546
  def _parse_complete_output(
553
547
  self, completion: Any, response_model: Type[BaseModel]
@@ -894,8 +888,8 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
894
888
  """
895
889
 
896
890
  lm: "AsyncLM"
897
- InputModel: Type[BaseModel]
898
- OutputModel: Type[BaseModel]
891
+ InputModel: InputModelType
892
+ OutputModel: OutputModelType
899
893
 
900
894
  temperature: float = 0.6
901
895
  think: bool = False
@@ -906,8 +900,7 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
906
900
  data: BaseModel | dict,
907
901
  temperature: float = 0.1,
908
902
  cache: bool = False,
909
- collect_messages: bool = False,
910
- ) -> OutputModelType | tuple[OutputModelType, List[Dict[str, Any]]]:
903
+ ) -> tuple[OutputModelType, List[Dict[str, Any]]]:
911
904
  # Get the input and output model types from the generic parameters
912
905
  type_args = getattr(self.__class__, "__orig_bases__", None)
913
906
  if (
@@ -930,7 +923,17 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
930
923
  input_model = self.InputModel
931
924
  output_model = self.OutputModel
932
925
 
933
- item = data if isinstance(data, BaseModel) else input_model(**data)
926
+ # Ensure input_model is a class before calling
927
+ if isinstance(data, BaseModel):
928
+ item = data
929
+ elif isinstance(input_model, type) and issubclass(input_model, BaseModel):
930
+ item = input_model(**data)
931
+ else:
932
+ raise TypeError("InputModel must be a subclass of BaseModel")
933
+
934
+ assert isinstance(output_model, type) and issubclass(output_model, BaseModel), (
935
+ "OutputModel must be a subclass of BaseModel"
936
+ )
934
937
 
935
938
  result = await self.lm.parse(
936
939
  prompt=item.model_dump_json(),
@@ -940,15 +943,12 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
940
943
  think=self.think,
941
944
  add_json_schema_to_instruction=self.add_json_schema,
942
945
  cache=cache,
943
- return_messages=True,
944
946
  )
945
947
 
946
- if collect_messages:
947
- return (
948
- cast(OutputModelType, result),
949
- result["messages"],
950
- )
951
- return cast(OutputModelType, result)
948
+ return (
949
+ cast(OutputModelType, result["parsed"]), # type: ignore
950
+ cast(List[dict], result["messages"]), # type: ignore
951
+ )
952
952
 
953
953
  def generate_training_data(
954
954
  self, input_dict: Dict[str, Any], output: Dict[str, Any]
@@ -962,4 +962,4 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
962
962
  )
963
963
  return {"messages": messages}
964
964
 
965
- # arun = __call__ # alias for compatibility with other LLMTask implementations
965
+ arun = __call__ # alias for compatibility with other LLMTask implementations
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 1.0.24
3
+ Version: 1.1.0
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -5,7 +5,7 @@ llm_utils/chat_format/transform.py,sha256=8TZhvUS5DrjUeMNtDIuWY54B_QZ7jjpXEL9c8F
5
5
  llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
6
6
  llm_utils/group_messages.py,sha256=8CU9nKOja3xeuhdrX5CvYVveSqSKb2zQ0eeNzA88aTQ,3621
7
7
  llm_utils/lm/__init__.py,sha256=rX36_MsnekM5GHwWS56XELbm4W5x2TDwnPERDTfo0eU,194
8
- llm_utils/lm/async_lm.py,sha256=xUOlAOivxrI-KRaUv1V0l5y7ajYTKTA6QZkqpK6Uue8,35847
8
+ llm_utils/lm/async_lm.py,sha256=kiWEecrkCTTQFlQj5JiHNziFeLOF1-7G_2xC2Dra1bw,35806
9
9
  llm_utils/lm/chat_html.py,sha256=FkGo0Dv_nAHYBMZzXfMu_bGQKaCx302goh3XaT-_ETc,8674
10
10
  llm_utils/lm/lm_json.py,sha256=fMt42phzFV2f6ulrtWcDXsWHi8WcG7gGkCzpIq8VSSM,1975
11
11
  llm_utils/lm/sync_lm.py,sha256=ANw_m5KiWcRwwoeQ5no6dzPFLc6j9o2oEcJtkMKqrn8,34640
@@ -31,7 +31,7 @@ speedy_utils/multi_worker/thread.py,sha256=u_hTwXh7_FciMa5EukdEA1fDCY_vUC4moDceB
31
31
  speedy_utils/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  speedy_utils/scripts/mpython.py,sha256=73PHm1jqbCt2APN4xuNjD0VDKwzOj4EZsViEMQiZU2g,3853
33
33
  speedy_utils/scripts/openapi_client_codegen.py,sha256=f2125S_q0PILgH5dyzoKRz7pIvNEjCkzpi4Q4pPFRZE,9683
34
- speedy_utils-1.0.24.dist-info/METADATA,sha256=p8QjQuz3u1B50Sj5qoPR0p-FsrYGvhd19dwfYM_MlwA,7442
35
- speedy_utils-1.0.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
36
- speedy_utils-1.0.24.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
37
- speedy_utils-1.0.24.dist-info/RECORD,,
34
+ speedy_utils-1.1.0.dist-info/METADATA,sha256=h1Alzm4q92GSiw5GNZWn6d8sHaSJS4X8RTMXStjkqHY,7441
35
+ speedy_utils-1.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
36
+ speedy_utils-1.1.0.dist-info/entry_points.txt,sha256=T1t85jwx8fK6m5msdkBGIXH5R5Kd0zSL0S6erXERPzg,237
37
+ speedy_utils-1.1.0.dist-info/RECORD,,