pdd-cli 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pdd-cli might be problematic. Click here for more details.

Files changed (95) hide show
  1. pdd/__init__.py +0 -0
  2. pdd/auto_deps_main.py +98 -0
  3. pdd/auto_include.py +175 -0
  4. pdd/auto_update.py +73 -0
  5. pdd/bug_main.py +99 -0
  6. pdd/bug_to_unit_test.py +159 -0
  7. pdd/change.py +141 -0
  8. pdd/change_main.py +240 -0
  9. pdd/cli.py +607 -0
  10. pdd/cmd_test_main.py +155 -0
  11. pdd/code_generator.py +117 -0
  12. pdd/code_generator_main.py +66 -0
  13. pdd/comment_line.py +35 -0
  14. pdd/conflicts_in_prompts.py +143 -0
  15. pdd/conflicts_main.py +90 -0
  16. pdd/construct_paths.py +251 -0
  17. pdd/context_generator.py +133 -0
  18. pdd/context_generator_main.py +73 -0
  19. pdd/continue_generation.py +140 -0
  20. pdd/crash_main.py +127 -0
  21. pdd/data/language_format.csv +61 -0
  22. pdd/data/llm_model.csv +15 -0
  23. pdd/detect_change.py +142 -0
  24. pdd/detect_change_main.py +100 -0
  25. pdd/find_section.py +28 -0
  26. pdd/fix_code_loop.py +212 -0
  27. pdd/fix_code_module_errors.py +143 -0
  28. pdd/fix_error_loop.py +216 -0
  29. pdd/fix_errors_from_unit_tests.py +240 -0
  30. pdd/fix_main.py +138 -0
  31. pdd/generate_output_paths.py +194 -0
  32. pdd/generate_test.py +140 -0
  33. pdd/get_comment.py +55 -0
  34. pdd/get_extension.py +52 -0
  35. pdd/get_language.py +41 -0
  36. pdd/git_update.py +84 -0
  37. pdd/increase_tests.py +93 -0
  38. pdd/insert_includes.py +150 -0
  39. pdd/llm_invoke.py +304 -0
  40. pdd/load_prompt_template.py +59 -0
  41. pdd/pdd_completion.fish +72 -0
  42. pdd/pdd_completion.sh +141 -0
  43. pdd/pdd_completion.zsh +418 -0
  44. pdd/postprocess.py +121 -0
  45. pdd/postprocess_0.py +52 -0
  46. pdd/preprocess.py +199 -0
  47. pdd/preprocess_main.py +72 -0
  48. pdd/process_csv_change.py +182 -0
  49. pdd/prompts/auto_include_LLM.prompt +230 -0
  50. pdd/prompts/bug_to_unit_test_LLM.prompt +17 -0
  51. pdd/prompts/change_LLM.prompt +34 -0
  52. pdd/prompts/conflict_LLM.prompt +23 -0
  53. pdd/prompts/continue_generation_LLM.prompt +3 -0
  54. pdd/prompts/detect_change_LLM.prompt +65 -0
  55. pdd/prompts/example_generator_LLM.prompt +10 -0
  56. pdd/prompts/extract_auto_include_LLM.prompt +6 -0
  57. pdd/prompts/extract_code_LLM.prompt +22 -0
  58. pdd/prompts/extract_conflict_LLM.prompt +19 -0
  59. pdd/prompts/extract_detect_change_LLM.prompt +19 -0
  60. pdd/prompts/extract_program_code_fix_LLM.prompt +16 -0
  61. pdd/prompts/extract_prompt_change_LLM.prompt +7 -0
  62. pdd/prompts/extract_prompt_split_LLM.prompt +9 -0
  63. pdd/prompts/extract_prompt_update_LLM.prompt +8 -0
  64. pdd/prompts/extract_promptline_LLM.prompt +11 -0
  65. pdd/prompts/extract_unit_code_fix_LLM.prompt +332 -0
  66. pdd/prompts/extract_xml_LLM.prompt +7 -0
  67. pdd/prompts/fix_code_module_errors_LLM.prompt +17 -0
  68. pdd/prompts/fix_errors_from_unit_tests_LLM.prompt +62 -0
  69. pdd/prompts/generate_test_LLM.prompt +12 -0
  70. pdd/prompts/increase_tests_LLM.prompt +16 -0
  71. pdd/prompts/insert_includes_LLM.prompt +30 -0
  72. pdd/prompts/split_LLM.prompt +94 -0
  73. pdd/prompts/summarize_file_LLM.prompt +11 -0
  74. pdd/prompts/trace_LLM.prompt +30 -0
  75. pdd/prompts/trim_results_LLM.prompt +83 -0
  76. pdd/prompts/trim_results_start_LLM.prompt +45 -0
  77. pdd/prompts/unfinished_prompt_LLM.prompt +18 -0
  78. pdd/prompts/update_prompt_LLM.prompt +19 -0
  79. pdd/prompts/xml_convertor_LLM.prompt +54 -0
  80. pdd/split.py +119 -0
  81. pdd/split_main.py +103 -0
  82. pdd/summarize_directory.py +212 -0
  83. pdd/trace.py +135 -0
  84. pdd/trace_main.py +108 -0
  85. pdd/track_cost.py +102 -0
  86. pdd/unfinished_prompt.py +114 -0
  87. pdd/update_main.py +96 -0
  88. pdd/update_prompt.py +115 -0
  89. pdd/xml_tagger.py +122 -0
  90. pdd_cli-0.0.2.dist-info/LICENSE +7 -0
  91. pdd_cli-0.0.2.dist-info/METADATA +225 -0
  92. pdd_cli-0.0.2.dist-info/RECORD +95 -0
  93. pdd_cli-0.0.2.dist-info/WHEEL +5 -0
  94. pdd_cli-0.0.2.dist-info/entry_points.txt +2 -0
  95. pdd_cli-0.0.2.dist-info/top_level.txt +1 -0
