diagram-to-iac 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diagram_to_iac/__init__.py +10 -0
- diagram_to_iac/actions/__init__.py +7 -0
- diagram_to_iac/actions/git_entry.py +174 -0
- diagram_to_iac/actions/supervisor_entry.py +116 -0
- diagram_to_iac/actions/terraform_agent_entry.py +207 -0
- diagram_to_iac/agents/__init__.py +26 -0
- diagram_to_iac/agents/demonstrator_langgraph/__init__.py +10 -0
- diagram_to_iac/agents/demonstrator_langgraph/agent.py +826 -0
- diagram_to_iac/agents/git_langgraph/__init__.py +10 -0
- diagram_to_iac/agents/git_langgraph/agent.py +1018 -0
- diagram_to_iac/agents/git_langgraph/pr.py +146 -0
- diagram_to_iac/agents/hello_langgraph/__init__.py +9 -0
- diagram_to_iac/agents/hello_langgraph/agent.py +621 -0
- diagram_to_iac/agents/policy_agent/__init__.py +15 -0
- diagram_to_iac/agents/policy_agent/agent.py +507 -0
- diagram_to_iac/agents/policy_agent/integration_example.py +191 -0
- diagram_to_iac/agents/policy_agent/tools/__init__.py +14 -0
- diagram_to_iac/agents/policy_agent/tools/tfsec_tool.py +259 -0
- diagram_to_iac/agents/shell_langgraph/__init__.py +21 -0
- diagram_to_iac/agents/shell_langgraph/agent.py +122 -0
- diagram_to_iac/agents/shell_langgraph/detector.py +50 -0
- diagram_to_iac/agents/supervisor_langgraph/__init__.py +17 -0
- diagram_to_iac/agents/supervisor_langgraph/agent.py +1947 -0
- diagram_to_iac/agents/supervisor_langgraph/demonstrator.py +22 -0
- diagram_to_iac/agents/supervisor_langgraph/guards.py +23 -0
- diagram_to_iac/agents/supervisor_langgraph/pat_loop.py +49 -0
- diagram_to_iac/agents/supervisor_langgraph/router.py +9 -0
- diagram_to_iac/agents/terraform_langgraph/__init__.py +15 -0
- diagram_to_iac/agents/terraform_langgraph/agent.py +1216 -0
- diagram_to_iac/agents/terraform_langgraph/parser.py +76 -0
- diagram_to_iac/core/__init__.py +7 -0
- diagram_to_iac/core/agent_base.py +19 -0
- diagram_to_iac/core/enhanced_memory.py +302 -0
- diagram_to_iac/core/errors.py +4 -0
- diagram_to_iac/core/issue_tracker.py +49 -0
- diagram_to_iac/core/memory.py +132 -0
- diagram_to_iac/services/__init__.py +10 -0
- diagram_to_iac/services/observability.py +59 -0
- diagram_to_iac/services/step_summary.py +77 -0
- diagram_to_iac/tools/__init__.py +11 -0
- diagram_to_iac/tools/api_utils.py +108 -26
- diagram_to_iac/tools/git/__init__.py +45 -0
- diagram_to_iac/tools/git/git.py +956 -0
- diagram_to_iac/tools/hello/__init__.py +30 -0
- diagram_to_iac/tools/hello/cal_utils.py +31 -0
- diagram_to_iac/tools/hello/text_utils.py +97 -0
- diagram_to_iac/tools/llm_utils/__init__.py +20 -0
- diagram_to_iac/tools/llm_utils/anthropic_driver.py +87 -0
- diagram_to_iac/tools/llm_utils/base_driver.py +90 -0
- diagram_to_iac/tools/llm_utils/gemini_driver.py +89 -0
- diagram_to_iac/tools/llm_utils/openai_driver.py +93 -0
- diagram_to_iac/tools/llm_utils/router.py +303 -0
- diagram_to_iac/tools/sec_utils.py +4 -2
- diagram_to_iac/tools/shell/__init__.py +17 -0
- diagram_to_iac/tools/shell/shell.py +415 -0
- diagram_to_iac/tools/text_utils.py +277 -0
- diagram_to_iac/tools/tf/terraform.py +851 -0
- diagram_to_iac-0.8.0.dist-info/METADATA +99 -0
- diagram_to_iac-0.8.0.dist-info/RECORD +64 -0
- {diagram_to_iac-0.6.0.dist-info → diagram_to_iac-0.8.0.dist-info}/WHEEL +1 -1
- diagram_to_iac-0.8.0.dist-info/entry_points.txt +4 -0
- diagram_to_iac/agents/codegen_agent.py +0 -0
- diagram_to_iac/agents/consensus_agent.py +0 -0
- diagram_to_iac/agents/deployment_agent.py +0 -0
- diagram_to_iac/agents/github_agent.py +0 -0
- diagram_to_iac/agents/interpretation_agent.py +0 -0
- diagram_to_iac/agents/question_agent.py +0 -0
- diagram_to_iac/agents/supervisor.py +0 -0
- diagram_to_iac/agents/vision_agent.py +0 -0
- diagram_to_iac/core/config.py +0 -0
- diagram_to_iac/tools/cv_utils.py +0 -0
- diagram_to_iac/tools/gh_utils.py +0 -0
- diagram_to_iac/tools/tf_utils.py +0 -0
- diagram_to_iac-0.6.0.dist-info/METADATA +0 -16
- diagram_to_iac-0.6.0.dist-info/RECORD +0 -32
- diagram_to_iac-0.6.0.dist-info/entry_points.txt +0 -2
- {diagram_to_iac-0.6.0.dist-info → diagram_to_iac-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,30 @@
|
|
1
|
+
"""
|
2
|
+
Hello LangGraph Agent Tools Package
|
3
|
+
|
4
|
+
This package contains tools specific to the HelloAgent:
|
5
|
+
- cal_utils: Arithmetic operations (addition, multiplication)
|
6
|
+
- text_utils: Text processing utilities
|
7
|
+
|
8
|
+
These tools have been moved from agents/hello_langgraph/tools/ to provide
|
9
|
+
better organization and reusability across the codebase.
|
10
|
+
"""
|
11
|
+
|
12
|
+
# Import text utilities (no external dependencies)
|
13
|
+
from .text_utils import extract_numbers_from_text, extract_numbers_from_text_with_duplicates
|
14
|
+
|
15
|
+
# Import calculation utilities (requires langchain_core)
|
16
|
+
try:
|
17
|
+
from .cal_utils import add_two, multiply_two
|
18
|
+
_cal_utils_available = True
|
19
|
+
except ImportError:
|
20
|
+
_cal_utils_available = False
|
21
|
+
add_two = None
|
22
|
+
multiply_two = None
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"extract_numbers_from_text",
|
26
|
+
"extract_numbers_from_text_with_duplicates"
|
27
|
+
]
|
28
|
+
|
29
|
+
if _cal_utils_available:
|
30
|
+
__all__.extend(["add_two", "multiply_two"])
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from langchain_core.tools import tool # Updated import for modern LangChain
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
|
4
|
+
# --- Pydantic Schemas for Tool Inputs ---
|
5
|
+
class AddToolInput(BaseModel):
|
6
|
+
x: int = Field(..., description="The first number for addition")
|
7
|
+
y: int = Field(..., description="The second number for addition")
|
8
|
+
|
9
|
+
class MultiplyToolInput(BaseModel):
|
10
|
+
x: int = Field(..., description="The first number for multiplication")
|
11
|
+
y: int = Field(..., description="The second number for multiplication")
|
12
|
+
|
13
|
+
# --- Tool Definitions ---
|
14
|
+
@tool(args_schema=AddToolInput)
|
15
|
+
def add_two(x: int, y: int) -> int:
|
16
|
+
"""Add two numbers. Expects input according to AddToolInput schema."""
|
17
|
+
# The decorator @tool with args_schema handles parsing the input dict
|
18
|
+
# into AddToolInput and then passes the fields (x, y) to this function.
|
19
|
+
# So, the function signature remains simple (x: int, y: int).
|
20
|
+
# If we wanted the function to receive the Pydantic model itself,
|
21
|
+
# the tool definition and invocation would be slightly different, often
|
22
|
+
# by defining a custom Tool class or using StructuredTool.from_function
|
23
|
+
# where the function takes the Pydantic model.
|
24
|
+
# For @tool, it unpacks the validated args.
|
25
|
+
return x + y
|
26
|
+
|
27
|
+
@tool(args_schema=MultiplyToolInput)
|
28
|
+
def multiply_two(x: int, y: int) -> int:
|
29
|
+
"""Multiply two numbers. Expects input according to MultiplyToolInput schema."""
|
30
|
+
# Similar to add_two, @tool unpacks the validated fields from MultiplyToolInput.
|
31
|
+
return x * y
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
def extract_numbers_from_text(text: str) -> List[int]:
|
5
|
+
"""
|
6
|
+
Extract all integer numbers from a given text string, handling various formats.
|
7
|
+
|
8
|
+
This function finds sequences of digits and also extracts numbers from patterns
|
9
|
+
like "number operator number" (e.g., "4*5", "7 + 3"). It returns a list of unique
|
10
|
+
integers found, preserving the order of their first appearance in the text.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
text: The input string from which to extract numbers.
|
14
|
+
|
15
|
+
Returns:
|
16
|
+
A list of unique integers found in the text, in order of appearance.
|
17
|
+
Returns an empty list if no numbers are found.
|
18
|
+
|
19
|
+
Supports:
|
20
|
+
- Space-separated numbers: "4 + 5", "add 7 and 3"
|
21
|
+
- Adjacent numbers with operators: "4*5", "6+2"
|
22
|
+
- Basic extraction of digit sequences.
|
23
|
+
"""
|
24
|
+
numbers = []
|
25
|
+
|
26
|
+
# First, find all digit sequences using regex, avoiding phone number confusion
|
27
|
+
digit_matches = re.findall(r'\b\d+\b', text)
|
28
|
+
numbers.extend([int(match) for match in digit_matches])
|
29
|
+
|
30
|
+
# Handle cases where numbers are adjacent to operators like "4*5" or "7+3"
|
31
|
+
# Use specific operators (not dash) to avoid phone number confusion
|
32
|
+
operator_adjacent = re.findall(r'(\d+)\s*[+*/×÷%]\s*(\d+)', text)
|
33
|
+
for match in operator_adjacent:
|
34
|
+
numbers.extend([int(match[0]), int(match[1])])
|
35
|
+
|
36
|
+
# Handle negative numbers in explicit mathematical contexts
|
37
|
+
# Look for patterns like "What is -10 + 5?" but avoid phone numbers
|
38
|
+
math_negative = re.findall(r'(?:is|add|subtract|plus|minus|\+|\*|/)\s+(-\d+)', text, re.IGNORECASE)
|
39
|
+
for match in math_negative:
|
40
|
+
numbers.append(int(match))
|
41
|
+
|
42
|
+
# Also catch negative numbers at the start of mathematical expressions
|
43
|
+
start_negative = re.findall(r'(?:^|\s)(-\d+)\s*[+*/×÷%]', text)
|
44
|
+
for match in start_negative:
|
45
|
+
numbers.append(int(match))
|
46
|
+
|
47
|
+
# Remove duplicates while preserving order of first appearance
|
48
|
+
seen = set()
|
49
|
+
unique_numbers = []
|
50
|
+
for num_val in numbers: # Renamed 'num' to 'num_val' to avoid conflict if 'num' was a global
|
51
|
+
if num_val not in seen:
|
52
|
+
unique_numbers.append(num_val)
|
53
|
+
seen.add(num_val)
|
54
|
+
|
55
|
+
return unique_numbers
|
56
|
+
|
57
|
+
def extract_numbers_from_text_with_duplicates(text: str) -> List[int]:
|
58
|
+
"""
|
59
|
+
Extract all integer numbers from a given text string, preserving duplicates.
|
60
|
+
|
61
|
+
This function is similar to extract_numbers_from_text but preserves duplicate
|
62
|
+
numbers, which is useful for operations like "10 * 10" where both numbers
|
63
|
+
are needed even if they're the same.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
text: The input string from which to extract numbers.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
A list of integers found in the text, in order of appearance, including duplicates.
|
70
|
+
Returns an empty list if no numbers are found.
|
71
|
+
|
72
|
+
Examples:
|
73
|
+
>>> extract_numbers_from_text_with_duplicates("Calculate 10 * 10")
|
74
|
+
[10, 10]
|
75
|
+
>>> extract_numbers_from_text_with_duplicates("What is 5 + 7?")
|
76
|
+
[5, 7]
|
77
|
+
"""
|
78
|
+
# Use a comprehensive regex to find all numbers (positive and negative) in one pass
|
79
|
+
# This regex matches:
|
80
|
+
# 1. Negative numbers in math contexts: "What is -10 + 5?"
|
81
|
+
# 2. Regular positive numbers: "10", "5", etc.
|
82
|
+
# 3. Avoids phone number dashes by being specific about negative contexts
|
83
|
+
|
84
|
+
results = []
|
85
|
+
|
86
|
+
# Combined pattern that captures both positive and negative numbers
|
87
|
+
# Negative numbers: preceded by math keywords or operators, followed by math operators
|
88
|
+
# Positive numbers: standalone digit sequences
|
89
|
+
pattern = r'(?:(?:is|add|subtract|plus|minus|\+|\*|/|^|\s)[-\s]*(-\d+)(?=\s*[+\-*/×÷%\s]|$))|(?:\b(\d+)\b)'
|
90
|
+
|
91
|
+
for match in re.finditer(pattern, text, re.IGNORECASE):
|
92
|
+
if match.group(1): # Negative number
|
93
|
+
results.append(int(match.group(1)))
|
94
|
+
elif match.group(2): # Positive number
|
95
|
+
results.append(int(match.group(2)))
|
96
|
+
|
97
|
+
return results
|
@@ -0,0 +1,20 @@
|
|
1
|
+
"""
|
2
|
+
LLM Utils Package
|
3
|
+
|
4
|
+
Provides LLM routing, drivers, and utilities for the diagram-to-iac project.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .router import LLMRouter, get_llm
|
8
|
+
from .base_driver import BaseLLMDriver
|
9
|
+
from .openai_driver import OpenAIDriver
|
10
|
+
from .anthropic_driver import AnthropicDriver
|
11
|
+
from .gemini_driver import GoogleDriver
|
12
|
+
|
13
|
+
__all__ = [
|
14
|
+
"LLMRouter",
|
15
|
+
"get_llm",
|
16
|
+
"BaseLLMDriver",
|
17
|
+
"OpenAIDriver",
|
18
|
+
"AnthropicDriver",
|
19
|
+
"GoogleDriver"
|
20
|
+
]
|
@@ -0,0 +1,87 @@
|
|
1
|
+
"""
|
2
|
+
Anthropic LLM Driver
|
3
|
+
|
4
|
+
Provides Anthropic Claude-specific optimizations and features.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import Dict, Any, List
|
9
|
+
from langchain_anthropic import ChatAnthropic
|
10
|
+
from .base_driver import BaseLLMDriver
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicDriver(BaseLLMDriver):
|
14
|
+
"""Anthropic Claude-specific LLM driver."""
|
15
|
+
|
16
|
+
SUPPORTED_MODELS = [
|
17
|
+
"claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022",
|
18
|
+
"claude-3-opus-20240229", "claude-3-sonnet-20240229",
|
19
|
+
"claude-3-haiku-20240307"
|
20
|
+
]
|
21
|
+
|
22
|
+
MODEL_CAPABILITIES = {
|
23
|
+
"claude-3-5-sonnet-20241022": {"context_length": 200000, "function_calling": True, "vision": True},
|
24
|
+
"claude-3-5-haiku-20241022": {"context_length": 200000, "function_calling": True, "vision": True},
|
25
|
+
"claude-3-opus-20240229": {"context_length": 200000, "function_calling": True, "vision": True},
|
26
|
+
"claude-3-sonnet-20240229": {"context_length": 200000, "function_calling": True, "vision": True},
|
27
|
+
"claude-3-haiku-20240307": {"context_length": 200000, "function_calling": True, "vision": False},
|
28
|
+
}
|
29
|
+
|
30
|
+
def validate_config(self, config: Dict[str, Any]) -> bool:
|
31
|
+
"""Validate Anthropic-specific configuration."""
|
32
|
+
if not os.getenv("ANTHROPIC_API_KEY"):
|
33
|
+
raise ValueError("ANTHROPIC_API_KEY environment variable not set")
|
34
|
+
|
35
|
+
model = config.get("model")
|
36
|
+
if model and model not in self.SUPPORTED_MODELS:
|
37
|
+
raise ValueError(f"Unsupported Anthropic model: {model}. Supported models: {self.SUPPORTED_MODELS}")
|
38
|
+
|
39
|
+
# Validate temperature range
|
40
|
+
temperature = config.get("temperature")
|
41
|
+
if temperature is not None and (temperature < 0 or temperature > 1):
|
42
|
+
raise ValueError("Anthropic temperature must be between 0 and 1")
|
43
|
+
|
44
|
+
return True
|
45
|
+
|
46
|
+
def create_llm(self, config: Dict[str, Any]) -> ChatAnthropic:
|
47
|
+
"""Create optimized Anthropic LLM instance."""
|
48
|
+
self.validate_config(config)
|
49
|
+
|
50
|
+
# Anthropic-specific optimizations
|
51
|
+
llm_config = {
|
52
|
+
"model": config["model"],
|
53
|
+
"temperature": config.get("temperature", 0.0),
|
54
|
+
"max_tokens": config.get("max_tokens", 1024),
|
55
|
+
"top_p": config.get("top_p"),
|
56
|
+
"top_k": config.get("top_k"), # Anthropic-specific parameter
|
57
|
+
}
|
58
|
+
|
59
|
+
# Remove None values
|
60
|
+
llm_config = {k: v for k, v in llm_config.items() if v is not None}
|
61
|
+
|
62
|
+
return ChatAnthropic(**llm_config)
|
63
|
+
|
64
|
+
def get_supported_models(self) -> List[str]:
|
65
|
+
"""Return list of supported Anthropic models."""
|
66
|
+
return self.SUPPORTED_MODELS.copy()
|
67
|
+
|
68
|
+
def get_model_capabilities(self, model: str) -> Dict[str, Any]:
|
69
|
+
"""Return capabilities for specific Anthropic model."""
|
70
|
+
return self.MODEL_CAPABILITIES.get(model, {})
|
71
|
+
|
72
|
+
def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
73
|
+
"""Estimate cost based on Anthropic pricing (as of 2024)."""
|
74
|
+
# Pricing per 1K tokens in USD
|
75
|
+
pricing = {
|
76
|
+
"claude-3-5-sonnet-20241022": {"input": 0.003, "output": 0.015},
|
77
|
+
"claude-3-5-haiku-20241022": {"input": 0.00025, "output": 0.00125},
|
78
|
+
"claude-3-opus-20240229": {"input": 0.015, "output": 0.075},
|
79
|
+
"claude-3-sonnet-20240229": {"input": 0.003, "output": 0.015},
|
80
|
+
"claude-3-haiku-20240307": {"input": 0.00025, "output": 0.00125},
|
81
|
+
}
|
82
|
+
|
83
|
+
if model not in pricing:
|
84
|
+
return 0.0
|
85
|
+
|
86
|
+
rates = pricing[model]
|
87
|
+
return (input_tokens / 1000 * rates["input"]) + (output_tokens / 1000 * rates["output"])
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""
|
2
|
+
Base LLM Driver Interface
|
3
|
+
|
4
|
+
This module provides the abstract base class for all LLM provider drivers.
|
5
|
+
Each driver implements provider-specific optimizations and features.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from typing import Dict, Any, List, Optional
|
10
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
11
|
+
|
12
|
+
|
13
|
+
class BaseLLMDriver(ABC):
|
14
|
+
"""Abstract base class for all LLM provider drivers."""
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
def validate_config(self, config: Dict[str, Any]) -> bool:
|
18
|
+
"""
|
19
|
+
Validate provider-specific configuration.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: Configuration dictionary containing model parameters
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
bool: True if configuration is valid
|
26
|
+
|
27
|
+
Raises:
|
28
|
+
ValueError: If configuration is invalid
|
29
|
+
"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def create_llm(self, config: Dict[str, Any]) -> BaseChatModel:
|
34
|
+
"""
|
35
|
+
Create and configure LLM instance.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
config: Configuration dictionary containing model parameters
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
BaseChatModel: Configured LLM instance
|
42
|
+
"""
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def get_supported_models(self) -> List[str]:
|
47
|
+
"""
|
48
|
+
Return list of supported models for this provider.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
List[str]: List of supported model names
|
52
|
+
"""
|
53
|
+
pass
|
54
|
+
|
55
|
+
@abstractmethod
|
56
|
+
def get_model_capabilities(self, model: str) -> Dict[str, Any]:
|
57
|
+
"""
|
58
|
+
Return capabilities for specific model.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
model: Model name
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Dict containing capabilities like context_length, function_calling, vision, etc.
|
65
|
+
"""
|
66
|
+
pass
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
70
|
+
"""
|
71
|
+
Estimate cost for given token usage.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
model: Model name
|
75
|
+
input_tokens: Number of input tokens
|
76
|
+
output_tokens: Number of output tokens
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
float: Estimated cost in USD
|
80
|
+
"""
|
81
|
+
pass
|
82
|
+
|
83
|
+
def get_provider_name(self) -> str:
|
84
|
+
"""
|
85
|
+
Get the provider name for this driver.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
str: Provider name (e.g., 'openai', 'anthropic', 'google')
|
89
|
+
"""
|
90
|
+
return self.__class__.__name__.lower().replace('driver', '')
|
@@ -0,0 +1,89 @@
|
|
1
|
+
"""
|
2
|
+
Google Gemini LLM Driver
|
3
|
+
|
4
|
+
Provides Google Gemini-specific optimizations and features.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import Dict, Any, List
|
9
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10
|
+
from .base_driver import BaseLLMDriver
|
11
|
+
|
12
|
+
|
13
|
+
class GoogleDriver(BaseLLMDriver):
|
14
|
+
"""Google Gemini-specific LLM driver."""
|
15
|
+
|
16
|
+
SUPPORTED_MODELS = [
|
17
|
+
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
|
18
|
+
"gemini-1.0-pro", "gemini-pro", "gemini-pro-vision"
|
19
|
+
]
|
20
|
+
|
21
|
+
MODEL_CAPABILITIES = {
|
22
|
+
"gemini-1.5-pro": {"context_length": 1000000, "function_calling": True, "vision": True},
|
23
|
+
"gemini-1.5-flash": {"context_length": 1000000, "function_calling": True, "vision": True},
|
24
|
+
"gemini-1.5-flash-8b": {"context_length": 1000000, "function_calling": True, "vision": True},
|
25
|
+
"gemini-1.0-pro": {"context_length": 30720, "function_calling": True, "vision": False},
|
26
|
+
"gemini-pro": {"context_length": 30720, "function_calling": True, "vision": False},
|
27
|
+
"gemini-pro-vision": {"context_length": 12288, "function_calling": False, "vision": True},
|
28
|
+
}
|
29
|
+
|
30
|
+
def validate_config(self, config: Dict[str, Any]) -> bool:
|
31
|
+
"""Validate Google-specific configuration."""
|
32
|
+
if not os.getenv("GOOGLE_API_KEY"):
|
33
|
+
raise ValueError("GOOGLE_API_KEY environment variable not set")
|
34
|
+
|
35
|
+
model = config.get("model")
|
36
|
+
if model and model not in self.SUPPORTED_MODELS:
|
37
|
+
raise ValueError(f"Unsupported Google model: {model}. Supported models: {self.SUPPORTED_MODELS}")
|
38
|
+
|
39
|
+
# Validate temperature range
|
40
|
+
temperature = config.get("temperature")
|
41
|
+
if temperature is not None and (temperature < 0 or temperature > 1):
|
42
|
+
raise ValueError("Google temperature must be between 0 and 1")
|
43
|
+
|
44
|
+
return True
|
45
|
+
|
46
|
+
def create_llm(self, config: Dict[str, Any]) -> ChatGoogleGenerativeAI:
|
47
|
+
"""Create optimized Google LLM instance."""
|
48
|
+
self.validate_config(config)
|
49
|
+
|
50
|
+
# Google-specific optimizations
|
51
|
+
llm_config = {
|
52
|
+
"model": config["model"],
|
53
|
+
"temperature": config.get("temperature", 0.0),
|
54
|
+
"max_tokens": config.get("max_tokens"),
|
55
|
+
"top_p": config.get("top_p"),
|
56
|
+
"top_k": config.get("top_k"), # Google-specific parameter
|
57
|
+
"google_api_key": os.getenv("GOOGLE_API_KEY"),
|
58
|
+
}
|
59
|
+
|
60
|
+
# Remove None values (except google_api_key)
|
61
|
+
llm_config = {k: v for k, v in llm_config.items() if v is not None}
|
62
|
+
|
63
|
+
return ChatGoogleGenerativeAI(**llm_config)
|
64
|
+
|
65
|
+
def get_supported_models(self) -> List[str]:
|
66
|
+
"""Return list of supported Google models."""
|
67
|
+
return self.SUPPORTED_MODELS.copy()
|
68
|
+
|
69
|
+
def get_model_capabilities(self, model: str) -> Dict[str, Any]:
|
70
|
+
"""Return capabilities for specific Google model."""
|
71
|
+
return self.MODEL_CAPABILITIES.get(model, {})
|
72
|
+
|
73
|
+
def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
74
|
+
"""Estimate cost based on Google pricing (as of 2024)."""
|
75
|
+
# Pricing per 1K tokens in USD
|
76
|
+
pricing = {
|
77
|
+
"gemini-1.5-pro": {"input": 0.00125, "output": 0.005},
|
78
|
+
"gemini-1.5-flash": {"input": 0.000075, "output": 0.0003},
|
79
|
+
"gemini-1.5-flash-8b": {"input": 0.0000375, "output": 0.00015},
|
80
|
+
"gemini-1.0-pro": {"input": 0.0005, "output": 0.0015},
|
81
|
+
"gemini-pro": {"input": 0.0005, "output": 0.0015},
|
82
|
+
"gemini-pro-vision": {"input": 0.00025, "output": 0.0005},
|
83
|
+
}
|
84
|
+
|
85
|
+
if model not in pricing:
|
86
|
+
return 0.0
|
87
|
+
|
88
|
+
rates = pricing[model]
|
89
|
+
return (input_tokens / 1000 * rates["input"]) + (output_tokens / 1000 * rates["output"])
|
@@ -0,0 +1,93 @@
|
|
1
|
+
"""
|
2
|
+
OpenAI LLM Driver
|
3
|
+
|
4
|
+
Provides OpenAI-specific optimizations and features for ChatGPT models.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import Dict, Any, List
|
9
|
+
from langchain_openai import ChatOpenAI
|
10
|
+
from .base_driver import BaseLLMDriver
|
11
|
+
|
12
|
+
|
13
|
+
class OpenAIDriver(BaseLLMDriver):
|
14
|
+
"""OpenAI-specific LLM driver with advanced features."""
|
15
|
+
|
16
|
+
SUPPORTED_MODELS = [
|
17
|
+
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4",
|
18
|
+
"gpt-3.5-turbo", "gpt-3.5-turbo-16k"
|
19
|
+
]
|
20
|
+
|
21
|
+
MODEL_CAPABILITIES = {
|
22
|
+
"gpt-4o": {"context_length": 128000, "function_calling": True, "vision": True},
|
23
|
+
"gpt-4o-mini": {"context_length": 128000, "function_calling": True, "vision": True},
|
24
|
+
"gpt-4-turbo": {"context_length": 128000, "function_calling": True, "vision": True},
|
25
|
+
"gpt-4": {"context_length": 8192, "function_calling": True, "vision": False},
|
26
|
+
"gpt-3.5-turbo": {"context_length": 16385, "function_calling": True, "vision": False},
|
27
|
+
"gpt-3.5-turbo-16k": {"context_length": 16385, "function_calling": True, "vision": False},
|
28
|
+
}
|
29
|
+
|
30
|
+
def validate_config(self, config: Dict[str, Any]) -> bool:
|
31
|
+
"""Validate OpenAI-specific configuration."""
|
32
|
+
# In testing environments the OPENAI_API_KEY may not be set. Instead of
|
33
|
+
# raising an error immediately we log a warning and allow initialization
|
34
|
+
# to proceed so that the driver can be mocked.
|
35
|
+
if not os.getenv("OPENAI_API_KEY"):
|
36
|
+
print("Warning: OPENAI_API_KEY environment variable not set. Using placeholder key for testing.")
|
37
|
+
os.environ.setdefault("OPENAI_API_KEY", "test-key")
|
38
|
+
|
39
|
+
model = config.get("model")
|
40
|
+
if model and model not in self.SUPPORTED_MODELS:
|
41
|
+
raise ValueError(f"Unsupported OpenAI model: {model}. Supported models: {self.SUPPORTED_MODELS}")
|
42
|
+
|
43
|
+
# Validate temperature range
|
44
|
+
temperature = config.get("temperature")
|
45
|
+
if temperature is not None and (temperature < 0 or temperature > 2):
|
46
|
+
raise ValueError("OpenAI temperature must be between 0 and 2")
|
47
|
+
|
48
|
+
return True
|
49
|
+
|
50
|
+
def create_llm(self, config: Dict[str, Any]) -> ChatOpenAI:
|
51
|
+
"""Create optimized OpenAI LLM instance."""
|
52
|
+
self.validate_config(config)
|
53
|
+
|
54
|
+
# OpenAI-specific optimizations
|
55
|
+
llm_config = {
|
56
|
+
"model": config["model"],
|
57
|
+
"temperature": config.get("temperature", 0.0),
|
58
|
+
"max_tokens": config.get("max_tokens"),
|
59
|
+
"top_p": config.get("top_p"),
|
60
|
+
"frequency_penalty": config.get("frequency_penalty"),
|
61
|
+
"presence_penalty": config.get("presence_penalty"),
|
62
|
+
}
|
63
|
+
|
64
|
+
# Remove None values
|
65
|
+
llm_config = {k: v for k, v in llm_config.items() if v is not None}
|
66
|
+
|
67
|
+
return ChatOpenAI(**llm_config)
|
68
|
+
|
69
|
+
def get_supported_models(self) -> List[str]:
|
70
|
+
"""Return list of supported OpenAI models."""
|
71
|
+
return self.SUPPORTED_MODELS.copy()
|
72
|
+
|
73
|
+
def get_model_capabilities(self, model: str) -> Dict[str, Any]:
|
74
|
+
"""Return capabilities for specific OpenAI model."""
|
75
|
+
return self.MODEL_CAPABILITIES.get(model, {})
|
76
|
+
|
77
|
+
def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
78
|
+
"""Estimate cost based on OpenAI pricing (as of 2024)."""
|
79
|
+
# Pricing per 1K tokens in USD
|
80
|
+
pricing = {
|
81
|
+
"gpt-4o": {"input": 0.005, "output": 0.015},
|
82
|
+
"gpt-4o-mini": {"input": 0.000150, "output": 0.000600},
|
83
|
+
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
84
|
+
"gpt-4": {"input": 0.03, "output": 0.06},
|
85
|
+
"gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
|
86
|
+
"gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
|
87
|
+
}
|
88
|
+
|
89
|
+
if model not in pricing:
|
90
|
+
return 0.0
|
91
|
+
|
92
|
+
rates = pricing[model]
|
93
|
+
return (input_tokens / 1000 * rates["input"]) + (output_tokens / 1000 * rates["output"])
|