AtCoderStudyBooster 0.3__py3-none-any.whl → 0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atcdr/generate.py +1 -1
- atcdr/test.py +153 -297
- atcdr/util/filetype.py +8 -22
- {atcoderstudybooster-0.3.dist-info → atcoderstudybooster-0.23.dist-info}/METADATA +1 -1
- {atcoderstudybooster-0.3.dist-info → atcoderstudybooster-0.23.dist-info}/RECORD +7 -7
- {atcoderstudybooster-0.3.dist-info → atcoderstudybooster-0.23.dist-info}/WHEEL +0 -0
- {atcoderstudybooster-0.3.dist-info → atcoderstudybooster-0.23.dist-info}/entry_points.txt +0 -0
    
        atcdr/generate.py
    CHANGED
    
    | @@ -140,8 +140,8 @@ def solve_problem(file: Filename, lang: Lang) -> None: | |
| 140 140 |  | 
| 141 141 | 
             
            	for i in range(1, 4):
         | 
| 142 142 | 
             
            		with console.status(f'{i}回目のコード生成 (by {gpt.model.value})...'):
         | 
| 143 | 
            -
            			test_report = ''
         | 
| 144 143 | 
             
            			if i == 1:
         | 
| 144 | 
            +
            				test_report = ''
         | 
| 145 145 | 
             
            				reply = gpt.tell(md)
         | 
| 146 146 | 
             
            			else:
         | 
