langroid 0.59.0b3__py3-none-any.whl → 0.59.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. langroid/agent/done_sequence_parser.py +46 -11
  2. langroid/agent/special/doc_chat_task.py +0 -0
  3. langroid/agent/task.py +44 -7
  4. langroid/language_models/model_info.py +51 -0
  5. langroid/mcp/__init__.py +1 -0
  6. langroid/mcp/server/__init__.py +1 -0
  7. langroid/pydantic_v1/__init__.py +1 -1
  8. {langroid-0.59.0b3.dist-info → langroid-0.59.2.dist-info}/METADATA +4 -1
  9. {langroid-0.59.0b3.dist-info → langroid-0.59.2.dist-info}/RECORD +11 -47
  10. langroid/agent/base.py-e +0 -2216
  11. langroid/agent/chat_agent.py-e +0 -2086
  12. langroid/agent/chat_document.py-e +0 -513
  13. langroid/agent/openai_assistant.py-e +0 -882
  14. langroid/agent/special/arangodb/arangodb_agent.py-e +0 -648
  15. langroid/agent/special/lance_tools.py-e +0 -61
  16. langroid/agent/special/neo4j/neo4j_chat_agent.py-e +0 -430
  17. langroid/agent/task.py-e +0 -2418
  18. langroid/agent/tool_message.py-e +0 -400
  19. langroid/agent/tools/file_tools.py-e +0 -234
  20. langroid/agent/tools/mcp/fastmcp_client.py-e +0 -584
  21. langroid/agent/tools/orchestration.py-e +0 -301
  22. langroid/agent/tools/task_tool.py-e +0 -249
  23. langroid/agent/xml_tool_message.py-e +0 -392
  24. langroid/embedding_models/models.py-e +0 -563
  25. langroid/language_models/azure_openai.py-e +0 -134
  26. langroid/language_models/base.py-e +0 -812
  27. langroid/language_models/config.py-e +0 -18
  28. langroid/language_models/model_info.py-e +0 -483
  29. langroid/language_models/openai_gpt.py-e +0 -2280
  30. langroid/language_models/provider_params.py-e +0 -153
  31. langroid/mytypes.py-e +0 -132
  32. langroid/parsing/file_attachment.py-e +0 -246
  33. langroid/parsing/md_parser.py-e +0 -574
  34. langroid/parsing/parser.py-e +0 -410
  35. langroid/parsing/repo_loader.py-e +0 -812
  36. langroid/parsing/url_loader.py-e +0 -683
  37. langroid/parsing/urls.py-e +0 -279
  38. langroid/pydantic_v1/__init__.py-e +0 -36
  39. langroid/pydantic_v1/main.py-e +0 -11
  40. langroid/utils/configuration.py-e +0 -141
  41. langroid/utils/constants.py-e +0 -32
  42. langroid/utils/globals.py-e +0 -49
  43. langroid/utils/html_logger.py-e +0 -825
  44. langroid/utils/object_registry.py-e +0 -66
  45. langroid/utils/pydantic_utils.py-e +0 -602
  46. langroid/utils/types.py-e +0 -113
  47. langroid/vector_store/lancedb.py-e +0 -404
  48. langroid/vector_store/pineconedb.py-e +0 -427
  49. {langroid-0.59.0b3.dist-info → langroid-0.59.2.dist-info}/WHEEL +0 -0
  50. {langroid-0.59.0b3.dist-info → langroid-0.59.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,563 +0,0 @@
1
- import atexit
2
- import os
3
- from functools import cached_property
4
- from typing import Any, Callable, Dict, List, Optional
5
-
6
- import requests
7
- import tiktoken
8
- from dotenv import load_dotenv
9
- from openai import AzureOpenAI, OpenAI
10
-
11
- from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
12
- from langroid.exceptions import LangroidImportError
13
- from langroid.language_models.provider_params import LangDBParams
14
- from langroid.mytypes import Embeddings
15
- from langroid.parsing.utils import batched
16
- from pydantic import ConfigDict
17
-
18
- AzureADTokenProvider = Callable[[], str]
19
-
20
-
21
- class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
22
- model_type: str = "openai"
23
- model_name: str = "text-embedding-3-small"
24
- api_key: str = ""
25
- api_base: Optional[str] = None
26
- organization: str = ""
27
- dims: int = 1536
28
- context_length: int = 8192
29
- langdb_params: LangDBParams = LangDBParams()
30
-
31
- model_config = ConfigDict(env_prefix="OPENAI_")
32
-
33
- class AzureOpenAIEmbeddingsConfig(EmbeddingModelsConfig):
34
- model_type: str = "azure-openai"
35
- model_name: str = "text-embedding-3-small"
36
- api_key: str = ""
37
- api_base: str = ""
38
- deployment_name: Optional[str] = None
39
- # api_version defaulted to 2024-06-01 as per https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=python-new
40
- # change this to required supported version
41
- api_version: Optional[str] = "2024-06-01"
42
- # TODO: Add auth support for Azure OpenAI via AzureADTokenProvider
43
- azure_ad_token: Optional[str] = None
44
- azure_ad_token_provider: Optional[AzureADTokenProvider] = None
45
- dims: int = 1536
46
- context_length: int = 8192
47
-
48
- model_config = ConfigDict(env_prefix="AZURE_OPENAI_")
49
-
50
- class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
51
- model_type: str = "sentence-transformer"
52
- model_name: str = "BAAI/bge-large-en-v1.5"
53
- context_length: int = 512
54
- data_parallel: bool = False
55
- # Select device (e.g. "cuda", "cpu") when data parallel is disabled
56
- device: Optional[str] = None
57
- # Select devices when data parallel is enabled
58
- devices: Optional[list[str]] = None
59
-
60
-
61
- class FastEmbedEmbeddingsConfig(EmbeddingModelsConfig):
62
- """Config for qdrant/fastembed embeddings,
63
- see here: https://github.com/qdrant/fastembed
64
- """
65
-
66
- model_type: str = "fastembed"
67
- model_name: str = "BAAI/bge-small-en-v1.5"
68
- batch_size: int = 256
69
- cache_dir: Optional[str] = None
70
- threads: Optional[int] = None
71
- parallel: Optional[int] = None
72
- additional_kwargs: Dict[str, Any] = {}
73
-
74
-
75
- class LlamaCppServerEmbeddingsConfig(EmbeddingModelsConfig):
76
- api_base: str = ""
77
- context_length: int = 2048
78
- batch_size: int = 2048
79
-
80
-
81
- class GeminiEmbeddingsConfig(EmbeddingModelsConfig):
82
- model_type: str = "gemini"
83
- model_name: str = "models/text-embedding-004"
84
- api_key: str = ""
85
- dims: int = 768
86
- batch_size: int = 512
87
-
88
-
89
- class EmbeddingFunctionCallable:
90
- """
91
- A callable class designed to generate embeddings for a list of texts using
92
- the OpenAI or Azure OpenAI API, with automatic retries on failure.
93
-
94
- Attributes:
95
- embed_model (EmbeddingModel): An instance of EmbeddingModel that provides
96
- configuration and utilities for generating embeddings.
97
-
98
- Methods:
99
- __call__(input: List[str]) -> Embeddings: Generate embeddings for
100
- a list of input texts.
101
- """
102
-
103
- def __init__(self, embed_model: EmbeddingModel, batch_size: int = 512):
104
- """
105
- Initialize the EmbeddingFunctionCallable with a specific model.
106
-
107
- Args:
108
- model ( OpenAIEmbeddings or AzureOpenAIEmbeddings): An instance of
109
- OpenAIEmbeddings or AzureOpenAIEmbeddings to use for
110
- generating embeddings.
111
- batch_size (int): Batch size
112
- """
113
- self.embed_model = embed_model
114
- self.batch_size = batch_size
115
-
116
- def __call__(self, input: List[str]) -> Embeddings:
117
- """
118
- Generate embeddings for a given list of input texts using the OpenAI API,
119
- with retries on failure.
120
-
121
- This method:
122
- - Truncates each text in the input list to the model's maximum context length.
123
- - Processes the texts in batches to generate embeddings efficiently.
124
- - Automatically retries the embedding generation process with exponential
125
- backoff in case of failures.
126
-
127
- Args:
128
- input (List[str]): A list of input texts to generate embeddings for.
129
-
130
- Returns:
131
- Embeddings: A list of embedding vectors corresponding to the input texts.
132
- """
133
- embeds = []
134
- if isinstance(self.embed_model, (OpenAIEmbeddings, AzureOpenAIEmbeddings)):
135
- # Truncate texts to context length while preserving text format
136
- truncated_texts = self.embed_model.truncate_texts(input)
137
-
138
- # Process in batches
139
- for batch in batched(truncated_texts, self.batch_size):
140
- result = self.embed_model.client.embeddings.create(
141
- input=batch, model=self.embed_model.config.model_name # type: ignore
142
- )
143
- batch_embeds = [d.embedding for d in result.data]
144
- embeds.extend(batch_embeds)
145
-
146
- elif isinstance(self.embed_model, SentenceTransformerEmbeddings):
147
- if self.embed_model.config.data_parallel:
148
- embeds = self.embed_model.model.encode_multi_process(
149
- input,
150
- self.embed_model.pool,
151
- batch_size=self.batch_size,
152
- ).tolist()
153
- else:
154
- for str_batch in batched(input, self.batch_size):
155
- batch_embeds = self.embed_model.model.encode(
156
- str_batch, convert_to_numpy=True
157
- ).tolist() # type: ignore
158
- embeds.extend(batch_embeds)
159
-
160
- elif isinstance(self.embed_model, FastEmbedEmbeddings):
161
- embeddings = self.embed_model.model.embed(
162
- input, batch_size=self.batch_size, parallel=self.embed_model.parallel
163
- )
164
-
165
- embeds = [embedding.tolist() for embedding in embeddings]
166
- elif isinstance(self.embed_model, LlamaCppServerEmbeddings):
167
- for input_string in input:
168
- tokenized_text = self.embed_model.tokenize_string(input_string)
169
- for token_batch in batched(tokenized_text, self.batch_size):
170
- gen_embedding = self.embed_model.generate_embedding(
171
- self.embed_model.detokenize_string(list(token_batch))
172
- )
173
- embeds.append(gen_embedding)
174
- elif isinstance(self.embed_model, GeminiEmbeddings):
175
- embeds = self.embed_model.generate_embeddings(input)
176
- return embeds
177
-
178
-
179
- class OpenAIEmbeddings(EmbeddingModel):
180
- def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
181
- super().__init__()
182
- self.config = config
183
- load_dotenv()
184
-
185
- # Check if using LangDB
186
- self.is_langdb = self.config.model_name.startswith("langdb/")
187
-
188
- if self.is_langdb:
189
- self.config.model_name = self.config.model_name.replace("langdb/", "")
190
- self.config.api_base = self.config.langdb_params.base_url
191
- project_id = self.config.langdb_params.project_id
192
- if project_id:
193
- self.config.api_base += "/" + project_id + "/v1"
194
- self.config.api_key = self.config.langdb_params.api_key
195
-
196
- if not self.config.api_key:
197
- self.config.api_key = os.getenv("OPENAI_API_KEY", "")
198
-
199
- self.config.organization = os.getenv("OPENAI_ORGANIZATION", "")
200
-
201
- if self.config.api_key == "":
202
- if self.is_langdb:
203
- raise ValueError(
204
- """
205
- LANGDB_API_KEY must be set in .env or your environment
206
- to use OpenAIEmbeddings via LangDB.
207
- """
208
- )
209
- else:
210
- raise ValueError(
211
- """
212
- OPENAI_API_KEY must be set in .env or your environment
213
- to use OpenAIEmbeddings.
214
- """
215
- )
216
-
217
- self.client = OpenAI(
218
- base_url=self.config.api_base,
219
- api_key=self.config.api_key,
220
- organization=self.config.organization,
221
- )
222
- model_for_tokenizer = self.config.model_name
223
- if model_for_tokenizer.startswith("openai/"):
224
- self.config.model_name = model_for_tokenizer.replace("openai/", "")
225
- self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
226
-
227
- def truncate_texts(self, texts: List[str]) -> List[str] | List[List[int]]:
228
- """
229
- Truncate texts to the embedding model's context length.
230
- TODO: Maybe we should show warning, and consider doing T5 summarization?
231
- """
232
- truncated_tokens = [
233
- self.tokenizer.encode(text, disallowed_special=())[
234
- : self.config.context_length
235
- ]
236
- for text in texts
237
- ]
238
-
239
- if self.is_langdb:
240
- # LangDB embedding endpt only works with strings, not tokens
241
- return [self.tokenizer.decode(tokens) for tokens in truncated_tokens]
242
- return truncated_tokens
243
-
244
- def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
245
- return EmbeddingFunctionCallable(self, self.config.batch_size)
246
-
247
- @property
248
- def embedding_dims(self) -> int:
249
- return self.config.dims
250
-
251
-
252
- class AzureOpenAIEmbeddings(EmbeddingModel):
253
- """
254
- Azure OpenAI embeddings model implementation.
255
- """
256
-
257
- def __init__(
258
- self, config: AzureOpenAIEmbeddingsConfig = AzureOpenAIEmbeddingsConfig()
259
- ):
260
- """
261
- Initializes Azure OpenAI embeddings model.
262
-
263
- Args:
264
- config: Configuration for Azure OpenAI embeddings model.
265
- Raises:
266
- ValueError: If required Azure config values are not set.
267
- """
268
- super().__init__()
269
- self.config = config
270
- load_dotenv()
271
-
272
- if self.config.api_key == "":
273
- raise ValueError(
274
- """AZURE_OPENAI_API_KEY env variable must be set to use
275
- AzureOpenAIEmbeddings. Please set the AZURE_OPENAI_API_KEY value
276
- in your .env file."""
277
- )
278
-
279
- if self.config.api_base == "":
280
- raise ValueError(
281
- """AZURE_OPENAI_API_BASE env variable must be set to use
282
- AzureOpenAIEmbeddings. Please set the AZURE_OPENAI_API_BASE value
283
- in your .env file."""
284
- )
285
- self.client = AzureOpenAI(
286
- api_key=self.config.api_key,
287
- api_version=self.config.api_version,
288
- azure_endpoint=self.config.api_base,
289
- azure_deployment=self.config.deployment_name,
290
- )
291
- self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
292
-
293
- def truncate_texts(self, texts: List[str]) -> List[str] | List[List[int]]:
294
- """
295
- Truncate texts to the embedding model's context length.
296
- TODO: Maybe we should show warning, and consider doing T5 summarization?
297
- """
298
- return [
299
- self.tokenizer.encode(text, disallowed_special=())[
300
- : self.config.context_length
301
- ]
302
- for text in texts
303
- ]
304
-
305
- def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
306
- """Get the embedding function for Azure OpenAI.
307
-
308
- Returns:
309
- Callable that generates embeddings for input texts.
310
- """
311
- return EmbeddingFunctionCallable(self, self.config.batch_size)
312
-
313
- @property
314
- def embedding_dims(self) -> int:
315
- return self.config.dims
316
-
317
-
318
- STEC = SentenceTransformerEmbeddingsConfig
319
-
320
-
321
- class SentenceTransformerEmbeddings(EmbeddingModel):
322
- def __init__(self, config: STEC = STEC()):
323
- # this is an "extra" optional dependency, so we import it here
324
- try:
325
- from sentence_transformers import SentenceTransformer
326
- from transformers import AutoTokenizer
327
- except ImportError:
328
- raise ImportError(
329
- """
330
- To use sentence_transformers embeddings,
331
- you must install langroid with the [hf-embeddings] extra, e.g.:
332
- pip install "langroid[hf-embeddings]"
333
- """
334
- )
335
-
336
- super().__init__()
337
- self.config = config
338
-
339
- self.model = SentenceTransformer(
340
- self.config.model_name,
341
- device=self.config.device,
342
- )
343
- if self.config.data_parallel:
344
- self.pool = self.model.start_multi_process_pool(
345
- self.config.devices # type: ignore
346
- )
347
- atexit.register(
348
- lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
349
- )
350
-
351
- self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
352
- self.config.context_length = self.tokenizer.model_max_length
353
-
354
- def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
355
- return EmbeddingFunctionCallable(self, self.config.batch_size)
356
-
357
- @property
358
- def embedding_dims(self) -> int:
359
- dims = self.model.get_sentence_embedding_dimension()
360
- if dims is None:
361
- raise ValueError(
362
- f"Could not get embedding dimension for model {self.config.model_name}"
363
- )
364
- return dims # type: ignore
365
-
366
-
367
- class FastEmbedEmbeddings(EmbeddingModel):
368
- def __init__(self, config: FastEmbedEmbeddingsConfig = FastEmbedEmbeddingsConfig()):
369
- try:
370
- from fastembed import TextEmbedding
371
- except ImportError:
372
- raise LangroidImportError("fastembed", extra="fastembed")
373
-
374
- super().__init__()
375
- self.config = config
376
- self.batch_size = config.batch_size
377
- self.parallel = config.parallel
378
-
379
- self.model = TextEmbedding(
380
- model_name=self.config.model_name,
381
- cache_dir=self.config.cache_dir,
382
- threads=self.config.threads,
383
- **self.config.additional_kwargs,
384
- )
385
-
386
- def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
387
- return EmbeddingFunctionCallable(self, self.config.batch_size)
388
-
389
- @cached_property
390
- def embedding_dims(self) -> int:
391
- embed_func = self.embedding_fn()
392
- return len(embed_func(["text"])[0])
393
-
394
-
395
- LCSEC = LlamaCppServerEmbeddingsConfig
396
-
397
-
398
- class LlamaCppServerEmbeddings(EmbeddingModel):
399
- def __init__(self, config: LCSEC = LCSEC()):
400
- super().__init__()
401
- self.config = config
402
-
403
- if self.config.api_base == "":
404
- raise ValueError(
405
- """Api Base MUST be set for Llama Server Embeddings.
406
- """
407
- )
408
-
409
- self.tokenize_url = self.config.api_base + "/tokenize"
410
- self.detokenize_url = self.config.api_base + "/detokenize"
411
- self.embedding_url = self.config.api_base + "/embeddings"
412
-
413
- def tokenize_string(self, text: str) -> List[int]:
414
- data = {"content": text, "add_special": False, "with_pieces": False}
415
- response = requests.post(self.tokenize_url, json=data)
416
-
417
- if response.status_code == 200:
418
- tokens = response.model_dump_json()["tokens"]
419
- if not (isinstance(tokens, list) and isinstance(tokens[0], (int, float))):
420
- # not all(isinstance(token, (int, float)) for token in tokens):
421
- raise ValueError(
422
- """Tokenizer endpoint has not returned the correct format.
423
- Is the URL correct?
424
- """
425
- )
426
- return tokens
427
- else:
428
- raise requests.HTTPError(
429
- self.tokenize_url,
430
- response.status_code,
431
- "Failed to connect to tokenization provider",
432
- )
433
-
434
- def detokenize_string(self, tokens: List[int]) -> str:
435
- data = {"tokens": tokens}
436
- response = requests.post(self.detokenize_url, json=data)
437
-
438
- if response.status_code == 200:
439
- text = response.model_dump_json()["content"]
440
- if not isinstance(text, str):
441
- raise ValueError(
442
- """Deokenizer endpoint has not returned the correct format.
443
- Is the URL correct?
444
- """
445
- )
446
- return text
447
- else:
448
- raise requests.HTTPError(
449
- self.detokenize_url,
450
- response.status_code,
451
- "Failed to connect to detokenization provider",
452
- )
453
-
454
- def truncate_string_to_context_size(self, text: str) -> str:
455
- tokens = self.tokenize_string(text)
456
- tokens = tokens[: self.config.context_length]
457
- return self.detokenize_string(tokens)
458
-
459
- def generate_embedding(self, text: str) -> List[int | float]:
460
- data = {"content": text}
461
- response = requests.post(self.embedding_url, json=data)
462
-
463
- if response.status_code == 200:
464
- embeddings = response.model_dump_json()["embedding"]
465
- if not (
466
- isinstance(embeddings, list) and isinstance(embeddings[0], (int, float))
467
- ):
468
- raise ValueError(
469
- """Embedding endpoint has not returned the correct format.
470
- Is the URL correct?
471
- """
472
- )
473
- return embeddings
474
- else:
475
- raise requests.HTTPError(
476
- self.embedding_url,
477
- response.status_code,
478
- "Failed to connect to embedding provider",
479
- )
480
-
481
- def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
482
- return EmbeddingFunctionCallable(self, self.config.batch_size)
483
-
484
- @property
485
- def embedding_dims(self) -> int:
486
- return self.config.dims
487
-
488
-
489
- class GeminiEmbeddings(EmbeddingModel):
490
- def __init__(self, config: GeminiEmbeddingsConfig = GeminiEmbeddingsConfig()):
491
- try:
492
- from google import genai
493
- except ImportError as e:
494
- raise LangroidImportError(extra="google-genai", error=str(e))
495
- super().__init__()
496
- self.config = config
497
- load_dotenv()
498
- self.config.api_key = os.getenv("GEMINI_API_KEY", "")
499
-
500
- if self.config.api_key == "":
501
- raise ValueError(
502
- """
503
- GEMINI_API_KEY env variable must be set to use GeminiEmbeddings.
504
- """
505
- )
506
- self.client = genai.Client(api_key=self.config.api_key)
507
-
508
- def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
509
- return EmbeddingFunctionCallable(self, self.config.batch_size)
510
-
511
- def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
512
- """Generates embeddings for a list of input texts."""
513
- all_embeddings: List[List[float]] = []
514
-
515
- for batch in batched(texts, self.config.batch_size):
516
- result = self.client.models.embed_content( # type: ignore[attr-defined]
517
- model=self.config.model_name,
518
- contents=batch, # type: ignore
519
- )
520
-
521
- if not hasattr(result, "embeddings") or not isinstance(
522
- result.embeddings, list
523
- ):
524
- raise ValueError(
525
- "Unexpected format for embeddings: missing or incorrect type"
526
- )
527
-
528
- # Extract .values from ContentEmbedding objects
529
- all_embeddings.extend(
530
- [emb.values for emb in result.embeddings] # type: ignore
531
- )
532
-
533
- return all_embeddings
534
-
535
- @property
536
- def embedding_dims(self) -> int:
537
- return self.config.dims
538
-
539
-
540
- def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
541
- """
542
- Args:
543
- embedding_fn_type: Type of embedding model to use. Options are:
544
- - "openai",
545
- - "azure-openai",
546
- - "sentencetransformer", or
547
- - "fastembed".
548
- (others may be added in the future)
549
- Returns:
550
- EmbeddingModel: The corresponding embedding model class.
551
- """
552
- if embedding_fn_type == "openai":
553
- return OpenAIEmbeddings # type: ignore
554
- elif embedding_fn_type == "azure-openai":
555
- return AzureOpenAIEmbeddings # type: ignore
556
- elif embedding_fn_type == "fastembed":
557
- return FastEmbedEmbeddings # type: ignore
558
- elif embedding_fn_type == "llamacppserver":
559
- return LlamaCppServerEmbeddings # type: ignore
560
- elif embedding_fn_type == "gemini":
561
- return GeminiEmbeddings # type: ignore
562
- else: # default sentence transformer
563
- return SentenceTransformerEmbeddings # type: ignore
@@ -1,134 +0,0 @@
1
- import logging
2
- from typing import Callable
3
-
4
- from dotenv import load_dotenv
5
- from httpx import Timeout
6
- from openai import AsyncAzureOpenAI, AzureOpenAI
7
-
8
- from langroid.language_models.openai_gpt import (
9
- OpenAIGPT,
10
- OpenAIGPTConfig,
11
- )
12
- from pydantic import ConfigDict
13
-
14
- azureStructuredOutputList = [
15
- "2024-08-06",
16
- "2024-11-20",
17
- ]
18
-
19
- azureStructuredOutputAPIMin = "2024-08-01-preview"
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- class AzureConfig(OpenAIGPTConfig):
25
- """
26
- Configuration for Azure OpenAI GPT.
27
-
28
- Attributes:
29
- type (str): should be ``azure.``
30
- api_version (str): can be set in the ``.env`` file as
31
- ``AZURE_OPENAI_API_VERSION.``
32
- deployment_name (str|None): can be optionally set in the ``.env`` file as
33
- ``AZURE_OPENAI_DEPLOYMENT_NAME`` and should be based the custom name you
34
- chose for your deployment when you deployed a model.
35
- model_name (str): [DEPRECATED] can be set in the ``.env``
36
- file as ``AZURE_OPENAI_MODEL_NAME``
37
- and should be based on the model name chosen during setup.
38
- chat_model (str): the chat model name to use. Can be set via
39
- the env variable ``AZURE_OPENAI_CHAT_MODEL``.
40
- Recommended to use this instead of ``model_name``.
41
-
42
- """
43
-
44
- api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
45
- type: str = "azure"
46
- api_version: str = "2023-05-15"
47
- deployment_name: str | None = None
48
- model_name: str = ""
49
- api_base: str = ""
50
-
51
- # Alternatively, bring your own clients:
52
- azure_openai_client_provider: Callable[[], AzureOpenAI] | None = None
53
- azure_openai_async_client_provider: Callable[[], AsyncAzureOpenAI] | None = None
54
-
55
- # all of the vars above can be set via env vars,
56
- # by upper-casing the name and prefixing with `env_prefix`, e.g.
57
- # AZURE_OPENAI_API_VERSION=2023-05-15
58
- # This is either done in the .env file, or via an explicit
59
- # `export AZURE_OPENAI_API_VERSION=...`
60
- model_config = ConfigDict(env_prefix="AZURE_OPENAI_")
61
-
62
- def __init__(self, **kwargs) -> None: # type: ignore
63
- if "model_name" in kwargs and "chat_model" not in kwargs:
64
- kwargs["chat_model"] = kwargs["model_name"]
65
- super().__init__(**kwargs)
66
-
67
-
68
- class AzureGPT(OpenAIGPT):
69
- """
70
- Class to access OpenAI LLMs via Azure. These env variables can be obtained from the
71
- file `.azure_env`. Azure OpenAI doesn't support ``completion``
72
- """
73
-
74
- def __init__(self, config: AzureConfig):
75
- # This will auto-populate config values from .env file
76
- load_dotenv()
77
- super().__init__(config)
78
- self.config: AzureConfig = config
79
-
80
- if (
81
- self.config.azure_openai_client_provider
82
- or self.config.azure_openai_async_client_provider
83
- ):
84
- if not self.config.azure_openai_client_provider:
85
- self.client = None
86
- logger.warning(
87
- "Using user-provided Azure OpenAI client, but only async "
88
- "client has been provided. Synchronous calls will fail."
89
- )
90
- if not self.config.azure_openai_async_client_provider:
91
- self.async_client = None
92
- logger.warning(
93
- "Using user-provided Azure OpenAI client, but no async "
94
- "client has been provided. Asynchronous calls will fail."
95
- )
96
-
97
- if self.config.azure_openai_client_provider:
98
- self.client = self.config.azure_openai_client_provider()
99
- if self.config.azure_openai_async_client_provider:
100
- self.async_client = self.config.azure_openai_async_client_provider()
101
- self.async_client.timeout = Timeout(self.config.timeout)
102
- else:
103
- if self.config.api_key == "":
104
- raise ValueError(
105
- """
106
- AZURE_OPENAI_API_KEY not set in .env file,
107
- please set it to your Azure API key."""
108
- )
109
-
110
- if self.config.api_base == "":
111
- raise ValueError(
112
- """
113
- AZURE_OPENAI_API_BASE not set in .env file,
114
- please set it to your Azure API key."""
115
- )
116
-
117
- self.client = AzureOpenAI(
118
- api_key=self.config.api_key,
119
- azure_endpoint=self.config.api_base,
120
- api_version=self.config.api_version,
121
- azure_deployment=self.config.deployment_name,
122
- )
123
- self.async_client = AsyncAzureOpenAI(
124
- api_key=self.config.api_key,
125
- azure_endpoint=self.config.api_base,
126
- api_version=self.config.api_version,
127
- azure_deployment=self.config.deployment_name,
128
- timeout=Timeout(self.config.timeout),
129
- )
130
-
131
- self.supports_json_schema = (
132
- self.config.api_version >= azureStructuredOutputAPIMin
133
- and self.config.api_version in azureStructuredOutputList
134
- )