lm-deluge 0.0.13__tar.gz → 0.0.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (82) hide show
  1. {lm_deluge-0.0.13/src/lm_deluge.egg-info → lm_deluge-0.0.15}/PKG-INFO +4 -1
  2. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/README.md +3 -0
  3. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/pyproject.toml +1 -1
  4. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/__init__.py +2 -0
  5. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/base.py +2 -148
  6. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/common.py +2 -0
  7. lm_deluge-0.0.15/src/lm_deluge/api_requests/gemini.py +222 -0
  8. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/openai.py +72 -6
  9. lm_deluge-0.0.15/src/lm_deluge/api_requests/response.py +153 -0
  10. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/client.py +36 -48
  11. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/config.py +3 -2
  12. lm_deluge-0.0.15/src/lm_deluge/file.py +154 -0
  13. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/models.py +57 -0
  14. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/prompt.py +70 -9
  15. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/tracker.py +5 -3
  16. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/usage.py +10 -0
  17. {lm_deluge-0.0.13 → lm_deluge-0.0.15/src/lm_deluge.egg-info}/PKG-INFO +4 -1
  18. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge.egg-info/SOURCES.txt +8 -0
  19. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_all_models.py +24 -24
  20. lm_deluge-0.0.15/tests/test_file_integration.py +156 -0
  21. lm_deluge-0.0.15/tests/test_file_support.py +210 -0
  22. lm_deluge-0.0.15/tests/test_gemini_integration.py +238 -0
  23. lm_deluge-0.0.15/tests/test_retry_fix.py +67 -0
  24. lm_deluge-0.0.15/tests/test_simple_gemini.py +32 -0
  25. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/LICENSE +0 -0
  26. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/setup.cfg +0 -0
  27. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/agent.py +0 -0
  28. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/__init__.py +0 -0
  29. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/anthropic.py +0 -0
  30. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/bedrock.py +0 -0
  31. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/deprecated/bedrock.py +0 -0
  32. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/deprecated/cohere.py +0 -0
  33. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/deprecated/deepseek.py +0 -0
  34. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/deprecated/mistral.py +0 -0
  35. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/deprecated/vertex.py +0 -0
  36. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/api_requests/mistral.py +0 -0
  37. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/batches.py +0 -0
  38. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/cache.py +0 -0
  39. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/computer_use/anthropic_tools.py +0 -0
  40. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/embed.py +0 -0
  41. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/errors.py +0 -0
  42. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/gemini_limits.py +0 -0
  43. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/image.py +0 -0
  44. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/llm_tools/__init__.py +0 -0
  45. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/llm_tools/extract.py +0 -0
  46. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/llm_tools/score.py +0 -0
  47. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/llm_tools/translate.py +0 -0
  48. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/rerank.py +0 -0
  49. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/tool.py +0 -0
  50. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/util/json.py +0 -0
  51. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/util/logprobs.py +0 -0
  52. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/util/validation.py +0 -0
  53. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge/util/xml.py +0 -0
  54. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge.egg-info/dependency_links.txt +0 -0
  55. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge.egg-info/requires.txt +0 -0
  56. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/src/lm_deluge.egg-info/top_level.txt +0 -0
  57. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_batch_real.py +0 -0
  58. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_bedrock_computer_use.py +0 -0
  59. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_bedrock_models.py +0 -0
  60. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_cache.py +0 -0
  61. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_client_tracker_integration.py +0 -0
  62. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_computer_use.py +0 -0
  63. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_computer_use_integration.py +0 -0
  64. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_debug_format.py +0 -0
  65. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_image_models.py +0 -0
  66. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_image_utils.py +0 -0
  67. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_json_utils.py +0 -0
  68. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_logprobs_refactor.py +0 -0
  69. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_max_concurrent_requests.py +0 -0
  70. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_mcp_tools.py +0 -0
  71. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_openai_responses.py +0 -0
  72. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_prompt_caching.py +0 -0
  73. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_real_caching.py +0 -0
  74. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_real_caching_bedrock.py +0 -0
  75. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_rich_display.py +0 -0
  76. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_sampling_params.py +0 -0
  77. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_tool_calls.py +0 -0
  78. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_tool_from_function.py +0 -0
  79. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_tool_validation.py +0 -0
  80. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_tracker_refactor.py +0 -0
  81. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_translate.py +0 -0
  82. {lm_deluge-0.0.13 → lm_deluge-0.0.15}/tests/test_xml_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lm_deluge