| 147 147 | 
             
            				reply = gpt.tell(f"""The following is the test report for the code you provided:
         | 
    
        atcdr/test.py
    CHANGED
    
    | @@ -2,28 +2,19 @@ import os | |
| 2 2 | 
             
            import subprocess
         | 
| 3 3 | 
             
            import tempfile
         | 
| 4 4 | 
             
            import time
         | 
| 5 | 
            -
            from dataclasses import dataclass | 
| 5 | 
            +
            from dataclasses import dataclass
         | 
| 6 6 | 
             
            from enum import Enum
         | 
| 7 | 
            -
            from typing import  | 
| 7 | 
            +
            from typing import Callable, Dict, List, Optional, Union
         | 
| 8 8 |  | 
| 9 9 | 
             
            from bs4 import BeautifulSoup as bs
         | 
| 10 | 
            -
            from rich.console import Console | 
| 10 | 
            +
            from rich.console import Console
         | 
| 11 11 | 
             
            from rich.markup import escape
         | 
| 12 12 | 
             
            from rich.panel import Panel
         | 
| 13 | 
            -
            from rich.rule import Rule
         | 
| 14 | 
            -
            from rich.style import Style
         | 
| 15 | 
            -
            from rich.syntax import Syntax
         | 
| 16 13 | 
             
            from rich.table import Table
         | 
| 17 14 | 
             
            from rich.text import Text
         | 
| 18 15 |  | 
| 19 16 | 
             
            from atcdr.util.execute import execute_files
         | 
| 20 | 
            -
            from atcdr.util.filetype import  | 
| 21 | 
            -
            	COMPILED_LANGUAGES,
         | 
| 22 | 
            -
            	INTERPRETED_LANGUAGES,
         | 
| 23 | 
            -
            	Lang,
         | 
| 24 | 
            -
            	detect_language,
         | 
| 25 | 
            -
            	lang2str,
         | 
| 26 | 
            -
            )
         | 
| 17 | 
            +
            from atcdr.util.filetype import FILE_EXTENSIONS, SOURCE_LANGUAGES, Lang
         | 
| 27 18 |  | 
| 28 19 |  | 
| 29 20 | 
             
            @dataclass
         | 
| @@ -39,13 +30,12 @@ class LabeledTestCase: | |
| 39 30 |  | 
| 40 31 |  | 
| 41 32 | 
             
            class ResultStatus(Enum):
         | 
| 42 | 
            -
            	 | 
| 43 | 
            -
            	WA = 'Wrong Answer'
         | 
| 44 | 
            -
            	TLE = 'Time Limit Exceeded'
         | 
| 33 | 
            +
            	CE = 'Compilation Error'
         | 
| 45 34 | 
             
            	MLE = 'Memory Limit Exceeded'
         | 
| 35 | 
            +
            	TLE = 'Time Limit Exceeded'
         | 
| 46 36 | 
             
            	RE = 'Runtime Error'
         | 
| 47 | 
            -
            	 | 
| 48 | 
            -
            	 | 
| 37 | 
            +
            	WA = 'Wrong Answer'
         | 
| 38 | 
            +
            	AC = 'Accepted'
         | 
| 49 39 |  | 
| 50 40 |  | 
| 51 41 | 
             
            @dataclass
         | 
| @@ -60,20 +50,10 @@ class TestCaseResult: | |
| 60 50 | 
             
            class LabeledTestCaseResult:
         | 
| 61 51 | 
             
            	label: str
         | 
| 62 52 | 
             
            	testcase: TestCase
         | 
| 53 | 
            +
            	# TODO : 実はラベル自体を使わない方がいいかもしれない.ラベルという概念が削除してプリントするときに適当にTest1, Test2と適当に名前をつけてもいいかも.
         | 
| 63 54 | 
             
            	result: TestCaseResult
         | 
| 64 55 |  | 
| 65 56 |  | 
| 66 | 
            -
            @dataclass
         | 
| 67 | 
            -
            class TestInformation:
         | 
| 68 | 
            -
            	lang: Lang
         | 
| 69 | 
            -
            	sourcename: str
         | 
| 70 | 
            -
            	case_number: int
         | 
| 71 | 
            -
            	result_summary: ResultStatus = ResultStatus.WJ
         | 
| 72 | 
            -
            	resultlist: List[LabeledTestCaseResult] = field(default_factory=list)  # 修正
         | 
| 73 | 
            -
            	compiler_message: str = ''
         | 
| 74 | 
            -
            	compile_time : Optional[int] = None
         | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 57 | 
             
            def create_testcases_from_html(html: str) -> List[LabeledTestCase]:
         | 
| 78 58 | 
             
            	soup = bs(html, 'html.parser')
         | 
| 79 59 | 
             
            	test_cases = []
         | 
| @@ -106,312 +86,192 @@ def create_testcases_from_html(html: str) -> List[LabeledTestCase]: | |
| 106 86 |  | 
| 107 87 |  | 
| 108 88 | 
             
            def run_code(cmd: list, case: TestCase) -> TestCaseResult:
         | 
| 109 | 
            -
            	start_time = time.time()
         | 
| 110 89 | 
             
            	try:
         | 
| 90 | 
            +
            		start_time = time.time()
         | 
| 111 91 | 
             
            		proc = subprocess.run(
         | 
| 112 92 | 
             
            			cmd, input=case.input, text=True, capture_output=True, timeout=4
         | 
| 113 93 | 
             
            		)
         | 
| 114 | 
            -
            		 | 
| 115 | 
            -
            	except subprocess.TimeoutExpired as e_proc:
         | 
| 116 | 
            -
            		executed_time = int((time.time() - start_time) * 1000)
         | 
| 117 | 
            -
            		stdout_text = e_proc.stdout.decode('utf-8') if e_proc.stdout is not None else ''
         | 
| 118 | 
            -
            		stderr_text = e_proc.stderr.decode('utf-8') if e_proc.stderr is not None else ''
         | 
| 119 | 
            -
            		text = stdout_text + '\n' + stderr_text
         | 
| 120 | 
            -
            		return TestCaseResult(
         | 
| 121 | 
            -
            			output=text, executed_time=executed_time, passed=ResultStatus.TLE
         | 
| 122 | 
            -
            		)
         | 
| 94 | 
            +
            		end_time = time.time()
         | 
| 123 95 |  | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
            		 | 
| 127 | 
            -
            			 | 
| 128 | 
            -
             | 
| 129 | 
            -
            			 | 
| 130 | 
            -
            		)
         | 
| 96 | 
            +
            		execution_time = int((end_time - start_time) * 1000)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            		if proc.returncode != 0:
         | 
| 99 | 
            +
            			return TestCaseResult(
         | 
| 100 | 
            +
            				output=proc.stderr, executed_time=None, passed=ResultStatus.RE
         | 
| 101 | 
            +
            			)
         | 
| 131 102 |  | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 103 | 
            +
            		actual_output = proc.stdout.strip()
         | 
| 104 | 
            +
            		expected_output = case.output.strip()
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            		if actual_output != expected_output:
         | 
| 107 | 
            +
            			return TestCaseResult(
         | 
| 108 | 
            +
            				output=actual_output,
         | 
| 109 | 
            +
            				executed_time=execution_time,
         | 
| 110 | 
            +
            				passed=ResultStatus.WA,
         | 
| 111 | 
            +
            			)
         | 
| 135 112 |  | 
| 136 | 
            -
            	if actual_output != expected_output:
         | 
| 137 113 | 
             
            		return TestCaseResult(
         | 
| 138 | 
            -
            			output=actual_output,
         | 
| 139 | 
            -
            			executed_time=executed_time,
         | 
| 140 | 
            -
            			passed=ResultStatus.WA,
         | 
| 114 | 
            +
            			output=actual_output, executed_time=execution_time, passed=ResultStatus.AC
         | 
| 141 115 | 
             
            		)
         | 
| 142 | 
            -
            	 | 
| 116 | 
            +
            	except subprocess.TimeoutExpired:
         | 
| 143 117 | 
             
            		return TestCaseResult(
         | 
| 144 | 
            -
            			output= | 
| 118 | 
            +
            			output='Time Limit Exceeded', executed_time=None, passed=ResultStatus.TLE
         | 
| 145 119 | 
             
            		)
         | 
| 120 | 
            +
            	except Exception as e:
         | 
| 121 | 
            +
            		return TestCaseResult(output=str(e), executed_time=None, passed=ResultStatus.RE)
         | 
| 146 122 |  | 
| 147 123 |  | 
| 148 | 
            -
             | 
| 149 | 
            -
            	 | 
| 150 | 
            -
            	Lang.JAVASCRIPT: ['node', '{source_path}'],
         | 
| 151 | 
            -
            	Lang.C: ['{exec_path}'],
         | 
| 152 | 
            -
            	Lang.CPP: ['{exec_path}'],
         | 
| 153 | 
            -
            	Lang.RUST: ['{exec_path}'],
         | 
| 154 | 
            -
            	Lang.JAVA: ['java', '{exec_path}'],
         | 
| 155 | 
            -
            }
         | 
| 124 | 
            +
            def run_python(path: str, case: TestCase) -> TestCaseResult:
         | 
| 125 | 
            +
            	return run_code(['python3', path], case)
         | 
| 156 126 |  | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
            	 | 
| 160 | 
            -
            	Lang.RUST: ['rustc', '{source_path}', '-o', '{exec_path}'],
         | 
| 161 | 
            -
            	Lang.JAVA: ['javac', '{source_path}'],
         | 
| 162 | 
            -
            }
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            def run_javascript(path: str, case: TestCase) -> TestCaseResult:
         | 
| 129 | 
            +
            	return run_code(['node', path], case)
         | 
| 163 130 |  | 
| 164 131 |  | 
| 165 | 
            -
            def  | 
| 132 | 
            +
            def run_c(path: str, case: TestCase) -> TestCaseResult:
         | 
| 166 133 | 
             
            	with tempfile.NamedTemporaryFile(delete=True) as tmp:
         | 
| 167 134 | 
             
            		exec_path = tmp.name
         | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
            		 | 
| 171 | 
            -
            	]
         | 
| 172 | 
            -
            	start_time = time.time()
         | 
| 173 | 
            -
            	compile_result = subprocess.run(cmd, capture_output=True, text=True)
         | 
| 174 | 
            -
            	compile_time = int((time.time() - start_time) * 1000)
         | 
| 175 | 
            -
             | 
| 176 | 
            -
            	return exec_path, compile_result, compile_time
         | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
            -
            def judge_code_from(
         | 
| 180 | 
            -
            	lcases: List[LabeledTestCase], path: str
         | 
| 181 | 
            -
            ) -> Generator[
         | 
| 182 | 
            -
            	Union[LabeledTestCaseResult, TestInformation],  # type: ignore
         | 
| 183 | 
            -
            	None,
         | 
| 184 | 
            -
            	None,
         | 
| 185 | 
            -
            ]:
         | 
| 186 | 
            -
            	lang = detect_language(path)
         | 
| 187 | 
            -
            	if lang in COMPILED_LANGUAGES:
         | 
| 188 | 
            -
            		exe_path, compile_result, compile_time = run_compile(path, lang)
         | 
| 135 | 
            +
            		compile_result = subprocess.run(
         | 
| 136 | 
            +
            			['gcc', path, '-o', exec_path], capture_output=True, text=True
         | 
| 137 | 
            +
            		)
         | 
| 189 138 | 
             
            		if compile_result.returncode != 0:
         | 
| 190 | 
            -
            			 | 
| 191 | 
            -
            				 | 
| 192 | 
            -
            				sourcename=path,
         | 
| 193 | 
            -
            				case_number=len(lcases),
         | 
| 194 | 
            -
            				result_summary=ResultStatus.CE,
         | 
| 195 | 
            -
            				compiler_message=compile_result.stderr,
         | 
| 196 | 
            -
            			)
         | 
| 197 | 
            -
            			return
         | 
| 198 | 
            -
            		else:
         | 
| 199 | 
            -
            			yield TestInformation(
         | 
| 200 | 
            -
            				lang=lang,
         | 
| 201 | 
            -
            				sourcename=path,
         | 
| 202 | 
            -
            				case_number=len(lcases),
         | 
| 203 | 
            -
            				compiler_message=compile_result.stderr,
         | 
| 204 | 
            -
            				compile_time=compile_time,
         | 
| 139 | 
            +
            			return TestCaseResult(
         | 
| 140 | 
            +
            				output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
         | 
| 205 141 | 
             
            			)
         | 
| 142 | 
            +
            		return run_code([exec_path], case)
         | 
| 206 143 |  | 
| 207 | 
            -
            			cmd = [
         | 
| 208 | 
            -
            				arg.format(exec_path=exe_path) for arg in LANGUAGE_RUN_COMMANDS[lang]
         | 
| 209 | 
            -
            			]
         | 
| 210 | 
            -
             | 
| 211 | 
            -
            			for lcase in lcases:
         | 
| 212 | 
            -
            				yield LabeledTestCaseResult(
         | 
| 213 | 
            -
            					lcase.label, lcase.case, run_code(cmd, lcase.case)
         | 
| 214 | 
            -
            				)
         | 
| 215 | 
            -
             | 
| 216 | 
            -
            			if os.path.exists(exe_path):
         | 
| 217 | 
            -
            				os.remove(exe_path)
         | 
| 218 144 |  | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
| 223 | 
            -
            			 | 
| 145 | 
            +
            def run_cpp(path: str, case: TestCase) -> TestCaseResult:
         | 
| 146 | 
            +
            	with tempfile.NamedTemporaryFile(delete=True) as tmp:
         | 
| 147 | 
            +
            		exec_path = tmp.name
         | 
| 148 | 
            +
            		compile_result = subprocess.run(
         | 
| 149 | 
            +
            			['g++', path, '-o', exec_path], capture_output=True, text=True
         | 
| 224 150 | 
             
            		)
         | 
| 225 | 
            -
            		 | 
| 226 | 
            -
             | 
| 227 | 
            -
             | 
| 228 | 
            -
            				lcase.label, lcase.case, run_code(cmd, lcase.case)
         | 
| 151 | 
            +
            		if compile_result.returncode != 0:
         | 
| 152 | 
            +
            			return TestCaseResult(
         | 
| 153 | 
            +
            				output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
         | 
| 229 154 | 
             
            			)
         | 
| 230 | 
            -
             | 
| 231 | 
            -
            		raise ValueError('適切な言語が見つかりませんでした.')
         | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
            COLOR_MAP = {
         | 
| 235 | 
            -
            	ResultStatus.AC: 'green',
         | 
| 236 | 
            -
            	ResultStatus.WA: 'red',
         | 
| 237 | 
            -
            	ResultStatus.TLE: 'yellow',
         | 
| 238 | 
            -
            	ResultStatus.MLE: 'yellow',
         | 
| 239 | 
            -
            	ResultStatus.RE: 'yellow',
         | 
| 240 | 
            -
            	ResultStatus.CE: 'yellow',
         | 
| 241 | 
            -
            	ResultStatus.WJ: 'grey',
         | 
| 242 | 
            -
            }
         | 
| 243 | 
            -
             | 
| 244 | 
            -
            STATUS_TEXT_MAP = {
         | 
| 245 | 
            -
            	ResultStatus.AC: Text.assemble(
         | 
| 246 | 
            -
            		('\u2713 ', 'green'),
         | 
| 247 | 
            -
            		(
         | 
| 248 | 
            -
            			f'{ResultStatus.AC.value}',
         | 
| 249 | 
            -
            			Style(bgcolor=COLOR_MAP[ResultStatus.AC], bold=True),
         | 
| 250 | 
            -
            		),
         | 
| 251 | 
            -
            	),
         | 
| 252 | 
            -
            	ResultStatus.WA: Text(
         | 
| 253 | 
            -
            		f'\u00d7 {ResultStatus.WA.value}', style=COLOR_MAP[ResultStatus.WA]
         | 
| 254 | 
            -
            	),
         | 
| 255 | 
            -
            	ResultStatus.TLE: Text(
         | 
| 256 | 
            -
            		f'\u00d7 {ResultStatus.TLE.value}', style=COLOR_MAP[ResultStatus.TLE]
         | 
| 257 | 
            -
            	),
         | 
| 258 | 
            -
            	ResultStatus.MLE: Text(
         | 
| 259 | 
            -
            		f'\u00d7 {ResultStatus.MLE.value}', style=COLOR_MAP[ResultStatus.MLE]
         | 
| 260 | 
            -
            	),
         | 
| 261 | 
            -
            	ResultStatus.RE: Text(
         | 
| 262 | 
            -
            		f'\u00d7 {ResultStatus.RE.value}', style=COLOR_MAP[ResultStatus.RE]
         | 
| 263 | 
            -
            	),
         | 
| 264 | 
            -
            	ResultStatus.CE: Text(
         | 
| 265 | 
            -
            		f'\u00d7 {ResultStatus.CE.value}', style=COLOR_MAP[ResultStatus.CE]
         | 
| 266 | 
            -
            	),
         | 
| 267 | 
            -
            	ResultStatus.WJ: Text(
         | 
| 268 | 
            -
            		f'\u23f3 {ResultStatus.WJ.value}', style=COLOR_MAP[ResultStatus.WJ]
         | 
| 269 | 
            -
            	),
         | 
| 270 | 
            -
            }
         | 
| 271 | 
            -
             | 
| 272 | 
            -
             | 
| 273 | 
            -
            def create_renderable_test_info(test_info: TestInformation) -> RenderableType:
         | 
| 274 | 
            -
            	components = []
         | 
| 155 | 
            +
            		return run_code([exec_path], case)
         | 
| 275 156 |  | 
| 276 | 
            -
            	success_count = sum(
         | 
| 277 | 
            -
            		1 for result in test_info.resultlist if result.result.passed == ResultStatus.AC
         | 
| 278 | 
            -
            	)
         | 
| 279 | 
            -
            	total_count = test_info.case_number
         | 
| 280 | 
            -
             | 
| 281 | 
            -
            	# 結果に応じたスタイル付きのテキストを取得
         | 
| 282 | 
            -
            	status_text = STATUS_TEXT_MAP[test_info.result_summary]
         | 
| 283 157 |  | 
| 284 | 
            -
             | 
| 285 | 
            -
            	 | 
| 286 | 
            -
            		 | 
| 287 | 
            -
            		 | 
| 288 | 
            -
             | 
| 289 | 
            -
            		Text.from_markup(
         | 
| 290 | 
            -
            			f'  [{COLOR_MAP[test_info.result_summary]} bold]{success_count}[/] / [white bold]{total_count}[/]'
         | 
| 291 | 
            -
            		),
         | 
| 292 | 
            -
            	)
         | 
| 293 | 
            -
             | 
| 294 | 
            -
            	components.append(Panel(header_text, expand=False))
         | 
| 295 | 
            -
             | 
| 296 | 
            -
            	if test_info.compiler_message:
         | 
| 297 | 
            -
            		rule = Rule(
         | 
| 298 | 
            -
            			title='コンパイラーのメッセージ',
         | 
| 299 | 
            -
            			style=COLOR_MAP[ResultStatus.CE],
         | 
| 300 | 
            -
            		)
         | 
| 301 | 
            -
            		components.append(rule)
         | 
| 302 | 
            -
            		error_message = Syntax(
         | 
| 303 | 
            -
            			test_info.compiler_message, lang2str(test_info.lang), line_numbers=False
         | 
| 158 | 
            +
            def run_rust(path: str, case: TestCase) -> TestCaseResult:
         | 
| 159 | 
            +
            	with tempfile.NamedTemporaryFile(delete=True) as tmp:
         | 
| 160 | 
            +
            		exec_path = tmp.name
         | 
| 161 | 
            +
            		compile_result = subprocess.run(
         | 
| 162 | 
            +
            			['rustc', path, '-o', exec_path], capture_output=True, text=True
         | 
| 304 163 | 
             
            		)
         | 
| 305 | 
            -
            		 | 
| 306 | 
            -
             | 
| 307 | 
            -
             | 
| 308 | 
            -
             | 
| 164 | 
            +
            		if compile_result.returncode != 0:
         | 
| 165 | 
            +
            			return TestCaseResult(
         | 
| 166 | 
            +
            				output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
         | 
| 167 | 
            +
            			)
         | 
| 168 | 
            +
            		return run_code([exec_path], case)
         | 
| 309 169 |  | 
| 310 | 
            -
            def update_test_info(
         | 
| 311 | 
            -
            	test_info: TestInformation, test_result: LabeledTestCaseResult
         | 
| 312 | 
            -
            ) -> None:
         | 
| 313 | 
            -
            	test_info.resultlist.append(test_result)
         | 
| 314 170 |  | 
| 315 | 
            -
             | 
| 316 | 
            -
             | 
| 317 | 
            -
             | 
| 318 | 
            -
            		 | 
| 319 | 
            -
             | 
| 320 | 
            -
            		 | 
| 321 | 
            -
             | 
| 322 | 
            -
            	 | 
| 171 | 
            +
            def run_java(path: str, case: TestCase) -> TestCaseResult:
         | 
| 172 | 
            +
            	compile_result = subprocess.run(['javac', path], capture_output=True, text=True)
         | 
| 173 | 
            +
            	if compile_result.returncode != 0:
         | 
| 174 | 
            +
            		return TestCaseResult(
         | 
| 175 | 
            +
            			output=compile_result.stderr, executed_time=None, passed=ResultStatus.CE
         | 
| 176 | 
            +
            		)
         | 
| 177 | 
            +
            	class_file = os.path.splitext(path)[0]
         | 
| 178 | 
            +
            	try:
         | 
| 179 | 
            +
            		return run_code(['java', class_file], case)
         | 
| 180 | 
            +
            	finally:
         | 
| 181 | 
            +
            		class_path = class_file + '.class'
         | 
| 182 | 
            +
            		if os.path.exists(class_path):
         | 
| 183 | 
            +
            			os.remove(class_path)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            LANGUAGE_RUNNERS: Dict[Lang, Callable[[str, TestCase], TestCaseResult]] = {
         | 
| 187 | 
            +
            	Lang.PYTHON: run_python,
         | 
| 188 | 
            +
            	Lang.JAVASCRIPT: run_javascript,
         | 
| 189 | 
            +
            	Lang.C: run_c,
         | 
| 190 | 
            +
            	Lang.CPP: run_cpp,
         | 
| 191 | 
            +
            	Lang.RUST: run_rust,
         | 
| 192 | 
            +
            	Lang.JAVA: run_java,
         | 
| 193 | 
            +
            }
         | 
| 323 194 |  | 
| 324 | 
            -
            	# 現在の結果の中で最も高い優先順位のステータスを見つける
         | 
| 325 | 
            -
            	highest_priority_status = (
         | 
| 326 | 
            -
            		test_info.result_summary
         | 
| 327 | 
            -
            	)  # デフォルトはWJまたは現在のサマリー
         | 
| 328 | 
            -
            	for result in test_info.resultlist:
         | 
| 329 | 
            -
            		status = result.result.passed
         | 
| 330 | 
            -
            		if priority_order.index(status) < priority_order.index(highest_priority_status):
         | 
| 331 | 
            -
            			highest_priority_status = status
         | 
| 332 | 
            -
             | 
| 333 | 
            -
            	# 特殊ケース: すべてのテストケースがACである場合(途中でも)
         | 
| 334 | 
            -
            	if all(result.result.passed == ResultStatus.AC for result in test_info.resultlist):
         | 
| 335 | 
            -
            		test_info.result_summary = ResultStatus.AC
         | 
| 336 | 
            -
            	else:
         | 
| 337 | 
            -
            		test_info.result_summary = highest_priority_status
         | 
| 338 | 
            -
             | 
| 339 | 
            -
             | 
| 340 | 
            -
            def create_renderable_test_result(
         | 
| 341 | 
            -
            	i: int,
         | 
| 342 | 
            -
            	test_result: LabeledTestCaseResult,
         | 
| 343 | 
            -
            ) -> RenderableType:
         | 
| 344 | 
            -
            	rule = Rule(
         | 
| 345 | 
            -
            		title=f'No.{i+1} {test_result.label}',
         | 
| 346 | 
            -
            		style=COLOR_MAP[test_result.result.passed],
         | 
| 347 | 
            -
            	)
         | 
| 348 195 |  | 
| 349 | 
            -
             | 
| 350 | 
            -
            	 | 
| 351 | 
            -
             | 
| 352 | 
            -
            		 | 
| 196 | 
            +
            def choose_lang(path: str) -> Optional[Callable[[str, TestCase], TestCaseResult]]:
         | 
| 197 | 
            +
            	ext = os.path.splitext(path)[1]
         | 
| 198 | 
            +
            	lang = next(
         | 
| 199 | 
            +
            		(lang for lang, extension in FILE_EXTENSIONS.items() if extension == ext), None
         | 
| 353 200 | 
             
            	)
         | 
| 201 | 
            +
            	# lang が None でない場合のみ get を呼び出す
         | 
| 202 | 
            +
            	if lang is not None:
         | 
| 203 | 
            +
            		return LANGUAGE_RUNNERS.get(lang)
         | 
| 204 | 
            +
            	return None
         | 
| 354 205 |  | 
| 355 | 
            -
            	execution_time_text = None
         | 
| 356 | 
            -
            	if test_result.result.executed_time is not None:
         | 
| 357 | 
            -
            		execution_time_text = Text.from_markup(
         | 
| 358 | 
            -
            			f'実行時間   [cyan]{test_result.result.executed_time}[/cyan] ms'
         | 
| 359 | 
            -
            		)
         | 
| 360 206 |  | 
| 361 | 
            -
             | 
| 362 | 
            -
            	 | 
| 363 | 
            -
             | 
| 364 | 
            -
            	 | 
| 365 | 
            -
             | 
| 366 | 
            -
             | 
| 367 | 
            -
             | 
| 368 | 
            -
             | 
| 369 | 
            -
            		 | 
| 370 | 
            -
             | 
| 371 | 
            -
            			escape(test_result.result.output),
         | 
| 372 | 
            -
            			escape(test_result.testcase.output),
         | 
| 373 | 
            -
            		)
         | 
| 374 | 
            -
            	else:
         | 
| 375 | 
            -
            		table.add_column(
         | 
| 376 | 
            -
            			'出力', style=COLOR_MAP[test_result.result.passed], min_width=10
         | 
| 377 | 
            -
            		)
         | 
| 378 | 
            -
            		table.add_row(
         | 
| 379 | 
            -
            			escape(test_result.testcase.input), escape(test_result.result.output)
         | 
| 380 | 
            -
            		)
         | 
| 381 | 
            -
             | 
| 382 | 
            -
            	components = [
         | 
| 383 | 
            -
            		rule,
         | 
| 384 | 
            -
            		status_header,
         | 
| 385 | 
            -
            		execution_time_text if execution_time_text else '',
         | 
| 386 | 
            -
            		table,
         | 
| 207 | 
            +
            def judge_code_from(
         | 
| 208 | 
            +
            	lcases: List[LabeledTestCase], path: str
         | 
| 209 | 
            +
            ) -> List[LabeledTestCaseResult]:
         | 
| 210 | 
            +
            	runner = choose_lang(path)
         | 
| 211 | 
            +
            	if runner is None:
         | 
| 212 | 
            +
            		raise ValueError(f'ランナーが見つかりませんでした。指定されたパス: {path}')
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            	return [
         | 
| 215 | 
            +
            		LabeledTestCaseResult(lcase.label, lcase.case, runner(path, lcase.case))
         | 
| 216 | 
            +
            		for lcase in lcases
         | 
| 387 217 | 
             
            	]
         | 
| 388 218 |  | 
| 389 | 
            -
            	return Group(*components)
         | 
| 390 219 |  | 
| 220 | 
            +
            class CustomFormatStyle(Enum):
         | 
| 221 | 
            +
            	SUCCESS = 'green'
         | 
| 222 | 
            +
            	FAILURE = 'red'
         | 
| 223 | 
            +
            	WARNING = 'yellow'
         | 
| 224 | 
            +
            	INFO = 'blue'
         | 
| 391 225 |  | 
| 392 | 
            -
            def render_results(
         | 
| 393 | 
            -
            	results: Generator[Union[LabeledTestCaseResult, TestInformation], None, None],
         | 
| 394 | 
            -
            ) -> None:
         | 
| 395 | 
            -
            	console = Console()
         | 
| 396 226 |  | 
| 397 | 
            -
             | 
| 398 | 
            -
            	 | 
| 399 | 
            -
            	 | 
| 400 | 
            -
            		 | 
| 401 | 
            -
            	 | 
| 227 | 
            +
            def render_results(path: str, results: List[LabeledTestCaseResult]) -> None:
         | 
| 228 | 
            +
            	console = Console()
         | 
| 229 | 
            +
            	success_count = sum(
         | 
| 230 | 
            +
            		1 for result in results if result.result.passed == ResultStatus.AC
         | 
| 231 | 
            +
            	)
         | 
| 232 | 
            +
            	total_count = len(results)
         | 
| 402 233 |  | 
| 403 | 
            -
            	 | 
| 234 | 
            +
            	# ヘッダー
         | 
| 235 | 
            +
            	header_text = Text.assemble(
         | 
| 236 | 
            +
            		f'{path}のテスト  ',
         | 
| 237 | 
            +
            		(
         | 
| 238 | 
            +
            			f'{success_count}/{total_count} ',
         | 
| 239 | 
            +
            			'green' if success_count == total_count else 'red',
         | 
| 240 | 
            +
            		),
         | 
| 241 | 
            +
            	)
         | 
| 242 | 
            +
            	console.print(Panel(header_text, expand=False))
         | 
| 404 243 |  | 
| 405 | 
            -
            	 | 
| 244 | 
            +
            	CHECK_MARK = '\u2713'
         | 
| 245 | 
            +
            	CROSS_MARK = '\u00d7'
         | 
| 406 246 | 
             
            	# 各テストケースの結果表示
         | 
| 407 247 | 
             
            	for i, result in enumerate(results):
         | 
| 408 | 
            -
            		if  | 
| 409 | 
            -
            			 | 
| 410 | 
            -
            			 | 
| 411 | 
            -
             | 
| 412 | 
            -
            			raise ValueError('テスト結果がyieldする型はLabeledTestCaseResultです')
         | 
| 248 | 
            +
            		if result.result.passed == ResultStatus.AC:
         | 
| 249 | 
            +
            			status_text = f'[green]{CHECK_MARK}[/] [white on green]{result.result.passed.value}[/]'
         | 
| 250 | 
            +
            			console.rule(title=f'No.{i+1} {result.label}', style='green')
         | 
| 251 | 
            +
            			console.print(f'[bold]ステータス:[/] {status_text}')
         | 
| 413 252 |  | 
| 414 | 
            -
             | 
| 253 | 
            +
            		else:
         | 
| 254 | 
            +
            			status_text = f'[red]{CROSS_MARK} {result.result.passed.value}[/]'
         | 
| 255 | 
            +
            			console.rule(title=f'No.{i+1} {result.label}', style='red')
         | 
| 256 | 
            +
            			console.print(f'[bold]ステータス:[/] {status_text}')
         | 
| 257 | 
            +
             | 
| 258 | 
            +
            		if result.result.executed_time is not None:
         | 
| 259 | 
            +
            			console.print(f'[bold]実行時間:[/] {result.result.executed_time} ms')
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            		table = Table(show_header=True, header_style='bold')
         | 
| 262 | 
            +
            		table.add_column('入力', style='cyan', min_width=10)
         | 
| 263 | 
            +
            		if result.result.passed != ResultStatus.AC:
         | 
| 264 | 
            +
            			table.add_column('出力', style='red', min_width=10)
         | 
| 265 | 
            +
            			table.add_column('正解の出力', style='green', min_width=10)
         | 
| 266 | 
            +
            			table.add_row(
         | 
| 267 | 
            +
            				escape(result.testcase.input),
         | 
| 268 | 
            +
            				escape(result.result.output),
         | 
| 269 | 
            +
            				escape(result.testcase.output),
         | 
| 270 | 
            +
            			)
         | 
| 271 | 
            +
            		else:
         | 
| 272 | 
            +
            			table.add_column('出力', style='green', min_width=10)
         | 
| 273 | 
            +
            			table.add_row(escape(result.testcase.input), escape(result.result.output))
         | 
| 274 | 
            +
            		console.print(table)
         | 
| 415 275 |  | 
| 416 276 |  | 
| 417 277 | 
             
            def run_test(path_of_code: str) -> None:
         | 
| @@ -427,12 +287,8 @@ def run_test(path_of_code: str) -> None: | |
| 427 287 |  | 
| 428 288 | 
             
            	test_cases = create_testcases_from_html(html)
         | 
| 429 289 | 
             
            	test_results = judge_code_from(test_cases, path_of_code)
         | 
| 430 | 
            -
            	render_results(test_results)
         | 
| 290 | 
            +
            	render_results(path_of_code, test_results)
         | 
| 431 291 |  | 
| 432 292 |  | 
| 433 293 | 
             
            def test(*args: str) -> None:
         | 
| 434 | 
            -
            	execute_files(
         | 
| 435 | 
            -
            		*args,
         | 
| 436 | 
            -
            		func=run_test,
         | 
| 437 | 
            -
            		target_filetypes=INTERPRETED_LANGUAGES + COMPILED_LANGUAGES,
         | 
| 438 | 
            -
            	)
         | 
| 294 | 
            +
            	execute_files(*args, func=run_test, target_filetypes=SOURCE_LANGUAGES)
         | 
    
        atcdr/util/filetype.py
    CHANGED
    
    | @@ -1,6 +1,5 @@ | |
| 1 | 
            -
            import os
         | 
| 2 1 | 
             
            from enum import Enum
         | 
| 3 | 
            -
            from typing import Dict, List,  | 
| 2 | 
            +
            from typing import Dict, List, TypeAlias
         | 
| 4 3 |  | 
| 5 4 | 
             
            # ファイル名と拡張子の型エイリアスを定義
         | 
| 6 5 | 
             
            Filename: TypeAlias = str
         | 
| @@ -24,7 +23,7 @@ class Lang(Enum): | |
| 24 23 |  | 
| 25 24 |  | 
| 26 25 | 
             
            # ファイル拡張子と対応する言語の辞書
         | 
| 27 | 
            -
            FILE_EXTENSIONS: Dict[Lang,  | 
| 26 | 
            +
            FILE_EXTENSIONS: Dict[Lang, Extension] = {
         | 
| 28 27 | 
             
            	Lang.PYTHON: '.py',
         | 
| 29 28 | 
             
            	Lang.JAVASCRIPT: '.js',
         | 
| 30 29 | 
             
            	Lang.JAVA: '.java',
         | 
| @@ -40,29 +39,24 @@ FILE_EXTENSIONS: Dict[Lang, str] = { | |
| 40 39 | 
             
            	Lang.JSON: '.json',
         | 
| 41 40 | 
             
            }
         | 
| 42 41 |  | 
| 43 | 
            -
            # ドキュメント言語のリスト
         | 
| 44 42 | 
             
            DOCUMENT_LANGUAGES: List[Lang] = [
         | 
| 45 43 | 
             
            	Lang.HTML,
         | 
| 46 44 | 
             
            	Lang.MARKDOWN,
         | 
| 47 45 | 
             
            	Lang.JSON,
         | 
| 48 46 | 
             
            ]
         | 
| 49 47 |  | 
| 50 | 
            -
            #  | 
| 51 | 
            -
             | 
| 48 | 
            +
            # ソースコードファイルと言語のリスト
         | 
| 49 | 
            +
            SOURCE_LANGUAGES: List[Lang] = [
         | 
| 50 | 
            +
            	Lang.PYTHON,
         | 
| 51 | 
            +
            	Lang.JAVASCRIPT,
         | 
| 52 52 | 
             
            	Lang.JAVA,
         | 
| 53 53 | 
             
            	Lang.C,
         | 
| 54 54 | 
             
            	Lang.CPP,
         | 
| 55 55 | 
             
            	Lang.CSHARP,
         | 
| 56 | 
            -
            	Lang.GO,
         | 
| 57 | 
            -
            	Lang.RUST,
         | 
| 58 | 
            -
            ]
         | 
| 59 | 
            -
             | 
| 60 | 
            -
            # インタプリター型言語のリスト
         | 
| 61 | 
            -
            INTERPRETED_LANGUAGES: List[Lang] = [
         | 
| 62 | 
            -
            	Lang.PYTHON,
         | 
| 63 | 
            -
            	Lang.JAVASCRIPT,
         | 
| 64 56 | 
             
            	Lang.RUBY,
         | 
| 65 57 | 
             
            	Lang.PHP,
         | 
| 58 | 
            +
            	Lang.GO,
         | 
| 59 | 
            +
            	Lang.RUST,
         | 
| 66 60 | 
             
            ]
         | 
| 67 61 |  | 
| 68 62 |  | 
| @@ -95,11 +89,3 @@ def str2lang(lang: str) -> Lang: | |
| 95 89 |  | 
| 96 90 | 
             
            def lang2str(lang: Lang) -> str:
         | 
| 97 91 | 
             
            	return lang.value
         | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
            def detect_language(path: str) -> Optional[Lang]:
         | 
| 101 | 
            -
            	ext = os.path.splitext(path)[1]  # ファイルの拡張子を取得
         | 
| 102 | 
            -
            	lang = next(
         | 
| 103 | 
            -
            		(lang for lang, extension in FILE_EXTENSIONS.items() if extension == ext), None
         | 
| 104 | 
            -
            	)
         | 
| 105 | 
            -
            	return lang
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.3
         | 
| 2 2 | 
             
            Name: AtCoderStudyBooster
         | 
| 3 | 
            -
            Version: 0. | 
| 3 | 
            +
            Version: 0.23
         | 
| 4 4 | 
             
            Summary: A tool to download and manage AtCoder problems.
         | 
| 5 5 | 
             
            Project-URL: Homepage, https://github.com/yuta6/AtCoderStudyBooster
         | 
| 6 6 | 
             
            Author-email: yuta6 <46110512+yuta6@users.noreply.github.com>
         | 
| @@ -1,17 +1,17 @@ | |
| 1 1 | 
             
            atcdr/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 2 2 | 
             
            atcdr/download.py,sha256=aqJEmrLop_Mj8GNJjoWRcSzWU0z9iNc9vtZOjhWXx3U,8945
         | 
| 3 | 
            -
            atcdr/generate.py,sha256= | 
| 3 | 
            +
            atcdr/generate.py,sha256=MKoip-0jTEkMH0hEusgJgtfQB9QCSiiZ2jlU_ccmB8E,6973
         | 
| 4 4 | 
             
            atcdr/main.py,sha256=y2IkXwcAyKZ_1y5PgU93GpXzo5lKak9oxo0XV_9d5Fo,727
         | 
| 5 5 | 
             
            atcdr/markdown.py,sha256=jEktnYgrDYcgIuhxRpJImAzNpFmfSPkRikAesfMxAVk,1125
         | 
| 6 6 | 
             
            atcdr/open.py,sha256=2UlmNWdieoMrPu1xSUWf-8sBB9Y19r0t6V9zDRBSPes,924
         | 
| 7 | 
            -
            atcdr/test.py,sha256= | 
| 7 | 
            +
            atcdr/test.py,sha256=hAhttwVJiDJX8IAWcnpKj04yTTs4cmr8GQ-NsldBAGc,8468
         | 
| 8 8 | 
             
            atcdr/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 9 9 | 
             
            atcdr/util/cost.py,sha256=0c9H8zLley7xZDLuYU4zJmB8m71qcO1WEIQOoEavD_4,3168
         | 
| 10 10 | 
             
            atcdr/util/execute.py,sha256=tcYflnVo_38LdaOGDUAuqfSfcA54bTrCaTRShH7kwUw,1750
         | 
| 11 | 
            -
            atcdr/util/filetype.py,sha256= | 
| 11 | 
            +
            atcdr/util/filetype.py,sha256=NyTkBbL44VbPwGXps381odbC_JEx_eYxRYPaYwRHfZ0,1647
         | 
| 12 12 | 
             
            atcdr/util/gpt.py,sha256=Lto6SJHZGer8cC_Nq8lJVnaET2R7apFQteo6ZEFpjdM,3304
         | 
| 13 13 | 
             
            atcdr/util/problem.py,sha256=WprmpOZm6xpyvksIS3ou1uHqFnBO1FUZWadsLziG1bY,2484
         | 
| 14 | 
            -
            atcoderstudybooster-0. | 
| 15 | 
            -
            atcoderstudybooster-0. | 
| 16 | 
            -
            atcoderstudybooster-0. | 
| 17 | 
            -
            atcoderstudybooster-0. | 
| 14 | 
            +
            atcoderstudybooster-0.23.dist-info/METADATA,sha256=dPsYm8RVvwjgwQjXCQS_gRz8JFIZEIfK6ig_Fsvb1k0,4468
         | 
| 15 | 
            +
            atcoderstudybooster-0.23.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
         | 
| 16 | 
            +
            atcoderstudybooster-0.23.dist-info/entry_points.txt,sha256=_bhz0R7vp2VubKl_eIokDO8Wz9TdqvYA7Q59uWfy6Sk,42
         | 
| 17 | 
            +
            atcoderstudybooster-0.23.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |