chatlas 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +11 -1
- chatlas/_anthropic.py +8 -10
- chatlas/_auto.py +183 -0
- chatlas/_chat.py +50 -19
- chatlas/_content.py +23 -7
- chatlas/_display.py +12 -2
- chatlas/_github.py +1 -1
- chatlas/_google.py +263 -166
- chatlas/_groq.py +1 -1
- chatlas/_live_render.py +116 -0
- chatlas/_merge.py +1 -1
- chatlas/_ollama.py +1 -1
- chatlas/_openai.py +4 -6
- chatlas/_perplexity.py +1 -1
- chatlas/_provider.py +0 -9
- chatlas/_snowflake.py +321 -0
- chatlas/_utils.py +7 -0
- chatlas/_version.py +21 -0
- chatlas/py.typed +0 -0
- chatlas/types/__init__.py +5 -1
- chatlas/types/anthropic/_submit.py +24 -2
- chatlas/types/google/_client.py +12 -91
- chatlas/types/google/_submit.py +40 -87
- chatlas/types/openai/_submit.py +9 -2
- chatlas/types/snowflake/__init__.py +8 -0
- chatlas/types/snowflake/_submit.py +24 -0
- {chatlas-0.3.0.dist-info → chatlas-0.5.0.dist-info}/METADATA +35 -7
- chatlas-0.5.0.dist-info/RECORD +44 -0
- chatlas-0.3.0.dist-info/RECORD +0 -37
- {chatlas-0.3.0.dist-info → chatlas-0.5.0.dist-info}/WHEEL +0 -0
chatlas/_google.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import json
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel
|
|
7
8
|
|
|
@@ -16,21 +17,19 @@ from ._content import (
|
|
|
16
17
|
ContentToolResult,
|
|
17
18
|
)
|
|
18
19
|
from ._logging import log_model_default
|
|
20
|
+
from ._merge import merge_dicts
|
|
19
21
|
from ._provider import Provider
|
|
20
22
|
from ._tokens import tokens_log
|
|
21
|
-
from ._tools import Tool
|
|
23
|
+
from ._tools import Tool
|
|
22
24
|
from ._turn import Turn, normalize_turns, user_turn
|
|
23
25
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
25
|
-
from google.
|
|
26
|
-
|
|
27
|
-
FunctionDeclaration,
|
|
28
|
-
PartType,
|
|
29
|
-
)
|
|
30
|
-
from google.generativeai.types.generation_types import (
|
|
31
|
-
AsyncGenerateContentResponse,
|
|
27
|
+
from google.genai.types import Content as GoogleContent
|
|
28
|
+
from google.genai.types import (
|
|
32
29
|
GenerateContentResponse,
|
|
33
|
-
|
|
30
|
+
GenerateContentResponseDict,
|
|
31
|
+
Part,
|
|
32
|
+
PartDict,
|
|
34
33
|
)
|
|
35
34
|
|
|
36
35
|
from .types.google import ChatClientArgs, SubmitInputArgs
|
|
@@ -62,8 +61,7 @@ def ChatGoogle(
|
|
|
62
61
|
::: {.callout-note}
|
|
63
62
|
## Python requirements
|
|
64
63
|
|
|
65
|
-
`ChatGoogle` requires the `google-
|
|
66
|
-
(e.g., `pip install google-generativeai`).
|
|
64
|
+
`ChatGoogle` requires the `google-genai` package: `pip install "chatlas[google]"`.
|
|
67
65
|
:::
|
|
68
66
|
|
|
69
67
|
Examples
|
|
@@ -96,17 +94,13 @@ def ChatGoogle(
|
|
|
96
94
|
The API key to use for authentication. You generally should not supply
|
|
97
95
|
this directly, but instead set the `GOOGLE_API_KEY` environment variable.
|
|
98
96
|
kwargs
|
|
99
|
-
Additional arguments to pass to the `genai.
|
|
97
|
+
Additional arguments to pass to the `genai.Client` constructor.
|
|
100
98
|
|
|
101
99
|
Returns
|
|
102
100
|
-------
|
|
103
101
|
Chat
|
|
104
102
|
A Chat object.
|
|
105
103
|
|
|
106
|
-
Limitations
|
|
107
|
-
-----------
|
|
108
|
-
`ChatGoogle` currently doesn't work with streaming tools.
|
|
109
|
-
|
|
110
104
|
Note
|
|
111
105
|
----
|
|
112
106
|
Pasting an API key into a chat constructor (e.g., `ChatGoogle(api_key="...")`)
|
|
@@ -145,63 +139,49 @@ def ChatGoogle(
|
|
|
145
139
|
"""
|
|
146
140
|
|
|
147
141
|
if model is None:
|
|
148
|
-
model = log_model_default("gemini-
|
|
149
|
-
|
|
150
|
-
turns = normalize_turns(
|
|
151
|
-
turns or [],
|
|
152
|
-
system_prompt=system_prompt,
|
|
153
|
-
)
|
|
142
|
+
model = log_model_default("gemini-2.0-flash")
|
|
154
143
|
|
|
155
144
|
return Chat(
|
|
156
145
|
provider=GoogleProvider(
|
|
157
|
-
turns=turns,
|
|
158
146
|
model=model,
|
|
159
147
|
api_key=api_key,
|
|
160
148
|
kwargs=kwargs,
|
|
161
149
|
),
|
|
162
|
-
turns=
|
|
150
|
+
turns=normalize_turns(
|
|
151
|
+
turns or [],
|
|
152
|
+
system_prompt=system_prompt,
|
|
153
|
+
),
|
|
163
154
|
)
|
|
164
155
|
|
|
165
156
|
|
|
166
|
-
# The dictionary form of ChatCompletion (TODO: stronger typing)?
|
|
167
|
-
GenerateContentDict = dict[str, Any]
|
|
168
|
-
|
|
169
|
-
|
|
170
157
|
class GoogleProvider(
|
|
171
|
-
Provider[
|
|
158
|
+
Provider[
|
|
159
|
+
GenerateContentResponse, GenerateContentResponse, "GenerateContentResponseDict"
|
|
160
|
+
]
|
|
172
161
|
):
|
|
173
162
|
def __init__(
|
|
174
163
|
self,
|
|
175
164
|
*,
|
|
176
|
-
turns: list[Turn],
|
|
177
165
|
model: str,
|
|
178
166
|
api_key: str | None,
|
|
179
167
|
kwargs: Optional["ChatClientArgs"],
|
|
180
168
|
):
|
|
181
169
|
try:
|
|
182
|
-
from google
|
|
170
|
+
from google import genai
|
|
183
171
|
except ImportError:
|
|
184
172
|
raise ImportError(
|
|
185
|
-
f"The {self.__class__.__name__} class requires the `google-
|
|
186
|
-
"Install it with `pip install google-
|
|
173
|
+
f"The {self.__class__.__name__} class requires the `google-genai` package. "
|
|
174
|
+
"Install it with `pip install google-genai`."
|
|
187
175
|
)
|
|
188
176
|
|
|
189
|
-
|
|
190
|
-
import google.generativeai as genai
|
|
191
|
-
|
|
192
|
-
genai.configure(api_key=api_key)
|
|
193
|
-
|
|
194
|
-
system_prompt = None
|
|
195
|
-
if len(turns) > 0 and turns[0].role == "system":
|
|
196
|
-
system_prompt = turns[0].text
|
|
177
|
+
self._model = model
|
|
197
178
|
|
|
198
179
|
kwargs_full: "ChatClientArgs" = {
|
|
199
|
-
"
|
|
200
|
-
"system_instruction": system_prompt,
|
|
180
|
+
"api_key": api_key,
|
|
201
181
|
**(kwargs or {}),
|
|
202
182
|
}
|
|
203
183
|
|
|
204
|
-
self._client =
|
|
184
|
+
self._client = genai.Client(**kwargs_full)
|
|
205
185
|
|
|
206
186
|
@overload
|
|
207
187
|
def chat_perform(
|
|
@@ -233,8 +213,11 @@ class GoogleProvider(
|
|
|
233
213
|
data_model: Optional[type[BaseModel]] = None,
|
|
234
214
|
kwargs: Optional["SubmitInputArgs"] = None,
|
|
235
215
|
):
|
|
236
|
-
kwargs = self._chat_perform_args(
|
|
237
|
-
|
|
216
|
+
kwargs = self._chat_perform_args(turns, tools, data_model, kwargs)
|
|
217
|
+
if stream:
|
|
218
|
+
return self._client.models.generate_content_stream(**kwargs)
|
|
219
|
+
else:
|
|
220
|
+
return self._client.models.generate_content(**kwargs)
|
|
238
221
|
|
|
239
222
|
@overload
|
|
240
223
|
async def chat_perform_async(
|
|
@@ -266,71 +249,82 @@ class GoogleProvider(
|
|
|
266
249
|
data_model: Optional[type[BaseModel]] = None,
|
|
267
250
|
kwargs: Optional["SubmitInputArgs"] = None,
|
|
268
251
|
):
|
|
269
|
-
kwargs = self._chat_perform_args(
|
|
270
|
-
|
|
252
|
+
kwargs = self._chat_perform_args(turns, tools, data_model, kwargs)
|
|
253
|
+
if stream:
|
|
254
|
+
return await self._client.aio.models.generate_content_stream(**kwargs)
|
|
255
|
+
else:
|
|
256
|
+
return await self._client.aio.models.generate_content(**kwargs)
|
|
271
257
|
|
|
272
258
|
def _chat_perform_args(
|
|
273
259
|
self,
|
|
274
|
-
stream: bool,
|
|
275
260
|
turns: list[Turn],
|
|
276
261
|
tools: dict[str, Tool],
|
|
277
262
|
data_model: Optional[type[BaseModel]] = None,
|
|
278
263
|
kwargs: Optional["SubmitInputArgs"] = None,
|
|
279
264
|
) -> "SubmitInputArgs":
|
|
265
|
+
from google.genai.types import FunctionDeclaration, GenerateContentConfig
|
|
266
|
+
from google.genai.types import Tool as GoogleTool
|
|
267
|
+
|
|
280
268
|
kwargs_full: "SubmitInputArgs" = {
|
|
281
|
-
"
|
|
282
|
-
"
|
|
283
|
-
"tools": self._gemini_tools(list(tools.values())) if tools else None,
|
|
269
|
+
"model": self._model,
|
|
270
|
+
"contents": cast("GoogleContent", self._google_contents(turns)),
|
|
284
271
|
**(kwargs or {}),
|
|
285
272
|
}
|
|
286
273
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
274
|
+
config = kwargs_full.get("config")
|
|
275
|
+
if config is None:
|
|
276
|
+
config = GenerateContentConfig()
|
|
277
|
+
if isinstance(config, dict):
|
|
278
|
+
config = GenerateContentConfig.model_construct(**config)
|
|
290
279
|
|
|
291
|
-
|
|
292
|
-
|
|
280
|
+
if config.system_instruction is None:
|
|
281
|
+
if len(turns) > 0 and turns[0].role == "system":
|
|
282
|
+
config.system_instruction = turns[0].text
|
|
293
283
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
284
|
+
if data_model:
|
|
285
|
+
config.response_schema = data_model
|
|
286
|
+
config.response_mime_type = "application/json"
|
|
287
|
+
|
|
288
|
+
if tools:
|
|
289
|
+
config.tools = [
|
|
290
|
+
GoogleTool(
|
|
291
|
+
function_declarations=[
|
|
292
|
+
FunctionDeclaration.from_callable(
|
|
293
|
+
client=self._client, callable=tool.func
|
|
294
|
+
)
|
|
295
|
+
for tool in tools.values()
|
|
296
|
+
]
|
|
297
|
+
)
|
|
298
|
+
]
|
|
301
299
|
|
|
302
|
-
|
|
300
|
+
kwargs_full["config"] = config
|
|
303
301
|
|
|
304
302
|
return kwargs_full
|
|
305
303
|
|
|
306
304
|
def stream_text(self, chunk) -> Optional[str]:
|
|
307
|
-
|
|
305
|
+
try:
|
|
306
|
+
# Errors if there is no text (e.g., tool request)
|
|
308
307
|
return chunk.text
|
|
309
|
-
|
|
308
|
+
except Exception:
|
|
309
|
+
return None
|
|
310
310
|
|
|
311
311
|
def stream_merge_chunks(self, completion, chunk):
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
stream.resolve()
|
|
319
|
-
return self._as_turn(
|
|
320
|
-
stream,
|
|
321
|
-
has_data_model,
|
|
312
|
+
chunkd = chunk.model_dump()
|
|
313
|
+
if completion is None:
|
|
314
|
+
return cast("GenerateContentResponseDict", chunkd)
|
|
315
|
+
return cast(
|
|
316
|
+
"GenerateContentResponseDict",
|
|
317
|
+
merge_dicts(completion, chunkd), # type: ignore
|
|
322
318
|
)
|
|
323
319
|
|
|
324
|
-
|
|
325
|
-
self, completion, has_data_model, stream: AsyncGenerateContentResponse
|
|
326
|
-
) -> Turn:
|
|
327
|
-
await stream.resolve()
|
|
320
|
+
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
328
321
|
return self._as_turn(
|
|
329
|
-
|
|
322
|
+
completion,
|
|
330
323
|
has_data_model,
|
|
331
324
|
)
|
|
332
325
|
|
|
333
326
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
327
|
+
completion = cast("GenerateContentResponseDict", completion.model_dump())
|
|
334
328
|
return self._as_turn(completion, has_data_model)
|
|
335
329
|
|
|
336
330
|
def token_count(
|
|
@@ -345,8 +339,8 @@ class GoogleProvider(
|
|
|
345
339
|
data_model=data_model,
|
|
346
340
|
)
|
|
347
341
|
|
|
348
|
-
res = self._client.count_tokens(**kwargs)
|
|
349
|
-
return res.total_tokens
|
|
342
|
+
res = self._client.models.count_tokens(**kwargs)
|
|
343
|
+
return res.total_tokens or 0
|
|
350
344
|
|
|
351
345
|
async def token_count_async(
|
|
352
346
|
self,
|
|
@@ -360,8 +354,8 @@ class GoogleProvider(
|
|
|
360
354
|
data_model=data_model,
|
|
361
355
|
)
|
|
362
356
|
|
|
363
|
-
res = await self._client.
|
|
364
|
-
return res.total_tokens
|
|
357
|
+
res = await self._client.aio.models.count_tokens(**kwargs)
|
|
358
|
+
return res.total_tokens or 0
|
|
365
359
|
|
|
366
360
|
def _token_count_args(
|
|
367
361
|
self,
|
|
@@ -372,44 +366,43 @@ class GoogleProvider(
|
|
|
372
366
|
turn = user_turn(*args)
|
|
373
367
|
|
|
374
368
|
kwargs = self._chat_perform_args(
|
|
375
|
-
stream=False,
|
|
376
369
|
turns=[turn],
|
|
377
370
|
tools=tools,
|
|
378
371
|
data_model=data_model,
|
|
379
372
|
)
|
|
380
373
|
|
|
381
|
-
args_to_keep = ["contents", "tools"]
|
|
374
|
+
args_to_keep = ["model", "contents", "tools"]
|
|
382
375
|
|
|
383
376
|
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}
|
|
384
377
|
|
|
385
|
-
def _google_contents(self, turns: list[Turn]) -> list["
|
|
386
|
-
|
|
378
|
+
def _google_contents(self, turns: list[Turn]) -> list["GoogleContent"]:
|
|
379
|
+
from google.genai.types import Content as GoogleContent
|
|
380
|
+
|
|
381
|
+
contents: list["GoogleContent"] = []
|
|
387
382
|
for turn in turns:
|
|
388
383
|
if turn.role == "system":
|
|
389
384
|
continue # System messages are handled separately
|
|
390
385
|
elif turn.role == "user":
|
|
391
386
|
parts = [self._as_part_type(c) for c in turn.contents]
|
|
392
|
-
contents.append(
|
|
387
|
+
contents.append(GoogleContent(role=turn.role, parts=parts))
|
|
393
388
|
elif turn.role == "assistant":
|
|
394
389
|
parts = [self._as_part_type(c) for c in turn.contents]
|
|
395
|
-
contents.append(
|
|
390
|
+
contents.append(GoogleContent(role="model", parts=parts))
|
|
396
391
|
else:
|
|
397
392
|
raise ValueError(f"Unknown role {turn.role}")
|
|
398
393
|
return contents
|
|
399
394
|
|
|
400
|
-
def _as_part_type(self, content: Content) -> "
|
|
401
|
-
from google.
|
|
395
|
+
def _as_part_type(self, content: Content) -> "Part":
|
|
396
|
+
from google.genai.types import FunctionCall, FunctionResponse, Part
|
|
402
397
|
|
|
403
398
|
if isinstance(content, ContentText):
|
|
404
|
-
return
|
|
399
|
+
return Part.from_text(text=content.text)
|
|
405
400
|
elif isinstance(content, ContentJson):
|
|
406
|
-
return
|
|
407
|
-
elif isinstance(content, ContentImageInline):
|
|
408
|
-
return
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
"data": content.data,
|
|
412
|
-
}
|
|
401
|
+
return Part.from_text(text="<structured data/>")
|
|
402
|
+
elif isinstance(content, ContentImageInline) and content.data:
|
|
403
|
+
return Part.from_bytes(
|
|
404
|
+
data=base64.b64decode(content.data),
|
|
405
|
+
mime_type=content.content_type,
|
|
413
406
|
)
|
|
414
407
|
elif isinstance(content, ContentImageRemote):
|
|
415
408
|
raise NotImplementedError(
|
|
@@ -417,92 +410,196 @@ class GoogleProvider(
|
|
|
417
410
|
"Consider downloading the image and using content_image_file() instead."
|
|
418
411
|
)
|
|
419
412
|
elif isinstance(content, ContentToolRequest):
|
|
420
|
-
return
|
|
421
|
-
function_call=
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
413
|
+
return Part(
|
|
414
|
+
function_call=FunctionCall(
|
|
415
|
+
id=content.id if content.name != content.id else None,
|
|
416
|
+
name=content.name,
|
|
417
|
+
# Goes in a dict, so should come out as a dict
|
|
418
|
+
args=cast(dict[str, Any], content.arguments),
|
|
419
|
+
)
|
|
425
420
|
)
|
|
426
421
|
elif isinstance(content, ContentToolResult):
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
422
|
+
if content.error:
|
|
423
|
+
resp = {"error": content.error}
|
|
424
|
+
else:
|
|
425
|
+
resp = {"result": str(content.value)}
|
|
426
|
+
return Part(
|
|
427
|
+
# TODO: seems function response parts might need role='tool'???
|
|
428
|
+
# https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344
|
|
429
|
+
function_response=FunctionResponse(
|
|
430
|
+
id=content.id if content.name != content.id else None,
|
|
431
|
+
name=content.name,
|
|
432
|
+
response=resp,
|
|
433
|
+
)
|
|
432
434
|
)
|
|
433
435
|
raise ValueError(f"Unknown content type: {type(content)}")
|
|
434
436
|
|
|
435
437
|
def _as_turn(
|
|
436
438
|
self,
|
|
437
|
-
message: "
|
|
439
|
+
message: "GenerateContentResponseDict",
|
|
438
440
|
has_data_model: bool,
|
|
439
441
|
) -> Turn:
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
442
|
+
from google.genai.types import FinishReason
|
|
443
|
+
|
|
444
|
+
candidates = message.get("candidates")
|
|
445
|
+
if not candidates:
|
|
446
|
+
return Turn("assistant", "")
|
|
447
|
+
|
|
448
|
+
parts: list["PartDict"] = []
|
|
449
|
+
finish_reason = None
|
|
450
|
+
for candidate in candidates:
|
|
451
|
+
content = candidate.get("content")
|
|
452
|
+
if content:
|
|
453
|
+
parts.extend(content.get("parts") or {})
|
|
454
|
+
finish = candidate.get("finish_reason")
|
|
455
|
+
if finish:
|
|
456
|
+
finish_reason = finish
|
|
457
|
+
|
|
458
|
+
contents: list[Content] = []
|
|
459
|
+
for part in parts:
|
|
460
|
+
text = part.get("text")
|
|
461
|
+
if text:
|
|
446
462
|
if has_data_model:
|
|
447
|
-
contents.append(ContentJson(json.loads(
|
|
463
|
+
contents.append(ContentJson(json.loads(text)))
|
|
448
464
|
else:
|
|
449
|
-
contents.append(ContentText(
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
465
|
+
contents.append(ContentText(text))
|
|
466
|
+
function_call = part.get("function_call")
|
|
467
|
+
if function_call:
|
|
468
|
+
# Seems name is required but id is optional?
|
|
469
|
+
name = function_call.get("name")
|
|
470
|
+
if name:
|
|
471
|
+
contents.append(
|
|
472
|
+
ContentToolRequest(
|
|
473
|
+
id=function_call.get("id") or name,
|
|
474
|
+
name=name,
|
|
475
|
+
arguments=function_call.get("args"),
|
|
476
|
+
)
|
|
457
477
|
)
|
|
458
|
-
|
|
459
|
-
if
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
478
|
+
function_response = part.get("function_response")
|
|
479
|
+
if function_response:
|
|
480
|
+
# Seems name is required but id is optional?
|
|
481
|
+
name = function_response.get("name")
|
|
482
|
+
if name:
|
|
483
|
+
contents.append(
|
|
484
|
+
ContentToolResult(
|
|
485
|
+
id=function_response.get("id") or name,
|
|
486
|
+
value=function_response.get("response"),
|
|
487
|
+
name=name,
|
|
488
|
+
)
|
|
465
489
|
)
|
|
466
|
-
)
|
|
467
490
|
|
|
468
|
-
usage = message.usage_metadata
|
|
469
|
-
tokens = (
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
491
|
+
usage = message.get("usage_metadata")
|
|
492
|
+
tokens = (0, 0)
|
|
493
|
+
if usage:
|
|
494
|
+
tokens = (
|
|
495
|
+
usage.get("prompt_token_count") or 0,
|
|
496
|
+
usage.get("candidates_token_count") or 0,
|
|
497
|
+
)
|
|
473
498
|
|
|
474
499
|
tokens_log(self, tokens)
|
|
475
500
|
|
|
476
|
-
|
|
501
|
+
if isinstance(finish_reason, FinishReason):
|
|
502
|
+
finish_reason = finish_reason.name
|
|
477
503
|
|
|
478
504
|
return Turn(
|
|
479
505
|
"assistant",
|
|
480
506
|
contents,
|
|
481
507
|
tokens=tokens,
|
|
482
|
-
finish_reason=
|
|
508
|
+
finish_reason=finish_reason,
|
|
483
509
|
completion=message,
|
|
484
510
|
)
|
|
485
511
|
|
|
486
|
-
def _gemini_tools(self, tools: list[Tool]) -> list["FunctionDeclaration"]:
|
|
487
|
-
from google.generativeai.types.content_types import FunctionDeclaration
|
|
488
|
-
|
|
489
|
-
res: list["FunctionDeclaration"] = []
|
|
490
|
-
for tool in tools:
|
|
491
|
-
fn = tool.schema["function"]
|
|
492
|
-
params = None
|
|
493
|
-
if "parameters" in fn and fn["parameters"]["properties"]:
|
|
494
|
-
params = {
|
|
495
|
-
"type": "object",
|
|
496
|
-
"properties": fn["parameters"]["properties"],
|
|
497
|
-
"required": fn["parameters"]["required"],
|
|
498
|
-
}
|
|
499
|
-
|
|
500
|
-
res.append(
|
|
501
|
-
FunctionDeclaration(
|
|
502
|
-
name=fn["name"],
|
|
503
|
-
description=fn.get("description", ""),
|
|
504
|
-
parameters=params,
|
|
505
|
-
)
|
|
506
|
-
)
|
|
507
512
|
|
|
508
|
-
|
|
513
|
+
def ChatVertex(
|
|
514
|
+
*,
|
|
515
|
+
model: Optional[str] = None,
|
|
516
|
+
project: Optional[str] = None,
|
|
517
|
+
location: Optional[str] = None,
|
|
518
|
+
api_key: Optional[str] = None,
|
|
519
|
+
system_prompt: Optional[str] = None,
|
|
520
|
+
turns: Optional[list[Turn]] = None,
|
|
521
|
+
kwargs: Optional["ChatClientArgs"] = None,
|
|
522
|
+
) -> Chat["SubmitInputArgs", GenerateContentResponse]:
|
|
523
|
+
"""
|
|
524
|
+
Chat with a Google Vertex AI model.
|
|
525
|
+
|
|
526
|
+
Prerequisites
|
|
527
|
+
-------------
|
|
528
|
+
|
|
529
|
+
::: {.callout-note}
|
|
530
|
+
## Python requirements
|
|
531
|
+
|
|
532
|
+
`ChatGoogle` requires the `google-genai` package: `pip install "chatlas[vertex]"`.
|
|
533
|
+
:::
|
|
534
|
+
|
|
535
|
+
::: {.callout-note}
|
|
536
|
+
## Credentials
|
|
537
|
+
|
|
538
|
+
To use Google's models (i.e., Vertex AI), you'll need to sign up for an account
|
|
539
|
+
with [Vertex AI](https://cloud.google.com/vertex-ai), then specify the appropriate
|
|
540
|
+
model, project, and location.
|
|
541
|
+
:::
|
|
542
|
+
|
|
543
|
+
Parameters
|
|
544
|
+
----------
|
|
545
|
+
model
|
|
546
|
+
The model to use for the chat. The default, None, will pick a reasonable
|
|
547
|
+
default, and warn you about it. We strongly recommend explicitly choosing
|
|
548
|
+
a model for all but the most casual use.
|
|
549
|
+
project
|
|
550
|
+
The Google Cloud project ID (e.g., "your-project-id"). If not provided, the
|
|
551
|
+
GOOGLE_CLOUD_PROJECT environment variable will be used.
|
|
552
|
+
location
|
|
553
|
+
The Google Cloud location (e.g., "us-central1"). If not provided, the
|
|
554
|
+
GOOGLE_CLOUD_LOCATION environment variable will be used.
|
|
555
|
+
system_prompt
|
|
556
|
+
A system prompt to set the behavior of the assistant.
|
|
557
|
+
turns
|
|
558
|
+
A list of turns to start the chat with (i.e., continuing a previous
|
|
559
|
+
conversation). If not provided, the conversation begins from scratch.
|
|
560
|
+
Do not provide non-`None` values for both `turns` and `system_prompt`.
|
|
561
|
+
Each message in the list should be a dictionary with at least `role`
|
|
562
|
+
(usually `system`, `user`, or `assistant`, but `tool` is also possible).
|
|
563
|
+
Normally there is also a `content` field, which is a string.
|
|
564
|
+
|
|
565
|
+
Returns
|
|
566
|
+
-------
|
|
567
|
+
Chat
|
|
568
|
+
A Chat object.
|
|
569
|
+
|
|
570
|
+
Examples
|
|
571
|
+
--------
|
|
572
|
+
|
|
573
|
+
```python
|
|
574
|
+
import os
|
|
575
|
+
from chatlas import ChatVertex
|
|
576
|
+
|
|
577
|
+
chat = ChatVertex(
|
|
578
|
+
project="your-project-id",
|
|
579
|
+
location="us-central1",
|
|
580
|
+
)
|
|
581
|
+
chat.chat("What is the capital of France?")
|
|
582
|
+
```
|
|
583
|
+
"""
|
|
584
|
+
|
|
585
|
+
if kwargs is None:
|
|
586
|
+
kwargs = {}
|
|
587
|
+
|
|
588
|
+
kwargs["vertexai"] = True
|
|
589
|
+
kwargs["project"] = project
|
|
590
|
+
kwargs["location"] = location
|
|
591
|
+
|
|
592
|
+
if model is None:
|
|
593
|
+
model = log_model_default("gemini-2.0-flash")
|
|
594
|
+
|
|
595
|
+
return Chat(
|
|
596
|
+
provider=GoogleProvider(
|
|
597
|
+
model=model,
|
|
598
|
+
api_key=api_key,
|
|
599
|
+
kwargs=kwargs,
|
|
600
|
+
),
|
|
601
|
+
turns=normalize_turns(
|
|
602
|
+
turns or [],
|
|
603
|
+
system_prompt=system_prompt,
|
|
604
|
+
),
|
|
605
|
+
)
|
chatlas/_groq.py
CHANGED