pdd-cli 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pdd-cli might be problematic. Click here for more details.

Files changed (95) hide show
  1. pdd/__init__.py +0 -0
  2. pdd/auto_deps_main.py +98 -0
  3. pdd/auto_include.py +175 -0
  4. pdd/auto_update.py +73 -0
  5. pdd/bug_main.py +99 -0
  6. pdd/bug_to_unit_test.py +159 -0
  7. pdd/change.py +141 -0
  8. pdd/change_main.py +240 -0
  9. pdd/cli.py +607 -0
  10. pdd/cmd_test_main.py +155 -0
  11. pdd/code_generator.py +117 -0
  12. pdd/code_generator_main.py +66 -0
  13. pdd/comment_line.py +35 -0
  14. pdd/conflicts_in_prompts.py +143 -0
  15. pdd/conflicts_main.py +90 -0
  16. pdd/construct_paths.py +251 -0
  17. pdd/context_generator.py +133 -0
  18. pdd/context_generator_main.py +73 -0
  19. pdd/continue_generation.py +140 -0
  20. pdd/crash_main.py +127 -0
  21. pdd/data/language_format.csv +61 -0
  22. pdd/data/llm_model.csv +15 -0
  23. pdd/detect_change.py +142 -0
  24. pdd/detect_change_main.py +100 -0
  25. pdd/find_section.py +28 -0
  26. pdd/fix_code_loop.py +212 -0
  27. pdd/fix_code_module_errors.py +143 -0
  28. pdd/fix_error_loop.py +216 -0
  29. pdd/fix_errors_from_unit_tests.py +240 -0
  30. pdd/fix_main.py +138 -0
  31. pdd/generate_output_paths.py +194 -0
  32. pdd/generate_test.py +140 -0
  33. pdd/get_comment.py +55 -0
  34. pdd/get_extension.py +52 -0
  35. pdd/get_language.py +41 -0
  36. pdd/git_update.py +84 -0
  37. pdd/increase_tests.py +93 -0
  38. pdd/insert_includes.py +150 -0
  39. pdd/llm_invoke.py +304 -0
  40. pdd/load_prompt_template.py +59 -0
  41. pdd/pdd_completion.fish +72 -0
  42. pdd/pdd_completion.sh +141 -0
  43. pdd/pdd_completion.zsh +418 -0
  44. pdd/postprocess.py +121 -0
  45. pdd/postprocess_0.py +52 -0
  46. pdd/preprocess.py +199 -0
  47. pdd/preprocess_main.py +72 -0
  48. pdd/process_csv_change.py +182 -0
  49. pdd/prompts/auto_include_LLM.prompt +230 -0
  50. pdd/prompts/bug_to_unit_test_LLM.prompt +17 -0
  51. pdd/prompts/change_LLM.prompt +34 -0
  52. pdd/prompts/conflict_LLM.prompt +23 -0
  53. pdd/prompts/continue_generation_LLM.prompt +3 -0
  54. pdd/prompts/detect_change_LLM.prompt +65 -0
  55. pdd/prompts/example_generator_LLM.prompt +10 -0
  56. pdd/prompts/extract_auto_include_LLM.prompt +6 -0
  57. pdd/prompts/extract_code_LLM.prompt +22 -0
  58. pdd/prompts/extract_conflict_LLM.prompt +19 -0
  59. pdd/prompts/extract_detect_change_LLM.prompt +19 -0
  60. pdd/prompts/extract_program_code_fix_LLM.prompt +16 -0
  61. pdd/prompts/extract_prompt_change_LLM.prompt +7 -0
  62. pdd/prompts/extract_prompt_split_LLM.prompt +9 -0
  63. pdd/prompts/extract_prompt_update_LLM.prompt +8 -0
  64. pdd/prompts/extract_promptline_LLM.prompt +11 -0
  65. pdd/prompts/extract_unit_code_fix_LLM.prompt +332 -0
  66. pdd/prompts/extract_xml_LLM.prompt +7 -0
  67. pdd/prompts/fix_code_module_errors_LLM.prompt +17 -0
  68. pdd/prompts/fix_errors_from_unit_tests_LLM.prompt +62 -0
  69. pdd/prompts/generate_test_LLM.prompt +12 -0
  70. pdd/prompts/increase_tests_LLM.prompt +16 -0
  71. pdd/prompts/insert_includes_LLM.prompt +30 -0
  72. pdd/prompts/split_LLM.prompt +94 -0
  73. pdd/prompts/summarize_file_LLM.prompt +11 -0
  74. pdd/prompts/trace_LLM.prompt +30 -0
  75. pdd/prompts/trim_results_LLM.prompt +83 -0
  76. pdd/prompts/trim_results_start_LLM.prompt +45 -0
  77. pdd/prompts/unfinished_prompt_LLM.prompt +18 -0
  78. pdd/prompts/update_prompt_LLM.prompt +19 -0
  79. pdd/prompts/xml_convertor_LLM.prompt +54 -0
  80. pdd/split.py +119 -0
  81. pdd/split_main.py +103 -0
  82. pdd/summarize_directory.py +212 -0
  83. pdd/trace.py +135 -0
  84. pdd/trace_main.py +108 -0
  85. pdd/track_cost.py +102 -0
  86. pdd/unfinished_prompt.py +114 -0
  87. pdd/update_main.py +96 -0
  88. pdd/update_prompt.py +115 -0
  89. pdd/xml_tagger.py +122 -0
  90. pdd_cli-0.0.2.dist-info/LICENSE +7 -0
  91. pdd_cli-0.0.2.dist-info/METADATA +225 -0
  92. pdd_cli-0.0.2.dist-info/RECORD +95 -0
  93. pdd_cli-0.0.2.dist-info/WHEEL +5 -0
  94. pdd_cli-0.0.2.dist-info/entry_points.txt +2 -0
  95. pdd_cli-0.0.2.dist-info/top_level.txt +1 -0