3
- Version: 0.0.13
3
+ Version: 0.0.15
4
4
  Summary: Python utility for using LLM API models.
5
5
  Author-email: Benjamin Anderson <ben@trytaylor.ai>
6
6
  Requires-Python: >=3.10
@@ -30,6 +30,7 @@ Dynamic: license-file
30
30
  `lm-deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
31
31
 
32
32
  - **Unified client** – Send prompts to all relevant models with a single client.
33
+ - **Files and Images** - Include images easily for multimodal models, and PDF files for models that support them (OpenAI and Anthropic).
33
34
  - **Massive concurrency with throttling** – Set `max_tokens_per_minute` and `max_requests_per_minute` and let it fly. The client will process as many requests as possible while respecting rate limits and retrying failures.
34
35
  - **Spray across models/providers** – Configure a client with multiple models from any provider(s), and sampling weights. The client samples a model for each request.
35
36
  - **Tool Use** – Unified API for defining tools for all providers, and creating tools automatically from python functions.
@@ -41,6 +42,8 @@ Dynamic: license-file
41
42
 
42
43
  **STREAMING IS NOT IN SCOPE.** There are plenty of packages that let you stream chat completions across providers. The sole purpose of this package is to do very fast batch inference using APIs. Sorry!
43
44
 
45
+ **Update 06/02/2025:** I lied, it supports (very basic) streaming now via client.stream(...). It will print tokens as they arrive, then return an APIResponse at the end. More sophisticated streaming may or may not be implemented later, don't count on it.
46
+
44
47
  ## Installation
45
48
 
46
49
  ```bash
@@ -3,6 +3,7 @@
3
3
  `lm-deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
4
4
 
5
5
  - **Unified client** – Send prompts to all relevant models with a single client.
6
+ - **Files and Images** - Include images easily for multimodal models, and PDF files for models that support them (OpenAI and Anthropic).
6
7
  - **Massive concurrency with throttling** – Set `max_tokens_per_minute` and `max_requests_per_minute` and let it fly. The client will process as many requests as possible while respecting rate limits and retrying failures.
7
8
  - **Spray across models/providers** – Configure a client with multiple models from any provider(s), and sampling weights. The client samples a model for each request.
8
9
  - **Tool Use** – Unified API for defining tools for all providers, and creating tools automatically from python functions.
@@ -14,6 +15,8 @@
14
15
 
15
16
  **STREAMING IS NOT IN SCOPE.** There are plenty of packages that let you stream chat completions across providers. The sole purpose of this package is to do very fast batch inference using APIs. Sorry!
16
17
 
18
+ **Update 06/02/2025:** I lied, it supports (very basic) streaming now via client.stream(...). It will print tokens as they arrive, then return an APIResponse at the end. More sophisticated streaming may or may not be implemented later, don't count on it.
19
+
17
20
  ## Installation
18
21
 
19
22
  ```bash
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
 
4
4
  [project]
5
5
  name = "lm_deluge"
6
- version = "0.0.13"
6
+ version = "0.0.15"
7
7
  authors = [{ name = "Benjamin Anderson", email = "ben@trytaylor.ai" }]
8
8
  description = "Python utility for using LLM API models."
9
9
  readme = "README.md"
@@ -1,6 +1,7 @@
1
1
  from .client import LLMClient, SamplingParams, APIResponse
2
2
  from .prompt import Conversation, Message
3
3
  from .tool import Tool
4
+ from .file import File
4
5
  import dotenv
5
6
 
6
7
  dotenv.load_dotenv()
