gpt-pr 0.2.1__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpt_pr/__init__.py +3 -0
- gpt_pr/checkversion.py +93 -0
- gpt_pr/config.py +104 -0
- gpt_pr/gh.py +44 -0
- {gptpr → gpt_pr}/gitutil.py +26 -12
- gpt_pr/gpt.py +8 -0
- gpt_pr/main.py +117 -0
- gpt_pr/prdata.py +217 -0
- gpt_pr/test_checkversion.py +132 -0
- gpt_pr/test_config.py +138 -0
- gpt_pr/test_gh.py +60 -0
- gpt_pr/test_prdata.py +17 -0
- gpt_pr-0.7.2.dist-info/METADATA +285 -0
- gpt_pr-0.7.2.dist-info/RECORD +17 -0
- {gpt_pr-0.2.1.dist-info → gpt_pr-0.7.2.dist-info}/WHEEL +1 -2
- gpt_pr-0.7.2.dist-info/entry_points.txt +4 -0
- gpt_pr-0.2.1.dist-info/METADATA +0 -49
- gpt_pr-0.2.1.dist-info/RECORD +0 -13
- gpt_pr-0.2.1.dist-info/entry_points.txt +0 -2
- gpt_pr-0.2.1.dist-info/top_level.txt +0 -1
- gptpr/__init__.py +0 -0
- gptpr/gh.py +0 -27
- gptpr/main.py +0 -52
- gptpr/prdata.py +0 -161
- gptpr/test_prdata.py +0 -13
- gptpr/version.py +0 -1
- {gptpr → gpt_pr}/consolecolor.py +0 -0
gpt_pr/__init__.py
ADDED
gpt_pr/checkversion.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import tempfile
|
|
5
|
+
from gpt_pr import __version__
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
|
+
|
|
8
|
+
from gpt_pr import consolecolor as cc
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
PACKAGE_NAME = "gpt-pr"
|
|
12
|
+
CACHE_FILE = os.path.join(os.path.expanduser("~"), ".gpt_pr_update_cache.json")
|
|
13
|
+
CACHE_DURATION = timedelta(days=1)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def cache_daily_version(func):
|
|
17
|
+
def wrapper(*args, **kwargs):
|
|
18
|
+
cache = load_cache()
|
|
19
|
+
last_checked = cache.get("last_checked")
|
|
20
|
+
|
|
21
|
+
if last_checked:
|
|
22
|
+
last_checked = datetime.fromisoformat(last_checked)
|
|
23
|
+
|
|
24
|
+
if datetime.now() - last_checked < CACHE_DURATION:
|
|
25
|
+
# Use cached version info
|
|
26
|
+
latest_version = cache.get("latest_version")
|
|
27
|
+
if latest_version:
|
|
28
|
+
return latest_version
|
|
29
|
+
|
|
30
|
+
latest_version = func(*args, **kwargs)
|
|
31
|
+
cache = {
|
|
32
|
+
"last_checked": datetime.now().isoformat(),
|
|
33
|
+
"latest_version": latest_version,
|
|
34
|
+
}
|
|
35
|
+
save_cache(cache)
|
|
36
|
+
|
|
37
|
+
return latest_version
|
|
38
|
+
|
|
39
|
+
return wrapper
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_cache_file_path():
|
|
43
|
+
temp_dir = tempfile.gettempdir()
|
|
44
|
+
return os.path.join(temp_dir, f"{PACKAGE_NAME}_update_cache.json")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@cache_daily_version
|
|
48
|
+
def get_latest_version():
|
|
49
|
+
url = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
response = requests.get(url)
|
|
53
|
+
response.raise_for_status()
|
|
54
|
+
data = response.json()
|
|
55
|
+
return data["info"]["version"]
|
|
56
|
+
except requests.exceptions.RequestException as e:
|
|
57
|
+
print(f"Error fetching latest version info: {e}")
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_cache():
|
|
62
|
+
cache_file = get_cache_file_path()
|
|
63
|
+
if os.path.exists(cache_file):
|
|
64
|
+
with open(cache_file, "r") as file:
|
|
65
|
+
return json.load(file)
|
|
66
|
+
|
|
67
|
+
return {}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def save_cache(data):
|
|
71
|
+
cache_file = get_cache_file_path()
|
|
72
|
+
with open(cache_file, "w") as file:
|
|
73
|
+
file.write(json.dumps(data))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def check_for_updates():
|
|
77
|
+
latest_version = get_latest_version()
|
|
78
|
+
|
|
79
|
+
if latest_version and latest_version != __version__:
|
|
80
|
+
print("")
|
|
81
|
+
print(
|
|
82
|
+
cc.yellow(
|
|
83
|
+
f"A new version of {PACKAGE_NAME} is available ({latest_version}). "
|
|
84
|
+
f"You are using version {__version__}. Please update by running"
|
|
85
|
+
),
|
|
86
|
+
cc.green(f"pip install --upgrade {PACKAGE_NAME}."),
|
|
87
|
+
)
|
|
88
|
+
print("")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
if __name__ == "__main__":
|
|
92
|
+
check_for_updates()
|
|
93
|
+
# Your CLI code here
|
gpt_pr/config.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
import configparser
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def config_command_example(name, value_sample):
|
|
7
|
+
return f'gpt-pr-config set {name} {value_sample}'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
CONFIG_PROJECT_REPO_URL = 'https://github.com/alissonperez/gpt-pr'
|
|
11
|
+
CONFIG_README_SECTION = f'{CONFIG_PROJECT_REPO_URL}?tab=readme-ov-file#configuration'
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Config:
|
|
15
|
+
|
|
16
|
+
config_filename = '.gpt-pr.ini'
|
|
17
|
+
|
|
18
|
+
_default_config = {
|
|
19
|
+
# Amenities
|
|
20
|
+
'ADD_TOOL_SIGNATURE': 'true', # Add GPT-PR signature to PRs
|
|
21
|
+
|
|
22
|
+
# Github
|
|
23
|
+
'GH_TOKEN': '',
|
|
24
|
+
|
|
25
|
+
# LLM input MAX Tokens
|
|
26
|
+
'INPUT_MAX_TOKENS': '15000',
|
|
27
|
+
|
|
28
|
+
# Open AI info
|
|
29
|
+
'OPENAI_MODEL': 'gpt-4o-mini',
|
|
30
|
+
'OPENAI_API_KEY': '',
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
def __init__(self, config_dir=None):
|
|
34
|
+
self.default_config = deepcopy(self._default_config)
|
|
35
|
+
self._config_dir = config_dir or os.path.expanduser('~')
|
|
36
|
+
self._config = configparser.ConfigParser()
|
|
37
|
+
self._initialized = False
|
|
38
|
+
|
|
39
|
+
def load(self):
|
|
40
|
+
if self._initialized:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
config_file_path = self.get_filepath()
|
|
44
|
+
|
|
45
|
+
if os.path.exists(config_file_path):
|
|
46
|
+
self._config.read(config_file_path)
|
|
47
|
+
self._ensure_default_values()
|
|
48
|
+
else:
|
|
49
|
+
self._config['user'] = {}
|
|
50
|
+
self._config['DEFAULT'] = deepcopy(self.default_config)
|
|
51
|
+
self.persist()
|
|
52
|
+
|
|
53
|
+
self._initialized = True
|
|
54
|
+
|
|
55
|
+
def _ensure_default_values(self):
|
|
56
|
+
added = False
|
|
57
|
+
for key, value in self.default_config.items():
|
|
58
|
+
if key not in self._config['DEFAULT']:
|
|
59
|
+
self._config['DEFAULT'][key] = value
|
|
60
|
+
added = True
|
|
61
|
+
|
|
62
|
+
if added:
|
|
63
|
+
self.persist()
|
|
64
|
+
|
|
65
|
+
def persist(self):
|
|
66
|
+
config_file_path = self.get_filepath()
|
|
67
|
+
|
|
68
|
+
with open(config_file_path, 'w') as configfile:
|
|
69
|
+
self._config.write(configfile)
|
|
70
|
+
|
|
71
|
+
def get_filepath(self):
|
|
72
|
+
return os.path.join(self._config_dir, self.config_filename)
|
|
73
|
+
|
|
74
|
+
def set_user_config(self, name, value):
|
|
75
|
+
self.load()
|
|
76
|
+
self._config['user'][name] = str(value)
|
|
77
|
+
|
|
78
|
+
def reset_user_config(self, name):
|
|
79
|
+
self.load()
|
|
80
|
+
self._config['user'][name] = self.default_config[name]
|
|
81
|
+
self.persist()
|
|
82
|
+
|
|
83
|
+
def get_user_config(self, name):
|
|
84
|
+
self.load()
|
|
85
|
+
return self._config['user'][name]
|
|
86
|
+
|
|
87
|
+
def all_values(self):
|
|
88
|
+
self.load()
|
|
89
|
+
|
|
90
|
+
# iterate over all sections and values and return them in a list
|
|
91
|
+
result = []
|
|
92
|
+
|
|
93
|
+
# add default section
|
|
94
|
+
for option in self._config['DEFAULT']:
|
|
95
|
+
result.append(('DEFAULT', option, self._config['DEFAULT'][option]))
|
|
96
|
+
|
|
97
|
+
for section in self._config.sections():
|
|
98
|
+
for option in self._config[section]:
|
|
99
|
+
result.append((section, option, self._config[section][option]))
|
|
100
|
+
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
config = Config()
|
gpt_pr/gh.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from github import Github
|
|
3
|
+
from InquirerPy import inquirer
|
|
4
|
+
from gpt_pr.config import config, config_command_example, CONFIG_README_SECTION
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _get_gh_token():
|
|
8
|
+
gh_token = config.get_user_config("GH_TOKEN")
|
|
9
|
+
if not gh_token:
|
|
10
|
+
gh_token = os.environ.get("GH_TOKEN")
|
|
11
|
+
|
|
12
|
+
if not gh_token:
|
|
13
|
+
print(
|
|
14
|
+
'Please set "gh_token" config. Just run:',
|
|
15
|
+
config_command_example("gh_token", "[my gh token]"),
|
|
16
|
+
"more about at",
|
|
17
|
+
CONFIG_README_SECTION,
|
|
18
|
+
)
|
|
19
|
+
raise SystemExit(1)
|
|
20
|
+
|
|
21
|
+
return gh_token
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_pr(pr_data, yield_confirmation, gh=None):
|
|
25
|
+
if not gh:
|
|
26
|
+
gh = Github(_get_gh_token())
|
|
27
|
+
|
|
28
|
+
repo = gh.get_repo(f"{pr_data.branch_info.owner}/{pr_data.branch_info.repo}")
|
|
29
|
+
|
|
30
|
+
pr_confirmation = (
|
|
31
|
+
yield_confirmation
|
|
32
|
+
or inquirer.confirm(message="Create GitHub PR?", default=True).execute()
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if pr_confirmation:
|
|
36
|
+
pr = repo.create_pull(
|
|
37
|
+
title=pr_data.title,
|
|
38
|
+
body=pr_data.create_body(),
|
|
39
|
+
head=pr_data.branch_info.branch,
|
|
40
|
+
base=pr_data.branch_info.base_branch,
|
|
41
|
+
)
|
|
42
|
+
print("Pull request created successfully: ", pr.html_url)
|
|
43
|
+
else:
|
|
44
|
+
print("cancelling...")
|
{gptpr → gpt_pr}/gitutil.py
RENAMED
|
@@ -32,12 +32,9 @@ class FileChange:
|
|
|
32
32
|
return f'{self.file_path} (+{(self.lines_added)} -{self.lines_removed})'
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def get_branch_info(base_branch, yield_confirmation):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
# Instantiate the repository
|
|
40
|
-
repo = Repo(current_dir)
|
|
35
|
+
def get_branch_info(base_branch, origin, yield_confirmation):
|
|
36
|
+
git_dir = fetch_nearest_git_dir(os.getcwd())
|
|
37
|
+
repo = Repo(git_dir)
|
|
41
38
|
|
|
42
39
|
# Check that the repository loaded correctly
|
|
43
40
|
if not repo.bare:
|
|
@@ -57,7 +54,7 @@ def get_branch_info(base_branch, yield_confirmation):
|
|
|
57
54
|
if not _branch_exists(repo, base_branch):
|
|
58
55
|
raise Exception(f'Base branch {base_branch} does not exist.')
|
|
59
56
|
|
|
60
|
-
owner, repo_name = _get_remote_info(repo)
|
|
57
|
+
owner, repo_name = _get_remote_info(repo, origin)
|
|
61
58
|
|
|
62
59
|
commits = _get_diff_messages_against_base_branch(repo, current_branch.name, base_branch)
|
|
63
60
|
commits = _get_valid_commits(commits, yield_confirmation)
|
|
@@ -79,6 +76,18 @@ def get_branch_info(base_branch, yield_confirmation):
|
|
|
79
76
|
)
|
|
80
77
|
|
|
81
78
|
|
|
79
|
+
def fetch_nearest_git_dir(current_dir):
|
|
80
|
+
# Goes upwards until it finds a .git directory
|
|
81
|
+
path = os.path.abspath(current_dir)
|
|
82
|
+
while True:
|
|
83
|
+
if os.path.isdir(os.path.join(path, '.git')):
|
|
84
|
+
return path
|
|
85
|
+
parent = os.path.dirname(path)
|
|
86
|
+
if parent == path: # Reached root
|
|
87
|
+
raise FileNotFoundError(f"Could not find a .git directory in or above '{current_dir}'")
|
|
88
|
+
path = parent
|
|
89
|
+
|
|
90
|
+
|
|
82
91
|
def _branch_exists(repo, branch_name):
|
|
83
92
|
if branch_name in repo.branches:
|
|
84
93
|
return True
|
|
@@ -125,9 +134,9 @@ def _get_highlight_commits(commits, yield_confirmation):
|
|
|
125
134
|
return highlight_commits
|
|
126
135
|
|
|
127
136
|
|
|
128
|
-
def _get_remote_info(repo):
|
|
137
|
+
def _get_remote_info(repo, origin):
|
|
129
138
|
for remote in repo.remotes:
|
|
130
|
-
if remote.name !=
|
|
139
|
+
if remote.name != origin:
|
|
131
140
|
continue
|
|
132
141
|
|
|
133
142
|
remote_urls_joined = ','.join([str(url) for url in remote.urls])
|
|
@@ -137,7 +146,7 @@ def _get_remote_info(repo):
|
|
|
137
146
|
for url in remote.urls:
|
|
138
147
|
return _extract_owner_and_repo(url)
|
|
139
148
|
|
|
140
|
-
raise Exception('Could not find origin remote.')
|
|
149
|
+
raise Exception(f'Could not find \'{origin}\' remote.')
|
|
141
150
|
|
|
142
151
|
|
|
143
152
|
def _extract_owner_and_repo(repo_url):
|
|
@@ -183,10 +192,15 @@ def _get_stats(repo, base_branch, branch):
|
|
|
183
192
|
continue
|
|
184
193
|
|
|
185
194
|
line = line.split('\t')
|
|
195
|
+
|
|
196
|
+
# Binary files will not have stats (just "-")
|
|
197
|
+
added = int(line[0]) if line[0] and line[0].isdigit() else 0
|
|
198
|
+
removed = int(line[1]) if line[1] and line[1].isdigit() else 0
|
|
199
|
+
|
|
186
200
|
files_changed.append(FileChange(
|
|
187
201
|
file_path=line[2],
|
|
188
|
-
lines_added=
|
|
189
|
-
lines_removed=
|
|
202
|
+
lines_added=added,
|
|
203
|
+
lines_removed=removed
|
|
190
204
|
))
|
|
191
205
|
|
|
192
206
|
return files_changed
|
gpt_pr/gpt.py
ADDED
gpt_pr/main.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import fire
|
|
2
|
+
from InquirerPy import inquirer
|
|
3
|
+
|
|
4
|
+
from gpt_pr.gitutil import get_branch_info
|
|
5
|
+
from gpt_pr.gh import create_pr
|
|
6
|
+
from gpt_pr.prdata import get_pr_data
|
|
7
|
+
from gpt_pr import __version__
|
|
8
|
+
from gpt_pr.config import config, config_command_example, CONFIG_README_SECTION
|
|
9
|
+
from gpt_pr import consolecolor as cc
|
|
10
|
+
from gpt_pr.checkversion import check_for_updates
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run(base_branch="main", origin="origin", yield_confirmation=False, version=False):
|
|
14
|
+
"""
|
|
15
|
+
Create Pull Requests from current branch with base branch (default 'main' branch)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
if version:
|
|
19
|
+
print("Current version:", __version__)
|
|
20
|
+
return
|
|
21
|
+
|
|
22
|
+
branch_info = get_branch_info(base_branch, origin, yield_confirmation)
|
|
23
|
+
|
|
24
|
+
if not branch_info:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
pr_data = None
|
|
28
|
+
generate_pr_data = True
|
|
29
|
+
while generate_pr_data:
|
|
30
|
+
pr_data = get_pr_data(branch_info)
|
|
31
|
+
print("")
|
|
32
|
+
print("#########################################")
|
|
33
|
+
print(pr_data.to_display())
|
|
34
|
+
print("#########################################")
|
|
35
|
+
print("")
|
|
36
|
+
|
|
37
|
+
if yield_confirmation:
|
|
38
|
+
break
|
|
39
|
+
|
|
40
|
+
generate_pr_data = not inquirer.confirm(
|
|
41
|
+
message="Create PR with this? If 'no', let's try again...", default=True
|
|
42
|
+
).execute()
|
|
43
|
+
|
|
44
|
+
if generate_pr_data:
|
|
45
|
+
print("Generating another PR data...")
|
|
46
|
+
|
|
47
|
+
create_pr(pr_data, yield_confirmation)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def set_config(name, value):
|
|
51
|
+
name = name.upper()
|
|
52
|
+
config.set_user_config(name, value)
|
|
53
|
+
config.persist()
|
|
54
|
+
|
|
55
|
+
print("Config value", cc.bold(name), "set to", cc.yellow(value))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_config(name):
|
|
59
|
+
upper_name = name.upper()
|
|
60
|
+
print(
|
|
61
|
+
"Config value",
|
|
62
|
+
cc.bold(name),
|
|
63
|
+
"=",
|
|
64
|
+
cc.yellow(config.get_user_config(upper_name)),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def reset_config(name):
|
|
69
|
+
upper_name = name.upper()
|
|
70
|
+
config.reset_user_config(upper_name)
|
|
71
|
+
print(
|
|
72
|
+
"Config value",
|
|
73
|
+
cc.bold(name),
|
|
74
|
+
"=",
|
|
75
|
+
cc.yellow(config.get_user_config(upper_name)),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def print_config():
|
|
80
|
+
print("Config values at", cc.yellow(config.get_filepath()))
|
|
81
|
+
print("")
|
|
82
|
+
print(
|
|
83
|
+
"To set values, just run:",
|
|
84
|
+
cc.yellow(config_command_example("[config name]", "[value]")),
|
|
85
|
+
)
|
|
86
|
+
print("More about at", cc.yellow(CONFIG_README_SECTION))
|
|
87
|
+
print("")
|
|
88
|
+
current_section = None
|
|
89
|
+
for section, option, value in config.all_values():
|
|
90
|
+
if current_section != section:
|
|
91
|
+
print("")
|
|
92
|
+
current_section = section
|
|
93
|
+
|
|
94
|
+
print(f"[{cc.bold(section)}]", option, "=", cc.yellow(value))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def main():
|
|
98
|
+
check_for_updates()
|
|
99
|
+
|
|
100
|
+
fire.Fire(run)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def run_config():
|
|
104
|
+
check_for_updates()
|
|
105
|
+
|
|
106
|
+
fire.Fire(
|
|
107
|
+
{
|
|
108
|
+
"set": set_config,
|
|
109
|
+
"get": get_config,
|
|
110
|
+
"print": print_config,
|
|
111
|
+
"reset": reset_config,
|
|
112
|
+
}
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
if __name__ == "__main__":
|
|
117
|
+
main()
|
gpt_pr/prdata.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import tiktoken
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
from pydantic_ai import Agent
|
|
8
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
9
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
10
|
+
|
|
11
|
+
from gpt_pr.gitutil import BranchInfo, fetch_nearest_git_dir
|
|
12
|
+
from gpt_pr.config import config, CONFIG_PROJECT_REPO_URL
|
|
13
|
+
import gpt_pr.consolecolor as cc
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PRTemplateModel(BaseModel):
|
|
17
|
+
title: str = Field(description="Title of the pull request")
|
|
18
|
+
description: str = Field(description="Description of the pull request")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
TOKENIZER_RATIO = 4
|
|
22
|
+
|
|
23
|
+
DEFAULT_PR_TEMPLATE = (
|
|
24
|
+
"### Ref. [Link]\n\n## What was done?\n[Fill here]\n\n"
|
|
25
|
+
"## How was it done?\n[Fill here]\n\n"
|
|
26
|
+
"## How was it tested?\n[Fill here with test information from diff content or commits]"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
SYSTEM_PROMPT = '''You are a generator of pull request data based on diff changes.
|
|
30
|
+
|
|
31
|
+
By analyzing the diff content and commit messages between two branches, you must strictly adhere to
|
|
32
|
+
the provided Pull Request template and produce a complete, ready-to-use PR output.
|
|
33
|
+
|
|
34
|
+
Your response must include:
|
|
35
|
+
- A clear and concise PR title.
|
|
36
|
+
- A PR description that:
|
|
37
|
+
- Details the work accomplished.
|
|
38
|
+
- Describes the methodology used, including testing procedures.
|
|
39
|
+
- Lists significant changes in bullet points.
|
|
40
|
+
|
|
41
|
+
Rules:
|
|
42
|
+
- Do not include raw diff content of any size.
|
|
43
|
+
- Do not add any explanations, suggestions, or messages directed to the user.
|
|
44
|
+
|
|
45
|
+
Pull Request Template:
|
|
46
|
+
---
|
|
47
|
+
{pr_template}
|
|
48
|
+
---
|
|
49
|
+
'''
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _get_pr_template():
|
|
53
|
+
pr_template = DEFAULT_PR_TEMPLATE
|
|
54
|
+
git_dir = fetch_nearest_git_dir(os.getcwd())
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
github_dir = os.path.join(git_dir, ".github")
|
|
58
|
+
github_files = os.listdir(github_dir)
|
|
59
|
+
pr_template_file = [
|
|
60
|
+
f for f in github_files if f.lower().startswith("pull_request_template")
|
|
61
|
+
][0]
|
|
62
|
+
pr_template_file_path = os.path.join(github_dir, pr_template_file)
|
|
63
|
+
|
|
64
|
+
with open(pr_template_file_path, "r") as f:
|
|
65
|
+
local_pr_template = f.read()
|
|
66
|
+
|
|
67
|
+
if local_pr_template.strip() != "":
|
|
68
|
+
print("Found PR template at:", pr_template_file_path)
|
|
69
|
+
pr_template = local_pr_template
|
|
70
|
+
else:
|
|
71
|
+
print(
|
|
72
|
+
"Empty PR template at:",
|
|
73
|
+
pr_template_file_path,
|
|
74
|
+
"using default template.",
|
|
75
|
+
)
|
|
76
|
+
except Exception:
|
|
77
|
+
print("PR template not found in .github dir. Using default template.")
|
|
78
|
+
|
|
79
|
+
return pr_template
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _get_open_ai_key():
|
|
83
|
+
api_key = config.get_user_config("OPENAI_API_KEY")
|
|
84
|
+
|
|
85
|
+
if not api_key:
|
|
86
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
87
|
+
|
|
88
|
+
if not api_key:
|
|
89
|
+
print(
|
|
90
|
+
'Please set "openai_api_key" config, just run:',
|
|
91
|
+
cc.yellow("gpt-pr-config set openai_api_key [open ai key]"),
|
|
92
|
+
)
|
|
93
|
+
raise SystemExit(1)
|
|
94
|
+
|
|
95
|
+
return api_key
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _count_tokens(text: str) -> int:
|
|
99
|
+
"""Returns the number of tokens in a text string."""
|
|
100
|
+
openai_model = config.get_user_config("OPENAI_MODEL")
|
|
101
|
+
try:
|
|
102
|
+
encoding = tiktoken.encoding_for_model(openai_model)
|
|
103
|
+
except KeyError:
|
|
104
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
105
|
+
|
|
106
|
+
return len(encoding.encode(text))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class PrData:
|
|
111
|
+
branch_info: BranchInfo
|
|
112
|
+
title: str
|
|
113
|
+
body: str
|
|
114
|
+
|
|
115
|
+
def to_display(self):
|
|
116
|
+
return "\n".join(
|
|
117
|
+
[
|
|
118
|
+
f"{cc.bold('Repository')}: {cc.yellow(self.branch_info.owner)}/{cc.yellow(self.branch_info.repo)}",
|
|
119
|
+
f"{cc.bold('Title')}: {cc.yellow(self.title)}",
|
|
120
|
+
f"{cc.bold('Branch name')}: {cc.yellow(self.branch_info.branch)}",
|
|
121
|
+
f"{cc.bold('Base branch')}: {cc.yellow(self.branch_info.base_branch)}",
|
|
122
|
+
f"{cc.bold('PR Description')}:\n{self.create_body()}",
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def create_body(self):
|
|
127
|
+
body = self.body
|
|
128
|
+
|
|
129
|
+
if config.get_user_config("ADD_TOOL_SIGNATURE") == "true":
|
|
130
|
+
pr_signature = f"Generated by [GPT-PR]({CONFIG_PROJECT_REPO_URL})"
|
|
131
|
+
body += "\n\n---\n\n" + pr_signature
|
|
132
|
+
|
|
133
|
+
return body
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_pr_data(branch_info):
|
|
137
|
+
system_prompt, messages = _get_messages(branch_info)
|
|
138
|
+
|
|
139
|
+
openai_model = config.get_user_config("OPENAI_MODEL")
|
|
140
|
+
model = OpenAIChatModel(openai_model, provider=OpenAIProvider(api_key=_get_open_ai_key()))
|
|
141
|
+
|
|
142
|
+
support_agent = Agent(
|
|
143
|
+
model=model, # TODO: make configurable for other providers
|
|
144
|
+
output_type=PRTemplateModel,
|
|
145
|
+
instructions=system_prompt,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
print("Generating changes description using OpenAI model", cc.yellow(openai_model), '. This may take time...')
|
|
149
|
+
|
|
150
|
+
result = support_agent.run_sync(messages)
|
|
151
|
+
output = result.output
|
|
152
|
+
|
|
153
|
+
return PrData(
|
|
154
|
+
branch_info=branch_info, title=output.title.strip(), body=output.description.strip()
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _get_messages(branch_info):
|
|
159
|
+
system_prompt = SYSTEM_PROMPT.format(pr_template=_get_pr_template())
|
|
160
|
+
|
|
161
|
+
messages = []
|
|
162
|
+
|
|
163
|
+
if len(branch_info.highlight_commits) > 0:
|
|
164
|
+
messages.append("main commits:\n" + "\n".join(branch_info.highlight_commits))
|
|
165
|
+
messages.append("---")
|
|
166
|
+
messages.append("secondary commits:\n" + "\n".join(branch_info.commits))
|
|
167
|
+
else:
|
|
168
|
+
messages.append("git commits:\n" + "\n".join(branch_info.commits))
|
|
169
|
+
|
|
170
|
+
joined_messages = "\n".join([m for m in messages])
|
|
171
|
+
current_total_tokens = _count_tokens(joined_messages) + _count_tokens(SYSTEM_PROMPT)
|
|
172
|
+
|
|
173
|
+
input_max_tokens = int(config.get_user_config("INPUT_MAX_TOKENS"))
|
|
174
|
+
|
|
175
|
+
if current_total_tokens > input_max_tokens:
|
|
176
|
+
exp_message = (
|
|
177
|
+
f"Length of {current_total_tokens} tokens for basic prompt "
|
|
178
|
+
f"(description and commits) is greater than max tokens {input_max_tokens} "
|
|
179
|
+
"(config 'input_max_tokens')"
|
|
180
|
+
)
|
|
181
|
+
raise Exception(exp_message)
|
|
182
|
+
|
|
183
|
+
total_tokens_with_diff = current_total_tokens + _count_tokens(branch_info.diff)
|
|
184
|
+
if total_tokens_with_diff > input_max_tokens:
|
|
185
|
+
print_msg = (
|
|
186
|
+
f"Length git changes with diff is too big (total is {total_tokens_with_diff}, "
|
|
187
|
+
f"'input_max_tokens' config is {input_max_tokens})."
|
|
188
|
+
)
|
|
189
|
+
print(print_msg, cc.red("Skipping changes diff content..."))
|
|
190
|
+
else:
|
|
191
|
+
messages.append("Diff changes:\n" + branch_info.diff)
|
|
192
|
+
|
|
193
|
+
return system_prompt, '\n'.join(messages)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _parse_json(content):
|
|
197
|
+
"""
|
|
198
|
+
A bit of a hack to parse the json content from the chat completion
|
|
199
|
+
Sometimes it returns a string with invalid json content (line breaks) that
|
|
200
|
+
makes it hard to parse.
|
|
201
|
+
example:
|
|
202
|
+
|
|
203
|
+
content = '{\n"title": "feat(dependencies): pin dependencies versions",\n"description":
|
|
204
|
+
"### Ref. [Link]\n\n## What was done? ..."\n}'
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
content = content.replace('{\n"title":', '{"title":')
|
|
209
|
+
content = content.replace(',\n"description":', ',"description":')
|
|
210
|
+
content = content.replace("\n}", "}")
|
|
211
|
+
content = content.replace("\n", "\\n")
|
|
212
|
+
|
|
213
|
+
return json.loads(content)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
print("Error to decode message:", e)
|
|
216
|
+
print("Content:", content)
|
|
217
|
+
raise e
|