pdd/postprocess.py ADDED
@@ -0,0 +1,121 @@
1
+ from typing import Tuple
2
+ from rich import print
3
+ from pydantic import BaseModel, Field
4
+ from .load_prompt_template import load_prompt_template
5
+ from .llm_invoke import llm_invoke
6
+
7
+ class ExtractedCode(BaseModel):
8
+ """Pydantic model for the extracted code."""
9
+ extracted_code: str = Field(description="The extracted code from the LLM output")
10
+
11
+ def postprocess_0(text: str) -> str:
12
+ """
13
+ Simple code extraction for strength = 0.
14
+ Extracts code between triple backticks.
15
+ """
16
+ lines = text.split('\n')
17
+ code_lines = []
18
+ in_code_block = False
19
+
20
+ for line in lines:
21
+ if line.startswith('```'):
22
+ if not in_code_block:
23
+ # Skip the language identifier line
24
+ in_code_block = True
25
+ continue
26
+ else:
27
+ in_code_block = False
28
+ continue
29
+ if in_code_block:
30
+ code_lines.append(line)
31
+
32
+ return '\n'.join(code_lines)
33
+
34
+ def postprocess(
35
+ llm_output: str,
36
+ language: str,
37
+ strength: float = 0.9,
38
+ temperature: float = 0,
39
+ verbose: bool = False
40
+ ) -> Tuple[str, float, str]:
41
+ """
42
+ Extract code from LLM output string.
43
+
44
+ Args:
45
+ llm_output (str): The string output from the LLM containing code sections
46
+ language (str): The programming language of the code to extract
47
+ strength (float): The strength of the LLM model to use (0-1)
48
+ temperature (float): The temperature parameter for the LLM (0-1)
49
+ verbose (bool): Whether to print detailed processing information
50
+
51
+ Returns:
52
+ Tuple[str, float, str]: (extracted_code, total_cost, model_name)
53
+ """
54
+ try:
55
+ # Input validation
56
+ if not llm_output or not isinstance(llm_output, str):
57
+ raise ValueError("llm_output must be a non-empty string")
58
+ if not language or not isinstance(language, str):
59
+ raise ValueError("language must be a non-empty string")
60
+ if not 0 <= strength <= 1:
61
+ raise ValueError("strength must be between 0 and 1")
62
+ if not 0 <= temperature <= 1:
63
+ raise ValueError("temperature must be between 0 and 1")
64
+
65
+ # Step 1: If strength is 0, use simple extraction
66
+ if strength == 0:
67
+ if verbose:
68
+ print("[blue]Using simple code extraction (strength = 0)[/blue]")
69
+ return (postprocess_0(llm_output), 0.0, "simple_extraction")
70
+
71
+ # Step 2: Load the prompt template
72
+ prompt_template = load_prompt_template("extract_code_LLM")
73
+ if not prompt_template:
74
+ raise ValueError("Failed to load prompt template")
75
+
76
+ if verbose:
77
+ print("[blue]Loaded prompt template for code extraction[/blue]")
78
+
79
+ # Step 3: Process using llm_invoke
80
+ input_json = {
81
+ "llm_output": llm_output,
82
+ "language": language
83
+ }
84
+
85
+ response = llm_invoke(
86
+ prompt=prompt_template,
87
+ input_json=input_json,
88
+ strength=strength,
89
+ temperature=temperature,
90
+ verbose=verbose,
91
+ output_pydantic=ExtractedCode
92
+ )
93
+
94
+ if not response or 'result' not in response:
95
+ raise ValueError("Failed to get valid response from LLM")
96
+
97
+ extracted_code: ExtractedCode = response['result']
98
+ code_text = extracted_code.extracted_code
99
+
100
+ # Step 3c: Remove triple backticks and language identifier if present
101
+ lines = code_text.split('\n')
102
+ if lines and lines[0].startswith('```'):
103
+ lines = lines[1:]
104
+ if lines and lines[-1].startswith('```'):
105
+ lines = lines[:-1]
106
+
107
+ final_code = '\n'.join(lines)
108
+
109
+ if verbose:
110
+ print("[green]Successfully extracted code[/green]")
111
+
112
+ # Step 4: Return the results
113
+ return (
114
+ final_code,
115
+ response['cost'],
116
+ response['model_name']
117
+ )
118
+
119
+ except Exception as e:
120
+ print(f"[red]Error in postprocess: {str(e)}[/red]")
121
+ raise
pdd/postprocess_0.py ADDED
@@ -0,0 +1,52 @@
1
+ #Here's the implementation of the `postprocess_0` function based on your requirements:
2
+ #
3
+ #```python
4
+ from .get_comment import get_comment
5
+ from .comment_line import comment_line
6
+ from .find_section import find_section
7
+
8
+ def postprocess_0(llm_output: str, language: str) -> str:
9
+ # Step 1: Get the comment character for the specified language
10
+ comment_char = get_comment(language)
11
+
12
+ # Step 2: Find code sections in the llm_output
13
+ lines = llm_output.splitlines()
14
+ sections = find_section(lines)
15
+
16
+ # Step 3: Find the largest section of the specified language
17
+ largest_section = None
18
+ max_size = 0
19
+ for code_lang, start, end in sections:
20
+ if code_lang.lower() == language.lower():
21
+ size = end - start
22
+ if size > max_size:
23
+ max_size = size
24
+ largest_section = (start, end)
25
+
26
+ # Step 4 & 5: Comment out lines outside the largest section
27
+ processed_lines = []
28
+ in_code_section = False
29
+ for i, line in enumerate(lines):
30
+ if largest_section and i == largest_section[0] + 1:
31
+ in_code_section = True
32
+ elif largest_section and i == largest_section[1]:
33
+ in_code_section = False
34
+
35
+ if not in_code_section:
36
+ processed_lines.append(comment_line(line, comment_char))
37
+ else:
38
+ processed_lines.append(line)
39
+
40
+ # Return the processed string
41
+ return '\n'.join(processed_lines)
42
+ #```
43
+ #
44
+ #This implementation follows the steps you outlined:
45
+ #
46
+ #1. We use `get_comment` to get the appropriate comment character for the specified language.
47
+ #2. We use `find_section` to identify all code sections in the input.
48
+ #3. We find the largest section of code in the specified language.
49
+ #4. We iterate through the lines, commenting out everything outside the largest section of the specified language using `comment_line`.
50
+ #5. Finally, we join the processed lines and return the result as a string.
51
+ #
52
+ #This function will produce a string where only the largest section of code in the specified language is left uncommented, while all other text and code sections are commented out using the appropriate comment character for the language.
pdd/preprocess.py ADDED
@@ -0,0 +1,199 @@
1
+ import os
2
+ import re
3
+ import subprocess
4
+ from typing import List
5
+ from rich import print
6
+ from rich.console import Console
7
+ from rich.panel import Panel
8
+
9
+ console = Console()
10
+
11
+ def preprocess(prompt: str, recursive: bool = False, double_curly_brackets: bool = True, exclude_keys: List[str] = None) -> str:
12
+ """
13
+ Preprocess the given prompt by handling includes, specific tags, and doubling curly brackets.
14
+
15
+ :param prompt: The input text to preprocess.
16
+ :param recursive: Whether to recursively preprocess included content.
17
+ :param double_curly_brackets: Whether to double curly brackets in the text.
18
+ :param exclude_keys: List of keys to exclude from curly bracket doubling.
19
+ :return: The preprocessed text.
20
+ """
21
+ console.print(Panel("Starting preprocessing", style="bold green"))
22
+
23
+ # Process includes in triple backticks
24
+ prompt = process_backtick_includes(prompt, recursive)
25
+
26
+ # Process specific tags without adding closing tags
27
+ prompt = process_specific_tags(prompt, recursive)
28
+
29
+ # Double curly brackets if needed
30
+ if double_curly_brackets:
31
+ prompt = double_curly(prompt, exclude_keys)
32
+
33
+ console.print(Panel("Preprocessing complete", style="bold green"))
34
+ return prompt
35
+
36
+
37
+ def process_backtick_includes(text: str, recursive: bool) -> str:
38
+ """
39
+ Process includes within triple backticks in the text.
40
+
41
+ :param text: The input text containing backtick includes.
42
+ :param recursive: Whether to recursively preprocess included content.
43
+ :return: The text with includes processed.
44
+ """
45
+ pattern = r"```<(.*?)>```"
46
+ matches = re.findall(pattern, text)
47
+
48
+ for match in matches:
49
+ console.print(f"Processing include: [cyan]{match}[/cyan]")
50
+ file_path = get_file_path(match)
51
+ try:
52
+ with open(file_path, 'r') as file:
53
+ content = file.read()
54
+ if recursive:
55
+ content = preprocess(content, recursive, False)
56
+ text = text.replace(f"```<{match}>```", f"```{content}```")
57
+ except FileNotFoundError:
58
+ console.print(f"[bold red]Warning:[/bold red] File not found: {file_path}")
59
+
60
+ return text
61
+
62
+
63
+ def process_specific_tags(text: str, recursive: bool) -> str:
64
+ """
65
+ Process specific tags in the text without adding closing tags.
66
+
67
+ :param text: The input text containing specific tags.
68
+ :param recursive: Whether to recursively preprocess included content.
69
+ :return: The text with specific tags processed.
70
+ """
71
+ def process_tag(match: re.Match) -> str:
72
+ pre_whitespace = match.group(1)
73
+ tag = match.group(2)
74
+ content = match.group(3) if match.group(3) else ""
75
+ post_whitespace = match.group(4)
76
+
77
+ if tag == 'include':
78
+ file_path = get_file_path(content.strip())
79
+ console.print(f"Processing XML include: [cyan]{file_path}[/cyan]")
80
+ try:
81
+ with open(file_path, 'r') as file:
82
+ included_content = file.read()
83
+ if recursive:
84
+ included_content = preprocess(included_content, recursive, False)
85
+ return pre_whitespace + included_content + post_whitespace
86
+ except FileNotFoundError:
87
+ console.print(f"[bold red]Warning:[/bold red] File not found: {file_path}")
88
+ return pre_whitespace + post_whitespace
89
+ elif tag == 'pdd':
90
+ return pre_whitespace + post_whitespace
91
+ elif tag == 'shell':
92
+ command = content.strip()
93
+ console.print(f"Executing shell command: [cyan]{command}[/cyan]")
94
+ try:
95
+ result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
96
+ return pre_whitespace + result.stdout + post_whitespace
97
+ except subprocess.CalledProcessError as e:
98
+ console.print(f"[bold red]Error:[/bold red] Shell command failed: {e}")
99
+ return pre_whitespace + f"Error: {e}" + post_whitespace
100
+ else:
101
+ return match.group(0) # Return the original match for any other tags
102
+
103
+ # Process only specific tags, capturing whitespace around them
104
+ pattern = r'(\s*)<(include|pdd|shell)(?:\s+[^>]*)?(?:>(.*?)</\2>|/|>)(\s*)'
105
+ return re.sub(pattern, process_tag, text, flags=re.DOTALL)
106
+
107
+
108
+ def get_file_path(file_name: str) -> str:
109
+ """
110
+ Get the full file path based on the current directory ('./').
111
+
112
+ :param file_name: The name of the file to locate.
113
+ :return: The full path to the file.
114
+ """
115
+ pdd_path = './' # Using './' as the base path
116
+ return os.path.join(pdd_path, file_name)
117
+
118
+
119
+ def double_curly(text: str, exclude_keys: List[str] = None) -> str:
120
+ """
121
+ Double the curly brackets in the text, excluding specified keys.
122
+ Supports nested curly brackets and handles all code blocks uniformly.
123
+
124
+ :param text: The input text with single curly brackets.
125
+ :param exclude_keys: List of keys to exclude from doubling.
126
+ :return: The text with doubled curly brackets.
127
+ """
128
+ console.print("Doubling curly brackets")
129
+ if exclude_keys is None:
130
+ exclude_keys = []
131
+
132
+ # console.print(f"Before doubling:\n{text}")
133
+
134
+ # Define the pattern for all code blocks (e.g., ```javascript, ```json)
135
+ code_pattern = r"```[\w]*\n[\s\S]*?```"
136
+
137
+ # Split the text into code and non-code segments
138
+ parts = re.split(f"({code_pattern})", text)
139
+
140
+ processed_parts = []
141
+ placeholder_mapping = {}
142
+ placeholder_prefix_excl = "__EXCLUDE_KEY_PLACEHOLDER_"
143
+ placeholder_suffix = "__"
144
+ placeholder_prefix_empty = "__EMPTY_BRACE_PLACEHOLDER_"
145
+
146
+ placeholder_counter = 0
147
+
148
+ for part in parts:
149
+ if re.match(code_pattern, part):
150
+ # It's a code block; process separately
151
+ console.print("Processing code block for curly brackets")
152
+ first_line_end = part.find('\n') + 1
153
+ code_content = part[first_line_end:-3] # Exclude the last ```
154
+ # Double curly brackets inside the code block
155
+ code_content = re.sub(r'(?<!{){(?!{)', '{{', code_content)
156
+ code_content = re.sub(r'(?<!})}(?!})', '}}', code_content)
157
+ # Reconstruct the code block
158
+ processed_part = part[:first_line_end] + code_content + part[-3:]
159
+ processed_parts.append(processed_part)
160
+ else:
161
+ # It's a non-code segment
162
+ temp_part = part
163
+
164
+ # Step 1: Protect excluded keys by replacing {exclude_key} with placeholders
165
+ for key in exclude_keys:
166
+ pattern_excl = r'\{' + re.escape(key) + r'\}'
167
+ placeholder_excl = f"{placeholder_prefix_excl}{placeholder_counter}{placeholder_suffix}"
168
+ temp_part = re.sub(pattern_excl, placeholder_excl, temp_part)
169
+ placeholder_mapping[placeholder_excl] = f"{{{key}}}"
170
+ placeholder_counter += 1
171
+
172
+ # Step 2: Protect empty braces '{}' by replacing with placeholders
173
+ pattern_empty = r'\{\}'
174
+ placeholder_empty = f"{placeholder_prefix_empty}{placeholder_counter}{placeholder_suffix}"
175
+ temp_part = re.sub(pattern_empty, placeholder_empty, temp_part)
176
+ placeholder_mapping[placeholder_empty] = '{{}}'
177
+ placeholder_counter += 1
178
+
179
+ # Step 3: Replace single '{' with '{{' and '}' with '}}'
180
+ temp_part = re.sub(r'(?<!{){(?!{)', '{{', temp_part)
181
+ temp_part = re.sub(r'(?<!})}(?!})', '}}', temp_part)
182
+
183
+ # Step 4: Restore excluded keys from placeholders
184
+ for placeholder, original in placeholder_mapping.items():
185
+ if original != '{{}}':
186
+ temp_part = temp_part.replace(placeholder, original)
187
+
188
+ # Step 5: Restore empty braces from placeholders
189
+ for placeholder, original in placeholder_mapping.items():
190
+ if original == '{{}}':
191
+ temp_part = temp_part.replace(placeholder, original)
192
+
193
+ processed_parts.append(temp_part)
194
+
195
+ # Reconstruct the full text after processing
196
+ text = ''.join(processed_parts)
197
+
198
+ # console.print(f"After doubling:\n{text}")
199
+ return text
pdd/preprocess_main.py ADDED
@@ -0,0 +1,72 @@
1
+ import csv
2
+ import sys
3
+ from typing import Tuple, Optional
4
+ import click
5
+ from rich import print as rprint
6
+
7
+ from .construct_paths import construct_paths
8
+ from .preprocess import preprocess
9
+ from .xml_tagger import xml_tagger
10
+
11
+ def preprocess_main(
12
+ ctx: click.Context, prompt_file: str, output: Optional[str], xml: bool, recursive: bool, double: bool, exclude: list
13
+ ) -> Tuple[str, float, str]:
14
+ """
15
+ CLI wrapper for preprocessing prompts.
16
+
17
+ :param ctx: Click context object containing CLI options and parameters.
18
+ :param prompt_file: Path to the prompt file to preprocess.
19
+ :param output: Optional path where to save the preprocessed prompt.
20
+ :param xml: If True, insert XML delimiters for better structure.
21
+ :param recursive: If True, recursively preprocess all prompt files in the prompt file.
22
+ :param double: If True, curly brackets will be doubled.
23
+ :param exclude: List of keys to exclude from curly bracket doubling.
24
+ :return: Tuple containing the preprocessed prompt, total cost, and model name used.
25
+ """
26
+ try:
27
+ # Construct file paths
28
+ input_file_paths = {"prompt_file": prompt_file}
29
+ command_options = {"output": output}
30
+ input_strings, output_file_paths, _ = construct_paths(
31
+ input_file_paths=input_file_paths,
32
+ force=ctx.obj.get("force", False),
33
+ quiet=ctx.obj.get("quiet", False),
34
+ command="preprocess",
35
+ command_options=command_options,
36
+ )
37
+
38
+ # Load prompt file
39
+ prompt = input_strings["prompt_file"]
40
+
41
+ if xml:
42
+ # Use xml_tagger to add XML delimiters
43
+ strength = ctx.obj.get("strength", 0.5)
44
+ temperature = ctx.obj.get("temperature", 0.0)
45
+ verbose = ctx.obj.get("verbose", False)
46
+ xml_tagged, total_cost, model_name = xml_tagger(prompt, strength, temperature, verbose)
47
+ processed_prompt = xml_tagged
48
+ else:
49
+ # Preprocess the prompt
50
+ processed_prompt = preprocess(prompt, recursive, double, exclude_keys=exclude)
51
+ total_cost, model_name = 0.0, "N/A"
52
+
53
+ # Save the preprocessed prompt
54
+ with open(output_file_paths["output"], "w") as f:
55
+ f.write(processed_prompt)
56
+
57
+ # Provide user feedback
58
+ if not ctx.obj.get("quiet", False):
59
+ rprint("[bold green]Prompt preprocessing completed successfully.[/bold green]")
60
+ if xml:
61
+ rprint(f"[bold]XML Tagging used: {model_name}[/bold]")
62
+ else:
63
+ rprint(f"[bold]Model used: {model_name}[/bold]")
64
+ rprint(f"[bold]Total cost: ${total_cost:.6f}[/bold]")
65
+ rprint(f"[bold]Preprocessed prompt saved to:[/bold] {output_file_paths['output']}")
66
+
67
+ return processed_prompt, total_cost, model_name
68
+
69
+ except Exception as e:
70
+ if not ctx.obj.get("quiet", False):
71
+ rprint(f"[bold red]Error during preprocessing:[/bold red] {e}")
72
+ sys.exit(1)
@@ -0,0 +1,182 @@
1
+ # process_csv_change.py
2
+
3
+ from typing import List, Dict, Tuple
4
+ import os
5
+ import csv
6
+ from pathlib import Path
7
+ import logging
8
+
9
+ from rich.console import Console
10
+ from rich.pretty import Pretty
11
+ from rich.panel import Panel
12
+
13
+ from .change import change # Relative import for the internal change function
14
+
15
+ console = Console()
16
+
17
+ # Set up logging
18
+ logging.basicConfig(level=logging.WARNING)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def process_csv_change(
22
+ csv_file: str,
23
+ strength: float,
24
+ temperature: float,
25
+ code_directory: str,
26
+ language: str,
27
+ extension: str,
28
+ budget: float
29
+ ) -> Tuple[bool, List[Dict[str, str]], float, str]:
30
+ """
31
+ Processes a CSV file to apply changes to code prompts using an LLM model.
32
+
33
+ Args:
34
+ csv_file (str): Path to the CSV file containing 'prompt_name' and 'change_instructions' columns.
35
+ strength (float): Strength parameter for the LLM model (0.0 to 1.0).
36
+ temperature (float): Temperature parameter for the LLM model.
37
+ code_directory (str): Path to the directory where code files are stored.
38
+ language (str): Programming language of the code files.
39
+ extension (str): File extension of the code files.
40
+ budget (float): Maximum allowed total cost for the change process.
41
+
42
+ Returns:
43
+ Tuple[bool, List[Dict[str, str]], float, str]:
44
+ - success (bool): Indicates if changes were successfully made within the budget and without errors.
45
+ - list_of_jsons (List[Dict[str, str]]): List of dictionaries with 'file_name' and 'modified_prompt'.
46
+ - total_cost (float): Total accumulated cost of all change attempts.
47
+ - model_name (str): Name of the LLM model used.
48
+ """
49
+ list_of_jsons: List[Dict[str, str]] = []
50
+ total_cost: float = 0.0
51
+ model_name: str = ""
52
+ success: bool = False
53
+ any_failures: bool = False # Track if any failures occur
54
+
55
+ # Validate inputs
56
+ if not os.path.isfile(csv_file):
57
+ console.print(f"[bold red]Error:[/bold red] CSV file '{csv_file}' does not exist.")
58
+ return success, list_of_jsons, total_cost, model_name
59
+
60
+ if not (0.0 <= strength <= 1.0):
61
+ console.print(f"[bold red]Error:[/bold red] 'strength' must be between 0 and 1. Given: {strength}")
62
+ return success, list_of_jsons, total_cost, model_name
63
+
64
+ if not (0.0 <= temperature <= 1.0):
65
+ console.print(f"[bold red]Error:[/bold red] 'temperature' must be between 0 and 1. Given: {temperature}")
66
+ return success, list_of_jsons, total_cost, model_name
67
+
68
+ code_dir_path = Path(code_directory)
69
+ if not code_dir_path.is_dir():
70
+ console.print(f"[bold red]Error:[/bold red] Code directory '{code_directory}' does not exist or is not a directory.")
71
+ return success, list_of_jsons, total_cost, model_name
72
+
73
+ try:
74
+ with open(csv_file, mode='r', newline='', encoding='utf-8') as csvfile:
75
+ reader = csv.DictReader(csvfile)
76
+ if 'prompt_name' not in reader.fieldnames or 'change_instructions' not in reader.fieldnames:
77
+ console.print("[bold red]Error:[/bold red] CSV file must contain 'prompt_name' and 'change_instructions' columns.")
78
+ return success, list_of_jsons, total_cost, model_name
79
+
80
+ for row_number, row in enumerate(reader, start=1):
81
+ prompt_name = row.get('prompt_name', '').strip()
82
+ change_instructions = row.get('change_instructions', '').strip()
83
+
84
+ if not prompt_name:
85
+ console.print(f"[yellow]Warning:[/yellow] Missing 'prompt_name' in row {row_number}. Skipping.")
86
+ any_failures = True
87
+ continue
88
+
89
+ if not change_instructions:
90
+ console.print(f"[yellow]Warning:[/yellow] Missing 'change_instructions' in row {row_number}. Skipping.")
91
+ any_failures = True
92
+ continue
93
+
94
+ # Parse the prompt_name to get the input_code filename
95
+ try:
96
+ prompt_path = Path(prompt_name)
97
+ base_name = prompt_path.stem # Removes suffix
98
+ # Remove the '_<language>' part if present
99
+ if '_' in base_name:
100
+ base_name = base_name.rsplit('_', 1)[0]
101
+ input_code_name = f"{base_name}{extension}"
102
+ input_code_path = code_dir_path / input_code_name
103
+
104
+ if not input_code_path.is_file():
105
+ console.print(f"[yellow]Warning:[/yellow] Input code file '{input_code_path}' does not exist. Skipping row {row_number}.")
106
+ logger.warning(f"Input code file '{input_code_path}' does not exist for row {row_number}")
107
+ any_failures = True
108
+ continue
109
+
110
+ # Check if prompt file exists
111
+ if not prompt_path.is_file():
112
+ console.print(f"[yellow]Warning:[/yellow] Prompt file '{prompt_name}' does not exist. Skipping row {row_number}.")
113
+ logger.warning(f"Prompt file '{prompt_name}' does not exist for row {row_number}")
114
+ any_failures = True
115
+ continue
116
+
117
+ # Read the input_code from the file
118
+ with open(input_code_path, 'r', encoding='utf-8') as code_file:
119
+ input_code = code_file.read()
120
+
121
+ # Read the input_prompt from the prompt file
122
+ with open(prompt_path, 'r', encoding='utf-8') as prompt_file:
123
+ input_prompt = prompt_file.read()
124
+
125
+ # Call the change function
126
+ modified_prompt, cost, current_model_name = change(
127
+ input_prompt=input_prompt,
128
+ input_code=input_code,
129
+ change_prompt=change_instructions,
130
+ strength=strength,
131
+ temperature=temperature
132
+ )
133
+
134
+ # Accumulate the total cost
135
+ total_cost += cost
136
+
137
+ # Check if budget is exceeded
138
+ if total_cost > budget:
139
+ console.print(f"[bold red]Budget exceeded after row {row_number}. Stopping further processing.[/bold red]")
140
+ any_failures = True
141
+ break
142
+
143
+ # Set the model_name (assumes all calls use the same model)
144
+ if not model_name:
145
+ model_name = current_model_name
146
+ elif model_name != current_model_name:
147
+ console.print(f"[yellow]Warning:[/yellow] Model name changed from '{model_name}' to '{current_model_name}' in row {row_number}.")
148
+
149
+ # Add to the list_of_jsons
150
+ list_of_jsons.append({
151
+ "file_name": prompt_name,
152
+ "modified_prompt": modified_prompt
153
+ })
154
+
155
+ console.print(Panel(f"[green]Row {row_number} processed successfully.[/green]"))
156
+
157
+ except Exception as e:
158
+ console.print(f"[red]Error:[/red] Failed to process 'prompt_name' in row {row_number}: {str(e)}")
159
+ logger.exception(f"Failed to process row {row_number}")
160
+ any_failures = True
161
+ continue
162
+
163
+ # Determine success based on whether total_cost is within budget and no failures occurred
164
+ success = (total_cost <= budget) and not any_failures
165
+
166
+ # Pretty print the results
167
+ console.print(Panel(f"[bold]Processing Complete[/bold]\n"
168
+ f"Success: {'Yes' if success else 'No'}\n"
169
+ f"Total Cost: ${total_cost:.6f}\n"
170
+ f"Model Used: {model_name if model_name else 'N/A'}"))
171
+
172
+ # Optionally, pretty print the list of modified prompts
173
+ if list_of_jsons:
174
+ console.print(Panel("[bold]List of Modified Prompts[/bold]"))
175
+ console.print(Pretty(list_of_jsons))
176
+
177
+ return success, list_of_jsons, total_cost, model_name
178
+
179
+ except Exception as e:
180
+ console.print(f"[bold red]Unexpected Error:[/bold red] {str(e)}")
181
+ logger.exception("Unexpected error occurred")
182
+ return success, list_of_jsons, total_cost, model_name