@@ -12,4 +13,5 @@ __all__ = [
12
13
  "Conversation",
13
14
  "Message",
14
15
  "Tool",
16
+ "File",
15
17
  ]
@@ -1,165 +1,19 @@
1
1
  import asyncio
2
- import json
3
2
  import random
4
3
  import traceback
5
4
  from abc import ABC, abstractmethod
6
- from dataclasses import dataclass
7
5
  from typing import Callable
8
6
 
9
7
  import aiohttp
10
8
  from aiohttp import ClientResponse
11
9
 
12
- from lm_deluge.prompt import CachePattern, Conversation, Message
13
- from lm_deluge.usage import Usage
10
+ from lm_deluge.prompt import CachePattern, Conversation
14
11
 
15
12
  from ..config import SamplingParams
16
13
  from ..errors import raise_if_modal_exception
17
14
  from ..models import APIModel
18
15
  from ..tracker import StatusTracker
19
-
20
-
21
- @dataclass
22
- class APIResponse:
23
- # request information
24
- id: int # should be unique to the request within a given prompt-processing call
25
- model_internal: str # our internal model tag
26
- prompt: Conversation
27
- sampling_params: SamplingParams
28
-
29
- # http response information
30
- status_code: int | None
31
- is_error: bool | None
32
- error_message: str | None
33
-
34
- # completion information - unified usage tracking
35
- usage: Usage | None = None
36
-
37
- # response content - structured format
38
- content: Message | None = None
39
-
40
- # optional or calculated automatically
41
- thinking: str | None = None # if model shows thinking tokens
42
- model_external: str | None = None # the model tag used by the API
43
- region: str | None = None
44
- logprobs: list | None = None
45
- finish_reason: str | None = None # make required later
46
- cost: float | None = None # calculated automatically
47
- cache_hit: bool = False # manually set if true
48
- # set to true if is_error and should be retried with a different model
49
- retry_with_different_model: bool | None = False
50
- # set to true if should NOT retry with the same model (unrecoverable error)
51
- give_up_if_no_other_models: bool | None = False
52
- # OpenAI Responses API specific - used for computer use continuation
53
- response_id: str | None = None
54
- # Raw API response for debugging
55
- raw_response: dict | None = None
56
-
57
- @property
58
- def completion(self) -> str | None:
59
- """Backward compatibility: extract text from content Message."""
60
- if self.content is not None:
61
- return self.content.completion
62
- return None
63
-
64
- @property
65
- def input_tokens(self) -> int | None:
66
- """Get input tokens from usage object."""
67
- return self.usage.input_tokens if self.usage else None
68
-
69
- @property
70
- def output_tokens(self) -> int | None:
71
- """Get output tokens from usage object."""
72
- return self.usage.output_tokens if self.usage else None
73
-
74
- @property
75
- def cache_read_tokens(self) -> int | None:
76
- """Get cache read tokens from usage object."""
77
- return self.usage.cache_read_tokens if self.usage else None
78
-
79
- @property
80
- def cache_write_tokens(self) -> int | None:
81
- """Get cache write tokens from usage object."""
82
- return self.usage.cache_write_tokens if self.usage else None
83
-
84
- def __post_init__(self):
85
- # calculate cost & get external model name
86
- self.id = int(self.id)
87
- api_model = APIModel.from_registry(self.model_internal)
88
- self.model_external = api_model.name
89
- self.cost = None
90
- if (
91
- self.usage is not None
92
- and api_model.input_cost is not None
93
- and api_model.output_cost is not None
94
- ):
95
- self.cost = (
96
- self.usage.input_tokens * api_model.input_cost / 1e6
97
- + self.usage.output_tokens * api_model.output_cost / 1e6
98
- )
99
- elif self.content is not None and self.completion is not None:
100
- print(
101
- f"Warning: Completion provided without token counts for model {self.model_internal}."
102
- )
103
-
104
- def to_dict(self):
105
- return {
106
- "id": self.id,
107
- "model_internal": self.model_internal,
108
- "model_external": self.model_external,
109
- "region": self.region,
110
- "prompt": self.prompt.to_log(), # destroys image if present
111
- "sampling_params": self.sampling_params.__dict__,
112
- "status_code": self.status_code,
113
- "is_error": self.is_error,
114
- "error_message": self.error_message,
115
- "completion": self.completion, # computed property
116
- "content": self.content.to_log() if self.content else None,
117
- "usage": self.usage.to_dict() if self.usage else None,
118
- "finish_reason": self.finish_reason,
119
- "cost": self.cost,
120
- }
121
-
122
- @classmethod
123
- def from_dict(cls, data: dict):
124
- # Handle backward compatibility for content/completion
125
- content = None
126
- if "content" in data and data["content"] is not None:
127
- # Reconstruct message from log format
128
- content = Message.from_log(data["content"])
129
- elif "completion" in data and data["completion"] is not None:
130
- # Backward compatibility: create a Message with just text
131
- content = Message.ai(data["completion"])
132
-
133
- usage = None
134
- if "usage" in data and data["usage"] is not None:
135
- usage = Usage.from_dict(data["usage"])
136
-
137
- return cls(
138
- id=data.get("id", random.randint(0, 1_000_000_000)),
139
- model_internal=data["model_internal"],
140
- prompt=Conversation.from_log(data["prompt"]),
141
- sampling_params=SamplingParams(**data["sampling_params"]),
142
- status_code=data["status_code"],
143
- is_error=data["is_error"],
144
- error_message=data["error_message"],
145
- usage=usage,
146
- content=content,
147
- thinking=data.get("thinking"),
148
- model_external=data.get("model_external"),
149
- region=data.get("region"),
150
- logprobs=data.get("logprobs"),
151
- finish_reason=data.get("finish_reason"),
152
- cost=data.get("cost"),
153
- cache_hit=data.get("cache_hit", False),
154
- )
155
-
156
- def write_to_file(self, filename):
157
- """
158
- Writes the APIResponse as a line to a file.
159
- If file exists, appends to it.
160
- """
161
- with open(filename, "a") as f:
162
- f.write(json.dumps(self.to_dict()) + "\n")
16
+ from .response import APIResponse
163
17
 
