AtCoderStudyBooster 0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atcdr/download.py +251 -240
- atcdr/generate.py +184 -193
- atcdr/login.py +133 -0
- atcdr/logout.py +24 -0
- atcdr/main.py +24 -17
- atcdr/markdown.py +22 -22
- atcdr/open.py +30 -30
- atcdr/submit.py +297 -0
- atcdr/test.py +382 -244
- atcdr/util/execute.py +52 -52
- atcdr/util/filetype.py +81 -67
- atcdr/util/gpt.py +102 -96
- atcdr/util/parse.py +206 -0
- atcdr/util/problem.py +94 -91
- atcdr/util/session.py +140 -0
- atcoderstudybooster-0.3.1.dist-info/METADATA +205 -0
- atcoderstudybooster-0.3.1.dist-info/RECORD +21 -0
- {atcoderstudybooster-0.2.dist-info → atcoderstudybooster-0.3.1.dist-info}/WHEEL +1 -1
- atcdr/util/cost.py +0 -120
- atcoderstudybooster-0.2.dist-info/METADATA +0 -96
- atcoderstudybooster-0.2.dist-info/RECORD +0 -17
- {atcoderstudybooster-0.2.dist-info → atcoderstudybooster-0.3.1.dist-info}/entry_points.txt +0 -0
atcdr/util/execute.py
CHANGED
@@ -8,56 +8,56 @@ from atcdr.util.filetype import FILE_EXTENSIONS, Filename, Lang
|
|
8
8
|
|
9
9
|
|
10
10
|
def execute_files(
|
11
|
-
|
11
|
+
*args: str, func: Callable[[Filename], None], target_filetypes: List[Lang]
|
12
12
|
) -> None:
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
13
|
+
target_extensions = [FILE_EXTENSIONS[lang] for lang in target_filetypes]
|
14
|
+
|
15
|
+
files = [
|
16
|
+
file
|
17
|
+
for file in os.listdir('.')
|
18
|
+
if os.path.isfile(file) and os.path.splitext(file)[1] in target_extensions
|
19
|
+
]
|
20
|
+
|
21
|
+
if not files:
|
22
|
+
print(
|
23
|
+
'対象のファイルが見つかりません.\n対象ファイルが存在するディレクトリーに移動してから実行してください。'
|
24
|
+
)
|
25
|
+
return
|
26
|
+
|
27
|
+
if not args:
|
28
|
+
if len(files) == 1:
|
29
|
+
func(files[0])
|
30
|
+
else:
|
31
|
+
target_file = q.select(
|
32
|
+
message='複数のファイルが見つかりました.ファイルを選択してください:',
|
33
|
+
choices=[q.Choice(title=file, value=file) for file in files],
|
34
|
+
instruction='\n 十字キーで移動, [enter]で実行',
|
35
|
+
pointer='❯❯❯',
|
36
|
+
qmark='',
|
37
|
+
style=q.Style(
|
38
|
+
[
|
39
|
+
('qmark', 'fg:#2196F3 bold'),
|
40
|
+
('question', 'fg:#2196F3 bold'),
|
41
|
+
('answer', 'fg:#FFB300 bold'),
|
42
|
+
('pointer', 'fg:#FFB300 bold'),
|
43
|
+
('highlighted', 'fg:#FFB300 bold'),
|
44
|
+
('selected', 'fg:#FFB300 bold'),
|
45
|
+
]
|
46
|
+
),
|
47
|
+
).ask()
|
48
|
+
list(map(func, [target_file]))
|
49
|
+
else:
|
50
|
+
target_files = set()
|
51
|
+
for arg in args:
|
52
|
+
if arg == '*':
|
53
|
+
target_files.update(files)
|
54
|
+
elif arg.startswith('*.'):
|
55
|
+
ext = arg[1:] # ".py" のような拡張子を取得
|
56
|
+
target_files.update(file for file in files if file.endswith(ext))
|
57
|
+
else:
|
58
|
+
if arg in files:
|
59
|
+
target_files.add(arg)
|
60
|
+
else:
|
61
|
+
print(f'エラー: {arg} は存在しません。')
|
62
|
+
|
63
|
+
list(map(func, target_files))
|
atcdr/util/filetype.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import os
|
1
2
|
from enum import Enum
|
2
3
|
from typing import Dict, List, TypeAlias
|
3
4
|
|
@@ -7,85 +8,98 @@ Extension: TypeAlias = str
|
|
7
8
|
|
8
9
|
|
9
10
|
class Lang(Enum):
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
11
|
+
PYTHON = 'Python'
|
12
|
+
JAVASCRIPT = 'JavaScript'
|
13
|
+
JAVA = 'Java'
|
14
|
+
C = 'C'
|
15
|
+
CPP = 'C++'
|
16
|
+
CSHARP = 'C#'
|
17
|
+
RUBY = 'Ruby'
|
18
|
+
PHP = 'php'
|
19
|
+
GO = 'Go'
|
20
|
+
RUST = 'Rust'
|
21
|
+
HTML = 'HTML'
|
22
|
+
MARKDOWN = 'markdown'
|
23
|
+
JSON = 'json'
|
23
24
|
|
24
25
|
|
25
26
|
# ファイル拡張子と対応する言語の辞書
|
26
|
-
FILE_EXTENSIONS: Dict[Lang,
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
27
|
+
FILE_EXTENSIONS: Dict[Lang, str] = {
|
28
|
+
Lang.PYTHON: '.py',
|
29
|
+
Lang.JAVASCRIPT: '.js',
|
30
|
+
Lang.JAVA: '.java',
|
31
|
+
Lang.C: '.c',
|
32
|
+
Lang.CPP: '.cpp',
|
33
|
+
Lang.CSHARP: '.cs',
|
34
|
+
Lang.RUBY: '.rb',
|
35
|
+
Lang.PHP: '.php',
|
36
|
+
Lang.GO: '.go',
|
37
|
+
Lang.RUST: '.rs',
|
38
|
+
Lang.HTML: '.html',
|
39
|
+
Lang.MARKDOWN: '.md',
|
40
|
+
Lang.JSON: '.json',
|
40
41
|
}
|
41
42
|
|
43
|
+
# ドキュメント言語のリスト
|
42
44
|
DOCUMENT_LANGUAGES: List[Lang] = [
|
43
|
-
|
44
|
-
|
45
|
-
|
45
|
+
Lang.HTML,
|
46
|
+
Lang.MARKDOWN,
|
47
|
+
Lang.JSON,
|
46
48
|
]
|
47
49
|
|
48
|
-
#
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
50
|
+
# コンパイル型言語のリスト
|
51
|
+
COMPILED_LANGUAGES: List[Lang] = [
|
52
|
+
Lang.JAVA,
|
53
|
+
Lang.C,
|
54
|
+
Lang.CPP,
|
55
|
+
Lang.CSHARP,
|
56
|
+
Lang.GO,
|
57
|
+
Lang.RUST,
|
58
|
+
]
|
59
|
+
|
60
|
+
# インタプリター型言語のリスト
|
61
|
+
INTERPRETED_LANGUAGES: List[Lang] = [
|
62
|
+
Lang.PYTHON,
|
63
|
+
Lang.JAVASCRIPT,
|
64
|
+
Lang.RUBY,
|
65
|
+
Lang.PHP,
|
60
66
|
]
|
61
67
|
|
62
68
|
|
63
69
|
def str2lang(lang: str) -> Lang:
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
70
|
+
lang_map = {
|
71
|
+
'py': Lang.PYTHON,
|
72
|
+
'python': Lang.PYTHON,
|
73
|
+
'js': Lang.JAVASCRIPT,
|
74
|
+
'javascript': Lang.JAVASCRIPT,
|
75
|
+
'java': Lang.JAVA,
|
76
|
+
'c': Lang.C,
|
77
|
+
'cpp': Lang.CPP,
|
78
|
+
'c++': Lang.CPP,
|
79
|
+
'csharp': Lang.CSHARP,
|
80
|
+
'cs': Lang.CSHARP,
|
81
|
+
'c#': Lang.CSHARP,
|
82
|
+
'rb': Lang.RUBY,
|
83
|
+
'ruby': Lang.RUBY,
|
84
|
+
'php': Lang.PHP,
|
85
|
+
'go': Lang.GO,
|
86
|
+
'rs': Lang.RUST,
|
87
|
+
'rust': Lang.RUST,
|
88
|
+
'html': Lang.HTML,
|
89
|
+
'md': Lang.MARKDOWN,
|
90
|
+
'markdown': Lang.MARKDOWN,
|
91
|
+
'json': Lang.JSON,
|
92
|
+
}
|
93
|
+
return lang_map[lang.lower()]
|
88
94
|
|
89
95
|
|
90
96
|
def lang2str(lang: Lang) -> str:
|
91
|
-
|
97
|
+
return lang.value
|
98
|
+
|
99
|
+
|
100
|
+
def detect_language(path: str) -> Lang:
|
101
|
+
ext = os.path.splitext(path)[1] # ファイルの拡張子を取得
|
102
|
+
lang = next(
|
103
|
+
(lang for lang, extension in FILE_EXTENSIONS.items() if extension == ext)
|
104
|
+
)
|
105
|
+
return lang
|
atcdr/util/gpt.py
CHANGED
@@ -1,112 +1,118 @@
|
|
1
1
|
import os
|
2
|
+
from enum import Enum
|
2
3
|
from typing import Dict, List, Optional
|
3
4
|
|
4
5
|
import requests
|
5
6
|
|
6
|
-
|
7
|
+
|
8
|
+
class Model(Enum):
|
9
|
+
GPT4O = 'gpt-4o'
|
10
|
+
GPT41 = 'gpt-4.1'
|
11
|
+
GPT41_MINI = 'gpt-4.1-mini'
|
12
|
+
GPT41_NANO = 'gpt-4.1-nano'
|
13
|
+
GPT4O_MINI = 'gpt-4o-mini'
|
14
|
+
O1_PREVIEW = 'o1-preview'
|
15
|
+
O1 = 'o1'
|
16
|
+
O3 = 'o3'
|
17
|
+
O1_MINI = 'o1-mini'
|
18
|
+
O3_MINI = 'o3-mini'
|
19
|
+
O4_MINI = 'o4-mini'
|
7
20
|
|
8
21
|
|
9
22
|
def set_api_key() -> Optional[str]:
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
23
|
+
api_key = os.getenv('OPENAI_API_KEY')
|
24
|
+
if api_key and validate_api_key(api_key):
|
25
|
+
return api_key
|
26
|
+
elif api_key:
|
27
|
+
print('環境変数に設定されているAPIキーの検証に失敗しました ')
|
28
|
+
else:
|
29
|
+
pass
|
30
|
+
|
31
|
+
api_key = input(
|
32
|
+
'https://platform.openai.com/api-keys からchatGPTのAPIキーを入手しましょう。\nAPIキー入力してください: '
|
33
|
+
)
|
34
|
+
if validate_api_key(api_key):
|
35
|
+
print('APIキーのテストに成功しました。')
|
36
|
+
print('以下, ~/.zshrcにAPIキーを保存しますか? [y/n]')
|
37
|
+
if input() == 'y':
|
38
|
+
zshrc_path = os.path.expanduser('~/.zshrc')
|
39
|
+
with open(zshrc_path, 'a') as f:
|
40
|
+
f.write(f'export OPENAI_API_KEY={api_key}\n')
|
41
|
+
print(
|
42
|
+
f'APIキーを {zshrc_path} に保存しました。次回シェル起動時に読み込まれます。'
|
43
|
+
)
|
44
|
+
os.environ['OPENAI_API_KEY'] = api_key
|
45
|
+
return api_key
|
46
|
+
else:
|
47
|
+
print('コード生成にはAPIキーが必要です。')
|
48
|
+
return None
|
36
49
|
|
37
50
|
|
38
51
|
def validate_api_key(api_key: str) -> bool:
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
52
|
+
headers = {
|
53
|
+
'Content-Type': 'application/json',
|
54
|
+
'Authorization': f'Bearer {api_key}',
|
55
|
+
}
|
43
56
|
|
44
|
-
|
57
|
+
response = requests.get('https://api.openai.com/v1/models', headers=headers)
|
45
58
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
59
|
+
if response.status_code == 200:
|
60
|
+
return True
|
61
|
+
else:
|
62
|
+
print('APIキーの検証に失敗しました。')
|
63
|
+
return False
|
51
64
|
|
52
65
|
|
53
66
|
class ChatGPT:
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
model=self.model, cost_type=CostType.INPUT, token_count=input_tokens
|
107
|
-
)
|
108
|
-
self.sum_cost += Rate.calc_cost(
|
109
|
-
model=self.model, cost_type=CostType.OUTPUT, token_count=output_tokens
|
110
|
-
)
|
111
|
-
|
112
|
-
return reply
|
67
|
+
API_URL = 'https://api.openai.com/v1/chat/completions'
|
68
|
+
|
69
|
+
# APIの使い方 https://platform.openai.com/docs/api-reference/making-requests
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
api_key: Optional[str] = None,
|
73
|
+
model: Model = Model.GPT41_MINI,
|
74
|
+
max_tokens: int = 3000,
|
75
|
+
temperature: float = 0.7,
|
76
|
+
messages: Optional[List[Dict[str, str]]] = None,
|
77
|
+
system_prompt: str = 'You are a helpful assistant.',
|
78
|
+
) -> None:
|
79
|
+
self.api_key = api_key or os.getenv('OPENAI_API_KEY')
|
80
|
+
self.model = model
|
81
|
+
self.max_tokens = max_tokens
|
82
|
+
self.temperature = temperature
|
83
|
+
self.messages = (
|
84
|
+
messages
|
85
|
+
if messages is not None
|
86
|
+
else [{'role': 'system', 'content': system_prompt}]
|
87
|
+
)
|
88
|
+
|
89
|
+
self.__headers = {
|
90
|
+
'Content-Type': 'application/json',
|
91
|
+
'Authorization': f'Bearer {self.api_key}',
|
92
|
+
}
|
93
|
+
|
94
|
+
def tell(self, message: str) -> str:
|
95
|
+
self.messages.append({'role': 'user', 'content': message})
|
96
|
+
|
97
|
+
settings = {
|
98
|
+
'model': self.model.value,
|
99
|
+
'messages': self.messages,
|
100
|
+
'max_tokens': self.max_tokens,
|
101
|
+
'temperature': self.temperature,
|
102
|
+
}
|
103
|
+
|
104
|
+
response = requests.post(self.API_URL, headers=self.__headers, json=settings)
|
105
|
+
responsej = response.json()
|
106
|
+
try:
|
107
|
+
reply = responsej['choices'][0]['message']['content']
|
108
|
+
except KeyError:
|
109
|
+
print('Error:レスポンスの形式が正しくありません. \n' + str(responsej))
|
110
|
+
return 'Error: Unable to retrieve response.'
|
111
|
+
|
112
|
+
self.messages.append({'role': 'assistant', 'content': reply})
|
113
|
+
|
114
|
+
# usage = responsej['usage']
|
115
|
+
# input_tokens = usage.get('prompt_tokens', 0)
|
116
|
+
# output_tokens = usage.get('completion_tokens', 0)
|
117
|
+
|
118
|
+
return reply
|