pdd/git_update.py ADDED
@@ -0,0 +1,84 @@
1
+ import os
2
+ from typing import Tuple, Optional
3
+ from rich import print
4
+ from rich.console import Console
5
+ from rich.panel import Panel
6
+ from .update_prompt import update_prompt
7
+ import git
8
+
9
+ console = Console()
10
+
11
+ def git_update(
12
+ input_prompt: str,
13
+ modified_code_file: str,
14
+ strength: float,
15
+ temperature: float,
16
+ verbose: bool = False
17
+ ) -> Tuple[Optional[str], float, str]:
18
+ """
19
+ Read in modified code, restore the prior checked-in version from GitHub,
20
+ update the prompt, write back the modified code, and return outputs.
21
+
22
+ Args:
23
+ input_prompt (str): The prompt that generated the original code.
24
+ modified_code_file (str): Filepath of the modified code.
25
+ strength (float): Strength parameter for the LLM model.
26
+ temperature (float): Temperature parameter for the LLM model.
27
+
28
+ Returns:
29
+ Tuple[Optional[str], float, str]: Modified prompt, total cost, and model name.
30
+ """
31
+ try:
32
+ # Check if inputs are valid
33
+ if not input_prompt or not modified_code_file:
34
+ raise ValueError("Input prompt and modified code file path are required.")
35
+
36
+ if not os.path.exists(modified_code_file):
37
+ raise FileNotFoundError(f"Modified code file not found: {modified_code_file}")
38
+
39
+ # Initialize git repository
40
+ repo = git.Repo(search_parent_directories=True)
41
+
42
+ # Get the file's relative path to the repo root
43
+ repo_root = repo.git.rev_parse("--show-toplevel")
44
+ relative_path = os.path.relpath(modified_code_file, repo_root)
45
+
46
+ # Read the modified code
47
+ with open(modified_code_file, 'r') as file:
48
+ modified_code = file.read()
49
+
50
+ # Restore the prior checked-in version
51
+ repo.git.checkout('HEAD', '--', relative_path)
52
+
53
+ # Read the original input code
54
+ with open(modified_code_file, 'r') as file:
55
+ original_input_code = file.read()
56
+
57
+ # Call update_prompt function
58
+ modified_prompt, total_cost, model_name = update_prompt(
59
+ input_prompt=input_prompt,
60
+ input_code=original_input_code,
61
+ modified_code=modified_code,
62
+ strength=strength,
63
+ temperature=temperature,
64
+ verbose=verbose
65
+ )
66
+
67
+ # Write back the modified code
68
+ with open(modified_code_file, 'w') as file:
69
+ file.write(modified_code)
70
+
71
+
72
+ # Pretty print the results
73
+ console.print(Panel.fit(
74
+ f"[bold green]Success:[/bold green]\n"
75
+ f"Modified prompt: {modified_prompt}\n"
76
+ f"Total cost: ${total_cost:.6f}\n"
77
+ f"Model name: {model_name}"
78
+ ))
79
+
80
+ return modified_prompt, total_cost, model_name
81
+
82
+ except Exception as e:
83
+ console.print(Panel(f"[bold red]Error:[/bold red] {str(e)}", title="Error", expand=False))
84
+ return None, 0.0, ""
pdd/increase_tests.py ADDED
@@ -0,0 +1,93 @@
1
+ from typing import Tuple
2
+ from rich.console import Console
3
+
4
+ from .load_prompt_template import load_prompt_template
5
+ from .llm_invoke import llm_invoke
6
+ from .postprocess import postprocess
7
+
8
+ def increase_tests(
9
+ existing_unit_tests: str,
10
+ coverage_report: str,
11
+ code: str,
12
+ prompt_that_generated_code: str,
13
+ language: str = "python",
14
+ strength: float = 0.5,
15
+ temperature: float = 0.0,
16
+ verbose: bool = False
17
+ ) -> Tuple[str, float, str]:
18
+ """
19
+ Generate additional unit tests to increase code coverage.
20
+
21
+ Args:
22
+ existing_unit_tests (str): Current unit tests for the code
23
+ coverage_report (str): Coverage report for the code
24
+ code (str): Code under test
25
+ prompt_that_generated_code (str): Original prompt used to generate the code
26
+ language (str, optional): Programming language. Defaults to "python".
27
+ strength (float, optional): LLM model strength. Defaults to 0.5.
28
+ temperature (float, optional): LLM model temperature. Defaults to 0.0.
29
+ verbose (bool, optional): Verbose output flag. Defaults to False.
30
+
31
+ Returns:
32
+ Tuple containing:
33
+ - Increased test function (str)
34
+ - Total cost of generation (float)
35
+ - Model name used (str)
36
+ """
37
+ console = Console()
38
+
39
+ # Validate inputs
40
+ if not all([existing_unit_tests, coverage_report, code, prompt_that_generated_code]):
41
+ raise ValueError("All input parameters must be non-empty strings")
42
+
43
+ # Validate strength and temperature
44
+ if not (0 <= strength <= 1):
45
+ raise ValueError("Strength must be between 0 and 1")
46
+ if not (0 <= temperature <= 1):
47
+ raise ValueError("Temperature must be between 0 and 1")
48
+
49
+ try:
50
+ # Step 1: Load prompt template
51
+ prompt_name = "increase_tests_LLM"
52
+ prompt_template = load_prompt_template(prompt_name)
53
+
54
+ if verbose:
55
+ console.print(f"[blue]Loaded Prompt Template:[/blue]\n{prompt_template}")
56
+
57
+ # Step 2: Prepare input for LLM invoke
58
+ input_json = {
59
+ "existing_unit_tests": existing_unit_tests,
60
+ "coverage_report": coverage_report,
61
+ "code": code,
62
+ "prompt_that_generated_code": prompt_that_generated_code,
63
+ "language": language
64
+ }
65
+
66
+ # Invoke LLM with the prompt
67
+ llm_response = llm_invoke(
68
+ prompt=prompt_template,
69
+ input_json=input_json,
70
+ strength=strength,
71
+ temperature=temperature,
72
+ verbose=verbose
73
+ )
74
+
75
+ # Step 3: Postprocess the result
76
+ increase_test_function, total_cost, model_name = postprocess(
77
+ llm_response['result'],
78
+ language,
79
+ 0.89, # Same strength as LLM invoke
80
+ temperature,
81
+ verbose
82
+ )
83
+
84
+ if verbose:
85
+ console.print(f"[green]Generated Test Function:[/green]\n{increase_test_function}")
86
+ console.print(f"[blue]Total Cost: ${total_cost:.6f}[/blue]")
87
+ console.print(f"[blue]Model Used: {model_name}[/blue]")
88
+
89
+ return increase_test_function, total_cost, model_name
90
+
91
+ except Exception as e:
92
+ console.print(f"[red]Error in increase_tests: {str(e)}[/red]")
93
+ raise
pdd/insert_includes.py ADDED
@@ -0,0 +1,150 @@
1
+ from typing import Tuple
2
+ from pathlib import Path
3
+ from rich import print
4
+ from pydantic import BaseModel, Field
5
+
6
+ from .llm_invoke import llm_invoke
7
+ from .load_prompt_template import load_prompt_template
8
+ from .auto_include import auto_include
9
+ from .preprocess import preprocess
10
+
11
+ class InsertIncludesOutput(BaseModel):
12
+ output_prompt: str = Field(description="The prompt with dependencies inserted")
13
+
14
+ def insert_includes(
15
+ input_prompt: str,
16
+ directory_path: str,
17
+ csv_filename: str,
18
+ strength: float,
19
+ temperature: float,
20
+ verbose: bool = False
21
+ ) -> Tuple[str, str, float, str]:
22
+ """
23
+ Determine needed dependencies and insert them into a prompt.
24
+
25
+ Args:
26
+ input_prompt (str): The prompt to process
27
+ directory_path (str): Directory path where the prompt file is located
28
+ csv_filename (str): Name of the CSV file containing dependencies
29
+ strength (float): Strength parameter for the LLM model
30
+ temperature (float): Temperature parameter for the LLM model
31
+ verbose (bool, optional): Whether to print detailed information. Defaults to False.
32
+
33
+ Returns:
34
+ Tuple[str, str, float, str]: Tuple containing:
35
+ - output_prompt: The prompt with dependencies inserted
36
+ - csv_output: Complete CSV output from auto_include
37
+ - total_cost: Total cost of running the function
38
+ - model_name: Name of the LLM model used
39
+ """
40
+ try:
41
+ # Step 1: Load the prompt template
42
+ insert_includes_prompt = load_prompt_template("insert_includes_LLM")
43
+ if not insert_includes_prompt:
44
+ raise ValueError("Failed to load insert_includes_LLM.prompt template")
45
+
46
+ if verbose:
47
+ print("[blue]Loaded insert_includes_LLM prompt template[/blue]")
48
+
49
+ # Step 2: Read the CSV file
50
+ try:
51
+ with open(csv_filename, 'r') as file:
52
+ csv_content = file.read()
53
+ except FileNotFoundError:
54
+ if verbose:
55
+ print(f"[yellow]CSV file {csv_filename} not found. Creating empty CSV.[/yellow]")
56
+ csv_content = "full_path,file_summary,date\n"
57
+ Path(csv_filename).write_text(csv_content)
58
+
59
+ # Step 3: Preprocess the prompt template
60
+ processed_prompt = preprocess(
61
+ insert_includes_prompt,
62
+ recursive=False,
63
+ double_curly_brackets=False
64
+ )
65
+
66
+ if verbose:
67
+ print("[blue]Preprocessed prompt template[/blue]")
68
+
69
+ # Step 4: Get dependencies using auto_include
70
+ dependencies, csv_output, auto_include_cost, auto_include_model = auto_include(
71
+ input_prompt=input_prompt,
72
+ directory_path=directory_path,
73
+ csv_file=csv_content,
74
+ strength=strength,
75
+ temperature=temperature,
76
+ verbose=verbose
77
+ )
78
+
79
+ if verbose:
80
+ print("[blue]Retrieved dependencies using auto_include[/blue]")
81
+ print(f"Dependencies found: {dependencies}")
82
+
83
+ # Step 5: Run llm_invoke with the insert includes prompt
84
+ response = llm_invoke(
85
+ prompt=processed_prompt,
86
+ input_json={
87
+ "actual_prompt_to_update": input_prompt,
88
+ "actual_dependencies_to_insert": dependencies
89
+ },
90
+ strength=strength,
91
+ temperature=temperature,
92
+ verbose=verbose,
93
+ output_pydantic=InsertIncludesOutput
94
+ )
95
+
96
+ if not response or 'result' not in response:
97
+ raise ValueError("Failed to get valid response from LLM model")
98
+
99
+ result: InsertIncludesOutput = response['result']
100
+ model_name = response['model_name']
101
+ total_cost = response['cost'] + auto_include_cost
102
+
103
+ if verbose:
104
+ print("[green]Successfully inserted includes into prompt[/green]")
105
+ print(f"Total cost: ${total_cost:.6f}")
106
+ print(f"Model used: {model_name}")
107
+
108
+ return (
109
+ result.output_prompt,
110
+ csv_output,
111
+ total_cost,
112
+ model_name
113
+ )
114
+
115
+ except Exception as e:
116
+ print(f"[red]Error in insert_includes: {str(e)}[/red]")
117
+ raise
118
+
119
+ def main():
120
+ """Example usage of the insert_includes function."""
121
+ # Example input
122
+ input_prompt = """% Generate a Python function that processes data
123
+ <include>data_processing.py</include>
124
+ """
125
+ directory_path = "./src"
126
+ csv_filename = "dependencies.csv"
127
+ strength = 0.7
128
+ temperature = 0.5
129
+
130
+ try:
131
+ output_prompt, csv_output, total_cost, model_name = insert_includes(
132
+ input_prompt=input_prompt,
133
+ directory_path=directory_path,
134
+ csv_filename=csv_filename,
135
+ strength=strength,
136
+ temperature=temperature,
137
+ verbose=True
138
+ )
139
+
140
+ print("\n[bold green]Results:[/bold green]")
141
+ print(f"[white]Output Prompt:[/white]\n{output_prompt}")
142
+ print(f"\n[white]CSV Output:[/white]\n{csv_output}")
143
+ print(f"[white]Total Cost: ${total_cost:.6f}[/white]")
144
+ print(f"[white]Model Used: {model_name}[/white]")
145
+
146
+ except Exception as e:
147
+ print(f"[red]Error in main: {str(e)}[/red]")
148
+
149
+ if __name__ == "__main__":
150
+ main()
pdd/llm_invoke.py ADDED
@@ -0,0 +1,304 @@
1
+ # llm_invoke.py
2
+
3
+ import os
4
+ import csv
5
+ import json
6
+ from pydantic import BaseModel, Field
7
+ from rich import print as rprint
8
+
9
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
10
+ from langchain_community.cache import SQLiteCache
11
+ from langchain.globals import set_llm_cache
12
+ from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
13
+ from langchain_core.runnables import RunnablePassthrough, ConfigurableField
14
+
15
+ from langchain_openai import AzureChatOpenAI
16
+ from langchain_fireworks import Fireworks
17
+ from langchain_anthropic import ChatAnthropic
18
+ from langchain_openai import ChatOpenAI # Chatbot and conversational tasks
19
+ from langchain_openai import OpenAI # General language tasks
20
+ from langchain_google_genai import ChatGoogleGenerativeAI
21
+ from langchain_groq import ChatGroq
22
+ from langchain_together import Together
23
+ from langchain_ollama.llms import OllamaLLM
24
+
25
+ from langchain.callbacks.base import BaseCallbackHandler
26
+ from langchain.schema import LLMResult
27
+
28
+ # import logging
29
+
30
+ # Configure logging to output to the console
31
+ # logging.basicConfig(level=logging.DEBUG)
32
+
33
+ # Get the LangSmith logger
34
+ # langsmith_logger = logging.getLogger("langsmith")
35
+
36
+ # Set its logging level to DEBUG
37
+ # langsmith_logger.setLevel(logging.DEBUG)
38
+
39
+ class CompletionStatusHandler(BaseCallbackHandler):
40
+ def __init__(self):
41
+ self.is_complete = False
42
+ self.finish_reason = None
43
+ self.input_tokens = None
44
+ self.output_tokens = None
45
+
46
+ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
47
+ self.is_complete = True
48
+ if response.generations and response.generations[0]:
49
+ generation = response.generations[0][0]
50
+ self.finish_reason = generation.generation_info.get('finish_reason', "").lower()
51
+
52
+ # Extract token usage
53
+ if hasattr(generation.message, 'usage_metadata'):
54
+ usage_metadata = generation.message.usage_metadata
55
+ self.input_tokens = usage_metadata.get('input_tokens')
56
+ self.output_tokens = usage_metadata.get('output_tokens')
57
+
58
+
59
+ class ModelInfo:
60
+ def __init__(self, provider, model, input_cost, output_cost, coding_arena_elo,
61
+ base_url, api_key, counter, encoder, max_tokens, max_completion_tokens,
62
+ structured_output):
63
+ self.provider = provider.strip()
64
+ self.model = model.strip()
65
+ self.input_cost = float(input_cost) if input_cost else 0.0
66
+ self.output_cost = float(output_cost) if output_cost else 0.0
67
+ self.average_cost = (self.input_cost + self.output_cost) / 2
68
+ self.coding_arena_elo = float(coding_arena_elo) if coding_arena_elo else 0.0
69
+ self.base_url = base_url.strip() if base_url else None
70
+ self.api_key = api_key.strip() if api_key else None
71
+ self.counter = counter.strip() if counter else None
72
+ self.encoder = encoder.strip() if encoder else None
73
+ self.max_tokens = int(max_tokens) if max_tokens else None
74
+ self.max_completion_tokens = int(
75
+ max_completion_tokens) if max_completion_tokens else None
76
+ self.structured_output = structured_output.lower(
77
+ ) == 'true' if structured_output else False
78
+
79
+
80
+ def load_models():
81
+ PDD_PATH = os.environ.get('PDD_PATH', '.')
82
+ # Assume that llm_model.csv is in PDD_PATH/data
83
+ models_file = os.path.join(PDD_PATH, 'data', 'llm_model.csv')
84
+ models = []
85
+ try:
86
+ with open(models_file, newline='') as csvfile:
87
+ reader = csv.DictReader(csvfile)
88
+ for row in reader:
89
+ model_info = ModelInfo(
90
+ provider=row['provider'],
91
+ model=row['model'],
92
+ input_cost=row['input'],
93
+ output_cost=row['output'],
94
+ coding_arena_elo=row['coding_arena_elo'],
95
+ base_url=row['base_url'],
96
+ api_key=row['api_key'],
97
+ counter=row['counter'],
98
+ encoder=row['encoder'],
99
+ max_tokens=row['max_tokens'],
100
+ max_completion_tokens=row['max_completion_tokens'],
101
+ structured_output=row['structured_output']
102
+ )
103
+ models.append(model_info)
104
+ except FileNotFoundError:
105
+ raise FileNotFoundError(f"llm_model.csv not found at {models_file}")
106
+ return models
107
+
108
+
109
+ def select_model(strength, models, base_model_name):
110
+ # Get the base model
111
+ base_model = None
112
+ for model in models:
113
+ if model.model == base_model_name:
114
+ base_model = model
115
+ break
116
+ if not base_model:
117
+ raise ValueError(f"Base model {base_model_name} not found in the models list.")
118
+
119
+ if strength == 0.5:
120
+ return base_model
121
+ elif strength < 0.5:
122
+ # Models cheaper than or equal to the base model
123
+ cheaper_models = [
124
+ model for model in models if model.average_cost <= base_model.average_cost]
125
+ # Sort models by average_cost ascending
126
+ cheaper_models.sort(key=lambda m: m.average_cost)
127
+ if not cheaper_models:
128
+ return base_model
129
+ # Interpolate between cheapest model and base model
130
+ cheapest_model = cheaper_models[0]
131
+ cost_range = base_model.average_cost - cheapest_model.average_cost
132
+ target_cost = cheapest_model.average_cost + (strength / 0.5) * cost_range
133
+ # Find the model with closest average cost to target_cost
134
+ selected_model = min(
135
+ cheaper_models, key=lambda m: abs(m.average_cost - target_cost))
136
+ return selected_model
137
+ else:
138
+ # strength > 0.5
139
+ # Models better than or equal to the base model
140
+ better_models = [
141
+ model for model in models if model.coding_arena_elo >= base_model.coding_arena_elo]
142
+ # Sort models by coding_arena_elo ascending
143
+ better_models.sort(key=lambda m: m.coding_arena_elo)
144
+ if not better_models:
145
+ return base_model
146
+ # Interpolate between base model and highest ELO model
147
+ highest_elo_model = better_models[-1]
148
+ elo_range = highest_elo_model.coding_arena_elo - base_model.coding_arena_elo
149
+ target_elo = base_model.coding_arena_elo + \
150
+ ((strength - 0.5) / 0.5) * elo_range
151
+ # Find the model with closest ELO to target_elo
152
+ selected_model = min(
153
+ better_models, key=lambda m: abs(m.coding_arena_elo - target_elo))
154
+ return selected_model
155
+
156
+
157
+ def create_llm_instance(selected_model, temperature, handler):
158
+ provider = selected_model.provider.lower()
159
+ model_name = selected_model.model
160
+ base_url = selected_model.base_url
161
+ api_key_name = selected_model.api_key
162
+ max_completion_tokens = selected_model.max_completion_tokens
163
+ max_tokens = selected_model.max_tokens
164
+
165
+ # Retrieve API key from environment variable if needed
166
+ api_key = os.environ.get(api_key_name) if api_key_name else None
167
+
168
+ # Initialize the appropriate LLM class
169
+ if provider == 'openai':
170
+ if base_url:
171
+ llm = ChatOpenAI(model=model_name, temperature=temperature,
172
+ openai_api_key=api_key, callbacks=[handler], openai_api_base = base_url)
173
+ else:
174
+ if model_name[0] == 'o':
175
+ llm = ChatOpenAI(model=model_name, temperature=temperature,
176
+ openai_api_key=api_key, callbacks=[handler],
177
+ model_kwargs = {'reasoning_effort':'high'})
178
+ else:
179
+ llm = ChatOpenAI(model=model_name, temperature=temperature,
180
+ openai_api_key=api_key, callbacks=[handler])
181
+ elif provider == 'anthropic':
182
+ llm = ChatAnthropic(model=model_name, temperature=temperature,
183
+ callbacks=[handler])
184
+ elif provider == 'google':
185
+ llm = ChatGoogleGenerativeAI(
186
+ model=model_name, temperature=temperature, callbacks=[handler])
187
+ elif provider == 'ollama':
188
+ llm = OllamaLLM(
189
+ model=model_name, temperature=temperature, callbacks=[handler])
190
+ elif provider == 'azure':
191
+ llm = AzureChatOpenAI(
192
+ model=model_name, temperature=temperature, callbacks=[handler])
193
+ elif provider == 'fireworks':
194
+ llm = Fireworks(model=model_name, temperature=temperature,
195
+ callbacks=[handler])
196
+ elif provider == 'together':
197
+ llm = Together(model=model_name, temperature=temperature,
198
+ callbacks=[handler])
199
+ elif provider == 'groq':
200
+ llm = ChatGroq(model_name=model_name, temperature=temperature,
201
+ callbacks=[handler])
202
+ else:
203
+ raise ValueError(f"Unsupported provider: {selected_model.provider}")
204
+ if max_completion_tokens:
205
+ llm.model_kwargs = {"max_completion_tokens" : max_completion_tokens}
206
+ else:
207
+ # Set max tokens if available
208
+ if max_tokens:
209
+ if provider == 'google':
210
+ llm.max_output_tokens = max_tokens
211
+ else:
212
+ llm.max_tokens = max_tokens
213
+ return llm
214
+
215
+
216
+ def calculate_cost(handler, selected_model):
217
+ input_tokens = handler.input_tokens or 0
218
+ output_tokens = handler.output_tokens or 0
219
+ input_cost_per_million = selected_model.input_cost
220
+ output_cost_per_million = selected_model.output_cost
221
+ # Cost is (tokens / 1_000_000) * cost_per_million
222
+ total_cost = (input_tokens / 1_000_000) * input_cost_per_million + \
223
+ (output_tokens / 1_000_000) * output_cost_per_million
224
+ return total_cost
225
+
226
+
227
+ def llm_invoke(prompt, input_json, strength, temperature, verbose=False, output_pydantic=None):
228
+ # Validate inputs
229
+ if not prompt:
230
+ raise ValueError("Prompt is required.")
231
+ if input_json is None:
232
+ raise ValueError("Input JSON is required.")
233
+ if not isinstance(input_json, dict):
234
+ raise ValueError("Input JSON must be a dictionary.")
235
+
236
+ # Set up cache
237
+ set_llm_cache(SQLiteCache(database_path=".langchain.db"))
238
+
239
+ # Get default model
240
+ base_model_name = os.environ.get('PDD_MODEL_DEFAULT', 'gpt-4o-mini')
241
+
242
+ # Load models
243
+ models = load_models()
244
+
245
+ # Select model
246
+ selected_model = select_model(strength, models, base_model_name)
247
+
248
+ # Create the prompt template
249
+ try:
250
+ prompt_template = PromptTemplate.from_template(prompt)
251
+ except Exception as e:
252
+ raise ValueError(f"Invalid prompt template: {str(e)}")
253
+
254
+ # Create a handler to capture token counts
255
+ handler = CompletionStatusHandler()
256
+
257
+ # Prepare LLM instance
258
+ llm = create_llm_instance(selected_model, temperature, handler)
259
+
260
+ # Handle structured output if output_pydantic is provided
261
+ if output_pydantic:
262
+ pydantic_model = output_pydantic
263
+ parser = PydanticOutputParser(pydantic_object=pydantic_model)
264
+ # Handle models that support structured output
265
+ if selected_model.structured_output:
266
+ llm = llm.with_structured_output(pydantic_model)
267
+ chain = prompt_template | llm
268
+ else:
269
+ # Use parser after the LLM
270
+ chain = prompt_template | llm | parser
271
+ else:
272
+ # Output is a string
273
+ chain = prompt_template | llm | StrOutputParser()
274
+
275
+ # Run the chain
276
+ try:
277
+ result = chain.invoke(input_json)
278
+ except Exception as e:
279
+ raise RuntimeError(f"Error during LLM invocation: {str(e)}")
280
+
281
+ # Calculate cost
282
+ cost = calculate_cost(handler, selected_model)
283
+
284
+ # If verbose, print information
285
+ if verbose:
286
+ rprint(f"Selected model: {selected_model.model}")
287
+ rprint(
288
+ f"Per input token cost: ${selected_model.input_cost} per million tokens")
289
+ rprint(
290
+ f"Per output token cost: ${selected_model.output_cost} per million tokens")
291
+ rprint(f"Number of input tokens: {handler.input_tokens}")
292
+ rprint(f"Number of output tokens: {handler.output_tokens}")
293
+ rprint(f"Cost of invoke run: ${cost}")
294
+ rprint(f"Strength used: {strength}")
295
+ rprint(f"Temperature used: {temperature}")
296
+ try:
297
+ rprint(f"Input JSON: {input_json}")
298
+ except:
299
+ print(f"Input JSON: {input_json}")
300
+ if output_pydantic:
301
+ rprint(f"Output Pydantic: {output_pydantic}")
302
+ rprint(f"Result: {result}")
303
+
304
+ return {'result': result, 'cost': cost, 'model_name': selected_model.model}
@@ -0,0 +1,59 @@
1
+ from pathlib import Path
2
+ import os
3
+ from typing import Optional
4
+ import sys
5
+ from rich import print
6
+
7
+ def print_formatted(message: str) -> None:
8
+ """Print message with raw formatting tags for testing compatibility."""
9
+ print(message)
10
+
11
+ def load_prompt_template(prompt_name: str) -> Optional[str]:
12
+ """
13
+ Load a prompt template from a file.
14
+
15
+ Args:
16
+ prompt_name (str): Name of the prompt file to load (without extension)
17
+
18
+ Returns:
19
+ str: The prompt template text
20
+ """
21
+ # Type checking
22
+ if not isinstance(prompt_name, str):
23
+ print_formatted("[red]Unexpected error loading prompt template[/red]")
24
+ return None
25
+
26
+ # Step 1: Get project path from environment variable
27
+ project_path = os.getenv('PDD_PATH')
28
+ if not project_path:
29
+ print_formatted("[red]PDD_PATH environment variable is not set[/red]")
30
+ return None
31
+
32
+ # Construct the full path to the prompt file
33
+ prompt_path = Path(project_path) / 'prompts' / f"{prompt_name}.prompt"
34
+
35
+ # Step 2: Load and return the prompt template
36
+ if not prompt_path.exists():
37
+ print_formatted(f"[red]Prompt file not found: {prompt_path}[/red]")
38
+ return None
39
+
40
+ try:
41
+ with open(prompt_path, 'r', encoding='utf-8') as file:
42
+ prompt_template = file.read()
43
+ print_formatted(f"[green]Successfully loaded prompt: {prompt_name}[/green]")
44
+ return prompt_template
45
+
46
+ except IOError as e:
47
+ print_formatted(f"[red]Error reading prompt file {prompt_name}: {str(e)}[/red]")
48
+ return None
49
+
50
+ except Exception as e:
51
+ print_formatted(f"[red]Unexpected error loading prompt template: {str(e)}[/red]")
52
+ return None
53
+
54
+ if __name__ == "__main__":
55
+ # Example usage
56
+ prompt = load_prompt_template("example_prompt")
57
+ if prompt:
58
+ print_formatted("[blue]Loaded prompt template:[/blue]")
59
+ print_formatted(prompt)