164
18
 
165
19
  class APIRequestBase(ABC):
@@ -2,6 +2,7 @@ from .openai import OpenAIRequest, OpenAIResponsesRequest
2
2
  from .anthropic import AnthropicRequest
3
3
  from .mistral import MistralRequest
4
4
  from .bedrock import BedrockRequest
5
+ from .gemini import GeminiRequest
5
6
 
6
7
  CLASSES = {
7
8
  "openai": OpenAIRequest,
@@ -9,4 +10,5 @@ CLASSES = {
9
10
  "anthropic": AnthropicRequest,
10
11
  "mistral": MistralRequest,
11
12
  "bedrock": BedrockRequest,
13
+ "gemini": GeminiRequest,
12
14
  }
@@ -0,0 +1,222 @@
1
+ import json
2
+ import os
3
+ import warnings
4
+ from typing import Callable
5
+
6
+ from aiohttp import ClientResponse
7
+
8
+ from lm_deluge.tool import Tool
9
+
10
+ from ..config import SamplingParams
11
+ from ..models import APIModel
12
+ from ..prompt import CachePattern, Conversation, Message, Text, Thinking, ToolCall
13
+ from ..tracker import StatusTracker
14
+ from ..usage import Usage
15
+ from .base import APIRequestBase, APIResponse
16
+
17
+
18
+ def _build_gemini_request(
19
+ model: APIModel,
20
+ prompt: Conversation,
21
+ tools: list[Tool] | None,
22
+ sampling_params: SamplingParams,
23
+ ) -> dict:
24
+ system_message, messages = prompt.to_gemini()
25
+
26
+ request_json = {
27
+ "contents": messages,
28
+ "generationConfig": {
29
+ "temperature": sampling_params.temperature,
30
+ "topP": sampling_params.top_p,
31
+ "maxOutputTokens": sampling_params.max_new_tokens,
32
+ },
33
+ }
34
+
35
+ # Add system instruction if present
36
+ if system_message:
37
+ request_json["systemInstruction"] = {"parts": [{"text": system_message}]}
38
+
39
+ # Handle reasoning models (thinking)
40
+ if model.reasoning_model:
41
+ request_json["generationConfig"]["thinkingConfig"] = {"includeThoughts": True}
42
+ if sampling_params.reasoning_effort and "flash" in model.id:
43
+ budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
44
+ sampling_params.reasoning_effort
45
+ )
46
+ request_json["generationConfig"]["thinkingConfig"]["thinkingBudget"] = (
47
+ budget
48
+ )
49
+
50
+ else:
51
+ if sampling_params.reasoning_effort:
52
+ warnings.warn(
53
+ f"Ignoring reasoning_effort param for non-reasoning model: {model.name}"
54
+ )
55
+
56
+ # Add tools if provided
57
+ if tools:
58
+ tool_declarations = [tool.dump_for("google") for tool in tools]
59
+ request_json["tools"] = [{"functionDeclarations": tool_declarations}]
60
+
61
+ # Handle JSON mode
62
+ if sampling_params.json_mode and model.supports_json:
63
+ request_json["generationConfig"]["responseMimeType"] = "application/json"
64
+
65
+ return request_json
66
+
67
+
68
+ class GeminiRequest(APIRequestBase):
69
+ def __init__(
70
+ self,
71
+ task_id: int,
72
+ model_name: str, # must correspond to registry
73
+ prompt: Conversation,
74
+ attempts_left: int,
75
+ status_tracker: StatusTracker,
76
+ results_arr: list,
77
+ request_timeout: int = 30,
78
+ sampling_params: SamplingParams = SamplingParams(),
79
+ callback: Callable | None = None,
80
+ all_model_names: list[str] | None = None,
81
+ all_sampling_params: list[SamplingParams] | None = None,
82
+ tools: list | None = None,
83
+ cache: CachePattern | None = None,
84
+ ):
85
+ super().__init__(
86
+ task_id=task_id,
87
+ model_name=model_name,
88
+ prompt=prompt,
89
+ attempts_left=attempts_left,
90
+ status_tracker=status_tracker,
91
+ results_arr=results_arr,
92
+ request_timeout=request_timeout,
93
+ sampling_params=sampling_params,
94
+ callback=callback,
95
+ all_model_names=all_model_names,
96
+ all_sampling_params=all_sampling_params,
97
+ tools=tools,
98
+ cache=cache,
99
+ )
100
+
101
+ # Warn if cache is specified for Gemini model
102
+ if cache is not None:
103
+ warnings.warn(
104
+ f"Cache parameter '{cache}' is not supported for Gemini models, ignoring for {model_name}"
105
+ )
106
+
107
+ self.model = APIModel.from_registry(model_name)
108
+ # Gemini API endpoint format: https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent
109
+ self.url = f"{self.model.api_base}/models/{self.model.name}:generateContent"
110
+ self.request_header = {
111
+ "Content-Type": "application/json",
112
+ }
113
+
114
+ # Add API key as query parameter for Gemini
115
+ api_key = os.getenv(self.model.api_key_env_var)
116
+ if not api_key:
117
+ raise ValueError(
118
+ f"API key environment variable {self.model.api_key_env_var} not set"
119
+ )
120
+ self.url += f"?key={api_key}"
121
+
122
+ self.request_json = _build_gemini_request(
123
+ self.model, prompt, tools, sampling_params
124
+ )
125
+
126
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
127
+ is_error = False
128
+ error_message = None
129
+ thinking = None
130
+ content = None
131
+ usage = None
132
+ status_code = http_response.status
133
+ mimetype = http_response.headers.get("Content-Type", None)
134
+ data = None
135
+
136
+ if status_code >= 200 and status_code < 300:
137
+ try:
138
+ data = await http_response.json()
139
+ except Exception as e:
140
+ is_error = True
141
+ error_message = (
142
+ f"Error calling .json() on response w/ status {status_code}: {e}"
143
+ )
144
+
145
+ if not is_error:
146
+ assert data
147
+ try:
148
+ # Parse Gemini response format
149
+ parts = []
150
+
151
+ if "candidates" in data and data["candidates"]:
152
+ candidate = data["candidates"][0]
153
+ if "content" in candidate and "parts" in candidate["content"]:
154
+ for part in candidate["content"]["parts"]:
155
+ if "text" in part:
156
+ parts.append(Text(part["text"]))
157
+ elif "thought" in part:
158
+ parts.append(Thinking(part["thought"]))
159
+ elif "functionCall" in part:
160
+ func_call = part["functionCall"]
161
+ # Generate a unique ID since Gemini doesn't provide one
162
+ import uuid
163
+
164
+ tool_id = f"call_{uuid.uuid4().hex[:8]}"
165
+ parts.append(
166
+ ToolCall(
167
+ id=tool_id,
168
+ name=func_call["name"],
169
+ arguments=func_call.get("args", {}),
170
+ )
171
+ )
172
+
173
+ content = Message("assistant", parts)
174
+
175
+ # Extract usage information if present
176
+ if "usageMetadata" in data:
177
+ usage_data = data["usageMetadata"]
178
+ usage = Usage.from_gemini_usage(usage_data)
179
+
180
+ except Exception as e:
181
+ is_error = True
182
+ error_message = f"Error parsing Gemini response: {str(e)}"
183
+
184
+ elif mimetype and "json" in mimetype.lower():
185
+ is_error = True
186
+ try:
187
+ data = await http_response.json()
188
+ error_message = json.dumps(data)
189
+ except Exception:
190
+ error_message = (
191
+ f"HTTP {status_code} with JSON content type but failed to parse"
192
+ )
193
+ else:
194
+ is_error = True
195
+ text = await http_response.text()
196
+ error_message = text
197
+
198
+ # Handle special kinds of errors
199
+ if is_error and error_message is not None:
200
+ if "rate limit" in error_message.lower() or status_code == 429:
201
+ error_message += " (Rate limit error, triggering cooldown.)"
202
+ self.status_tracker.rate_limit_exceeded()
203
+ if (
204
+ "context length" in error_message.lower()
205
+ or "token limit" in error_message.lower()
206
+ ):
207
+ error_message += " (Context length exceeded, set retries to 0.)"
208
+ self.attempts_left = 0
209
+
210
+ return APIResponse(
211
+ id=self.task_id,
212
+ status_code=status_code,
213
+ is_error=is_error,
214
+ error_message=error_message,
215
+ prompt=self.prompt,
216
+ content=content,
217
+ thinking=thinking,
218
+ model_internal=self.model_name,
219
+ sampling_params=self.sampling_params,
220
+ usage=usage,
221
+ raw_response=data,
222
+ )
@@ -1,17 +1,19 @@
1
- import warnings
2
- from aiohttp import ClientResponse
3
1
  import json
