promptbuilder 0.4.37__tar.gz → 0.4.39__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {promptbuilder-0.4.37/promptbuilder.egg-info → promptbuilder-0.4.39}/PKG-INFO +2 -1
  2. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/agent/agent.py +16 -8
  3. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/base_client.py +127 -20
  4. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/google_client.py +2 -0
  5. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/litellm_client.py +6 -6
  6. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/logfire_decorators.py +6 -3
  7. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/openai_client.py +56 -2
  8. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/utils.py +155 -0
  9. {promptbuilder-0.4.37 → promptbuilder-0.4.39/promptbuilder.egg-info}/PKG-INFO +2 -1
  10. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder.egg-info/SOURCES.txt +0 -1
  11. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder.egg-info/requires.txt +1 -0
  12. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/setup.py +3 -2
  13. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/tests/test_timeout_google.py +2 -2
  14. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/tests/test_timeout_litellm.py +7 -3
  15. promptbuilder-0.4.37/promptbuilder/llm_client/vertex_client.py +0 -403
  16. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/LICENSE +0 -0
  17. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/MANIFEST.in +0 -0
  18. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/Readme.md +0 -0
  19. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/__init__.py +0 -0
  20. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/agent/__init__.py +0 -0
  21. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/agent/context.py +0 -0
  22. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/agent/tool.py +0 -0
  23. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/agent/utils.py +0 -0
  24. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/embeddings.py +0 -0
  25. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/__init__.py +0 -0
  26. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/aisuite_client.py +0 -0
  27. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/anthropic_client.py +0 -0
  28. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/bedrock_client.py +0 -0
  29. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/config.py +0 -0
  30. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/exceptions.py +0 -0
  31. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/main.py +0 -0
  32. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/llm_client/types.py +0 -0
  33. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder/prompt_builder.py +0 -0
  34. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder.egg-info/dependency_links.txt +0 -0
  35. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/promptbuilder.egg-info/top_level.txt +0 -0
  36. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/pyproject.toml +0 -0
  37. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/setup.cfg +0 -0
  38. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/tests/test_llm_client.py +0 -0
  39. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/tests/test_llm_client_async.py +0 -0
  40. {promptbuilder-0.4.37 → promptbuilder-0.4.39}/tests/test_timeout_openai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: promptbuilder
3
- Version: 0.4.37
3
+ Version: 0.4.39
4
4
  Summary: Library for building prompts for LLMs
5
5
  Home-page: https://github.com/kapulkin/promptbuilder
6
6
  Author: Kapulkin Stanislav
@@ -21,6 +21,7 @@ Requires-Dist: aioboto3
21
21
  Requires-Dist: litellm
22
22
  Requires-Dist: httpx
23
23
  Requires-Dist: aiohttp
24
+ Requires-Dist: tiktoken
24
25
  Dynamic: author
25
26
  Dynamic: author-email
26
27
  Dynamic: classifier
@@ -82,21 +82,24 @@ class AgentRouter(Agent[MessageType, ContextType]):
82
82
  )
83
83
  content = response.candidates[0].content
84
84
 
85
+ router_tool_contents = []
85
86
  for part in content.parts:
86
87
  if part.function_call is None:
87
88
  if part.text is not None:
88
- self.context.dialog_history.add_message(Content(parts=[Part(text=part.text, thought=part.thought)], role="model"))
89
+ router_tool_contents.append(Content(parts=[Part(text=part.text)], role="model"))
89
90
  else:
90
91
  tr_name = part.function_call.name
91
- args = part.function_call.args
92
- if args is None:
93
- args = {}
92
+ tr_args = part.function_call.args
93
+ if tr_args is None:
94
+ tr_args = {}
94
95
 
95
96
  route = self.routes.get(tr_name)
96
97
  if route is not None:
98
+ router_tool_contents = []
99
+
97
100
  self.last_used_tr_name = tr_name
98
- logger.debug("Route %s called with args: %s", tr_name, args)
99
- merged_args = {**kwargs, **args}
101
+ logger.debug("Route %s called with args: %s", tr_name, tr_args)
102
+ merged_args = {**kwargs, **tr_args}
100
103
  result = await route(**merged_args)
101
104
  logger.debug("Route %s result: %s", tr_name, result)
102
105
  trs_to_exclude = trs_to_exclude | {tr_name}
@@ -108,9 +111,14 @@ class AgentRouter(Agent[MessageType, ContextType]):
108
111
  tool = self.tools.get(tr_name)
109
112
  if tool is not None:
110
113
  self.last_used_tr_name = tr_name
114
+
115
+ for rtc in router_tool_contents:
116
+ self.context.dialog_history.add_message(rtc)
117
+ router_tool_contents = []
118
+
111
119
  self.context.dialog_history.add_message(content)
112
- logger.debug("Tool %s called with args: %s", tr_name, args)
113
- tool_response = await tool(**args)
120
+ logger.debug("Tool %s called with args: %s", tr_name, tr_args)
121
+ tool_response = await tool(**tr_args)
114
122
  logger.debug("Tool %s response: %s", tr_name, tool_response)
115
123
  self.context.dialog_history.add_message(tool_response.candidates[0].content)
116
124
  trs_to_exclude = trs_to_exclude | {tr_name}
@@ -134,6 +134,7 @@ class BaseLLMClient(ABC, utils.InheritDecoratorsMixin):
134
134
  @logfire_decorators.create
135
135
  @utils.retry_cls
136
136
  @utils.rpm_limit_cls
137
+ @utils.tpm_limit_cls
137
138
  @abstractmethod
