langchain-maritaca 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_maritaca/__init__.py +14 -0
- langchain_maritaca/chat_models.py +794 -0
- langchain_maritaca/embeddings.py +289 -0
- langchain_maritaca/py.typed +0 -0
- langchain_maritaca/version.py +7 -0
- langchain_maritaca-0.2.2.dist-info/METADATA +274 -0
- langchain_maritaca-0.2.2.dist-info/RECORD +9 -0
- langchain_maritaca-0.2.2.dist-info/WHEEL +4 -0
- langchain_maritaca-0.2.2.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,794 @@
|
|
|
1
|
+
"""Maritaca AI Chat wrapper for LangChain.
|
|
2
|
+
|
|
3
|
+
Maritaca AI provides Brazilian Portuguese-optimized language models,
|
|
4
|
+
including the Sabiá family of models.
|
|
5
|
+
|
|
6
|
+
Author: Anderson Henrique da Silva
|
|
7
|
+
Location: Minas Gerais, Brasil
|
|
8
|
+
GitHub: https://github.com/anderson-ufrj
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import asyncio
|
|
14
|
+
import json
|
|
15
|
+
import time
|
|
16
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
|
17
|
+
from operator import itemgetter
|
|
18
|
+
from typing import Any, Literal
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
from langchain_core.callbacks import (
|
|
22
|
+
AsyncCallbackManagerForLLMRun,
|
|
23
|
+
CallbackManagerForLLMRun,
|
|
24
|
+
)
|
|
25
|
+
from langchain_core.language_models.chat_models import (
|
|
26
|
+
BaseChatModel,
|
|
27
|
+
LangSmithParams,
|
|
28
|
+
agenerate_from_stream,
|
|
29
|
+
generate_from_stream,
|
|
30
|
+
)
|
|
31
|
+
from langchain_core.messages import (
|
|
32
|
+
AIMessage,
|
|
33
|
+
AIMessageChunk,
|
|
34
|
+
BaseMessage,
|
|
35
|
+
ChatMessage,
|
|
36
|
+
HumanMessage,
|
|
37
|
+
SystemMessage,
|
|
38
|
+
ToolMessage,
|
|
39
|
+
)
|
|
40
|
+
from langchain_core.messages.ai import UsageMetadata
|
|
41
|
+
from langchain_core.messages.tool import ToolCall
|
|
42
|
+
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
|
43
|
+
from langchain_core.output_parsers.openai_tools import (
|
|
44
|
+
JsonOutputKeyToolsParser,
|
|
45
|
+
PydanticToolsParser,
|
|
46
|
+
)
|
|
47
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
48
|
+
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
|
49
|
+
from langchain_core.tools import BaseTool
|
|
50
|
+
from langchain_core.utils import from_env, secret_from_env
|
|
51
|
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
52
|
+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
|
53
|
+
from typing_extensions import Self
|
|
54
|
+
|
|
55
|
+
from langchain_maritaca.version import __version__
|
|
56
|
+
|
|
57
|
+
# HTTP status code for rate limiting
|
|
58
|
+
HTTP_TOO_MANY_REQUESTS = 429
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ChatMaritaca(BaseChatModel):
|
|
62
|
+
r"""Maritaca AI Chat large language models API.
|
|
63
|
+
|
|
64
|
+
Maritaca AI provides Brazilian Portuguese-optimized language models,
|
|
65
|
+
offering excellent performance for Portuguese text generation, analysis,
|
|
66
|
+
and understanding tasks.
|
|
67
|
+
|
|
68
|
+
To use, you should have the environment variable `MARITACA_API_KEY`
|
|
69
|
+
set with your API key, or pass it as a named parameter to the constructor.
|
|
70
|
+
|
|
71
|
+
Setup:
|
|
72
|
+
Install `langchain-maritaca` and set environment variable
|
|
73
|
+
`MARITACA_API_KEY`.
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
pip install -U langchain-maritaca
|
|
77
|
+
export MARITACA_API_KEY="your-api-key"
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
Key init args - completion params:
|
|
81
|
+
model:
|
|
82
|
+
Name of Maritaca model to use. Available models:
|
|
83
|
+
- `sabia-3.1` (default): Most capable model
|
|
84
|
+
- `sabiazinho-3.1`: Faster and more economical
|
|
85
|
+
temperature:
|
|
86
|
+
Sampling temperature. Ranges from 0.0 to 2.0.
|
|
87
|
+
max_tokens:
|
|
88
|
+
Max number of tokens to generate.
|
|
89
|
+
|
|
90
|
+
Key init args - client params:
|
|
91
|
+
timeout:
|
|
92
|
+
Timeout for requests.
|
|
93
|
+
max_retries:
|
|
94
|
+
Max number of retries.
|
|
95
|
+
api_key:
|
|
96
|
+
Maritaca API key. If not passed in will be read from
|
|
97
|
+
env var `MARITACA_API_KEY`.
|
|
98
|
+
|
|
99
|
+
Instantiate:
|
|
100
|
+
```python
|
|
101
|
+
from langchain_maritaca import ChatMaritaca
|
|
102
|
+
|
|
103
|
+
model = ChatMaritaca(
|
|
104
|
+
model="sabia-3.1",
|
|
105
|
+
temperature=0.7,
|
|
106
|
+
max_retries=2,
|
|
107
|
+
)
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
Invoke:
|
|
111
|
+
```python
|
|
112
|
+
messages = [
|
|
113
|
+
("system", "Você é um assistente prestativo."),
|
|
114
|
+
("human", "Qual é a capital do Brasil?"),
|
|
115
|
+
]
|
|
116
|
+
model.invoke(messages)
|
|
117
|
+
```
|
|
118
|
+
```python
|
|
119
|
+
AIMessage(
|
|
120
|
+
content="A capital do Brasil é Brasília.",
|
|
121
|
+
response_metadata={"model": "sabia-3.1", "finish_reason": "stop"},
|
|
122
|
+
)
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
Stream:
|
|
126
|
+
```python
|
|
127
|
+
for chunk in model.stream(messages):
|
|
128
|
+
print(chunk.text, end="")
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
Async:
|
|
132
|
+
```python
|
|
133
|
+
await model.ainvoke(messages)
|
|
134
|
+
```
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
client: Any = Field(default=None, exclude=True)
|
|
138
|
+
"""Sync HTTP client."""
|
|
139
|
+
|
|
140
|
+
async_client: Any = Field(default=None, exclude=True)
|
|
141
|
+
"""Async HTTP client."""
|
|
142
|
+
|
|
143
|
+
model_name: str = Field(default="sabia-3.1", alias="model")
|
|
144
|
+
"""Model name to use.
|
|
145
|
+
|
|
146
|
+
Available models:
|
|
147
|
+
- sabia-3.1: Most capable model, best for complex tasks
|
|
148
|
+
- sabiazinho-3.1: Fast and economical, great for simple tasks
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
temperature: float = 0.7
|
|
152
|
+
"""Sampling temperature (0.0 to 2.0)."""
|
|
153
|
+
|
|
154
|
+
max_tokens: int | None = Field(default=None)
|
|
155
|
+
"""Maximum number of tokens to generate."""
|
|
156
|
+
|
|
157
|
+
top_p: float = 0.9
|
|
158
|
+
"""Top-p sampling parameter."""
|
|
159
|
+
|
|
160
|
+
stop: list[str] | str | None = Field(default=None, alias="stop_sequences")
|
|
161
|
+
"""Default stop sequences."""
|
|
162
|
+
|
|
163
|
+
frequency_penalty: float = 0.0
|
|
164
|
+
"""Frequency penalty (-2.0 to 2.0)."""
|
|
165
|
+
|
|
166
|
+
presence_penalty: float = 0.0
|
|
167
|
+
"""Presence penalty (-2.0 to 2.0)."""
|
|
168
|
+
|
|
169
|
+
maritaca_api_key: SecretStr | None = Field(
|
|
170
|
+
alias="api_key",
|
|
171
|
+
default_factory=secret_from_env("MARITACA_API_KEY", default=None),
|
|
172
|
+
)
|
|
173
|
+
"""Maritaca API key. Automatically inferred from env var `MARITACA_API_KEY`."""
|
|
174
|
+
|
|
175
|
+
maritaca_api_base: str = Field(
|
|
176
|
+
alias="base_url",
|
|
177
|
+
default_factory=from_env(
|
|
178
|
+
"MARITACA_API_BASE", default="https://chat.maritaca.ai/api"
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
"""Base URL for Maritaca API."""
|
|
182
|
+
|
|
183
|
+
request_timeout: float | None = Field(default=60.0, alias="timeout")
|
|
184
|
+
"""Timeout for requests in seconds."""
|
|
185
|
+
|
|
186
|
+
max_retries: int = 2
|
|
187
|
+
"""Maximum number of retries."""
|
|
188
|
+
|
|
189
|
+
streaming: bool = False
|
|
190
|
+
"""Whether to stream results."""
|
|
191
|
+
|
|
192
|
+
n: int = 1
|
|
193
|
+
"""Number of completions to generate."""
|
|
194
|
+
|
|
195
|
+
tools: list[dict[str, Any]] | None = Field(default=None, exclude=True)
|
|
196
|
+
"""List of tools (functions) available for the model to call."""
|
|
197
|
+
|
|
198
|
+
tool_choice: str | dict[str, Any] | None = Field(default=None, exclude=True)
|
|
199
|
+
"""Control which tool is called. Options: 'auto', 'required', or specific tool."""
|
|
200
|
+
|
|
201
|
+
model_config = ConfigDict(
|
|
202
|
+
populate_by_name=True,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
@model_validator(mode="after")
|
|
206
|
+
def validate_environment(self) -> Self:
|
|
207
|
+
"""Validate that API key exists and initialize HTTP clients."""
|
|
208
|
+
if self.n < 1:
|
|
209
|
+
msg = "n must be at least 1."
|
|
210
|
+
raise ValueError(msg)
|
|
211
|
+
if self.n > 1 and self.streaming:
|
|
212
|
+
msg = "n must be 1 when streaming."
|
|
213
|
+
raise ValueError(msg)
|
|
214
|
+
|
|
215
|
+
# Ensure temperature is not exactly 0 (causes issues with some APIs)
|
|
216
|
+
if self.temperature == 0:
|
|
217
|
+
self.temperature = 1e-8
|
|
218
|
+
|
|
219
|
+
# Initialize HTTP clients
|
|
220
|
+
api_key = (
|
|
221
|
+
self.maritaca_api_key.get_secret_value() if self.maritaca_api_key else ""
|
|
222
|
+
)
|
|
223
|
+
headers = {
|
|
224
|
+
"Authorization": f"Bearer {api_key}",
|
|
225
|
+
"Content-Type": "application/json",
|
|
226
|
+
"User-Agent": f"langchain-maritaca/{__version__}",
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
if not self.client:
|
|
230
|
+
self.client = httpx.Client(
|
|
231
|
+
base_url=self.maritaca_api_base,
|
|
232
|
+
headers=headers,
|
|
233
|
+
timeout=httpx.Timeout(self.request_timeout),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if not self.async_client:
|
|
237
|
+
self.async_client = httpx.AsyncClient(
|
|
238
|
+
base_url=self.maritaca_api_base,
|
|
239
|
+
headers=headers,
|
|
240
|
+
timeout=httpx.Timeout(self.request_timeout),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return self
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def lc_secrets(self) -> dict[str, str]:
|
|
247
|
+
"""Mapping of secret environment variables."""
|
|
248
|
+
return {"maritaca_api_key": "MARITACA_API_KEY"}
|
|
249
|
+
|
|
250
|
+
@classmethod
|
|
251
|
+
def is_lc_serializable(cls) -> bool:
|
|
252
|
+
"""Return whether this model can be serialized by LangChain."""
|
|
253
|
+
return True
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def _llm_type(self) -> str:
|
|
257
|
+
"""Return type of model."""
|
|
258
|
+
return "maritaca-chat"
|
|
259
|
+
|
|
260
|
+
def _get_ls_params(
|
|
261
|
+
self, stop: list[str] | None = None, **kwargs: Any
|
|
262
|
+
) -> LangSmithParams:
|
|
263
|
+
"""Get standard params for tracing."""
|
|
264
|
+
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
265
|
+
ls_params = LangSmithParams(
|
|
266
|
+
ls_provider="maritaca",
|
|
267
|
+
ls_model_name=params.get("model", self.model_name),
|
|
268
|
+
ls_model_type="chat",
|
|
269
|
+
ls_temperature=params.get("temperature", self.temperature),
|
|
270
|
+
)
|
|
271
|
+
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
|
272
|
+
ls_params["ls_max_tokens"] = ls_max_tokens
|
|
273
|
+
if ls_stop := stop or params.get("stop", None) or self.stop:
|
|
274
|
+
ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop]
|
|
275
|
+
return ls_params
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def _default_params(self) -> dict[str, Any]:
|
|
279
|
+
"""Get the default parameters for calling Maritaca API."""
|
|
280
|
+
params: dict[str, Any] = {
|
|
281
|
+
"model": self.model_name,
|
|
282
|
+
"temperature": self.temperature,
|
|
283
|
+
"top_p": self.top_p,
|
|
284
|
+
"frequency_penalty": self.frequency_penalty,
|
|
285
|
+
"presence_penalty": self.presence_penalty,
|
|
286
|
+
"n": self.n,
|
|
287
|
+
}
|
|
288
|
+
if self.max_tokens is not None:
|
|
289
|
+
params["max_tokens"] = self.max_tokens
|
|
290
|
+
if self.stop is not None:
|
|
291
|
+
params["stop"] = self.stop
|
|
292
|
+
if self.tools is not None:
|
|
293
|
+
params["tools"] = self.tools
|
|
294
|
+
if self.tool_choice is not None:
|
|
295
|
+
params["tool_choice"] = self.tool_choice
|
|
296
|
+
return params
|
|
297
|
+
|
|
298
|
+
def bind_tools(
|
|
299
|
+
self,
|
|
300
|
+
tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
|
|
301
|
+
*,
|
|
302
|
+
tool_choice: str | dict[str, Any] | None = None,
|
|
303
|
+
**kwargs: Any,
|
|
304
|
+
) -> Runnable[Any, BaseMessage]:
|
|
305
|
+
"""Bind tools to this chat model.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
tools: A list of tools to bind. Can be:
|
|
309
|
+
- Dict with OpenAI tool schema
|
|
310
|
+
- Pydantic BaseModel class
|
|
311
|
+
- Python function with type hints
|
|
312
|
+
- LangChain BaseTool instance
|
|
313
|
+
tool_choice: Control which tool is called:
|
|
314
|
+
- "auto": Model decides (default)
|
|
315
|
+
- "required": Model must call a tool
|
|
316
|
+
- {"type": "function", "function": {"name": "..."}}:
|
|
317
|
+
Force specific tool
|
|
318
|
+
**kwargs: Additional arguments passed to the model.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
A Runnable that will pass the tools to the model.
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
.. code-block:: python
|
|
325
|
+
|
|
326
|
+
from pydantic import BaseModel, Field
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class GetWeather(BaseModel):
|
|
330
|
+
'''Get the weather for a location.'''
|
|
331
|
+
|
|
332
|
+
location: str = Field(description="City name")
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
model = ChatMaritaca()
|
|
336
|
+
model_with_tools = model.bind_tools([GetWeather])
|
|
337
|
+
response = model_with_tools.invoke("What's the weather in São Paulo?")
|
|
338
|
+
"""
|
|
339
|
+
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
|
340
|
+
return self.bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)
|
|
341
|
+
|
|
342
|
+
def with_structured_output(
|
|
343
|
+
self,
|
|
344
|
+
schema: dict[str, Any] | type[BaseModel] | None = None,
|
|
345
|
+
*,
|
|
346
|
+
method: Literal["function_calling", "json_mode"] = "function_calling",
|
|
347
|
+
include_raw: bool = False,
|
|
348
|
+
**kwargs: Any,
|
|
349
|
+
) -> Runnable[Any, dict[str, Any] | BaseModel]:
|
|
350
|
+
"""Create a runnable that returns structured output matching a schema.
|
|
351
|
+
|
|
352
|
+
Uses the model's tool-calling or JSON mode capabilities to guarantee
|
|
353
|
+
output conforms to the specified schema.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
schema: The output schema. Can be:
|
|
357
|
+
- A Pydantic BaseModel class
|
|
358
|
+
- A dictionary with JSON Schema
|
|
359
|
+
method: The method to use for structured output:
|
|
360
|
+
- "function_calling": Uses tool calling (default, recommended)
|
|
361
|
+
- "json_mode": Uses JSON response format
|
|
362
|
+
include_raw: If True, returns a dict with keys:
|
|
363
|
+
- "raw": The raw model response (BaseMessage)
|
|
364
|
+
- "parsed": The parsed structured output
|
|
365
|
+
- "parsing_error": Any parsing error that occurred
|
|
366
|
+
**kwargs: Additional arguments passed to the model.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
A Runnable that outputs structured data matching the schema.
|
|
370
|
+
|
|
371
|
+
Example:
|
|
372
|
+
.. code-block:: python
|
|
373
|
+
|
|
374
|
+
from pydantic import BaseModel, Field
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class Person(BaseModel):
|
|
378
|
+
'''Information about a person.'''
|
|
379
|
+
|
|
380
|
+
name: str = Field(description="Person's name")
|
|
381
|
+
age: int = Field(description="Person's age")
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
model = ChatMaritaca()
|
|
385
|
+
structured_model = model.with_structured_output(Person)
|
|
386
|
+
result = structured_model.invoke("João tem 25 anos")
|
|
387
|
+
# Returns: Person(name="João", age=25)
|
|
388
|
+
|
|
389
|
+
Note:
|
|
390
|
+
The "function_calling" method is more reliable as it uses the
|
|
391
|
+
model's native tool calling capabilities.
|
|
392
|
+
"""
|
|
393
|
+
if schema is None:
|
|
394
|
+
msg = "schema must be specified for with_structured_output"
|
|
395
|
+
raise ValueError(msg)
|
|
396
|
+
|
|
397
|
+
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
|
398
|
+
|
|
399
|
+
if method == "function_calling":
|
|
400
|
+
formatted_tool = convert_to_openai_tool(schema)
|
|
401
|
+
tool_name = formatted_tool["function"]["name"]
|
|
402
|
+
llm = self.bind_tools(
|
|
403
|
+
[schema],
|
|
404
|
+
tool_choice={"type": "function", "function": {"name": tool_name}},
|
|
405
|
+
**kwargs,
|
|
406
|
+
)
|
|
407
|
+
if is_pydantic_schema:
|
|
408
|
+
# Type narrowing: we know schema is type[BaseModel] here
|
|
409
|
+
pydantic_schema: type[BaseModel] = schema # type: ignore[assignment]
|
|
410
|
+
output_parser: Runnable[Any, Any] = PydanticToolsParser(
|
|
411
|
+
tools=[pydantic_schema],
|
|
412
|
+
first_tool_only=True,
|
|
413
|
+
)
|
|
414
|
+
else:
|
|
415
|
+
output_parser = JsonOutputKeyToolsParser(
|
|
416
|
+
key_name=tool_name, first_tool_only=True
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
elif method == "json_mode":
|
|
420
|
+
llm = self.bind(
|
|
421
|
+
response_format={"type": "json_object"},
|
|
422
|
+
**kwargs,
|
|
423
|
+
)
|
|
424
|
+
if is_pydantic_schema:
|
|
425
|
+
# Type narrowing: we know schema is type[BaseModel] here
|
|
426
|
+
pydantic_schema = schema # type: ignore[assignment]
|
|
427
|
+
output_parser = PydanticOutputParser(pydantic_object=pydantic_schema)
|
|
428
|
+
else:
|
|
429
|
+
output_parser = JsonOutputParser()
|
|
430
|
+
|
|
431
|
+
else:
|
|
432
|
+
msg = (
|
|
433
|
+
f"Unrecognized method argument. Expected 'function_calling' or "
|
|
434
|
+
f"'json_mode'. Received: '{method}'"
|
|
435
|
+
)
|
|
436
|
+
raise ValueError(msg)
|
|
437
|
+
|
|
438
|
+
if include_raw:
|
|
439
|
+
parser_assign = RunnablePassthrough.assign(
|
|
440
|
+
parsed=itemgetter("raw") | output_parser,
|
|
441
|
+
parsing_error=lambda _: None,
|
|
442
|
+
)
|
|
443
|
+
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
|
444
|
+
parser_with_fallback = parser_assign.with_fallbacks(
|
|
445
|
+
[parser_none], exception_key="parsing_error"
|
|
446
|
+
)
|
|
447
|
+
return RunnableMap(raw=llm) | parser_with_fallback
|
|
448
|
+
|
|
449
|
+
return llm | output_parser
|
|
450
|
+
|
|
451
|
+
def _generate(
|
|
452
|
+
self,
|
|
453
|
+
messages: list[BaseMessage],
|
|
454
|
+
stop: list[str] | None = None,
|
|
455
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
456
|
+
**kwargs: Any,
|
|
457
|
+
) -> ChatResult:
|
|
458
|
+
"""Generate a chat completion."""
|
|
459
|
+
if self.streaming:
|
|
460
|
+
stream_iter = self._stream(
|
|
461
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
462
|
+
)
|
|
463
|
+
return generate_from_stream(stream_iter)
|
|
464
|
+
|
|
465
|
+
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
466
|
+
params = {**params, **kwargs}
|
|
467
|
+
|
|
468
|
+
response = self._make_request(message_dicts, params)
|
|
469
|
+
return self._create_chat_result(response)
|
|
470
|
+
|
|
471
|
+
async def _agenerate(
|
|
472
|
+
self,
|
|
473
|
+
messages: list[BaseMessage],
|
|
474
|
+
stop: list[str] | None = None,
|
|
475
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
476
|
+
**kwargs: Any,
|
|
477
|
+
) -> ChatResult:
|
|
478
|
+
"""Async generate a chat completion."""
|
|
479
|
+
if self.streaming:
|
|
480
|
+
stream_iter = self._astream(
|
|
481
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
482
|
+
)
|
|
483
|
+
return await agenerate_from_stream(stream_iter)
|
|
484
|
+
|
|
485
|
+
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
486
|
+
params = {**params, **kwargs}
|
|
487
|
+
|
|
488
|
+
response = await self._amake_request(message_dicts, params)
|
|
489
|
+
return self._create_chat_result(response)
|
|
490
|
+
|
|
491
|
+
def _stream(
|
|
492
|
+
self,
|
|
493
|
+
messages: list[BaseMessage],
|
|
494
|
+
stop: list[str] | None = None,
|
|
495
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
496
|
+
**kwargs: Any,
|
|
497
|
+
) -> Iterator[ChatGenerationChunk]:
|
|
498
|
+
"""Stream a chat completion."""
|
|
499
|
+
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
500
|
+
params = {**params, **kwargs, "stream": True}
|
|
501
|
+
|
|
502
|
+
with self.client.stream(
|
|
503
|
+
"POST",
|
|
504
|
+
"/chat/completions",
|
|
505
|
+
json={"messages": message_dicts, **params},
|
|
506
|
+
) as response:
|
|
507
|
+
response.raise_for_status()
|
|
508
|
+
for line in response.iter_lines():
|
|
509
|
+
if line.startswith("data: "):
|
|
510
|
+
data = line[6:]
|
|
511
|
+
if data == "[DONE]":
|
|
512
|
+
break
|
|
513
|
+
try:
|
|
514
|
+
chunk = json.loads(data)
|
|
515
|
+
if not chunk.get("choices"):
|
|
516
|
+
continue
|
|
517
|
+
choice = chunk["choices"][0]
|
|
518
|
+
delta = choice.get("delta", {})
|
|
519
|
+
content = delta.get("content", "")
|
|
520
|
+
|
|
521
|
+
message_chunk = AIMessageChunk(content=content)
|
|
522
|
+
generation_info = {}
|
|
523
|
+
|
|
524
|
+
if finish_reason := choice.get("finish_reason"):
|
|
525
|
+
generation_info["finish_reason"] = finish_reason
|
|
526
|
+
generation_info["model"] = self.model_name
|
|
527
|
+
|
|
528
|
+
if generation_info:
|
|
529
|
+
message_chunk = message_chunk.model_copy(
|
|
530
|
+
update={"response_metadata": generation_info}
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
generation_chunk = ChatGenerationChunk(
|
|
534
|
+
message=message_chunk,
|
|
535
|
+
generation_info=generation_info or None,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
if run_manager:
|
|
539
|
+
run_manager.on_llm_new_token(
|
|
540
|
+
generation_chunk.text, chunk=generation_chunk
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
yield generation_chunk
|
|
544
|
+
|
|
545
|
+
except json.JSONDecodeError:
|
|
546
|
+
continue
|
|
547
|
+
|
|
548
|
+
async def _astream(
|
|
549
|
+
self,
|
|
550
|
+
messages: list[BaseMessage],
|
|
551
|
+
stop: list[str] | None = None,
|
|
552
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
553
|
+
**kwargs: Any,
|
|
554
|
+
) -> AsyncIterator[ChatGenerationChunk]:
|
|
555
|
+
"""Async stream a chat completion."""
|
|
556
|
+
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
557
|
+
params = {**params, **kwargs, "stream": True}
|
|
558
|
+
|
|
559
|
+
async with self.async_client.stream(
|
|
560
|
+
"POST",
|
|
561
|
+
"/chat/completions",
|
|
562
|
+
json={"messages": message_dicts, **params},
|
|
563
|
+
) as response:
|
|
564
|
+
response.raise_for_status()
|
|
565
|
+
async for line in response.aiter_lines():
|
|
566
|
+
if line.startswith("data: "):
|
|
567
|
+
data = line[6:]
|
|
568
|
+
if data == "[DONE]":
|
|
569
|
+
break
|
|
570
|
+
try:
|
|
571
|
+
chunk = json.loads(data)
|
|
572
|
+
if not chunk.get("choices"):
|
|
573
|
+
continue
|
|
574
|
+
choice = chunk["choices"][0]
|
|
575
|
+
delta = choice.get("delta", {})
|
|
576
|
+
content = delta.get("content", "")
|
|
577
|
+
|
|
578
|
+
message_chunk = AIMessageChunk(content=content)
|
|
579
|
+
generation_info = {}
|
|
580
|
+
|
|
581
|
+
if finish_reason := choice.get("finish_reason"):
|
|
582
|
+
generation_info["finish_reason"] = finish_reason
|
|
583
|
+
generation_info["model"] = self.model_name
|
|
584
|
+
|
|
585
|
+
if generation_info:
|
|
586
|
+
message_chunk = message_chunk.model_copy(
|
|
587
|
+
update={"response_metadata": generation_info}
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
generation_chunk = ChatGenerationChunk(
|
|
591
|
+
message=message_chunk,
|
|
592
|
+
generation_info=generation_info or None,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
if run_manager:
|
|
596
|
+
await run_manager.on_llm_new_token(
|
|
597
|
+
token=generation_chunk.text, chunk=generation_chunk
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
yield generation_chunk
|
|
601
|
+
|
|
602
|
+
except json.JSONDecodeError:
|
|
603
|
+
continue
|
|
604
|
+
|
|
605
|
+
def _make_request(
|
|
606
|
+
self, messages: list[dict[str, Any]], params: dict[str, Any]
|
|
607
|
+
) -> dict[str, Any]:
|
|
608
|
+
"""Make a sync request to Maritaca API."""
|
|
609
|
+
for attempt in range(self.max_retries + 1):
|
|
610
|
+
try:
|
|
611
|
+
response = self.client.post(
|
|
612
|
+
"/chat/completions",
|
|
613
|
+
json={"messages": messages, **params},
|
|
614
|
+
)
|
|
615
|
+
response.raise_for_status()
|
|
616
|
+
return response.json()
|
|
617
|
+
except httpx.HTTPStatusError as e:
|
|
618
|
+
is_rate_limited = e.response.status_code == HTTP_TOO_MANY_REQUESTS
|
|
619
|
+
if is_rate_limited and attempt < self.max_retries:
|
|
620
|
+
retry_after = int(e.response.headers.get("Retry-After", 60))
|
|
621
|
+
time.sleep(retry_after)
|
|
622
|
+
continue
|
|
623
|
+
raise
|
|
624
|
+
except httpx.TimeoutException:
|
|
625
|
+
if attempt < self.max_retries:
|
|
626
|
+
continue
|
|
627
|
+
raise
|
|
628
|
+
msg = f"Failed after {self.max_retries + 1} attempts"
|
|
629
|
+
raise RuntimeError(msg)
|
|
630
|
+
|
|
631
|
+
async def _amake_request(
|
|
632
|
+
self, messages: list[dict[str, Any]], params: dict[str, Any]
|
|
633
|
+
) -> dict[str, Any]:
|
|
634
|
+
"""Make an async request to Maritaca API."""
|
|
635
|
+
for attempt in range(self.max_retries + 1):
|
|
636
|
+
try:
|
|
637
|
+
response = await self.async_client.post(
|
|
638
|
+
"/chat/completions",
|
|
639
|
+
json={"messages": messages, **params},
|
|
640
|
+
)
|
|
641
|
+
response.raise_for_status()
|
|
642
|
+
return response.json()
|
|
643
|
+
except httpx.HTTPStatusError as e:
|
|
644
|
+
is_rate_limited = e.response.status_code == HTTP_TOO_MANY_REQUESTS
|
|
645
|
+
if is_rate_limited and attempt < self.max_retries:
|
|
646
|
+
retry_after = int(e.response.headers.get("Retry-After", 60))
|
|
647
|
+
await asyncio.sleep(retry_after)
|
|
648
|
+
continue
|
|
649
|
+
raise
|
|
650
|
+
except httpx.TimeoutException:
|
|
651
|
+
if attempt < self.max_retries:
|
|
652
|
+
continue
|
|
653
|
+
raise
|
|
654
|
+
msg = f"Failed after {self.max_retries + 1} attempts"
|
|
655
|
+
raise RuntimeError(msg)
|
|
656
|
+
|
|
657
|
+
def _create_message_dicts(
|
|
658
|
+
self, messages: list[BaseMessage], stop: list[str] | None
|
|
659
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
660
|
+
"""Convert LangChain messages to Maritaca format."""
|
|
661
|
+
params = self._default_params.copy()
|
|
662
|
+
if stop is not None:
|
|
663
|
+
params["stop"] = stop
|
|
664
|
+
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
|
665
|
+
return message_dicts, params
|
|
666
|
+
|
|
667
|
+
def _create_chat_result(self, response: dict[str, Any]) -> ChatResult:
|
|
668
|
+
"""Create a ChatResult from Maritaca API response."""
|
|
669
|
+
generations = []
|
|
670
|
+
token_usage = response.get("usage", {})
|
|
671
|
+
|
|
672
|
+
for choice in response.get("choices", []):
|
|
673
|
+
message = _convert_dict_to_message(choice.get("message", {}))
|
|
674
|
+
|
|
675
|
+
if token_usage and isinstance(message, AIMessage):
|
|
676
|
+
message.usage_metadata = _create_usage_metadata(token_usage)
|
|
677
|
+
|
|
678
|
+
generation_info = {"finish_reason": choice.get("finish_reason")}
|
|
679
|
+
gen = ChatGeneration(message=message, generation_info=generation_info)
|
|
680
|
+
generations.append(gen)
|
|
681
|
+
|
|
682
|
+
llm_output = {
|
|
683
|
+
"token_usage": token_usage,
|
|
684
|
+
"model": response.get("model", self.model_name),
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
return ChatResult(generations=generations, llm_output=llm_output)
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _convert_message_to_dict(message: BaseMessage) -> dict[str, Any]:
|
|
691
|
+
"""Convert a LangChain message to Maritaca format.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
message: The LangChain message.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
Dictionary in Maritaca API format.
|
|
698
|
+
"""
|
|
699
|
+
if isinstance(message, ChatMessage):
|
|
700
|
+
return {"role": message.role, "content": message.content}
|
|
701
|
+
if isinstance(message, HumanMessage):
|
|
702
|
+
return {"role": "user", "content": message.content}
|
|
703
|
+
if isinstance(message, AIMessage):
|
|
704
|
+
result: dict[str, Any] = {
|
|
705
|
+
"role": "assistant",
|
|
706
|
+
"content": message.content or "",
|
|
707
|
+
}
|
|
708
|
+
# Handle tool calls in AIMessage
|
|
709
|
+
if message.tool_calls:
|
|
710
|
+
result["tool_calls"] = [
|
|
711
|
+
{
|
|
712
|
+
"id": tc["id"],
|
|
713
|
+
"type": "function",
|
|
714
|
+
"function": {
|
|
715
|
+
"name": tc["name"],
|
|
716
|
+
"arguments": json.dumps(tc["args"]),
|
|
717
|
+
},
|
|
718
|
+
}
|
|
719
|
+
for tc in message.tool_calls
|
|
720
|
+
]
|
|
721
|
+
return result
|
|
722
|
+
if isinstance(message, SystemMessage):
|
|
723
|
+
return {"role": "system", "content": message.content}
|
|
724
|
+
if isinstance(message, ToolMessage):
|
|
725
|
+
return {
|
|
726
|
+
"role": "tool",
|
|
727
|
+
"content": message.content,
|
|
728
|
+
"tool_call_id": message.tool_call_id,
|
|
729
|
+
}
|
|
730
|
+
msg = f"Got unknown message type: {type(message)}"
|
|
731
|
+
raise TypeError(msg)
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
def _convert_dict_to_message(message_dict: Mapping[str, Any]) -> BaseMessage:
|
|
735
|
+
"""Convert a Maritaca message dict to LangChain message.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
message_dict: Dictionary from Maritaca API response.
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
LangChain BaseMessage.
|
|
742
|
+
"""
|
|
743
|
+
role = message_dict.get("role", "")
|
|
744
|
+
content = message_dict.get("content", "") or ""
|
|
745
|
+
|
|
746
|
+
if role == "user":
|
|
747
|
+
return HumanMessage(content=content)
|
|
748
|
+
if role == "assistant":
|
|
749
|
+
# Parse tool_calls if present
|
|
750
|
+
tool_calls_data = message_dict.get("tool_calls", [])
|
|
751
|
+
tool_calls = []
|
|
752
|
+
for tc in tool_calls_data:
|
|
753
|
+
func = tc.get("function", {})
|
|
754
|
+
args_str = func.get("arguments", "{}")
|
|
755
|
+
try:
|
|
756
|
+
args = json.loads(args_str)
|
|
757
|
+
except json.JSONDecodeError:
|
|
758
|
+
args = {}
|
|
759
|
+
tool_calls.append(
|
|
760
|
+
ToolCall(
|
|
761
|
+
name=func.get("name", ""),
|
|
762
|
+
args=args,
|
|
763
|
+
id=tc.get("id", ""),
|
|
764
|
+
)
|
|
765
|
+
)
|
|
766
|
+
return AIMessage(content=content, tool_calls=tool_calls if tool_calls else [])
|
|
767
|
+
if role == "system":
|
|
768
|
+
return SystemMessage(content=content)
|
|
769
|
+
if role == "tool":
|
|
770
|
+
return ToolMessage(
|
|
771
|
+
content=content,
|
|
772
|
+
tool_call_id=message_dict.get("tool_call_id", ""),
|
|
773
|
+
)
|
|
774
|
+
return ChatMessage(content=content, role=role)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _create_usage_metadata(token_usage: dict[str, Any]) -> UsageMetadata:
|
|
778
|
+
"""Create usage metadata from Maritaca token usage response.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
token_usage: Token usage dict from Maritaca API response.
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
UsageMetadata with token counts.
|
|
785
|
+
"""
|
|
786
|
+
input_tokens = token_usage.get("prompt_tokens", 0)
|
|
787
|
+
output_tokens = token_usage.get("completion_tokens", 0)
|
|
788
|
+
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)
|
|
789
|
+
|
|
790
|
+
return UsageMetadata(
|
|
791
|
+
input_tokens=input_tokens,
|
|
792
|
+
output_tokens=output_tokens,
|
|
793
|
+
total_tokens=total_tokens,
|
|
794
|
+
)
|