4
2
  import os
3
+ import warnings
5
4
  from typing import Callable
6
5
 
6
+ import aiohttp
7
+ from aiohttp import ClientResponse
8
+
7
9
  from lm_deluge.tool import Tool
8
10
 
9
- from .base import APIRequestBase, APIResponse
10
- from ..prompt import Conversation, Message, Text, ToolCall, Thinking, CachePattern
11
- from ..usage import Usage
12
- from ..tracker import StatusTracker
13
11
  from ..config import SamplingParams
14
12
  from ..models import APIModel
13
+ from ..prompt import CachePattern, Conversation, Message, Text, Thinking, ToolCall
14
+ from ..tracker import StatusTracker
15
+ from ..usage import Usage
16
+ from .base import APIRequestBase, APIResponse
15
17
 
16
18
 
17
19
  def _build_oa_chat_request(
@@ -111,6 +113,7 @@ class OpenAIRequest(APIRequestBase):
111
113
  status_code = http_response.status
112
114
  mimetype = http_response.headers.get("Content-Type", None)
113
115
  data = None
116
+ finish_reason = None
114
117
  if status_code >= 200 and status_code < 300:
115
118
  try:
116
119
  data = await http_response.json()
@@ -125,6 +128,7 @@ class OpenAIRequest(APIRequestBase):
125
128
  # Parse response into Message with parts
126
129
  parts = []
127
130
  message = data["choices"][0]["message"]
131
+ finish_reason = data["choices"][0]["finish_reason"]
128
132
 
129
133
  # Add text content if present
130
134
  if message.get("content"):
@@ -190,6 +194,7 @@ class OpenAIRequest(APIRequestBase):
190
194
  sampling_params=self.sampling_params,
191
195
  usage=usage,
192
196
  raw_response=data,
197
+ finish_reason=finish_reason,
193
198
  )
