webtoonmtl 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- webtoonmtl-0.1.0/PKG-INFO +88 -0
- webtoonmtl-0.1.0/README.md +72 -0
- webtoonmtl-0.1.0/pyproject.toml +25 -0
- webtoonmtl-0.1.0/src/webtoonmtl/__init__.py +4 -0
- webtoonmtl-0.1.0/src/webtoonmtl/__main__.py +4 -0
- webtoonmtl-0.1.0/src/webtoonmtl/_utils/__init__.py +0 -0
- webtoonmtl-0.1.0/src/webtoonmtl/_utils/logger.py +98 -0
- webtoonmtl-0.1.0/src/webtoonmtl/_utils/logging_config.json +61 -0
- webtoonmtl-0.1.0/src/webtoonmtl/_utils/trainer.py +244 -0
- webtoonmtl-0.1.0/src/webtoonmtl/cli.py +15 -0
- webtoonmtl-0.1.0/src/webtoonmtl/core/__init__.py +0 -0
- webtoonmtl-0.1.0/src/webtoonmtl/core/mtlcore.py +98 -0
- webtoonmtl-0.1.0/src/webtoonmtl/core/setup.py +127 -0
- webtoonmtl-0.1.0/src/webtoonmtl/core/translator.py +111 -0
- webtoonmtl-0.1.0/src/webtoonmtl/py.typed +0 -0
- webtoonmtl-0.1.0/src/webtoonmtl/ui/__init__.py +4 -0
- webtoonmtl-0.1.0/src/webtoonmtl/ui/colors.py +13 -0
- webtoonmtl-0.1.0/src/webtoonmtl/ui/main_window.py +556 -0
- webtoonmtl-0.1.0/src/webtoonmtl/ui/setup_wizard.py +269 -0
- webtoonmtl-0.1.0/src/webtoonmtl/ui/widgets.py +365 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: webtoonmtl
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Webtoon machine translation application.
|
|
5
|
+
Author: jyyhuang
|
|
6
|
+
Author-email: jyyhuang <jyyhuanggit@gmail.com>
|
|
7
|
+
Requires-Dist: appdirs>=1.4.4
|
|
8
|
+
Requires-Dist: easyocr>=1.7.2
|
|
9
|
+
Requires-Dist: pyqt6>=6.10.2
|
|
10
|
+
Requires-Dist: sentencepiece>=0.2.1
|
|
11
|
+
Requires-Dist: timm>=1.0.24
|
|
12
|
+
Requires-Dist: torch>=2.10.0
|
|
13
|
+
Requires-Dist: transformers>=5.0.0
|
|
14
|
+
Requires-Python: >=3.14
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# Webtoon MTL
|
|
18
|
+
|
|
19
|
+
A Python application for extracting and translating Korean text from webtoon images using OCR and neural machine translation.
|
|
20
|
+
|
|
21
|
+
## Project Status
|
|
22
|
+
|
|
23
|
+
Still under development, expect breaking changes
|
|
24
|
+
|
|
25
|
+
## Table of Contents
|
|
26
|
+
|
|
27
|
+
- [Webtoon MTL](#webtoon-mtl)
|
|
28
|
+
- [Table of Contents](#table-of-contents)
|
|
29
|
+
- [Demo](#demo)
|
|
30
|
+
- [Background](#background)
|
|
31
|
+
- [Features](#features)
|
|
32
|
+
- [Installation](#installation)
|
|
33
|
+
- [Usage](#usage)
|
|
34
|
+
- [Roadmap](#roadmap)
|
|
35
|
+
- [License](#license)
|
|
36
|
+
|
|
37
|
+
## Demo
|
|
38
|
+
https://github.com/user-attachments/assets/6e52bf87-991e-4488-935f-eeeb72c1f76a
|
|
39
|
+
|
|
40
|
+
## Background
|
|
41
|
+
|
|
42
|
+
After years of reading webtoons and manhwa, I repeatedly ran into the same frustration of reaching the latest available chapter in English, only to find that newer chapters exist only in Korean. While fan translations eventually appear, they are often delayed or incomplete.
|
|
43
|
+
|
|
44
|
+
With recent advances in optical character recognition (OCR) and neural machine translation (NMT), I decided it was finally time to address this problem myself. This project is personal, but feel free to use it too!
|
|
45
|
+
|
|
46
|
+
## Features
|
|
47
|
+
|
|
48
|
+
- **Simple GUI**: Desktop gui with PyQT6
|
|
49
|
+
- **OCR Extraction**: Extract Korean text from images using EasyOCR
|
|
50
|
+
- **Neural Translation**: Translate Korean to English using transformer models
|
|
51
|
+
- **Model Training**: Fine-tune translation models on Korean-English datasets
|
|
52
|
+
|
|
53
|
+
## Installation
|
|
54
|
+
|
|
55
|
+
1. Clone the repository:
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
git clone https://github.com/jyyhuang/webtoonmtl.git
|
|
59
|
+
cd webtoonmtl
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
2. Install dependencies:
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
pip install -r requirements.txt
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## Usage
|
|
69
|
+
|
|
70
|
+
```
|
|
71
|
+
webtoonmtl start
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## Roadmap
|
|
75
|
+
|
|
76
|
+
- ✅ Extract text from images using OCR
|
|
77
|
+
- ✅ Use transformer to translate Korean text
|
|
78
|
+
- ✅ Add support for fine-tuning the translation model
|
|
79
|
+
- ✅ Create command line usage
|
|
80
|
+
- ✅ Desktop GUI with PyQT
|
|
81
|
+
- ⬜ Add better testing
|
|
82
|
+
- ⬜ Cache OCR, model, and translations
|
|
83
|
+
- ⬜ Improve logging, error handling, and progress reports
|
|
84
|
+
- ⬜ Package as a pip library
|
|
85
|
+
|
|
86
|
+
## License
|
|
87
|
+
|
|
88
|
+
See LICENSE file for details.
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Webtoon MTL
|
|
2
|
+
|
|
3
|
+
A Python application for extracting and translating Korean text from webtoon images using OCR and neural machine translation.
|
|
4
|
+
|
|
5
|
+
## Project Status
|
|
6
|
+
|
|
7
|
+
Still under development, expect breaking changes
|
|
8
|
+
|
|
9
|
+
## Table of Contents
|
|
10
|
+
|
|
11
|
+
- [Webtoon MTL](#webtoon-mtl)
|
|
12
|
+
- [Table of Contents](#table-of-contents)
|
|
13
|
+
- [Demo](#demo)
|
|
14
|
+
- [Background](#background)
|
|
15
|
+
- [Features](#features)
|
|
16
|
+
- [Installation](#installation)
|
|
17
|
+
- [Usage](#usage)
|
|
18
|
+
- [Roadmap](#roadmap)
|
|
19
|
+
- [License](#license)
|
|
20
|
+
|
|
21
|
+
## Demo
|
|
22
|
+
https://github.com/user-attachments/assets/6e52bf87-991e-4488-935f-eeeb72c1f76a
|
|
23
|
+
|
|
24
|
+
## Background
|
|
25
|
+
|
|
26
|
+
After years of reading webtoons and manhwa, I repeatedly ran into the same frustration of reaching the latest available chapter in English, only to find that newer chapters exist only in Korean. While fan translations eventually appear, they are often delayed or incomplete.
|
|
27
|
+
|
|
28
|
+
With recent advances in optical character recognition (OCR) and neural machine translation (NMT), I decided it was finally time to address this problem myself. This project is personal, but feel free to use it too!
|
|
29
|
+
|
|
30
|
+
## Features
|
|
31
|
+
|
|
32
|
+
- **Simple GUI**: Desktop gui with PyQT6
|
|
33
|
+
- **OCR Extraction**: Extract Korean text from images using EasyOCR
|
|
34
|
+
- **Neural Translation**: Translate Korean to English using transformer models
|
|
35
|
+
- **Model Training**: Fine-tune translation models on Korean-English datasets
|
|
36
|
+
|
|
37
|
+
## Installation
|
|
38
|
+
|
|
39
|
+
1. Clone the repository:
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
git clone https://github.com/jyyhuang/webtoonmtl.git
|
|
43
|
+
cd webtoonmtl
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
2. Install dependencies:
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pip install -r requirements.txt
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## Usage
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
webtoonmtl start
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Roadmap
|
|
59
|
+
|
|
60
|
+
- ✅ Extract text from images using OCR
|
|
61
|
+
- ✅ Use transformer to translate Korean text
|
|
62
|
+
- ✅ Add support for fine-tuning the translation model
|
|
63
|
+
- ✅ Create command line usage
|
|
64
|
+
- ✅ Desktop GUI with PyQT
|
|
65
|
+
- ⬜ Add better testing
|
|
66
|
+
- ⬜ Cache OCR, model, and translations
|
|
67
|
+
- ⬜ Improve logging, error handling, and progress reports
|
|
68
|
+
- ⬜ Package as a pip library
|
|
69
|
+
|
|
70
|
+
## License
|
|
71
|
+
|
|
72
|
+
See LICENSE file for details.
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "webtoonmtl"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Webtoon machine translation application."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "jyyhuang", email = "jyyhuanggit@gmail.com" }
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.14"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"appdirs>=1.4.4",
|
|
12
|
+
"easyocr>=1.7.2",
|
|
13
|
+
"pyqt6>=6.10.2",
|
|
14
|
+
"sentencepiece>=0.2.1",
|
|
15
|
+
"timm>=1.0.24",
|
|
16
|
+
"torch>=2.10.0",
|
|
17
|
+
"transformers>=5.0.0",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["uv_build>=0.9.26,<0.10.0"]
|
|
22
|
+
build-backend = "uv_build"
|
|
23
|
+
|
|
24
|
+
[project.scripts]
|
|
25
|
+
webtoonmtl = "webtoonmtl.cli:cli"
|
|
File without changes
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import override
|
|
5
|
+
|
|
6
|
+
import appdirs
|
|
7
|
+
import atexit
|
|
8
|
+
import logging.config
|
|
9
|
+
import pathlib
|
|
10
|
+
|
|
11
|
+
LOG_RECORD_BUILTIN_ATTRS = {
|
|
12
|
+
"args",
|
|
13
|
+
"asctime",
|
|
14
|
+
"created",
|
|
15
|
+
"exc_info",
|
|
16
|
+
"exc_text",
|
|
17
|
+
"filename",
|
|
18
|
+
"funcName",
|
|
19
|
+
"levelname",
|
|
20
|
+
"levelno",
|
|
21
|
+
"lineno",
|
|
22
|
+
"module",
|
|
23
|
+
"msecs",
|
|
24
|
+
"message",
|
|
25
|
+
"msg",
|
|
26
|
+
"name",
|
|
27
|
+
"pathname",
|
|
28
|
+
"process",
|
|
29
|
+
"processName",
|
|
30
|
+
"relativeCreated",
|
|
31
|
+
"stack_info",
|
|
32
|
+
"thread",
|
|
33
|
+
"threadName",
|
|
34
|
+
"taskName",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class JSONFormatter(logging.Formatter):
|
|
39
|
+
def __init__(self, *, fmt_keys: dict[str, str] | None = None):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.fmt_keys = fmt_keys if fmt_keys is not None else {}
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
45
|
+
message = self._make_log_dict(record)
|
|
46
|
+
return json.dumps(message, default=str)
|
|
47
|
+
|
|
48
|
+
def _make_log_dict(self, record: logging.LogRecord):
|
|
49
|
+
always = {
|
|
50
|
+
"message": record.getMessage(),
|
|
51
|
+
"timestamp": dt.datetime.fromtimestamp(
|
|
52
|
+
record.created, tz=dt.timezone.utc
|
|
53
|
+
).isoformat(),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
if record.exc_info is not None:
|
|
57
|
+
always["exc_info"] = self.formatException(record.exc_info)
|
|
58
|
+
if record.stack_info is not None:
|
|
59
|
+
always["stack_info"] = self.formatStack(record.stack_info)
|
|
60
|
+
message = {
|
|
61
|
+
key: (
|
|
62
|
+
msg_val
|
|
63
|
+
if (msg_val := always.pop(val, None)) is not None
|
|
64
|
+
else getattr(record, val)
|
|
65
|
+
)
|
|
66
|
+
for key, val in self.fmt_keys.items()
|
|
67
|
+
}
|
|
68
|
+
message.update(always)
|
|
69
|
+
for key, val in record.__dict__.items():
|
|
70
|
+
if key not in LOG_RECORD_BUILTIN_ATTRS:
|
|
71
|
+
message[key] = val
|
|
72
|
+
|
|
73
|
+
return message
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class StdoutFilter(logging.Filter):
|
|
77
|
+
@override
|
|
78
|
+
def filter(self, record: logging.LogRecord) -> bool | logging.LogRecord:
|
|
79
|
+
return record.levelno <= logging.INFO
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def setup_logging():
|
|
83
|
+
config_file = pathlib.Path(__file__).resolve().parent / "logging_config.json"
|
|
84
|
+
|
|
85
|
+
log_dir = pathlib.Path(appdirs.user_log_dir("webtoonmtl"))
|
|
86
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
|
|
88
|
+
with open(config_file) as file:
|
|
89
|
+
config = json.load(file)
|
|
90
|
+
|
|
91
|
+
config["handlers"]["json_file"]["filename"] = str(log_dir / "webtoonmtl.log.jsonl")
|
|
92
|
+
|
|
93
|
+
logging.config.dictConfig(config)
|
|
94
|
+
|
|
95
|
+
queue_handler = logging.getHandlerByName("queue_handler")
|
|
96
|
+
if queue_handler is not None:
|
|
97
|
+
queue_handler.listener.start()
|
|
98
|
+
atexit.register(queue_handler.listener.stop)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
{
|
|
2
|
+
"version": 1,
|
|
3
|
+
"disable_existing_loggers": false,
|
|
4
|
+
"formatters": {
|
|
5
|
+
"simple": {
|
|
6
|
+
"format": "%(asctime)s [%(levelname)s] %(module)s: %(message)s",
|
|
7
|
+
"datefmt": "%Y-%m-%dT%H:%M:%S%z"
|
|
8
|
+
},
|
|
9
|
+
"json": {
|
|
10
|
+
"()": "webtoonmtl._utils.logger.JSONFormatter",
|
|
11
|
+
"fmt_keys": {
|
|
12
|
+
"level": "levelname",
|
|
13
|
+
"message": "message",
|
|
14
|
+
"timestamp": "timestamp",
|
|
15
|
+
"logger": "name",
|
|
16
|
+
"module": "module",
|
|
17
|
+
"function": "funcName",
|
|
18
|
+
"line": "lineno",
|
|
19
|
+
"thread_name": "threadName"
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
},
|
|
23
|
+
"filters": {
|
|
24
|
+
"stdout_filter": {
|
|
25
|
+
"()": "webtoonmtl._utils.logger.StdoutFilter"
|
|
26
|
+
}
|
|
27
|
+
},
|
|
28
|
+
"handlers": {
|
|
29
|
+
"stdout": {
|
|
30
|
+
"class": "logging.StreamHandler",
|
|
31
|
+
"level": "DEBUG",
|
|
32
|
+
"filters": ["stdout_filter"],
|
|
33
|
+
"formatter": "simple",
|
|
34
|
+
"stream": "ext://sys.stdout"
|
|
35
|
+
},
|
|
36
|
+
"stderr": {
|
|
37
|
+
"class": "logging.StreamHandler",
|
|
38
|
+
"level": "WARNING",
|
|
39
|
+
"formatter": "simple",
|
|
40
|
+
"stream": "ext://sys.stderr"
|
|
41
|
+
},
|
|
42
|
+
"json_file": {
|
|
43
|
+
"class": "logging.handlers.RotatingFileHandler",
|
|
44
|
+
"level": "DEBUG",
|
|
45
|
+
"formatter": "json",
|
|
46
|
+
"maxBytes": 10000,
|
|
47
|
+
"backupCount": 3
|
|
48
|
+
},
|
|
49
|
+
"queue_handler": {
|
|
50
|
+
"class": "logging.handlers.QueueHandler",
|
|
51
|
+
"handlers": ["stderr", "json_file"],
|
|
52
|
+
"respect_handler_level": true
|
|
53
|
+
}
|
|
54
|
+
},
|
|
55
|
+
"loggers": {
|
|
56
|
+
"root": {
|
|
57
|
+
"level": "DEBUG",
|
|
58
|
+
"handlers": ["queue_handler"]
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
#!pip install evaluate
|
|
2
|
+
#!pip install torch
|
|
3
|
+
#!pip install datasets
|
|
4
|
+
#!pip install transformers
|
|
5
|
+
#!pip install sacrebleu
|
|
6
|
+
#!pip install sacremoses
|
|
7
|
+
|
|
8
|
+
import inspect
|
|
9
|
+
import json
|
|
10
|
+
from dataclasses import asdict, dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import evaluate
|
|
15
|
+
import torch
|
|
16
|
+
from datasets import load_dataset
|
|
17
|
+
from transformers import (
|
|
18
|
+
AutoModelForSeq2SeqLM,
|
|
19
|
+
AutoTokenizer,
|
|
20
|
+
DataCollatorForSeq2Seq,
|
|
21
|
+
Seq2SeqTrainer,
|
|
22
|
+
Seq2SeqTrainingArguments,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class TrainingConfig:
|
|
27
|
+
model_name: str = "Helsinki-NLP/opus-mt-ko-en"
|
|
28
|
+
dataset_name: str = "lemon-mint/Korean-FineTome-100k"
|
|
29
|
+
max_length: int = 128
|
|
30
|
+
|
|
31
|
+
output_dir: Path = Path("~/.webtoonmtl/model").expanduser().resolve()
|
|
32
|
+
output_data_dir: Path = Path("~/.webtoonmtl/data").expanduser().resolve()
|
|
33
|
+
gradient_checkpointing: bool = True
|
|
34
|
+
per_device_train_batch_size: int = 16
|
|
35
|
+
per_device_eval_batch_size: int = 16
|
|
36
|
+
learning_rate: float = 2e-5
|
|
37
|
+
warmup_steps: int = 100
|
|
38
|
+
num_train_epochs: int = 5
|
|
39
|
+
optim: str = "adamw_torch"
|
|
40
|
+
metric_for_best_model: str = "bleu"
|
|
41
|
+
fp16: bool = True
|
|
42
|
+
predict_with_generate: bool = True
|
|
43
|
+
eval_strategy: str = "steps"
|
|
44
|
+
eval_steps: int = 500
|
|
45
|
+
save_strategy: str = "steps"
|
|
46
|
+
save_steps: int = 500
|
|
47
|
+
logging_steps: int = 100
|
|
48
|
+
save_total_limit: int = 3
|
|
49
|
+
load_best_model_at_end: bool = True
|
|
50
|
+
push_to_hub: bool = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_model(model_name: str, gradient_checkpointing: bool = True):
|
|
54
|
+
"""
|
|
55
|
+
Load tokenizer and model from HuggingFace.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model_name: str path to HF base model
|
|
59
|
+
gradient_checkpointing: whether to enable gradient checkpointing
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
tuple: (tokenizer, model)
|
|
63
|
+
"""
|
|
64
|
+
try:
|
|
65
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
66
|
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
67
|
+
|
|
68
|
+
if gradient_checkpointing:
|
|
69
|
+
model.gradient_checkpointing_enable(
|
|
70
|
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return tokenizer, model
|
|
74
|
+
except Exception as e:
|
|
75
|
+
raise
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def tokenize_function(examples, tokenizer, max_length: int = 128):
|
|
79
|
+
"""
|
|
80
|
+
Tokenize examples for translation.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
examples: batch of examples from dataset
|
|
84
|
+
tokenizer: tokenizer to use
|
|
85
|
+
max_length: max sequence length
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
dict: tokenized inputs
|
|
89
|
+
"""
|
|
90
|
+
inputs, targets = [], []
|
|
91
|
+
|
|
92
|
+
for msgs in examples["messages"]:
|
|
93
|
+
if "content" in msgs[0] and "content_en" in msgs[0]:
|
|
94
|
+
inputs.append(msgs[0]["content"])
|
|
95
|
+
targets.append(msgs[0]["content_en"])
|
|
96
|
+
else:
|
|
97
|
+
inputs.append("")
|
|
98
|
+
targets.append("")
|
|
99
|
+
|
|
100
|
+
model_inputs = tokenizer(
|
|
101
|
+
inputs,
|
|
102
|
+
text_target=targets,
|
|
103
|
+
max_length=max_length,
|
|
104
|
+
truncation=True,
|
|
105
|
+
padding="max_length",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return model_inputs
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def prepare_datasets(dataset_name: str, tokenizer, max_length: int = 128):
|
|
112
|
+
"""
|
|
113
|
+
Load and tokenize dataset.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
dataset_name: name of dataset to load
|
|
117
|
+
tokenizer: tokenizer to use
|
|
118
|
+
max_length: max sequence length
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
tuple: (train_dataset, test_dataset)
|
|
122
|
+
"""
|
|
123
|
+
dataset = load_dataset(dataset_name, split="train")
|
|
124
|
+
dataset = dataset.train_test_split(test_size=0.2)
|
|
125
|
+
|
|
126
|
+
train_dataset = dataset["train"]
|
|
127
|
+
test_dataset = dataset["test"]
|
|
128
|
+
|
|
129
|
+
tokenized_dataset_train = train_dataset.map(
|
|
130
|
+
lambda examples: tokenize_function(examples, tokenizer, max_length),
|
|
131
|
+
batched=True,
|
|
132
|
+
remove_columns=train_dataset.column_names,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
tokenized_dataset_test = test_dataset.map(
|
|
136
|
+
lambda examples: tokenize_function(examples, tokenizer, max_length),
|
|
137
|
+
batched=True,
|
|
138
|
+
remove_columns=test_dataset.column_names,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return tokenized_dataset_train, tokenized_dataset_test
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def compute_metrics(eval_preds, tokenizer, bleu_metric, chrf_metric):
|
|
145
|
+
"""
|
|
146
|
+
Compute BLEU and chrF metrics for evaluation.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
eval_preds: tuple of (predictions, labels)
|
|
150
|
+
tokenizer: tokenizer for decoding
|
|
151
|
+
bleu_metric: BLEU metric object
|
|
152
|
+
chrf_metric: chrF metric object
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
dict: metrics dictionary
|
|
156
|
+
"""
|
|
157
|
+
preds, labels = eval_preds
|
|
158
|
+
if isinstance(preds, tuple):
|
|
159
|
+
preds = preds[0]
|
|
160
|
+
|
|
161
|
+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
162
|
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
|
163
|
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
164
|
+
|
|
165
|
+
decoded_preds = [p.strip() for p in decoded_preds]
|
|
166
|
+
decoded_labels = [[l.strip()] for l in decoded_labels]
|
|
167
|
+
|
|
168
|
+
bleu_result = bleu_metric.compute(
|
|
169
|
+
predictions=decoded_preds, references=decoded_labels
|
|
170
|
+
)
|
|
171
|
+
chrf_result = chrf_metric.compute(
|
|
172
|
+
predictions=decoded_preds, references=decoded_labels
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return {"bleu": bleu_result["score"], "chrf": chrf_result["score"]}
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def train(
|
|
179
|
+
config: TrainingConfig | None = None,
|
|
180
|
+
resume_from_checkpoint: str | None = None,
|
|
181
|
+
):
|
|
182
|
+
"""
|
|
183
|
+
Starts the fine-tuning process.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
config: Training configuration
|
|
187
|
+
resume_from_checkpoint: checkpoint path to resume from
|
|
188
|
+
"""
|
|
189
|
+
config = config or TrainingConfig()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
tokenizer, model = load_model(config.model_name, config.gradient_checkpointing)
|
|
193
|
+
|
|
194
|
+
train_dataset, test_dataset = prepare_datasets(
|
|
195
|
+
config.dataset_name, tokenizer, config.max_length
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
199
|
+
model.to(device)
|
|
200
|
+
|
|
201
|
+
bleu_metric = evaluate.load("sacrebleu")
|
|
202
|
+
chrf_metric = evaluate.load("chrf")
|
|
203
|
+
|
|
204
|
+
valid_args = inspect.signature(Seq2SeqTrainingArguments).parameters
|
|
205
|
+
|
|
206
|
+
trainer_settings = {k: v for k, v in asdict(config).items() if k in valid_args}
|
|
207
|
+
|
|
208
|
+
training_args = Seq2SeqTrainingArguments(**trainer_settings)
|
|
209
|
+
|
|
210
|
+
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
|
|
211
|
+
|
|
212
|
+
trainer = Seq2SeqTrainer(
|
|
213
|
+
model=model,
|
|
214
|
+
args=training_args,
|
|
215
|
+
train_dataset=train_dataset,
|
|
216
|
+
eval_dataset=test_dataset,
|
|
217
|
+
data_collator=data_collator,
|
|
218
|
+
compute_metrics=lambda eval_preds: compute_metrics(
|
|
219
|
+
eval_preds, tokenizer, bleu_metric, chrf_metric
|
|
220
|
+
),
|
|
221
|
+
processing_class=tokenizer,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
226
|
+
|
|
227
|
+
output_dir = config.output_dir
|
|
228
|
+
output_data_dir = config.output_data_dir
|
|
229
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
230
|
+
output_data_dir.mkdir(parents=True, exist_ok=True)
|
|
231
|
+
|
|
232
|
+
trainer.save_model(str(output_dir))
|
|
233
|
+
|
|
234
|
+
metrics = train_result.metrics
|
|
235
|
+
eval_metrics = trainer.evaluate()
|
|
236
|
+
metrics.update({k: v for k, v in eval_metrics.items()})
|
|
237
|
+
|
|
238
|
+
metrics_path = output_data_dir / "metrics.json"
|
|
239
|
+
with metrics_path.open("w") as f:
|
|
240
|
+
json.dump(metrics, f, indent=2)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
raise
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@click.group()
|
|
5
|
+
@click.version_option("0.1.0", prog_name="webtoonmtl")
|
|
6
|
+
def cli():
|
|
7
|
+
"""Machine Translations for Korean webtoons to English."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@cli.command()
|
|
11
|
+
def start():
|
|
12
|
+
"""Starts the Webtoon MTL GUI application."""
|
|
13
|
+
from webtoonmtl.ui import run_gui
|
|
14
|
+
|
|
15
|
+
run_gui()
|
|
File without changes
|