khmerns 0.0.3__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {khmerns-0.0.3 → khmerns-0.0.4}/PKG-INFO +9 -3
- {khmerns-0.0.3 → khmerns-0.0.4}/README.md +8 -2
- {khmerns-0.0.3 → khmerns-0.0.4}/pyproject.toml +2 -1
- khmerns-0.0.4/src/khmerns/__init__.py +4 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/khmerns/__init__.pyi +4 -0
- {khmerns-0.0.3/training → khmerns-0.0.4/src/khmerns}/khnormal.py +3 -2
- khmerns-0.0.3/.github/workflows/wheels.yml +0 -93
- khmerns-0.0.3/img/graph.png +0 -0
- khmerns-0.0.3/src/khmerns/__init__.py +0 -3
- khmerns-0.0.3/test/example.py +0 -9
- khmerns-0.0.3/training/best_model.pt +0 -0
- khmerns-0.0.3/training/convert_to_gguf.py +0 -124
- khmerns-0.0.3/training/data.py +0 -78
- khmerns-0.0.3/training/export_onnx.py +0 -169
- khmerns-0.0.3/training/generate.py +0 -52
- khmerns-0.0.3/training/infer.py +0 -63
- khmerns-0.0.3/training/model.py +0 -41
- khmerns-0.0.3/training/requirements.txt +0 -7
- khmerns-0.0.3/training/segmenter.onnx +0 -0
- khmerns-0.0.3/training/tokenizer.py +0 -202
- khmerns-0.0.3/training/train.py +0 -112
- {khmerns-0.0.3 → khmerns-0.0.4}/.gitignore +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/CMakeLists.txt +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/LICENSE +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/crf.cpp +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/crf.h +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/khmer-segmenter.cpp +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/khmer-segmenter.h +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/main.cpp +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/model_data.h +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/tokenizer.cpp +0 -0
- {khmerns-0.0.3 → khmerns-0.0.4}/src/tokenizer.h +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: khmerns
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: Khmer Neural Segmenter
|
|
5
5
|
Keywords: khmer,nlp,segmentation,tokenization,neural-network
|
|
6
6
|
Author-Email: Seanghay Yath <seanghay.dev@gmail.com>
|
|
@@ -35,11 +35,15 @@ pip install khmerns
|
|
|
35
35
|
## Usage
|
|
36
36
|
|
|
37
37
|
```python
|
|
38
|
-
from khmerns import tokenize
|
|
38
|
+
from khmerns import tokenize, normalize
|
|
39
39
|
|
|
40
40
|
# Returns a list of words
|
|
41
41
|
words = tokenize("សួស្តីបងប្អូន")
|
|
42
|
-
# ['សួស្តី', 'បង', 'ប្អូន']
|
|
42
|
+
# => ['សួស្តី', 'បង', 'ប្អូន']
|
|
43
|
+
|
|
44
|
+
# normalize and reorder Khmer characters
|
|
45
|
+
words = tokenize(normalize("សួស្តីបងប្អូន"))
|
|
46
|
+
# => ['សួស្តី', 'បង', 'ប្អូន']
|
|
43
47
|
```
|
|
44
48
|
|
|
45
49
|
You can also use the class-based API if you prefer:
|
|
@@ -48,8 +52,10 @@ You can also use the class-based API if you prefer:
|
|
|
48
52
|
from khmerns import KhmerSegmenter
|
|
49
53
|
|
|
50
54
|
segmenter = KhmerSegmenter()
|
|
55
|
+
|
|
51
56
|
words = segmenter.tokenize("សួស្តីបងប្អូន")
|
|
52
57
|
# or
|
|
58
|
+
|
|
53
59
|
words = segmenter("សួស្តីបងប្អូន")
|
|
54
60
|
```
|
|
55
61
|
|
|
@@ -13,11 +13,15 @@ pip install khmerns
|
|
|
13
13
|
## Usage
|
|
14
14
|
|
|
15
15
|
```python
|
|
16
|
-
from khmerns import tokenize
|
|
16
|
+
from khmerns import tokenize, normalize
|
|
17
17
|
|
|
18
18
|
# Returns a list of words
|
|
19
19
|
words = tokenize("សួស្តីបងប្អូន")
|
|
20
|
-
# ['សួស្តី', 'បង', 'ប្អូន']
|
|
20
|
+
# => ['សួស្តី', 'បង', 'ប្អូន']
|
|
21
|
+
|
|
22
|
+
# normalize and reorder Khmer characters
|
|
23
|
+
words = tokenize(normalize("សួស្តីបងប្អូន"))
|
|
24
|
+
# => ['សួស្តី', 'បង', 'ប្អូន']
|
|
21
25
|
```
|
|
22
26
|
|
|
23
27
|
You can also use the class-based API if you prefer:
|
|
@@ -26,8 +30,10 @@ You can also use the class-based API if you prefer:
|
|
|
26
30
|
from khmerns import KhmerSegmenter
|
|
27
31
|
|
|
28
32
|
segmenter = KhmerSegmenter()
|
|
33
|
+
|
|
29
34
|
words = segmenter.tokenize("សួស្តីបងប្អូន")
|
|
30
35
|
# or
|
|
36
|
+
|
|
31
37
|
words = segmenter("សួស្តីបងប្អូន")
|
|
32
38
|
```
|
|
33
39
|
|
|
@@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "khmerns"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.4"
|
|
8
8
|
license = "MIT"
|
|
9
9
|
license-files = ["LICENSE"]
|
|
10
10
|
description = "Khmer Neural Segmenter"
|
|
@@ -31,6 +31,7 @@ Issues = "https://github.com/seanghay/khmer-neural-segmenter/issues"
|
|
|
31
31
|
[tool.scikit-build]
|
|
32
32
|
minimum-version = "build-system.requires"
|
|
33
33
|
wheel.packages = ["src/khmerns"]
|
|
34
|
+
sdist.exclude = ["training", "img", "test", ".github"]
|
|
34
35
|
|
|
35
36
|
[tool.cibuildwheel]
|
|
36
37
|
build-frontend = "build[uv]"
|
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
# Copyright (c) 2021-2024, SIL Global.
|
|
3
3
|
# Licensed under MIT license: https://opensource.org/licenses/MIT
|
|
4
4
|
|
|
5
|
-
import enum
|
|
5
|
+
import enum
|
|
6
|
+
import re
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class Cats(enum.Enum):
|
|
@@ -137,7 +138,7 @@ def lunar(m, base):
|
|
|
137
138
|
return chr(v + base)
|
|
138
139
|
|
|
139
140
|
|
|
140
|
-
def
|
|
141
|
+
def normalize(txt, lang="km"):
|
|
141
142
|
"""Returns khmer normalised string, without fixing or marking errors"""
|
|
142
143
|
# Mark final coengs in Middle Khmer
|
|
143
144
|
if lang == "xhm":
|
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
name: Wheels
|
|
2
|
-
|
|
3
|
-
on:
|
|
4
|
-
workflow_dispatch:
|
|
5
|
-
pull_request:
|
|
6
|
-
push:
|
|
7
|
-
branches:
|
|
8
|
-
- main
|
|
9
|
-
release:
|
|
10
|
-
types:
|
|
11
|
-
- published
|
|
12
|
-
|
|
13
|
-
env:
|
|
14
|
-
FORCE_COLOR: 3
|
|
15
|
-
|
|
16
|
-
concurrency:
|
|
17
|
-
group: ${{ github.workflow }}-${{ github.ref }}
|
|
18
|
-
cancel-in-progress: true
|
|
19
|
-
|
|
20
|
-
jobs:
|
|
21
|
-
build_sdist:
|
|
22
|
-
name: Build SDist
|
|
23
|
-
runs-on: ubuntu-latest
|
|
24
|
-
steps:
|
|
25
|
-
- uses: actions/checkout@v5
|
|
26
|
-
with:
|
|
27
|
-
submodules: true
|
|
28
|
-
|
|
29
|
-
- name: Build SDist
|
|
30
|
-
run: pipx run build --sdist
|
|
31
|
-
|
|
32
|
-
- name: Check metadata
|
|
33
|
-
run: pipx run twine check dist/*
|
|
34
|
-
|
|
35
|
-
- uses: actions/upload-artifact@v5
|
|
36
|
-
with:
|
|
37
|
-
name: cibw-sdist
|
|
38
|
-
path: dist/*.tar.gz
|
|
39
|
-
|
|
40
|
-
build_wheels:
|
|
41
|
-
name: Wheels on ${{ matrix.os }}
|
|
42
|
-
runs-on: ${{ matrix.os }}
|
|
43
|
-
strategy:
|
|
44
|
-
fail-fast: false
|
|
45
|
-
matrix:
|
|
46
|
-
os: [ubuntu-latest, macos-latest, macos-15-intel, windows-latest, ubuntu-24.04-arm]
|
|
47
|
-
env:
|
|
48
|
-
MACOSX_DEPLOYMENT_TARGET: "11.0"
|
|
49
|
-
steps:
|
|
50
|
-
- uses: actions/checkout@v5
|
|
51
|
-
with:
|
|
52
|
-
submodules: true
|
|
53
|
-
|
|
54
|
-
- uses: astral-sh/setup-uv@v7
|
|
55
|
-
|
|
56
|
-
- uses: pypa/cibuildwheel@v3.3
|
|
57
|
-
|
|
58
|
-
- name: Verify clean directory
|
|
59
|
-
run: git diff --exit-code
|
|
60
|
-
shell: bash
|
|
61
|
-
|
|
62
|
-
- uses: actions/upload-artifact@v5
|
|
63
|
-
with:
|
|
64
|
-
name: cibw-wheels-${{ matrix.os }}
|
|
65
|
-
path: wheelhouse/*.whl
|
|
66
|
-
|
|
67
|
-
upload_all:
|
|
68
|
-
name: Upload if release
|
|
69
|
-
needs: [build_wheels, build_sdist]
|
|
70
|
-
runs-on: ubuntu-latest
|
|
71
|
-
if: github.event_name == 'release' && github.event.action == 'published'
|
|
72
|
-
environment: pypi
|
|
73
|
-
permissions:
|
|
74
|
-
id-token: write
|
|
75
|
-
attestations: write
|
|
76
|
-
|
|
77
|
-
steps:
|
|
78
|
-
- uses: actions/setup-python@v6
|
|
79
|
-
with:
|
|
80
|
-
python-version: "3.x"
|
|
81
|
-
|
|
82
|
-
- uses: actions/download-artifact@v6
|
|
83
|
-
with:
|
|
84
|
-
pattern: cibw-*
|
|
85
|
-
merge-multiple: true
|
|
86
|
-
path: dist
|
|
87
|
-
|
|
88
|
-
- name: Generate artifact attestation for sdist and wheels
|
|
89
|
-
uses: actions/attest-build-provenance@v3
|
|
90
|
-
with:
|
|
91
|
-
subject-path: "dist/*"
|
|
92
|
-
|
|
93
|
-
- uses: pypa/gh-action-pypi-publish@release/v1
|
khmerns-0.0.3/img/graph.png
DELETED
|
Binary file
|
khmerns-0.0.3/test/example.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
from khmerns import tokenize
|
|
2
|
-
|
|
3
|
-
print(
|
|
4
|
-
tokenize(
|
|
5
|
-
"តុលាការបារាំងបើកការស៊ើបអង្កេតលើក្រុមហ៊ុនបណ្តាញសង្គម X (Twitter) ក្នុងសំណុំរឿងឧក្រិដ្ឋកម្មអនឡាញ"
|
|
6
|
-
)
|
|
7
|
-
)
|
|
8
|
-
|
|
9
|
-
# ['តុលាការ', 'បារាំង', 'បើក', 'ការ', 'ស៊ើប', 'អង្កេត', 'លើ', 'ក្រុមហ៊ុន', 'បណ្តាញ', 'សង្គម', ' ', 'X', ' ', '(', 'T', 'w', 'i', 't', 't', 'e', 'r', ')', ' ', 'ក្នុង', 'សំណុំ', 'រឿង', 'ឧក្រិដ្ឋ', 'កម្ម', 'អនឡាញ']
|
|
Binary file
|
|
@@ -1,124 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""Convert PyTorch Khmer segmenter model to GGUF format."""
|
|
3
|
-
|
|
4
|
-
import argparse
|
|
5
|
-
import sys
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
# Add parent directory for imports
|
|
12
|
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
13
|
-
from model import Segmenter
|
|
14
|
-
from tokenizer import Tokenizer
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
from gguf import GGUFWriter
|
|
18
|
-
except ImportError:
|
|
19
|
-
print("Error: gguf package not installed. Run: pip install gguf")
|
|
20
|
-
sys.exit(1)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def convert_to_gguf(model_path: str, output_path: str):
|
|
24
|
-
"""Convert PyTorch model to GGUF format."""
|
|
25
|
-
|
|
26
|
-
# Load tokenizer and model
|
|
27
|
-
print(f"Loading model from {model_path}...")
|
|
28
|
-
tokenizer = Tokenizer()
|
|
29
|
-
model = Segmenter(
|
|
30
|
-
vocab_size=len(tokenizer),
|
|
31
|
-
embedding_dim=256,
|
|
32
|
-
hidden_dim=256,
|
|
33
|
-
num_labels=3,
|
|
34
|
-
)
|
|
35
|
-
model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
|
|
36
|
-
model.eval()
|
|
37
|
-
|
|
38
|
-
# Create GGUF writer
|
|
39
|
-
print(f"Creating GGUF file: {output_path}")
|
|
40
|
-
writer = GGUFWriter(output_path, "khmer-segmenter")
|
|
41
|
-
|
|
42
|
-
# Write metadata
|
|
43
|
-
writer.add_uint32("khmer.vocab_size", len(tokenizer))
|
|
44
|
-
writer.add_uint32("khmer.embedding_dim", 256)
|
|
45
|
-
writer.add_uint32("khmer.hidden_dim", 256)
|
|
46
|
-
writer.add_uint32("khmer.num_labels", 3)
|
|
47
|
-
|
|
48
|
-
# Write tensors
|
|
49
|
-
print("Writing tensors...")
|
|
50
|
-
|
|
51
|
-
# Embedding: [vocab_size, embedding_dim]
|
|
52
|
-
embed_weight = model.embedding.weight.detach().numpy().astype(np.float32)
|
|
53
|
-
writer.add_tensor("embedding.weight", embed_weight)
|
|
54
|
-
print(f" embedding.weight: {embed_weight.shape}")
|
|
55
|
-
|
|
56
|
-
# GRU forward weights
|
|
57
|
-
# PyTorch GRU weight_ih_l0: [3*hidden, input]
|
|
58
|
-
# PyTorch GRU weight_hh_l0: [3*hidden, hidden]
|
|
59
|
-
gru = model.gru
|
|
60
|
-
writer.add_tensor("gru.weight_ih_l0", gru.weight_ih_l0.detach().numpy().astype(np.float32))
|
|
61
|
-
writer.add_tensor("gru.weight_hh_l0", gru.weight_hh_l0.detach().numpy().astype(np.float32))
|
|
62
|
-
writer.add_tensor("gru.bias_ih_l0", gru.bias_ih_l0.detach().numpy().astype(np.float32))
|
|
63
|
-
writer.add_tensor("gru.bias_hh_l0", gru.bias_hh_l0.detach().numpy().astype(np.float32))
|
|
64
|
-
print(f" gru.weight_ih_l0: {gru.weight_ih_l0.shape}")
|
|
65
|
-
print(f" gru.weight_hh_l0: {gru.weight_hh_l0.shape}")
|
|
66
|
-
|
|
67
|
-
# GRU backward (reverse) weights
|
|
68
|
-
writer.add_tensor("gru.weight_ih_l0_reverse", gru.weight_ih_l0_reverse.detach().numpy().astype(np.float32))
|
|
69
|
-
writer.add_tensor("gru.weight_hh_l0_reverse", gru.weight_hh_l0_reverse.detach().numpy().astype(np.float32))
|
|
70
|
-
writer.add_tensor("gru.bias_ih_l0_reverse", gru.bias_ih_l0_reverse.detach().numpy().astype(np.float32))
|
|
71
|
-
writer.add_tensor("gru.bias_hh_l0_reverse", gru.bias_hh_l0_reverse.detach().numpy().astype(np.float32))
|
|
72
|
-
print(f" gru.weight_ih_l0_reverse: {gru.weight_ih_l0_reverse.shape}")
|
|
73
|
-
|
|
74
|
-
# Linear layer: [num_labels, 2*hidden]
|
|
75
|
-
fc_weight = model.fc.weight.detach().numpy().astype(np.float32)
|
|
76
|
-
fc_bias = model.fc.bias.detach().numpy().astype(np.float32)
|
|
77
|
-
writer.add_tensor("fc.weight", fc_weight)
|
|
78
|
-
writer.add_tensor("fc.bias", fc_bias)
|
|
79
|
-
print(f" fc.weight: {fc_weight.shape}")
|
|
80
|
-
print(f" fc.bias: {fc_bias.shape}")
|
|
81
|
-
|
|
82
|
-
# CRF parameters
|
|
83
|
-
crf = model.crf
|
|
84
|
-
writer.add_tensor("crf.start_transitions", crf.start_transitions.detach().numpy().astype(np.float32))
|
|
85
|
-
writer.add_tensor("crf.end_transitions", crf.end_transitions.detach().numpy().astype(np.float32))
|
|
86
|
-
writer.add_tensor("crf.transitions", crf.transitions.detach().numpy().astype(np.float32))
|
|
87
|
-
print(f" crf.start_transitions: {crf.start_transitions.shape}")
|
|
88
|
-
print(f" crf.end_transitions: {crf.end_transitions.shape}")
|
|
89
|
-
print(f" crf.transitions: {crf.transitions.shape}")
|
|
90
|
-
|
|
91
|
-
# Finalize
|
|
92
|
-
writer.write_header_to_file()
|
|
93
|
-
writer.write_kv_data_to_file()
|
|
94
|
-
writer.write_tensors_to_file()
|
|
95
|
-
writer.close()
|
|
96
|
-
|
|
97
|
-
print(f"\nGGUF model saved to: {output_path}")
|
|
98
|
-
|
|
99
|
-
# Print file size
|
|
100
|
-
size_mb = Path(output_path).stat().st_size / (1024 * 1024)
|
|
101
|
-
print(f"File size: {size_mb:.2f} MB")
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def main():
|
|
105
|
-
parser = argparse.ArgumentParser(
|
|
106
|
-
description="Convert PyTorch Khmer segmenter to GGUF format"
|
|
107
|
-
)
|
|
108
|
-
parser.add_argument(
|
|
109
|
-
"model_path",
|
|
110
|
-
type=str,
|
|
111
|
-
help="Path to PyTorch model file (best_model.pt)",
|
|
112
|
-
)
|
|
113
|
-
parser.add_argument(
|
|
114
|
-
"output_path",
|
|
115
|
-
type=str,
|
|
116
|
-
help="Output GGUF file path",
|
|
117
|
-
)
|
|
118
|
-
args = parser.parse_args()
|
|
119
|
-
|
|
120
|
-
convert_to_gguf(args.model_path, args.output_path)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
if __name__ == "__main__":
|
|
124
|
-
main()
|
khmerns-0.0.3/training/data.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
import torch
|
|
3
|
-
import re
|
|
4
|
-
from torch.utils.data import DataLoader
|
|
5
|
-
from tokenizer import Tokenizer
|
|
6
|
-
from torch.utils.data import Dataset
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def yield_chunks(data, n, s):
|
|
10
|
-
for i in range(0, len(data), s):
|
|
11
|
-
yield data[i : i + n]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
re_khmer = re.compile(r"[\u1780-\u17ff]+")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class TextDataset(Dataset):
|
|
18
|
-
def __init__(self, tokenizer: Tokenizer, split="train", train_ratio=0.95, seed=42):
|
|
19
|
-
super().__init__()
|
|
20
|
-
self.tokenizer = tokenizer
|
|
21
|
-
with open("data/train.txt") as infile:
|
|
22
|
-
lines = [line.rstrip("\n") for line in infile]
|
|
23
|
-
|
|
24
|
-
rng = random.Random(seed)
|
|
25
|
-
all_items = [c for c in yield_chunks(lines, 128, rng.randint(1, 128))]
|
|
26
|
-
split_idx = int(len(all_items) * train_ratio)
|
|
27
|
-
|
|
28
|
-
if split == "train":
|
|
29
|
-
self.items = all_items[:split_idx]
|
|
30
|
-
else:
|
|
31
|
-
self.items = all_items[split_idx:]
|
|
32
|
-
|
|
33
|
-
def __len__(self):
|
|
34
|
-
return len(self.items)
|
|
35
|
-
|
|
36
|
-
def __getitem__(self, i):
|
|
37
|
-
inputs = []
|
|
38
|
-
tags = []
|
|
39
|
-
for w in self.items[i]:
|
|
40
|
-
is_khmer = re_khmer.search(w)
|
|
41
|
-
token_ids = self.tokenizer.encode(w)
|
|
42
|
-
for idx, token_id in enumerate(token_ids):
|
|
43
|
-
inputs.append(token_id)
|
|
44
|
-
if is_khmer:
|
|
45
|
-
if idx == 0:
|
|
46
|
-
tags.append(1)
|
|
47
|
-
else:
|
|
48
|
-
tags.append(2)
|
|
49
|
-
else:
|
|
50
|
-
tags.append(0)
|
|
51
|
-
|
|
52
|
-
inputs = [self.tokenizer.bos_id] + inputs + [self.tokenizer.eos_id]
|
|
53
|
-
tags = [0] + tags + [0]
|
|
54
|
-
|
|
55
|
-
return torch.LongTensor(inputs), torch.LongTensor(tags)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def collate_fn(batch):
|
|
59
|
-
inputs, tags = zip(*batch)
|
|
60
|
-
lengths = [len(x) for x in inputs]
|
|
61
|
-
max_len = max(lengths)
|
|
62
|
-
|
|
63
|
-
padded_inputs = torch.zeros(len(batch), max_len, dtype=torch.long)
|
|
64
|
-
padded_tags = torch.zeros(len(batch), max_len, dtype=torch.long)
|
|
65
|
-
mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
|
|
66
|
-
|
|
67
|
-
for i, (inp, tag) in enumerate(zip(inputs, tags)):
|
|
68
|
-
padded_inputs[i, : lengths[i]] = inp
|
|
69
|
-
padded_tags[i, : lengths[i]] = tag
|
|
70
|
-
mask[i, : lengths[i]] = True
|
|
71
|
-
|
|
72
|
-
return padded_inputs, padded_tags, mask
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if __name__ == "__main__":
|
|
76
|
-
dataset = TextDataset(tokenizer=Tokenizer())
|
|
77
|
-
inputs, targets = dataset[1]
|
|
78
|
-
# print(inputs, targets)
|
|
@@ -1,169 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
import numpy as np
|
|
4
|
-
from model import Segmenter
|
|
5
|
-
from tokenizer import Tokenizer
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class SegmenterEmissions(nn.Module):
|
|
9
|
-
"""Wrapper that outputs emissions only (for ONNX export)."""
|
|
10
|
-
|
|
11
|
-
def __init__(self, segmenter):
|
|
12
|
-
super().__init__()
|
|
13
|
-
self.embedding = segmenter.embedding
|
|
14
|
-
self.gru = segmenter.gru
|
|
15
|
-
self.fc = segmenter.fc
|
|
16
|
-
|
|
17
|
-
def forward(self, x):
|
|
18
|
-
embedded = self.embedding(x)
|
|
19
|
-
gru_out, _ = self.gru(embedded)
|
|
20
|
-
emissions = self.fc(gru_out)
|
|
21
|
-
return emissions
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def export_to_onnx(
|
|
25
|
-
model_path="best_model.pt",
|
|
26
|
-
onnx_path="segmenter.onnx",
|
|
27
|
-
crf_path="crf_params.npz",
|
|
28
|
-
):
|
|
29
|
-
tokenizer = Tokenizer()
|
|
30
|
-
model = Segmenter(
|
|
31
|
-
vocab_size=len(tokenizer),
|
|
32
|
-
embedding_dim=256,
|
|
33
|
-
hidden_dim=256,
|
|
34
|
-
num_labels=3,
|
|
35
|
-
)
|
|
36
|
-
model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
|
37
|
-
model.eval()
|
|
38
|
-
|
|
39
|
-
# Extract CRF parameters
|
|
40
|
-
crf = model.crf
|
|
41
|
-
print(crf.start_transitions)
|
|
42
|
-
print(crf.end_transitions)
|
|
43
|
-
print(crf.transitions)
|
|
44
|
-
|
|
45
|
-
np.savez(
|
|
46
|
-
crf_path,
|
|
47
|
-
start_transitions=crf.start_transitions.detach().numpy(),
|
|
48
|
-
end_transitions=crf.end_transitions.detach().numpy(),
|
|
49
|
-
transitions=crf.transitions.detach().numpy(),
|
|
50
|
-
)
|
|
51
|
-
print(f"Saved CRF parameters to {crf_path}")
|
|
52
|
-
|
|
53
|
-
# Create emissions-only model
|
|
54
|
-
emissions_model = SegmenterEmissions(model)
|
|
55
|
-
emissions_model.eval()
|
|
56
|
-
|
|
57
|
-
# Create dummy input for tracing
|
|
58
|
-
dummy_input = torch.randint(0, len(tokenizer), (1, 32), dtype=torch.long)
|
|
59
|
-
|
|
60
|
-
# Export to ONNX (use legacy export to avoid dynamo issues)
|
|
61
|
-
torch.onnx.export(
|
|
62
|
-
emissions_model,
|
|
63
|
-
dummy_input,
|
|
64
|
-
onnx_path,
|
|
65
|
-
input_names=["input_ids"],
|
|
66
|
-
output_names=["emissions"],
|
|
67
|
-
dynamic_axes={
|
|
68
|
-
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
|
69
|
-
"emissions": {0: "batch_size", 1: "sequence_length"},
|
|
70
|
-
},
|
|
71
|
-
opset_version=14,
|
|
72
|
-
dynamo=False,
|
|
73
|
-
)
|
|
74
|
-
print(f"Exported ONNX model to {onnx_path}")
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def viterbi_decode(emissions, start_transitions, end_transitions, transitions):
|
|
78
|
-
"""Viterbi decoding for CRF inference."""
|
|
79
|
-
seq_length, _ = emissions.shape
|
|
80
|
-
|
|
81
|
-
# Initialize
|
|
82
|
-
score = start_transitions + emissions[0]
|
|
83
|
-
history = []
|
|
84
|
-
|
|
85
|
-
# Forward pass
|
|
86
|
-
for i in range(1, seq_length):
|
|
87
|
-
broadcast_score = score.reshape(-1, 1)
|
|
88
|
-
broadcast_emissions = emissions[i].reshape(1, -1)
|
|
89
|
-
next_score = broadcast_score + transitions + broadcast_emissions
|
|
90
|
-
indices = next_score.argmax(axis=0)
|
|
91
|
-
score = next_score.max(axis=0)
|
|
92
|
-
history.append(indices)
|
|
93
|
-
|
|
94
|
-
# Add end transitions
|
|
95
|
-
score += end_transitions
|
|
96
|
-
|
|
97
|
-
# Backtrack
|
|
98
|
-
best_tags = [int(score.argmax())]
|
|
99
|
-
for hist in reversed(history):
|
|
100
|
-
best_tags.append(int(hist[best_tags[-1]]))
|
|
101
|
-
best_tags.reverse()
|
|
102
|
-
|
|
103
|
-
return best_tags
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def segment_onnx(text, session, tokenizer, crf_params):
|
|
107
|
-
"""Segment text using ONNX Runtime."""
|
|
108
|
-
token_ids = tokenizer.encode(text)
|
|
109
|
-
inputs = [tokenizer.bos_id] + token_ids + [tokenizer.eos_id]
|
|
110
|
-
input_array = np.array([inputs], dtype=np.int64)
|
|
111
|
-
|
|
112
|
-
# Run inference
|
|
113
|
-
emissions = session.run(None, {"input_ids": input_array})[0][0]
|
|
114
|
-
|
|
115
|
-
# Viterbi decode
|
|
116
|
-
predictions = viterbi_decode(
|
|
117
|
-
emissions,
|
|
118
|
-
crf_params["start_transitions"],
|
|
119
|
-
crf_params["end_transitions"],
|
|
120
|
-
crf_params["transitions"],
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
# Remove BOS/EOS predictions
|
|
124
|
-
predictions = predictions[1:-1]
|
|
125
|
-
|
|
126
|
-
# Segment based on B-WORD (1) tags
|
|
127
|
-
words = []
|
|
128
|
-
current_word = []
|
|
129
|
-
|
|
130
|
-
for char, tag in zip(text, predictions):
|
|
131
|
-
if tag == 1: # B-WORD
|
|
132
|
-
if current_word:
|
|
133
|
-
words.append("".join(current_word))
|
|
134
|
-
current_word = [char]
|
|
135
|
-
elif tag == 2: # I-WORD
|
|
136
|
-
current_word.append(char)
|
|
137
|
-
else: # 0 (non-Khmer)
|
|
138
|
-
if current_word:
|
|
139
|
-
words.append("".join(current_word))
|
|
140
|
-
current_word = []
|
|
141
|
-
words.append(char)
|
|
142
|
-
|
|
143
|
-
if current_word:
|
|
144
|
-
words.append("".join(current_word))
|
|
145
|
-
|
|
146
|
-
return words
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
if __name__ == "__main__":
|
|
150
|
-
# Export model
|
|
151
|
-
export_to_onnx()
|
|
152
|
-
|
|
153
|
-
# Test ONNX inference
|
|
154
|
-
import onnxruntime as ort
|
|
155
|
-
|
|
156
|
-
tokenizer = Tokenizer()
|
|
157
|
-
session = ort.InferenceSession("segmenter.onnx")
|
|
158
|
-
crf_params = np.load("crf_params.npz")
|
|
159
|
-
|
|
160
|
-
text = "គិតចាប់ពី ខែធ្នូ ឆ្នាំ២០២៤ មកដល់ថ្ងៃទី១១".replace("\u200b", "")
|
|
161
|
-
words = segment_onnx(text, session, tokenizer, crf_params)
|
|
162
|
-
print(f"ONNX result: {'|'.join(words)}")
|
|
163
|
-
|
|
164
|
-
# Compare with PyTorch
|
|
165
|
-
from infer import load_model, segment
|
|
166
|
-
|
|
167
|
-
model, tokenizer = load_model()
|
|
168
|
-
words_pt = segment(text, model, tokenizer)
|
|
169
|
-
print(f"PyTorch result: {'|'.join(words_pt)}")
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import regex as re
|
|
3
|
-
from khmersegment import Segmenter
|
|
4
|
-
from nltk.tokenize import TweetTokenizer
|
|
5
|
-
from khnormal import khnormal
|
|
6
|
-
|
|
7
|
-
tknzr = TweetTokenizer(reduce_len=True, strip_handles=False)
|
|
8
|
-
|
|
9
|
-
segmenter = Segmenter("-m assets/km-5tag-seg-model")
|
|
10
|
-
re_pre_segment = re.compile(
|
|
11
|
-
r"([\u1780-\u17dd]+)|([\u17e0-\u17e90-9]+)|([^\u1780-\u17ff]+)"
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def segment(text: str):
|
|
16
|
-
for m in re_pre_segment.finditer(text):
|
|
17
|
-
if m[2]:
|
|
18
|
-
yield m[2]
|
|
19
|
-
continue
|
|
20
|
-
if m[1]:
|
|
21
|
-
for segment in segmenter(m[1], deep=True):
|
|
22
|
-
yield segment
|
|
23
|
-
continue
|
|
24
|
-
|
|
25
|
-
if len(m[0].strip()) == 0:
|
|
26
|
-
yield m[0]
|
|
27
|
-
continue
|
|
28
|
-
|
|
29
|
-
tokens = tknzr.tokenize(m[0])
|
|
30
|
-
if len(tokens) == 0:
|
|
31
|
-
yield m[0]
|
|
32
|
-
continue
|
|
33
|
-
yield from tokens
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
if __name__ == "__main__":
|
|
37
|
-
text_path = "/Users/seanghay/Projects/github/khmer-text-crawler/train.txt"
|
|
38
|
-
|
|
39
|
-
os.makedirs("data", exist_ok=True)
|
|
40
|
-
c = 0
|
|
41
|
-
with open("data/train.txt", "w") as outfile:
|
|
42
|
-
with open(text_path) as infile:
|
|
43
|
-
for line in infile:
|
|
44
|
-
line = line.rstrip("\n")
|
|
45
|
-
#print(line)
|
|
46
|
-
line = khnormal(line)
|
|
47
|
-
for s in segment(line):
|
|
48
|
-
c += 1
|
|
49
|
-
outfile.write(s + "\n")
|
|
50
|
-
print(c)
|
|
51
|
-
if c > 10_000_000:
|
|
52
|
-
break
|
khmerns-0.0.3/training/infer.py
DELETED
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
from khmercut import tokenize
|
|
2
|
-
import torch
|
|
3
|
-
from model import Segmenter
|
|
4
|
-
from tokenizer import Tokenizer
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def load_model(path, device="cpu"):
|
|
8
|
-
tokenizer = Tokenizer()
|
|
9
|
-
model = Segmenter(
|
|
10
|
-
vocab_size=len(tokenizer),
|
|
11
|
-
embedding_dim=256,
|
|
12
|
-
hidden_dim=256,
|
|
13
|
-
num_labels=3,
|
|
14
|
-
)
|
|
15
|
-
model.load_state_dict(torch.load(path, map_location=device))
|
|
16
|
-
model.to(device)
|
|
17
|
-
model.eval()
|
|
18
|
-
return model, tokenizer
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def segment(text, model, tokenizer, device="cpu"):
|
|
22
|
-
token_ids = tokenizer.encode(text)
|
|
23
|
-
inputs = [tokenizer.bos_id] + token_ids + [tokenizer.eos_id]
|
|
24
|
-
inputs = torch.LongTensor(inputs).unsqueeze(0).to(device)
|
|
25
|
-
|
|
26
|
-
with torch.no_grad():
|
|
27
|
-
predictions = model(inputs)[0]
|
|
28
|
-
|
|
29
|
-
# Remove BOS/EOS predictions
|
|
30
|
-
predictions = predictions[1:-1]
|
|
31
|
-
|
|
32
|
-
# Segment based on B-WORD (1) tags
|
|
33
|
-
words = []
|
|
34
|
-
current_word = []
|
|
35
|
-
|
|
36
|
-
for char, tag in zip(text, predictions):
|
|
37
|
-
if tag == 1: # B-WORD
|
|
38
|
-
if current_word:
|
|
39
|
-
words.append("".join(current_word))
|
|
40
|
-
current_word = [char]
|
|
41
|
-
elif tag == 2: # I-WORD
|
|
42
|
-
current_word.append(char)
|
|
43
|
-
else: # 0 (non-Khmer)
|
|
44
|
-
if current_word:
|
|
45
|
-
words.append("".join(current_word))
|
|
46
|
-
current_word = []
|
|
47
|
-
words.append(char)
|
|
48
|
-
|
|
49
|
-
if current_word:
|
|
50
|
-
words.append("".join(current_word))
|
|
51
|
-
|
|
52
|
-
return words
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
if __name__ == "__main__":
|
|
56
|
-
device = "cpu"
|
|
57
|
-
model, tokenizer = load_model("best_model.pt", device=device)
|
|
58
|
-
text = "ប្រជាជនទីបេរស់នៅក្រៅស្រុក ទូទាំងពិភពលោក បានចាប់ផ្តើមនីតិវិធីបោះឆ្នោត ដើម្បីជ្រើសរើសថ្នាក់ដឹកនាំរដ្ឋាភិបាលភៀសខ្លួន ដែលមានទីតាំងស្ថិតនៅទីក្រុង Dharamsala ភាគខាងជើងប្រទេសឥណ្ឌា។ ជាជំហានដំបូង ថ្ងៃទី១កុម្ភៈ ប្រជាជនទីបេត្រូវបោះឆ្នោត តែងតាំងបេក្ខជនជាមុនសិន ហើយជំហានបន្ទាប់ នៅថ្ងៃទី២៦មេសា គឺត្រូវសម្រេចជ្រើសរើសក្នុងចំណោមបេក្ខជនឈរឈ្មោះទាំងអស់។ លទ្ធផលជាស្ថាពរចុងក្រោយ នឹងត្រូវប្រកាសនៅថ្ងៃទី១៣ខែឧសភា។".replace(
|
|
59
|
-
"\u200b", ""
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
words = segment(text, model, tokenizer, device=device)
|
|
63
|
-
print("|".join(words))
|
khmerns-0.0.3/training/model.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
from torchcrf import CRF
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class Segmenter(nn.Module):
|
|
7
|
-
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_labels):
|
|
8
|
-
super(Segmenter, self).__init__()
|
|
9
|
-
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
|
10
|
-
self.gru = nn.GRU(
|
|
11
|
-
embedding_dim,
|
|
12
|
-
hidden_dim,
|
|
13
|
-
bidirectional=True,
|
|
14
|
-
batch_first=True,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
self.fc = nn.Linear(hidden_dim * 2, num_labels)
|
|
18
|
-
self.crf = CRF(num_labels, batch_first=True)
|
|
19
|
-
|
|
20
|
-
def forward(self, x, tags=None, mask=None):
|
|
21
|
-
embedded = self.embedding(x)
|
|
22
|
-
gru_out, _ = self.gru(embedded)
|
|
23
|
-
emissions = self.fc(gru_out)
|
|
24
|
-
if tags is not None:
|
|
25
|
-
log_likelihood = self.crf(emissions, tags, mask=mask, reduction="mean")
|
|
26
|
-
return -log_likelihood
|
|
27
|
-
else:
|
|
28
|
-
return self.crf.decode(emissions, mask=mask)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
if __name__ == "__main__":
|
|
32
|
-
model = Segmenter(vocab_size=200, embedding_dim=256, hidden_dim=256, num_labels=5)
|
|
33
|
-
input_data = torch.randint(0, 200, (4, 10)).long()
|
|
34
|
-
target_tags = torch.randint(0, 5, (4, 10)).long()
|
|
35
|
-
|
|
36
|
-
loss = model(input_data, tags=target_tags)
|
|
37
|
-
loss.backward()
|
|
38
|
-
|
|
39
|
-
with torch.no_grad():
|
|
40
|
-
best_paths = model(input_data)
|
|
41
|
-
print(f"Predicted Tag Sequence: {best_paths}")
|
|
Binary file
|
|
@@ -1,202 +0,0 @@
|
|
|
1
|
-
class Tokenizer:
|
|
2
|
-
def __init__(self):
|
|
3
|
-
self.vocab = [
|
|
4
|
-
" ",
|
|
5
|
-
"!",
|
|
6
|
-
"#",
|
|
7
|
-
"$",
|
|
8
|
-
"%",
|
|
9
|
-
"&",
|
|
10
|
-
"(",
|
|
11
|
-
")",
|
|
12
|
-
"+",
|
|
13
|
-
",",
|
|
14
|
-
"-",
|
|
15
|
-
".",
|
|
16
|
-
"/",
|
|
17
|
-
"0",
|
|
18
|
-
"1",
|
|
19
|
-
"2",
|
|
20
|
-
"3",
|
|
21
|
-
"4",
|
|
22
|
-
"5",
|
|
23
|
-
"6",
|
|
24
|
-
"7",
|
|
25
|
-
"8",
|
|
26
|
-
"9",
|
|
27
|
-
":",
|
|
28
|
-
";",
|
|
29
|
-
"=",
|
|
30
|
-
"?",
|
|
31
|
-
"@",
|
|
32
|
-
"A",
|
|
33
|
-
"B",
|
|
34
|
-
"C",
|
|
35
|
-
"D",
|
|
36
|
-
"E",
|
|
37
|
-
"F",
|
|
38
|
-
"G",
|
|
39
|
-
"H",
|
|
40
|
-
"I",
|
|
41
|
-
"J",
|
|
42
|
-
"K",
|
|
43
|
-
"L",
|
|
44
|
-
"M",
|
|
45
|
-
"N",
|
|
46
|
-
"O",
|
|
47
|
-
"P",
|
|
48
|
-
"Q",
|
|
49
|
-
"R",
|
|
50
|
-
"S",
|
|
51
|
-
"T",
|
|
52
|
-
"U",
|
|
53
|
-
"V",
|
|
54
|
-
"W",
|
|
55
|
-
"X",
|
|
56
|
-
"Y",
|
|
57
|
-
"Z",
|
|
58
|
-
"_",
|
|
59
|
-
"a",
|
|
60
|
-
"b",
|
|
61
|
-
"c",
|
|
62
|
-
"d",
|
|
63
|
-
"e",
|
|
64
|
-
"f",
|
|
65
|
-
"g",
|
|
66
|
-
"h",
|
|
67
|
-
"i",
|
|
68
|
-
"j",
|
|
69
|
-
"k",
|
|
70
|
-
"l",
|
|
71
|
-
"m",
|
|
72
|
-
"n",
|
|
73
|
-
"o",
|
|
74
|
-
"p",
|
|
75
|
-
"q",
|
|
76
|
-
"r",
|
|
77
|
-
"s",
|
|
78
|
-
"t",
|
|
79
|
-
"u",
|
|
80
|
-
"v",
|
|
81
|
-
"w",
|
|
82
|
-
"x",
|
|
83
|
-
"y",
|
|
84
|
-
"z",
|
|
85
|
-
"«",
|
|
86
|
-
"°",
|
|
87
|
-
"»",
|
|
88
|
-
"á",
|
|
89
|
-
"é",
|
|
90
|
-
"ë",
|
|
91
|
-
"ó",
|
|
92
|
-
"ö",
|
|
93
|
-
"ü",
|
|
94
|
-
"ក",
|
|
95
|
-
"ខ",
|
|
96
|
-
"គ",
|
|
97
|
-
"ឃ",
|
|
98
|
-
"ង",
|
|
99
|
-
"ច",
|
|
100
|
-
"ឆ",
|
|
101
|
-
"ជ",
|
|
102
|
-
"ឈ",
|
|
103
|
-
"ញ",
|
|
104
|
-
"ដ",
|
|
105
|
-
"ឋ",
|
|
106
|
-
"ឌ",
|
|
107
|
-
"ឍ",
|
|
108
|
-
"ណ",
|
|
109
|
-
"ត",
|
|
110
|
-
"ថ",
|
|
111
|
-
"ទ",
|
|
112
|
-
"ធ",
|
|
113
|
-
"ន",
|
|
114
|
-
"ប",
|
|
115
|
-
"ផ",
|
|
116
|
-
"ព",
|
|
117
|
-
"ភ",
|
|
118
|
-
"ម",
|
|
119
|
-
"យ",
|
|
120
|
-
"រ",
|
|
121
|
-
"ល",
|
|
122
|
-
"វ",
|
|
123
|
-
"ស",
|
|
124
|
-
"ហ",
|
|
125
|
-
"ឡ",
|
|
126
|
-
"អ",
|
|
127
|
-
"ឤ",
|
|
128
|
-
"ឥ",
|
|
129
|
-
"ឦ",
|
|
130
|
-
"ឧ",
|
|
131
|
-
"ឪ",
|
|
132
|
-
"ឫ",
|
|
133
|
-
"ឬ",
|
|
134
|
-
"ឭ",
|
|
135
|
-
"ឮ",
|
|
136
|
-
"ឯ",
|
|
137
|
-
"ឱ",
|
|
138
|
-
"ឲ",
|
|
139
|
-
"ា",
|
|
140
|
-
"ិ",
|
|
141
|
-
"ី",
|
|
142
|
-
"ឹ",
|
|
143
|
-
"ឺ",
|
|
144
|
-
"ុ",
|
|
145
|
-
"ូ",
|
|
146
|
-
"ួ",
|
|
147
|
-
"ើ",
|
|
148
|
-
"ឿ",
|
|
149
|
-
"ៀ",
|
|
150
|
-
"េ",
|
|
151
|
-
"ែ",
|
|
152
|
-
"ៃ",
|
|
153
|
-
"ោ",
|
|
154
|
-
"ៅ",
|
|
155
|
-
"ំ",
|
|
156
|
-
"ះ",
|
|
157
|
-
"ៈ",
|
|
158
|
-
"៉",
|
|
159
|
-
"៊",
|
|
160
|
-
"់",
|
|
161
|
-
"៌",
|
|
162
|
-
"៍",
|
|
163
|
-
"៏",
|
|
164
|
-
"័",
|
|
165
|
-
"្",
|
|
166
|
-
"។",
|
|
167
|
-
"៕",
|
|
168
|
-
"៖",
|
|
169
|
-
"ៗ",
|
|
170
|
-
"៘",
|
|
171
|
-
"៛",
|
|
172
|
-
"០",
|
|
173
|
-
"១",
|
|
174
|
-
"២",
|
|
175
|
-
"៣",
|
|
176
|
-
"៤",
|
|
177
|
-
"៥",
|
|
178
|
-
"៦",
|
|
179
|
-
"៧",
|
|
180
|
-
"៨",
|
|
181
|
-
"៩",
|
|
182
|
-
]
|
|
183
|
-
|
|
184
|
-
self.pad_id = 0
|
|
185
|
-
self.bos_id = 1
|
|
186
|
-
self.eos_id = 2
|
|
187
|
-
self.unk_id = 3
|
|
188
|
-
|
|
189
|
-
def __len__(self):
|
|
190
|
-
return len(self.vocab) + 4
|
|
191
|
-
|
|
192
|
-
def decode(self, ids) -> str:
|
|
193
|
-
return "".join([self.vocab[i - 4] for i in ids if i - 4 >= 0])
|
|
194
|
-
|
|
195
|
-
def encode(self, text: str):
|
|
196
|
-
return [(self.vocab.index(c) + 4) if c in self.vocab else self.unk_id for c in text]
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if __name__ == "__main__":
|
|
200
|
-
t = Tokenizer()
|
|
201
|
-
ids = t.encode("មិនដឹង")
|
|
202
|
-
print(t.decode(ids), len(t))
|
khmerns-0.0.3/training/train.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.utils.data import DataLoader
|
|
3
|
-
from torch.optim import AdamW
|
|
4
|
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
5
|
-
from model import Segmenter
|
|
6
|
-
from data import TextDataset, collate_fn
|
|
7
|
-
from tokenizer import Tokenizer
|
|
8
|
-
from tqdm import tqdm
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def train():
|
|
12
|
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
13
|
-
print(f"Using device: {device}")
|
|
14
|
-
|
|
15
|
-
tokenizer = Tokenizer()
|
|
16
|
-
|
|
17
|
-
train_dataset = TextDataset(tokenizer=tokenizer, split="train")
|
|
18
|
-
eval_dataset = TextDataset(tokenizer=tokenizer, split="eval")
|
|
19
|
-
|
|
20
|
-
train_loader = DataLoader(
|
|
21
|
-
train_dataset,
|
|
22
|
-
batch_size=256,
|
|
23
|
-
shuffle=True,
|
|
24
|
-
collate_fn=collate_fn,
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
eval_loader = DataLoader(
|
|
28
|
-
eval_dataset,
|
|
29
|
-
batch_size=256,
|
|
30
|
-
shuffle=False,
|
|
31
|
-
collate_fn=collate_fn,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
print(f"Train samples: {len(train_dataset)}, Eval samples: {len(eval_dataset)}")
|
|
35
|
-
|
|
36
|
-
model = Segmenter(
|
|
37
|
-
vocab_size=len(tokenizer),
|
|
38
|
-
embedding_dim=256,
|
|
39
|
-
hidden_dim=256,
|
|
40
|
-
num_labels=3,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
model.to(device)
|
|
44
|
-
|
|
45
|
-
optimizer = AdamW(model.parameters(), lr=1e-5)
|
|
46
|
-
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)
|
|
47
|
-
num_epochs = 20
|
|
48
|
-
best_eval_loss = float("inf")
|
|
49
|
-
|
|
50
|
-
for epoch in range(num_epochs):
|
|
51
|
-
# Training
|
|
52
|
-
model.train()
|
|
53
|
-
total_loss = 0.0
|
|
54
|
-
|
|
55
|
-
for batch_idx, (inputs, tags, mask) in enumerate(tqdm(train_loader, desc="Train")):
|
|
56
|
-
inputs = inputs.to(device)
|
|
57
|
-
tags = tags.to(device)
|
|
58
|
-
mask = mask.to(device)
|
|
59
|
-
|
|
60
|
-
optimizer.zero_grad()
|
|
61
|
-
loss = model(inputs, tags=tags, mask=mask)
|
|
62
|
-
loss.backward()
|
|
63
|
-
optimizer.step()
|
|
64
|
-
|
|
65
|
-
total_loss += loss.item()
|
|
66
|
-
|
|
67
|
-
avg_train_loss = total_loss / len(train_loader)
|
|
68
|
-
|
|
69
|
-
# Evaluation
|
|
70
|
-
model.eval()
|
|
71
|
-
eval_loss = 0.0
|
|
72
|
-
correct = 0
|
|
73
|
-
total = 0
|
|
74
|
-
|
|
75
|
-
with torch.no_grad():
|
|
76
|
-
for inputs, tags, mask in tqdm(eval_loader, desc="Eval"):
|
|
77
|
-
inputs = inputs.to(device)
|
|
78
|
-
tags = tags.to(device)
|
|
79
|
-
mask = mask.to(device)
|
|
80
|
-
|
|
81
|
-
loss = model(inputs, tags=tags, mask=mask)
|
|
82
|
-
eval_loss += loss.item()
|
|
83
|
-
|
|
84
|
-
predictions = model(inputs, mask=mask)
|
|
85
|
-
for pred, target, m in zip(predictions, tags, mask):
|
|
86
|
-
for p, t, valid in zip(pred, target, m):
|
|
87
|
-
if valid:
|
|
88
|
-
total += 1
|
|
89
|
-
if p == t.item():
|
|
90
|
-
correct += 1
|
|
91
|
-
|
|
92
|
-
avg_eval_loss = eval_loss / len(eval_loader)
|
|
93
|
-
accuracy = correct / total if total > 0 else 0
|
|
94
|
-
|
|
95
|
-
current_lr = optimizer.param_groups[0]["lr"]
|
|
96
|
-
print(
|
|
97
|
-
f"Epoch [{epoch + 1}/{num_epochs}] Train Loss: {avg_train_loss:.4f}, Eval Loss: {avg_eval_loss:.4f}, Accuracy: {accuracy:.4f}, LR: {current_lr:.6f}"
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
scheduler.step(avg_eval_loss)
|
|
101
|
-
|
|
102
|
-
if avg_eval_loss < best_eval_loss:
|
|
103
|
-
best_eval_loss = avg_eval_loss
|
|
104
|
-
torch.save(model.state_dict(), "best_model.pt")
|
|
105
|
-
print(f"Best model saved with eval loss: {best_eval_loss:.4f}")
|
|
106
|
-
|
|
107
|
-
torch.save(model.state_dict(), "model.pt")
|
|
108
|
-
print("Final model saved to model.pt")
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
if __name__ == "__main__":
|
|
112
|
-
train()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|