194
199
 
195
200
 
@@ -266,6 +271,13 @@ class OpenAIResponsesRequest(APIRequestBase):
266
271
  self.request_json["max_output_tokens"] = sampling_params.max_new_tokens
267
272
 
268
273
  if self.model.reasoning_model:
274
+ if sampling_params.reasoning_effort in [None, "none"]:
275
+ # gemini models can switch reasoning off
276
+ if "gemini" in self.model.id:
277
+ self.sampling_params.reasoning_effort = "none" # expects string
278
+ # openai models can only go down to "low"
279
+ else:
280
+ self.sampling_params.reasoning_effort = "low"
269
281
  self.request_json["temperature"] = 1.0
270
282
  self.request_json["top_p"] = 1.0
271
283
  self.request_json["reasoning"] = {
@@ -413,3 +425,57 @@ class OpenAIResponsesRequest(APIRequestBase):
413
425
  usage=usage,
414
426
  raw_response=data,
415
427
  )
428
+
429
+
430
+ async def stream_chat(
431
+ model_name: str, # must correspond to registry
432
+ prompt: Conversation,
433
+ sampling_params: SamplingParams = SamplingParams(),
434
+ tools: list | None = None,
435
+ cache: CachePattern | None = None,
436
+ ):
437
+ if cache is not None:
438
+ warnings.warn(
439
+ f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
440
+ )
441
+
442
+ model = APIModel.from_registry(model_name)
443
+ if model.api_spec != "openai":
444
+ raise ValueError("streaming only supported on openai models for now")
445
+ url = f"{model.api_base}/chat/completions"
446
+ request_header = {"Authorization": f"Bearer {os.getenv(model.api_key_env_var)}"}
447
+ request_json = _build_oa_chat_request(model, prompt, tools, sampling_params)
448
+ request_json["stream"] = True
449
+
450
+ async with aiohttp.ClientSession() as s:
451
+ async with s.post(url, headers=request_header, json=request_json) as r:
452
+ r.raise_for_status() # bail on 4xx/5xx
453
+ content = ""
454
+ buf = ""
455
+ async for chunk in r.content.iter_any(): # raw bytes
456
+ buf += chunk.decode()
457
+ while "\n\n" in buf: # full SSE frame
458
+ event, buf = buf.split("\n\n", 1)
459
+ if not event.startswith("data:"):
460
+ continue # ignore comments
461
+ data = event[5:].strip() # after "data:"
462
+ if data == "[DONE]":
463
+ yield APIResponse(
464
+ id=0,
465
+ status_code=None,
466
+ is_error=False,
467
+ error_message=None,
468
+ prompt=prompt,
469
+ content=Message(
470
+ role="assistant", parts=[Text(text=content)]
471
+ ),
472
+ model_internal=model.id,
473
+ sampling_params=sampling_params,
474
+ usage=None,
475
+ raw_response=None,
476
+ )
477
+ msg = json.loads(data) # SSE payload
478
+ delta = msg["choices"][0]["delta"].get("content")
479
+ if delta:
480
+ content += delta
481
+ yield delta