pdd-cli 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pdd-cli might be problematic. Click here for more details.
- pdd/__init__.py +0 -0
- pdd/auto_deps_main.py +98 -0
- pdd/auto_include.py +175 -0
- pdd/auto_update.py +73 -0
- pdd/bug_main.py +99 -0
- pdd/bug_to_unit_test.py +159 -0
- pdd/change.py +141 -0
- pdd/change_main.py +240 -0
- pdd/cli.py +607 -0
- pdd/cmd_test_main.py +155 -0
- pdd/code_generator.py +117 -0
- pdd/code_generator_main.py +66 -0
- pdd/comment_line.py +35 -0
- pdd/conflicts_in_prompts.py +143 -0
- pdd/conflicts_main.py +90 -0
- pdd/construct_paths.py +251 -0
- pdd/context_generator.py +133 -0
- pdd/context_generator_main.py +73 -0
- pdd/continue_generation.py +140 -0
- pdd/crash_main.py +127 -0
- pdd/data/language_format.csv +61 -0
- pdd/data/llm_model.csv +15 -0
- pdd/detect_change.py +142 -0
- pdd/detect_change_main.py +100 -0
- pdd/find_section.py +28 -0
- pdd/fix_code_loop.py +212 -0
- pdd/fix_code_module_errors.py +143 -0
- pdd/fix_error_loop.py +216 -0
- pdd/fix_errors_from_unit_tests.py +240 -0
- pdd/fix_main.py +138 -0
- pdd/generate_output_paths.py +194 -0
- pdd/generate_test.py +140 -0
- pdd/get_comment.py +55 -0
- pdd/get_extension.py +52 -0
- pdd/get_language.py +41 -0
- pdd/git_update.py +84 -0
- pdd/increase_tests.py +93 -0
- pdd/insert_includes.py +150 -0
- pdd/llm_invoke.py +304 -0
- pdd/load_prompt_template.py +59 -0
- pdd/pdd_completion.fish +72 -0
- pdd/pdd_completion.sh +141 -0
- pdd/pdd_completion.zsh +418 -0
- pdd/postprocess.py +121 -0
- pdd/postprocess_0.py +52 -0
- pdd/preprocess.py +199 -0
- pdd/preprocess_main.py +72 -0
- pdd/process_csv_change.py +182 -0
- pdd/prompts/auto_include_LLM.prompt +230 -0
- pdd/prompts/bug_to_unit_test_LLM.prompt +17 -0
- pdd/prompts/change_LLM.prompt +34 -0
- pdd/prompts/conflict_LLM.prompt +23 -0
- pdd/prompts/continue_generation_LLM.prompt +3 -0
- pdd/prompts/detect_change_LLM.prompt +65 -0
- pdd/prompts/example_generator_LLM.prompt +10 -0
- pdd/prompts/extract_auto_include_LLM.prompt +6 -0
- pdd/prompts/extract_code_LLM.prompt +22 -0
- pdd/prompts/extract_conflict_LLM.prompt +19 -0
- pdd/prompts/extract_detect_change_LLM.prompt +19 -0
- pdd/prompts/extract_program_code_fix_LLM.prompt +16 -0
- pdd/prompts/extract_prompt_change_LLM.prompt +7 -0
- pdd/prompts/extract_prompt_split_LLM.prompt +9 -0
- pdd/prompts/extract_prompt_update_LLM.prompt +8 -0
- pdd/prompts/extract_promptline_LLM.prompt +11 -0
- pdd/prompts/extract_unit_code_fix_LLM.prompt +332 -0
- pdd/prompts/extract_xml_LLM.prompt +7 -0
- pdd/prompts/fix_code_module_errors_LLM.prompt +17 -0
- pdd/prompts/fix_errors_from_unit_tests_LLM.prompt +62 -0
- pdd/prompts/generate_test_LLM.prompt +12 -0
- pdd/prompts/increase_tests_LLM.prompt +16 -0
- pdd/prompts/insert_includes_LLM.prompt +30 -0
- pdd/prompts/split_LLM.prompt +94 -0
- pdd/prompts/summarize_file_LLM.prompt +11 -0
- pdd/prompts/trace_LLM.prompt +30 -0
- pdd/prompts/trim_results_LLM.prompt +83 -0
- pdd/prompts/trim_results_start_LLM.prompt +45 -0
- pdd/prompts/unfinished_prompt_LLM.prompt +18 -0
- pdd/prompts/update_prompt_LLM.prompt +19 -0
- pdd/prompts/xml_convertor_LLM.prompt +54 -0
- pdd/split.py +119 -0
- pdd/split_main.py +103 -0
- pdd/summarize_directory.py +212 -0
- pdd/trace.py +135 -0
- pdd/trace_main.py +108 -0
- pdd/track_cost.py +102 -0
- pdd/unfinished_prompt.py +114 -0
- pdd/update_main.py +96 -0
- pdd/update_prompt.py +115 -0
- pdd/xml_tagger.py +122 -0
- pdd_cli-0.0.2.dist-info/LICENSE +7 -0
- pdd_cli-0.0.2.dist-info/METADATA +225 -0
- pdd_cli-0.0.2.dist-info/RECORD +95 -0
- pdd_cli-0.0.2.dist-info/WHEEL +5 -0
- pdd_cli-0.0.2.dist-info/entry_points.txt +2 -0
- pdd_cli-0.0.2.dist-info/top_level.txt +1 -0
pdd/git_update.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Tuple, Optional
|
|
3
|
+
from rich import print
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from rich.panel import Panel
|
|
6
|
+
from .update_prompt import update_prompt
|
|
7
|
+
import git
|
|
8
|
+
|
|
9
|
+
console = Console()
|
|
10
|
+
|
|
11
|
+
def git_update(
|
|
12
|
+
input_prompt: str,
|
|
13
|
+
modified_code_file: str,
|
|
14
|
+
strength: float,
|
|
15
|
+
temperature: float,
|
|
16
|
+
verbose: bool = False
|
|
17
|
+
) -> Tuple[Optional[str], float, str]:
|
|
18
|
+
"""
|
|
19
|
+
Read in modified code, restore the prior checked-in version from GitHub,
|
|
20
|
+
update the prompt, write back the modified code, and return outputs.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
input_prompt (str): The prompt that generated the original code.
|
|
24
|
+
modified_code_file (str): Filepath of the modified code.
|
|
25
|
+
strength (float): Strength parameter for the LLM model.
|
|
26
|
+
temperature (float): Temperature parameter for the LLM model.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Tuple[Optional[str], float, str]: Modified prompt, total cost, and model name.
|
|
30
|
+
"""
|
|
31
|
+
try:
|
|
32
|
+
# Check if inputs are valid
|
|
33
|
+
if not input_prompt or not modified_code_file:
|
|
34
|
+
raise ValueError("Input prompt and modified code file path are required.")
|
|
35
|
+
|
|
36
|
+
if not os.path.exists(modified_code_file):
|
|
37
|
+
raise FileNotFoundError(f"Modified code file not found: {modified_code_file}")
|
|
38
|
+
|
|
39
|
+
# Initialize git repository
|
|
40
|
+
repo = git.Repo(search_parent_directories=True)
|
|
41
|
+
|
|
42
|
+
# Get the file's relative path to the repo root
|
|
43
|
+
repo_root = repo.git.rev_parse("--show-toplevel")
|
|
44
|
+
relative_path = os.path.relpath(modified_code_file, repo_root)
|
|
45
|
+
|
|
46
|
+
# Read the modified code
|
|
47
|
+
with open(modified_code_file, 'r') as file:
|
|
48
|
+
modified_code = file.read()
|
|
49
|
+
|
|
50
|
+
# Restore the prior checked-in version
|
|
51
|
+
repo.git.checkout('HEAD', '--', relative_path)
|
|
52
|
+
|
|
53
|
+
# Read the original input code
|
|
54
|
+
with open(modified_code_file, 'r') as file:
|
|
55
|
+
original_input_code = file.read()
|
|
56
|
+
|
|
57
|
+
# Call update_prompt function
|
|
58
|
+
modified_prompt, total_cost, model_name = update_prompt(
|
|
59
|
+
input_prompt=input_prompt,
|
|
60
|
+
input_code=original_input_code,
|
|
61
|
+
modified_code=modified_code,
|
|
62
|
+
strength=strength,
|
|
63
|
+
temperature=temperature,
|
|
64
|
+
verbose=verbose
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Write back the modified code
|
|
68
|
+
with open(modified_code_file, 'w') as file:
|
|
69
|
+
file.write(modified_code)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Pretty print the results
|
|
73
|
+
console.print(Panel.fit(
|
|
74
|
+
f"[bold green]Success:[/bold green]\n"
|
|
75
|
+
f"Modified prompt: {modified_prompt}\n"
|
|
76
|
+
f"Total cost: ${total_cost:.6f}\n"
|
|
77
|
+
f"Model name: {model_name}"
|
|
78
|
+
))
|
|
79
|
+
|
|
80
|
+
return modified_prompt, total_cost, model_name
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
console.print(Panel(f"[bold red]Error:[/bold red] {str(e)}", title="Error", expand=False))
|
|
84
|
+
return None, 0.0, ""
|
pdd/increase_tests.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from rich.console import Console
|
|
3
|
+
|
|
4
|
+
from .load_prompt_template import load_prompt_template
|
|
5
|
+
from .llm_invoke import llm_invoke
|
|
6
|
+
from .postprocess import postprocess
|
|
7
|
+
|
|
8
|
+
def increase_tests(
|
|
9
|
+
existing_unit_tests: str,
|
|
10
|
+
coverage_report: str,
|
|
11
|
+
code: str,
|
|
12
|
+
prompt_that_generated_code: str,
|
|
13
|
+
language: str = "python",
|
|
14
|
+
strength: float = 0.5,
|
|
15
|
+
temperature: float = 0.0,
|
|
16
|
+
verbose: bool = False
|
|
17
|
+
) -> Tuple[str, float, str]:
|
|
18
|
+
"""
|
|
19
|
+
Generate additional unit tests to increase code coverage.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
existing_unit_tests (str): Current unit tests for the code
|
|
23
|
+
coverage_report (str): Coverage report for the code
|
|
24
|
+
code (str): Code under test
|
|
25
|
+
prompt_that_generated_code (str): Original prompt used to generate the code
|
|
26
|
+
language (str, optional): Programming language. Defaults to "python".
|
|
27
|
+
strength (float, optional): LLM model strength. Defaults to 0.5.
|
|
28
|
+
temperature (float, optional): LLM model temperature. Defaults to 0.0.
|
|
29
|
+
verbose (bool, optional): Verbose output flag. Defaults to False.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tuple containing:
|
|
33
|
+
- Increased test function (str)
|
|
34
|
+
- Total cost of generation (float)
|
|
35
|
+
- Model name used (str)
|
|
36
|
+
"""
|
|
37
|
+
console = Console()
|
|
38
|
+
|
|
39
|
+
# Validate inputs
|
|
40
|
+
if not all([existing_unit_tests, coverage_report, code, prompt_that_generated_code]):
|
|
41
|
+
raise ValueError("All input parameters must be non-empty strings")
|
|
42
|
+
|
|
43
|
+
# Validate strength and temperature
|
|
44
|
+
if not (0 <= strength <= 1):
|
|
45
|
+
raise ValueError("Strength must be between 0 and 1")
|
|
46
|
+
if not (0 <= temperature <= 1):
|
|
47
|
+
raise ValueError("Temperature must be between 0 and 1")
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
# Step 1: Load prompt template
|
|
51
|
+
prompt_name = "increase_tests_LLM"
|
|
52
|
+
prompt_template = load_prompt_template(prompt_name)
|
|
53
|
+
|
|
54
|
+
if verbose:
|
|
55
|
+
console.print(f"[blue]Loaded Prompt Template:[/blue]\n{prompt_template}")
|
|
56
|
+
|
|
57
|
+
# Step 2: Prepare input for LLM invoke
|
|
58
|
+
input_json = {
|
|
59
|
+
"existing_unit_tests": existing_unit_tests,
|
|
60
|
+
"coverage_report": coverage_report,
|
|
61
|
+
"code": code,
|
|
62
|
+
"prompt_that_generated_code": prompt_that_generated_code,
|
|
63
|
+
"language": language
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Invoke LLM with the prompt
|
|
67
|
+
llm_response = llm_invoke(
|
|
68
|
+
prompt=prompt_template,
|
|
69
|
+
input_json=input_json,
|
|
70
|
+
strength=strength,
|
|
71
|
+
temperature=temperature,
|
|
72
|
+
verbose=verbose
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Step 3: Postprocess the result
|
|
76
|
+
increase_test_function, total_cost, model_name = postprocess(
|
|
77
|
+
llm_response['result'],
|
|
78
|
+
language,
|
|
79
|
+
0.89, # Same strength as LLM invoke
|
|
80
|
+
temperature,
|
|
81
|
+
verbose
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if verbose:
|
|
85
|
+
console.print(f"[green]Generated Test Function:[/green]\n{increase_test_function}")
|
|
86
|
+
console.print(f"[blue]Total Cost: ${total_cost:.6f}[/blue]")
|
|
87
|
+
console.print(f"[blue]Model Used: {model_name}[/blue]")
|
|
88
|
+
|
|
89
|
+
return increase_test_function, total_cost, model_name
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
console.print(f"[red]Error in increase_tests: {str(e)}[/red]")
|
|
93
|
+
raise
|
pdd/insert_includes.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from rich import print
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from .llm_invoke import llm_invoke
|
|
7
|
+
from .load_prompt_template import load_prompt_template
|
|
8
|
+
from .auto_include import auto_include
|
|
9
|
+
from .preprocess import preprocess
|
|
10
|
+
|
|
11
|
+
class InsertIncludesOutput(BaseModel):
|
|
12
|
+
output_prompt: str = Field(description="The prompt with dependencies inserted")
|
|
13
|
+
|
|
14
|
+
def insert_includes(
|
|
15
|
+
input_prompt: str,
|
|
16
|
+
directory_path: str,
|
|
17
|
+
csv_filename: str,
|
|
18
|
+
strength: float,
|
|
19
|
+
temperature: float,
|
|
20
|
+
verbose: bool = False
|
|
21
|
+
) -> Tuple[str, str, float, str]:
|
|
22
|
+
"""
|
|
23
|
+
Determine needed dependencies and insert them into a prompt.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
input_prompt (str): The prompt to process
|
|
27
|
+
directory_path (str): Directory path where the prompt file is located
|
|
28
|
+
csv_filename (str): Name of the CSV file containing dependencies
|
|
29
|
+
strength (float): Strength parameter for the LLM model
|
|
30
|
+
temperature (float): Temperature parameter for the LLM model
|
|
31
|
+
verbose (bool, optional): Whether to print detailed information. Defaults to False.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tuple[str, str, float, str]: Tuple containing:
|
|
35
|
+
- output_prompt: The prompt with dependencies inserted
|
|
36
|
+
- csv_output: Complete CSV output from auto_include
|
|
37
|
+
- total_cost: Total cost of running the function
|
|
38
|
+
- model_name: Name of the LLM model used
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
# Step 1: Load the prompt template
|
|
42
|
+
insert_includes_prompt = load_prompt_template("insert_includes_LLM")
|
|
43
|
+
if not insert_includes_prompt:
|
|
44
|
+
raise ValueError("Failed to load insert_includes_LLM.prompt template")
|
|
45
|
+
|
|
46
|
+
if verbose:
|
|
47
|
+
print("[blue]Loaded insert_includes_LLM prompt template[/blue]")
|
|
48
|
+
|
|
49
|
+
# Step 2: Read the CSV file
|
|
50
|
+
try:
|
|
51
|
+
with open(csv_filename, 'r') as file:
|
|
52
|
+
csv_content = file.read()
|
|
53
|
+
except FileNotFoundError:
|
|
54
|
+
if verbose:
|
|
55
|
+
print(f"[yellow]CSV file {csv_filename} not found. Creating empty CSV.[/yellow]")
|
|
56
|
+
csv_content = "full_path,file_summary,date\n"
|
|
57
|
+
Path(csv_filename).write_text(csv_content)
|
|
58
|
+
|
|
59
|
+
# Step 3: Preprocess the prompt template
|
|
60
|
+
processed_prompt = preprocess(
|
|
61
|
+
insert_includes_prompt,
|
|
62
|
+
recursive=False,
|
|
63
|
+
double_curly_brackets=False
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if verbose:
|
|
67
|
+
print("[blue]Preprocessed prompt template[/blue]")
|
|
68
|
+
|
|
69
|
+
# Step 4: Get dependencies using auto_include
|
|
70
|
+
dependencies, csv_output, auto_include_cost, auto_include_model = auto_include(
|
|
71
|
+
input_prompt=input_prompt,
|
|
72
|
+
directory_path=directory_path,
|
|
73
|
+
csv_file=csv_content,
|
|
74
|
+
strength=strength,
|
|
75
|
+
temperature=temperature,
|
|
76
|
+
verbose=verbose
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if verbose:
|
|
80
|
+
print("[blue]Retrieved dependencies using auto_include[/blue]")
|
|
81
|
+
print(f"Dependencies found: {dependencies}")
|
|
82
|
+
|
|
83
|
+
# Step 5: Run llm_invoke with the insert includes prompt
|
|
84
|
+
response = llm_invoke(
|
|
85
|
+
prompt=processed_prompt,
|
|
86
|
+
input_json={
|
|
87
|
+
"actual_prompt_to_update": input_prompt,
|
|
88
|
+
"actual_dependencies_to_insert": dependencies
|
|
89
|
+
},
|
|
90
|
+
strength=strength,
|
|
91
|
+
temperature=temperature,
|
|
92
|
+
verbose=verbose,
|
|
93
|
+
output_pydantic=InsertIncludesOutput
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if not response or 'result' not in response:
|
|
97
|
+
raise ValueError("Failed to get valid response from LLM model")
|
|
98
|
+
|
|
99
|
+
result: InsertIncludesOutput = response['result']
|
|
100
|
+
model_name = response['model_name']
|
|
101
|
+
total_cost = response['cost'] + auto_include_cost
|
|
102
|
+
|
|
103
|
+
if verbose:
|
|
104
|
+
print("[green]Successfully inserted includes into prompt[/green]")
|
|
105
|
+
print(f"Total cost: ${total_cost:.6f}")
|
|
106
|
+
print(f"Model used: {model_name}")
|
|
107
|
+
|
|
108
|
+
return (
|
|
109
|
+
result.output_prompt,
|
|
110
|
+
csv_output,
|
|
111
|
+
total_cost,
|
|
112
|
+
model_name
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
except Exception as e:
|
|
116
|
+
print(f"[red]Error in insert_includes: {str(e)}[/red]")
|
|
117
|
+
raise
|
|
118
|
+
|
|
119
|
+
def main():
|
|
120
|
+
"""Example usage of the insert_includes function."""
|
|
121
|
+
# Example input
|
|
122
|
+
input_prompt = """% Generate a Python function that processes data
|
|
123
|
+
<include>data_processing.py</include>
|
|
124
|
+
"""
|
|
125
|
+
directory_path = "./src"
|
|
126
|
+
csv_filename = "dependencies.csv"
|
|
127
|
+
strength = 0.7
|
|
128
|
+
temperature = 0.5
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
output_prompt, csv_output, total_cost, model_name = insert_includes(
|
|
132
|
+
input_prompt=input_prompt,
|
|
133
|
+
directory_path=directory_path,
|
|
134
|
+
csv_filename=csv_filename,
|
|
135
|
+
strength=strength,
|
|
136
|
+
temperature=temperature,
|
|
137
|
+
verbose=True
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
print("\n[bold green]Results:[/bold green]")
|
|
141
|
+
print(f"[white]Output Prompt:[/white]\n{output_prompt}")
|
|
142
|
+
print(f"\n[white]CSV Output:[/white]\n{csv_output}")
|
|
143
|
+
print(f"[white]Total Cost: ${total_cost:.6f}[/white]")
|
|
144
|
+
print(f"[white]Model Used: {model_name}[/white]")
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
print(f"[red]Error in main: {str(e)}[/red]")
|
|
148
|
+
|
|
149
|
+
if __name__ == "__main__":
|
|
150
|
+
main()
|
pdd/llm_invoke.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# llm_invoke.py
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import csv
|
|
5
|
+
import json
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
from rich import print as rprint
|
|
8
|
+
|
|
9
|
+
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
|
|
10
|
+
from langchain_community.cache import SQLiteCache
|
|
11
|
+
from langchain.globals import set_llm_cache
|
|
12
|
+
from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
|
|
13
|
+
from langchain_core.runnables import RunnablePassthrough, ConfigurableField
|
|
14
|
+
|
|
15
|
+
from langchain_openai import AzureChatOpenAI
|
|
16
|
+
from langchain_fireworks import Fireworks
|
|
17
|
+
from langchain_anthropic import ChatAnthropic
|
|
18
|
+
from langchain_openai import ChatOpenAI # Chatbot and conversational tasks
|
|
19
|
+
from langchain_openai import OpenAI # General language tasks
|
|
20
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
21
|
+
from langchain_groq import ChatGroq
|
|
22
|
+
from langchain_together import Together
|
|
23
|
+
from langchain_ollama.llms import OllamaLLM
|
|
24
|
+
|
|
25
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
|
26
|
+
from langchain.schema import LLMResult
|
|
27
|
+
|
|
28
|
+
# import logging
|
|
29
|
+
|
|
30
|
+
# Configure logging to output to the console
|
|
31
|
+
# logging.basicConfig(level=logging.DEBUG)
|
|
32
|
+
|
|
33
|
+
# Get the LangSmith logger
|
|
34
|
+
# langsmith_logger = logging.getLogger("langsmith")
|
|
35
|
+
|
|
36
|
+
# Set its logging level to DEBUG
|
|
37
|
+
# langsmith_logger.setLevel(logging.DEBUG)
|
|
38
|
+
|
|
39
|
+
class CompletionStatusHandler(BaseCallbackHandler):
|
|
40
|
+
def __init__(self):
|
|
41
|
+
self.is_complete = False
|
|
42
|
+
self.finish_reason = None
|
|
43
|
+
self.input_tokens = None
|
|
44
|
+
self.output_tokens = None
|
|
45
|
+
|
|
46
|
+
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
|
|
47
|
+
self.is_complete = True
|
|
48
|
+
if response.generations and response.generations[0]:
|
|
49
|
+
generation = response.generations[0][0]
|
|
50
|
+
self.finish_reason = generation.generation_info.get('finish_reason', "").lower()
|
|
51
|
+
|
|
52
|
+
# Extract token usage
|
|
53
|
+
if hasattr(generation.message, 'usage_metadata'):
|
|
54
|
+
usage_metadata = generation.message.usage_metadata
|
|
55
|
+
self.input_tokens = usage_metadata.get('input_tokens')
|
|
56
|
+
self.output_tokens = usage_metadata.get('output_tokens')
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ModelInfo:
|
|
60
|
+
def __init__(self, provider, model, input_cost, output_cost, coding_arena_elo,
|
|
61
|
+
base_url, api_key, counter, encoder, max_tokens, max_completion_tokens,
|
|
62
|
+
structured_output):
|
|
63
|
+
self.provider = provider.strip()
|
|
64
|
+
self.model = model.strip()
|
|
65
|
+
self.input_cost = float(input_cost) if input_cost else 0.0
|
|
66
|
+
self.output_cost = float(output_cost) if output_cost else 0.0
|
|
67
|
+
self.average_cost = (self.input_cost + self.output_cost) / 2
|
|
68
|
+
self.coding_arena_elo = float(coding_arena_elo) if coding_arena_elo else 0.0
|
|
69
|
+
self.base_url = base_url.strip() if base_url else None
|
|
70
|
+
self.api_key = api_key.strip() if api_key else None
|
|
71
|
+
self.counter = counter.strip() if counter else None
|
|
72
|
+
self.encoder = encoder.strip() if encoder else None
|
|
73
|
+
self.max_tokens = int(max_tokens) if max_tokens else None
|
|
74
|
+
self.max_completion_tokens = int(
|
|
75
|
+
max_completion_tokens) if max_completion_tokens else None
|
|
76
|
+
self.structured_output = structured_output.lower(
|
|
77
|
+
) == 'true' if structured_output else False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def load_models():
|
|
81
|
+
PDD_PATH = os.environ.get('PDD_PATH', '.')
|
|
82
|
+
# Assume that llm_model.csv is in PDD_PATH/data
|
|
83
|
+
models_file = os.path.join(PDD_PATH, 'data', 'llm_model.csv')
|
|
84
|
+
models = []
|
|
85
|
+
try:
|
|
86
|
+
with open(models_file, newline='') as csvfile:
|
|
87
|
+
reader = csv.DictReader(csvfile)
|
|
88
|
+
for row in reader:
|
|
89
|
+
model_info = ModelInfo(
|
|
90
|
+
provider=row['provider'],
|
|
91
|
+
model=row['model'],
|
|
92
|
+
input_cost=row['input'],
|
|
93
|
+
output_cost=row['output'],
|
|
94
|
+
coding_arena_elo=row['coding_arena_elo'],
|
|
95
|
+
base_url=row['base_url'],
|
|
96
|
+
api_key=row['api_key'],
|
|
97
|
+
counter=row['counter'],
|
|
98
|
+
encoder=row['encoder'],
|
|
99
|
+
max_tokens=row['max_tokens'],
|
|
100
|
+
max_completion_tokens=row['max_completion_tokens'],
|
|
101
|
+
structured_output=row['structured_output']
|
|
102
|
+
)
|
|
103
|
+
models.append(model_info)
|
|
104
|
+
except FileNotFoundError:
|
|
105
|
+
raise FileNotFoundError(f"llm_model.csv not found at {models_file}")
|
|
106
|
+
return models
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def select_model(strength, models, base_model_name):
|
|
110
|
+
# Get the base model
|
|
111
|
+
base_model = None
|
|
112
|
+
for model in models:
|
|
113
|
+
if model.model == base_model_name:
|
|
114
|
+
base_model = model
|
|
115
|
+
break
|
|
116
|
+
if not base_model:
|
|
117
|
+
raise ValueError(f"Base model {base_model_name} not found in the models list.")
|
|
118
|
+
|
|
119
|
+
if strength == 0.5:
|
|
120
|
+
return base_model
|
|
121
|
+
elif strength < 0.5:
|
|
122
|
+
# Models cheaper than or equal to the base model
|
|
123
|
+
cheaper_models = [
|
|
124
|
+
model for model in models if model.average_cost <= base_model.average_cost]
|
|
125
|
+
# Sort models by average_cost ascending
|
|
126
|
+
cheaper_models.sort(key=lambda m: m.average_cost)
|
|
127
|
+
if not cheaper_models:
|
|
128
|
+
return base_model
|
|
129
|
+
# Interpolate between cheapest model and base model
|
|
130
|
+
cheapest_model = cheaper_models[0]
|
|
131
|
+
cost_range = base_model.average_cost - cheapest_model.average_cost
|
|
132
|
+
target_cost = cheapest_model.average_cost + (strength / 0.5) * cost_range
|
|
133
|
+
# Find the model with closest average cost to target_cost
|
|
134
|
+
selected_model = min(
|
|
135
|
+
cheaper_models, key=lambda m: abs(m.average_cost - target_cost))
|
|
136
|
+
return selected_model
|
|
137
|
+
else:
|
|
138
|
+
# strength > 0.5
|
|
139
|
+
# Models better than or equal to the base model
|
|
140
|
+
better_models = [
|
|
141
|
+
model for model in models if model.coding_arena_elo >= base_model.coding_arena_elo]
|
|
142
|
+
# Sort models by coding_arena_elo ascending
|
|
143
|
+
better_models.sort(key=lambda m: m.coding_arena_elo)
|
|
144
|
+
if not better_models:
|
|
145
|
+
return base_model
|
|
146
|
+
# Interpolate between base model and highest ELO model
|
|
147
|
+
highest_elo_model = better_models[-1]
|
|
148
|
+
elo_range = highest_elo_model.coding_arena_elo - base_model.coding_arena_elo
|
|
149
|
+
target_elo = base_model.coding_arena_elo + \
|
|
150
|
+
((strength - 0.5) / 0.5) * elo_range
|
|
151
|
+
# Find the model with closest ELO to target_elo
|
|
152
|
+
selected_model = min(
|
|
153
|
+
better_models, key=lambda m: abs(m.coding_arena_elo - target_elo))
|
|
154
|
+
return selected_model
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def create_llm_instance(selected_model, temperature, handler):
|
|
158
|
+
provider = selected_model.provider.lower()
|
|
159
|
+
model_name = selected_model.model
|
|
160
|
+
base_url = selected_model.base_url
|
|
161
|
+
api_key_name = selected_model.api_key
|
|
162
|
+
max_completion_tokens = selected_model.max_completion_tokens
|
|
163
|
+
max_tokens = selected_model.max_tokens
|
|
164
|
+
|
|
165
|
+
# Retrieve API key from environment variable if needed
|
|
166
|
+
api_key = os.environ.get(api_key_name) if api_key_name else None
|
|
167
|
+
|
|
168
|
+
# Initialize the appropriate LLM class
|
|
169
|
+
if provider == 'openai':
|
|
170
|
+
if base_url:
|
|
171
|
+
llm = ChatOpenAI(model=model_name, temperature=temperature,
|
|
172
|
+
openai_api_key=api_key, callbacks=[handler], openai_api_base = base_url)
|
|
173
|
+
else:
|
|
174
|
+
if model_name[0] == 'o':
|
|
175
|
+
llm = ChatOpenAI(model=model_name, temperature=temperature,
|
|
176
|
+
openai_api_key=api_key, callbacks=[handler],
|
|
177
|
+
model_kwargs = {'reasoning_effort':'high'})
|
|
178
|
+
else:
|
|
179
|
+
llm = ChatOpenAI(model=model_name, temperature=temperature,
|
|
180
|
+
openai_api_key=api_key, callbacks=[handler])
|
|
181
|
+
elif provider == 'anthropic':
|
|
182
|
+
llm = ChatAnthropic(model=model_name, temperature=temperature,
|
|
183
|
+
callbacks=[handler])
|
|
184
|
+
elif provider == 'google':
|
|
185
|
+
llm = ChatGoogleGenerativeAI(
|
|
186
|
+
model=model_name, temperature=temperature, callbacks=[handler])
|
|
187
|
+
elif provider == 'ollama':
|
|
188
|
+
llm = OllamaLLM(
|
|
189
|
+
model=model_name, temperature=temperature, callbacks=[handler])
|
|
190
|
+
elif provider == 'azure':
|
|
191
|
+
llm = AzureChatOpenAI(
|
|
192
|
+
model=model_name, temperature=temperature, callbacks=[handler])
|
|
193
|
+
elif provider == 'fireworks':
|
|
194
|
+
llm = Fireworks(model=model_name, temperature=temperature,
|
|
195
|
+
callbacks=[handler])
|
|
196
|
+
elif provider == 'together':
|
|
197
|
+
llm = Together(model=model_name, temperature=temperature,
|
|
198
|
+
callbacks=[handler])
|
|
199
|
+
elif provider == 'groq':
|
|
200
|
+
llm = ChatGroq(model_name=model_name, temperature=temperature,
|
|
201
|
+
callbacks=[handler])
|
|
202
|
+
else:
|
|
203
|
+
raise ValueError(f"Unsupported provider: {selected_model.provider}")
|
|
204
|
+
if max_completion_tokens:
|
|
205
|
+
llm.model_kwargs = {"max_completion_tokens" : max_completion_tokens}
|
|
206
|
+
else:
|
|
207
|
+
# Set max tokens if available
|
|
208
|
+
if max_tokens:
|
|
209
|
+
if provider == 'google':
|
|
210
|
+
llm.max_output_tokens = max_tokens
|
|
211
|
+
else:
|
|
212
|
+
llm.max_tokens = max_tokens
|
|
213
|
+
return llm
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def calculate_cost(handler, selected_model):
|
|
217
|
+
input_tokens = handler.input_tokens or 0
|
|
218
|
+
output_tokens = handler.output_tokens or 0
|
|
219
|
+
input_cost_per_million = selected_model.input_cost
|
|
220
|
+
output_cost_per_million = selected_model.output_cost
|
|
221
|
+
# Cost is (tokens / 1_000_000) * cost_per_million
|
|
222
|
+
total_cost = (input_tokens / 1_000_000) * input_cost_per_million + \
|
|
223
|
+
(output_tokens / 1_000_000) * output_cost_per_million
|
|
224
|
+
return total_cost
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def llm_invoke(prompt, input_json, strength, temperature, verbose=False, output_pydantic=None):
|
|
228
|
+
# Validate inputs
|
|
229
|
+
if not prompt:
|
|
230
|
+
raise ValueError("Prompt is required.")
|
|
231
|
+
if input_json is None:
|
|
232
|
+
raise ValueError("Input JSON is required.")
|
|
233
|
+
if not isinstance(input_json, dict):
|
|
234
|
+
raise ValueError("Input JSON must be a dictionary.")
|
|
235
|
+
|
|
236
|
+
# Set up cache
|
|
237
|
+
set_llm_cache(SQLiteCache(database_path=".langchain.db"))
|
|
238
|
+
|
|
239
|
+
# Get default model
|
|
240
|
+
base_model_name = os.environ.get('PDD_MODEL_DEFAULT', 'gpt-4o-mini')
|
|
241
|
+
|
|
242
|
+
# Load models
|
|
243
|
+
models = load_models()
|
|
244
|
+
|
|
245
|
+
# Select model
|
|
246
|
+
selected_model = select_model(strength, models, base_model_name)
|
|
247
|
+
|
|
248
|
+
# Create the prompt template
|
|
249
|
+
try:
|
|
250
|
+
prompt_template = PromptTemplate.from_template(prompt)
|
|
251
|
+
except Exception as e:
|
|
252
|
+
raise ValueError(f"Invalid prompt template: {str(e)}")
|
|
253
|
+
|
|
254
|
+
# Create a handler to capture token counts
|
|
255
|
+
handler = CompletionStatusHandler()
|
|
256
|
+
|
|
257
|
+
# Prepare LLM instance
|
|
258
|
+
llm = create_llm_instance(selected_model, temperature, handler)
|
|
259
|
+
|
|
260
|
+
# Handle structured output if output_pydantic is provided
|
|
261
|
+
if output_pydantic:
|
|
262
|
+
pydantic_model = output_pydantic
|
|
263
|
+
parser = PydanticOutputParser(pydantic_object=pydantic_model)
|
|
264
|
+
# Handle models that support structured output
|
|
265
|
+
if selected_model.structured_output:
|
|
266
|
+
llm = llm.with_structured_output(pydantic_model)
|
|
267
|
+
chain = prompt_template | llm
|
|
268
|
+
else:
|
|
269
|
+
# Use parser after the LLM
|
|
270
|
+
chain = prompt_template | llm | parser
|
|
271
|
+
else:
|
|
272
|
+
# Output is a string
|
|
273
|
+
chain = prompt_template | llm | StrOutputParser()
|
|
274
|
+
|
|
275
|
+
# Run the chain
|
|
276
|
+
try:
|
|
277
|
+
result = chain.invoke(input_json)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
raise RuntimeError(f"Error during LLM invocation: {str(e)}")
|
|
280
|
+
|
|
281
|
+
# Calculate cost
|
|
282
|
+
cost = calculate_cost(handler, selected_model)
|
|
283
|
+
|
|
284
|
+
# If verbose, print information
|
|
285
|
+
if verbose:
|
|
286
|
+
rprint(f"Selected model: {selected_model.model}")
|
|
287
|
+
rprint(
|
|
288
|
+
f"Per input token cost: ${selected_model.input_cost} per million tokens")
|
|
289
|
+
rprint(
|
|
290
|
+
f"Per output token cost: ${selected_model.output_cost} per million tokens")
|
|
291
|
+
rprint(f"Number of input tokens: {handler.input_tokens}")
|
|
292
|
+
rprint(f"Number of output tokens: {handler.output_tokens}")
|
|
293
|
+
rprint(f"Cost of invoke run: ${cost}")
|
|
294
|
+
rprint(f"Strength used: {strength}")
|
|
295
|
+
rprint(f"Temperature used: {temperature}")
|
|
296
|
+
try:
|
|
297
|
+
rprint(f"Input JSON: {input_json}")
|
|
298
|
+
except:
|
|
299
|
+
print(f"Input JSON: {input_json}")
|
|
300
|
+
if output_pydantic:
|
|
301
|
+
rprint(f"Output Pydantic: {output_pydantic}")
|
|
302
|
+
rprint(f"Result: {result}")
|
|
303
|
+
|
|
304
|
+
return {'result': result, 'cost': cost, 'model_name': selected_model.model}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import sys
|
|
5
|
+
from rich import print
|
|
6
|
+
|
|
7
|
+
def print_formatted(message: str) -> None:
|
|
8
|
+
"""Print message with raw formatting tags for testing compatibility."""
|
|
9
|
+
print(message)
|
|
10
|
+
|
|
11
|
+
def load_prompt_template(prompt_name: str) -> Optional[str]:
|
|
12
|
+
"""
|
|
13
|
+
Load a prompt template from a file.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
prompt_name (str): Name of the prompt file to load (without extension)
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
str: The prompt template text
|
|
20
|
+
"""
|
|
21
|
+
# Type checking
|
|
22
|
+
if not isinstance(prompt_name, str):
|
|
23
|
+
print_formatted("[red]Unexpected error loading prompt template[/red]")
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
# Step 1: Get project path from environment variable
|
|
27
|
+
project_path = os.getenv('PDD_PATH')
|
|
28
|
+
if not project_path:
|
|
29
|
+
print_formatted("[red]PDD_PATH environment variable is not set[/red]")
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
# Construct the full path to the prompt file
|
|
33
|
+
prompt_path = Path(project_path) / 'prompts' / f"{prompt_name}.prompt"
|
|
34
|
+
|
|
35
|
+
# Step 2: Load and return the prompt template
|
|
36
|
+
if not prompt_path.exists():
|
|
37
|
+
print_formatted(f"[red]Prompt file not found: {prompt_path}[/red]")
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
with open(prompt_path, 'r', encoding='utf-8') as file:
|
|
42
|
+
prompt_template = file.read()
|
|
43
|
+
print_formatted(f"[green]Successfully loaded prompt: {prompt_name}[/green]")
|
|
44
|
+
return prompt_template
|
|
45
|
+
|
|
46
|
+
except IOError as e:
|
|
47
|
+
print_formatted(f"[red]Error reading prompt file {prompt_name}: {str(e)}[/red]")
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
except Exception as e:
|
|
51
|
+
print_formatted(f"[red]Unexpected error loading prompt template: {str(e)}[/red]")
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
if __name__ == "__main__":
|
|
55
|
+
# Example usage
|
|
56
|
+
prompt = load_prompt_template("example_prompt")
|
|
57
|
+
if prompt:
|
|
58
|
+
print_formatted("[blue]Loaded prompt template:[/blue]")
|
|
59
|
+
print_formatted(prompt)
|