durag 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- durag/__init__.py +5 -0
- durag/configs/__init__.py +0 -0
- durag/configs/base.py +81 -0
- durag/configs/embeddings/__init__.py +0 -0
- durag/configs/embeddings/base.py +110 -0
- durag/configs/enums.py +7 -0
- durag/configs/llms/__init__.py +0 -0
- durag/configs/llms/anthropic.py +56 -0
- durag/configs/llms/aws_bedrock.py +196 -0
- durag/configs/llms/azure.py +60 -0
- durag/configs/llms/base.py +67 -0
- durag/configs/llms/deepseek.py +56 -0
- durag/configs/llms/lmstudio.py +59 -0
- durag/configs/llms/minimax.py +56 -0
- durag/configs/llms/ollama.py +56 -0
- durag/configs/llms/openai.py +87 -0
- durag/configs/llms/vllm.py +56 -0
- durag/configs/prompts.py +1062 -0
- durag/configs/rerankers/__init__.py +0 -0
- durag/configs/rerankers/base.py +17 -0
- durag/configs/rerankers/cohere.py +15 -0
- durag/configs/rerankers/config.py +12 -0
- durag/configs/rerankers/huggingface.py +17 -0
- durag/configs/rerankers/llm.py +54 -0
- durag/configs/rerankers/sentence_transformer.py +16 -0
- durag/configs/rerankers/zero_entropy.py +28 -0
- durag/configs/vector_stores/__init__.py +0 -0
- durag/configs/vector_stores/azure_ai_search.py +57 -0
- durag/configs/vector_stores/azure_mysql.py +83 -0
- durag/configs/vector_stores/baidu.py +27 -0
- durag/configs/vector_stores/cassandra.py +76 -0
- durag/configs/vector_stores/chroma.py +58 -0
- durag/configs/vector_stores/databricks.py +61 -0
- durag/configs/vector_stores/elasticsearch.py +68 -0
- durag/configs/vector_stores/faiss.py +37 -0
- durag/configs/vector_stores/langchain.py +30 -0
- durag/configs/vector_stores/milvus.py +42 -0
- durag/configs/vector_stores/mongodb.py +27 -0
- durag/configs/vector_stores/neptune.py +25 -0
- durag/configs/vector_stores/opensearch.py +43 -0
- durag/configs/vector_stores/pgvector.py +54 -0
- durag/configs/vector_stores/pinecone.py +55 -0
- durag/configs/vector_stores/qdrant.py +47 -0
- durag/configs/vector_stores/redis.py +24 -0
- durag/configs/vector_stores/s3_vectors.py +28 -0
- durag/configs/vector_stores/supabase.py +46 -0
- durag/configs/vector_stores/turbopuffer.py +45 -0
- durag/configs/vector_stores/upstash_vector.py +34 -0
- durag/configs/vector_stores/valkey.py +32 -0
- durag/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- durag/configs/vector_stores/weaviate.py +41 -0
- durag/embeddings/__init__.py +0 -0
- durag/embeddings/aws_bedrock.py +100 -0
- durag/embeddings/azure_openai.py +72 -0
- durag/embeddings/base.py +47 -0
- durag/embeddings/configs.py +31 -0
- durag/embeddings/fastembed.py +32 -0
- durag/embeddings/gemini.py +39 -0
- durag/embeddings/huggingface.py +44 -0
- durag/embeddings/langchain.py +35 -0
- durag/embeddings/lmstudio.py +29 -0
- durag/embeddings/mock.py +11 -0
- durag/embeddings/ollama.py +65 -0
- durag/embeddings/openai.py +76 -0
- durag/embeddings/together.py +31 -0
- durag/embeddings/vertexai.py +64 -0
- durag/exceptions.py +485 -0
- durag/llms/__init__.py +0 -0
- durag/llms/anthropic.py +112 -0
- durag/llms/aws_bedrock.py +713 -0
- durag/llms/azure_openai.py +144 -0
- durag/llms/azure_openai_structured.py +119 -0
- durag/llms/base.py +141 -0
- durag/llms/configs.py +35 -0
- durag/llms/deepseek.py +109 -0
- durag/llms/gemini.py +204 -0
- durag/llms/groq.py +88 -0
- durag/llms/langchain.py +94 -0
- durag/llms/litellm.py +87 -0
- durag/llms/lmstudio.py +114 -0
- durag/llms/minimax.py +114 -0
- durag/llms/ollama.py +143 -0
- durag/llms/openai.py +149 -0
- durag/llms/openai_structured.py +52 -0
- durag/llms/sarvam.py +89 -0
- durag/llms/together.py +88 -0
- durag/llms/vllm.py +109 -0
- durag/llms/xai.py +52 -0
- durag/memory/__init__.py +0 -0
- durag/memory/base.py +63 -0
- durag/memory/main.py +3222 -0
- durag/memory/setup.py +154 -0
- durag/memory/storage.py +347 -0
- durag/memory/telemetry.py +235 -0
- durag/memory/utils.py +295 -0
- durag/reranker/__init__.py +9 -0
- durag/reranker/base.py +20 -0
- durag/reranker/cohere_reranker.py +85 -0
- durag/reranker/huggingface_reranker.py +147 -0
- durag/reranker/llm_reranker.py +169 -0
- durag/reranker/sentence_transformer_reranker.py +110 -0
- durag/reranker/zero_entropy_reranker.py +96 -0
- durag/utils/entity_extraction.py +357 -0
- durag/utils/factory.py +264 -0
- durag/utils/gcp_auth.py +167 -0
- durag/utils/lemmatization.py +50 -0
- durag/utils/scoring.py +121 -0
- durag/utils/spacy_models.py +91 -0
- durag/vector_stores/__init__.py +0 -0
- durag/vector_stores/azure_ai_search.py +424 -0
- durag/vector_stores/azure_mysql.py +545 -0
- durag/vector_stores/baidu.py +411 -0
- durag/vector_stores/base.py +92 -0
- durag/vector_stores/cassandra.py +506 -0
- durag/vector_stores/chroma.py +332 -0
- durag/vector_stores/configs.py +67 -0
- durag/vector_stores/databricks.py +875 -0
- durag/vector_stores/elasticsearch.py +289 -0
- durag/vector_stores/faiss.py +631 -0
- durag/vector_stores/langchain.py +180 -0
- durag/vector_stores/milvus.py +346 -0
- durag/vector_stores/mongodb.py +400 -0
- durag/vector_stores/neptune_analytics.py +467 -0
- durag/vector_stores/opensearch.py +380 -0
- durag/vector_stores/pgvector.py +526 -0
- durag/vector_stores/pinecone.py +418 -0
- durag/vector_stores/qdrant.py +556 -0
- durag/vector_stores/redis.py +351 -0
- durag/vector_stores/s3_vectors.py +206 -0
- durag/vector_stores/supabase.py +237 -0
- durag/vector_stores/turbopuffer.py +337 -0
- durag/vector_stores/upstash_vector.py +332 -0
- durag/vector_stores/valkey.py +837 -0
- durag/vector_stores/vertex_ai_vector_search.py +644 -0
- durag/vector_stores/weaviate.py +392 -0
- durag-2.0.4.dist-info/METADATA +72 -0
- durag-2.0.4.dist-info/RECORD +139 -0
- durag-2.0.4.dist-info/WHEEL +4 -0
- durag-2.0.4.dist-info/licenses/LICENSE +201 -0
durag/__init__.py
ADDED
|
File without changes
|
durag/configs/base.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from durag.configs.rerankers.config import RerankerConfig
|
|
7
|
+
from durag.embeddings.configs import EmbedderConfig
|
|
8
|
+
from durag.llms.configs import LlmConfig
|
|
9
|
+
from durag.vector_stores.configs import VectorStoreConfig
|
|
10
|
+
|
|
11
|
+
# Set up the directory path
|
|
12
|
+
home_dir = os.path.expanduser("~")
|
|
13
|
+
durag_dir = os.environ.get("DURAG_DIR") or os.path.join(home_dir, ".durag")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MemoryItem(BaseModel):
|
|
17
|
+
id: str = Field(..., description="The unique identifier for the text data")
|
|
18
|
+
memory: str = Field(
|
|
19
|
+
..., description="The memory deduced from the text data"
|
|
20
|
+
) # TODO After prompt changes from platform, update this
|
|
21
|
+
hash: Optional[str] = Field(None, description="The hash of the memory")
|
|
22
|
+
# The metadata value can be anything and not just string. Fix it
|
|
23
|
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
|
|
24
|
+
score: Optional[float] = Field(None, description="The score associated with the text data")
|
|
25
|
+
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
|
|
26
|
+
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemoryConfig(BaseModel):
|
|
30
|
+
vector_store: VectorStoreConfig = Field(
|
|
31
|
+
description="Configuration for the vector store",
|
|
32
|
+
default_factory=VectorStoreConfig,
|
|
33
|
+
)
|
|
34
|
+
llm: LlmConfig = Field(
|
|
35
|
+
description="Configuration for the language model",
|
|
36
|
+
default_factory=LlmConfig,
|
|
37
|
+
)
|
|
38
|
+
embedder: EmbedderConfig = Field(
|
|
39
|
+
description="Configuration for the embedding model",
|
|
40
|
+
default_factory=EmbedderConfig,
|
|
41
|
+
)
|
|
42
|
+
history_db_path: str = Field(
|
|
43
|
+
description="Path to the history database",
|
|
44
|
+
default=os.path.join(durag_dir, "history.db"),
|
|
45
|
+
)
|
|
46
|
+
reranker: Optional[RerankerConfig] = Field(
|
|
47
|
+
description="Configuration for the reranker",
|
|
48
|
+
default=None,
|
|
49
|
+
)
|
|
50
|
+
version: str = Field(
|
|
51
|
+
description="The version of the API",
|
|
52
|
+
default="v1.1",
|
|
53
|
+
)
|
|
54
|
+
custom_instructions: Optional[str] = Field(
|
|
55
|
+
description="Custom instructions for fact extraction",
|
|
56
|
+
default=None,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AzureConfig(BaseModel):
|
|
61
|
+
"""
|
|
62
|
+
Configuration settings for Azure.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
api_key (str): The API key used for authenticating with the Azure service.
|
|
66
|
+
azure_deployment (str): The name of the Azure deployment.
|
|
67
|
+
azure_endpoint (str): The endpoint URL for the Azure service.
|
|
68
|
+
api_version (str): The version of the Azure API being used.
|
|
69
|
+
default_headers (Dict[str, str]): Headers to include in requests to the Azure API.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
api_key: str = Field(
|
|
73
|
+
description="The API key used for authenticating with the Azure service.",
|
|
74
|
+
default=None,
|
|
75
|
+
)
|
|
76
|
+
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
|
|
77
|
+
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
|
|
78
|
+
api_version: str = Field(description="The version of the Azure API being used.", default=None)
|
|
79
|
+
default_headers: Optional[Dict[str, str]] = Field(
|
|
80
|
+
description="Headers to include in requests to the Azure API.", default=None
|
|
81
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from typing import Dict, Optional, Union
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from durag.configs.base import AzureConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseEmbedderConfig(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Config for Embeddings.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
model: Optional[str] = None,
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
embedding_dims: Optional[int] = None,
|
|
20
|
+
# Ollama specific
|
|
21
|
+
ollama_base_url: Optional[str] = None,
|
|
22
|
+
# Openai specific
|
|
23
|
+
openai_base_url: Optional[str] = None,
|
|
24
|
+
# Huggingface specific
|
|
25
|
+
model_kwargs: Optional[dict] = None,
|
|
26
|
+
huggingface_base_url: Optional[str] = None,
|
|
27
|
+
# AzureOpenAI specific
|
|
28
|
+
azure_kwargs: Optional[AzureConfig] = {},
|
|
29
|
+
http_client_proxies: Optional[Union[Dict, str]] = None,
|
|
30
|
+
# VertexAI specific
|
|
31
|
+
vertex_credentials_json: Optional[str] = None,
|
|
32
|
+
memory_add_embedding_type: Optional[str] = None,
|
|
33
|
+
memory_update_embedding_type: Optional[str] = None,
|
|
34
|
+
memory_search_embedding_type: Optional[str] = None,
|
|
35
|
+
# Gemini specific
|
|
36
|
+
output_dimensionality: Optional[str] = None,
|
|
37
|
+
# LM Studio specific
|
|
38
|
+
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
|
39
|
+
# AWS Bedrock specific
|
|
40
|
+
aws_access_key_id: Optional[str] = None,
|
|
41
|
+
aws_secret_access_key: Optional[str] = None,
|
|
42
|
+
aws_region: Optional[str] = None,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Initializes a configuration class instance for the Embeddings.
|
|
46
|
+
|
|
47
|
+
:param model: Embedding model to use, defaults to None
|
|
48
|
+
:type model: Optional[str], optional
|
|
49
|
+
:param api_key: API key to be use, defaults to None
|
|
50
|
+
:type api_key: Optional[str], optional
|
|
51
|
+
:param embedding_dims: The number of dimensions in the embedding, defaults to None
|
|
52
|
+
:type embedding_dims: Optional[int], optional
|
|
53
|
+
:param ollama_base_url: Base URL for the Ollama API, defaults to None
|
|
54
|
+
:type ollama_base_url: Optional[str], optional
|
|
55
|
+
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
|
|
56
|
+
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
|
|
57
|
+
:param huggingface_base_url: Huggingface base URL to be use, defaults to None
|
|
58
|
+
:type huggingface_base_url: Optional[str], optional
|
|
59
|
+
:param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1"
|
|
60
|
+
:type openai_base_url: Optional[str], optional
|
|
61
|
+
:param azure_kwargs: key-value arguments for the AzureOpenAI embedding model, defaults a dict inside init
|
|
62
|
+
:type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
|
|
63
|
+
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
|
64
|
+
:type http_client_proxies: Optional[Dict | str], optional
|
|
65
|
+
:param vertex_credentials_json: The path to the Vertex AI credentials JSON file, defaults to None
|
|
66
|
+
:type vertex_credentials_json: Optional[str], optional
|
|
67
|
+
:param memory_add_embedding_type: The type of embedding to use for the add memory action, defaults to None
|
|
68
|
+
:type memory_add_embedding_type: Optional[str], optional
|
|
69
|
+
:param memory_update_embedding_type: The type of embedding to use for the update memory action, defaults to None
|
|
70
|
+
:type memory_update_embedding_type: Optional[str], optional
|
|
71
|
+
:param memory_search_embedding_type: The type of embedding to use for the search memory action, defaults to None
|
|
72
|
+
:type memory_search_embedding_type: Optional[str], optional
|
|
73
|
+
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
|
74
|
+
:type lmstudio_base_url: Optional[str], optional
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
self.model = model
|
|
78
|
+
self.api_key = api_key
|
|
79
|
+
self.openai_base_url = openai_base_url
|
|
80
|
+
self.embedding_dims = embedding_dims
|
|
81
|
+
|
|
82
|
+
# AzureOpenAI specific
|
|
83
|
+
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
|
84
|
+
|
|
85
|
+
# Ollama specific
|
|
86
|
+
self.ollama_base_url = ollama_base_url
|
|
87
|
+
|
|
88
|
+
# Huggingface specific
|
|
89
|
+
self.model_kwargs = model_kwargs or {}
|
|
90
|
+
self.huggingface_base_url = huggingface_base_url
|
|
91
|
+
# AzureOpenAI specific
|
|
92
|
+
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
|
|
93
|
+
|
|
94
|
+
# VertexAI specific
|
|
95
|
+
self.vertex_credentials_json = vertex_credentials_json
|
|
96
|
+
self.memory_add_embedding_type = memory_add_embedding_type
|
|
97
|
+
self.memory_update_embedding_type = memory_update_embedding_type
|
|
98
|
+
self.memory_search_embedding_type = memory_search_embedding_type
|
|
99
|
+
|
|
100
|
+
# Gemini specific
|
|
101
|
+
self.output_dimensionality = output_dimensionality
|
|
102
|
+
|
|
103
|
+
# LM Studio specific
|
|
104
|
+
self.lmstudio_base_url = lmstudio_base_url
|
|
105
|
+
|
|
106
|
+
# AWS Bedrock specific
|
|
107
|
+
self.aws_access_key_id = aws_access_key_id
|
|
108
|
+
self.aws_secret_access_key = aws_secret_access_key
|
|
109
|
+
self.aws_region = aws_region or os.environ.get("AWS_REGION") or "us-west-2"
|
|
110
|
+
|
durag/configs/enums.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from durag.configs.llms.base import BaseLlmConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AnthropicConfig(BaseLlmConfig):
|
|
7
|
+
"""
|
|
8
|
+
Configuration class for Anthropic-specific parameters.
|
|
9
|
+
Inherits from BaseLlmConfig and adds Anthropic-specific settings.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
# Base parameters
|
|
15
|
+
model: Optional[str] = None,
|
|
16
|
+
temperature: float = 0.1,
|
|
17
|
+
api_key: Optional[str] = None,
|
|
18
|
+
max_tokens: int = 2000,
|
|
19
|
+
top_p: Optional[float] = None,
|
|
20
|
+
top_k: int = 1,
|
|
21
|
+
enable_vision: bool = False,
|
|
22
|
+
vision_details: Optional[str] = "auto",
|
|
23
|
+
http_client_proxies: Optional[dict] = None,
|
|
24
|
+
# Anthropic-specific parameters
|
|
25
|
+
anthropic_base_url: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize Anthropic configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: Anthropic model to use, defaults to None
|
|
32
|
+
temperature: Controls randomness, defaults to 0.1
|
|
33
|
+
api_key: Anthropic API key, defaults to None
|
|
34
|
+
max_tokens: Maximum tokens to generate, defaults to 2000
|
|
35
|
+
top_p: Nucleus sampling parameter, defaults to None (omitted to avoid conflict with temperature)
|
|
36
|
+
top_k: Top-k sampling parameter, defaults to 1
|
|
37
|
+
enable_vision: Enable vision capabilities, defaults to False
|
|
38
|
+
vision_details: Vision detail level, defaults to "auto"
|
|
39
|
+
http_client_proxies: HTTP client proxy settings, defaults to None
|
|
40
|
+
anthropic_base_url: Anthropic API base URL, defaults to None
|
|
41
|
+
"""
|
|
42
|
+
# Initialize base parameters
|
|
43
|
+
super().__init__(
|
|
44
|
+
model=model,
|
|
45
|
+
temperature=temperature,
|
|
46
|
+
api_key=api_key,
|
|
47
|
+
max_tokens=max_tokens,
|
|
48
|
+
top_p=top_p,
|
|
49
|
+
top_k=top_k,
|
|
50
|
+
enable_vision=enable_vision,
|
|
51
|
+
vision_details=vision_details,
|
|
52
|
+
http_client_proxies=http_client_proxies,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Anthropic-specific parameters
|
|
56
|
+
self.anthropic_base_url = anthropic_base_url
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from durag.configs.llms.base import BaseLlmConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AWSBedrockConfig(BaseLlmConfig):
|
|
8
|
+
"""
|
|
9
|
+
Configuration class for AWS Bedrock LLM integration.
|
|
10
|
+
|
|
11
|
+
Supports all available Bedrock models with automatic provider detection.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model: Optional[str] = None,
|
|
17
|
+
temperature: float = 0.1,
|
|
18
|
+
max_tokens: int = 2000,
|
|
19
|
+
top_p: Optional[float] = None,
|
|
20
|
+
top_k: int = 1,
|
|
21
|
+
aws_access_key_id: Optional[str] = None,
|
|
22
|
+
aws_secret_access_key: Optional[str] = None,
|
|
23
|
+
aws_region: str = "",
|
|
24
|
+
aws_session_token: Optional[str] = None,
|
|
25
|
+
aws_profile: Optional[str] = None,
|
|
26
|
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize AWS Bedrock configuration.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: Bedrock model identifier (e.g., "amazon.nova-3-mini-20241119-v1:0")
|
|
34
|
+
temperature: Controls randomness (0.0 to 2.0)
|
|
35
|
+
max_tokens: Maximum tokens to generate
|
|
36
|
+
top_p: Nucleus sampling (0.0–1.0). Default None omits topP on Converse
|
|
37
|
+
(required for Anthropic, which rejects temperature and topP together).
|
|
38
|
+
top_k: Top-k sampling parameter (1 to 40)
|
|
39
|
+
aws_access_key_id: AWS access key (optional, uses env vars if not provided)
|
|
40
|
+
aws_secret_access_key: AWS secret key (optional, uses env vars if not provided)
|
|
41
|
+
aws_region: AWS region for Bedrock service
|
|
42
|
+
aws_session_token: AWS session token for temporary credentials
|
|
43
|
+
aws_profile: AWS profile name for credentials
|
|
44
|
+
model_kwargs: Additional model-specific parameters
|
|
45
|
+
**kwargs: Additional arguments passed to base class
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(
|
|
48
|
+
model=model or "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
49
|
+
temperature=temperature,
|
|
50
|
+
max_tokens=max_tokens,
|
|
51
|
+
top_p=top_p,
|
|
52
|
+
top_k=top_k,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.aws_access_key_id = aws_access_key_id
|
|
57
|
+
self.aws_secret_access_key = aws_secret_access_key
|
|
58
|
+
self.aws_region = aws_region or os.getenv("AWS_REGION", "us-west-2")
|
|
59
|
+
self.aws_session_token = aws_session_token
|
|
60
|
+
self.aws_profile = aws_profile
|
|
61
|
+
self.model_kwargs = model_kwargs or {}
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def provider(self) -> str:
|
|
65
|
+
"""Get the provider from the model identifier."""
|
|
66
|
+
if not self.model or "." not in self.model:
|
|
67
|
+
return "unknown"
|
|
68
|
+
return self.model.split(".")[0]
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def model_name(self) -> str:
|
|
72
|
+
"""Get the model name without provider prefix."""
|
|
73
|
+
if not self.model or "." not in self.model:
|
|
74
|
+
return self.model
|
|
75
|
+
return ".".join(self.model.split(".")[1:])
|
|
76
|
+
|
|
77
|
+
def get_model_config(self) -> Dict[str, Any]:
|
|
78
|
+
"""Get model-specific configuration parameters."""
|
|
79
|
+
base_config: Dict[str, Any] = {
|
|
80
|
+
"temperature": self.temperature,
|
|
81
|
+
"max_tokens": self.max_tokens,
|
|
82
|
+
"top_k": self.top_k,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# Only include top_p when explicitly set by the user.
|
|
86
|
+
if self.top_p is not None:
|
|
87
|
+
base_config["top_p"] = self.top_p
|
|
88
|
+
|
|
89
|
+
# Add custom model kwargs
|
|
90
|
+
base_config.update(self.model_kwargs)
|
|
91
|
+
|
|
92
|
+
return base_config
|
|
93
|
+
|
|
94
|
+
def get_aws_config(self) -> Dict[str, Any]:
|
|
95
|
+
"""Get AWS configuration parameters."""
|
|
96
|
+
config = {
|
|
97
|
+
"region_name": self.aws_region,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
if self.aws_access_key_id:
|
|
101
|
+
config["aws_access_key_id"] = self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
|
102
|
+
|
|
103
|
+
if self.aws_secret_access_key:
|
|
104
|
+
config["aws_secret_access_key"] = self.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
105
|
+
|
|
106
|
+
if self.aws_session_token:
|
|
107
|
+
config["aws_session_token"] = self.aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
|
108
|
+
|
|
109
|
+
if self.aws_profile:
|
|
110
|
+
config["profile_name"] = self.aws_profile or os.getenv("AWS_PROFILE")
|
|
111
|
+
|
|
112
|
+
return config
|
|
113
|
+
|
|
114
|
+
def validate_model_format(self) -> bool:
|
|
115
|
+
"""
|
|
116
|
+
Validate that the model identifier follows Bedrock naming convention.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
True if valid, False otherwise
|
|
120
|
+
"""
|
|
121
|
+
if not self.model:
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
# Check if model follows provider.model-name format
|
|
125
|
+
if "." not in self.model:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
provider, model_name = self.model.split(".", 1)
|
|
129
|
+
|
|
130
|
+
# Validate provider
|
|
131
|
+
valid_providers = [
|
|
132
|
+
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral",
|
|
133
|
+
"stability", "writer", "deepseek", "gpt-oss", "perplexity",
|
|
134
|
+
"snowflake", "titan", "command", "j2", "llama"
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if provider not in valid_providers:
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
# Validate model name is not empty
|
|
141
|
+
if not model_name:
|
|
142
|
+
return False
|
|
143
|
+
|
|
144
|
+
return True
|
|
145
|
+
|
|
146
|
+
def get_supported_regions(self) -> List[str]:
|
|
147
|
+
"""Get list of AWS regions that support Bedrock."""
|
|
148
|
+
return [
|
|
149
|
+
"us-east-1",
|
|
150
|
+
"us-west-2",
|
|
151
|
+
"us-east-2",
|
|
152
|
+
"eu-west-1",
|
|
153
|
+
"ap-southeast-1",
|
|
154
|
+
"ap-northeast-1",
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
def get_model_capabilities(self) -> Dict[str, Any]:
|
|
158
|
+
"""Get model capabilities based on provider."""
|
|
159
|
+
capabilities = {
|
|
160
|
+
"supports_tools": False,
|
|
161
|
+
"supports_vision": False,
|
|
162
|
+
"supports_streaming": False,
|
|
163
|
+
"supports_multimodal": False,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
if self.provider == "anthropic":
|
|
167
|
+
capabilities.update({
|
|
168
|
+
"supports_tools": True,
|
|
169
|
+
"supports_vision": True,
|
|
170
|
+
"supports_streaming": True,
|
|
171
|
+
"supports_multimodal": True,
|
|
172
|
+
})
|
|
173
|
+
elif self.provider == "amazon":
|
|
174
|
+
capabilities.update({
|
|
175
|
+
"supports_tools": True,
|
|
176
|
+
"supports_vision": True,
|
|
177
|
+
"supports_streaming": True,
|
|
178
|
+
"supports_multimodal": True,
|
|
179
|
+
})
|
|
180
|
+
elif self.provider == "cohere":
|
|
181
|
+
capabilities.update({
|
|
182
|
+
"supports_tools": True,
|
|
183
|
+
"supports_streaming": True,
|
|
184
|
+
})
|
|
185
|
+
elif self.provider == "meta":
|
|
186
|
+
capabilities.update({
|
|
187
|
+
"supports_vision": True,
|
|
188
|
+
"supports_streaming": True,
|
|
189
|
+
})
|
|
190
|
+
elif self.provider == "mistral":
|
|
191
|
+
capabilities.update({
|
|
192
|
+
"supports_vision": True,
|
|
193
|
+
"supports_streaming": True,
|
|
194
|
+
})
|
|
195
|
+
|
|
196
|
+
return capabilities
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from durag.configs.base import AzureConfig
|
|
4
|
+
from durag.configs.llms.base import BaseLlmConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AzureOpenAIConfig(BaseLlmConfig):
|
|
8
|
+
"""
|
|
9
|
+
Configuration class for Azure OpenAI-specific parameters.
|
|
10
|
+
Inherits from BaseLlmConfig and adds Azure OpenAI-specific settings.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
# Base parameters
|
|
16
|
+
model: Optional[str] = None,
|
|
17
|
+
temperature: float = 0.1,
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
max_tokens: int = 2000,
|
|
20
|
+
top_p: float = 0.1,
|
|
21
|
+
top_k: int = 1,
|
|
22
|
+
enable_vision: bool = False,
|
|
23
|
+
vision_details: Optional[str] = "auto",
|
|
24
|
+
reasoning_effort: Optional[str] = None,
|
|
25
|
+
http_client_proxies: Optional[dict] = None,
|
|
26
|
+
# Azure OpenAI-specific parameters
|
|
27
|
+
azure_kwargs: Optional[Dict[str, Any]] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize Azure OpenAI configuration.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: Azure OpenAI model to use, defaults to None
|
|
34
|
+
temperature: Controls randomness, defaults to 0.1
|
|
35
|
+
api_key: Azure OpenAI API key, defaults to None
|
|
36
|
+
max_tokens: Maximum tokens to generate, defaults to 2000
|
|
37
|
+
top_p: Nucleus sampling parameter, defaults to 0.1
|
|
38
|
+
top_k: Top-k sampling parameter, defaults to 1
|
|
39
|
+
enable_vision: Enable vision capabilities, defaults to False
|
|
40
|
+
vision_details: Vision detail level, defaults to "auto"
|
|
41
|
+
reasoning_effort: Effort level for reasoning models ("low", "medium", "high"), defaults to None
|
|
42
|
+
http_client_proxies: HTTP client proxy settings, defaults to None
|
|
43
|
+
azure_kwargs: Azure-specific configuration, defaults to None
|
|
44
|
+
"""
|
|
45
|
+
# Initialize base parameters
|
|
46
|
+
super().__init__(
|
|
47
|
+
model=model,
|
|
48
|
+
temperature=temperature,
|
|
49
|
+
api_key=api_key,
|
|
50
|
+
max_tokens=max_tokens,
|
|
51
|
+
top_p=top_p,
|
|
52
|
+
top_k=top_k,
|
|
53
|
+
enable_vision=enable_vision,
|
|
54
|
+
vision_details=vision_details,
|
|
55
|
+
reasoning_effort=reasoning_effort,
|
|
56
|
+
http_client_proxies=http_client_proxies,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Azure OpenAI-specific parameters
|
|
60
|
+
self.azure_kwargs = AzureConfig(**(azure_kwargs or {}))
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseLlmConfig(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Base configuration for LLMs with only common parameters.
|
|
10
|
+
Provider-specific configurations should be handled by separate config classes.
|
|
11
|
+
|
|
12
|
+
This class contains only the parameters that are common across all LLM providers.
|
|
13
|
+
For provider-specific parameters, use the appropriate provider config class.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model: Optional[Union[str, Dict]] = None,
|
|
19
|
+
temperature: float = 0.1,
|
|
20
|
+
api_key: Optional[str] = None,
|
|
21
|
+
max_tokens: int = 2000,
|
|
22
|
+
top_p: float = 0.1,
|
|
23
|
+
top_k: int = 1,
|
|
24
|
+
enable_vision: bool = False,
|
|
25
|
+
vision_details: Optional[str] = "auto",
|
|
26
|
+
reasoning_effort: Optional[str] = None,
|
|
27
|
+
http_client_proxies: Optional[Union[Dict, str]] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize a base configuration class instance for the LLM.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: The model identifier to use (e.g., "gpt-4.1-nano-2025-04-14", "claude-3-5-sonnet-20240620")
|
|
34
|
+
Defaults to None (will be set by provider-specific configs)
|
|
35
|
+
temperature: Controls the randomness of the model's output.
|
|
36
|
+
Higher values (closer to 1) make output more random, lower values make it more deterministic.
|
|
37
|
+
Range: 0.0 to 2.0. Defaults to 0.1
|
|
38
|
+
api_key: API key for the LLM provider. If None, will try to get from environment variables.
|
|
39
|
+
Defaults to None
|
|
40
|
+
max_tokens: Maximum number of tokens to generate in the response.
|
|
41
|
+
Range: 1 to 4096 (varies by model). Defaults to 2000
|
|
42
|
+
top_p: Nucleus sampling parameter. Controls diversity via nucleus sampling.
|
|
43
|
+
Higher values (closer to 1) make word selection more diverse.
|
|
44
|
+
Range: 0.0 to 1.0. Defaults to 0.1
|
|
45
|
+
top_k: Top-k sampling parameter. Limits the number of tokens considered for each step.
|
|
46
|
+
Higher values make word selection more diverse.
|
|
47
|
+
Range: 1 to 40. Defaults to 1
|
|
48
|
+
enable_vision: Whether to enable vision capabilities for the model.
|
|
49
|
+
Only applicable to vision-enabled models. Defaults to False
|
|
50
|
+
vision_details: Level of detail for vision processing.
|
|
51
|
+
Options: "low", "high", "auto". Defaults to "auto"
|
|
52
|
+
reasoning_effort: Effort level for reasoning models (e.g., o1, o3, gpt-5).
|
|
53
|
+
Options: "low", "medium", "high". Only applicable to reasoning models.
|
|
54
|
+
Defaults to None (uses the model's default reasoning effort)
|
|
55
|
+
http_client_proxies: Proxy settings for HTTP client.
|
|
56
|
+
Can be a dict or string. Defaults to None
|
|
57
|
+
"""
|
|
58
|
+
self.model = model
|
|
59
|
+
self.temperature = temperature
|
|
60
|
+
self.api_key = api_key
|
|
61
|
+
self.max_tokens = max_tokens
|
|
62
|
+
self.top_p = top_p
|
|
63
|
+
self.top_k = top_k
|
|
64
|
+
self.enable_vision = enable_vision
|
|
65
|
+
self.vision_details = vision_details
|
|
66
|
+
self.reasoning_effort = reasoning_effort
|
|
67
|
+
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from durag.configs.llms.base import BaseLlmConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DeepSeekConfig(BaseLlmConfig):
|
|
7
|
+
"""
|
|
8
|
+
Configuration class for DeepSeek-specific parameters.
|
|
9
|
+
Inherits from BaseLlmConfig and adds DeepSeek-specific settings.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
# Base parameters
|
|
15
|
+
model: Optional[str] = None,
|
|
16
|
+
temperature: float = 0.1,
|
|
17
|
+
api_key: Optional[str] = None,
|
|
18
|
+
max_tokens: int = 2000,
|
|
19
|
+
top_p: float = 0.1,
|
|
20
|
+
top_k: int = 1,
|
|
21
|
+
enable_vision: bool = False,
|
|
22
|
+
vision_details: Optional[str] = "auto",
|
|
23
|
+
http_client_proxies: Optional[dict] = None,
|
|
24
|
+
# DeepSeek-specific parameters
|
|
25
|
+
deepseek_base_url: Optional[str] = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize DeepSeek configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: DeepSeek model to use, defaults to None
|
|
32
|
+
temperature: Controls randomness, defaults to 0.1
|
|
33
|
+
api_key: DeepSeek API key, defaults to None
|
|
34
|
+
max_tokens: Maximum tokens to generate, defaults to 2000
|
|
35
|
+
top_p: Nucleus sampling parameter, defaults to 0.1
|
|
36
|
+
top_k: Top-k sampling parameter, defaults to 1
|
|
37
|
+
enable_vision: Enable vision capabilities, defaults to False
|
|
38
|
+
vision_details: Vision detail level, defaults to "auto"
|
|
39
|
+
http_client_proxies: HTTP client proxy settings, defaults to None
|
|
40
|
+
deepseek_base_url: DeepSeek API base URL, defaults to None
|
|
41
|
+
"""
|
|
42
|
+
# Initialize base parameters
|
|
43
|
+
super().__init__(
|
|
44
|
+
model=model,
|
|
45
|
+
temperature=temperature,
|
|
46
|
+
api_key=api_key,
|
|
47
|
+
max_tokens=max_tokens,
|
|
48
|
+
top_p=top_p,
|
|
49
|
+
top_k=top_k,
|
|
50
|
+
enable_vision=enable_vision,
|
|
51
|
+
vision_details=vision_details,
|
|
52
|
+
http_client_proxies=http_client_proxies,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# DeepSeek-specific parameters
|
|
56
|
+
self.deepseek_base_url = deepseek_base_url
|