chatlas 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +2 -1
- chatlas/_anthropic.py +3 -5
- chatlas/_chat.py +48 -19
- chatlas/_content.py +20 -7
- chatlas/_google.py +265 -166
- chatlas/_merge.py +1 -1
- chatlas/_openai.py +1 -4
- chatlas/_provider.py +0 -9
- chatlas/py.typed +0 -0
- chatlas/types/__init__.py +5 -1
- chatlas/types/anthropic/_submit.py +1 -1
- chatlas/types/google/_client.py +12 -91
- chatlas/types/google/_submit.py +40 -87
- chatlas/types/openai/_submit.py +3 -1
- {chatlas-0.3.0.dist-info → chatlas-0.4.0.dist-info}/METADATA +10 -7
- {chatlas-0.3.0.dist-info → chatlas-0.4.0.dist-info}/RECORD +17 -16
- {chatlas-0.3.0.dist-info → chatlas-0.4.0.dist-info}/WHEEL +0 -0
chatlas/_google.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import json
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel
|
|
7
8
|
|
|
@@ -16,21 +17,19 @@ from ._content import (
|
|
|
16
17
|
ContentToolResult,
|
|
17
18
|
)
|
|
18
19
|
from ._logging import log_model_default
|
|
20
|
+
from ._merge import merge_dicts
|
|
19
21
|
from ._provider import Provider
|
|
20
22
|
from ._tokens import tokens_log
|
|
21
|
-
from ._tools import Tool
|
|
23
|
+
from ._tools import Tool
|
|
22
24
|
from ._turn import Turn, normalize_turns, user_turn
|
|
23
25
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
25
|
-
from google.
|
|
26
|
-
|
|
27
|
-
FunctionDeclaration,
|
|
28
|
-
PartType,
|
|
29
|
-
)
|
|
30
|
-
from google.generativeai.types.generation_types import (
|
|
31
|
-
AsyncGenerateContentResponse,
|
|
27
|
+
from google.genai.types import Content as GoogleContent
|
|
28
|
+
from google.genai.types import (
|
|
32
29
|
GenerateContentResponse,
|
|
33
|
-
|
|
30
|
+
GenerateContentResponseDict,
|
|
31
|
+
Part,
|
|
32
|
+
PartDict,
|
|
34
33
|
)
|
|
35
34
|
|
|
36
35
|
from .types.google import ChatClientArgs, SubmitInputArgs
|
|
@@ -62,8 +61,8 @@ def ChatGoogle(
|
|
|
62
61
|
::: {.callout-note}
|
|
63
62
|
## Python requirements
|
|
64
63
|
|
|
65
|
-
`ChatGoogle` requires the `google-
|
|
66
|
-
(e.g., `pip install google-
|
|
64
|
+
`ChatGoogle` requires the `google-genai` package
|
|
65
|
+
(e.g., `pip install google-genai`).
|
|
67
66
|
:::
|
|
68
67
|
|
|
69
68
|
Examples
|
|
@@ -96,17 +95,13 @@ def ChatGoogle(
|
|
|
96
95
|
The API key to use for authentication. You generally should not supply
|
|
97
96
|
this directly, but instead set the `GOOGLE_API_KEY` environment variable.
|
|
98
97
|
kwargs
|
|
99
|
-
Additional arguments to pass to the `genai.
|
|
98
|
+
Additional arguments to pass to the `genai.Client` constructor.
|
|
100
99
|
|
|
101
100
|
Returns
|
|
102
101
|
-------
|
|
103
102
|
Chat
|
|
104
103
|
A Chat object.
|
|
105
104
|
|
|
106
|
-
Limitations
|
|
107
|
-
-----------
|
|
108
|
-
`ChatGoogle` currently doesn't work with streaming tools.
|
|
109
|
-
|
|
110
105
|
Note
|
|
111
106
|
----
|
|
112
107
|
Pasting an API key into a chat constructor (e.g., `ChatGoogle(api_key="...")`)
|
|
@@ -145,63 +140,49 @@ def ChatGoogle(
|
|
|
145
140
|
"""
|
|
146
141
|
|
|
147
142
|
if model is None:
|
|
148
|
-
model = log_model_default("gemini-
|
|
149
|
-
|
|
150
|
-
turns = normalize_turns(
|
|
151
|
-
turns or [],
|
|
152
|
-
system_prompt=system_prompt,
|
|
153
|
-
)
|
|
143
|
+
model = log_model_default("gemini-2.0-flash")
|
|
154
144
|
|
|
155
145
|
return Chat(
|
|
156
146
|
provider=GoogleProvider(
|
|
157
|
-
turns=turns,
|
|
158
147
|
model=model,
|
|
159
148
|
api_key=api_key,
|
|
160
149
|
kwargs=kwargs,
|
|
161
150
|
),
|
|
162
|
-
turns=
|
|
151
|
+
turns=normalize_turns(
|
|
152
|
+
turns or [],
|
|
153
|
+
system_prompt=system_prompt,
|
|
154
|
+
),
|
|
163
155
|
)
|
|
164
156
|
|
|
165
157
|
|
|
166
|
-
# The dictionary form of ChatCompletion (TODO: stronger typing)?
|
|
167
|
-
GenerateContentDict = dict[str, Any]
|
|
168
|
-
|
|
169
|
-
|
|
170
158
|
class GoogleProvider(
|
|
171
|
-
Provider[
|
|
159
|
+
Provider[
|
|
160
|
+
GenerateContentResponse, GenerateContentResponse, "GenerateContentResponseDict"
|
|
161
|
+
]
|
|
172
162
|
):
|
|
173
163
|
def __init__(
|
|
174
164
|
self,
|
|
175
165
|
*,
|
|
176
|
-
turns: list[Turn],
|
|
177
166
|
model: str,
|
|
178
167
|
api_key: str | None,
|
|
179
168
|
kwargs: Optional["ChatClientArgs"],
|
|
180
169
|
):
|
|
181
170
|
try:
|
|
182
|
-
from google
|
|
171
|
+
from google import genai
|
|
183
172
|
except ImportError:
|
|
184
173
|
raise ImportError(
|
|
185
|
-
f"The {self.__class__.__name__} class requires the `google-
|
|
186
|
-
"Install it with `pip install google-
|
|
174
|
+
f"The {self.__class__.__name__} class requires the `google-genai` package. "
|
|
175
|
+
"Install it with `pip install google-genai`."
|
|
187
176
|
)
|
|
188
177
|
|
|
189
|
-
|
|
190
|
-
import google.generativeai as genai
|
|
191
|
-
|
|
192
|
-
genai.configure(api_key=api_key)
|
|
193
|
-
|
|
194
|
-
system_prompt = None
|
|
195
|
-
if len(turns) > 0 and turns[0].role == "system":
|
|
196
|
-
system_prompt = turns[0].text
|
|
178
|
+
self._model = model
|
|
197
179
|
|
|
198
180
|
kwargs_full: "ChatClientArgs" = {
|
|
199
|
-
"
|
|
200
|
-
"system_instruction": system_prompt,
|
|
181
|
+
"api_key": api_key,
|
|
201
182
|
**(kwargs or {}),
|
|
202
183
|
}
|
|
203
184
|
|
|
204
|
-
self._client =
|
|
185
|
+
self._client = genai.Client(**kwargs_full)
|
|
205
186
|
|
|
206
187
|
@overload
|
|
207
188
|
def chat_perform(
|
|
@@ -233,8 +214,11 @@ class GoogleProvider(
|
|
|
233
214
|
data_model: Optional[type[BaseModel]] = None,
|
|
234
215
|
kwargs: Optional["SubmitInputArgs"] = None,
|
|
235
216
|
):
|
|
236
|
-
kwargs = self._chat_perform_args(
|
|
237
|
-
|
|
217
|
+
kwargs = self._chat_perform_args(turns, tools, data_model, kwargs)
|
|
218
|
+
if stream:
|
|
219
|
+
return self._client.models.generate_content_stream(**kwargs)
|
|
220
|
+
else:
|
|
221
|
+
return self._client.models.generate_content(**kwargs)
|
|
238
222
|
|
|
239
223
|
@overload
|
|
240
224
|
async def chat_perform_async(
|
|
@@ -266,71 +250,82 @@ class GoogleProvider(
|
|
|
266
250
|
data_model: Optional[type[BaseModel]] = None,
|
|
267
251
|
kwargs: Optional["SubmitInputArgs"] = None,
|
|
268
252
|
):
|
|
269
|
-
kwargs = self._chat_perform_args(
|
|
270
|
-
|
|
253
|
+
kwargs = self._chat_perform_args(turns, tools, data_model, kwargs)
|
|
254
|
+
if stream:
|
|
255
|
+
return await self._client.aio.models.generate_content_stream(**kwargs)
|
|
256
|
+
else:
|
|
257
|
+
return await self._client.aio.models.generate_content(**kwargs)
|
|
271
258
|
|
|
272
259
|
def _chat_perform_args(
|
|
273
260
|
self,
|
|
274
|
-
stream: bool,
|
|
275
261
|
turns: list[Turn],
|
|
276
262
|
tools: dict[str, Tool],
|
|
277
263
|
data_model: Optional[type[BaseModel]] = None,
|
|
278
264
|
kwargs: Optional["SubmitInputArgs"] = None,
|
|
279
265
|
) -> "SubmitInputArgs":
|
|
266
|
+
from google.genai.types import FunctionDeclaration, GenerateContentConfig
|
|
267
|
+
from google.genai.types import Tool as GoogleTool
|
|
268
|
+
|
|
280
269
|
kwargs_full: "SubmitInputArgs" = {
|
|
281
|
-
"
|
|
282
|
-
"
|
|
283
|
-
"tools": self._gemini_tools(list(tools.values())) if tools else None,
|
|
270
|
+
"model": self._model,
|
|
271
|
+
"contents": cast("GoogleContent", self._google_contents(turns)),
|
|
284
272
|
**(kwargs or {}),
|
|
285
273
|
}
|
|
286
274
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
275
|
+
config = kwargs_full.get("config")
|
|
276
|
+
if config is None:
|
|
277
|
+
config = GenerateContentConfig()
|
|
278
|
+
if isinstance(config, dict):
|
|
279
|
+
config = GenerateContentConfig.model_construct(**config)
|
|
290
280
|
|
|
291
|
-
|
|
292
|
-
|
|
281
|
+
if config.system_instruction is None:
|
|
282
|
+
if len(turns) > 0 and turns[0].role == "system":
|
|
283
|
+
config.system_instruction = turns[0].text
|
|
293
284
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
285
|
+
if data_model:
|
|
286
|
+
config.response_schema = data_model
|
|
287
|
+
config.response_mime_type = "application/json"
|
|
288
|
+
|
|
289
|
+
if tools:
|
|
290
|
+
config.tools = [
|
|
291
|
+
GoogleTool(
|
|
292
|
+
function_declarations=[
|
|
293
|
+
FunctionDeclaration.from_callable(
|
|
294
|
+
client=self._client, callable=tool.func
|
|
295
|
+
)
|
|
296
|
+
for tool in tools.values()
|
|
297
|
+
]
|
|
298
|
+
)
|
|
299
|
+
]
|
|
301
300
|
|
|
302
|
-
|
|
301
|
+
kwargs_full["config"] = config
|
|
303
302
|
|
|
304
303
|
return kwargs_full
|
|
305
304
|
|
|
306
305
|
def stream_text(self, chunk) -> Optional[str]:
|
|
307
|
-
|
|
306
|
+
try:
|
|
307
|
+
# Errors if there is no text (e.g., tool request)
|
|
308
308
|
return chunk.text
|
|
309
|
-
|
|
309
|
+
except Exception:
|
|
310
|
+
return None
|
|
310
311
|
|
|
311
312
|
def stream_merge_chunks(self, completion, chunk):
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
stream.resolve()
|
|
319
|
-
return self._as_turn(
|
|
320
|
-
stream,
|
|
321
|
-
has_data_model,
|
|
313
|
+
chunkd = chunk.model_dump()
|
|
314
|
+
if completion is None:
|
|
315
|
+
return cast("GenerateContentResponseDict", chunkd)
|
|
316
|
+
return cast(
|
|
317
|
+
"GenerateContentResponseDict",
|
|
318
|
+
merge_dicts(completion, chunkd), # type: ignore
|
|
322
319
|
)
|
|
323
320
|
|
|
324
|
-
|
|
325
|
-
self, completion, has_data_model, stream: AsyncGenerateContentResponse
|
|
326
|
-
) -> Turn:
|
|
327
|
-
await stream.resolve()
|
|
321
|
+
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
328
322
|
return self._as_turn(
|
|
329
|
-
|
|
323
|
+
completion,
|
|
330
324
|
has_data_model,
|
|
331
325
|
)
|
|
332
326
|
|
|
333
327
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
328
|
+
completion = cast("GenerateContentResponseDict", completion.model_dump())
|
|
334
329
|
return self._as_turn(completion, has_data_model)
|
|
335
330
|
|
|
336
331
|
def token_count(
|
|
@@ -345,8 +340,8 @@ class GoogleProvider(
|
|
|
345
340
|
data_model=data_model,
|
|
346
341
|
)
|
|
347
342
|
|
|
348
|
-
res = self._client.count_tokens(**kwargs)
|
|
349
|
-
return res.total_tokens
|
|
343
|
+
res = self._client.models.count_tokens(**kwargs)
|
|
344
|
+
return res.total_tokens or 0
|
|
350
345
|
|
|
351
346
|
async def token_count_async(
|
|
352
347
|
self,
|
|
@@ -360,8 +355,8 @@ class GoogleProvider(
|
|
|
360
355
|
data_model=data_model,
|
|
361
356
|
)
|
|
362
357
|
|
|
363
|
-
res = await self._client.
|
|
364
|
-
return res.total_tokens
|
|
358
|
+
res = await self._client.aio.models.count_tokens(**kwargs)
|
|
359
|
+
return res.total_tokens or 0
|
|
365
360
|
|
|
366
361
|
def _token_count_args(
|
|
367
362
|
self,
|
|
@@ -372,44 +367,43 @@ class GoogleProvider(
|
|
|
372
367
|
turn = user_turn(*args)
|
|
373
368
|
|
|
374
369
|
kwargs = self._chat_perform_args(
|
|
375
|
-
stream=False,
|
|
376
370
|
turns=[turn],
|
|
377
371
|
tools=tools,
|
|
378
372
|
data_model=data_model,
|
|
379
373
|
)
|
|
380
374
|
|
|
381
|
-
args_to_keep = ["contents", "tools"]
|
|
375
|
+
args_to_keep = ["model", "contents", "tools"]
|
|
382
376
|
|
|
383
377
|
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}
|
|
384
378
|
|
|
385
|
-
def _google_contents(self, turns: list[Turn]) -> list["
|
|
386
|
-
|
|
379
|
+
def _google_contents(self, turns: list[Turn]) -> list["GoogleContent"]:
|
|
380
|
+
from google.genai.types import Content as GoogleContent
|
|
381
|
+
|
|
382
|
+
contents: list["GoogleContent"] = []
|
|
387
383
|
for turn in turns:
|
|
388
384
|
if turn.role == "system":
|
|
389
385
|
continue # System messages are handled separately
|
|
390
386
|
elif turn.role == "user":
|
|
391
387
|
parts = [self._as_part_type(c) for c in turn.contents]
|
|
392
|
-
contents.append(
|
|
388
|
+
contents.append(GoogleContent(role=turn.role, parts=parts))
|
|
393
389
|
elif turn.role == "assistant":
|
|
394
390
|
parts = [self._as_part_type(c) for c in turn.contents]
|
|
395
|
-
contents.append(
|
|
391
|
+
contents.append(GoogleContent(role="model", parts=parts))
|
|
396
392
|
else:
|
|
397
393
|
raise ValueError(f"Unknown role {turn.role}")
|
|
398
394
|
return contents
|
|
399
395
|
|
|
400
|
-
def _as_part_type(self, content: Content) -> "
|
|
401
|
-
from google.
|
|
396
|
+
def _as_part_type(self, content: Content) -> "Part":
|
|
397
|
+
from google.genai.types import FunctionCall, FunctionResponse, Part
|
|
402
398
|
|
|
403
399
|
if isinstance(content, ContentText):
|
|
404
|
-
return
|
|
400
|
+
return Part.from_text(text=content.text)
|
|
405
401
|
elif isinstance(content, ContentJson):
|
|
406
|
-
return
|
|
407
|
-
elif isinstance(content, ContentImageInline):
|
|
408
|
-
return
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
"data": content.data,
|
|
412
|
-
}
|
|
402
|
+
return Part.from_text(text="<structured data/>")
|
|
403
|
+
elif isinstance(content, ContentImageInline) and content.data:
|
|
404
|
+
return Part.from_bytes(
|
|
405
|
+
data=base64.b64decode(content.data),
|
|
406
|
+
mime_type=content.content_type,
|
|
413
407
|
)
|
|
414
408
|
elif isinstance(content, ContentImageRemote):
|
|
415
409
|
raise NotImplementedError(
|
|
@@ -417,92 +411,197 @@ class GoogleProvider(
|
|
|
417
411
|
"Consider downloading the image and using content_image_file() instead."
|
|
418
412
|
)
|
|
419
413
|
elif isinstance(content, ContentToolRequest):
|
|
420
|
-
return
|
|
421
|
-
function_call=
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
414
|
+
return Part(
|
|
415
|
+
function_call=FunctionCall(
|
|
416
|
+
id=content.id,
|
|
417
|
+
name=content.name,
|
|
418
|
+
# Goes in a dict, so should come out as a dict
|
|
419
|
+
args=cast(dict[str, Any], content.arguments),
|
|
420
|
+
)
|
|
425
421
|
)
|
|
426
422
|
elif isinstance(content, ContentToolResult):
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
423
|
+
if content.error:
|
|
424
|
+
resp = {"error": content.error}
|
|
425
|
+
else:
|
|
426
|
+
resp = {"result": str(content.value)}
|
|
427
|
+
return Part(
|
|
428
|
+
# TODO: seems function response parts might need role='tool'???
|
|
429
|
+
# https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344
|
|
430
|
+
function_response=FunctionResponse(
|
|
431
|
+
id=content.id,
|
|
432
|
+
name=content.name,
|
|
433
|
+
response=resp,
|
|
434
|
+
)
|
|
432
435
|
)
|
|
433
436
|
raise ValueError(f"Unknown content type: {type(content)}")
|
|
434
437
|
|
|
435
438
|
def _as_turn(
|
|
436
439
|
self,
|
|
437
|
-
message: "
|
|
440
|
+
message: "GenerateContentResponseDict",
|
|
438
441
|
has_data_model: bool,
|
|
439
442
|
) -> Turn:
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
443
|
+
from google.genai.types import FinishReason
|
|
444
|
+
|
|
445
|
+
candidates = message.get("candidates")
|
|
446
|
+
if not candidates:
|
|
447
|
+
return Turn("assistant", "")
|
|
448
|
+
|
|
449
|
+
parts: list["PartDict"] = []
|
|
450
|
+
finish_reason = None
|
|
451
|
+
for candidate in candidates:
|
|
452
|
+
content = candidate.get("content")
|
|
453
|
+
if content:
|
|
454
|
+
parts.extend(content.get("parts") or {})
|
|
455
|
+
finish = candidate.get("finish_reason")
|
|
456
|
+
if finish:
|
|
457
|
+
finish_reason = finish
|
|
458
|
+
|
|
459
|
+
contents: list[Content] = []
|
|
460
|
+
for part in parts:
|
|
461
|
+
text = part.get("text")
|
|
462
|
+
if text:
|
|
446
463
|
if has_data_model:
|
|
447
|
-
contents.append(ContentJson(json.loads(
|
|
464
|
+
contents.append(ContentJson(json.loads(text)))
|
|
448
465
|
else:
|
|
449
|
-
contents.append(ContentText(
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
466
|
+
contents.append(ContentText(text))
|
|
467
|
+
function_call = part.get("function_call")
|
|
468
|
+
if function_call:
|
|
469
|
+
# Seems name is required but id is optional?
|
|
470
|
+
name = function_call.get("name")
|
|
471
|
+
if name:
|
|
472
|
+
contents.append(
|
|
473
|
+
ContentToolRequest(
|
|
474
|
+
id=function_call.get("id") or name,
|
|
475
|
+
name=name,
|
|
476
|
+
arguments=function_call.get("args"),
|
|
477
|
+
)
|
|
457
478
|
)
|
|
458
|
-
|
|
459
|
-
if
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
479
|
+
function_response = part.get("function_response")
|
|
480
|
+
if function_response:
|
|
481
|
+
# Seems name is required but id is optional?
|
|
482
|
+
name = function_response.get("name")
|
|
483
|
+
if name:
|
|
484
|
+
contents.append(
|
|
485
|
+
ContentToolResult(
|
|
486
|
+
id=function_response.get("id") or name,
|
|
487
|
+
value=function_response.get("response"),
|
|
488
|
+
name=name,
|
|
489
|
+
)
|
|
465
490
|
)
|
|
466
|
-
)
|
|
467
491
|
|
|
468
|
-
usage = message.usage_metadata
|
|
469
|
-
tokens = (
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
492
|
+
usage = message.get("usage_metadata")
|
|
493
|
+
tokens = (0, 0)
|
|
494
|
+
if usage:
|
|
495
|
+
tokens = (
|
|
496
|
+
usage.get("prompt_token_count") or 0,
|
|
497
|
+
usage.get("candidates_token_count") or 0,
|
|
498
|
+
)
|
|
473
499
|
|
|
474
500
|
tokens_log(self, tokens)
|
|
475
501
|
|
|
476
|
-
|
|
502
|
+
if isinstance(finish_reason, FinishReason):
|
|
503
|
+
finish_reason = finish_reason.name
|
|
477
504
|
|
|
478
505
|
return Turn(
|
|
479
506
|
"assistant",
|
|
480
507
|
contents,
|
|
481
508
|
tokens=tokens,
|
|
482
|
-
finish_reason=
|
|
509
|
+
finish_reason=finish_reason,
|
|
483
510
|
completion=message,
|
|
484
511
|
)
|
|
485
512
|
|
|
486
|
-
def _gemini_tools(self, tools: list[Tool]) -> list["FunctionDeclaration"]:
|
|
487
|
-
from google.generativeai.types.content_types import FunctionDeclaration
|
|
488
|
-
|
|
489
|
-
res: list["FunctionDeclaration"] = []
|
|
490
|
-
for tool in tools:
|
|
491
|
-
fn = tool.schema["function"]
|
|
492
|
-
params = None
|
|
493
|
-
if "parameters" in fn and fn["parameters"]["properties"]:
|
|
494
|
-
params = {
|
|
495
|
-
"type": "object",
|
|
496
|
-
"properties": fn["parameters"]["properties"],
|
|
497
|
-
"required": fn["parameters"]["required"],
|
|
498
|
-
}
|
|
499
|
-
|
|
500
|
-
res.append(
|
|
501
|
-
FunctionDeclaration(
|
|
502
|
-
name=fn["name"],
|
|
503
|
-
description=fn.get("description", ""),
|
|
504
|
-
parameters=params,
|
|
505
|
-
)
|
|
506
|
-
)
|
|
507
513
|
|
|
508
|
-
|
|
514
|
+
def ChatVertex(
|
|
515
|
+
*,
|
|
516
|
+
model: Optional[str] = None,
|
|
517
|
+
project: Optional[str] = None,
|
|
518
|
+
location: Optional[str] = None,
|
|
519
|
+
api_key: Optional[str] = None,
|
|
520
|
+
system_prompt: Optional[str] = None,
|
|
521
|
+
turns: Optional[list[Turn]] = None,
|
|
522
|
+
kwargs: Optional["ChatClientArgs"] = None,
|
|
523
|
+
) -> Chat["SubmitInputArgs", GenerateContentResponse]:
|
|
524
|
+
"""
|
|
525
|
+
Chat with a Google Vertex AI model.
|
|
526
|
+
|
|
527
|
+
Prerequisites
|
|
528
|
+
-------------
|
|
529
|
+
|
|
530
|
+
::: {.callout-note}
|
|
531
|
+
## Python requirements
|
|
532
|
+
|
|
533
|
+
`ChatGoogle` requires the `google-genai` package
|
|
534
|
+
(e.g., `pip install google-genai`).
|
|
535
|
+
:::
|
|
536
|
+
|
|
537
|
+
::: {.callout-note}
|
|
538
|
+
## Credentials
|
|
539
|
+
|
|
540
|
+
To use Google's models (i.e., Vertex AI), you'll need to sign up for an account
|
|
541
|
+
with [Vertex AI](https://cloud.google.com/vertex-ai), then specify the appropriate
|
|
542
|
+
model, project, and location.
|
|
543
|
+
:::
|
|
544
|
+
|
|
545
|
+
Parameters
|
|
546
|
+
----------
|
|
547
|
+
model
|
|
548
|
+
The model to use for the chat. The default, None, will pick a reasonable
|
|
549
|
+
default, and warn you about it. We strongly recommend explicitly choosing
|
|
550
|
+
a model for all but the most casual use.
|
|
551
|
+
project
|
|
552
|
+
The Google Cloud project ID (e.g., "your-project-id"). If not provided, the
|
|
553
|
+
GOOGLE_CLOUD_PROJECT environment variable will be used.
|
|
554
|
+
location
|
|
555
|
+
The Google Cloud location (e.g., "us-central1"). If not provided, the
|
|
556
|
+
GOOGLE_CLOUD_LOCATION environment variable will be used.
|
|
557
|
+
system_prompt
|
|
558
|
+
A system prompt to set the behavior of the assistant.
|
|
559
|
+
turns
|
|
560
|
+
A list of turns to start the chat with (i.e., continuing a previous
|
|
561
|
+
conversation). If not provided, the conversation begins from scratch.
|
|
562
|
+
Do not provide non-`None` values for both `turns` and `system_prompt`.
|
|
563
|
+
Each message in the list should be a dictionary with at least `role`
|
|
564
|
+
(usually `system`, `user`, or `assistant`, but `tool` is also possible).
|
|
565
|
+
Normally there is also a `content` field, which is a string.
|
|
566
|
+
|
|
567
|
+
Returns
|
|
568
|
+
-------
|
|
569
|
+
Chat
|
|
570
|
+
A Chat object.
|
|
571
|
+
|
|
572
|
+
Examples
|
|
573
|
+
--------
|
|
574
|
+
|
|
575
|
+
```python
|
|
576
|
+
import os
|
|
577
|
+
from chatlas import ChatVertex
|
|
578
|
+
|
|
579
|
+
chat = ChatVertex(
|
|
580
|
+
project="your-project-id",
|
|
581
|
+
location="us-central1",
|
|
582
|
+
)
|
|
583
|
+
chat.chat("What is the capital of France?")
|
|
584
|
+
```
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
if kwargs is None:
|
|
588
|
+
kwargs = {}
|
|
589
|
+
|
|
590
|
+
kwargs["vertexai"] = True
|
|
591
|
+
kwargs["project"] = project
|
|
592
|
+
kwargs["location"] = location
|
|
593
|
+
|
|
594
|
+
if model is None:
|
|
595
|
+
model = log_model_default("gemini-2.0-flash")
|
|
596
|
+
|
|
597
|
+
return Chat(
|
|
598
|
+
provider=GoogleProvider(
|
|
599
|
+
model=model,
|
|
600
|
+
api_key=api_key,
|
|
601
|
+
kwargs=kwargs,
|
|
602
|
+
),
|
|
603
|
+
turns=normalize_turns(
|
|
604
|
+
turns or [],
|
|
605
|
+
system_prompt=system_prompt,
|
|
606
|
+
),
|
|
607
|
+
)
|
chatlas/_merge.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# Adapted from https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/utils/_merge.py
|
|
2
|
-
# Also tweaked to more closely match https://github.com/hadley/
|
|
2
|
+
# Also tweaked to more closely match https://github.com/hadley/ellmer/blob/main/R/utils-merge.R
|
|
3
3
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
chatlas/_openai.py
CHANGED
|
@@ -338,7 +338,7 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
338
338
|
return chunkd
|
|
339
339
|
return merge_dicts(completion, chunkd)
|
|
340
340
|
|
|
341
|
-
def stream_turn(self, completion, has_data_model
|
|
341
|
+
def stream_turn(self, completion, has_data_model) -> Turn:
|
|
342
342
|
from openai.types.chat import ChatCompletion
|
|
343
343
|
|
|
344
344
|
delta = completion["choices"][0].pop("delta") # type: ignore
|
|
@@ -346,9 +346,6 @@ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|
|
346
346
|
completion = ChatCompletion.construct(**completion)
|
|
347
347
|
return self._as_turn(completion, has_data_model)
|
|
348
348
|
|
|
349
|
-
async def stream_turn_async(self, completion, has_data_model, stream):
|
|
350
|
-
return self.stream_turn(completion, has_data_model, stream)
|
|
351
|
-
|
|
352
349
|
def value_turn(self, completion, has_data_model) -> Turn:
|
|
353
350
|
return self._as_turn(completion, has_data_model)
|
|
354
351
|
|
chatlas/_provider.py
CHANGED
|
@@ -125,15 +125,6 @@ class Provider(
|
|
|
125
125
|
self,
|
|
126
126
|
completion: ChatCompletionDictT,
|
|
127
127
|
has_data_model: bool,
|
|
128
|
-
stream: Any,
|
|
129
|
-
) -> Turn: ...
|
|
130
|
-
|
|
131
|
-
@abstractmethod
|
|
132
|
-
async def stream_turn_async(
|
|
133
|
-
self,
|
|
134
|
-
completion: ChatCompletionDictT,
|
|
135
|
-
has_data_model: bool,
|
|
136
|
-
stream: Any,
|
|
137
128
|
) -> Turn: ...
|
|
138
129
|
|
|
139
130
|
@abstractmethod
|
chatlas/py.typed
ADDED
|
File without changes
|