chatlas 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +79 -45
- chatlas/_auto.py +3 -12
- chatlas/_chat.py +800 -169
- chatlas/_content.py +149 -29
- chatlas/_databricks.py +4 -14
- chatlas/_github.py +21 -25
- chatlas/_google.py +71 -32
- chatlas/_groq.py +15 -18
- chatlas/_interpolate.py +3 -4
- chatlas/_mcp_manager.py +306 -0
- chatlas/_ollama.py +14 -18
- chatlas/_openai.py +74 -39
- chatlas/_perplexity.py +14 -18
- chatlas/_provider.py +78 -8
- chatlas/_snowflake.py +29 -18
- chatlas/_tokens.py +93 -5
- chatlas/_tools.py +181 -22
- chatlas/_turn.py +2 -18
- chatlas/_utils.py +27 -1
- chatlas/_version.py +2 -2
- chatlas/data/prices.json +264 -0
- chatlas/types/anthropic/_submit.py +2 -0
- chatlas/types/openai/_client.py +1 -0
- chatlas/types/openai/_client_azure.py +1 -0
- chatlas/types/openai/_submit.py +4 -1
- chatlas-0.9.0.dist-info/METADATA +141 -0
- chatlas-0.9.0.dist-info/RECORD +48 -0
- chatlas-0.8.0.dist-info/METADATA +0 -383
- chatlas-0.8.0.dist-info/RECORD +0 -46
- {chatlas-0.8.0.dist-info → chatlas-0.9.0.dist-info}/WHEEL +0 -0
- {chatlas-0.8.0.dist-info → chatlas-0.9.0.dist-info}/licenses/LICENSE +0 -0
chatlas/_content.py
CHANGED
|
@@ -1,15 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import textwrap
|
|
4
3
|
from pprint import pformat
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any, Literal, Optional, Union
|
|
6
5
|
|
|
7
6
|
import orjson
|
|
8
7
|
from pydantic import BaseModel, ConfigDict
|
|
9
8
|
|
|
10
|
-
if TYPE_CHECKING:
|
|
11
|
-
from htmltools import TagChild
|
|
12
|
-
|
|
13
9
|
ImageContentTypes = Literal[
|
|
14
10
|
"image/png",
|
|
15
11
|
"image/jpeg",
|
|
@@ -26,6 +22,8 @@ ContentTypeEnum = Literal[
|
|
|
26
22
|
"image_inline",
|
|
27
23
|
"tool_request",
|
|
28
24
|
"tool_result",
|
|
25
|
+
"tool_result_image",
|
|
26
|
+
"tool_result_resource",
|
|
29
27
|
"json",
|
|
30
28
|
"pdf",
|
|
31
29
|
]
|
|
@@ -202,7 +200,10 @@ class ContentToolRequest(Content):
|
|
|
202
200
|
return ", ".join(f"{k}={v}" for k, v in self.arguments.items())
|
|
203
201
|
return str(self.arguments)
|
|
204
202
|
|
|
205
|
-
def
|
|
203
|
+
def __repr_html__(self) -> str:
|
|
204
|
+
return str(self.tagify())
|
|
205
|
+
|
|
206
|
+
def tagify(self):
|
|
206
207
|
"Returns an HTML string suitable for passing to htmltools/shiny's `Chat()` component."
|
|
207
208
|
try:
|
|
208
209
|
from htmltools import HTML, TagList, head_content, tags
|
|
@@ -314,7 +315,7 @@ class ContentToolResult(Content):
|
|
|
314
315
|
return res + ">"
|
|
315
316
|
|
|
316
317
|
# Format the value for display purposes
|
|
317
|
-
def _get_display_value(self)
|
|
318
|
+
def _get_display_value(self):
|
|
318
319
|
if self.error:
|
|
319
320
|
return f"Tool call failed with error: '{self.error}'"
|
|
320
321
|
|
|
@@ -333,7 +334,7 @@ class ContentToolResult(Content):
|
|
|
333
334
|
# Not valid JSON, return as string
|
|
334
335
|
return val
|
|
335
336
|
|
|
336
|
-
return val
|
|
337
|
+
return str(val)
|
|
337
338
|
|
|
338
339
|
def get_model_value(self) -> object:
|
|
339
340
|
"Get the actual value sent to the model."
|
|
@@ -370,37 +371,62 @@ class ContentToolResult(Content):
|
|
|
370
371
|
|
|
371
372
|
return orjson.dumps(value).decode("utf-8")
|
|
372
373
|
|
|
373
|
-
def
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
374
|
+
def __repr_html__(self):
|
|
375
|
+
return str(self.tagify())
|
|
376
|
+
|
|
377
|
+
def tagify(self):
|
|
378
|
+
"A method for rendering this object via htmltools/shiny."
|
|
377
379
|
try:
|
|
378
|
-
from htmltools import HTML
|
|
380
|
+
from htmltools import HTML, html_escape
|
|
379
381
|
except ImportError:
|
|
380
382
|
raise ImportError(
|
|
381
383
|
".tagify() is only intended to be called by htmltools/shiny, ",
|
|
382
384
|
"but htmltools is not installed. ",
|
|
383
385
|
)
|
|
384
386
|
|
|
387
|
+
# Helper function to format code blocks (optionally with labels for arguments).
|
|
388
|
+
def pre_code(code: str, label: str | None = None) -> str:
|
|
389
|
+
lbl = f"<span class='input-parameter-label'>{label}</span>" if label else ""
|
|
390
|
+
return f"<pre>{lbl}<code>{html_escape(code)}</code></pre>"
|
|
391
|
+
|
|
392
|
+
# Helper function to wrap content in a <details> block.
|
|
393
|
+
def details_block(summary: str, content: str, open_: bool = True) -> str:
|
|
394
|
+
open_attr = " open" if open_ else ""
|
|
395
|
+
return (
|
|
396
|
+
f"<details{open_attr}><summary>{summary}</summary>{content}</details>"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# First, format the input parameters.
|
|
400
|
+
args = self.arguments or {}
|
|
401
|
+
if isinstance(args, dict):
|
|
402
|
+
args = "".join(pre_code(str(v), label=k) for k, v in args.items())
|
|
403
|
+
else:
|
|
404
|
+
args = pre_code(str(args))
|
|
405
|
+
|
|
406
|
+
# Wrap the input parameters in an (open) details block.
|
|
407
|
+
if args:
|
|
408
|
+
params = details_block("<strong>Input parameters:</strong>", args)
|
|
409
|
+
else:
|
|
410
|
+
params = ""
|
|
411
|
+
|
|
412
|
+
# Also wrap the tool result in an (open) details block.
|
|
413
|
+
result = details_block(
|
|
414
|
+
"<strong>Result:</strong>",
|
|
415
|
+
pre_code(self._get_display_value()),
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Put both the result and parameters into a container
|
|
419
|
+
result_div = f'<div class="chatlas-tool-result-content">{result}{params}</div>'
|
|
420
|
+
|
|
421
|
+
# Header for the top-level result details block.
|
|
385
422
|
if not self.error:
|
|
386
|
-
header = f"
|
|
423
|
+
header = f"Result from tool call: <code>{self.name}</code>"
|
|
387
424
|
else:
|
|
388
425
|
header = f"❌ Failed to call tool <code>{self.name}</code>"
|
|
389
426
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
return HTML(
|
|
394
|
-
textwrap.dedent(f"""
|
|
395
|
-
<details class="chatlas-tool-result">
|
|
396
|
-
<summary>{header}</summary>
|
|
397
|
-
<div class="chatlas-tool-result-content">
|
|
398
|
-
Result: <p><code>{content}</code></p>
|
|
399
|
-
Arguments: <p><code>{args}</code></p>
|
|
400
|
-
</div>
|
|
401
|
-
</details>
|
|
402
|
-
""")
|
|
403
|
-
)
|
|
427
|
+
res = details_block(header, result_div, open_=False)
|
|
428
|
+
|
|
429
|
+
return HTML(f'<div class="chatlas-tool-result">{res}</div>')
|
|
404
430
|
|
|
405
431
|
def _arguments_str(self) -> str:
|
|
406
432
|
if isinstance(self.arguments, dict):
|
|
@@ -408,6 +434,68 @@ class ContentToolResult(Content):
|
|
|
408
434
|
return str(self.arguments)
|
|
409
435
|
|
|
410
436
|
|
|
437
|
+
class ContentToolResultImage(ContentToolResult):
|
|
438
|
+
"""
|
|
439
|
+
A tool result that contains an image.
|
|
440
|
+
|
|
441
|
+
This is a specialized version of ContentToolResult for returning images.
|
|
442
|
+
It requires the image data to be base64-encoded (as `value`) and
|
|
443
|
+
the MIME type of the image (as `mime_type`).
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
value
|
|
448
|
+
The image data as a base64-encoded string.
|
|
449
|
+
mime_type
|
|
450
|
+
The MIME type of the image (e.g., "image/png").
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
value: str
|
|
454
|
+
model_format: Literal["auto", "json", "str", "as_is"] = "as_is"
|
|
455
|
+
mime_type: ImageContentTypes
|
|
456
|
+
|
|
457
|
+
content_type: ContentTypeEnum = "tool_result_image"
|
|
458
|
+
|
|
459
|
+
def __str__(self):
|
|
460
|
+
return f"<ContentToolResultImage mime_type='{self.mime_type}'>"
|
|
461
|
+
|
|
462
|
+
def _repr_markdown_(self):
|
|
463
|
+
return f""
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
class ContentToolResultResource(ContentToolResult):
|
|
467
|
+
"""
|
|
468
|
+
A tool result that contains a resource.
|
|
469
|
+
|
|
470
|
+
This is a specialized version of ContentToolResult for returning resources
|
|
471
|
+
(e.g., images, files) as raw bytes. It requires the resource data to be
|
|
472
|
+
provided as bytes (as `value`) and the MIME type of the resource (as
|
|
473
|
+
`mime_type`).
|
|
474
|
+
|
|
475
|
+
Parameters
|
|
476
|
+
----------
|
|
477
|
+
value
|
|
478
|
+
The resource data, in bytes.
|
|
479
|
+
mime_type
|
|
480
|
+
The MIME type of the image (e.g., "image/png").
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
value: bytes
|
|
484
|
+
model_format: Literal["auto", "json", "str", "as_is"] = "as_is"
|
|
485
|
+
mime_type: Optional[str]
|
|
486
|
+
|
|
487
|
+
content_type: ContentTypeEnum = "tool_result_resource"
|
|
488
|
+
|
|
489
|
+
def __str__(self):
|
|
490
|
+
return f"<ContentToolResultResource mime_type='{self.mime_type}'>"
|
|
491
|
+
|
|
492
|
+
def _repr_mimebundle_(self, include=None, exclude=None):
|
|
493
|
+
return {
|
|
494
|
+
self.mime_type: self.value,
|
|
495
|
+
"text/plain": f"<{self.mime_type} object>",
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
|
|
411
499
|
class ContentJson(Content):
|
|
412
500
|
"""
|
|
413
501
|
JSON content
|
|
@@ -468,6 +556,8 @@ ContentUnion = Union[
|
|
|
468
556
|
ContentImageInline,
|
|
469
557
|
ContentToolRequest,
|
|
470
558
|
ContentToolResult,
|
|
559
|
+
ContentToolResultImage,
|
|
560
|
+
ContentToolResultResource,
|
|
471
561
|
ContentJson,
|
|
472
562
|
ContentPDF,
|
|
473
563
|
]
|
|
@@ -494,6 +584,10 @@ def create_content(data: dict[str, Any]) -> ContentUnion:
|
|
|
494
584
|
return ContentToolRequest.model_validate(data)
|
|
495
585
|
elif ct == "tool_result":
|
|
496
586
|
return ContentToolResult.model_validate(data)
|
|
587
|
+
elif ct == "tool_result_image":
|
|
588
|
+
return ContentToolResultImage.model_validate(data)
|
|
589
|
+
elif ct == "tool_result_resource":
|
|
590
|
+
return ContentToolResultResource.model_validate(data)
|
|
497
591
|
elif ct == "json":
|
|
498
592
|
return ContentJson.model_validate(data)
|
|
499
593
|
elif ct == "pdf":
|
|
@@ -536,11 +630,12 @@ TOOL_CSS = """
|
|
|
536
630
|
vertical-align: middle;
|
|
537
631
|
}
|
|
538
632
|
|
|
539
|
-
.chatlas-tool-result[open] summary::after {
|
|
633
|
+
.chatlas-tool-result details[open] summary::after {
|
|
540
634
|
content: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='16' height='16' fill='currentColor' class='bi bi-caret-down-fill' viewBox='0 0 16 16'%3E%3Cpath d='M7.247 11.14 2.451 5.658C1.885 5.013 2.345 4 3.204 4h9.592a1 1 0 0 1 .753 1.659l-4.796 5.48a1 1 0 0 1-1.506 0z'/%3E%3C/svg%3E");
|
|
541
635
|
}
|
|
542
636
|
|
|
543
637
|
.chatlas-tool-result-content {
|
|
638
|
+
position: relative;
|
|
544
639
|
border: 1px solid var(--bs-border-color, #0066cc);
|
|
545
640
|
width: 100%;
|
|
546
641
|
padding: 1rem;
|
|
@@ -548,4 +643,29 @@ TOOL_CSS = """
|
|
|
548
643
|
margin-top: 1rem;
|
|
549
644
|
margin-bottom: 1rem;
|
|
550
645
|
}
|
|
646
|
+
|
|
647
|
+
.chatlas-tool-result-content pre, .chatlas-tool-result-content code {
|
|
648
|
+
background-color: var(--bs-body-bg, white) !important;
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
.chatlas-tool-result-content .input-parameter-label {
|
|
652
|
+
position: absolute;
|
|
653
|
+
top: 0;
|
|
654
|
+
width: 100%;
|
|
655
|
+
text-align: center;
|
|
656
|
+
font-weight: 300;
|
|
657
|
+
font-size: 0.8rem;
|
|
658
|
+
color: var(--bs-gray-600);
|
|
659
|
+
background-color: var(--bs-body-bg);
|
|
660
|
+
padding: 0.5rem;
|
|
661
|
+
font-family: var(--bs-font-monospace, monospace);
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
pre:has(> .input-parameter-label) {
|
|
665
|
+
padding-top: 1.5rem;
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
shiny-markdown-stream p:first-of-type:empty {
|
|
669
|
+
display: none;
|
|
670
|
+
}
|
|
551
671
|
"""
|
chatlas/_databricks.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional
|
|
|
5
5
|
from ._chat import Chat
|
|
6
6
|
from ._logging import log_model_default
|
|
7
7
|
from ._openai import OpenAIProvider
|
|
8
|
-
from ._turn import Turn, normalize_turns
|
|
9
8
|
|
|
10
9
|
if TYPE_CHECKING:
|
|
11
10
|
from databricks.sdk import WorkspaceClient
|
|
@@ -18,7 +17,6 @@ def ChatDatabricks(
|
|
|
18
17
|
*,
|
|
19
18
|
system_prompt: Optional[str] = None,
|
|
20
19
|
model: Optional[str] = None,
|
|
21
|
-
turns: Optional[list[Turn]] = None,
|
|
22
20
|
workspace_client: Optional["WorkspaceClient"] = None,
|
|
23
21
|
) -> Chat["SubmitInputArgs", ChatCompletion]:
|
|
24
22
|
"""
|
|
@@ -68,13 +66,6 @@ def ChatDatabricks(
|
|
|
68
66
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
69
67
|
default, and warn you about it. We strongly recommend explicitly
|
|
70
68
|
choosing a model for all but the most casual use.
|
|
71
|
-
turns
|
|
72
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
73
|
-
conversation). If not provided, the conversation begins from scratch. Do
|
|
74
|
-
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
75
|
-
message in the list should be a dictionary with at least `role` (usually
|
|
76
|
-
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
77
|
-
there is also a `content` field, which is a string.
|
|
78
69
|
workspace_client
|
|
79
70
|
A `databricks.sdk.WorkspaceClient()` to use for the connection. If not
|
|
80
71
|
provided, a new client will be created.
|
|
@@ -92,10 +83,7 @@ def ChatDatabricks(
|
|
|
92
83
|
model=model,
|
|
93
84
|
workspace_client=workspace_client,
|
|
94
85
|
),
|
|
95
|
-
|
|
96
|
-
turns or [],
|
|
97
|
-
system_prompt,
|
|
98
|
-
),
|
|
86
|
+
system_prompt=system_prompt,
|
|
99
87
|
)
|
|
100
88
|
|
|
101
89
|
|
|
@@ -104,6 +92,7 @@ class DatabricksProvider(OpenAIProvider):
|
|
|
104
92
|
self,
|
|
105
93
|
*,
|
|
106
94
|
model: str,
|
|
95
|
+
name: str = "Databricks",
|
|
107
96
|
workspace_client: Optional["WorkspaceClient"] = None,
|
|
108
97
|
):
|
|
109
98
|
try:
|
|
@@ -117,7 +106,8 @@ class DatabricksProvider(OpenAIProvider):
|
|
|
117
106
|
import httpx
|
|
118
107
|
from openai import AsyncOpenAI
|
|
119
108
|
|
|
120
|
-
|
|
109
|
+
super().__init__(name=name, model=model)
|
|
110
|
+
|
|
121
111
|
self._seed = None
|
|
122
112
|
|
|
123
113
|
if workspace_client is None:
|
chatlas/_github.py
CHANGED
|
@@ -5,9 +5,8 @@ from typing import TYPE_CHECKING, Optional
|
|
|
5
5
|
|
|
6
6
|
from ._chat import Chat
|
|
7
7
|
from ._logging import log_model_default
|
|
8
|
-
from ._openai import
|
|
9
|
-
from .
|
|
10
|
-
from ._utils import MISSING, MISSING_TYPE
|
|
8
|
+
from ._openai import OpenAIProvider
|
|
9
|
+
from ._utils import MISSING, MISSING_TYPE, is_testing
|
|
11
10
|
|
|
12
11
|
if TYPE_CHECKING:
|
|
13
12
|
from ._openai import ChatCompletion
|
|
@@ -17,7 +16,6 @@ if TYPE_CHECKING:
|
|
|
17
16
|
def ChatGithub(
|
|
18
17
|
*,
|
|
19
18
|
system_prompt: Optional[str] = None,
|
|
20
|
-
turns: Optional[list[Turn]] = None,
|
|
21
19
|
model: Optional[str] = None,
|
|
22
20
|
api_key: Optional[str] = None,
|
|
23
21
|
base_url: str = "https://models.inference.ai.azure.com/",
|
|
@@ -48,7 +46,7 @@ def ChatGithub(
|
|
|
48
46
|
import os
|
|
49
47
|
from chatlas import ChatGithub
|
|
50
48
|
|
|
51
|
-
chat = ChatGithub(api_key=os.getenv("
|
|
49
|
+
chat = ChatGithub(api_key=os.getenv("GITHUB_TOKEN"))
|
|
52
50
|
chat.chat("What is the capital of France?")
|
|
53
51
|
```
|
|
54
52
|
|
|
@@ -56,20 +54,13 @@ def ChatGithub(
|
|
|
56
54
|
----------
|
|
57
55
|
system_prompt
|
|
58
56
|
A system prompt to set the behavior of the assistant.
|
|
59
|
-
turns
|
|
60
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
61
|
-
conversation). If not provided, the conversation begins from scratch. Do
|
|
62
|
-
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
63
|
-
message in the list should be a dictionary with at least `role` (usually
|
|
64
|
-
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
65
|
-
there is also a `content` field, which is a string.
|
|
66
57
|
model
|
|
67
58
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
68
59
|
default, and warn you about it. We strongly recommend explicitly
|
|
69
60
|
choosing a model for all but the most casual use.
|
|
70
61
|
api_key
|
|
71
62
|
The API key to use for authentication. You generally should not supply
|
|
72
|
-
this directly, but instead set the `
|
|
63
|
+
this directly, but instead set the `GITHUB_TOKEN` environment variable.
|
|
73
64
|
base_url
|
|
74
65
|
The base URL to the endpoint; the default uses Github's API.
|
|
75
66
|
seed
|
|
@@ -106,7 +97,7 @@ def ChatGithub(
|
|
|
106
97
|
|
|
107
98
|
```shell
|
|
108
99
|
# .env
|
|
109
|
-
|
|
100
|
+
GITHUB_TOKEN=...
|
|
110
101
|
```
|
|
111
102
|
|
|
112
103
|
```python
|
|
@@ -122,20 +113,25 @@ def ChatGithub(
|
|
|
122
113
|
before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file):
|
|
123
114
|
|
|
124
115
|
```shell
|
|
125
|
-
export
|
|
116
|
+
export GITHUB_TOKEN=...
|
|
126
117
|
```
|
|
127
118
|
"""
|
|
128
119
|
if model is None:
|
|
129
|
-
model = log_model_default("gpt-
|
|
120
|
+
model = log_model_default("gpt-4.1")
|
|
130
121
|
if api_key is None:
|
|
131
|
-
api_key = os.getenv("GITHUB_PAT")
|
|
132
|
-
|
|
133
|
-
|
|
122
|
+
api_key = os.getenv("GITHUB_TOKEN", os.getenv("GITHUB_PAT"))
|
|
123
|
+
|
|
124
|
+
if isinstance(seed, MISSING_TYPE):
|
|
125
|
+
seed = 1014 if is_testing() else None
|
|
126
|
+
|
|
127
|
+
return Chat(
|
|
128
|
+
provider=OpenAIProvider(
|
|
129
|
+
api_key=api_key,
|
|
130
|
+
model=model,
|
|
131
|
+
base_url=base_url,
|
|
132
|
+
seed=seed,
|
|
133
|
+
name="GitHub",
|
|
134
|
+
kwargs=kwargs,
|
|
135
|
+
),
|
|
134
136
|
system_prompt=system_prompt,
|
|
135
|
-
turns=turns,
|
|
136
|
-
model=model,
|
|
137
|
-
api_key=api_key,
|
|
138
|
-
base_url=base_url,
|
|
139
|
-
seed=seed,
|
|
140
|
-
kwargs=kwargs,
|
|
141
137
|
)
|
chatlas/_google.py
CHANGED
|
@@ -16,17 +16,20 @@ from ._content import (
|
|
|
16
16
|
ContentText,
|
|
17
17
|
ContentToolRequest,
|
|
18
18
|
ContentToolResult,
|
|
19
|
+
ContentToolResultImage,
|
|
20
|
+
ContentToolResultResource,
|
|
19
21
|
)
|
|
20
22
|
from ._logging import log_model_default
|
|
21
23
|
from ._merge import merge_dicts
|
|
22
|
-
from ._provider import Provider
|
|
24
|
+
from ._provider import Provider, StandardModelParamNames, StandardModelParams
|
|
23
25
|
from ._tokens import tokens_log
|
|
24
26
|
from ._tools import Tool
|
|
25
|
-
from ._turn import Turn,
|
|
27
|
+
from ._turn import Turn, user_turn
|
|
26
28
|
|
|
27
29
|
if TYPE_CHECKING:
|
|
28
30
|
from google.genai.types import Content as GoogleContent
|
|
29
31
|
from google.genai.types import (
|
|
32
|
+
GenerateContentConfigDict,
|
|
30
33
|
GenerateContentResponse,
|
|
31
34
|
GenerateContentResponseDict,
|
|
32
35
|
Part,
|
|
@@ -41,7 +44,6 @@ else:
|
|
|
41
44
|
def ChatGoogle(
|
|
42
45
|
*,
|
|
43
46
|
system_prompt: Optional[str] = None,
|
|
44
|
-
turns: Optional[list[Turn]] = None,
|
|
45
47
|
model: Optional[str] = None,
|
|
46
48
|
api_key: Optional[str] = None,
|
|
47
49
|
kwargs: Optional["ChatClientArgs"] = None,
|
|
@@ -80,13 +82,6 @@ def ChatGoogle(
|
|
|
80
82
|
----------
|
|
81
83
|
system_prompt
|
|
82
84
|
A system prompt to set the behavior of the assistant.
|
|
83
|
-
turns
|
|
84
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
85
|
-
conversation). If not provided, the conversation begins from scratch.
|
|
86
|
-
Do not provide non-`None` values for both `turns` and `system_prompt`.
|
|
87
|
-
Each message in the list should be a dictionary with at least `role`
|
|
88
|
-
(usually `system`, `user`, or `assistant`, but `tool` is also possible).
|
|
89
|
-
Normally there is also a `content` field, which is a string.
|
|
90
85
|
model
|
|
91
86
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
92
87
|
default, and warn you about it. We strongly recommend explicitly choosing
|
|
@@ -140,24 +135,25 @@ def ChatGoogle(
|
|
|
140
135
|
"""
|
|
141
136
|
|
|
142
137
|
if model is None:
|
|
143
|
-
model = log_model_default("gemini-2.
|
|
138
|
+
model = log_model_default("gemini-2.5-flash")
|
|
144
139
|
|
|
145
140
|
return Chat(
|
|
146
141
|
provider=GoogleProvider(
|
|
147
142
|
model=model,
|
|
148
143
|
api_key=api_key,
|
|
144
|
+
name="Google/Gemini",
|
|
149
145
|
kwargs=kwargs,
|
|
150
146
|
),
|
|
151
|
-
|
|
152
|
-
turns or [],
|
|
153
|
-
system_prompt=system_prompt,
|
|
154
|
-
),
|
|
147
|
+
system_prompt=system_prompt,
|
|
155
148
|
)
|
|
156
149
|
|
|
157
150
|
|
|
158
151
|
class GoogleProvider(
|
|
159
152
|
Provider[
|
|
160
|
-
GenerateContentResponse,
|
|
153
|
+
GenerateContentResponse,
|
|
154
|
+
GenerateContentResponse,
|
|
155
|
+
"GenerateContentResponseDict",
|
|
156
|
+
"SubmitInputArgs",
|
|
161
157
|
]
|
|
162
158
|
):
|
|
163
159
|
def __init__(
|
|
@@ -165,6 +161,7 @@ class GoogleProvider(
|
|
|
165
161
|
*,
|
|
166
162
|
model: str,
|
|
167
163
|
api_key: str | None,
|
|
164
|
+
name: str = "Google/Gemini",
|
|
168
165
|
kwargs: Optional["ChatClientArgs"],
|
|
169
166
|
):
|
|
170
167
|
try:
|
|
@@ -174,8 +171,7 @@ class GoogleProvider(
|
|
|
174
171
|
f"The {self.__class__.__name__} class requires the `google-genai` package. "
|
|
175
172
|
"Install it with `pip install google-genai`."
|
|
176
173
|
)
|
|
177
|
-
|
|
178
|
-
self._model = model
|
|
174
|
+
super().__init__(name=name, model=model)
|
|
179
175
|
|
|
180
176
|
kwargs_full: "ChatClientArgs" = {
|
|
181
177
|
"api_key": api_key,
|
|
@@ -267,7 +263,7 @@ class GoogleProvider(
|
|
|
267
263
|
from google.genai.types import Tool as GoogleTool
|
|
268
264
|
|
|
269
265
|
kwargs_full: "SubmitInputArgs" = {
|
|
270
|
-
"model": self.
|
|
266
|
+
"model": self.model,
|
|
271
267
|
"contents": cast("GoogleContent", self._google_contents(turns)),
|
|
272
268
|
**(kwargs or {}),
|
|
273
269
|
}
|
|
@@ -430,6 +426,13 @@ class GoogleProvider(
|
|
|
430
426
|
)
|
|
431
427
|
)
|
|
432
428
|
elif isinstance(content, ContentToolResult):
|
|
429
|
+
if isinstance(
|
|
430
|
+
content, (ContentToolResultImage, ContentToolResultResource)
|
|
431
|
+
):
|
|
432
|
+
raise NotImplementedError(
|
|
433
|
+
"Tool results with images or resources aren't supported by Google (Gemini). "
|
|
434
|
+
)
|
|
435
|
+
|
|
433
436
|
if content.error:
|
|
434
437
|
resp = {"error": content.error}
|
|
435
438
|
else:
|
|
@@ -524,6 +527,52 @@ class GoogleProvider(
|
|
|
524
527
|
completion=message,
|
|
525
528
|
)
|
|
526
529
|
|
|
530
|
+
def translate_model_params(self, params: StandardModelParams) -> "SubmitInputArgs":
|
|
531
|
+
config: "GenerateContentConfigDict" = {}
|
|
532
|
+
if "temperature" in params:
|
|
533
|
+
config["temperature"] = params["temperature"]
|
|
534
|
+
|
|
535
|
+
if "top_p" in params:
|
|
536
|
+
config["top_p"] = params["top_p"]
|
|
537
|
+
|
|
538
|
+
if "top_k" in params:
|
|
539
|
+
config["top_k"] = params["top_k"]
|
|
540
|
+
|
|
541
|
+
if "frequency_penalty" in params:
|
|
542
|
+
config["frequency_penalty"] = params["frequency_penalty"]
|
|
543
|
+
|
|
544
|
+
if "presence_penalty" in params:
|
|
545
|
+
config["presence_penalty"] = params["presence_penalty"]
|
|
546
|
+
|
|
547
|
+
if "seed" in params:
|
|
548
|
+
config["seed"] = params["seed"]
|
|
549
|
+
|
|
550
|
+
if "max_tokens" in params:
|
|
551
|
+
config["max_output_tokens"] = params["max_tokens"]
|
|
552
|
+
|
|
553
|
+
if "log_probs" in params:
|
|
554
|
+
config["logprobs"] = params["log_probs"]
|
|
555
|
+
|
|
556
|
+
if "stop_sequences" in params:
|
|
557
|
+
config["stop_sequences"] = params["stop_sequences"]
|
|
558
|
+
|
|
559
|
+
res: "SubmitInputArgs" = {"config": config}
|
|
560
|
+
|
|
561
|
+
return res
|
|
562
|
+
|
|
563
|
+
def supported_model_params(self) -> set[StandardModelParamNames]:
|
|
564
|
+
return {
|
|
565
|
+
"temperature",
|
|
566
|
+
"top_p",
|
|
567
|
+
"top_k",
|
|
568
|
+
"frequency_penalty",
|
|
569
|
+
"presence_penalty",
|
|
570
|
+
"seed",
|
|
571
|
+
"max_tokens",
|
|
572
|
+
"log_probs",
|
|
573
|
+
"stop_sequences",
|
|
574
|
+
}
|
|
575
|
+
|
|
527
576
|
|
|
528
577
|
def ChatVertex(
|
|
529
578
|
*,
|
|
@@ -532,7 +581,6 @@ def ChatVertex(
|
|
|
532
581
|
location: Optional[str] = None,
|
|
533
582
|
api_key: Optional[str] = None,
|
|
534
583
|
system_prompt: Optional[str] = None,
|
|
535
|
-
turns: Optional[list[Turn]] = None,
|
|
536
584
|
kwargs: Optional["ChatClientArgs"] = None,
|
|
537
585
|
) -> Chat["SubmitInputArgs", GenerateContentResponse]:
|
|
538
586
|
"""
|
|
@@ -569,13 +617,6 @@ def ChatVertex(
|
|
|
569
617
|
GOOGLE_CLOUD_LOCATION environment variable will be used.
|
|
570
618
|
system_prompt
|
|
571
619
|
A system prompt to set the behavior of the assistant.
|
|
572
|
-
turns
|
|
573
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
574
|
-
conversation). If not provided, the conversation begins from scratch.
|
|
575
|
-
Do not provide non-`None` values for both `turns` and `system_prompt`.
|
|
576
|
-
Each message in the list should be a dictionary with at least `role`
|
|
577
|
-
(usually `system`, `user`, or `assistant`, but `tool` is also possible).
|
|
578
|
-
Normally there is also a `content` field, which is a string.
|
|
579
620
|
|
|
580
621
|
Returns
|
|
581
622
|
-------
|
|
@@ -605,16 +646,14 @@ def ChatVertex(
|
|
|
605
646
|
kwargs["location"] = location
|
|
606
647
|
|
|
607
648
|
if model is None:
|
|
608
|
-
model = log_model_default("gemini-2.
|
|
649
|
+
model = log_model_default("gemini-2.5-flash")
|
|
609
650
|
|
|
610
651
|
return Chat(
|
|
611
652
|
provider=GoogleProvider(
|
|
612
653
|
model=model,
|
|
613
654
|
api_key=api_key,
|
|
655
|
+
name="Google/Vertex",
|
|
614
656
|
kwargs=kwargs,
|
|
615
657
|
),
|
|
616
|
-
|
|
617
|
-
turns or [],
|
|
618
|
-
system_prompt=system_prompt,
|
|
619
|
-
),
|
|
658
|
+
system_prompt=system_prompt,
|
|
620
659
|
)
|
chatlas/_groq.py
CHANGED
|
@@ -5,9 +5,8 @@ from typing import TYPE_CHECKING, Optional
|
|
|
5
5
|
|
|
6
6
|
from ._chat import Chat
|
|
7
7
|
from ._logging import log_model_default
|
|
8
|
-
from ._openai import
|
|
9
|
-
from .
|
|
10
|
-
from ._utils import MISSING, MISSING_TYPE
|
|
8
|
+
from ._openai import OpenAIProvider
|
|
9
|
+
from ._utils import MISSING, MISSING_TYPE, is_testing
|
|
11
10
|
|
|
12
11
|
if TYPE_CHECKING:
|
|
13
12
|
from ._openai import ChatCompletion
|
|
@@ -17,7 +16,6 @@ if TYPE_CHECKING:
|
|
|
17
16
|
def ChatGroq(
|
|
18
17
|
*,
|
|
19
18
|
system_prompt: Optional[str] = None,
|
|
20
|
-
turns: Optional[list[Turn]] = None,
|
|
21
19
|
model: Optional[str] = None,
|
|
22
20
|
api_key: Optional[str] = None,
|
|
23
21
|
base_url: str = "https://api.groq.com/openai/v1",
|
|
@@ -53,13 +51,6 @@ def ChatGroq(
|
|
|
53
51
|
----------
|
|
54
52
|
system_prompt
|
|
55
53
|
A system prompt to set the behavior of the assistant.
|
|
56
|
-
turns
|
|
57
|
-
A list of turns to start the chat with (i.e., continuing a previous
|
|
58
|
-
conversation). If not provided, the conversation begins from scratch.
|
|
59
|
-
Do not provide non-`None` values for both `turns` and `system_prompt`.
|
|
60
|
-
Each message in the list should be a dictionary with at least `role`
|
|
61
|
-
(usually `system`, `user`, or `assistant`, but `tool` is also possible).
|
|
62
|
-
Normally there is also a `content` field, which is a string.
|
|
63
54
|
model
|
|
64
55
|
The model to use for the chat. The default, None, will pick a reasonable
|
|
65
56
|
default, and warn you about it. We strongly recommend explicitly choosing
|
|
@@ -123,15 +114,21 @@ def ChatGroq(
|
|
|
123
114
|
"""
|
|
124
115
|
if model is None:
|
|
125
116
|
model = log_model_default("llama3-8b-8192")
|
|
117
|
+
|
|
126
118
|
if api_key is None:
|
|
127
119
|
api_key = os.getenv("GROQ_API_KEY")
|
|
128
120
|
|
|
129
|
-
|
|
121
|
+
if isinstance(seed, MISSING_TYPE):
|
|
122
|
+
seed = 1014 if is_testing() else None
|
|
123
|
+
|
|
124
|
+
return Chat(
|
|
125
|
+
provider=OpenAIProvider(
|
|
126
|
+
api_key=api_key,
|
|
127
|
+
model=model,
|
|
128
|
+
base_url=base_url,
|
|
129
|
+
seed=seed,
|
|
130
|
+
name="Groq",
|
|
131
|
+
kwargs=kwargs,
|
|
132
|
+
),
|
|
130
133
|
system_prompt=system_prompt,
|
|
131
|
-
turns=turns,
|
|
132
|
-
model=model,
|
|
133
|
-
api_key=api_key,
|
|
134
|
-
base_url=base_url,
|
|
135
|
-
seed=seed,
|
|
136
|
-
kwargs=kwargs,
|
|
137
134
|
)
|