AtCoderStudyBooster 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atcdr/__init__.py +0 -0
- atcdr/download.py +266 -0
- atcdr/generate.py +167 -0
- atcdr/main.py +35 -0
- atcdr/open.py +36 -0
- atcdr/test.py +276 -0
- atcdr/util/__init__.py +0 -0
- atcdr/util/cost.py +120 -0
- atcdr/util/filename.py +137 -0
- atcdr/util/gpt.py +112 -0
- atcdr/util/problem.py +87 -0
- atcoderstudybooster-0.1.0.dist-info/METADATA +94 -0
- atcoderstudybooster-0.1.0.dist-info/RECORD +15 -0
- atcoderstudybooster-0.1.0.dist-info/WHEEL +4 -0
- atcoderstudybooster-0.1.0.dist-info/entry_points.txt +2 -0
atcdr/__init__.py
ADDED
File without changes
|
atcdr/download.py
ADDED
@@ -0,0 +1,266 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import time
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from enum import Enum
|
6
|
+
from typing import Callable, List, Match, Optional, Union, cast
|
7
|
+
|
8
|
+
import requests
|
9
|
+
|
10
|
+
from atcdr.util.problem import make_problem_markdown
|
11
|
+
|
12
|
+
|
13
|
+
class Diff(Enum):
|
14
|
+
A = 'A'
|
15
|
+
B = 'B'
|
16
|
+
C = 'C'
|
17
|
+
D = 'D'
|
18
|
+
E = 'E'
|
19
|
+
F = 'F'
|
20
|
+
G = 'G'
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class Problem:
|
25
|
+
number: int
|
26
|
+
difficulty: Diff
|
27
|
+
|
28
|
+
|
29
|
+
def get_problem_html(problem: Problem) -> Optional[str]:
|
30
|
+
url = f'https://atcoder.jp/contests/abc{problem.number}/tasks/abc{problem.number}_{problem.difficulty.value.lower()}'
|
31
|
+
response = requests.get(url)
|
32
|
+
retry_attempts = 3
|
33
|
+
retry_wait = 1 # 1 second
|
34
|
+
|
35
|
+
for _ in range(retry_attempts):
|
36
|
+
response = requests.get(url)
|
37
|
+
if response.status_code == 200:
|
38
|
+
return response.text
|
39
|
+
elif response.status_code == 429:
|
40
|
+
print(
|
41
|
+
f'[Error{response.status_code}] 再試行します. abc{problem.number} {problem.difficulty.value}'
|
42
|
+
)
|
43
|
+
time.sleep(retry_wait)
|
44
|
+
elif 300 <= response.status_code < 400:
|
45
|
+
print(
|
46
|
+
f'[Erroe{response.status_code}] リダイレクトが発生しました。abc{problem.number} {problem.difficulty.value}'
|
47
|
+
)
|
48
|
+
elif 400 <= response.status_code < 500:
|
49
|
+
print(
|
50
|
+
f'[Error{response.status_code}] 問題が見つかりません。abc{problem.number} {problem.difficulty.value}'
|
51
|
+
)
|
52
|
+
break
|
53
|
+
elif 500 <= response.status_code < 600:
|
54
|
+
print(
|
55
|
+
f'[Error{response.status_code}] サーバーエラーが発生しました。abc{problem.number} {problem.difficulty.value}'
|
56
|
+
)
|
57
|
+
break
|
58
|
+
else:
|
59
|
+
print(
|
60
|
+
f'[Error{response.status_code}] abc{problem.number} {problem.difficulty.value}に対応するHTMLファイルを取得できませんでした。'
|
61
|
+
)
|
62
|
+
break
|
63
|
+
return None
|
64
|
+
|
65
|
+
|
66
|
+
def repair_html(html: str) -> str:
|
67
|
+
html = html.replace('//img.atcoder.jp', 'https://img.atcoder.jp')
|
68
|
+
html = html.replace(
|
69
|
+
'<meta http-equiv="Content-Language" content="en">',
|
70
|
+
'<meta http-equiv="Content-Language" content="ja">',
|
71
|
+
)
|
72
|
+
html = html.replace('LANG = "en"', 'LANG="ja"')
|
73
|
+
return html
|
74
|
+
|
75
|
+
|
76
|
+
def get_title_from_html(html: str) -> Optional[str]:
|
77
|
+
title_match: Optional[Match[str]] = re.search(
|
78
|
+
r'<title>(?:.*?-\s*)?([^<]*)</title>', html, re.IGNORECASE | re.DOTALL
|
79
|
+
)
|
80
|
+
if title_match:
|
81
|
+
title: str = title_match.group(1).replace(' ', '')
|
82
|
+
title = re.sub(r'[\\/*?:"<>| ]', '', title)
|
83
|
+
return title
|
84
|
+
return None
|
85
|
+
|
86
|
+
|
87
|
+
def save_file(file_path: str, html: str) -> None:
|
88
|
+
with open(file_path, 'w', encoding='utf-8') as file:
|
89
|
+
file.write(html)
|
90
|
+
print(f'[+] ファイルを保存しました :{file_path}')
|
91
|
+
|
92
|
+
|
93
|
+
def mkdir(path: str) -> None:
|
94
|
+
if not os.path.exists(path):
|
95
|
+
os.makedirs(path)
|
96
|
+
print(f'[+] フォルダー: {path} を作成しました')
|
97
|
+
|
98
|
+
|
99
|
+
class GenerateMode:
|
100
|
+
@staticmethod
|
101
|
+
def gene_path_on_diff(base: str, number: int, diff: Diff) -> str:
|
102
|
+
return os.path.join(base, diff.name, str(number))
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def gene_path_on_num(base: str, number: int, diff: Diff) -> str:
|
106
|
+
return os.path.join(base, str(number), diff.name)
|
107
|
+
|
108
|
+
|
109
|
+
def generate_problem_directory(
|
110
|
+
base_path: str, problems: List[Problem], gene_path: Callable[[str, int, Diff], str]
|
111
|
+
) -> None:
|
112
|
+
for problem in problems:
|
113
|
+
dir_path = gene_path(base_path, problem.number, problem.difficulty)
|
114
|
+
|
115
|
+
html = get_problem_html(problem)
|
116
|
+
if html is None:
|
117
|
+
continue
|
118
|
+
|
119
|
+
title = get_title_from_html(html)
|
120
|
+
if title is None:
|
121
|
+
print('[Error] タイトルが取得できませんでした')
|
122
|
+
title = f'problem{problem.number}{problem.difficulty.value}'
|
123
|
+
|
124
|
+
mkdir(dir_path)
|
125
|
+
repaired_html = repair_html(html)
|
126
|
+
|
127
|
+
html_path = os.path.join(dir_path, f'{title}.html')
|
128
|
+
save_file(html_path, repaired_html)
|
129
|
+
md = make_problem_markdown(html, 'ja')
|
130
|
+
save_file(os.path.join(dir_path, f'{title}.md'), md)
|
131
|
+
|
132
|
+
|
133
|
+
def parse_range(range_str: str) -> List[int]:
|
134
|
+
match = re.match(r'^(\d+)\.\.(\d+)$', range_str)
|
135
|
+
if match:
|
136
|
+
start, end = map(int, match.groups())
|
137
|
+
return list(range(start, end + 1))
|
138
|
+
else:
|
139
|
+
raise ValueError('Invalid range format')
|
140
|
+
|
141
|
+
|
142
|
+
def parse_diff_range(range_str: str) -> List[Diff]:
|
143
|
+
match = re.match(r'^([A-F])\.\.([A-F])$', range_str)
|
144
|
+
if match:
|
145
|
+
start, end = match.groups()
|
146
|
+
start_index = ord(start) - ord('A')
|
147
|
+
end_index = ord(end) - ord('A')
|
148
|
+
if start_index <= end_index:
|
149
|
+
return [Diff(chr(i + ord('A'))) for i in range(start_index, end_index + 1)]
|
150
|
+
raise ValueError('A..C の形式になっていません')
|
151
|
+
|
152
|
+
|
153
|
+
def convert_arg(arg: Union[str, int]) -> Union[List[int], List[Diff]]:
|
154
|
+
if isinstance(arg, int):
|
155
|
+
return [arg]
|
156
|
+
elif isinstance(arg, str):
|
157
|
+
if arg.isdigit():
|
158
|
+
return [int(arg)]
|
159
|
+
elif arg in Diff.__members__:
|
160
|
+
return [Diff[arg]]
|
161
|
+
elif re.match(r'^\d+\.\.\d+$', arg):
|
162
|
+
return parse_range(arg)
|
163
|
+
elif re.match(r'^[A-F]\.\.[A-F]$', arg):
|
164
|
+
return parse_diff_range(arg)
|
165
|
+
raise ValueError(f'{arg}は認識できません')
|
166
|
+
|
167
|
+
|
168
|
+
def are_all_integers(args: Union[List[int], List[Diff]]) -> bool:
|
169
|
+
return all(isinstance(arg, int) for arg in args)
|
170
|
+
|
171
|
+
|
172
|
+
def are_all_diffs(args: Union[List[int], List[Diff]]) -> bool:
|
173
|
+
return all(isinstance(arg, Diff) for arg in args)
|
174
|
+
|
175
|
+
|
176
|
+
def download(
|
177
|
+
first: Union[str, int, None] = None,
|
178
|
+
second: Union[str, int, None] = None,
|
179
|
+
base_path: str = '.',
|
180
|
+
) -> None:
|
181
|
+
if first is None:
|
182
|
+
main()
|
183
|
+
return
|
184
|
+
|
185
|
+
first_args = convert_arg(str(first))
|
186
|
+
if second is None:
|
187
|
+
if isinstance(first, Diff):
|
188
|
+
raise ValueError(
|
189
|
+
"""難易度だけでなく, 問題番号も指定してコマンドを実行してください.
|
190
|
+
例 atcdr -d A 120 : A問題の120をダウンロードます
|
191
|
+
例 atcdr -d A 120..130 : A問題の120から130をダウンロードます
|
192
|
+
"""
|
193
|
+
)
|
194
|
+
second_args: Union[List[int], List[Diff]] = list(Diff)
|
195
|
+
else:
|
196
|
+
second_args = convert_arg(str(second))
|
197
|
+
|
198
|
+
if are_all_integers(first_args) and are_all_diffs(second_args):
|
199
|
+
first_args_int = cast(List[int], first_args)
|
200
|
+
second_args_diff = cast(List[Diff], second_args)
|
201
|
+
problems = [
|
202
|
+
Problem(number, diff)
|
203
|
+
for number in first_args_int
|
204
|
+
for diff in second_args_diff
|
205
|
+
]
|
206
|
+
generate_problem_directory(base_path, problems, GenerateMode.gene_path_on_num)
|
207
|
+
elif are_all_diffs(first_args) and are_all_integers(second_args):
|
208
|
+
first_args_diff = cast(List[Diff], first_args)
|
209
|
+
second_args_int = cast(List[int], second_args)
|
210
|
+
problems = [
|
211
|
+
Problem(number, diff)
|
212
|
+
for diff in first_args_diff
|
213
|
+
for number in second_args_int
|
214
|
+
]
|
215
|
+
generate_problem_directory(base_path, problems, GenerateMode.gene_path_on_diff)
|
216
|
+
else:
|
217
|
+
raise ValueError(
|
218
|
+
"""次のような形式で問題を指定してください
|
219
|
+
例 atcdr -d A 120..130 : A問題の120から130をダウンロードします
|
220
|
+
例 atcdr -d 120 : ABCのコンテストの問題をダウンロードします
|
221
|
+
"""
|
222
|
+
)
|
223
|
+
|
224
|
+
|
225
|
+
def main() -> None:
|
226
|
+
print('AtCoderの問題のHTMLファイルをダウンロードします')
|
227
|
+
print(
|
228
|
+
"""
|
229
|
+
1. 番号の範囲を指定してダウンロードする
|
230
|
+
2. 1ファイルだけダウンロードする
|
231
|
+
q: 終了
|
232
|
+
"""
|
233
|
+
)
|
234
|
+
|
235
|
+
choice = input('選択してください: ')
|
236
|
+
|
237
|
+
if choice == '1':
|
238
|
+
start_end = input(
|
239
|
+
'開始と終了のコンテストの番号をスペースで区切って指定してください (例: 223 230): '
|
240
|
+
)
|
241
|
+
start, end = map(int, start_end.split(' '))
|
242
|
+
difficulty = Diff[
|
243
|
+
input(
|
244
|
+
'ダウンロードする問題の難易度を指定してください (例: A, B, C): '
|
245
|
+
).upper()
|
246
|
+
]
|
247
|
+
problem_list = [Problem(number, difficulty) for number in range(start, end + 1)]
|
248
|
+
generate_problem_directory('.', problem_list, GenerateMode.gene_path_on_diff)
|
249
|
+
elif choice == '2':
|
250
|
+
number = int(input('コンテストの番号を指定してください: '))
|
251
|
+
difficulty = Diff[
|
252
|
+
input(
|
253
|
+
'ダウンロードする問題の難易度を指定してください (例: A, B, C): '
|
254
|
+
).upper()
|
255
|
+
]
|
256
|
+
generate_problem_directory(
|
257
|
+
'.', [Problem(number, difficulty)], GenerateMode.gene_path_on_diff
|
258
|
+
)
|
259
|
+
elif choice == 'q':
|
260
|
+
print('終了します')
|
261
|
+
else:
|
262
|
+
print('無効な選択です')
|
263
|
+
|
264
|
+
|
265
|
+
if __name__ == '__main__':
|
266
|
+
main()
|
atcdr/generate.py
ADDED
@@ -0,0 +1,167 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import re
|
4
|
+
|
5
|
+
from atcdr.test import (
|
6
|
+
ResultStatus,
|
7
|
+
create_testcases_from_html,
|
8
|
+
judge_code_from,
|
9
|
+
render_result,
|
10
|
+
)
|
11
|
+
from atcdr.util.filename import (
|
12
|
+
FILE_EXTENSIONS,
|
13
|
+
Filename,
|
14
|
+
Lang,
|
15
|
+
execute_files,
|
16
|
+
lang2str,
|
17
|
+
str2lang,
|
18
|
+
)
|
19
|
+
from atcdr.util.gpt import ChatGPT, set_api_key
|
20
|
+
from atcdr.util.problem import make_problem_markdown
|
21
|
+
|
22
|
+
|
23
|
+
def get_code_from_gpt_output(output: str) -> str:
|
24
|
+
pattern = re.compile(r'```(?:\w+)?\s*(.*?)\s*```', re.DOTALL)
|
25
|
+
match = pattern.search(output)
|
26
|
+
return match.group(1) if match else ''
|
27
|
+
|
28
|
+
|
29
|
+
def generate_code(file: Filename, lang: Lang) -> None:
|
30
|
+
with open(file, 'r') as f:
|
31
|
+
html_content = f.read()
|
32
|
+
md = make_problem_markdown(html_content, 'en')
|
33
|
+
|
34
|
+
if set_api_key() is None:
|
35
|
+
return
|
36
|
+
gpt = ChatGPT(
|
37
|
+
system_prompt=f"""You are an excellent programmer. You solve problems in competitive programming.When a user provides you with a problem from a programming contest called AtCoder, including the Problem,Constraints, Input, Output, Input Example, and Output Example, please carefully consider these and solve the problem.Make sure that your output code block contains no more than two blocks. Pay close attention to the Input, Input Example, Output, and Output Example.Create the solution in {lang2str(lang)}.""",
|
38
|
+
)
|
39
|
+
|
40
|
+
reply = gpt.tell(md)
|
41
|
+
code = get_code_from_gpt_output(reply)
|
42
|
+
print(f'AI利用にかかったAPIコスト: {gpt.sum_cost}')
|
43
|
+
|
44
|
+
saved_filename = (
|
45
|
+
os.path.splitext(file)[0] + f'_by_{gpt.model.value}' + FILE_EXTENSIONS[lang]
|
46
|
+
)
|
47
|
+
with open(saved_filename, 'w') as f:
|
48
|
+
print(f'[+]:{gpt.model.value}の出力したコードを保存しました:{f.name}')
|
49
|
+
f.write(code)
|
50
|
+
|
51
|
+
|
52
|
+
def generate_template(file: Filename, lang: Lang) -> None:
|
53
|
+
with open(file, 'r') as f:
|
54
|
+
html_content = f.read()
|
55
|
+
md = make_problem_markdown(html_content, 'en')
|
56
|
+
|
57
|
+
if set_api_key() is None:
|
58
|
+
return
|
59
|
+
gpt = ChatGPT(
|
60
|
+
system_prompt='You are a highly skilled programmer. Your role is to create a template code for competitive programming.',
|
61
|
+
temperature=0.0,
|
62
|
+
)
|
63
|
+
|
64
|
+
propmpt = f"""
|
65
|
+
The user will provide a problem from a programming contest called AtCoder. This problem will include the Problem Statement, Constraints, Input, Output, Input Example, and Output Example. You should focus on the Constraints and Input sections to create the template in {lang2str(lang)}.
|
66
|
+
|
67
|
+
- First, create the part of the code that handles input. Then, you should read ###Input Block and ###Constraints Block.
|
68
|
+
- After receiving the input, define variables in the program by reading ###Constraints Block and explain how to use the variables in the comment of your code block with example.
|
69
|
+
- Last, define variables needed for output. Then you should read ###Output Block and ###Constraints Block.
|
70
|
+
|
71
|
+
You must not solve the problem. Please faithfully reproduce the variable names defined in the problem.
|
72
|
+
"""
|
73
|
+
reply = gpt.tell(md + propmpt)
|
74
|
+
code = get_code_from_gpt_output(reply)
|
75
|
+
print(f'AI利用にかかったAPIコスト:{gpt.sum_cost}')
|
76
|
+
|
77
|
+
savaed_filename = os.path.splitext(file)[0] + FILE_EXTENSIONS[lang]
|
78
|
+
with open(savaed_filename, 'w') as f:
|
79
|
+
print(f'[+]:テンプレートファイル{savaed_filename}を作成しました.')
|
80
|
+
f.write(code)
|
81
|
+
|
82
|
+
|
83
|
+
def solve_problem(file: Filename, lang: Lang) -> None:
|
84
|
+
with open(file, 'r') as f:
|
85
|
+
html_content = f.read()
|
86
|
+
md = make_problem_markdown(html_content, 'en')
|
87
|
+
labeled_cases = create_testcases_from_html(html_content)
|
88
|
+
|
89
|
+
if set_api_key() is None:
|
90
|
+
return
|
91
|
+
gpt = ChatGPT(
|
92
|
+
system_prompt=f"""You are a brilliant programmer. Your task is to solve an AtCoder problem. AtCoder is a platform that hosts programming competitions where participants write programs to solve algorithmic challenges.Please solve the problem in {lang2str(lang)}.""",
|
93
|
+
)
|
94
|
+
|
95
|
+
file_without_ext = os.path.splitext(file)[0]
|
96
|
+
|
97
|
+
reply = gpt.tell(md)
|
98
|
+
|
99
|
+
for i in range(1, 4):
|
100
|
+
code = get_code_from_gpt_output(reply)
|
101
|
+
|
102
|
+
saved_filename = (
|
103
|
+
f'{i}_'
|
104
|
+
+ file_without_ext
|
105
|
+
+ f'_by_{gpt.model.value}'
|
106
|
+
+ FILE_EXTENSIONS[lang]
|
107
|
+
)
|
108
|
+
with open(saved_filename, 'w') as f:
|
109
|
+
print(f'[+]:{gpt.model.value}の出力したコードを保存しました:{f.name}')
|
110
|
+
f.write(code)
|
111
|
+
|
112
|
+
labeled_results = judge_code_from(labeled_cases, saved_filename)
|
113
|
+
test_report = '\n'.join(render_result(lresult) for lresult in labeled_results)
|
114
|
+
|
115
|
+
print(f'{i}回目のコード生成でのテスト結果:---')
|
116
|
+
print(test_report)
|
117
|
+
|
118
|
+
if all(
|
119
|
+
labeled_result.result.passed == ResultStatus.AC
|
120
|
+
for labeled_result in labeled_results
|
121
|
+
):
|
122
|
+
print('コードのテストに成功!')
|
123
|
+
break
|
124
|
+
else:
|
125
|
+
reply = gpt.tell(f"""The following is the test report for the code you provided:
|
126
|
+
{test_report}
|
127
|
+
Please provide an updated version of the code in {lang2str(lang)}.""")
|
128
|
+
|
129
|
+
with open(
|
130
|
+
'log_'
|
131
|
+
+ file_without_ext
|
132
|
+
+ f'_by_{gpt.model.value}'
|
133
|
+
+ FILE_EXTENSIONS[Lang.JSON],
|
134
|
+
'w',
|
135
|
+
) as f:
|
136
|
+
print(f'[+]:{gpt.model.value}の出力のログを保存しました:{f.name}')
|
137
|
+
f.write(json.dumps(gpt.messages, indent=2))
|
138
|
+
print(f'AI利用にかかったAPIコスト:{gpt.sum_cost}')
|
139
|
+
return
|
140
|
+
|
141
|
+
|
142
|
+
def generate(
|
143
|
+
*source: str,
|
144
|
+
lang: str = 'Python',
|
145
|
+
without_test: bool = False,
|
146
|
+
template: bool = False,
|
147
|
+
) -> None:
|
148
|
+
la = str2lang(lang)
|
149
|
+
|
150
|
+
if template:
|
151
|
+
execute_files(
|
152
|
+
*source,
|
153
|
+
func=lambda file: generate_template(file, la),
|
154
|
+
target_filetypes=[Lang.HTML],
|
155
|
+
)
|
156
|
+
elif without_test:
|
157
|
+
execute_files(
|
158
|
+
*source,
|
159
|
+
func=lambda file: generate_code(file, la),
|
160
|
+
target_filetypes=[Lang.HTML],
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
execute_files(
|
164
|
+
*source,
|
165
|
+
func=lambda file: solve_problem(file, la),
|
166
|
+
target_filetypes=[Lang.HTML],
|
167
|
+
)
|
atcdr/main.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
from importlib.metadata import metadata
|
2
|
+
|
3
|
+
import fire # type: ignore
|
4
|
+
|
5
|
+
from atcdr.download import download
|
6
|
+
from atcdr.generate import generate
|
7
|
+
from atcdr.open import open_files
|
8
|
+
from atcdr.test import test
|
9
|
+
|
10
|
+
|
11
|
+
def get_version() -> None:
|
12
|
+
meta = metadata('AtCoderStudyBooster')
|
13
|
+
print(meta['Name'], meta['Version'])
|
14
|
+
|
15
|
+
|
16
|
+
MAP_COMMANDS: dict = {
|
17
|
+
'test': test,
|
18
|
+
't': test,
|
19
|
+
'download': download,
|
20
|
+
'd': download,
|
21
|
+
'open': open_files,
|
22
|
+
'o': open_files,
|
23
|
+
'generate': generate,
|
24
|
+
'g': generate,
|
25
|
+
'--version': get_version,
|
26
|
+
'-v': get_version,
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def main():
|
31
|
+
fire.Fire(MAP_COMMANDS)
|
32
|
+
|
33
|
+
|
34
|
+
if __name__ == '__main__':
|
35
|
+
main()
|
atcdr/open.py
ADDED
@@ -0,0 +1,36 @@
|
|
1
|
+
import webbrowser
|
2
|
+
|
3
|
+
from bs4 import BeautifulSoup as bs
|
4
|
+
from bs4.element import Tag
|
5
|
+
|
6
|
+
from atcdr.util.filename import Lang, execute_files
|
7
|
+
|
8
|
+
|
9
|
+
def find_link_from(html: str) -> str | None:
|
10
|
+
soup = bs(html, 'html.parser')
|
11
|
+
meta_tag = soup.find('meta', property='og:url')
|
12
|
+
if isinstance(meta_tag, Tag) and 'content' in meta_tag.attrs:
|
13
|
+
content = meta_tag['content']
|
14
|
+
if isinstance(content, list):
|
15
|
+
return content[0] # 必要に応じて、最初の要素を返す
|
16
|
+
return content
|
17
|
+
return None
|
18
|
+
|
19
|
+
|
20
|
+
def open_html(file: str) -> None:
|
21
|
+
try:
|
22
|
+
with open(file, 'r') as f:
|
23
|
+
html_content = f.read()
|
24
|
+
except FileNotFoundError:
|
25
|
+
print(f"HTMLファイル '{file}' が見つかりません。")
|
26
|
+
return
|
27
|
+
|
28
|
+
url = find_link_from(html_content)
|
29
|
+
if url:
|
30
|
+
webbrowser.open(url)
|
31
|
+
else:
|
32
|
+
print('URLが見つかりませんでした。')
|
33
|
+
|
34
|
+
|
35
|
+
def open_files(*args: str) -> None:
|
36
|
+
execute_files(*args, func=open_html, target_filetypes=[Lang.HTML])
|
atcdr/test.py
ADDED
@@ -0,0 +1,276 @@
|
|
1
|
+
import os
|
2
|
+
import subprocess
|
3
|
+
import tempfile
|
4
|
+
import time
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from enum import Enum
|
7
|
+
from typing import Callable, Dict, List, Optional, Union
|
8
|
+
|
9
|
+
import colorama
|
10
|
+
from bs4 import BeautifulSoup as bs
|
11
|
+
from colorama import Fore
|
12
|
+
|
13
|
+
from atcdr.util.filename import FILE_EXTENSIONS, SOURCE_LANGUAGES, Lang, execute_files
|
14
|
+
|
15
|
+
colorama.init(autoreset=True)
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class TestCase:
|
20
|
+
input: str
|
21
|
+
output: str
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class LabeledTestCase:
|
26
|
+
label: str
|
27
|
+
case: TestCase
|
28
|
+
|
29
|
+
|
30
|
+
class ResultStatus(Enum):
|
31
|
+
CE = 'Compilation Error'
|
32
|
+
MLE = 'Memory Limit Exceeded'
|
33
|
+
TLE = 'Time Limit Exceeded'
|
34
|
+
RE = 'Runtime Error'
|
35
|
+
WA = 'Wrong Answer'
|
36
|
+
AC = 'Accepted'
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class TestCaseResult:
|
41
|
+
output: str
|
42
|
+
executed_time: Union[int, None]
|
43
|
+
# memory_usage: Union[int, None]
|
44
|
+
passed: ResultStatus
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class LabeledTestCaseResult:
|
49
|
+
label: str
|
50
|
+
testcase: TestCase
|
51
|
+
# TODO : 実はラベル自体を使わない方がいいかもしれない.ラベルという概念が削除してプリントするときに適当にTest1, Test2と適当に名前をつけてもいいかも.
|
52
|
+
result: TestCaseResult
|
53
|
+
|
54
|
+
|
55
|
+
def create_testcases_from_html(html: str) -> List[LabeledTestCase]:
|
56
|
+
soup = bs(html, 'html.parser')
|
57
|
+
test_cases = []
|
58
|
+
|
59
|
+
for i in range(1, 20):
|
60
|
+
sample_input_section = soup.find('h3', text=f'Sample Input {i}')
|
61
|
+
sample_output_section = soup.find('h3', text=f'Sample Output {i}')
|
62
|
+
if not sample_input_section or not sample_output_section:
|
63
|
+
break
|
64
|
+
|
65
|
+
sample_input_pre = sample_input_section.find_next('pre')
|
66
|
+
sample_output_pre = sample_output_section.find_next('pre')
|
67
|
+
|
68
|
+
sample_input = (
|
69
|
+
sample_input_pre.get_text(strip=True)
|
70
|
+
if sample_input_pre is not None
|
71
|
+
else ''
|
72
|
+
)
|
73
|
+
sample_output = (
|
74
|
+
sample_output_pre.get_text(strip=True)
|
75
|
+
if sample_output_pre is not None
|
76
|
+
else ''
|
77
|
+
)
|
78
|
+
|
79
|
+
test_case = TestCase(input=sample_input, output=sample_output)
|
80
|
+
labeled_test_case = LabeledTestCase(label=f'Sample {i}', case=test_case)
|
81
|
+
test_cases.append(labeled_test_case)
|
82
|
+
|
83
|
+
return test_cases
|
84
|
+
|
85
|
+
|
86
|
+
def run_code(cmd: list, case: TestCase) -> TestCaseResult:
|
87
|
+
try:
|
88
|
+
start_time = time.time()
|
89
|
+
proc = subprocess.run(
|
90
|
+
cmd, input=case.input, text=True, capture_output=True, timeout=4
|
91
|
+
)
|
92
|
+
end_time = time.time()
|
93
|
+
|
94
|
+
execution_time = int((end_time - start_time) * 1000)
|
95
|
+
|
96
|
+
if proc.returncode != 0:
|
97
|
+
return TestCaseResult(
|
98
|
+
output=proc.stderr, executed_time=None, passed=ResultStatus.RE
|
99
|
+
)
|
100
|
+
|
101
|
+
actual_output = proc.stdout.strip()
|
102
|
+
expected_output = case.output.strip()
|
103
|
+
|
104
|
+
if actual_output != expected_output:
|
105
|
+
return TestCaseResult(
|
106
|
+
output=actual_output,
|
107
|
+
executed_time=execution_time,
|
108
|
+
passed=ResultStatus.WA,
|
109
|
+
)
|
110
|
+
|
111
|
+
return TestCaseResult(
|
112
|
+
output=actual_output, executed_time=execution_time, passed=ResultStatus.AC
|
113
|
+
)
|
114
|
+
except subprocess.TimeoutExpired:
|
115
|
+
return TestCaseResult(
|
116
|
+
output='Time Limit Exceeded', executed_time=None, passed=ResultStatus.TLE
|
117
|
+
)
|
118
|
+
except Exception as e:
|
119
|
+
return TestCaseResult(output=str(e), executed_time=None, passed=ResultStatus.RE)
|
120
|
+
|
121
|
+
|
122
|
+
def run_python(path: str, case: TestCase) -> TestCaseResult:
|
123
|
+
return run_code(['python3', path], case)
|
124
|
+
|
125
|
+
|
126
|
+
def run_javascript(path: str, case: TestCase) -> TestCaseResult:
|
127
|
+
return run_code(['node', path], case)
|
128
|
+
|
129
|
+
|
130
|
+
def run_c(path: str, case: TestCase) -> TestCaseResult:
|
131
|
+
with tempfile.NamedTemporaryFile(delete=True) as tmp:
|
132
|
+
exec_path = tmp.name
|
133
|
+
compile_result = subprocess.run(
|
134
|
+
['gcc', path, '-o', exec_path], capture_output=True, text=True
|
135
|
+
)
|
136
|
+
if compile_result.returncode != 0:
|
137
|
+
return TestCaseResult(
|
138
|
+
output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
|
139
|
+
)
|
140
|
+
if compile_result.stderr:
|
141
|
+
print(f'コンパイラーからのメッセージ\n{compile_result.stderr}')
|
142
|
+
return run_code([exec_path], case)
|
143
|
+
|
144
|
+
|
145
|
+
def run_cpp(path: str, case: TestCase) -> TestCaseResult:
|
146
|
+
with tempfile.NamedTemporaryFile(delete=True) as tmp:
|
147
|
+
exec_path = tmp.name
|
148
|
+
compile_result = subprocess.run(
|
149
|
+
['g++', path, '-o', exec_path], capture_output=True, text=True
|
150
|
+
)
|
151
|
+
if compile_result.returncode != 0:
|
152
|
+
return TestCaseResult(
|
153
|
+
output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
|
154
|
+
)
|
155
|
+
if compile_result.stderr:
|
156
|
+
print(f'コンパイラーからのメッセージ\n{compile_result.stderr}')
|
157
|
+
return run_code([exec_path], case)
|
158
|
+
|
159
|
+
|
160
|
+
def run_rust(path: str, case: TestCase) -> TestCaseResult:
|
161
|
+
with tempfile.NamedTemporaryFile(delete=True) as tmp:
|
162
|
+
exec_path = tmp.name
|
163
|
+
compile_result = subprocess.run(
|
164
|
+
['rustc', path, '-o', exec_path], capture_output=True, text=True
|
165
|
+
)
|
166
|
+
if compile_result.returncode != 0:
|
167
|
+
return TestCaseResult(
|
168
|
+
output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
|
169
|
+
)
|
170
|
+
if compile_result.stderr:
|
171
|
+
print(f'コンパイラーからのメッセージ\n{compile_result.stderr}')
|
172
|
+
return run_code([exec_path], case)
|
173
|
+
|
174
|
+
|
175
|
+
def run_java(path: str, case: TestCase) -> TestCaseResult:
|
176
|
+
compile_result = subprocess.run(['javac', path], capture_output=True, text=True)
|
177
|
+
if compile_result.returncode != 0:
|
178
|
+
return TestCaseResult(
|
179
|
+
output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
|
180
|
+
)
|
181
|
+
class_file = os.path.splitext(path)[0]
|
182
|
+
try:
|
183
|
+
return run_code(['java', class_file], case)
|
184
|
+
finally:
|
185
|
+
class_path = class_file + '.class'
|
186
|
+
if os.path.exists(class_path):
|
187
|
+
os.remove(class_path)
|
188
|
+
|
189
|
+
|
190
|
+
LANGUAGE_RUNNERS: Dict[Lang, Callable[[str, TestCase], TestCaseResult]] = {
|
191
|
+
Lang.PYTHON: run_python,
|
192
|
+
Lang.JAVASCRIPT: run_javascript,
|
193
|
+
Lang.C: run_c,
|
194
|
+
Lang.CPP: run_cpp,
|
195
|
+
Lang.RUST: run_rust,
|
196
|
+
Lang.JAVA: run_java,
|
197
|
+
}
|
198
|
+
|
199
|
+
|
200
|
+
def choose_lang(path: str) -> Optional[Callable[[str, TestCase], TestCaseResult]]:
|
201
|
+
ext = os.path.splitext(path)[1]
|
202
|
+
lang = next(
|
203
|
+
(lang for lang, extension in FILE_EXTENSIONS.items() if extension == ext), None
|
204
|
+
)
|
205
|
+
# lang が None でない場合のみ get を呼び出す
|
206
|
+
if lang is not None:
|
207
|
+
return LANGUAGE_RUNNERS.get(lang)
|
208
|
+
return None
|
209
|
+
|
210
|
+
|
211
|
+
def judge_code_from(
|
212
|
+
lcases: List[LabeledTestCase], path: str
|
213
|
+
) -> List[LabeledTestCaseResult]:
|
214
|
+
runner = choose_lang(path)
|
215
|
+
if runner is None:
|
216
|
+
raise ValueError(f'ランナーが見つかりませんでした。指定されたパス: {path}')
|
217
|
+
|
218
|
+
return [
|
219
|
+
LabeledTestCaseResult(lcase.label, lcase.case, runner(path, lcase.case))
|
220
|
+
for lcase in lcases
|
221
|
+
]
|
222
|
+
|
223
|
+
|
224
|
+
CHECK_MARK = '\u2713'
|
225
|
+
CROSS_MARK = '\u00d7'
|
226
|
+
|
227
|
+
|
228
|
+
def render_result(lresult: LabeledTestCaseResult) -> str:
|
229
|
+
output = f'{Fore.CYAN}{lresult.label} of Test:\n'
|
230
|
+
result = lresult.result
|
231
|
+
testcase = lresult.testcase
|
232
|
+
|
233
|
+
if result.passed == ResultStatus.AC:
|
234
|
+
output += (
|
235
|
+
Fore.GREEN + f'{CHECK_MARK} Accepted !! Time: {result.executed_time} ms\n'
|
236
|
+
)
|
237
|
+
elif result.passed == ResultStatus.WA:
|
238
|
+
output += (
|
239
|
+
Fore.RED
|
240
|
+
+ f'{CROSS_MARK} Wrong Answer ! Time: {result.executed_time} ms\nOutput:\n{result.output}\nExpected Output:\n{testcase.output}\n'
|
241
|
+
)
|
242
|
+
elif result.passed == ResultStatus.RE:
|
243
|
+
output += Fore.YELLOW + f'[RE] Runtime Error\n Output:\n{result.output}'
|
244
|
+
elif result.passed == ResultStatus.TLE:
|
245
|
+
output += Fore.YELLOW + '[TLE] Time Limit Exceeded\n'
|
246
|
+
elif result.passed == ResultStatus.CE:
|
247
|
+
output += Fore.YELLOW + f'[CE] Compile Error\n Output:\n{result.output}'
|
248
|
+
elif result.passed == ResultStatus.MLE:
|
249
|
+
output += Fore.YELLOW + '[ME] Memory Limit Exceeded\n'
|
250
|
+
|
251
|
+
output += Fore.RESET
|
252
|
+
|
253
|
+
return output
|
254
|
+
|
255
|
+
|
256
|
+
def run_test(path_of_code: str) -> None:
|
257
|
+
html_paths = [f for f in os.listdir('.') if f.endswith('.html')]
|
258
|
+
if not html_paths:
|
259
|
+
print(
|
260
|
+
'問題のファイルが見つかりません。\n問題のファイルが存在するディレクトリーに移動してから実行してください。'
|
261
|
+
)
|
262
|
+
return
|
263
|
+
|
264
|
+
with open(html_paths[0], 'r') as file:
|
265
|
+
html = file.read()
|
266
|
+
|
267
|
+
test_cases = create_testcases_from_html(html)
|
268
|
+
print(f'{path_of_code}をテストします。\n' + '-' * 20 + '\n')
|
269
|
+
test_results = judge_code_from(test_cases, path_of_code)
|
270
|
+
output = '\n'.join(render_result(lresult) for lresult in test_results)
|
271
|
+
|
272
|
+
print(output)
|
273
|
+
|
274
|
+
|
275
|
+
def test(*args: str) -> None:
|
276
|
+
execute_files(*args, func=run_test, target_filetypes=SOURCE_LANGUAGES)
|
atcdr/util/__init__.py
ADDED
File without changes
|
atcdr/util/cost.py
ADDED
@@ -0,0 +1,120 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import tiktoken
|
5
|
+
import yfinance as yf # type: ignore
|
6
|
+
|
7
|
+
|
8
|
+
class Model(Enum):
|
9
|
+
GPT4O = 'gpt-4o'
|
10
|
+
GPT4O_MINI = 'gpt-4o-mini'
|
11
|
+
|
12
|
+
|
13
|
+
class CostType(Enum):
|
14
|
+
INPUT = 'input'
|
15
|
+
OUTPUT = 'output'
|
16
|
+
|
17
|
+
|
18
|
+
class Currency:
|
19
|
+
def __init__(
|
20
|
+
self, usd: Optional[float] = None, jpy: Optional[float] = None
|
21
|
+
) -> None:
|
22
|
+
self._exchange_rate = self.get_exchange_rate()
|
23
|
+
self._usd = usd if usd is not None else 0.0
|
24
|
+
self._jpy = jpy if jpy is not None else self.convert_usd_to_jpy(self._usd)
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def get_exchange_rate() -> float:
|
28
|
+
ticker = yf.Ticker('USDJPY=X')
|
29
|
+
todays_data = ticker.history(period='1d')
|
30
|
+
return todays_data['Close'].iloc[0]
|
31
|
+
|
32
|
+
def convert_usd_to_jpy(self, usd: float) -> float:
|
33
|
+
return usd * self._exchange_rate
|
34
|
+
|
35
|
+
def convert_jpy_to_usd(self, jpy: float) -> float:
|
36
|
+
return jpy / self._exchange_rate
|
37
|
+
|
38
|
+
@property
|
39
|
+
def usd(self) -> float:
|
40
|
+
return self._usd
|
41
|
+
|
42
|
+
@usd.setter
|
43
|
+
def usd(self, value: float) -> None:
|
44
|
+
self._usd = value
|
45
|
+
self._jpy = self.convert_usd_to_jpy(value)
|
46
|
+
|
47
|
+
@property
|
48
|
+
def jpy(self) -> float:
|
49
|
+
return self._jpy
|
50
|
+
|
51
|
+
@jpy.setter
|
52
|
+
def jpy(self, value: float) -> None:
|
53
|
+
self._jpy = value
|
54
|
+
self._usd = self.convert_jpy_to_usd(value)
|
55
|
+
|
56
|
+
def __add__(self, other: 'Currency') -> 'Currency':
|
57
|
+
return Currency(usd=self.usd + other.usd)
|
58
|
+
|
59
|
+
def __sub__(self, other: 'Currency') -> 'Currency':
|
60
|
+
return Currency(usd=self.usd - other.usd)
|
61
|
+
|
62
|
+
def __mul__(self, factor: float) -> 'Currency':
|
63
|
+
return Currency(usd=self.usd * factor)
|
64
|
+
|
65
|
+
def __truediv__(self, factor: float) -> 'Currency':
|
66
|
+
return Currency(usd=self.usd / factor)
|
67
|
+
|
68
|
+
def __eq__(self, other: object) -> bool:
|
69
|
+
if not isinstance(other, Currency):
|
70
|
+
return NotImplemented
|
71
|
+
epsilon = 1e-9 # 許容範囲
|
72
|
+
return abs(self.usd - other.usd) < epsilon
|
73
|
+
|
74
|
+
def __lt__(self, other: 'Currency') -> bool:
|
75
|
+
return self.usd < other.usd
|
76
|
+
|
77
|
+
def __repr__(self) -> str:
|
78
|
+
return f'Currency(usd={self.usd:.2f}, jpy={self.jpy:.2f})'
|
79
|
+
|
80
|
+
def __str__(self) -> str:
|
81
|
+
return f'USD: {self.usd:.2f}, JPY: {self.jpy:.2f}'
|
82
|
+
|
83
|
+
|
84
|
+
class Rate:
|
85
|
+
_COST_RATES = {
|
86
|
+
Model.GPT4O: {CostType.INPUT: 5 / 1000**2, CostType.OUTPUT: 15 / 1000**2},
|
87
|
+
Model.GPT4O_MINI: {
|
88
|
+
CostType.INPUT: 0.15 / 1000**2,
|
89
|
+
CostType.OUTPUT: 0.60 / 1000**2,
|
90
|
+
},
|
91
|
+
}
|
92
|
+
|
93
|
+
@staticmethod
|
94
|
+
def calc_cost(model: Model, cost_type: CostType, token_count: int) -> Currency:
|
95
|
+
cost_in_usd = Rate._COST_RATES[model][cost_type] * token_count
|
96
|
+
return Currency(usd=cost_in_usd)
|
97
|
+
|
98
|
+
|
99
|
+
class ApiCostCalculator:
|
100
|
+
def __init__(self, text: str, cost_type: CostType, model: Model) -> None:
|
101
|
+
self.text = text
|
102
|
+
self.cost_type = cost_type
|
103
|
+
self.model = model
|
104
|
+
|
105
|
+
# トークンモデルを取得
|
106
|
+
self.token_model = tiktoken.encoding_for_model(model.value)
|
107
|
+
|
108
|
+
@property
|
109
|
+
def token_count(self) -> int:
|
110
|
+
return len(self.token_model.encode(self.text))
|
111
|
+
|
112
|
+
@property
|
113
|
+
def cost(self) -> Currency:
|
114
|
+
return Rate.calc_cost(self.model, self.cost_type, self.token_count)
|
115
|
+
|
116
|
+
def __str__(self) -> str:
|
117
|
+
return f'Token count: {self.token_count}\nCost ({self.cost_type.value}): ${self.cost.usd:.2f} / ¥{self.cost.jpy:.2f}'
|
118
|
+
|
119
|
+
def __repr__(self) -> str:
|
120
|
+
return f'ApiCostCalculator(text={self.text}, cost_type={self.cost_type}, model={self.model})'
|
atcdr/util/filename.py
ADDED
@@ -0,0 +1,137 @@
|
|
1
|
+
import os
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Callable, Dict, List, TypeAlias
|
4
|
+
|
5
|
+
# ファイル名と拡張子の型エイリアスを定義
|
6
|
+
Filename: TypeAlias = str
|
7
|
+
Extension: TypeAlias = str
|
8
|
+
|
9
|
+
|
10
|
+
class Lang(Enum):
|
11
|
+
PYTHON = 'Python'
|
12
|
+
JAVASCRIPT = 'JavaScript'
|
13
|
+
JAVA = 'Java'
|
14
|
+
C = 'C'
|
15
|
+
CPP = 'C++'
|
16
|
+
CSHARP = 'C#'
|
17
|
+
RUBY = 'Ruby'
|
18
|
+
PHP = 'php'
|
19
|
+
GO = 'Go'
|
20
|
+
RUST = 'Rust'
|
21
|
+
HTML = 'HTML'
|
22
|
+
MARKDOWN = 'markdown'
|
23
|
+
JSON = 'json'
|
24
|
+
|
25
|
+
|
26
|
+
# ファイル拡張子と対応する言語の辞書
|
27
|
+
FILE_EXTENSIONS: Dict[Lang, Extension] = {
|
28
|
+
Lang.PYTHON: '.py',
|
29
|
+
Lang.JAVASCRIPT: '.js',
|
30
|
+
Lang.JAVA: '.java',
|
31
|
+
Lang.C: '.c',
|
32
|
+
Lang.CPP: '.cpp',
|
33
|
+
Lang.CSHARP: '.cs',
|
34
|
+
Lang.RUBY: '.rb',
|
35
|
+
Lang.PHP: '.php',
|
36
|
+
Lang.GO: '.go',
|
37
|
+
Lang.RUST: '.rs',
|
38
|
+
Lang.HTML: '.html',
|
39
|
+
Lang.MARKDOWN: '.md',
|
40
|
+
Lang.JSON: '.json',
|
41
|
+
}
|
42
|
+
|
43
|
+
DOCUMENT_LANGUAGES: List[Lang] = [
|
44
|
+
Lang.HTML,
|
45
|
+
Lang.MARKDOWN,
|
46
|
+
Lang.JSON,
|
47
|
+
]
|
48
|
+
|
49
|
+
# ソースコードファイルと言語のリスト
|
50
|
+
SOURCE_LANGUAGES: List[Lang] = [
|
51
|
+
Lang.PYTHON,
|
52
|
+
Lang.JAVASCRIPT,
|
53
|
+
Lang.JAVA,
|
54
|
+
Lang.C,
|
55
|
+
Lang.CPP,
|
56
|
+
Lang.CSHARP,
|
57
|
+
Lang.RUBY,
|
58
|
+
Lang.PHP,
|
59
|
+
Lang.GO,
|
60
|
+
Lang.RUST,
|
61
|
+
]
|
62
|
+
|
63
|
+
|
64
|
+
def str2lang(lang: str) -> Lang:
|
65
|
+
lang_map = {
|
66
|
+
'py': Lang.PYTHON,
|
67
|
+
'python': Lang.PYTHON,
|
68
|
+
'js': Lang.JAVASCRIPT,
|
69
|
+
'javascript': Lang.JAVASCRIPT,
|
70
|
+
'java': Lang.JAVA,
|
71
|
+
'c': Lang.C,
|
72
|
+
'cpp': Lang.CPP,
|
73
|
+
'c++': Lang.CPP,
|
74
|
+
'csharp': Lang.CSHARP,
|
75
|
+
'c#': Lang.CSHARP,
|
76
|
+
'rb': Lang.RUBY,
|
77
|
+
'ruby': Lang.RUBY,
|
78
|
+
'php': Lang.PHP,
|
79
|
+
'go': Lang.GO,
|
80
|
+
'rs': Lang.RUST,
|
81
|
+
'rust': Lang.RUST,
|
82
|
+
'html': Lang.HTML,
|
83
|
+
'md': Lang.MARKDOWN,
|
84
|
+
'markdown': Lang.MARKDOWN,
|
85
|
+
'json': Lang.JSON,
|
86
|
+
}
|
87
|
+
return lang_map[lang.lower()]
|
88
|
+
|
89
|
+
|
90
|
+
def lang2str(lang: Lang) -> str:
|
91
|
+
return lang.value
|
92
|
+
|
93
|
+
|
94
|
+
def execute_files(
|
95
|
+
*args: str, func: Callable[[Filename], None], target_filetypes: List[Lang]
|
96
|
+
) -> None:
|
97
|
+
target_extensions = [FILE_EXTENSIONS[lang] for lang in target_filetypes]
|
98
|
+
|
99
|
+
files = [
|
100
|
+
file
|
101
|
+
for file in os.listdir('.')
|
102
|
+
if os.path.isfile(file) and os.path.splitext(file)[1] in target_extensions
|
103
|
+
]
|
104
|
+
|
105
|
+
if not files:
|
106
|
+
print(
|
107
|
+
'対象のファイルが見つかりません.\n対象ファイルが存在するディレクトリーに移動してから実行してください。'
|
108
|
+
)
|
109
|
+
return
|
110
|
+
|
111
|
+
if not args:
|
112
|
+
if len(files) == 1:
|
113
|
+
func(files[0])
|
114
|
+
else:
|
115
|
+
print('複数のファイルが見つかりました。以下のファイルから選択してください:')
|
116
|
+
for i, file in enumerate(files):
|
117
|
+
print(f'{i + 1}. {file}')
|
118
|
+
choice = int(input('ファイル番号を入力してください: ')) - 1
|
119
|
+
if 0 <= choice < len(files):
|
120
|
+
func(files[choice])
|
121
|
+
else:
|
122
|
+
print('無効な選択です')
|
123
|
+
else:
|
124
|
+
target_files = set()
|
125
|
+
for arg in args:
|
126
|
+
if arg == '*':
|
127
|
+
target_files.update(files)
|
128
|
+
elif arg.startswith('*.'):
|
129
|
+
ext = arg[1:] # ".py" のような拡張子を取得
|
130
|
+
target_files.update(file for file in files if file.endswith(ext))
|
131
|
+
else:
|
132
|
+
if arg in files:
|
133
|
+
target_files.add(arg)
|
134
|
+
else:
|
135
|
+
print(f'エラー: {arg} は存在しません。')
|
136
|
+
|
137
|
+
list(map(func, target_files))
|
atcdr/util/gpt.py
ADDED
@@ -0,0 +1,112 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Dict, List, Optional
|
3
|
+
|
4
|
+
import requests
|
5
|
+
|
6
|
+
from atcdr.util.cost import CostType, Currency, Model, Rate
|
7
|
+
|
8
|
+
|
9
|
+
def set_api_key() -> Optional[str]:
|
10
|
+
api_key = os.getenv('OPENAI_API_KEY')
|
11
|
+
if api_key and validate_api_key(api_key):
|
12
|
+
return api_key
|
13
|
+
elif api_key:
|
14
|
+
print('環境変数に設定されているAPIキーの検証に失敗しました ')
|
15
|
+
else:
|
16
|
+
pass
|
17
|
+
|
18
|
+
api_key = input(
|
19
|
+
'https://platform.openai.com/api-keys からchatGPTのAPIキーを入手しましょう。\nAPIキー入力してください: '
|
20
|
+
)
|
21
|
+
if validate_api_key(api_key):
|
22
|
+
print('APIキーのテストに成功しました。')
|
23
|
+
print('以下, ~/.zshrcにAPIキーを保存しますか? [y/n]')
|
24
|
+
if input() == 'y':
|
25
|
+
zshrc_path = os.path.expanduser('~/.zshrc')
|
26
|
+
with open(zshrc_path, 'a') as f:
|
27
|
+
f.write(f'export OPENAI_API_KEY={api_key}\n')
|
28
|
+
print(
|
29
|
+
f'APIキーを {zshrc_path} に保存しました。次回シェル起動時に読み込まれます。'
|
30
|
+
)
|
31
|
+
os.environ['OPENAI_API_KEY'] = api_key
|
32
|
+
return api_key
|
33
|
+
else:
|
34
|
+
print('コード生成にはAPIキーが必要です。')
|
35
|
+
return None
|
36
|
+
|
37
|
+
|
38
|
+
def validate_api_key(api_key: str) -> bool:
|
39
|
+
headers = {
|
40
|
+
'Content-Type': 'application/json',
|
41
|
+
'Authorization': f'Bearer {api_key}',
|
42
|
+
}
|
43
|
+
|
44
|
+
response = requests.get('https://api.openai.com/v1/models', headers=headers)
|
45
|
+
|
46
|
+
if response.status_code == 200:
|
47
|
+
return True
|
48
|
+
else:
|
49
|
+
print('APIキーの検証に失敗しました。')
|
50
|
+
return False
|
51
|
+
|
52
|
+
|
53
|
+
class ChatGPT:
|
54
|
+
API_URL = 'https://api.openai.com/v1/chat/completions'
|
55
|
+
|
56
|
+
# APIの使い方 https://platform.openai.com/docs/api-reference/making-requests
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
api_key: Optional[str] = None,
|
60
|
+
model: Model = Model.GPT4O_MINI,
|
61
|
+
max_tokens: int = 3000,
|
62
|
+
temperature: float = 0.7,
|
63
|
+
messages: Optional[List[Dict[str, str]]] = None,
|
64
|
+
system_prompt: str = 'You are a helpful assistant.',
|
65
|
+
) -> None:
|
66
|
+
self.api_key = api_key or os.getenv('OPENAI_API_KEY')
|
67
|
+
self.model = model
|
68
|
+
self.max_tokens = max_tokens
|
69
|
+
self.temperature = temperature
|
70
|
+
self.messages = (
|
71
|
+
messages
|
72
|
+
if messages is not None
|
73
|
+
else [{'role': 'system', 'content': system_prompt}]
|
74
|
+
)
|
75
|
+
|
76
|
+
self.sum_cost: Currency = Currency(usd=0)
|
77
|
+
self.__headers = {
|
78
|
+
'Content-Type': 'application/json',
|
79
|
+
'Authorization': f'Bearer {self.api_key}',
|
80
|
+
}
|
81
|
+
|
82
|
+
def tell(self, message: str) -> str:
|
83
|
+
self.messages.append({'role': 'user', 'content': message})
|
84
|
+
|
85
|
+
settings = {
|
86
|
+
'model': self.model.value,
|
87
|
+
'messages': self.messages,
|
88
|
+
'max_tokens': self.max_tokens,
|
89
|
+
'temperature': self.temperature,
|
90
|
+
}
|
91
|
+
|
92
|
+
response = requests.post(self.API_URL, headers=self.__headers, json=settings)
|
93
|
+
responsej = response.json()
|
94
|
+
try:
|
95
|
+
reply = responsej['choices'][0]['message']['content']
|
96
|
+
except KeyError:
|
97
|
+
print('Error:レスポンスの形式が正しくありません. \n' + str(responsej))
|
98
|
+
return 'Error: Unable to retrieve response.'
|
99
|
+
|
100
|
+
self.messages.append({'role': 'assistant', 'content': reply})
|
101
|
+
|
102
|
+
usage = responsej['usage']
|
103
|
+
input_tokens = usage.get('prompt_tokens', 0)
|
104
|
+
output_tokens = usage.get('completion_tokens', 0)
|
105
|
+
self.sum_cost += Rate.calc_cost(
|
106
|
+
model=self.model, cost_type=CostType.INPUT, token_count=input_tokens
|
107
|
+
)
|
108
|
+
self.sum_cost += Rate.calc_cost(
|
109
|
+
model=self.model, cost_type=CostType.OUTPUT, token_count=output_tokens
|
110
|
+
)
|
111
|
+
|
112
|
+
return reply
|
atcdr/util/problem.py
ADDED
@@ -0,0 +1,87 @@
|
|
1
|
+
import re
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
from bs4 import BeautifulSoup as bs
|
6
|
+
from bs4 import NavigableString, Tag
|
7
|
+
from markdownify import MarkdownConverter # type: ignore
|
8
|
+
|
9
|
+
|
10
|
+
# TODO : そのうちgenerate.pyやtest.py, open.pyのHTMLのparse処理を全部まとめる
|
11
|
+
class Lang(Enum):
|
12
|
+
JA = 'ja'
|
13
|
+
EN = 'en'
|
14
|
+
|
15
|
+
|
16
|
+
class ProblemStruct:
|
17
|
+
def __init__(self) -> None:
|
18
|
+
self.problem_part: Optional[str] = None
|
19
|
+
self.condition_part: Optional[str] = None
|
20
|
+
self.io_part: Optional[str] = None
|
21
|
+
self.test_part: Optional[list[str]] = None
|
22
|
+
|
23
|
+
def divide_problem_part(self, task_statement: Union[Tag, NavigableString]) -> None:
|
24
|
+
if not isinstance(task_statement, Tag):
|
25
|
+
return
|
26
|
+
|
27
|
+
parts = task_statement.find_all('div', {'class': 'part'})
|
28
|
+
|
29
|
+
if len(parts) >= 2:
|
30
|
+
self.problem_part = str(parts[0])
|
31
|
+
self.condition_part = str(parts[1])
|
32
|
+
|
33
|
+
io_div = task_statement.find('div', {'class': 'io-style'})
|
34
|
+
if isinstance(io_div, Tag):
|
35
|
+
io_parts = io_div.find_all('div', {'class': 'part'})
|
36
|
+
|
37
|
+
if len(io_parts) > 0:
|
38
|
+
self.io_part = str(
|
39
|
+
io_parts[0]
|
40
|
+
) # .find_all() はリストを返すので、str()でキャスト
|
41
|
+
|
42
|
+
# 2つ目以降のdivをtest_partに格納
|
43
|
+
self.test_part = [str(part) for part in io_parts[1:]]
|
44
|
+
|
45
|
+
|
46
|
+
class CustomMarkdownConverter(MarkdownConverter):
|
47
|
+
def convert_var(self, el, text, convert_as_inline):
|
48
|
+
var_text = el.text.strip()
|
49
|
+
return f'\\({var_text}\\)'
|
50
|
+
|
51
|
+
def convert_pre(self, el, text, convert_as_inline):
|
52
|
+
pre_text = el.text.strip()
|
53
|
+
return f'```\n{pre_text}\n```'
|
54
|
+
|
55
|
+
|
56
|
+
def custom_markdownify(html, **options):
|
57
|
+
return CustomMarkdownConverter(**options).convert(html)
|
58
|
+
|
59
|
+
|
60
|
+
def remove_unnecessary_emptylines(md_text):
|
61
|
+
md_text = re.sub(r'\n\s*\n\s*\n+', '\n\n', md_text)
|
62
|
+
md_text = md_text.strip()
|
63
|
+
return md_text
|
64
|
+
|
65
|
+
|
66
|
+
def abstract_problem_part(html_content: str, lang: str) -> str:
|
67
|
+
soup = bs(html_content, 'html.parser')
|
68
|
+
task_statement = soup.find('div', {'id': 'task-statement'})
|
69
|
+
|
70
|
+
if not isinstance(task_statement, Tag):
|
71
|
+
return ''
|
72
|
+
|
73
|
+
if lang == 'ja':
|
74
|
+
lang_class = 'lang-ja'
|
75
|
+
elif lang == 'en':
|
76
|
+
lang_class = 'lang-en'
|
77
|
+
else:
|
78
|
+
pass
|
79
|
+
span = task_statement.find('span', {'class': lang_class})
|
80
|
+
return str(span)
|
81
|
+
|
82
|
+
|
83
|
+
def make_problem_markdown(html_content: str, lang: str) -> str:
|
84
|
+
problem_part = abstract_problem_part(html_content, lang)
|
85
|
+
problem_md = custom_markdownify(problem_part)
|
86
|
+
problem_md = remove_unnecessary_emptylines(problem_md)
|
87
|
+
return problem_md
|
@@ -0,0 +1,94 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: AtCoderStudyBooster
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A tool to download and manage AtCoder problems.
|
5
|
+
Author-email: yuta6 <46110512+yuta6@users.noreply.github.com>
|
6
|
+
Requires-Python: >=3.8
|
7
|
+
Requires-Dist: beautifulsoup4
|
8
|
+
Requires-Dist: colorama
|
9
|
+
Requires-Dist: fire
|
10
|
+
Requires-Dist: markdownify>=0.13.1
|
11
|
+
Requires-Dist: requests
|
12
|
+
Requires-Dist: tiktoken
|
13
|
+
Requires-Dist: types-beautifulsoup4>=4.12.0.20240511
|
14
|
+
Requires-Dist: types-colorama>=0.4.15.20240311
|
15
|
+
Requires-Dist: types-requests>=2.32.0.20240712
|
16
|
+
Requires-Dist: yfinance
|
17
|
+
Description-Content-Type: text/markdown
|
18
|
+
|
19
|
+
# AtCoderStudyBooster
|
20
|
+
|
21
|
+
## 概要
|
22
|
+
|
23
|
+
AtCoderStudyBoosterはAtCoderの学習を加速させるためのツールです。問題をローカルにダウンロードし、テスト、解答の作成をサポートするツールです。Pythonが入っていることが必須です。Pythonが入っている環境なら、`pip install AtCoderStudyBooster`でインストールできます。
|
24
|
+
|
25
|
+
このツールは以下のプロジェクトに強く影響を受けています。
|
26
|
+
[online-judge-tools](https://github.com/online-judge-tools)
|
27
|
+
[atcoder-cli](https://github.com/Tatamo/atcoder-cli)
|
28
|
+
これらとの違いですが、本ツールはAtCoderでのコンテストでの利用は想定しておらず、初心者の学習のサポートのみを意識しています。そのため、現時点で提出機能は備えていません。また, Chat GPT APIによる解答の作成サポート機能を備えています。
|
29
|
+
|
30
|
+
## 利用ケース
|
31
|
+
|
32
|
+
### B問題の練習したい場合
|
33
|
+
|
34
|
+
ABCコンテストの223から226のB問題だけを集中的に練習したい場合、次のコマンドを実行します。
|
35
|
+
|
36
|
+
```sh
|
37
|
+
atcdr download B 223..226
|
38
|
+
```
|
39
|
+
|
40
|
+
コマンドを実行すると,次のようなフォルダーを作成して、各々のフォルダーに問題をダウンロードします。
|
41
|
+
|
42
|
+
```css
|
43
|
+
B
|
44
|
+
├── 223
|
45
|
+
│ ├── StringShifting.html
|
46
|
+
│ └── StringShifting.md
|
47
|
+
├── 224
|
48
|
+
│ ├── Mongeness.html
|
49
|
+
│ └── Mongeness.md
|
50
|
+
├── 225
|
51
|
+
│ ├── StarorNot.html
|
52
|
+
│ └── StarorNot.md
|
53
|
+
└── 226
|
54
|
+
├── CountingArrays.html
|
55
|
+
└── CountingArrays.md
|
56
|
+
```
|
57
|
+
|
58
|
+
MarkdownファイルあるいはHTMLファイルをVS CodeのHTML Preview, Markdown Previewで開くと問題を確認できます。VS Codeで開くと左側にテキストエディターを表示して、右側で問題をみながら問題に取り組めます。
|
59
|
+
|
60
|
+

|
61
|
+
|
62
|
+
### サンプルをローカルでテストする
|
63
|
+
|
64
|
+
問題をダウンロードしたフォルダーに移動します。
|
65
|
+
|
66
|
+
```sh
|
67
|
+
cd B/224
|
68
|
+
```
|
69
|
+
|
70
|
+
移動したフォルダーで解答ファイルを作成後を実行すると, 自動的にテストします。
|
71
|
+
|
72
|
+
```sh
|
73
|
+
▷ ~/.../B/224
|
74
|
+
atcdr t
|
75
|
+
```
|
76
|
+
|
77
|
+
```sh
|
78
|
+
solution.pyをテストします。
|
79
|
+
--------------------
|
80
|
+
|
81
|
+
Sample 1 of Test:
|
82
|
+
✓ Accepted !! Time: 24 ms
|
83
|
+
|
84
|
+
Sample 2 of Test:
|
85
|
+
✓ Accepted !! Time: 15 ms
|
86
|
+
```
|
87
|
+
|
88
|
+
と実行すると作成したソースコードをテストして、HTMLに書かれているテストケースを読み込んで実行し, Passするかを判定します。
|
89
|
+
|
90
|
+
## 解答生成機能generateコマンドに関する注意点
|
91
|
+
|
92
|
+
本ツールにはChatGPT APIを利用したコード生成機能があります。[AtCoder生成AI対策ルール](https://info.atcoder.jp/entry/llm-abc-rules-ja?_gl=1*1axgs02*_ga*ODc0NDAyNjA4LjE3MTk1ODEyNDA.*_ga_RC512FD18N*MTcyMzMxNDA1Ni43NC4xLjE3MjMzMTY1NjUuMC4wLjA.)によるとAtCoder Beginner Contestにおいてに問題文を生成AIに直接与えることは禁止されています。ただし、このルールは過去問を練習している際には適用されません。
|
93
|
+
|
94
|
+
現時点で本ツールにはログイン機能がないため、コンテスト中の問題に対して`download`コマンドは利用して問題をダウンロードすることはできません。`generate`コマンドは`download`コマンドに依存しており、ダウンロードした問題のHTMLファイルをパースしてGPTに解釈しやすいmarkdownを与えることで実現しています。したがって、このコマンドがAtCoder Beginner Contest中に[AtCoder生成AI対策ルール](https://info.atcoder.jp/entry/llm-abc-rules-ja?_gl=1*1axgs02*_ga*ODc0NDAyNjA4LjE3MTk1ODEyNDA.*_ga_RC512FD18N*MTcyMzMxNDA1Ni43NC4xLjE3MjMzMTY1NjUuMC4wLjA.)に抵触することはありません。
|
@@ -0,0 +1,15 @@
|
|
1
|
+
atcdr/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
atcdr/download.py,sha256=UYOkKoFaN7Rj2p-13MOjvgq26QzdvoNmWKBaUW6cCKs,7995
|
3
|
+
atcdr/generate.py,sha256=Bh0QHRUVeI8u5FvXbJssbs6gr55XtUkNT2p897FlUgs,5521
|
4
|
+
atcdr/main.py,sha256=i-TFxFk7bFMtKZxtDgI7aPoZAF-dsXqNoz3O_ZsGvb4,605
|
5
|
+
atcdr/open.py,sha256=vbOy3fthklhZ7_WIWNGyS2H3iK2FHLeClDt_tloJ_b0,924
|
6
|
+
atcdr/test.py,sha256=it3QjFxdlR0GY6Hc0c2Qdke71Z4dj5eLhfuWVZU9cZA,7969
|
7
|
+
atcdr/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
atcdr/util/cost.py,sha256=0c9H8zLley7xZDLuYU4zJmB8m71qcO1WEIQOoEavD_4,3168
|
9
|
+
atcdr/util/filename.py,sha256=taCgSwIpB5iCjWZrYWAGRncRyUUl9exoNfsP-KLF2bs,2984
|
10
|
+
atcdr/util/gpt.py,sha256=Lto6SJHZGer8cC_Nq8lJVnaET2R7apFQteo6ZEFpjdM,3304
|
11
|
+
atcdr/util/problem.py,sha256=iDfNGfoCk_sy9RQRZ4vVqd1ViyT8HSWe_ekKUb4PdKs,2412
|
12
|
+
atcoderstudybooster-0.1.0.dist-info/METADATA,sha256=VO3J_117RJuCbT7euQAOyb375fMS65cfJMI_lORtquE,4397
|
13
|
+
atcoderstudybooster-0.1.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
14
|
+
atcoderstudybooster-0.1.0.dist-info/entry_points.txt,sha256=_bhz0R7vp2VubKl_eIokDO8Wz9TdqvYA7Q59uWfy6Sk,42
|
15
|
+
atcoderstudybooster-0.1.0.dist-info/RECORD,,
|