138
139
  def _create(
139
140
  self,
@@ -252,13 +253,12 @@ class BaseLLMClient(ABC, utils.InheritDecoratorsMixin):
252
253
  if result_type is None:
253
254
  return response.text
254
255
  else:
255
- if result_type == "json":
256
+ if result_type == "json" and response.parsed is None:
256
257
  response.parsed = BaseLLMClient.as_json(response.text)
257
258
  return response.parsed
258
259
 
259
-
260
260
  @staticmethod
261
- def _append_generated_part(messages: list[Content], response: Response) -> Content | None:
261
+ def _responce_to_text(response: Response):
262
262
  assert(response.candidates and response.candidates[0].content), "Response must contain at least one candidate with content."
263
263
 
264
264
  text_parts = [
@@ -267,6 +267,7 @@ class BaseLLMClient(ABC, utils.InheritDecoratorsMixin):
267
267
  if text_parts is not None and len(text_parts) > 0:
268
268
  response_text = "".join(part.text for part in text_parts)
269
269
  is_thought = False
270
+ return response_text, is_thought
270
271
  else:
271
272
  thought_parts = [
272
273
  part for part in response.candidates[0].content.parts if part.text and part.thought
@@ -274,17 +275,28 @@ class BaseLLMClient(ABC, utils.InheritDecoratorsMixin):
274
275
  if thought_parts is not None and len(thought_parts) > 0:
275
276
  response_text = "".join(part.text for part in thought_parts)
276
277
  is_thought = True
278
+ return response_text, is_thought
277
279
  else:
278
- return None
280
+ return None, None
281
+
282
+ @staticmethod
283
+ def _append_to_message(message: Content, text: str, is_thought: bool):
284
+ if message.parts and message.parts[-1].text is not None and message.parts[-1].thought == is_thought:
285
+ message.parts[-1].text += text
286
+ else:
287
+ if not message.parts:
288
+ message.parts = []
289
+ message.parts.append(Part(text=text, thought=is_thought))
290
+
291
+ @staticmethod
292
+ def _append_generated_part(messages: list[Content], response: Response) -> Content | None:
293
+ response_text, is_thought = BaseLLMClient._responce_to_text(response)
294
+ if response_text is None:
295
+ return None
279
296
 
280
297
  if len(messages) > 0 and messages[-1].role == "model":
281
298
  message_to_append = messages[-1]
282
- if message_to_append.parts and message_to_append.parts[-1].text is not None and message_to_append.parts[-1].thought == is_thought:
283
- message_to_append.parts[-1].text += response_text
284
- else:
285
- if not message_to_append.parts:
286
- message_to_append.parts = []
287
- message_to_append.parts.append(Part(text=response_text, thought=is_thought))
299
+ BaseLLMClient._append_to_message(message_to_append, response_text, is_thought)
288
300
  else:
289
301
  messages.append(Content(parts=[Part(text=response_text, thought=is_thought)], role="model"))
290
302
  return messages[-1]
@@ -296,6 +308,7 @@ class BaseLLMClient(ABC, utils.InheritDecoratorsMixin):
296
308
  @logfire_decorators.create_stream
297
309
  @utils.retry_cls
298
310
  @utils.rpm_limit_cls
311
+ @utils.tpm_limit_cls
299
312
  def _create_stream(
300
313
  self,
301
314
  messages: list[Content],
@@ -539,6 +552,7 @@ class BaseLLMClientAsync(ABC, utils.InheritDecoratorsMixin):
539
552
  @logfire_decorators.create_async
540
553
  @utils.retry_cls_async
541
554
  @utils.rpm_limit_cls_async
555
+ @utils.tpm_limit_cls_async
542
556
  @abstractmethod
543
557
  async def _create(
544
558
  self,
@@ -656,13 +670,14 @@ class BaseLLMClientAsync(ABC, utils.InheritDecoratorsMixin):
656
670
  if result_type is None:
657
671
  return response.text
658
672
  else:
659
- if result_type == "json":
673
+ if result_type == "json" and response.parsed is None:
660
674
  response.parsed = BaseLLMClient.as_json(response.text)
661
675
  return response.parsed
662
676
 
663
677
  @logfire_decorators.create_stream_async
664
678
  @utils.retry_cls_async
665
679
  @utils.rpm_limit_cls_async
680
+ @utils.tpm_limit_cls_async
666
681
  async def _create_stream(
667
682
  self,
668
683
  messages: list[Content],
@@ -819,19 +834,65 @@ class CachedLLMClient(BaseLLMClient):
819
834
  self.llm_client = llm_client
820
835
  self.cache_dir = cache_dir
821
836
 
822
- def _create(self, messages: list[Content], **kwargs) -> Response:
823
- response, messages_dump, cache_path = CachedLLMClient.create_cached(self.llm_client, self.cache_dir, messages, **kwargs)
837
+ def _create(self, messages: list[Content], system_message: str | None = None, **kwargs) -> Response:
838
+ response, messages_dump, cache_path = CachedLLMClient.create_cached(self.llm_client, self.cache_dir, messages, system_message, **kwargs)
824
839
  if response is not None:
825
840
  return response
826
- response = self.llm_client.create(messages, **kwargs)
841
+ response = self.llm_client.create(messages, system_message=system_message, **kwargs)
827
842
  CachedLLMClient.save_cache(cache_path, self.llm_client.full_model_name, messages_dump, response)
828
843
  return response
829
844
 
845
+
846
+ def _create_stream(
847
+ self,
848
+ messages: list[Content],
849
+ *,
850
+ thinking_config: ThinkingConfig | None = None,
851
+ system_message: str | None = None,
852
+ max_tokens: int | None = None,
853
+ ) -> Iterator[Response]:
854
+ response, messages_dump, cache_path = CachedLLMClient.create_cached(
855
+ self.llm_client, self.cache_dir, messages,
856
+ thinking_config=thinking_config,
857
+ system_message=system_message,
858
+ max_tokens=max_tokens,
859
+ )
860
+ if response is not None:
861
+ yield response
862
+ return
863
+
864
+ accumulated_content: Content | None = None
865
+ final_response: Response | None = None
866
+
867
+ for response in self.llm_client._create_stream(
868
+ messages=messages,
869
+ thinking_config=thinking_config,
870
+ system_message=system_message,
871
+ max_tokens=max_tokens,
872
+ ):
873
+ # Accumulate content from each response chunk
874
+ if response.candidates and response.candidates[0].content:
875
+ response_text, is_thought = BaseLLMClient._responce_to_text(response)
876
+ if response_text is not None:
877
+ if accumulated_content is None:
878
+ accumulated_content = Content(parts=[], role="model")
879
+ BaseLLMClient._append_to_message(accumulated_content, response_text, is_thought or False)
880
+ final_response = response
881
+ yield response
882
+
883
+ # Save accumulated response to cache
884
+ if final_response is not None and accumulated_content is not None and final_response.candidates:
885
+ cached_response = Response(
886
+ candidates=[final_response.candidates[0].model_copy(update={"content": accumulated_content})],
887
+ usage_metadata=final_response.usage_metadata,
888
+ )
889
+ CachedLLMClient.save_cache(cache_path, self.llm_client.full_model_name, messages_dump, cached_response)
890
+
830
891
  @staticmethod
831
- def create_cached(llm_client: BaseLLMClient | BaseLLMClientAsync, cache_dir: str, messages: list[Content], **kwargs) -> tuple[Response | None, list[dict], str]:
892
+ def create_cached(llm_client: BaseLLMClient | BaseLLMClientAsync, cache_dir: str, messages: list[Content], system_message: str | None = None, **kwargs) -> tuple[Response | None, list[dict], str]:
832
893
  messages_dump = [message.model_dump() for message in messages]
833
894
  key = hashlib.sha256(
834
- json.dumps((llm_client.full_model_name, messages_dump)).encode()
895
+ json.dumps((llm_client.full_model_name, messages_dump, system_message)).encode()
835
896
  ).hexdigest()
836
897
  cache_path = os.path.join(cache_dir, f"{key}.json")
837
898
  if os.path.exists(cache_path):
@@ -855,7 +916,7 @@ class CachedLLMClient(BaseLLMClient):
855
916
  @staticmethod
856
917
  def save_cache(cache_path: str, full_model_name: str, messages_dump: list[dict], response: Response):
857
918
  with open(cache_path, 'wt') as f:
858
- json.dump({"full_model_name": full_model_name, "request": messages_dump, "response": response.model_dump()}, f, indent=4)
919
+ json.dump({"full_model_name": full_model_name, "request": messages_dump, "response": Response.model_dump(response, mode="json")}, f, indent=4)
859
920
 
860
921
 
861
922
  class CachedLLMClientAsync(BaseLLMClientAsync):
@@ -869,10 +930,56 @@ class CachedLLMClientAsync(BaseLLMClientAsync):
869
930
  self.llm_client = llm_client
870
931
  self.cache_dir = cache_dir
871
932
 
872
- async def _create(self, messages: list[Content], **kwargs) -> Response:
873
- response, messages_dump, cache_path = CachedLLMClient.create_cached(self.llm_client, self.cache_dir, messages, **kwargs)
933
+ async def _create(self, messages: list[Content], system_message: str | None = None, **kwargs) -> Response:
934
+ response, messages_dump, cache_path = CachedLLMClient.create_cached(self.llm_client, self.cache_dir, messages, system_message, **kwargs)
874
935
  if response is not None:
875
936
  return response
876
- response = await self.llm_client.create(messages, **kwargs)
937
+ response = await self.llm_client.create(messages, system_message=system_message, **kwargs)
877
938
  CachedLLMClient.save_cache(cache_path, self.llm_client.full_model_name, messages_dump, response)
878
939
  return response
940
+
941
+
942
+ async def _create_stream(
943
+ self,
944
+ messages: list[Content],
945
+ *,
946
+ thinking_config: ThinkingConfig | None = None,
947
+ system_message: str | None = None,
948
+ max_tokens: int | None = None,
949
+ ) -> AsyncIterator[Response]:
950
+ response, messages_dump, cache_path = CachedLLMClient.create_cached(
951
+ self.llm_client, self.cache_dir, messages,
952
+ thinking_config=thinking_config,
953
+ system_message=system_message,
954
+ max_tokens=max_tokens,
955
+ )
956
+ if response is not None:
957
+ yield response
958
+ return
959
+
960
+ accumulated_content: Content | None = None
961
+ final_response: Response | None = None
962
+
963
+ async for response in self.llm_client._create_stream(
964
+ messages=messages,
965
+ thinking_config=thinking_config,
966
+ system_message=system_message,
967
+ max_tokens=max_tokens,
968
+ ):
969
+ # Accumulate content from each response chunk
970
+ if response.candidates and response.candidates[0].content:
971
+ response_text, is_thought = BaseLLMClient._responce_to_text(response)
972
+ if response_text is not None:
973
+ if accumulated_content is None:
974
+ accumulated_content = Content(parts=[], role="model")
975
+ BaseLLMClient._append_to_message(accumulated_content, response_text, is_thought or False)
976
+ final_response = response
977
+ yield response
978
+
979
+ # Save accumulated response to cache
980
+ if final_response is not None and accumulated_content is not None and final_response.candidates:
981
+ cached_response = Response(
982
+ candidates=[final_response.candidates[0].model_copy(update={"content": accumulated_content})],
983
+ usage_metadata=final_response.usage_metadata,
984
+ )
985
+ CachedLLMClient.save_cache(cache_path, self.llm_client.full_model_name, messages_dump, cached_response)
@@ -123,6 +123,7 @@ class GoogleLLMClient(BaseLLMClient):
123
123
  config=config,
124
124
  )
125
125
  elif result_type == "json":
126
+ config.response_mime_type = "application/json"
126
127
  response = self.client.models.generate_content(
127
128
  model=self.model,
128
129
  contents=messages,
@@ -273,6 +274,7 @@ class GoogleLLMClientAsync(BaseLLMClientAsync):
273
274
  config.thinking_config = thinking_config
274
275
 
275
276
  if result_type is None or result_type == "json":
277
+ config.response_mime_type = "application/json"
276
278
  return await self.client.aio.models.generate_content(
277
279
  model=self.model,
278
280
  contents=messages,
@@ -241,7 +241,7 @@ class LiteLLMClient(BaseLLMClient):
241
241
  finish_reason_val = first_choice.get("finish_reason")
242
242
  else:
243
243
  finish_reason_val = getattr(first_choice, "finish_reason", None)
244
- mapped_finish_reason = LiteLLMLLMClient._map_finish_reason(finish_reason_val)
244
+ mapped_finish_reason = LiteLLMClient._map_finish_reason(finish_reason_val)
245
245
 
246
246
  content_parts: list[Part | Any] = list(parts)
247
247
  return Response(
@@ -293,7 +293,7 @@ class LiteLLMClient(BaseLLMClient):
293
293
  finish_reason_val = first_choice.get("finish_reason")
294
294
  else:
295
295
  finish_reason_val = getattr(first_choice, "finish_reason", None)
296
- mapped_finish_reason = LiteLLMLLMClient._map_finish_reason(finish_reason_val)
296
+ mapped_finish_reason = LiteLLMClient._map_finish_reason(finish_reason_val)
297
297
 
298
298
  content_parts2: list[Part | Any] = list(parts)
299
299
  return Response(
@@ -460,11 +460,11 @@ class LiteLLMClientAsync(BaseLLMClientAsync):
460
460
 
461
461
  @staticmethod
462
462
  def make_function_call(tool_call) -> FunctionCall | None:
463
- return LiteLLMLLMClient.make_function_call(tool_call)
463
+ return LiteLLMClient.make_function_call(tool_call)
464
464
 
465
465
  @staticmethod
466
466
  def make_usage_metadata(usage) -> UsageMetadata:
467
- return LiteLLMLLMClient.make_usage_metadata(usage)
467
+ return LiteLLMClient.make_usage_metadata(usage)
468
468
 
469
469
  async def _create(
470
470
  self,
@@ -569,7 +569,7 @@ class LiteLLMClientAsync(BaseLLMClientAsync):
569
569
  finish_reason_val = first_choice.get("finish_reason")
570
570
  else:
571
571
  finish_reason_val = getattr(first_choice, "finish_reason", None)
572
- mapped_finish_reason = LiteLLMLLMClient._map_finish_reason(finish_reason_val)
572
+ mapped_finish_reason = LiteLLMClient._map_finish_reason(finish_reason_val)
573
573
 
574
574
  content_parts3: list[Part | Any] = list(parts)
575
575
  return Response(
@@ -621,7 +621,7 @@ class LiteLLMClientAsync(BaseLLMClientAsync):
621
621
  finish_reason_val = first_choice.get("finish_reason")
622
622
  else:
623
623
  finish_reason_val = getattr(first_choice, "finish_reason", None)
624
- mapped_finish_reason = LiteLLMLLMClient._map_finish_reason(finish_reason_val)
624
+ mapped_finish_reason = LiteLLMClient._map_finish_reason(finish_reason_val)
625
625
 
626
626
  content_parts4: list[Part | Any] = list(parts)
627
627
  return Response(
@@ -46,9 +46,12 @@ def extract_response_data(response: Response) -> dict[str, Any]:
46
46
  response_data = {"message": {"role": "assistant"}}
47
47
  response_data["message"]["content"] = response.text
48
48
  tool_calls = []
49
- for part in response.candidates[0].content.parts:
50
- if part.function_call is not None:
51
- tool_calls.append({"function": {"name": part.function_call.name, "arguments": part.function_call.args}})
49
+ if response.candidates is not None and len(response.candidates) > 0:
50
+ content = response.candidates[0].content
51
+ if content is not None and content.parts is not None:
52
+ for part in content.parts:
53
+ if part.function_call is not None:
54
+ tool_calls.append({"function": {"name": part.function_call.name, "arguments": part.function_call.args}})
52
55
  if len(tool_calls) > 0:
53
56
  response_data["message"]["tool_calls"] = tool_calls
54
57
  return response_data
@@ -205,7 +205,7 @@ class OpenaiLLMClient(BaseLLMClient):
205
205
  elif tool_choice_mode == "ANY":
206
206
  openai_kwargs["tool_choice"] = "required"
207
207
 
208
- if result_type is None or result_type == "json":
208
+ if result_type is None:
209
209
  # Forward timeout to OpenAI per-request if provided
210
210
  if timeout is not None:
211
211
  openai_kwargs["timeout"] = timeout
@@ -222,6 +222,33 @@ class OpenaiLLMClient(BaseLLMClient):
222
222
  elif output_item.type == "function_call":
223
223
  parts.append(Part(function_call=FunctionCall(args=json.loads(output_item.arguments), name=output_item.name)))
224
224
 
225
+ return Response(
226
+ candidates=[Candidate(content=Content(parts=parts, role="model"))],
227
+ usage_metadata=UsageMetadata(
228
+ candidates_token_count=response.usage.output_tokens,
229
+ prompt_token_count=response.usage.input_tokens,
230
+ total_token_count=response.usage.total_tokens,
231
+ )
232
+ )
233
+ elif result_type == "json":
234
+ # Forward timeout to OpenAI per-request if provided
235
+ if timeout is not None:
236
+ openai_kwargs["timeout"] = timeout
237
+ response = self.client.responses.create(**openai_kwargs, text={ "format" : { "type": "json_object" } })
238
+
239
+ response_text = ""
240
+ parts: list[Part] = []
241
+ for output_item in response.output:
242
+ if output_item.type == "message":
243
+ for content in output_item.content:
244
+ parts.append(Part(text=content.text))
245
+ response_text += content.text
246
+ elif output_item.type == "reasoning":
247
+ for summary in output_item.summary:
248
+ parts.append(Part(text=summary.text, thought=True))
249
+ elif output_item.type == "function_call":
250
+ parts.append(Part(function_call=FunctionCall(args=json.loads(output_item.arguments), name=output_item.name)))
251
+
225
252
  return Response(
226
253
  candidates=[Candidate(content=Content(parts=parts, role="model"))],
227
254
  usage_metadata=UsageMetadata(
@@ -229,6 +256,7 @@ class OpenaiLLMClient(BaseLLMClient):
229
256
  prompt_token_count=response.usage.input_tokens,
230
257
  total_token_count=response.usage.total_tokens,
231
258
  ),
259
+ parsed=BaseLLMClient.as_json(response_text)
232
260
  )
233
261
  elif isinstance(result_type, type(BaseModel)):
234
262
  if timeout is not None:
@@ -453,7 +481,7 @@ class OpenaiLLMClientAsync(BaseLLMClientAsync):
453
481
  elif tool_choice_mode == "ANY":
454
482
  openai_kwargs["tool_choice"] = "required"
455
483
 
456
- if result_type is None or result_type == "json":
484
+ if result_type is None:
457
485
  if timeout is not None:
458
486
  openai_kwargs["timeout"] = timeout
459
487
  response = await self.client.responses.create(**openai_kwargs)
@@ -476,6 +504,32 @@ class OpenaiLLMClientAsync(BaseLLMClientAsync):
476
504
  total_token_count=response.usage.total_tokens,
477
505
  ),
478
506
  )
507
+ elif result_type == "json":
508
+ if timeout is not None:
509
+ openai_kwargs["timeout"] = timeout
510
+ response = await self.client.responses.create(**openai_kwargs, text={ "format" : { "type": "json_object" } })
511
+ parts: list[Part] = []
512
+ response_text = ""
513
+ for output_item in response.output:
514
+ if output_item.type == "message":
515
+ for content in output_item.content:
516
+ parts.append(Part(text=content.text))
517
+ response_text += content.text
518
+ elif output_item.type == "reasoning":
519
+ for summary in output_item.summary:
520
+ parts.append(Part(text=summary.text, thought=True))
521
+ elif output_item.type == "function_call":
522
+ parts.append(Part(function_call=FunctionCall(args=json.loads(output_item.arguments), name=output_item.name)))
523
+
524
+ return Response(
525
+ candidates=[Candidate(content=Content(parts=parts, role="model"))],
526
+ usage_metadata=UsageMetadata(
527
+ candidates_token_count=response.usage.output_tokens,
528
+ prompt_token_count=response.usage.input_tokens,
529
+ total_token_count=response.usage.total_tokens,
530
+ ),
531
+ parsed=BaseLLMClient.as_json(response_text)
532
+ )
479
533
  elif isinstance(result_type, type(BaseModel)):
480
534
  if timeout is not None:
481
535
  openai_kwargs["timeout"] = timeout
@@ -4,10 +4,14 @@ import logging
4
4
  import traceback
5
5
  from functools import wraps
6
6
  from typing import Callable, Awaitable, ParamSpec, TypeVar
7
+ import tiktoken
7
8
  from collections import defaultdict
8
9
 
9
10
  from pydantic import BaseModel
10
11
 
12
+ from promptbuilder.llm_client.types import Content
13
+
14
+
11
15
 
12
16
  logger = logging.getLogger(__name__)
13
17
 
@@ -48,9 +52,14 @@ class RetryConfig(BaseModel):
48
52
  class RpmLimitConfig(BaseModel):
49
53
  rpm_limit: int = 0
50
54
 
55
+ class TpmLimitConfig(BaseModel):
56
+ tpm_limit: int = 0
57
+ fast: bool = False
58
+
51
59
  class DecoratorConfigs(BaseModel):
52
60
  retry: RetryConfig | None = None
53
61
  rpm_limit: RpmLimitConfig | None = None
62
+ tpm_limit: TpmLimitConfig | None = None
54
63
 
55
64
 
56
65
  @inherited_decorator
@@ -181,3 +190,149 @@ def rpm_limit_cls_async(class_method: Callable[P, Awaitable[T]]) -> Callable[P,
181
190
  self._last_request_time = time.time()
182
191
  return await class_method(self, *args, **kwargs)
183
192
  return wrapper
193
+
194
+
195
+ def _estimate_input_tokens_from_messages(self, messages: list[Content], fast: bool = False) -> int:
196
+ """Estimate input tokens for a list[Content] using best available method.
197
+
198
+ Priority:
199
+ 1) If provider == "google" and a google.genai client is available, use
200
+ models.count_tokens for accurate counts.
201
+ 2) If tiktoken is installed, approximate with a BPE encoding.
202
+ 3) Fallback heuristic: ~4 characters per token across text parts.
203
+ """
204
+ if not messages:
205
+ return 0
206
+
207
+ # Collect text parts for non-Google fallback methods
208
+ texts: list[str] = []
209
+ for m in messages:
210
+ parts = m.parts
211
+ if not parts:
212
+ continue
213
+ for part in parts:
214
+ text = part.text
215
+ if text:
216
+ texts.append(text)
217
+
218
+ if not fast:
219
+ # 1) Google Gemini accurate count via genai API (when provider == google)
220
+ if self.provider == "google":
221
+ genai_client = self.client
222
+ contents_arg = "\n".join(texts)
223
+ total_tokens = genai_client.models.count_tokens(
224
+ model=self.model,
225
+ contents=contents_arg,
226
+ ).total_tokens
227
+ return total_tokens
228
+
229
+ # 2) tiktoken approximation
230
+ # cl100k_base is a good default for many chat models
231
+ enc = tiktoken.get_encoding("cl100k_base")
232
+ return sum(len(enc.encode(t)) for t in texts)
233
+
234
+ else:
235
+ # 3) Heuristic fallback
236
+ total_chars = sum(len(t) for t in texts)
237
+ tokens = total_chars // 4
238
+ return tokens if tokens > 0 else (1 if total_chars > 0 else 0)
239
+
240
+
241
+ @inherited_decorator
242
+ def tpm_limit_cls(class_method: Callable[P, T]) -> Callable[P, T]:
243
+ """
244
+ Decorator that limits the number of input tokens per minute to the decorated class methods.
245
+ Decorated methods must have 'self' as its first arg and accept a 'messages' argument
246
+ either positionally (first arg) or by keyword.
247
+
248
+ The decorator estimates tokens from input messages and ensures the total tokens
249
+ sent within a 60-second window do not exceed the configured TPM limit. If the
250
+ limit would be exceeded, it waits until the window resets.
251
+ """
252
+ @wraps(class_method)
253
+ def wrapper(self, *args, **kwargs):
254
+ if not hasattr(self, "_decorator_configs"):
255
+ self._decorator_configs = DecoratorConfigs()
256
+ if getattr(self._decorator_configs, "tpm_limit", None) is None:
257
+ self._decorator_configs.tpm_limit = TpmLimitConfig()
258
+
259
+ limit = self._decorator_configs.tpm_limit.tpm_limit
260
+ if limit <= 0:
261
+ return class_method(self, *args, **kwargs)
262
+
263
+ # Extract messages from either kwargs or positional args
264
+ messages = kwargs.get("messages") if "messages" in kwargs else (args[0] if len(args) > 0 else None)
265
+ tokens_needed = _estimate_input_tokens_from_messages(self, messages, self._decorator_configs.tpm_limit.fast)
266
+
267
+ # Initialize sliding window state
268
+ now = time.time()
269
+ if not hasattr(self, "_tpm_window_start"):
270
+ self._tpm_window_start = now
271
+ if not hasattr(self, "_tpm_used_tokens"):
272
+ self._tpm_used_tokens = 0
273
+
274
+ while True:
275
+ now = time.time()
276
+ elapsed = now - self._tpm_window_start
277
+ if elapsed >= 60:
278
+ # Reset window
279
+ self._tpm_window_start = now
280
+ self._tpm_used_tokens = 0
281
+
282
+ if self._tpm_used_tokens + tokens_needed <= limit:
283
+ self._tpm_used_tokens += tokens_needed
284
+ break
285
+ # Need to wait until window resets
286
+ sleep_for = max(0.0, 60 - elapsed)
287
+ if sleep_for > 0:
288
+ time.sleep(sleep_for)
289
+ continue
290
+ # If sleep_for == 0, loop will reset on next iteration
291
+
292
+ return class_method(self, *args, **kwargs)
293
+ return wrapper
294
+
295
+
296
+ @inherited_decorator
297
+ def tpm_limit_cls_async(class_method: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
298
+ """
299
+ Async variant of TPM limiter.
300
+ """
301
+ @wraps(class_method)
302
+ async def wrapper(self, *args, **kwargs):
303
+ if not hasattr(self, "_decorator_configs"):
304
+ self._decorator_configs = DecoratorConfigs()
305
+ if getattr(self._decorator_configs, "tpm_limit", None) is None:
306
+ self._decorator_configs.tpm_limit = TpmLimitConfig()
307
+
308
+ limit = self._decorator_configs.tpm_limit.tpm_limit
309
+ if limit <= 0:
310
+ return await class_method(self, *args, **kwargs)
311
+
312
+ messages = kwargs.get("messages") if "messages" in kwargs else (args[0] if len(args) > 0 else None)
313
+ tokens_needed = _estimate_input_tokens_from_messages(self, messages, self._decorator_configs.tpm_limit.fast)
314
+
315
+ now = time.time()
316
+ if not hasattr(self, "_tpm_window_start"):
317
+ self._tpm_window_start = now
318
+ if not hasattr(self, "_tpm_used_tokens"):
319
+ self._tpm_used_tokens = 0
320
+
321
+ while True:
322
+ now = time.time()
323
+ elapsed = now - self._tpm_window_start
324
+ if elapsed >= 60:
325
+ self._tpm_window_start = now
326
+ self._tpm_used_tokens = 0
327
+
328
+ if self._tpm_used_tokens + tokens_needed <= limit:
329
+ self._tpm_used_tokens += tokens_needed
330
+ break
331
+
332
+ sleep_for = max(0.0, 60 - elapsed)
333
+ if sleep_for > 0:
334
+ await asyncio.sleep(sleep_for)
335
+ continue
336
+
337
+ return await class_method(self, *args, **kwargs)
338
+ return wrapper
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: promptbuilder
3
- Version: 0.4.37
3
+ Version: 0.4.39
4
4
  Summary: Library for building prompts for LLMs
5
5
  Home-page: https://github.com/kapulkin/promptbuilder
6
6
  Author: Kapulkin Stanislav
@@ -21,6 +21,7 @@ Requires-Dist: aioboto3
21
21
  Requires-Dist: litellm
22
22
  Requires-Dist: httpx
23
23
  Requires-Dist: aiohttp
24
+ Requires-Dist: tiktoken
24
25
  Dynamic: author
25
26
  Dynamic: author-email
26
27
  Dynamic: classifier
@@ -30,7 +30,6 @@ promptbuilder/llm_client/main.py
30
30
  promptbuilder/llm_client/openai_client.py
31
31
  promptbuilder/llm_client/types.py
32
32
  promptbuilder/llm_client/utils.py
33
- promptbuilder/llm_client/vertex_client.py
34
33
  tests/test_llm_client.py
35
34
  tests/test_llm_client_async.py
36
35
  tests/test_timeout_google.py
@@ -8,3 +8,4 @@ aioboto3
8
8
  litellm
9
9
  httpx
10
10
  aiohttp
11
+ tiktoken
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="promptbuilder",
5
- version="0.4.37",
5
+ version="0.4.39",
6
6
  packages=find_packages(),
7
7
  install_requires=[
8
8
  "pydantic",
@@ -14,7 +14,8 @@ setup(
14
14
  "aioboto3",
15
15
  "litellm",
16
16
  "httpx",
17
- "aiohttp"
17
+ "aiohttp",
18
+ "tiktoken"
18
19
  ],
19
20
  author="Kapulkin Stanislav",
20
21
  author_email="kapulkin@gmail.com",
@@ -36,7 +36,7 @@ def test_google_timeout_forwarded_sync(monkeypatch):
36
36
  cfg = rec.get("last_config")
37
37
  assert cfg is not None
38
38
  assert cfg.http_options is not None
39
- assert int(cfg.http_options.timeout) == 12
39
+ assert int(cfg.http_options.timeout) == 12000 # Google API expects milliseconds
40
40
 
41
41
 
42
42
  class _FakeAioGoogleModels:
@@ -75,4 +75,4 @@ async def test_google_timeout_forwarded_async(monkeypatch):
75
75
  cfg = rec.get("last_config_async")
76
76
  assert cfg is not None
77
77
  assert cfg.http_options is not None
78
- assert int(cfg.http_options.timeout) == 8
78
+ assert int(cfg.http_options.timeout) == 8500 # Google API expects milliseconds
@@ -1,7 +1,7 @@
1
1
  import pytest
2
2
  from pydantic import BaseModel
3
+ import litellm
3
4
 
4
- import promptbuilder.llm_client.litellm_client as litellm_mod
5
5
  from promptbuilder.llm_client.litellm_client import LiteLLMClient, LiteLLMClientAsync
6
6
  from promptbuilder.llm_client.types import Content, Part
7
7
 
@@ -14,8 +14,10 @@ def test_litellm_timeout_forwarded_sync(monkeypatch):
14
14
  def __init__(self):
15
15
  self.choices = []
16
16
  self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
17
+ def get(self, key, default=None):
18
+ return getattr(self, key, default)
17
19
  return R()
18
- monkeypatch.setattr(litellm_mod, "completion", fake_completion)
20
+ monkeypatch.setattr(litellm, "completion", fake_completion)
19
21
 
20
22
  cli = LiteLLMClient(full_model_name="ollama:llama3.1", api_key=None)
21
23
  _ = cli.create([Content(parts=[Part(text="hi")], role="user")], timeout=7.5)
@@ -33,8 +35,10 @@ async def test_litellm_timeout_forwarded_async(monkeypatch):
33
35
  def __init__(self):
34
36
  self.choices = []
35
37
  self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
38
+ def get(self, key, default=None):
39
+ return getattr(self, key, default)
36
40
  return R()
37
- monkeypatch.setattr(litellm_mod, "acompletion", fake_acompletion)
41
+ monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
38
42
 
39
43
  cli = LiteLLMClientAsync(full_model_name="ollama:llama3.1", api_key=None)
40
44
  _ = await cli.create([Content(parts=[Part(text="hi")], role="user")], timeout=5.0)
@@ -1,403 +0,0 @@
1
- import os
2
- import importlib
3
- from functools import wraps
4
- from typing import AsyncIterator, Iterator, Callable, ParamSpec, Awaitable, Any, cast
5
-
6
- from pydantic import BaseModel, ConfigDict
7
- from tenacity import RetryError
8
-
9
- from vertexai import init as vertex_init
10
- from vertexai.generative_models import GenerativeModel
11
-
12
- from promptbuilder.llm_client.base_client import BaseLLMClient, BaseLLMClientAsync, ResultType
13
- from promptbuilder.llm_client.types import (
14
- Response,
15
- Content,
16
- Candidate,
17
- UsageMetadata,
18
- Part,
19
- PartLike,
20
- ApiKey,
21
- ThinkingConfig,
22
- Tool,
23
- ToolConfig,
24
- Model,
25
- CustomApiKey,
26
- )
27
- from promptbuilder.llm_client.config import DecoratorConfigs
28
- from promptbuilder.llm_client.utils import inherited_decorator
29
- from promptbuilder.llm_client.exceptions import APIError
30
-
31
-
32
- P = ParamSpec("P")
33
-
34
-
35
- class VertexApiKey(BaseModel, CustomApiKey):
36
- model_config = ConfigDict(frozen=True)
37
- project: str
38
- location: str
39
-
40
-
41
- @inherited_decorator
42
- def _error_handler(func: Callable[P, Response]) -> Callable[P, Response]:
43
- @wraps(func)
44
- def wrapper(*args, **kwargs):
45
- try:
46
- return func(*args, **kwargs)
47
- except RetryError as retry_error:
48
- e = retry_error.last_attempt._exception
49
- if e is None:
50
- raise APIError()
51
- code = getattr(e, "code", None)
52
- response_json = {
53
- "status": getattr(e, "status", None),
54
- "message": str(e),
55
- }
56
- response = getattr(e, "response", None)
57
- raise APIError(code, response_json, response)
58
- except Exception as e: # noqa: BLE001
59
- raise APIError(None, {"status": None, "message": str(e)}, None)
60
- return wrapper
61
-
62
-
63
- def _to_vertex_content(messages: list[Content]):
64
- gen_mod = importlib.import_module("vertexai.generative_models")
65
- VPart = getattr(gen_mod, "Part")
66
- VContent = getattr(gen_mod, "Content")
67
- v_messages: list[Any] = []
68
- for m in messages:
69
- v_parts: list[Any] = []
70
- if m.parts:
71
- for p in m.parts:
72
- if p.text is not None:
73
- v_parts.append(VPart.from_text(p.text))
74
- elif p.inline_data is not None and p.inline_data.data is not None:
75
- v_parts.append(VPart.from_bytes(data=p.inline_data.data, mime_type=p.inline_data.mime_type or "application/octet-stream"))
76
- v_messages.append(VContent(role=m.role, parts=v_parts))
77
- return v_messages
78
-
79
-
80
- def _tool_to_vertex(tool: Tool):
81
- VTool = getattr(importlib.import_module("vertexai.generative_models"), "Tool")
82
- if not tool.function_declarations:
83
- return VTool(function_declarations=[])
84
- fds = []
85
- for fd in tool.function_declarations:
86
- fds.append({
87
- "name": fd.name,
88
- "description": fd.description,
89
- "parameters": fd.parameters.model_dump() if fd.parameters is not None else None,
90
- "response": fd.response.model_dump() if fd.response is not None else None,
91
- })
92
- return VTool(function_declarations=fds)
93
-
94
-
95
- def _tool_config_to_vertex(cfg: ToolConfig | None):
96
- VToolConfig = getattr(importlib.import_module("vertexai.generative_models"), "ToolConfig")
97
- if cfg is None or cfg.function_calling_config is None:
98
- return None
99
- mode = cfg.function_calling_config.mode or "AUTO"
100
- allowed = cfg.function_calling_config.allowed_function_names
101
- return VToolConfig(function_calling_config={"mode": mode, "allowedFunctionNames": allowed})
102
-
103
-
104
- def _from_vertex_response(v_resp: Any) -> Response:
105
- candidates: list[Candidate] = []
106
- if getattr(v_resp, "candidates", None):
107
- for c in v_resp.candidates:
108
- parts: list[Part] = []
109
- if c.content and getattr(c.content, "parts", None):
110
- for vp in c.content.parts:
111
- t = getattr(vp, "text", None)
112
- if isinstance(t, str):
113
- parts.append(Part(text=t))
114
- candidates.append(Candidate(content=Content(parts=cast(list[Part | PartLike], parts), role="model")))
115
-
116
- usage = None
117
- um = getattr(v_resp, "usage_metadata", None)
118
- if um is not None:
119
- usage = UsageMetadata(
120
- cached_content_token_count=getattr(um, "cached_content_token_count", None),
121
- candidates_token_count=getattr(um, "candidates_token_count", None),
122
- prompt_token_count=getattr(um, "prompt_token_count", None),
123
- thoughts_token_count=getattr(um, "thoughts_token_count", None),
124
- total_token_count=getattr(um, "total_token_count", None),
125
- )
126
-
127
- return Response(candidates=candidates, usage_metadata=usage)
128
-
129
-
130
- class VertexLLMClient(BaseLLMClient):
131
- PROVIDER: str = "vertexai"
132
-
133
- def __init__(
134
- self,
135
- model: str,
136
- api_key: ApiKey | None = None,
137
- decorator_configs: DecoratorConfigs | None = None,
138
- default_thinking_config: ThinkingConfig | None = None,
139
- default_max_tokens: int | None = None,
140
- project: str | None = None,
141
- location: str | None = None,
142
- **kwargs,
143
- ):
144
- # Resolve project/location from args or env
145
- project = project or os.getenv("VERTEXAI_PROJECT") or os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCLOUD_PROJECT")
146
- location = location or os.getenv("VERTEXAI_LOCATION") or os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("GOOGLE_CLOUD_LOCATION")
147
-
148
- # Allow API Key (string) or ADC (VertexApiKey)
149
- api_key_str: str | None = None
150
- if isinstance(api_key, str):
151
- api_key_str = api_key
152
- elif api_key is None:
153
- # Fallback to env vars for API key
154
- api_key_str = os.getenv("VERTEX_API_KEY") or os.getenv("GOOGLE_API_KEY")
155
- elif isinstance(api_key, VertexApiKey):
156
- # ADC path with explicit project/location
157
- pass
158
- else:
159
- # Unexpected CustomApiKey subtype
160
- raise ValueError("Unsupported api_key type for Vertex: expected str or VertexApiKey")
161
-
162
- if not project or not location:
163
- raise ValueError("To create a vertexai llm client you need to provide project and location via args or env vars VERTEXAI_PROJECT and VERTEXAI_LOCATION")
164
-
165
- if not isinstance(api_key, VertexApiKey):
166
- api_key = VertexApiKey(project=project, location=location)
167
-
168
- super().__init__(
169
- VertexLLMClient.PROVIDER,
170
- model,
171
- decorator_configs=decorator_configs,
172
- default_thinking_config=default_thinking_config,
173
- default_max_tokens=default_max_tokens,
174
- )
175
- self._api_key = api_key
176
- self._api_key_str = api_key_str
177
-
178
- vertex_init(project=self._api_key.project, location=self._api_key.location)
179
- self._model = GenerativeModel(self.model)
180
-
181
- @property
182
- def api_key(self) -> VertexApiKey:
183
- return self._api_key
184
-
185
- @_error_handler
186
- def _create(
187
- self,
188
- messages: list[Content],
189
- result_type: ResultType = None,
190
- *,
191
- thinking_config: ThinkingConfig | None = None,
192
- system_message: str | None = None,
193
- max_tokens: int | None = None,
194
- timeout: float | None = None,
195
- tools: list[Tool] | None = None,
196
- tool_config: ToolConfig = ToolConfig(),
197
- ) -> Response:
198
- v_messages = _to_vertex_content(messages)
199
- GenerationConfig = getattr(importlib.import_module("vertexai.generative_models"), "GenerationConfig")
200
- gen_cfg = GenerationConfig(max_output_tokens=max_tokens or self.default_max_tokens)
201
-
202
- # Handle thinking config
203
- if thinking_config is None:
204
- thinking_config = self.default_thinking_config
205
- if thinking_config is not None:
206
- # Vertex AI supports thinking via response_logprobs and logprobs parameters
207
- # but the exact implementation may vary - for now, we'll store it for potential future use
208
- pass
209
-
210
- req_opts: dict[str, Any] | None = {}
211
- if timeout is not None:
212
- req_opts["timeout"] = timeout
213
- if self._api_key_str:
214
- req_opts["api_key"] = self._api_key_str
215
- if not req_opts:
216
- req_opts = None
217
-
218
- v_tools = None
219
- if tools is not None:
220
- v_tools = [_tool_to_vertex(t) for t in tools]
221
- v_tool_cfg = _tool_config_to_vertex(tool_config)
222
-
223
- v_resp = self._model.generate_content(
224
- contents=v_messages,
225
- generation_config=gen_cfg,
226
- tools=v_tools,
227
- tool_config=v_tool_cfg,
228
- system_instruction=system_message,
229
- request_options=req_opts,
230
- )
231
-
232
- resp = _from_vertex_response(v_resp)
233
- if result_type == "json" and resp.text is not None:
234
- resp.parsed = BaseLLMClient.as_json(resp.text)
235
- elif isinstance(result_type, type(BaseModel)) and resp.text is not None:
236
- parsed = BaseLLMClient.as_json(resp.text)
237
- resp.parsed = result_type.model_validate(parsed)
238
- return resp
239
-
240
- def create_stream(
241
- self,
242
- messages: list[Content],
243
- *,
244
- thinking_config: ThinkingConfig | None = None,
245
- system_message: str | None = None,
246
- max_tokens: int | None = None,
247
- ) -> Iterator[Response]:
248
- v_messages = _to_vertex_content(messages)
249
- GenerationConfig = getattr(importlib.import_module("vertexai.generative_models"), "GenerationConfig")
250
- gen_cfg = GenerationConfig(max_output_tokens=max_tokens or self.default_max_tokens)
251
-
252
- # Handle thinking config
253
- if thinking_config is None:
254
- thinking_config = self.default_thinking_config
255
- if thinking_config is not None:
256
- # Store for potential future use when Vertex AI supports thinking features
257
- pass
258
-
259
- req_opts: dict[str, Any] | None = {}
260
- if self._api_key_str:
261
- req_opts["api_key"] = self._api_key_str
262
- if not req_opts:
263
- req_opts = None
264
- stream = self._model.generate_content(
265
- contents=v_messages,
266
- generation_config=gen_cfg,
267
- system_instruction=system_message,
268
- request_options=req_opts,
269
- stream=True,
270
- )
271
- for ev in stream:
272
- yield _from_vertex_response(ev)
273
-
274
- @staticmethod
275
- def models_list() -> list[Model]:
276
- return []
277
-
278
-
279
- @inherited_decorator
280
- def _error_handler_async(func: Callable[P, Awaitable[Response]]) -> Callable[P, Awaitable[Response]]:
281
- @wraps(func)
282
- async def wrapper(*args, **kwargs):
283
- try:
284
- return await func(*args, **kwargs)
285
- except RetryError as retry_error:
286
- e = retry_error.last_attempt._exception
287
- if e is None:
288
- raise APIError()
289
- code = getattr(e, "code", None)
290
- response_json = {
291
- "status": getattr(e, "status", None),
292
- "message": str(e),
293
- }
294
- response = getattr(e, "response", None)
295
- raise APIError(code, response_json, response)
296
- except Exception as e: # noqa: BLE001
297
- raise APIError(None, {"status": None, "message": str(e)}, None)
298
- return wrapper
299
-
300
-
301
- class VertexLLMClientAsync(BaseLLMClientAsync):
302
- PROVIDER: str = "vertexai"
303
-
304
- def __init__(
305
- self,
306
- model: str,
307
- api_key: ApiKey | None = None,
308
- decorator_configs: DecoratorConfigs | None = None,
309
- default_thinking_config: ThinkingConfig | None = None,
310
- default_max_tokens: int | None = None,
311
- project: str | None = None,
312
- location: str | None = None,
313
- **kwargs,
314
- ):
315
- project = project or os.getenv("VERTEXAI_PROJECT") or os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCLOUD_PROJECT")
316
- location = location or os.getenv("VERTEXAI_LOCATION") or os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("GOOGLE_CLOUD_LOCATION")
317
-
318
- api_key_str: str | None = None
319
- if isinstance(api_key, str):
320
- api_key_str = api_key
321
- elif api_key is None:
322
- api_key_str = os.getenv("VERTEX_API_KEY") or os.getenv("GOOGLE_API_KEY")
323
- elif isinstance(api_key, VertexApiKey):
324
- pass
325
- else:
326
- raise ValueError("Unsupported api_key type for Vertex: expected str or VertexApiKey")
327
-
328
- if not project or not location:
329
- raise ValueError("To create a vertexai llm client you need to provide project and location via args or env vars VERTEXAI_PROJECT and VERTEXAI_LOCATION")
330
-
331
- if not isinstance(api_key, VertexApiKey):
332
- api_key = VertexApiKey(project=project, location=location)
333
-
334
- super().__init__(
335
- VertexLLMClientAsync.PROVIDER,
336
- model,
337
- decorator_configs=decorator_configs,
338
- default_thinking_config=default_thinking_config,
339
- default_max_tokens=default_max_tokens,
340
- )
341
- self._api_key = api_key
342
- self._api_key_str = api_key_str
343
-
344
- vertex_init(project=self._api_key.project, location=self._api_key.location)
345
- self._model = GenerativeModel(self.model)
346
-
347
- @property
348
- def api_key(self) -> VertexApiKey:
349
- return self._api_key
350
-
351
- @_error_handler_async
352
- async def _create(
353
- self,
354
- messages: list[Content],
355
- result_type: ResultType = None,
356
- *,
357
- thinking_config: ThinkingConfig | None = None,
358
- system_message: str | None = None,
359
- max_tokens: int | None = None,
360
- timeout: float | None = None,
361
- tools: list[Tool] | None = None,
362
- tool_config: ToolConfig = ToolConfig(),
363
- ) -> Response:
364
- # Reuse sync implementation (SDK is sync). For real async, offload to thread.
365
- client = VertexLLMClient(
366
- model=self.model,
367
- api_key=self._api_key,
368
- decorator_configs=self._decorator_configs,
369
- default_thinking_config=self.default_thinking_config,
370
- default_max_tokens=self.default_max_tokens,
371
- )
372
- return client._create(
373
- messages=messages,
374
- result_type=result_type,
375
- thinking_config=thinking_config,
376
- system_message=system_message,
377
- max_tokens=max_tokens,
378
- timeout=timeout,
379
- tools=tools,
380
- tool_config=tool_config,
381
- )
382
-
383
- async def create_stream(
384
- self,
385
- messages: list[Content],
386
- *,
387
- thinking_config: ThinkingConfig | None = None,
388
- system_message: str | None = None,
389
- max_tokens: int | None = None,
390
- ) -> AsyncIterator[Response]:
391
- # Provide a simple wrapper yielding once (non-streaming)
392
- resp = await self._create(
393
- messages=messages,
394
- result_type=None,
395
- thinking_config=thinking_config,
396
- system_message=system_message,
397
- max_tokens=max_tokens,
398
- )
399
- yield resp
400
-
401
- @staticmethod
402
- def models_list() -> list[Model]:
403
- return VertexLLMClient.models_list()
File without changes
File without changes
File without changes