diagram-to-iac 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. diagram_to_iac/__init__.py +10 -0
  2. diagram_to_iac/actions/__init__.py +7 -0
  3. diagram_to_iac/actions/git_entry.py +174 -0
  4. diagram_to_iac/actions/supervisor_entry.py +116 -0
  5. diagram_to_iac/actions/terraform_agent_entry.py +207 -0
  6. diagram_to_iac/agents/__init__.py +26 -0
  7. diagram_to_iac/agents/demonstrator_langgraph/__init__.py +10 -0
  8. diagram_to_iac/agents/demonstrator_langgraph/agent.py +826 -0
  9. diagram_to_iac/agents/git_langgraph/__init__.py +10 -0
  10. diagram_to_iac/agents/git_langgraph/agent.py +1018 -0
  11. diagram_to_iac/agents/git_langgraph/pr.py +146 -0
  12. diagram_to_iac/agents/hello_langgraph/__init__.py +9 -0
  13. diagram_to_iac/agents/hello_langgraph/agent.py +621 -0
  14. diagram_to_iac/agents/policy_agent/__init__.py +15 -0
  15. diagram_to_iac/agents/policy_agent/agent.py +507 -0
  16. diagram_to_iac/agents/policy_agent/integration_example.py +191 -0
  17. diagram_to_iac/agents/policy_agent/tools/__init__.py +14 -0
  18. diagram_to_iac/agents/policy_agent/tools/tfsec_tool.py +259 -0
  19. diagram_to_iac/agents/shell_langgraph/__init__.py +21 -0
  20. diagram_to_iac/agents/shell_langgraph/agent.py +122 -0
  21. diagram_to_iac/agents/shell_langgraph/detector.py +50 -0
  22. diagram_to_iac/agents/supervisor_langgraph/__init__.py +17 -0
  23. diagram_to_iac/agents/supervisor_langgraph/agent.py +1947 -0
  24. diagram_to_iac/agents/supervisor_langgraph/demonstrator.py +22 -0
  25. diagram_to_iac/agents/supervisor_langgraph/guards.py +23 -0
  26. diagram_to_iac/agents/supervisor_langgraph/pat_loop.py +49 -0
  27. diagram_to_iac/agents/supervisor_langgraph/router.py +9 -0
  28. diagram_to_iac/agents/terraform_langgraph/__init__.py +15 -0
  29. diagram_to_iac/agents/terraform_langgraph/agent.py +1216 -0
  30. diagram_to_iac/agents/terraform_langgraph/parser.py +76 -0
  31. diagram_to_iac/core/__init__.py +7 -0
  32. diagram_to_iac/core/agent_base.py +19 -0
  33. diagram_to_iac/core/enhanced_memory.py +302 -0
  34. diagram_to_iac/core/errors.py +4 -0
  35. diagram_to_iac/core/issue_tracker.py +49 -0
  36. diagram_to_iac/core/memory.py +132 -0
  37. diagram_to_iac/services/__init__.py +10 -0
  38. diagram_to_iac/services/observability.py +59 -0
  39. diagram_to_iac/services/step_summary.py +77 -0
  40. diagram_to_iac/tools/__init__.py +11 -0
  41. diagram_to_iac/tools/api_utils.py +108 -26
  42. diagram_to_iac/tools/git/__init__.py +45 -0
  43. diagram_to_iac/tools/git/git.py +956 -0
  44. diagram_to_iac/tools/hello/__init__.py +30 -0
  45. diagram_to_iac/tools/hello/cal_utils.py +31 -0
  46. diagram_to_iac/tools/hello/text_utils.py +97 -0
  47. diagram_to_iac/tools/llm_utils/__init__.py +20 -0
  48. diagram_to_iac/tools/llm_utils/anthropic_driver.py +87 -0
  49. diagram_to_iac/tools/llm_utils/base_driver.py +90 -0
  50. diagram_to_iac/tools/llm_utils/gemini_driver.py +89 -0
  51. diagram_to_iac/tools/llm_utils/openai_driver.py +93 -0
  52. diagram_to_iac/tools/llm_utils/router.py +303 -0
  53. diagram_to_iac/tools/sec_utils.py +4 -2
  54. diagram_to_iac/tools/shell/__init__.py +17 -0
  55. diagram_to_iac/tools/shell/shell.py +415 -0
  56. diagram_to_iac/tools/text_utils.py +277 -0
  57. diagram_to_iac/tools/tf/terraform.py +851 -0
  58. diagram_to_iac-0.8.0.dist-info/METADATA +99 -0
  59. diagram_to_iac-0.8.0.dist-info/RECORD +64 -0
  60. {diagram_to_iac-0.6.0.dist-info → diagram_to_iac-0.8.0.dist-info}/WHEEL +1 -1
  61. diagram_to_iac-0.8.0.dist-info/entry_points.txt +4 -0
  62. diagram_to_iac/agents/codegen_agent.py +0 -0
  63. diagram_to_iac/agents/consensus_agent.py +0 -0
  64. diagram_to_iac/agents/deployment_agent.py +0 -0
  65. diagram_to_iac/agents/github_agent.py +0 -0
  66. diagram_to_iac/agents/interpretation_agent.py +0 -0
  67. diagram_to_iac/agents/question_agent.py +0 -0
  68. diagram_to_iac/agents/supervisor.py +0 -0
  69. diagram_to_iac/agents/vision_agent.py +0 -0
  70. diagram_to_iac/core/config.py +0 -0
  71. diagram_to_iac/tools/cv_utils.py +0 -0
  72. diagram_to_iac/tools/gh_utils.py +0 -0
  73. diagram_to_iac/tools/tf_utils.py +0 -0
  74. diagram_to_iac-0.6.0.dist-info/METADATA +0 -16
  75. diagram_to_iac-0.6.0.dist-info/RECORD +0 -32
  76. diagram_to_iac-0.6.0.dist-info/entry_points.txt +0 -2
  77. {diagram_to_iac-0.6.0.dist-info → diagram_to_iac-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,30 @@
1
+ """
2
+ Hello LangGraph Agent Tools Package
3
+
4
+ This package contains tools specific to the HelloAgent:
5
+ - cal_utils: Arithmetic operations (addition, multiplication)
6
+ - text_utils: Text processing utilities
7
+
8
+ These tools have been moved from agents/hello_langgraph/tools/ to provide
9
+ better organization and reusability across the codebase.
10
+ """
11
+
12
+ # Import text utilities (no external dependencies)
13
+ from .text_utils import extract_numbers_from_text, extract_numbers_from_text_with_duplicates
14
+
15
+ # Import calculation utilities (requires langchain_core)
16
+ try:
17
+ from .cal_utils import add_two, multiply_two
18
+ _cal_utils_available = True
19
+ except ImportError:
20
+ _cal_utils_available = False
21
+ add_two = None
22
+ multiply_two = None
23
+
24
+ __all__ = [
25
+ "extract_numbers_from_text",
26
+ "extract_numbers_from_text_with_duplicates"
27
+ ]
28
+
29
+ if _cal_utils_available:
30
+ __all__.extend(["add_two", "multiply_two"])
@@ -0,0 +1,31 @@
1
+ from langchain_core.tools import tool # Updated import for modern LangChain
2
+ from pydantic import BaseModel, Field
3
+
4
+ # --- Pydantic Schemas for Tool Inputs ---
5
+ class AddToolInput(BaseModel):
6
+ x: int = Field(..., description="The first number for addition")
7
+ y: int = Field(..., description="The second number for addition")
8
+
9
+ class MultiplyToolInput(BaseModel):
10
+ x: int = Field(..., description="The first number for multiplication")
11
+ y: int = Field(..., description="The second number for multiplication")
12
+
13
+ # --- Tool Definitions ---
14
+ @tool(args_schema=AddToolInput)
15
+ def add_two(x: int, y: int) -> int:
16
+ """Add two numbers. Expects input according to AddToolInput schema."""
17
+ # The decorator @tool with args_schema handles parsing the input dict
18
+ # into AddToolInput and then passes the fields (x, y) to this function.
19
+ # So, the function signature remains simple (x: int, y: int).
20
+ # If we wanted the function to receive the Pydantic model itself,
21
+ # the tool definition and invocation would be slightly different, often
22
+ # by defining a custom Tool class or using StructuredTool.from_function
23
+ # where the function takes the Pydantic model.
24
+ # For @tool, it unpacks the validated args.
25
+ return x + y
26
+
27
+ @tool(args_schema=MultiplyToolInput)
28
+ def multiply_two(x: int, y: int) -> int:
29
+ """Multiply two numbers. Expects input according to MultiplyToolInput schema."""
30
+ # Similar to add_two, @tool unpacks the validated fields from MultiplyToolInput.
31
+ return x * y
@@ -0,0 +1,97 @@
1
+ import re
2
+ from typing import List
3
+
4
+ def extract_numbers_from_text(text: str) -> List[int]:
5
+ """
6
+ Extract all integer numbers from a given text string, handling various formats.
7
+
8
+ This function finds sequences of digits and also extracts numbers from patterns
9
+ like "number operator number" (e.g., "4*5", "7 + 3"). It returns a list of unique
10
+ integers found, preserving the order of their first appearance in the text.
11
+
12
+ Args:
13
+ text: The input string from which to extract numbers.
14
+
15
+ Returns:
16
+ A list of unique integers found in the text, in order of appearance.
17
+ Returns an empty list if no numbers are found.
18
+
19
+ Supports:
20
+ - Space-separated numbers: "4 + 5", "add 7 and 3"
21
+ - Adjacent numbers with operators: "4*5", "6+2"
22
+ - Basic extraction of digit sequences.
23
+ """
24
+ numbers = []
25
+
26
+ # First, find all digit sequences using regex, avoiding phone number confusion
27
+ digit_matches = re.findall(r'\b\d+\b', text)
28
+ numbers.extend([int(match) for match in digit_matches])
29
+
30
+ # Handle cases where numbers are adjacent to operators like "4*5" or "7+3"
31
+ # Use specific operators (not dash) to avoid phone number confusion
32
+ operator_adjacent = re.findall(r'(\d+)\s*[+*/×÷%]\s*(\d+)', text)
33
+ for match in operator_adjacent:
34
+ numbers.extend([int(match[0]), int(match[1])])
35
+
36
+ # Handle negative numbers in explicit mathematical contexts
37
+ # Look for patterns like "What is -10 + 5?" but avoid phone numbers
38
+ math_negative = re.findall(r'(?:is|add|subtract|plus|minus|\+|\*|/)\s+(-\d+)', text, re.IGNORECASE)
39
+ for match in math_negative:
40
+ numbers.append(int(match))
41
+
42
+ # Also catch negative numbers at the start of mathematical expressions
43
+ start_negative = re.findall(r'(?:^|\s)(-\d+)\s*[+*/×÷%]', text)
44
+ for match in start_negative:
45
+ numbers.append(int(match))
46
+
47
+ # Remove duplicates while preserving order of first appearance
48
+ seen = set()
49
+ unique_numbers = []
50
+ for num_val in numbers: # Renamed 'num' to 'num_val' to avoid conflict if 'num' was a global
51
+ if num_val not in seen:
52
+ unique_numbers.append(num_val)
53
+ seen.add(num_val)
54
+
55
+ return unique_numbers
56
+
57
+ def extract_numbers_from_text_with_duplicates(text: str) -> List[int]:
58
+ """
59
+ Extract all integer numbers from a given text string, preserving duplicates.
60
+
61
+ This function is similar to extract_numbers_from_text but preserves duplicate
62
+ numbers, which is useful for operations like "10 * 10" where both numbers
63
+ are needed even if they're the same.
64
+
65
+ Args:
66
+ text: The input string from which to extract numbers.
67
+
68
+ Returns:
69
+ A list of integers found in the text, in order of appearance, including duplicates.
70
+ Returns an empty list if no numbers are found.
71
+
72
+ Examples:
73
+ >>> extract_numbers_from_text_with_duplicates("Calculate 10 * 10")
74
+ [10, 10]
75
+ >>> extract_numbers_from_text_with_duplicates("What is 5 + 7?")
76
+ [5, 7]
77
+ """
78
+ # Use a comprehensive regex to find all numbers (positive and negative) in one pass
79
+ # This regex matches:
80
+ # 1. Negative numbers in math contexts: "What is -10 + 5?"
81
+ # 2. Regular positive numbers: "10", "5", etc.
82
+ # 3. Avoids phone number dashes by being specific about negative contexts
83
+
84
+ results = []
85
+
86
+ # Combined pattern that captures both positive and negative numbers
87
+ # Negative numbers: preceded by math keywords or operators, followed by math operators
88
+ # Positive numbers: standalone digit sequences
89
+ pattern = r'(?:(?:is|add|subtract|plus|minus|\+|\*|/|^|\s)[-\s]*(-\d+)(?=\s*[+\-*/×÷%\s]|$))|(?:\b(\d+)\b)'
90
+
91
+ for match in re.finditer(pattern, text, re.IGNORECASE):
92
+ if match.group(1): # Negative number
93
+ results.append(int(match.group(1)))
94
+ elif match.group(2): # Positive number
95
+ results.append(int(match.group(2)))
96
+
97
+ return results
@@ -0,0 +1,20 @@
1
+ """
2
+ LLM Utils Package
3
+
4
+ Provides LLM routing, drivers, and utilities for the diagram-to-iac project.
5
+ """
6
+
7
+ from .router import LLMRouter, get_llm
8
+ from .base_driver import BaseLLMDriver
9
+ from .openai_driver import OpenAIDriver
10
+ from .anthropic_driver import AnthropicDriver
11
+ from .gemini_driver import GoogleDriver
12
+
13
+ __all__ = [
14
+ "LLMRouter",
15
+ "get_llm",
16
+ "BaseLLMDriver",
17
+ "OpenAIDriver",
18
+ "AnthropicDriver",
19
+ "GoogleDriver"
20
+ ]
@@ -0,0 +1,87 @@
1
+ """
2
+ Anthropic LLM Driver
3
+
4
+ Provides Anthropic Claude-specific optimizations and features.
5
+ """
6
+
7
+ import os
8
+ from typing import Dict, Any, List
9
+ from langchain_anthropic import ChatAnthropic
10
+ from .base_driver import BaseLLMDriver
11
+
12
+
13
+ class AnthropicDriver(BaseLLMDriver):
14
+ """Anthropic Claude-specific LLM driver."""
15
+
16
+ SUPPORTED_MODELS = [
17
+ "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022",
18
+ "claude-3-opus-20240229", "claude-3-sonnet-20240229",
19
+ "claude-3-haiku-20240307"
20
+ ]
21
+
22
+ MODEL_CAPABILITIES = {
23
+ "claude-3-5-sonnet-20241022": {"context_length": 200000, "function_calling": True, "vision": True},
24
+ "claude-3-5-haiku-20241022": {"context_length": 200000, "function_calling": True, "vision": True},
25
+ "claude-3-opus-20240229": {"context_length": 200000, "function_calling": True, "vision": True},
26
+ "claude-3-sonnet-20240229": {"context_length": 200000, "function_calling": True, "vision": True},
27
+ "claude-3-haiku-20240307": {"context_length": 200000, "function_calling": True, "vision": False},
28
+ }
29
+
30
+ def validate_config(self, config: Dict[str, Any]) -> bool:
31
+ """Validate Anthropic-specific configuration."""
32
+ if not os.getenv("ANTHROPIC_API_KEY"):
33
+ raise ValueError("ANTHROPIC_API_KEY environment variable not set")
34
+
35
+ model = config.get("model")
36
+ if model and model not in self.SUPPORTED_MODELS:
37
+ raise ValueError(f"Unsupported Anthropic model: {model}. Supported models: {self.SUPPORTED_MODELS}")
38
+
39
+ # Validate temperature range
40
+ temperature = config.get("temperature")
41
+ if temperature is not None and (temperature < 0 or temperature > 1):
42
+ raise ValueError("Anthropic temperature must be between 0 and 1")
43
+
44
+ return True
45
+
46
+ def create_llm(self, config: Dict[str, Any]) -> ChatAnthropic:
47
+ """Create optimized Anthropic LLM instance."""
48
+ self.validate_config(config)
49
+
50
+ # Anthropic-specific optimizations
51
+ llm_config = {
52
+ "model": config["model"],
53
+ "temperature": config.get("temperature", 0.0),
54
+ "max_tokens": config.get("max_tokens", 1024),
55
+ "top_p": config.get("top_p"),
56
+ "top_k": config.get("top_k"), # Anthropic-specific parameter
57
+ }
58
+
59
+ # Remove None values
60
+ llm_config = {k: v for k, v in llm_config.items() if v is not None}
61
+
62
+ return ChatAnthropic(**llm_config)
63
+
64
+ def get_supported_models(self) -> List[str]:
65
+ """Return list of supported Anthropic models."""
66
+ return self.SUPPORTED_MODELS.copy()
67
+
68
+ def get_model_capabilities(self, model: str) -> Dict[str, Any]:
69
+ """Return capabilities for specific Anthropic model."""
70
+ return self.MODEL_CAPABILITIES.get(model, {})
71
+
72
+ def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
73
+ """Estimate cost based on Anthropic pricing (as of 2024)."""
74
+ # Pricing per 1K tokens in USD
75
+ pricing = {
76
+ "claude-3-5-sonnet-20241022": {"input": 0.003, "output": 0.015},
77
+ "claude-3-5-haiku-20241022": {"input": 0.00025, "output": 0.00125},
78
+ "claude-3-opus-20240229": {"input": 0.015, "output": 0.075},
79
+ "claude-3-sonnet-20240229": {"input": 0.003, "output": 0.015},
80
+ "claude-3-haiku-20240307": {"input": 0.00025, "output": 0.00125},
81
+ }
82
+
83
+ if model not in pricing:
84
+ return 0.0
85
+
86
+ rates = pricing[model]
87
+ return (input_tokens / 1000 * rates["input"]) + (output_tokens / 1000 * rates["output"])
@@ -0,0 +1,90 @@
1
+ """
2
+ Base LLM Driver Interface
3
+
4
+ This module provides the abstract base class for all LLM provider drivers.
5
+ Each driver implements provider-specific optimizations and features.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Dict, Any, List, Optional
10
+ from langchain_core.language_models.chat_models import BaseChatModel
11
+
12
+
13
+ class BaseLLMDriver(ABC):
14
+ """Abstract base class for all LLM provider drivers."""
15
+
16
+ @abstractmethod
17
+ def validate_config(self, config: Dict[str, Any]) -> bool:
18
+ """
19
+ Validate provider-specific configuration.
20
+
21
+ Args:
22
+ config: Configuration dictionary containing model parameters
23
+
24
+ Returns:
25
+ bool: True if configuration is valid
26
+
27
+ Raises:
28
+ ValueError: If configuration is invalid
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ def create_llm(self, config: Dict[str, Any]) -> BaseChatModel:
34
+ """
35
+ Create and configure LLM instance.
36
+
37
+ Args:
38
+ config: Configuration dictionary containing model parameters
39
+
40
+ Returns:
41
+ BaseChatModel: Configured LLM instance
42
+ """
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_supported_models(self) -> List[str]:
47
+ """
48
+ Return list of supported models for this provider.
49
+
50
+ Returns:
51
+ List[str]: List of supported model names
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def get_model_capabilities(self, model: str) -> Dict[str, Any]:
57
+ """
58
+ Return capabilities for specific model.
59
+
60
+ Args:
61
+ model: Model name
62
+
63
+ Returns:
64
+ Dict containing capabilities like context_length, function_calling, vision, etc.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
70
+ """
71
+ Estimate cost for given token usage.
72
+
73
+ Args:
74
+ model: Model name
75
+ input_tokens: Number of input tokens
76
+ output_tokens: Number of output tokens
77
+
78
+ Returns:
79
+ float: Estimated cost in USD
80
+ """
81
+ pass
82
+
83
+ def get_provider_name(self) -> str:
84
+ """
85
+ Get the provider name for this driver.
86
+
87
+ Returns:
88
+ str: Provider name (e.g., 'openai', 'anthropic', 'google')
89
+ """
90
+ return self.__class__.__name__.lower().replace('driver', '')
@@ -0,0 +1,89 @@
1
+ """
2
+ Google Gemini LLM Driver
3
+
4
+ Provides Google Gemini-specific optimizations and features.
5
+ """
6
+
7
+ import os
8
+ from typing import Dict, Any, List
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from .base_driver import BaseLLMDriver
11
+
12
+
13
+ class GoogleDriver(BaseLLMDriver):
14
+ """Google Gemini-specific LLM driver."""
15
+
16
+ SUPPORTED_MODELS = [
17
+ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
18
+ "gemini-1.0-pro", "gemini-pro", "gemini-pro-vision"
19
+ ]
20
+
21
+ MODEL_CAPABILITIES = {
22
+ "gemini-1.5-pro": {"context_length": 1000000, "function_calling": True, "vision": True},
23
+ "gemini-1.5-flash": {"context_length": 1000000, "function_calling": True, "vision": True},
24
+ "gemini-1.5-flash-8b": {"context_length": 1000000, "function_calling": True, "vision": True},
25
+ "gemini-1.0-pro": {"context_length": 30720, "function_calling": True, "vision": False},
26
+ "gemini-pro": {"context_length": 30720, "function_calling": True, "vision": False},
27
+ "gemini-pro-vision": {"context_length": 12288, "function_calling": False, "vision": True},
28
+ }
29
+
30
+ def validate_config(self, config: Dict[str, Any]) -> bool:
31
+ """Validate Google-specific configuration."""
32
+ if not os.getenv("GOOGLE_API_KEY"):
33
+ raise ValueError("GOOGLE_API_KEY environment variable not set")
34
+
35
+ model = config.get("model")
36
+ if model and model not in self.SUPPORTED_MODELS:
37
+ raise ValueError(f"Unsupported Google model: {model}. Supported models: {self.SUPPORTED_MODELS}")
38
+
39
+ # Validate temperature range
40
+ temperature = config.get("temperature")
41
+ if temperature is not None and (temperature < 0 or temperature > 1):
42
+ raise ValueError("Google temperature must be between 0 and 1")
43
+
44
+ return True
45
+
46
+ def create_llm(self, config: Dict[str, Any]) -> ChatGoogleGenerativeAI:
47
+ """Create optimized Google LLM instance."""
48
+ self.validate_config(config)
49
+
50
+ # Google-specific optimizations
51
+ llm_config = {
52
+ "model": config["model"],
53
+ "temperature": config.get("temperature", 0.0),
54
+ "max_tokens": config.get("max_tokens"),
55
+ "top_p": config.get("top_p"),
56
+ "top_k": config.get("top_k"), # Google-specific parameter
57
+ "google_api_key": os.getenv("GOOGLE_API_KEY"),
58
+ }
59
+
60
+ # Remove None values (except google_api_key)
61
+ llm_config = {k: v for k, v in llm_config.items() if v is not None}
62
+
63
+ return ChatGoogleGenerativeAI(**llm_config)
64
+
65
+ def get_supported_models(self) -> List[str]:
66
+ """Return list of supported Google models."""
67
+ return self.SUPPORTED_MODELS.copy()
68
+
69
+ def get_model_capabilities(self, model: str) -> Dict[str, Any]:
70
+ """Return capabilities for specific Google model."""
71
+ return self.MODEL_CAPABILITIES.get(model, {})
72
+
73
+ def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
74
+ """Estimate cost based on Google pricing (as of 2024)."""
75
+ # Pricing per 1K tokens in USD
76
+ pricing = {
77
+ "gemini-1.5-pro": {"input": 0.00125, "output": 0.005},
78
+ "gemini-1.5-flash": {"input": 0.000075, "output": 0.0003},
79
+ "gemini-1.5-flash-8b": {"input": 0.0000375, "output": 0.00015},
80
+ "gemini-1.0-pro": {"input": 0.0005, "output": 0.0015},
81
+ "gemini-pro": {"input": 0.0005, "output": 0.0015},
82
+ "gemini-pro-vision": {"input": 0.00025, "output": 0.0005},
83
+ }
84
+
85
+ if model not in pricing:
86
+ return 0.0
87
+
88
+ rates = pricing[model]
89
+ return (input_tokens / 1000 * rates["input"]) + (output_tokens / 1000 * rates["output"])
@@ -0,0 +1,93 @@
1
+ """
2
+ OpenAI LLM Driver
3
+
4
+ Provides OpenAI-specific optimizations and features for ChatGPT models.
5
+ """
6
+
7
+ import os
8
+ from typing import Dict, Any, List
9
+ from langchain_openai import ChatOpenAI
10
+ from .base_driver import BaseLLMDriver
11
+
12
+
13
+ class OpenAIDriver(BaseLLMDriver):
14
+ """OpenAI-specific LLM driver with advanced features."""
15
+
16
+ SUPPORTED_MODELS = [
17
+ "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4",
18
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
19
+ ]
20
+
21
+ MODEL_CAPABILITIES = {
22
+ "gpt-4o": {"context_length": 128000, "function_calling": True, "vision": True},
23
+ "gpt-4o-mini": {"context_length": 128000, "function_calling": True, "vision": True},
24
+ "gpt-4-turbo": {"context_length": 128000, "function_calling": True, "vision": True},
25
+ "gpt-4": {"context_length": 8192, "function_calling": True, "vision": False},
26
+ "gpt-3.5-turbo": {"context_length": 16385, "function_calling": True, "vision": False},
27
+ "gpt-3.5-turbo-16k": {"context_length": 16385, "function_calling": True, "vision": False},
28
+ }
29
+
30
+ def validate_config(self, config: Dict[str, Any]) -> bool:
31
+ """Validate OpenAI-specific configuration."""
32
+ # In testing environments the OPENAI_API_KEY may not be set. Instead of
33
+ # raising an error immediately we log a warning and allow initialization
34
+ # to proceed so that the driver can be mocked.
35
+ if not os.getenv("OPENAI_API_KEY"):
36
+ print("Warning: OPENAI_API_KEY environment variable not set. Using placeholder key for testing.")
37
+ os.environ.setdefault("OPENAI_API_KEY", "test-key")
38
+
39
+ model = config.get("model")
40
+ if model and model not in self.SUPPORTED_MODELS:
41
+ raise ValueError(f"Unsupported OpenAI model: {model}. Supported models: {self.SUPPORTED_MODELS}")
42
+
43
+ # Validate temperature range
44
+ temperature = config.get("temperature")
45
+ if temperature is not None and (temperature < 0 or temperature > 2):
46
+ raise ValueError("OpenAI temperature must be between 0 and 2")
47
+
48
+ return True
49
+
50
+ def create_llm(self, config: Dict[str, Any]) -> ChatOpenAI:
51
+ """Create optimized OpenAI LLM instance."""
52
+ self.validate_config(config)
53
+
54
+ # OpenAI-specific optimizations
55
+ llm_config = {
56
+ "model": config["model"],
57
+ "temperature": config.get("temperature", 0.0),
58
+ "max_tokens": config.get("max_tokens"),
59
+ "top_p": config.get("top_p"),
60
+ "frequency_penalty": config.get("frequency_penalty"),
61
+ "presence_penalty": config.get("presence_penalty"),
62
+ }
63
+
64
+ # Remove None values
65
+ llm_config = {k: v for k, v in llm_config.items() if v is not None}
66
+
67
+ return ChatOpenAI(**llm_config)
68
+
69
+ def get_supported_models(self) -> List[str]:
70
+ """Return list of supported OpenAI models."""
71
+ return self.SUPPORTED_MODELS.copy()
72
+
73
+ def get_model_capabilities(self, model: str) -> Dict[str, Any]:
74
+ """Return capabilities for specific OpenAI model."""
75
+ return self.MODEL_CAPABILITIES.get(model, {})
76
+
77
+ def estimate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
78
+ """Estimate cost based on OpenAI pricing (as of 2024)."""
79
+ # Pricing per 1K tokens in USD
80
+ pricing = {
81
+ "gpt-4o": {"input": 0.005, "output": 0.015},
82
+ "gpt-4o-mini": {"input": 0.000150, "output": 0.000600},
83
+ "gpt-4-turbo": {"input": 0.01, "output": 0.03},
84
+ "gpt-4": {"input": 0.03, "output": 0.06},
85
+ "gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
86
+ "gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
87
+ }
88
+
89
+ if model not in pricing:
90
+ return 0.0
91
+
92
+ rates = pricing[model]
93
+ return (input_tokens / 1000 * rates["input"]) + (output_tokens / 1000 * rates["output"])