langchain-maritaca 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,794 @@
1
+ """Maritaca AI Chat wrapper for LangChain.
2
+
3
+ Maritaca AI provides Brazilian Portuguese-optimized language models,
4
+ including the Sabiá family of models.
5
+
6
+ Author: Anderson Henrique da Silva
7
+ Location: Minas Gerais, Brasil
8
+ GitHub: https://github.com/anderson-ufrj
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import time
16
+ from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
17
+ from operator import itemgetter
18
+ from typing import Any, Literal
19
+
20
+ import httpx
21
+ from langchain_core.callbacks import (
22
+ AsyncCallbackManagerForLLMRun,
23
+ CallbackManagerForLLMRun,
24
+ )
25
+ from langchain_core.language_models.chat_models import (
26
+ BaseChatModel,
27
+ LangSmithParams,
28
+ agenerate_from_stream,
29
+ generate_from_stream,
30
+ )
31
+ from langchain_core.messages import (
32
+ AIMessage,
33
+ AIMessageChunk,
34
+ BaseMessage,
35
+ ChatMessage,
36
+ HumanMessage,
37
+ SystemMessage,
38
+ ToolMessage,
39
+ )
40
+ from langchain_core.messages.ai import UsageMetadata
41
+ from langchain_core.messages.tool import ToolCall
42
+ from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
43
+ from langchain_core.output_parsers.openai_tools import (
44
+ JsonOutputKeyToolsParser,
45
+ PydanticToolsParser,
46
+ )
47
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
48
+ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
49
+ from langchain_core.tools import BaseTool
50
+ from langchain_core.utils import from_env, secret_from_env
51
+ from langchain_core.utils.function_calling import convert_to_openai_tool
52
+ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
53
+ from typing_extensions import Self
54
+
55
+ from langchain_maritaca.version import __version__
56
+
57
+ # HTTP status code for rate limiting
58
+ HTTP_TOO_MANY_REQUESTS = 429
59
+
60
+
61
+ class ChatMaritaca(BaseChatModel):
62
+ r"""Maritaca AI Chat large language models API.
63
+
64
+ Maritaca AI provides Brazilian Portuguese-optimized language models,
65
+ offering excellent performance for Portuguese text generation, analysis,
66
+ and understanding tasks.
67
+
68
+ To use, you should have the environment variable `MARITACA_API_KEY`
69
+ set with your API key, or pass it as a named parameter to the constructor.
70
+
71
+ Setup:
72
+ Install `langchain-maritaca` and set environment variable
73
+ `MARITACA_API_KEY`.
74
+
75
+ ```bash
76
+ pip install -U langchain-maritaca
77
+ export MARITACA_API_KEY="your-api-key"
78
+ ```
79
+
80
+ Key init args - completion params:
81
+ model:
82
+ Name of Maritaca model to use. Available models:
83
+ - `sabia-3.1` (default): Most capable model
84
+ - `sabiazinho-3.1`: Faster and more economical
85
+ temperature:
86
+ Sampling temperature. Ranges from 0.0 to 2.0.
87
+ max_tokens:
88
+ Max number of tokens to generate.
89
+
90
+ Key init args - client params:
91
+ timeout:
92
+ Timeout for requests.
93
+ max_retries:
94
+ Max number of retries.
95
+ api_key:
96
+ Maritaca API key. If not passed in will be read from
97
+ env var `MARITACA_API_KEY`.
98
+
99
+ Instantiate:
100
+ ```python
101
+ from langchain_maritaca import ChatMaritaca
102
+
103
+ model = ChatMaritaca(
104
+ model="sabia-3.1",
105
+ temperature=0.7,
106
+ max_retries=2,
107
+ )
108
+ ```
109
+
110
+ Invoke:
111
+ ```python
112
+ messages = [
113
+ ("system", "Você é um assistente prestativo."),
114
+ ("human", "Qual é a capital do Brasil?"),
115
+ ]
116
+ model.invoke(messages)
117
+ ```
118
+ ```python
119
+ AIMessage(
120
+ content="A capital do Brasil é Brasília.",
121
+ response_metadata={"model": "sabia-3.1", "finish_reason": "stop"},
122
+ )
123
+ ```
124
+
125
+ Stream:
126
+ ```python
127
+ for chunk in model.stream(messages):
128
+ print(chunk.text, end="")
129
+ ```
130
+
131
+ Async:
132
+ ```python
133
+ await model.ainvoke(messages)
134
+ ```
135
+ """
136
+
137
+ client: Any = Field(default=None, exclude=True)
138
+ """Sync HTTP client."""
139
+
140
+ async_client: Any = Field(default=None, exclude=True)
141
+ """Async HTTP client."""
142
+
143
+ model_name: str = Field(default="sabia-3.1", alias="model")
144
+ """Model name to use.
145
+
146
+ Available models:
147
+ - sabia-3.1: Most capable model, best for complex tasks
148
+ - sabiazinho-3.1: Fast and economical, great for simple tasks
149
+ """
150
+
151
+ temperature: float = 0.7
152
+ """Sampling temperature (0.0 to 2.0)."""
153
+
154
+ max_tokens: int | None = Field(default=None)
155
+ """Maximum number of tokens to generate."""
156
+
157
+ top_p: float = 0.9
158
+ """Top-p sampling parameter."""
159
+
160
+ stop: list[str] | str | None = Field(default=None, alias="stop_sequences")
161
+ """Default stop sequences."""
162
+
163
+ frequency_penalty: float = 0.0
164
+ """Frequency penalty (-2.0 to 2.0)."""
165
+
166
+ presence_penalty: float = 0.0
167
+ """Presence penalty (-2.0 to 2.0)."""
168
+
169
+ maritaca_api_key: SecretStr | None = Field(
170
+ alias="api_key",
171
+ default_factory=secret_from_env("MARITACA_API_KEY", default=None),
172
+ )
173
+ """Maritaca API key. Automatically inferred from env var `MARITACA_API_KEY`."""
174
+
175
+ maritaca_api_base: str = Field(
176
+ alias="base_url",
177
+ default_factory=from_env(
178
+ "MARITACA_API_BASE", default="https://chat.maritaca.ai/api"
179
+ ),
180
+ )
181
+ """Base URL for Maritaca API."""
182
+
183
+ request_timeout: float | None = Field(default=60.0, alias="timeout")
184
+ """Timeout for requests in seconds."""
185
+
186
+ max_retries: int = 2
187
+ """Maximum number of retries."""
188
+
189
+ streaming: bool = False
190
+ """Whether to stream results."""
191
+
192
+ n: int = 1
193
+ """Number of completions to generate."""
194
+
195
+ tools: list[dict[str, Any]] | None = Field(default=None, exclude=True)
196
+ """List of tools (functions) available for the model to call."""
197
+
198
+ tool_choice: str | dict[str, Any] | None = Field(default=None, exclude=True)
199
+ """Control which tool is called. Options: 'auto', 'required', or specific tool."""
200
+
201
+ model_config = ConfigDict(
202
+ populate_by_name=True,
203
+ )
204
+
205
+ @model_validator(mode="after")
206
+ def validate_environment(self) -> Self:
207
+ """Validate that API key exists and initialize HTTP clients."""
208
+ if self.n < 1:
209
+ msg = "n must be at least 1."
210
+ raise ValueError(msg)
211
+ if self.n > 1 and self.streaming:
212
+ msg = "n must be 1 when streaming."
213
+ raise ValueError(msg)
214
+
215
+ # Ensure temperature is not exactly 0 (causes issues with some APIs)
216
+ if self.temperature == 0:
217
+ self.temperature = 1e-8
218
+
219
+ # Initialize HTTP clients
220
+ api_key = (
221
+ self.maritaca_api_key.get_secret_value() if self.maritaca_api_key else ""
222
+ )
223
+ headers = {
224
+ "Authorization": f"Bearer {api_key}",
225
+ "Content-Type": "application/json",
226
+ "User-Agent": f"langchain-maritaca/{__version__}",
227
+ }
228
+
229
+ if not self.client:
230
+ self.client = httpx.Client(
231
+ base_url=self.maritaca_api_base,
232
+ headers=headers,
233
+ timeout=httpx.Timeout(self.request_timeout),
234
+ )
235
+
236
+ if not self.async_client:
237
+ self.async_client = httpx.AsyncClient(
238
+ base_url=self.maritaca_api_base,
239
+ headers=headers,
240
+ timeout=httpx.Timeout(self.request_timeout),
241
+ )
242
+
243
+ return self
244
+
245
+ @property
246
+ def lc_secrets(self) -> dict[str, str]:
247
+ """Mapping of secret environment variables."""
248
+ return {"maritaca_api_key": "MARITACA_API_KEY"}
249
+
250
+ @classmethod
251
+ def is_lc_serializable(cls) -> bool:
252
+ """Return whether this model can be serialized by LangChain."""
253
+ return True
254
+
255
+ @property
256
+ def _llm_type(self) -> str:
257
+ """Return type of model."""
258
+ return "maritaca-chat"
259
+
260
+ def _get_ls_params(
261
+ self, stop: list[str] | None = None, **kwargs: Any
262
+ ) -> LangSmithParams:
263
+ """Get standard params for tracing."""
264
+ params = self._get_invocation_params(stop=stop, **kwargs)
265
+ ls_params = LangSmithParams(
266
+ ls_provider="maritaca",
267
+ ls_model_name=params.get("model", self.model_name),
268
+ ls_model_type="chat",
269
+ ls_temperature=params.get("temperature", self.temperature),
270
+ )
271
+ if ls_max_tokens := params.get("max_tokens", self.max_tokens):
272
+ ls_params["ls_max_tokens"] = ls_max_tokens
273
+ if ls_stop := stop or params.get("stop", None) or self.stop:
274
+ ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop]
275
+ return ls_params
276
+
277
+ @property
278
+ def _default_params(self) -> dict[str, Any]:
279
+ """Get the default parameters for calling Maritaca API."""
280
+ params: dict[str, Any] = {
281
+ "model": self.model_name,
282
+ "temperature": self.temperature,
283
+ "top_p": self.top_p,
284
+ "frequency_penalty": self.frequency_penalty,
285
+ "presence_penalty": self.presence_penalty,
286
+ "n": self.n,
287
+ }
288
+ if self.max_tokens is not None:
289
+ params["max_tokens"] = self.max_tokens
290
+ if self.stop is not None:
291
+ params["stop"] = self.stop
292
+ if self.tools is not None:
293
+ params["tools"] = self.tools
294
+ if self.tool_choice is not None:
295
+ params["tool_choice"] = self.tool_choice
296
+ return params
297
+
298
+ def bind_tools(
299
+ self,
300
+ tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
301
+ *,
302
+ tool_choice: str | dict[str, Any] | None = None,
303
+ **kwargs: Any,
304
+ ) -> Runnable[Any, BaseMessage]:
305
+ """Bind tools to this chat model.
306
+
307
+ Args:
308
+ tools: A list of tools to bind. Can be:
309
+ - Dict with OpenAI tool schema
310
+ - Pydantic BaseModel class
311
+ - Python function with type hints
312
+ - LangChain BaseTool instance
313
+ tool_choice: Control which tool is called:
314
+ - "auto": Model decides (default)
315
+ - "required": Model must call a tool
316
+ - {"type": "function", "function": {"name": "..."}}:
317
+ Force specific tool
318
+ **kwargs: Additional arguments passed to the model.
319
+
320
+ Returns:
321
+ A Runnable that will pass the tools to the model.
322
+
323
+ Example:
324
+ .. code-block:: python
325
+
326
+ from pydantic import BaseModel, Field
327
+
328
+
329
+ class GetWeather(BaseModel):
330
+ '''Get the weather for a location.'''
331
+
332
+ location: str = Field(description="City name")
333
+
334
+
335
+ model = ChatMaritaca()
336
+ model_with_tools = model.bind_tools([GetWeather])
337
+ response = model_with_tools.invoke("What's the weather in São Paulo?")
338
+ """
339
+ formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
340
+ return self.bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)
341
+
342
+ def with_structured_output(
343
+ self,
344
+ schema: dict[str, Any] | type[BaseModel] | None = None,
345
+ *,
346
+ method: Literal["function_calling", "json_mode"] = "function_calling",
347
+ include_raw: bool = False,
348
+ **kwargs: Any,
349
+ ) -> Runnable[Any, dict[str, Any] | BaseModel]:
350
+ """Create a runnable that returns structured output matching a schema.
351
+
352
+ Uses the model's tool-calling or JSON mode capabilities to guarantee
353
+ output conforms to the specified schema.
354
+
355
+ Args:
356
+ schema: The output schema. Can be:
357
+ - A Pydantic BaseModel class
358
+ - A dictionary with JSON Schema
359
+ method: The method to use for structured output:
360
+ - "function_calling": Uses tool calling (default, recommended)
361
+ - "json_mode": Uses JSON response format
362
+ include_raw: If True, returns a dict with keys:
363
+ - "raw": The raw model response (BaseMessage)
364
+ - "parsed": The parsed structured output
365
+ - "parsing_error": Any parsing error that occurred
366
+ **kwargs: Additional arguments passed to the model.
367
+
368
+ Returns:
369
+ A Runnable that outputs structured data matching the schema.
370
+
371
+ Example:
372
+ .. code-block:: python
373
+
374
+ from pydantic import BaseModel, Field
375
+
376
+
377
+ class Person(BaseModel):
378
+ '''Information about a person.'''
379
+
380
+ name: str = Field(description="Person's name")
381
+ age: int = Field(description="Person's age")
382
+
383
+
384
+ model = ChatMaritaca()
385
+ structured_model = model.with_structured_output(Person)
386
+ result = structured_model.invoke("João tem 25 anos")
387
+ # Returns: Person(name="João", age=25)
388
+
389
+ Note:
390
+ The "function_calling" method is more reliable as it uses the
391
+ model's native tool calling capabilities.
392
+ """
393
+ if schema is None:
394
+ msg = "schema must be specified for with_structured_output"
395
+ raise ValueError(msg)
396
+
397
+ is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
398
+
399
+ if method == "function_calling":
400
+ formatted_tool = convert_to_openai_tool(schema)
401
+ tool_name = formatted_tool["function"]["name"]
402
+ llm = self.bind_tools(
403
+ [schema],
404
+ tool_choice={"type": "function", "function": {"name": tool_name}},
405
+ **kwargs,
406
+ )
407
+ if is_pydantic_schema:
408
+ # Type narrowing: we know schema is type[BaseModel] here
409
+ pydantic_schema: type[BaseModel] = schema # type: ignore[assignment]
410
+ output_parser: Runnable[Any, Any] = PydanticToolsParser(
411
+ tools=[pydantic_schema],
412
+ first_tool_only=True,
413
+ )
414
+ else:
415
+ output_parser = JsonOutputKeyToolsParser(
416
+ key_name=tool_name, first_tool_only=True
417
+ )
418
+
419
+ elif method == "json_mode":
420
+ llm = self.bind(
421
+ response_format={"type": "json_object"},
422
+ **kwargs,
423
+ )
424
+ if is_pydantic_schema:
425
+ # Type narrowing: we know schema is type[BaseModel] here
426
+ pydantic_schema = schema # type: ignore[assignment]
427
+ output_parser = PydanticOutputParser(pydantic_object=pydantic_schema)
428
+ else:
429
+ output_parser = JsonOutputParser()
430
+
431
+ else:
432
+ msg = (
433
+ f"Unrecognized method argument. Expected 'function_calling' or "
434
+ f"'json_mode'. Received: '{method}'"
435
+ )
436
+ raise ValueError(msg)
437
+
438
+ if include_raw:
439
+ parser_assign = RunnablePassthrough.assign(
440
+ parsed=itemgetter("raw") | output_parser,
441
+ parsing_error=lambda _: None,
442
+ )
443
+ parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
444
+ parser_with_fallback = parser_assign.with_fallbacks(
445
+ [parser_none], exception_key="parsing_error"
446
+ )
447
+ return RunnableMap(raw=llm) | parser_with_fallback
448
+
449
+ return llm | output_parser
450
+
451
+ def _generate(
452
+ self,
453
+ messages: list[BaseMessage],
454
+ stop: list[str] | None = None,
455
+ run_manager: CallbackManagerForLLMRun | None = None,
456
+ **kwargs: Any,
457
+ ) -> ChatResult:
458
+ """Generate a chat completion."""
459
+ if self.streaming:
460
+ stream_iter = self._stream(
461
+ messages, stop=stop, run_manager=run_manager, **kwargs
462
+ )
463
+ return generate_from_stream(stream_iter)
464
+
465
+ message_dicts, params = self._create_message_dicts(messages, stop)
466
+ params = {**params, **kwargs}
467
+
468
+ response = self._make_request(message_dicts, params)
469
+ return self._create_chat_result(response)
470
+
471
+ async def _agenerate(
472
+ self,
473
+ messages: list[BaseMessage],
474
+ stop: list[str] | None = None,
475
+ run_manager: AsyncCallbackManagerForLLMRun | None = None,
476
+ **kwargs: Any,
477
+ ) -> ChatResult:
478
+ """Async generate a chat completion."""
479
+ if self.streaming:
480
+ stream_iter = self._astream(
481
+ messages, stop=stop, run_manager=run_manager, **kwargs
482
+ )
483
+ return await agenerate_from_stream(stream_iter)
484
+
485
+ message_dicts, params = self._create_message_dicts(messages, stop)
486
+ params = {**params, **kwargs}
487
+
488
+ response = await self._amake_request(message_dicts, params)
489
+ return self._create_chat_result(response)
490
+
491
+ def _stream(
492
+ self,
493
+ messages: list[BaseMessage],
494
+ stop: list[str] | None = None,
495
+ run_manager: CallbackManagerForLLMRun | None = None,
496
+ **kwargs: Any,
497
+ ) -> Iterator[ChatGenerationChunk]:
498
+ """Stream a chat completion."""
499
+ message_dicts, params = self._create_message_dicts(messages, stop)
500
+ params = {**params, **kwargs, "stream": True}
501
+
502
+ with self.client.stream(
503
+ "POST",
504
+ "/chat/completions",
505
+ json={"messages": message_dicts, **params},
506
+ ) as response:
507
+ response.raise_for_status()
508
+ for line in response.iter_lines():
509
+ if line.startswith("data: "):
510
+ data = line[6:]
511
+ if data == "[DONE]":
512
+ break
513
+ try:
514
+ chunk = json.loads(data)
515
+ if not chunk.get("choices"):
516
+ continue
517
+ choice = chunk["choices"][0]
518
+ delta = choice.get("delta", {})
519
+ content = delta.get("content", "")
520
+
521
+ message_chunk = AIMessageChunk(content=content)
522
+ generation_info = {}
523
+
524
+ if finish_reason := choice.get("finish_reason"):
525
+ generation_info["finish_reason"] = finish_reason
526
+ generation_info["model"] = self.model_name
527
+
528
+ if generation_info:
529
+ message_chunk = message_chunk.model_copy(
530
+ update={"response_metadata": generation_info}
531
+ )
532
+
533
+ generation_chunk = ChatGenerationChunk(
534
+ message=message_chunk,
535
+ generation_info=generation_info or None,
536
+ )
537
+
538
+ if run_manager:
539
+ run_manager.on_llm_new_token(
540
+ generation_chunk.text, chunk=generation_chunk
541
+ )
542
+
543
+ yield generation_chunk
544
+
545
+ except json.JSONDecodeError:
546
+ continue
547
+
548
+ async def _astream(
549
+ self,
550
+ messages: list[BaseMessage],
551
+ stop: list[str] | None = None,
552
+ run_manager: AsyncCallbackManagerForLLMRun | None = None,
553
+ **kwargs: Any,
554
+ ) -> AsyncIterator[ChatGenerationChunk]:
555
+ """Async stream a chat completion."""
556
+ message_dicts, params = self._create_message_dicts(messages, stop)
557
+ params = {**params, **kwargs, "stream": True}
558
+
559
+ async with self.async_client.stream(
560
+ "POST",
561
+ "/chat/completions",
562
+ json={"messages": message_dicts, **params},
563
+ ) as response:
564
+ response.raise_for_status()
565
+ async for line in response.aiter_lines():
566
+ if line.startswith("data: "):
567
+ data = line[6:]
568
+ if data == "[DONE]":
569
+ break
570
+ try:
571
+ chunk = json.loads(data)
572
+ if not chunk.get("choices"):
573
+ continue
574
+ choice = chunk["choices"][0]
575
+ delta = choice.get("delta", {})
576
+ content = delta.get("content", "")
577
+
578
+ message_chunk = AIMessageChunk(content=content)
579
+ generation_info = {}
580
+
581
+ if finish_reason := choice.get("finish_reason"):
582
+ generation_info["finish_reason"] = finish_reason
583
+ generation_info["model"] = self.model_name
584
+
585
+ if generation_info:
586
+ message_chunk = message_chunk.model_copy(
587
+ update={"response_metadata": generation_info}
588
+ )
589
+
590
+ generation_chunk = ChatGenerationChunk(
591
+ message=message_chunk,
592
+ generation_info=generation_info or None,
593
+ )
594
+
595
+ if run_manager:
596
+ await run_manager.on_llm_new_token(
597
+ token=generation_chunk.text, chunk=generation_chunk
598
+ )
599
+
600
+ yield generation_chunk
601
+
602
+ except json.JSONDecodeError:
603
+ continue
604
+
605
+ def _make_request(
606
+ self, messages: list[dict[str, Any]], params: dict[str, Any]
607
+ ) -> dict[str, Any]:
608
+ """Make a sync request to Maritaca API."""
609
+ for attempt in range(self.max_retries + 1):
610
+ try:
611
+ response = self.client.post(
612
+ "/chat/completions",
613
+ json={"messages": messages, **params},
614
+ )
615
+ response.raise_for_status()
616
+ return response.json()
617
+ except httpx.HTTPStatusError as e:
618
+ is_rate_limited = e.response.status_code == HTTP_TOO_MANY_REQUESTS
619
+ if is_rate_limited and attempt < self.max_retries:
620
+ retry_after = int(e.response.headers.get("Retry-After", 60))
621
+ time.sleep(retry_after)
622
+ continue
623
+ raise
624
+ except httpx.TimeoutException:
625
+ if attempt < self.max_retries:
626
+ continue
627
+ raise
628
+ msg = f"Failed after {self.max_retries + 1} attempts"
629
+ raise RuntimeError(msg)
630
+
631
+ async def _amake_request(
632
+ self, messages: list[dict[str, Any]], params: dict[str, Any]
633
+ ) -> dict[str, Any]:
634
+ """Make an async request to Maritaca API."""
635
+ for attempt in range(self.max_retries + 1):
636
+ try:
637
+ response = await self.async_client.post(
638
+ "/chat/completions",
639
+ json={"messages": messages, **params},
640
+ )
641
+ response.raise_for_status()
642
+ return response.json()
643
+ except httpx.HTTPStatusError as e:
644
+ is_rate_limited = e.response.status_code == HTTP_TOO_MANY_REQUESTS
645
+ if is_rate_limited and attempt < self.max_retries:
646
+ retry_after = int(e.response.headers.get("Retry-After", 60))
647
+ await asyncio.sleep(retry_after)
648
+ continue
649
+ raise
650
+ except httpx.TimeoutException:
651
+ if attempt < self.max_retries:
652
+ continue
653
+ raise
654
+ msg = f"Failed after {self.max_retries + 1} attempts"
655
+ raise RuntimeError(msg)
656
+
657
+ def _create_message_dicts(
658
+ self, messages: list[BaseMessage], stop: list[str] | None
659
+ ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
660
+ """Convert LangChain messages to Maritaca format."""
661
+ params = self._default_params.copy()
662
+ if stop is not None:
663
+ params["stop"] = stop
664
+ message_dicts = [_convert_message_to_dict(m) for m in messages]
665
+ return message_dicts, params
666
+
667
+ def _create_chat_result(self, response: dict[str, Any]) -> ChatResult:
668
+ """Create a ChatResult from Maritaca API response."""
669
+ generations = []
670
+ token_usage = response.get("usage", {})
671
+
672
+ for choice in response.get("choices", []):
673
+ message = _convert_dict_to_message(choice.get("message", {}))
674
+
675
+ if token_usage and isinstance(message, AIMessage):
676
+ message.usage_metadata = _create_usage_metadata(token_usage)
677
+
678
+ generation_info = {"finish_reason": choice.get("finish_reason")}
679
+ gen = ChatGeneration(message=message, generation_info=generation_info)
680
+ generations.append(gen)
681
+
682
+ llm_output = {
683
+ "token_usage": token_usage,
684
+ "model": response.get("model", self.model_name),
685
+ }
686
+
687
+ return ChatResult(generations=generations, llm_output=llm_output)
688
+
689
+
690
+ def _convert_message_to_dict(message: BaseMessage) -> dict[str, Any]:
691
+ """Convert a LangChain message to Maritaca format.
692
+
693
+ Args:
694
+ message: The LangChain message.
695
+
696
+ Returns:
697
+ Dictionary in Maritaca API format.
698
+ """
699
+ if isinstance(message, ChatMessage):
700
+ return {"role": message.role, "content": message.content}
701
+ if isinstance(message, HumanMessage):
702
+ return {"role": "user", "content": message.content}
703
+ if isinstance(message, AIMessage):
704
+ result: dict[str, Any] = {
705
+ "role": "assistant",
706
+ "content": message.content or "",
707
+ }
708
+ # Handle tool calls in AIMessage
709
+ if message.tool_calls:
710
+ result["tool_calls"] = [
711
+ {
712
+ "id": tc["id"],
713
+ "type": "function",
714
+ "function": {
715
+ "name": tc["name"],
716
+ "arguments": json.dumps(tc["args"]),
717
+ },
718
+ }
719
+ for tc in message.tool_calls
720
+ ]
721
+ return result
722
+ if isinstance(message, SystemMessage):
723
+ return {"role": "system", "content": message.content}
724
+ if isinstance(message, ToolMessage):
725
+ return {
726
+ "role": "tool",
727
+ "content": message.content,
728
+ "tool_call_id": message.tool_call_id,
729
+ }
730
+ msg = f"Got unknown message type: {type(message)}"
731
+ raise TypeError(msg)
732
+
733
+
734
+ def _convert_dict_to_message(message_dict: Mapping[str, Any]) -> BaseMessage:
735
+ """Convert a Maritaca message dict to LangChain message.
736
+
737
+ Args:
738
+ message_dict: Dictionary from Maritaca API response.
739
+
740
+ Returns:
741
+ LangChain BaseMessage.
742
+ """
743
+ role = message_dict.get("role", "")
744
+ content = message_dict.get("content", "") or ""
745
+
746
+ if role == "user":
747
+ return HumanMessage(content=content)
748
+ if role == "assistant":
749
+ # Parse tool_calls if present
750
+ tool_calls_data = message_dict.get("tool_calls", [])
751
+ tool_calls = []
752
+ for tc in tool_calls_data:
753
+ func = tc.get("function", {})
754
+ args_str = func.get("arguments", "{}")
755
+ try:
756
+ args = json.loads(args_str)
757
+ except json.JSONDecodeError:
758
+ args = {}
759
+ tool_calls.append(
760
+ ToolCall(
761
+ name=func.get("name", ""),
762
+ args=args,
763
+ id=tc.get("id", ""),
764
+ )
765
+ )
766
+ return AIMessage(content=content, tool_calls=tool_calls if tool_calls else [])
767
+ if role == "system":
768
+ return SystemMessage(content=content)
769
+ if role == "tool":
770
+ return ToolMessage(
771
+ content=content,
772
+ tool_call_id=message_dict.get("tool_call_id", ""),
773
+ )
774
+ return ChatMessage(content=content, role=role)
775
+
776
+
777
+ def _create_usage_metadata(token_usage: dict[str, Any]) -> UsageMetadata:
778
+ """Create usage metadata from Maritaca token usage response.
779
+
780
+ Args:
781
+ token_usage: Token usage dict from Maritaca API response.
782
+
783
+ Returns:
784
+ UsageMetadata with token counts.
785
+ """
786
+ input_tokens = token_usage.get("prompt_tokens", 0)
787
+ output_tokens = token_usage.get("completion_tokens", 0)
788
+ total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)
789
+
790
+ return UsageMetadata(
791
+ input_tokens=input_tokens,
792
+ output_tokens=output_tokens,
793
+ total_tokens=total_tokens,
794
+ )