chatlas 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/_openai.py ADDED
@@ -0,0 +1,654 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from ._chat import Chat
9
+ from ._content import (
10
+ Content,
11
+ ContentImageInline,
12
+ ContentImageRemote,
13
+ ContentJson,
14
+ ContentText,
15
+ ContentToolRequest,
16
+ ContentToolResult,
17
+ )
18
+ from ._logging import log_model_default
19
+ from ._merge import merge_dicts
20
+ from ._provider import Provider
21
+ from ._tokens import tokens_log
22
+ from ._tools import Tool, basemodel_to_param_schema
23
+ from ._turn import Turn, normalize_turns
24
+ from ._utils import MISSING, MISSING_TYPE, is_testing
25
+
26
+ if TYPE_CHECKING:
27
+ from openai.types.chat import (
28
+ ChatCompletion,
29
+ ChatCompletionChunk,
30
+ ChatCompletionMessageParam,
31
+ )
32
+ from openai.types.chat.chat_completion_assistant_message_param import (
33
+ ContentArrayOfContentPart,
34
+ )
35
+ from openai.types.chat.chat_completion_content_part_param import (
36
+ ChatCompletionContentPartParam,
37
+ )
38
+ from openai.types.chat_model import ChatModel
39
+
40
+ from .types.openai import ChatAzureClientArgs, ChatClientArgs, SubmitInputArgs
41
+ else:
42
+ ChatCompletion = object
43
+ ChatCompletionChunk = object
44
+
45
+
46
+ # The dictionary form of ChatCompletion (TODO: stronger typing)?
47
+ ChatCompletionDict = dict[str, Any]
48
+
49
+
50
+ def ChatOpenAI(
51
+ *,
52
+ system_prompt: Optional[str] = None,
53
+ turns: Optional[list[Turn]] = None,
54
+ model: "Optional[ChatModel | str]" = None,
55
+ api_key: Optional[str] = None,
56
+ base_url: str = "https://api.openai.com/v1",
57
+ seed: int | None | MISSING_TYPE = MISSING,
58
+ kwargs: Optional["ChatClientArgs"] = None,
59
+ ) -> Chat["SubmitInputArgs", ChatCompletion]:
60
+ """
61
+ Chat with an OpenAI model.
62
+
63
+ [OpenAI](https://openai.com/) provides a number of chat based models under
64
+ the [ChatGPT](https://chatgpt.com) moniker.
65
+
66
+ Prerequisites
67
+ --------------
68
+
69
+ ::: {.callout-note}
70
+ ## API key
71
+
72
+ Note that a ChatGPT Plus membership does not give you the ability to call
73
+ models via the API. You will need to go to the [developer
74
+ platform](https://platform.openai.com) to sign up (and pay for) a developer
75
+ account that will give you an API key that you can use with this package.
76
+ :::
77
+
78
+ ::: {.callout-note}
79
+ ## Python requirements
80
+
81
+ `ChatOpenAI` requires the `openai` package (e.g., `pip install openai`).
82
+ :::
83
+
84
+ Examples
85
+ --------
86
+ ```python
87
+ import os
88
+ from chatlas import ChatOpenAI
89
+
90
+ chat = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
91
+ chat.chat("What is the capital of France?")
92
+ ```
93
+
94
+ Parameters
95
+ ----------
96
+ system_prompt
97
+ A system prompt to set the behavior of the assistant.
98
+ turns
99
+ A list of turns to start the chat with (i.e., continuing a previous
100
+ conversation). If not provided, the conversation begins from scratch. Do
101
+ not provide non-`None` values for both `turns` and `system_prompt`. Each
102
+ message in the list should be a dictionary with at least `role` (usually
103
+ `system`, `user`, or `assistant`, but `tool` is also possible). Normally
104
+ there is also a `content` field, which is a string.
105
+ model
106
+ The model to use for the chat. The default, None, will pick a reasonable
107
+ default, and warn you about it. We strongly recommend explicitly
108
+ choosing a model for all but the most casual use.
109
+ api_key
110
+ The API key to use for authentication. You generally should not supply
111
+ this directly, but instead set the `OPENAI_API_KEY` environment
112
+ variable.
113
+ base_url
114
+ The base URL to the endpoint; the default uses OpenAI.
115
+ seed
116
+ Optional integer seed that ChatGPT uses to try and make output more
117
+ reproducible.
118
+ kwargs
119
+ Additional arguments to pass to the `openai.OpenAI()` client
120
+ constructor.
121
+
122
+ Returns
123
+ -------
124
+ Chat
125
+ A chat object that retains the state of the conversation.
126
+
127
+ Note
128
+ ----
129
+ Pasting an API key into a chat constructor (e.g., `ChatOpenAI(api_key="...")`)
130
+ is the simplest way to get started, and is fine for interactive use, but is
131
+ problematic for code that may be shared with others.
132
+
133
+ Instead, consider using environment variables or a configuration file to manage
134
+ your credentials. One popular way to manage credentials is to use a `.env` file
135
+ to store your credentials, and then use the `python-dotenv` package to load them
136
+ into your environment.
137
+
138
+ ```shell
139
+ pip install python-dotenv
140
+ ```
141
+
142
+ ```shell
143
+ # .env
144
+ OPENAI_API_KEY=...
145
+ ```
146
+
147
+ ```python
148
+ from chatlas import ChatOpenAI
149
+ from dotenv import load_dotenv
150
+
151
+ load_dotenv()
152
+ chat = ChatOpenAI()
153
+ chat.console()
154
+ ```
155
+
156
+ Another, more general, solution is to load your environment variables into the shell
157
+ before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file):
158
+
159
+ ```shell
160
+ export OPENAI_API_KEY=...
161
+ ```
162
+ """
163
+ if isinstance(seed, MISSING_TYPE):
164
+ seed = 1014 if is_testing() else None
165
+
166
+ if model is None:
167
+ model = log_model_default("gpt-4o")
168
+
169
+ return Chat(
170
+ provider=OpenAIProvider(
171
+ api_key=api_key,
172
+ model=model,
173
+ base_url=base_url,
174
+ seed=seed,
175
+ kwargs=kwargs,
176
+ ),
177
+ turns=normalize_turns(
178
+ turns or [],
179
+ system_prompt,
180
+ ),
181
+ )
182
+
183
+
184
+ class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict]):
185
+ def __init__(
186
+ self,
187
+ *,
188
+ api_key: Optional[str] = None,
189
+ model: str,
190
+ base_url: str = "https://api.openai.com/v1",
191
+ seed: Optional[int] = None,
192
+ kwargs: Optional["ChatClientArgs"] = None,
193
+ ):
194
+ try:
195
+ from openai import AsyncOpenAI, OpenAI
196
+ except ImportError:
197
+ raise ImportError(
198
+ "`ChatOpenAI()` requires the `openai` package. "
199
+ "Install it with `pip install openai`."
200
+ )
201
+
202
+ self._model = model
203
+ self._seed = seed
204
+
205
+ kwargs_full: "ChatClientArgs" = {
206
+ "api_key": api_key,
207
+ "base_url": base_url,
208
+ **(kwargs or {}),
209
+ }
210
+
211
+ # TODO: worth bringing in AsyncOpenAI types?
212
+ self._client = OpenAI(**kwargs_full) # type: ignore
213
+ self._async_client = AsyncOpenAI(**kwargs_full)
214
+
215
+ @overload
216
+ def chat_perform(
217
+ self,
218
+ *,
219
+ stream: Literal[False],
220
+ turns: list[Turn],
221
+ tools: dict[str, Tool],
222
+ data_model: Optional[type[BaseModel]] = None,
223
+ kwargs: Optional["SubmitInputArgs"] = None,
224
+ ): ...
225
+
226
+ @overload
227
+ def chat_perform(
228
+ self,
229
+ *,
230
+ stream: Literal[True],
231
+ turns: list[Turn],
232
+ tools: dict[str, Tool],
233
+ data_model: Optional[type[BaseModel]] = None,
234
+ kwargs: Optional["SubmitInputArgs"] = None,
235
+ ): ...
236
+
237
+ def chat_perform(
238
+ self,
239
+ *,
240
+ stream: bool,
241
+ turns: list[Turn],
242
+ tools: dict[str, Tool],
243
+ data_model: Optional[type[BaseModel]] = None,
244
+ kwargs: Optional["SubmitInputArgs"] = None,
245
+ ):
246
+ kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
247
+ return self._client.chat.completions.create(**kwargs) # type: ignore
248
+
249
+ @overload
250
+ async def chat_perform_async(
251
+ self,
252
+ *,
253
+ stream: Literal[False],
254
+ turns: list[Turn],
255
+ tools: dict[str, Tool],
256
+ data_model: Optional[type[BaseModel]] = None,
257
+ kwargs: Optional["SubmitInputArgs"] = None,
258
+ ): ...
259
+
260
+ @overload
261
+ async def chat_perform_async(
262
+ self,
263
+ *,
264
+ stream: Literal[True],
265
+ turns: list[Turn],
266
+ tools: dict[str, Tool],
267
+ data_model: Optional[type[BaseModel]] = None,
268
+ kwargs: Optional["SubmitInputArgs"] = None,
269
+ ): ...
270
+
271
+ async def chat_perform_async(
272
+ self,
273
+ *,
274
+ stream: bool,
275
+ turns: list[Turn],
276
+ tools: dict[str, Tool],
277
+ data_model: Optional[type[BaseModel]] = None,
278
+ kwargs: Optional["SubmitInputArgs"] = None,
279
+ ):
280
+ kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
281
+ return await self._async_client.chat.completions.create(**kwargs) # type: ignore
282
+
283
+ def _chat_perform_args(
284
+ self,
285
+ stream: bool,
286
+ turns: list[Turn],
287
+ tools: dict[str, Tool],
288
+ data_model: Optional[type[BaseModel]] = None,
289
+ kwargs: Optional["SubmitInputArgs"] = None,
290
+ ) -> "SubmitInputArgs":
291
+ tool_schemas = [tool.schema for tool in tools.values()]
292
+
293
+ kwargs_full: "SubmitInputArgs" = {
294
+ "stream": stream,
295
+ "messages": self._as_message_param(turns),
296
+ "model": self._model,
297
+ "seed": self._seed,
298
+ **(kwargs or {}),
299
+ }
300
+
301
+ if tool_schemas:
302
+ kwargs_full["tools"] = tool_schemas
303
+
304
+ if data_model is not None:
305
+ params = basemodel_to_param_schema(data_model)
306
+ params = cast(dict, params)
307
+ params["additionalProperties"] = False
308
+ kwargs_full["response_format"] = {
309
+ "type": "json_schema",
310
+ "json_schema": {
311
+ "name": "structured_data",
312
+ "description": params.get("description", ""),
313
+ "schema": params,
314
+ "strict": True,
315
+ },
316
+ }
317
+ # Apparently OpenAI gets confused if you include
318
+ # both response_format and tools
319
+ if "tools" in kwargs_full:
320
+ del kwargs_full["tools"]
321
+
322
+ if stream and "stream_options" not in kwargs_full:
323
+ kwargs_full["stream_options"] = {"include_usage": True}
324
+
325
+ return kwargs_full
326
+
327
+ def stream_text(self, chunk):
328
+ if not chunk.choices:
329
+ return None
330
+ return chunk.choices[0].delta.content
331
+
332
+ def stream_merge_chunks(self, completion, chunk):
333
+ chunkd = chunk.model_dump()
334
+ if completion is None:
335
+ return chunkd
336
+ return merge_dicts(completion, chunkd)
337
+
338
+ def stream_turn(self, completion, has_data_model, stream) -> Turn:
339
+ from openai.types.chat import ChatCompletion
340
+
341
+ delta = completion["choices"][0].pop("delta") # type: ignore
342
+ completion["choices"][0]["message"] = delta # type: ignore
343
+ completion = ChatCompletion.construct(**completion)
344
+ return self._as_turn(completion, has_data_model)
345
+
346
+ async def stream_turn_async(self, completion, has_data_model, stream):
347
+ return self.stream_turn(completion, has_data_model, stream)
348
+
349
+ def value_turn(self, completion, has_data_model) -> Turn:
350
+ return self._as_turn(completion, has_data_model)
351
+
352
+ @staticmethod
353
+ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
354
+ from openai.types.chat import (
355
+ ChatCompletionAssistantMessageParam,
356
+ ChatCompletionMessageToolCallParam,
357
+ ChatCompletionSystemMessageParam,
358
+ ChatCompletionToolMessageParam,
359
+ ChatCompletionUserMessageParam,
360
+ )
361
+
362
+ res: list["ChatCompletionMessageParam"] = []
363
+ for turn in turns:
364
+ if turn.role == "system":
365
+ res.append(
366
+ ChatCompletionSystemMessageParam(content=turn.text, role="system")
367
+ )
368
+ elif turn.role == "assistant":
369
+ content_parts: list["ContentArrayOfContentPart"] = []
370
+ tool_calls: list["ChatCompletionMessageToolCallParam"] = []
371
+ for x in turn.contents:
372
+ if isinstance(x, ContentText):
373
+ content_parts.append({"type": "text", "text": x.text})
374
+ elif isinstance(x, ContentJson):
375
+ content_parts.append({"type": "text", "text": ""})
376
+ elif isinstance(x, ContentToolRequest):
377
+ tool_calls.append(
378
+ {
379
+ "id": x.id,
380
+ "function": {
381
+ "name": x.name,
382
+ "arguments": json.dumps(x.arguments),
383
+ },
384
+ "type": "function",
385
+ }
386
+ )
387
+ else:
388
+ raise ValueError(
389
+ f"Don't know how to handle content type {type(x)} for role='assistant'."
390
+ )
391
+
392
+ # Some OpenAI-compatible models (e.g., Groq) don't work nicely with empty content
393
+ args = {
394
+ "role": "assistant",
395
+ "content": content_parts,
396
+ "tool_calls": tool_calls,
397
+ }
398
+ if not content_parts:
399
+ del args["content"]
400
+ if not tool_calls:
401
+ del args["tool_calls"]
402
+
403
+ res.append(ChatCompletionAssistantMessageParam(**args))
404
+
405
+ elif turn.role == "user":
406
+ contents: list["ChatCompletionContentPartParam"] = []
407
+ tool_results: list["ChatCompletionToolMessageParam"] = []
408
+ for x in turn.contents:
409
+ if isinstance(x, ContentText):
410
+ contents.append({"type": "text", "text": x.text})
411
+ elif isinstance(x, ContentJson):
412
+ contents.append({"type": "text", "text": ""})
413
+ elif isinstance(x, ContentImageRemote):
414
+ contents.append(
415
+ {"type": "image_url", "image_url": {"url": x.url}}
416
+ )
417
+ elif isinstance(x, ContentImageInline):
418
+ contents.append(
419
+ {
420
+ "type": "image_url",
421
+ "image_url": {
422
+ "url": f"data:{x.content_type};base64,{x.data}"
423
+ },
424
+ }
425
+ )
426
+ elif isinstance(x, ContentToolResult):
427
+ tool_results.append(
428
+ ChatCompletionToolMessageParam(
429
+ # TODO: a tool could return an image!?!
430
+ content=x.get_final_value(),
431
+ tool_call_id=x.id,
432
+ role="tool",
433
+ )
434
+ )
435
+ else:
436
+ raise ValueError(
437
+ f"Don't know how to handle content type {type(x)} for role='user'."
438
+ )
439
+
440
+ if contents:
441
+ res.append(
442
+ ChatCompletionUserMessageParam(content=contents, role="user")
443
+ )
444
+ res.extend(tool_results)
445
+
446
+ else:
447
+ raise ValueError(f"Unknown role: {turn.role}")
448
+
449
+ return res
450
+
451
+ def _as_turn(
452
+ self, completion: "ChatCompletion", has_data_model: bool
453
+ ) -> Turn[ChatCompletion]:
454
+ message = completion.choices[0].message
455
+
456
+ contents: list[Content] = []
457
+ if message.content is not None:
458
+ if has_data_model:
459
+ data = json.loads(message.content)
460
+ contents = [ContentJson(data)]
461
+ else:
462
+ contents = [ContentText(message.content)]
463
+
464
+ tool_calls = message.tool_calls
465
+
466
+ if tool_calls is not None:
467
+ for call in tool_calls:
468
+ func = call.function
469
+ if func is None:
470
+ continue
471
+
472
+ args = {}
473
+ try:
474
+ args = json.loads(func.arguments) if func.arguments else {}
475
+ except json.JSONDecodeError:
476
+ raise ValueError(
477
+ f"The model's completion included a tool request ({func.name}) "
478
+ "with invalid JSON for input arguments: '{func.arguments}'"
479
+ "This can happen if the model hallucinates parameters not defined by "
480
+ "your function schema. Try revising your tool description and system "
481
+ "prompt to be more specific about the expected input arguments to this function."
482
+ )
483
+
484
+ contents.append(
485
+ ContentToolRequest(
486
+ call.id,
487
+ name=func.name,
488
+ arguments=args,
489
+ )
490
+ )
491
+
492
+ usage = completion.usage
493
+ if usage is None:
494
+ tokens = (0, 0)
495
+ else:
496
+ tokens = usage.prompt_tokens, usage.completion_tokens
497
+
498
+ # For some reason ChatGroq() includes tokens under completion.x_groq
499
+ if usage is None and hasattr(completion, "x_groq"):
500
+ usage = completion.x_groq["usage"] # type: ignore
501
+ tokens = usage["prompt_tokens"], usage["completion_tokens"]
502
+
503
+ tokens_log(self, tokens)
504
+
505
+ return Turn(
506
+ "assistant",
507
+ contents,
508
+ tokens=tokens,
509
+ finish_reason=completion.choices[0].finish_reason,
510
+ completion=completion,
511
+ )
512
+
513
+
514
+ def ChatAzureOpenAI(
515
+ *,
516
+ endpoint: str,
517
+ deployment_id: str,
518
+ api_version: str,
519
+ api_key: Optional[str] = None,
520
+ system_prompt: Optional[str] = None,
521
+ turns: Optional[list[Turn]] = None,
522
+ seed: int | None | MISSING_TYPE = MISSING,
523
+ kwargs: Optional["ChatAzureClientArgs"] = None,
524
+ ) -> Chat["SubmitInputArgs", ChatCompletion]:
525
+ """
526
+ Chat with a model hosted on Azure OpenAI.
527
+
528
+ The [Azure OpenAI server](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
529
+ hosts a number of open source models as well as proprietary models
530
+ from OpenAI.
531
+
532
+ Prerequisites
533
+ -------------
534
+
535
+ ::: {.callout-note}
536
+ ## Python requirements
537
+
538
+ `ChatAzureOpenAI` requires the `openai` package (e.g., `pip install openai`).
539
+ :::
540
+
541
+ Examples
542
+ --------
543
+ ```python
544
+ import os
545
+ from chatlas import ChatAzureOpenAI
546
+
547
+ chat = ChatAzureOpenAI(
548
+ endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
549
+ deployment_id="REPLACE_WITH_YOUR_DEPLOYMENT_ID",
550
+ api_version="YYYY-MM-DD",
551
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
552
+ )
553
+
554
+ chat.chat("What is the capital of France?")
555
+ ```
556
+
557
+ Parameters
558
+ ----------
559
+ endpoint
560
+ Azure OpenAI endpoint url with protocol and hostname, i.e.
561
+ `https://{your-resource-name}.openai.azure.com`. Defaults to using the
562
+ value of the `AZURE_OPENAI_ENDPOINT` envinronment variable.
563
+ deployment_id
564
+ Deployment id for the model you want to use.
565
+ api_version
566
+ The API version to use.
567
+ api_key
568
+ The API key to use for authentication. You generally should not supply
569
+ this directly, but instead set the `AZURE_OPENAI_API_KEY` environment
570
+ variable.
571
+ system_prompt
572
+ A system prompt to set the behavior of the assistant.
573
+ turns
574
+ A list of turns to start the chat with (i.e., continuing a previous
575
+ conversation). If not provided, the conversation begins from scratch.
576
+ Do not provide non-None values for both `turns` and `system_prompt`.
577
+ Each message in the list should be a dictionary with at least `role`
578
+ (usually `system`, `user`, or `assistant`, but `tool` is also possible).
579
+ Normally there is also a `content` field, which is a string.
580
+ seed
581
+ Optional integer seed that ChatGPT uses to try and make output more
582
+ reproducible.
583
+ kwargs
584
+ Additional arguments to pass to the `openai.AzureOpenAI()` client constructor.
585
+
586
+ Returns
587
+ -------
588
+ Chat
589
+ A Chat object.
590
+ """
591
+
592
+ if isinstance(seed, MISSING_TYPE):
593
+ seed = 1014 if is_testing() else None
594
+
595
+ return Chat(
596
+ provider=OpenAIAzureProvider(
597
+ endpoint=endpoint,
598
+ deployment_id=deployment_id,
599
+ api_version=api_version,
600
+ api_key=api_key,
601
+ seed=seed,
602
+ kwargs=kwargs,
603
+ ),
604
+ turns=normalize_turns(
605
+ turns or [],
606
+ system_prompt,
607
+ ),
608
+ )
609
+
610
+
611
+ class OpenAIAzureProvider(OpenAIProvider):
612
+ def __init__(
613
+ self,
614
+ *,
615
+ endpoint: Optional[str] = None,
616
+ deployment_id: Optional[str] = None,
617
+ api_version: Optional[str] = None,
618
+ api_key: Optional[str] = None,
619
+ seed: int | None = None,
620
+ kwargs: Optional["ChatAzureClientArgs"] = None,
621
+ ):
622
+ try:
623
+ from openai import AsyncAzureOpenAI, AzureOpenAI
624
+ except ImportError:
625
+ raise ImportError(
626
+ "`ChatAzureOpenAI()` requires the `openai` package. "
627
+ "Install it with `pip install openai`."
628
+ )
629
+
630
+ self._model = deployment_id
631
+ self._seed = seed
632
+
633
+ kwargs_full: "ChatAzureClientArgs" = {
634
+ "azure_endpoint": endpoint,
635
+ "azure_deployment": deployment_id,
636
+ "api_version": api_version,
637
+ "api_key": api_key,
638
+ **(kwargs or {}),
639
+ }
640
+
641
+ self._client = AzureOpenAI(**kwargs_full) # type: ignore
642
+ self._async_client = AsyncAzureOpenAI(**kwargs_full) # type: ignore
643
+
644
+
645
+ class InvalidJSONParameterWarning(RuntimeWarning):
646
+ """
647
+ Warning for when a tool request includes invalid JSON for input arguments.
648
+
649
+ This is a subclass of `RuntimeWarning` and is used to indicate that a tool
650
+ request included invalid JSON for input arguments. This can happen if the
651
+ model hallucinates parameters not defined by your function schema.
652
+ """
653
+
654
+ pass