sembr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sembr/__init__.py +0 -0
- sembr/cli.py +159 -0
- sembr/dataset.py +93 -0
- sembr/eval.py +36 -0
- sembr/inference.py +119 -0
- sembr/process.py +302 -0
- sembr/sembr2023.py +60 -0
- sembr/train.py +128 -0
- sembr/utils.py +44 -0
- sembr-0.0.1.dist-info/LICENSE.txt +33 -0
- sembr-0.0.1.dist-info/METADATA +256 -0
- sembr-0.0.1.dist-info/RECORD +15 -0
- sembr-0.0.1.dist-info/WHEEL +5 -0
- sembr-0.0.1.dist-info/entry_points.txt +2 -0
- sembr-0.0.1.dist-info/top_level.txt +1 -0
sembr/__init__.py
ADDED
|
File without changes
|
sembr/cli.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import traceback
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
7
|
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def parse_args():
|
|
11
|
+
import argparse
|
|
12
|
+
parser = argparse.ArgumentParser()
|
|
13
|
+
model_name = 'admko/sembr2023-bert-small'
|
|
14
|
+
parser.add_argument('-m', '--model-name', type=str, default=model_name)
|
|
15
|
+
parser.add_argument('-i', '--input-file', type=str, default=None)
|
|
16
|
+
parser.add_argument('-o', '--output-file', type=str, default=None)
|
|
17
|
+
parser.add_argument('-w', '--words-per-line', type=int, default=10)
|
|
18
|
+
parser.add_argument('-b', '--batch-size', type=int, default=8)
|
|
19
|
+
parser.add_argument('-d', '--overlap-divisor', type=int, default=8)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
'-f', '--predict-func', type=str,
|
|
22
|
+
choices=['argmax', 'breaks_first', 'logit_adjustment'],
|
|
23
|
+
default='argmax')
|
|
24
|
+
parser.add_argument('-t', '--tokens-per-line', type=int, default=10)
|
|
25
|
+
parser.add_argument('-s', '--server', type=str, default='127.0.0.1')
|
|
26
|
+
parser.add_argument('-l', '--listen', action='store_true')
|
|
27
|
+
parser.add_argument('-p', '--port', type=int, default=8384)
|
|
28
|
+
return parser.parse_args()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def init(model_name):
|
|
32
|
+
import torch
|
|
33
|
+
from transformers import (AutoTokenizer, AutoModelForTokenClassification)
|
|
34
|
+
from .process import SemBrProcessor
|
|
35
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
36
|
+
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
|
37
|
+
model.eval()
|
|
38
|
+
if torch.cuda.is_available():
|
|
39
|
+
model = model.to('cuda')
|
|
40
|
+
elif torch.backends.mps.is_available():
|
|
41
|
+
model = model.to('mps')
|
|
42
|
+
processor = SemBrProcessor()
|
|
43
|
+
return tokenizer, model, processor
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def start_server(port, tokenizer, model, processor):
|
|
48
|
+
from flask import Flask, request
|
|
49
|
+
app = Flask(__name__)
|
|
50
|
+
base_rv = {
|
|
51
|
+
'model': model.__class__.__name__,
|
|
52
|
+
'tokenizer': tokenizer.__class__.__name__,
|
|
53
|
+
'processor': processor.__class__.__name__,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
@app.route('/check')
|
|
57
|
+
def check():
|
|
58
|
+
return {**base_rv, 'status': 'success'}
|
|
59
|
+
|
|
60
|
+
@app.route('/rewrap', methods=['POST'])
|
|
61
|
+
def rewrap():
|
|
62
|
+
from .inference import sembr
|
|
63
|
+
form = request.form
|
|
64
|
+
text = form['text']
|
|
65
|
+
kwargs = {
|
|
66
|
+
'batch_size': int(form.get('batch_size', 8)),
|
|
67
|
+
'predict_func': form.get('predict_func', 'argmax'),
|
|
68
|
+
'tokens_per_line': int(form.get('tokens_per_line', 10)),
|
|
69
|
+
'overlap_divisor': int(form.get('overlap_divisor', 8)),
|
|
70
|
+
}
|
|
71
|
+
try:
|
|
72
|
+
results = sembr(text, tokenizer, model, processor, **kwargs)
|
|
73
|
+
return {
|
|
74
|
+
**base_rv,
|
|
75
|
+
'status': 'success',
|
|
76
|
+
'text': results
|
|
77
|
+
}
|
|
78
|
+
except Exception as e:
|
|
79
|
+
return {
|
|
80
|
+
**base_rv,
|
|
81
|
+
'status': 'error',
|
|
82
|
+
'error': str(e),
|
|
83
|
+
'traceback': traceback.format_exc(),
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
app.run(port=port)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def check_server(server, port):
|
|
90
|
+
import requests
|
|
91
|
+
from requests.exceptions import ConnectionError, ReadTimeout
|
|
92
|
+
if not server:
|
|
93
|
+
return False
|
|
94
|
+
try:
|
|
95
|
+
response = requests.get(f'http://{server}:{port}/check', timeout=0.3)
|
|
96
|
+
except (ConnectionError, ReadTimeout) as e:
|
|
97
|
+
return False
|
|
98
|
+
if response.status_code != 200:
|
|
99
|
+
return False
|
|
100
|
+
if response.json()['status'] != 'success':
|
|
101
|
+
return False
|
|
102
|
+
return True
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def rewrap_on_server(text, server, port, kwargs):
|
|
106
|
+
import requests
|
|
107
|
+
data = {
|
|
108
|
+
'text': text,
|
|
109
|
+
**kwargs,
|
|
110
|
+
}
|
|
111
|
+
try:
|
|
112
|
+
results = requests.post(
|
|
113
|
+
f'http://{server}:{port}/rewrap', data=data)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise ValueError(f'Connection Error: {e}')
|
|
116
|
+
if results.status_code != 200:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f'Connection Error: {results.status_code}: {results.text}')
|
|
119
|
+
data = results.json()
|
|
120
|
+
if data['status'] != 'success':
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f'Status: {data["status"]}\n'
|
|
123
|
+
f'Exception: {data["error"]}\n'
|
|
124
|
+
f'{data["traceback"]}')
|
|
125
|
+
return data['text']
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def main(args=None):
|
|
129
|
+
if args is None:
|
|
130
|
+
args = parse_args()
|
|
131
|
+
if args.listen:
|
|
132
|
+
tokenizer, model, processor = init(args.model_name)
|
|
133
|
+
return start_server(args.port, tokenizer, model, processor)
|
|
134
|
+
if args.input_file is not None:
|
|
135
|
+
with open(args.input_file, 'r', encoding='utf-8') as f:
|
|
136
|
+
text = f.read()
|
|
137
|
+
else:
|
|
138
|
+
text = sys.stdin.read()
|
|
139
|
+
kwargs = {
|
|
140
|
+
'batch_size': args.batch_size,
|
|
141
|
+
'predict_func': args.predict_func,
|
|
142
|
+
'tokens_per_line': args.tokens_per_line,
|
|
143
|
+
'overlap_divisor': args.overlap_divisor,
|
|
144
|
+
}
|
|
145
|
+
if check_server(args.server, args.port):
|
|
146
|
+
result = rewrap_on_server(text, args.server, args.port, kwargs)
|
|
147
|
+
else:
|
|
148
|
+
from .inference import sembr
|
|
149
|
+
tokenizer, model, processor = init(args.model_name)
|
|
150
|
+
result = sembr(text, tokenizer, model, processor, **kwargs)
|
|
151
|
+
if args.output_file is None:
|
|
152
|
+
print(result)
|
|
153
|
+
return
|
|
154
|
+
with open(args.output_file, 'w', encoding='utf-8') as f:
|
|
155
|
+
f.write(result)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
if __name__ == '__main__':
|
|
159
|
+
main(parse_args())
|
sembr/dataset.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
import datasets
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _process_examples(
|
|
9
|
+
examples, processor, tokenizer, mode_names, max_indent, label2id
|
|
10
|
+
):
|
|
11
|
+
examples['modes'] = [
|
|
12
|
+
[mode_names[i] for i in bm] for bm in examples['modes']]
|
|
13
|
+
transposed = [dict(zip(examples, c)) for c in zip(*examples.values())]
|
|
14
|
+
results = processor.tokenize_with_modes(tokenizer, transposed)
|
|
15
|
+
for r in results:
|
|
16
|
+
r['labels'] = labels = []
|
|
17
|
+
indents = []
|
|
18
|
+
for m, i in zip(r.pop('modes'), r.pop('indents')):
|
|
19
|
+
i = min(i, max_indent)
|
|
20
|
+
indents.append(i)
|
|
21
|
+
if m == 'off':
|
|
22
|
+
label = 'off'
|
|
23
|
+
else:
|
|
24
|
+
label = f'{m}-{i}'
|
|
25
|
+
labels.append(label2id[label])
|
|
26
|
+
r.pop('base_indent')
|
|
27
|
+
keys = ['input_ids', 'labels']
|
|
28
|
+
return {k: [d[k] for d in results] for k in keys}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def chunk_examples(examples):
|
|
32
|
+
id_chunks, label_chunks = [], []
|
|
33
|
+
max_length = 512
|
|
34
|
+
overlap = int(max_length / 10)
|
|
35
|
+
for ids, labels in zip(examples.pop('input_ids'), examples.pop('labels')):
|
|
36
|
+
id_chunks += [
|
|
37
|
+
ids[i:i + max_length] for i in range(0, len(ids), overlap)]
|
|
38
|
+
label_chunks += [
|
|
39
|
+
labels[i:i + max_length] for i in range(0, len(labels), overlap)]
|
|
40
|
+
return {'input_ids': id_chunks, 'labels': label_chunks}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def process_dataset(dataset, processor, tokenizer, max_indent, label2id):
|
|
44
|
+
mode_names = dataset.features['modes'].feature.names
|
|
45
|
+
removed_columns = [
|
|
46
|
+
'flat_lines', 'modes', 'mode_offsets', 'indents', 'base_indent']
|
|
47
|
+
process_examples = functools.partial(
|
|
48
|
+
_process_examples,
|
|
49
|
+
processor=processor, tokenizer=tokenizer, mode_names=mode_names,
|
|
50
|
+
max_indent=max_indent, label2id=label2id)
|
|
51
|
+
dataset = dataset.map(
|
|
52
|
+
process_examples, batched=True, remove_columns=removed_columns)
|
|
53
|
+
dataset = dataset.map(chunk_examples, batched=True)
|
|
54
|
+
return dataset
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def init_dataset():
|
|
58
|
+
dataset_file = os.path.join(os.path.dirname(__file__), 'sembr2023.py')
|
|
59
|
+
return datasets.load_dataset(dataset_file)
|
|
60
|
+
|
|
61
|
+
def push_to_hub(dataset, hub_user, dataset_name, private):
|
|
62
|
+
hub_path = f'{hub_user}/{dataset_name}'
|
|
63
|
+
print(f'Pushing dataset to {hub_path}...')
|
|
64
|
+
dataset.push_to_hub(hub_path, private=private)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def save_to_disk(dataset, save_dir, dataset_name):
|
|
68
|
+
save_path = os.path.join(save_dir, dataset_name)
|
|
69
|
+
print(f'Saving dataset to {save_path}...')
|
|
70
|
+
dataset.save_to_disk(save_path)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def parse_args():
|
|
74
|
+
parser = argparse.ArgumentParser()
|
|
75
|
+
parser.add_argument('-n', '--dataset-name', type=str, default='sembr2023')
|
|
76
|
+
parser.add_argument('-s', '--save-dir', type=str, default='data')
|
|
77
|
+
parser.add_argument('-u', '--hub-user', type=str, default=None)
|
|
78
|
+
parser.add_argument('-p', '--private', action='store_true')
|
|
79
|
+
return parser.parse_args()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def main(args=None):
|
|
83
|
+
if args is None:
|
|
84
|
+
args = parse_args()
|
|
85
|
+
dataset = init_dataset()
|
|
86
|
+
if args.save_dir is not None:
|
|
87
|
+
save_to_disk(dataset, args.save_dir, args.dataset_name)
|
|
88
|
+
if args.hub_user is not None:
|
|
89
|
+
push_to_hub(dataset, args.hub_user, args.dataset_name, args.private)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
main()
|
sembr/eval.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
|
|
3
|
+
import datasets
|
|
4
|
+
import evaluate
|
|
5
|
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
|
6
|
+
|
|
7
|
+
from .inference import inference
|
|
8
|
+
from .process import SemBrProcessor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def checkpoints():
|
|
12
|
+
return glob.glob("checkpoints/*/checkpoint-*")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def eval_model(dataset, processor, checkpoint, metric):
|
|
16
|
+
model = AutoModelForTokenClassification.from_pretrained(checkpoint)
|
|
17
|
+
model = model.to('cuda')
|
|
18
|
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
|
19
|
+
paras = dataset['test']['flat_lines']
|
|
20
|
+
results = inference(
|
|
21
|
+
paras, tokenizer, model, processor, batch_size=8, overlap_divisor=8)
|
|
22
|
+
generated = processor.generate(results, join=False)
|
|
23
|
+
return metric.compute(predictions=generated, references=paras)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def main():
|
|
27
|
+
dataset = datasets.load_from_disk('./data/sembr2023')
|
|
28
|
+
processor = SemBrProcessor()
|
|
29
|
+
wer = evaluate.load('wer')
|
|
30
|
+
for checkpoint in checkpoints():
|
|
31
|
+
metrics = eval_model(dataset, processor, checkpoint, wer)
|
|
32
|
+
print(f'{checkpoint=}, {metrics=}')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if __name__ == '__main__':
|
|
36
|
+
main()
|
sembr/inference.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from tqdm import trange
|
|
4
|
+
|
|
5
|
+
from transformers import DataCollatorForTokenClassification
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _tiled_inference(model, collator, results, batch_size, overlap_divisor):
|
|
9
|
+
device = model.device
|
|
10
|
+
max_length = model.config.max_position_embeddings
|
|
11
|
+
overlap_length = int(max_length / overlap_divisor)
|
|
12
|
+
input_ids = [{'input_ids': r['input_ids']} for r in results]
|
|
13
|
+
num_paras = len(input_ids)
|
|
14
|
+
lengths = [len(i['input_ids']) for i in input_ids]
|
|
15
|
+
sorted_indices = sorted(
|
|
16
|
+
range(num_paras), key=lambda i: lengths[i], reverse=True)
|
|
17
|
+
logits = torch.zeros(
|
|
18
|
+
(num_paras, max(lengths), model.config.num_labels), device=device)
|
|
19
|
+
counts = torch.zeros(
|
|
20
|
+
(num_paras, max(lengths)), dtype=torch.long, device=device)
|
|
21
|
+
for b in trange(0, num_paras, batch_size):
|
|
22
|
+
bslice = slice(b, min(num_paras, b + batch_size))
|
|
23
|
+
bindices = sorted_indices[bslice]
|
|
24
|
+
binids = [input_ids[i] for i in bindices]
|
|
25
|
+
data = collator(binids, return_tensors='pt').to(device)
|
|
26
|
+
num_tokens = data['input_ids'].shape[1]
|
|
27
|
+
for i in range(0, num_tokens, max_length - overlap_length):
|
|
28
|
+
islice = slice(i, min(num_tokens, i + max_length))
|
|
29
|
+
inids = data['input_ids'][:, islice]
|
|
30
|
+
attns = data['attention_mask'][:, islice]
|
|
31
|
+
with torch.no_grad():
|
|
32
|
+
outputs = model(
|
|
33
|
+
input_ids=inids, attention_mask=attns, return_dict=True)
|
|
34
|
+
logits[bindices, islice] += outputs.logits
|
|
35
|
+
counts[bindices, islice] += attns
|
|
36
|
+
attns = counts > 0
|
|
37
|
+
logits /= counts.unsqueeze(-1)
|
|
38
|
+
logits[~attns] = 0
|
|
39
|
+
return logits, attns
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _format_labels(id2label, preds, attns, results):
|
|
43
|
+
modes, indents = [], []
|
|
44
|
+
for i, (p, a) in enumerate(zip(preds, attns)):
|
|
45
|
+
para_modes, para_indents = [], []
|
|
46
|
+
for name in [id2label[int(t)] for t in p[a]]:
|
|
47
|
+
if name == 'off':
|
|
48
|
+
mode, indent = 'off', 0
|
|
49
|
+
else:
|
|
50
|
+
mode, indent = name.split('-')
|
|
51
|
+
para_modes.append(mode)
|
|
52
|
+
para_indents.append(int(indent))
|
|
53
|
+
modes.append(para_modes)
|
|
54
|
+
indents.append(para_indents)
|
|
55
|
+
for r, m, i in zip(results, modes, indents):
|
|
56
|
+
r['modes'] = m
|
|
57
|
+
r['indents'] = i
|
|
58
|
+
return results
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def predict_argmax(logits, counts, tokens_per_line):
|
|
62
|
+
return logits.argmax(dim=2)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def zero_runs(a, dim):
|
|
66
|
+
# Create an array that is 1 where a is 0, and pad each end with an extra 0.
|
|
67
|
+
pad = np.zeros_like(a.take([0], dim))
|
|
68
|
+
iszero = np.concatenate([pad, (a == 0).view(np.int8), pad], dim)
|
|
69
|
+
absdiff = np.abs(np.diff(iszero))
|
|
70
|
+
# Runs start and end where absdiff is 1.
|
|
71
|
+
return np.where(absdiff == 1)[0].reshape(-1, 2)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def predict_logit_adjustment(logits, counts, tokens_per_line):
|
|
75
|
+
delta = 1.0
|
|
76
|
+
logits[:, :, 0] -= delta
|
|
77
|
+
logits[:, :, 1:] += delta / logits.shape[2]
|
|
78
|
+
return logits.argmax(dim=2)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def predict_breaks_first(logits, counts, tokens_per_line):
|
|
82
|
+
delta = 0.0
|
|
83
|
+
off_logits = logits[:, :, 0] - delta
|
|
84
|
+
break_logits = logits[:, :, 1:] + delta / logits.shape[2]
|
|
85
|
+
breaks = (break_logits > off_logits.unsqueeze(-1)).any(2)
|
|
86
|
+
break_preds = 1 + break_logits.argmax(dim=2)
|
|
87
|
+
return torch.where(breaks, break_preds, torch.zeros_like(break_preds))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
PREDICT_FUNC_MAP = {
|
|
91
|
+
'argmax': predict_argmax,
|
|
92
|
+
'breaks_first': predict_breaks_first,
|
|
93
|
+
'logit_adjustment': predict_logit_adjustment,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def inference(
|
|
98
|
+
text, tokenizer, model, processor,
|
|
99
|
+
predict_func='argmax', tokens_per_line=10,
|
|
100
|
+
batch_size=8, overlap_divisor=8,
|
|
101
|
+
):
|
|
102
|
+
collator = DataCollatorForTokenClassification(tokenizer, padding='longest')
|
|
103
|
+
results = processor(text, split=isinstance(text, str))
|
|
104
|
+
results = processor.tokenize_with_modes(tokenizer, results)
|
|
105
|
+
logits, counts = _tiled_inference(
|
|
106
|
+
model, collator, results, batch_size, overlap_divisor)
|
|
107
|
+
preds = PREDICT_FUNC_MAP[predict_func](logits, counts, tokens_per_line)
|
|
108
|
+
return _format_labels(model.config.id2label, preds, counts, results)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def sembr(
|
|
112
|
+
text, tokenizer, model, processor,
|
|
113
|
+
predict_func='argmax', tokens_per_line=10,
|
|
114
|
+
batch_size=8, overlap_divisor=8,
|
|
115
|
+
):
|
|
116
|
+
results = inference(
|
|
117
|
+
text, tokenizer, model, processor, predict_func, tokens_per_line,
|
|
118
|
+
batch_size, overlap_divisor)
|
|
119
|
+
return processor.generate(results, join=True)
|
sembr/process.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SemBrProcessor(object):
|
|
7
|
+
def __init__(self, spaces=4):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.spaces = spaces
|
|
10
|
+
self.replace_tokens = {
|
|
11
|
+
# r'\n(?:\s*\n)+': '[par]',
|
|
12
|
+
# '\t': '[indent]',
|
|
13
|
+
# ' ' * self.spaces: '[indent]',
|
|
14
|
+
'\\%': '[percent]',
|
|
15
|
+
'\n': '[newline]',
|
|
16
|
+
}
|
|
17
|
+
self.reverse_replace_tokens = {
|
|
18
|
+
v: k for k, v in self.replace_tokens.items()}
|
|
19
|
+
|
|
20
|
+
def prepare_tokenizer(self, tokenizer):
|
|
21
|
+
tokenizer.add_tokens(list(self.replace_tokens.values()))
|
|
22
|
+
|
|
23
|
+
def _process_specials(self, lines):
|
|
24
|
+
for k, v in self.replace_tokens.items():
|
|
25
|
+
# lines = [re.sub(k, v, l) for l in lines]
|
|
26
|
+
lines = [l.replace(k, v) for l in lines]
|
|
27
|
+
return lines
|
|
28
|
+
|
|
29
|
+
def _process_indents(self, lines):
|
|
30
|
+
nlines = []
|
|
31
|
+
indents = []
|
|
32
|
+
# get indent levels
|
|
33
|
+
for line in lines:
|
|
34
|
+
indent = 0
|
|
35
|
+
for c in line:
|
|
36
|
+
if c == ' ':
|
|
37
|
+
indent += 1
|
|
38
|
+
elif c == '\t':
|
|
39
|
+
indent += self.spaces
|
|
40
|
+
else:
|
|
41
|
+
break
|
|
42
|
+
indent_level = int(indent / self.spaces)
|
|
43
|
+
nlines.append(line[indent_level * self.spaces:].rstrip())
|
|
44
|
+
indents.append(indent_level)
|
|
45
|
+
return nlines, indents
|
|
46
|
+
|
|
47
|
+
def _process_comments(self, lines, indents):
|
|
48
|
+
# normalize comments, ['xxx % comment'] -> ['xxx', '% comment']
|
|
49
|
+
nclines = []
|
|
50
|
+
ncindents = []
|
|
51
|
+
for line, indent in zip(lines, indents):
|
|
52
|
+
if '%' in line:
|
|
53
|
+
normal, *comment = line.split('%')
|
|
54
|
+
comment = '%'.join(comment).strip()
|
|
55
|
+
if normal.strip():
|
|
56
|
+
if comment:
|
|
57
|
+
nclines += [normal, f'%{comment}']
|
|
58
|
+
ncindents += [indent, indent]
|
|
59
|
+
continue
|
|
60
|
+
line = f'{normal}%'
|
|
61
|
+
nclines.append(line)
|
|
62
|
+
ncindents.append(indent)
|
|
63
|
+
return nclines, ncindents
|
|
64
|
+
|
|
65
|
+
def _process_modes(self, lines):
|
|
66
|
+
new_lines = []
|
|
67
|
+
modes = []
|
|
68
|
+
prev_status = 'start'
|
|
69
|
+
for line in lines:
|
|
70
|
+
if line.startswith('%'):
|
|
71
|
+
status = 'comment'
|
|
72
|
+
elif line.endswith('%'):
|
|
73
|
+
status = 'percent'
|
|
74
|
+
line = line.rstrip('%')
|
|
75
|
+
else:
|
|
76
|
+
status = 'normal'
|
|
77
|
+
match (prev_status, status):
|
|
78
|
+
case ('start', _):
|
|
79
|
+
pass
|
|
80
|
+
case ('normal', _):
|
|
81
|
+
modes.append('space')
|
|
82
|
+
case ('percent', _):
|
|
83
|
+
modes.append('nospace')
|
|
84
|
+
case ('comment', 'normal'):
|
|
85
|
+
modes.append('break')
|
|
86
|
+
case ('comment', 'percent'):
|
|
87
|
+
modes.append('break')
|
|
88
|
+
case ('comment', 'comment'):
|
|
89
|
+
modes.append('comment')
|
|
90
|
+
case (_, 'comment'):
|
|
91
|
+
modes.append('comment')
|
|
92
|
+
case _:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
'Unknown status transition: '
|
|
95
|
+
f'{prev_status} -> {status}.')
|
|
96
|
+
new_lines.append(line)
|
|
97
|
+
prev_status = status
|
|
98
|
+
# last transition always force a break
|
|
99
|
+
modes.append('break')
|
|
100
|
+
return new_lines, modes
|
|
101
|
+
|
|
102
|
+
def _flatten_with_modes(self, lines, modes):
|
|
103
|
+
in_comment = 0
|
|
104
|
+
flat_lines, flat_modes, offsets = [], [], []
|
|
105
|
+
prev_len = flat_len = 0
|
|
106
|
+
for line, mode in zip(lines, modes):
|
|
107
|
+
if in_comment >= 1:
|
|
108
|
+
line = re.sub(r'^\s*%', '', line)
|
|
109
|
+
if mode == 'break':
|
|
110
|
+
in_comment = 0
|
|
111
|
+
line = f'{line}[newline]'
|
|
112
|
+
mode = 'off'
|
|
113
|
+
elif mode == 'comment':
|
|
114
|
+
in_comment += 1
|
|
115
|
+
mode = 'space'
|
|
116
|
+
elif mode == 'space':
|
|
117
|
+
line = f'{line} '
|
|
118
|
+
elif mode == 'nospace':
|
|
119
|
+
pass
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f'Unknown mode: {mode}.')
|
|
122
|
+
flat_lines.append(line)
|
|
123
|
+
flat_modes.append(mode)
|
|
124
|
+
prev_len = flat_len
|
|
125
|
+
flat_len += len(line)
|
|
126
|
+
offsets.append((prev_len, flat_len))
|
|
127
|
+
return ''.join(flat_lines), flat_modes, offsets
|
|
128
|
+
|
|
129
|
+
def _process_paragraph(self, text):
|
|
130
|
+
lines = text.split('\n')
|
|
131
|
+
lines = self._process_specials(lines)
|
|
132
|
+
lines, indents = self._process_indents(lines)
|
|
133
|
+
base_indent = min(indents)
|
|
134
|
+
indents = [i - base_indent for i in indents]
|
|
135
|
+
lines, indents = self._process_comments(lines, indents)
|
|
136
|
+
lines, modes = self._process_modes(lines)
|
|
137
|
+
flat_lines, modes, mode_offsets = self._flatten_with_modes(lines, modes)
|
|
138
|
+
return {
|
|
139
|
+
'flat_lines': flat_lines,
|
|
140
|
+
'modes': modes,
|
|
141
|
+
'mode_offsets': mode_offsets,
|
|
142
|
+
'indents': indents,
|
|
143
|
+
'base_indent': base_indent,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
def _tokenize_with_modes(
|
|
147
|
+
self, tokenizer, text, line_modes, line_mode_offsets, line_indents
|
|
148
|
+
):
|
|
149
|
+
enc = tokenizer(text, return_offsets_mapping=True)
|
|
150
|
+
words, modes, indents = [], [], []
|
|
151
|
+
pos = mode_idx = 0
|
|
152
|
+
# fill empty words in offset mapping
|
|
153
|
+
offset_mapping = []
|
|
154
|
+
for start, end in enc.offset_mapping:
|
|
155
|
+
offset_mapping.append((min(start, pos), end))
|
|
156
|
+
pos = end
|
|
157
|
+
pos = 0
|
|
158
|
+
input_ids = []
|
|
159
|
+
for tid, (start, end) in zip(enc.input_ids, offset_mapping):
|
|
160
|
+
if pos >= len(text):
|
|
161
|
+
break
|
|
162
|
+
mode_offset = line_mode_offsets[mode_idx][1]
|
|
163
|
+
word = text[pos:end]
|
|
164
|
+
input_ids.append(tid)
|
|
165
|
+
words.append(word)
|
|
166
|
+
indents.append(line_indents[mode_idx])
|
|
167
|
+
pos = max(pos, end)
|
|
168
|
+
if mode_offset >= end:
|
|
169
|
+
modes.append('off')
|
|
170
|
+
continue
|
|
171
|
+
mode = line_modes[mode_idx]
|
|
172
|
+
modes.append(mode)
|
|
173
|
+
mode_idx += 1
|
|
174
|
+
# current word is on a new line
|
|
175
|
+
indents[-1] = line_indents[mode_idx]
|
|
176
|
+
return input_ids, words, modes, indents
|
|
177
|
+
|
|
178
|
+
def tokenize_with_modes(self, tokenizer, results):
|
|
179
|
+
self.prepare_tokenizer(tokenizer)
|
|
180
|
+
new_results = []
|
|
181
|
+
for r in results:
|
|
182
|
+
flat_lines = r['flat_lines']
|
|
183
|
+
modes = r['modes']
|
|
184
|
+
mode_offsets = r['mode_offsets']
|
|
185
|
+
indents = r['indents']
|
|
186
|
+
base_indent = r['base_indent']
|
|
187
|
+
input_ids, words, modes, indents = self._tokenize_with_modes(
|
|
188
|
+
tokenizer, flat_lines, modes, mode_offsets, indents)
|
|
189
|
+
tokenized = {
|
|
190
|
+
'input_ids': input_ids,
|
|
191
|
+
'words': words,
|
|
192
|
+
'modes': modes,
|
|
193
|
+
'indents': indents,
|
|
194
|
+
'base_indent': base_indent,
|
|
195
|
+
}
|
|
196
|
+
keys = ['input_ids', 'words', 'modes', 'indents']
|
|
197
|
+
if len(set(len(tokenized[k]) for k in keys)) != 1:
|
|
198
|
+
len_dict = {k: len(tokenized[k]) for k in keys}
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f'Lengths do not match. Found: {len_dict}.')
|
|
201
|
+
new_results.append(tokenized)
|
|
202
|
+
return new_results
|
|
203
|
+
|
|
204
|
+
def __call__(self, text, split=True):
|
|
205
|
+
if split:
|
|
206
|
+
text = re.split(r'\n(?:\s*\n)+', text)
|
|
207
|
+
elif isinstance(text, str):
|
|
208
|
+
raise ValueError(
|
|
209
|
+
'Text must be a list of strings if split=True.')
|
|
210
|
+
paras = []
|
|
211
|
+
for p in text:
|
|
212
|
+
if not p.strip():
|
|
213
|
+
continue
|
|
214
|
+
paras.append(self._process_paragraph(p))
|
|
215
|
+
return paras
|
|
216
|
+
|
|
217
|
+
def _replace_newlines(self, words, modes, indents):
|
|
218
|
+
new_words, new_modes, new_indents = [], [], []
|
|
219
|
+
next_mode = None
|
|
220
|
+
for word, mode, indent in zip(words, modes, indents):
|
|
221
|
+
if word == '[newline]':
|
|
222
|
+
next_mode = 'break'
|
|
223
|
+
continue
|
|
224
|
+
if next_mode:
|
|
225
|
+
# if mode != 'off':
|
|
226
|
+
# raise ValueError(
|
|
227
|
+
# f'Cannot set mode {next_mode} '
|
|
228
|
+
# f'when mode is {mode}.')
|
|
229
|
+
mode = next_mode
|
|
230
|
+
next_mode = None
|
|
231
|
+
new_words.append(word)
|
|
232
|
+
new_modes.append(mode)
|
|
233
|
+
new_indents.append(indent)
|
|
234
|
+
return new_words, new_modes, new_indents
|
|
235
|
+
|
|
236
|
+
def _generate_lines(self, words, modes, indents):
|
|
237
|
+
lbs = [
|
|
238
|
+
(o, m) for o, m in enumerate(modes)
|
|
239
|
+
if m in ('space', 'nospace', 'break')]
|
|
240
|
+
if not lbs or lbs[-1][0] < len(words):
|
|
241
|
+
lbs.append((len(words), 'space'))
|
|
242
|
+
lines, line_indents = [], []
|
|
243
|
+
pos = in_comment = 0
|
|
244
|
+
for o, m in lbs:
|
|
245
|
+
line = ''.join(words[pos:o]).strip()
|
|
246
|
+
if line.startswith('%'):
|
|
247
|
+
in_comment = 1
|
|
248
|
+
if m == 'nospace':
|
|
249
|
+
line = f'{line}%'
|
|
250
|
+
if m in ('space', 'break'):
|
|
251
|
+
if in_comment > 1:
|
|
252
|
+
line = f'% {line}'
|
|
253
|
+
if in_comment:
|
|
254
|
+
in_comment += 1
|
|
255
|
+
if m == 'break':
|
|
256
|
+
in_comment = 0
|
|
257
|
+
lines.append(line)
|
|
258
|
+
line_indents.append(indents[pos:o])
|
|
259
|
+
pos = o
|
|
260
|
+
# line_indents = [Counter(l).most_common(1)[0][0] for l in line_indents]
|
|
261
|
+
line_indents = [l[0] for l in line_indents]
|
|
262
|
+
return lines, line_indents
|
|
263
|
+
|
|
264
|
+
def _indent_lines(self, lines, indents, base_indent):
|
|
265
|
+
spaces = ' ' * self.spaces
|
|
266
|
+
return [
|
|
267
|
+
f'{spaces * (i + base_indent)}{l}'
|
|
268
|
+
for i, l in zip(indents, lines)]
|
|
269
|
+
|
|
270
|
+
def _generate_paragraph(self, processed):
|
|
271
|
+
words = processed['words']
|
|
272
|
+
modes = processed['modes']
|
|
273
|
+
indents = processed['indents']
|
|
274
|
+
base_indent = processed['base_indent']
|
|
275
|
+
words, modes, indents = self._replace_newlines(words, modes, indents)
|
|
276
|
+
lines, indents = self._generate_lines(words, modes, indents)
|
|
277
|
+
lines = self._indent_lines(lines, indents, base_indent)
|
|
278
|
+
text = '\n'.join(lines)
|
|
279
|
+
for k, v in self.reverse_replace_tokens.items():
|
|
280
|
+
text = text.replace(k, v)
|
|
281
|
+
return text
|
|
282
|
+
|
|
283
|
+
def generate(self, paragraphs, join=True):
|
|
284
|
+
paragraphs = [self._generate_paragraph(p) for p in paragraphs]
|
|
285
|
+
if join:
|
|
286
|
+
return '\n\n'.join(paragraphs)
|
|
287
|
+
return paragraphs
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
if __name__ == '__main__':
|
|
291
|
+
# test = open('./data/test/mair.tex', 'r').read()
|
|
292
|
+
test = open('./data/example.tex', 'r').read()
|
|
293
|
+
processor = SemBrProcessor()
|
|
294
|
+
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
|
295
|
+
results = processor(test)
|
|
296
|
+
results = processor.tokenize_with_modes(tokenizer, results)
|
|
297
|
+
print('--- Processed ---')
|
|
298
|
+
print(processor.generate(results))
|
|
299
|
+
for r in results:
|
|
300
|
+
r['modes'] = ['off' if m != 'break' else m for m in r['modes']]
|
|
301
|
+
print('--- Flattened ---')
|
|
302
|
+
print(processor.generate(results))
|
sembr/sembr2023.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
|
|
4
|
+
import datasets
|
|
5
|
+
|
|
6
|
+
from .process import SemBrProcessor
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = datasets.logging.get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
MAX_INDENT = 10
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SemBr2023(datasets.GeneratorBasedBuilder):
|
|
15
|
+
BUILDER_CONFIGS = [
|
|
16
|
+
datasets.BuilderConfig(
|
|
17
|
+
name='sembr2023',
|
|
18
|
+
version=datasets.Version('1.0.0'),
|
|
19
|
+
description='SemBr2023 dataset'),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
def __init__(self, **kwargs):
|
|
23
|
+
super().__init__(**kwargs)
|
|
24
|
+
self.processor = SemBrProcessor()
|
|
25
|
+
|
|
26
|
+
def _info(self):
|
|
27
|
+
modes = ['off', 'space', 'nospace']
|
|
28
|
+
indents = [str(i) for i in range(MAX_INDENT + 1)]
|
|
29
|
+
return datasets.DatasetInfo(
|
|
30
|
+
features=datasets.Features({
|
|
31
|
+
'flat_lines': datasets.Value('string'),
|
|
32
|
+
'modes': datasets.Sequence(
|
|
33
|
+
datasets.features.ClassLabel(names=modes)),
|
|
34
|
+
'mode_offsets': datasets.Sequence(
|
|
35
|
+
datasets.Sequence(datasets.Value('int32'))),
|
|
36
|
+
'indents': datasets.Sequence(
|
|
37
|
+
datasets.features.ClassLabel(names=indents)),
|
|
38
|
+
'base_indent': datasets.Value('int32'),
|
|
39
|
+
})
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def _split_generators(self, dl_manager):
|
|
43
|
+
return [
|
|
44
|
+
datasets.SplitGenerator(
|
|
45
|
+
name=datasets.Split.TRAIN,
|
|
46
|
+
gen_kwargs={'root': './data/raw/train/'}),
|
|
47
|
+
datasets.SplitGenerator(
|
|
48
|
+
name=datasets.Split.TEST,
|
|
49
|
+
gen_kwargs={'root': './data/raw/test/'}),
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
def _generate_examples(self, root):
|
|
53
|
+
eid = 0
|
|
54
|
+
for path in glob.glob(os.path.join(root, '*')):
|
|
55
|
+
logger.info(f'Generating examples from {path!r}...')
|
|
56
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
57
|
+
text = f.read()
|
|
58
|
+
for p in self.processor(text):
|
|
59
|
+
yield eid, p
|
|
60
|
+
eid += 1
|
sembr/train.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
import datasets
|
|
6
|
+
from transformers import (
|
|
7
|
+
AutoTokenizer, AutoModelForTokenClassification,
|
|
8
|
+
TrainingArguments, Trainer, DataCollatorForTokenClassification)
|
|
9
|
+
|
|
10
|
+
from .process import SemBrProcessor
|
|
11
|
+
from .dataset import process_dataset
|
|
12
|
+
from .utils import compute_metrics
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse_args():
|
|
16
|
+
parser = argparse.ArgumentParser()
|
|
17
|
+
parser.add_argument('model', type=str)
|
|
18
|
+
parser.add_argument('-dn', '--dataset-name', type=str, default='admko/sembr2023')
|
|
19
|
+
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-5)
|
|
20
|
+
parser.add_argument('-tb', '--train-batch-size', type=int, default=64)
|
|
21
|
+
parser.add_argument('-eb', '--eval-batch-size', type=int, default=128)
|
|
22
|
+
parser.add_argument('-mi', '--max-indent', type=int, default=3)
|
|
23
|
+
parser.add_argument('-hu', '--hub-user', type=str, default=None)
|
|
24
|
+
parser.add_argument('-ms', '--max-steps', type=int, default=5000)
|
|
25
|
+
parser.add_argument('-es', '--eval-steps', type=int, default=10)
|
|
26
|
+
parser.add_argument('-ss', '--save-steps', type=int, default=100)
|
|
27
|
+
parser.add_argument('-rt', '--report-to', type=str, default='all')
|
|
28
|
+
parser.add_argument('-d', '--debug', action='store_true')
|
|
29
|
+
args = parser.parse_args()
|
|
30
|
+
if args.debug:
|
|
31
|
+
import debugpy
|
|
32
|
+
debugpy.listen(5678)
|
|
33
|
+
print('Waiting for debugger...')
|
|
34
|
+
debugpy.wait_for_client()
|
|
35
|
+
args.report_to = 'none'
|
|
36
|
+
args.hub_user = None
|
|
37
|
+
return args
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DataCollatorForTokenClassificationWithTruncation(
|
|
41
|
+
DataCollatorForTokenClassification
|
|
42
|
+
):
|
|
43
|
+
def __init__(self, tokenizer, max_length=512, **kwargs):
|
|
44
|
+
super().__init__(tokenizer, **kwargs)
|
|
45
|
+
self.max_length = max_length
|
|
46
|
+
|
|
47
|
+
def __call__(self, features, return_tensors=None):
|
|
48
|
+
truncated_features = []
|
|
49
|
+
for f in features:
|
|
50
|
+
truncated_features.append(
|
|
51
|
+
{k: v[:self.max_length] for k, v in f.items()})
|
|
52
|
+
return super().__call__(truncated_features, return_tensors)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def init_dataset(args, label2id, max_length):
|
|
56
|
+
dataset = datasets.load_dataset(args.dataset_name)
|
|
57
|
+
processor = SemBrProcessor()
|
|
58
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
|
59
|
+
processor.prepare_tokenizer(tokenizer)
|
|
60
|
+
train_dataset = process_dataset(
|
|
61
|
+
dataset['train'], processor, tokenizer, args.max_indent, label2id)
|
|
62
|
+
test_dataset = process_dataset(
|
|
63
|
+
dataset['test'], processor, tokenizer, args.max_indent, label2id)
|
|
64
|
+
print(f'{len(train_dataset)=}, {len(test_dataset)=}')
|
|
65
|
+
collator = DataCollatorForTokenClassificationWithTruncation(
|
|
66
|
+
tokenizer, padding='max_length', max_length=max_length)
|
|
67
|
+
return train_dataset, test_dataset, tokenizer, collator
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def init_model(model_name, max_indent):
|
|
71
|
+
label_names = ['off'] + [
|
|
72
|
+
f'{m}-{i}' for m in ['space', 'nospace']
|
|
73
|
+
for i in range(max_indent + 1)]
|
|
74
|
+
id2label = {i: l for i, l in enumerate(label_names)}
|
|
75
|
+
label2id = {l: i for i, l in enumerate(label_names)}
|
|
76
|
+
model = AutoModelForTokenClassification.from_pretrained(
|
|
77
|
+
model_name, ignore_mismatched_sizes=True,
|
|
78
|
+
num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
|
79
|
+
return model
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def main(args):
|
|
83
|
+
model = init_model(args.model, args.max_indent)
|
|
84
|
+
max_length = model.config.max_position_embeddings
|
|
85
|
+
train_dataset, test_dataset, tokenizer, collator = \
|
|
86
|
+
init_dataset(args, model.config.label2id, max_length)
|
|
87
|
+
model.config.__dict__['max_indent'] = args.max_indent
|
|
88
|
+
model.resize_token_embeddings(len(tokenizer))
|
|
89
|
+
model_name = args.model.split('/')[-1]
|
|
90
|
+
run_name = f'sembr2023-{model_name}'
|
|
91
|
+
training_args = TrainingArguments(
|
|
92
|
+
output_dir=f'checkpoints/{run_name}',
|
|
93
|
+
run_name=run_name,
|
|
94
|
+
report_to=args.report_to,
|
|
95
|
+
learning_rate=args.learning_rate,
|
|
96
|
+
lr_scheduler_type='cosine',
|
|
97
|
+
per_device_train_batch_size=args.train_batch_size,
|
|
98
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
|
99
|
+
weight_decay=1e-5,
|
|
100
|
+
evaluation_strategy='steps',
|
|
101
|
+
max_steps=args.max_steps,
|
|
102
|
+
eval_steps=args.eval_steps,
|
|
103
|
+
save_strategy='steps',
|
|
104
|
+
save_steps=args.save_steps,
|
|
105
|
+
save_total_limit=1,
|
|
106
|
+
metric_for_best_model='f1',
|
|
107
|
+
load_best_model_at_end=True,
|
|
108
|
+
logging_steps=1,
|
|
109
|
+
push_to_hub=args.hub_user is not None,
|
|
110
|
+
hub_strategy='end',
|
|
111
|
+
hub_model_id=f'{args.hub_user}/{run_name}',
|
|
112
|
+
)
|
|
113
|
+
trainer = Trainer(
|
|
114
|
+
model=model,
|
|
115
|
+
args=training_args,
|
|
116
|
+
train_dataset=train_dataset,
|
|
117
|
+
eval_dataset=test_dataset,
|
|
118
|
+
tokenizer=tokenizer,
|
|
119
|
+
data_collator=collator,
|
|
120
|
+
compute_metrics=compute_metrics,
|
|
121
|
+
)
|
|
122
|
+
trainer.train()
|
|
123
|
+
if args.hub_user is not None:
|
|
124
|
+
trainer.push_to_hub()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if __name__ == '__main__':
|
|
128
|
+
main(parse_args())
|
sembr/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def binary_metrics(fp, tp, fn, tn=None, prefix=None):
|
|
6
|
+
def safe_div(a, b):
|
|
7
|
+
return a / b if b > 0 else 0
|
|
8
|
+
result = {
|
|
9
|
+
'precision': safe_div(tp, tp + fp),
|
|
10
|
+
'recall': safe_div(tp, tp + fn),
|
|
11
|
+
'f1': safe_div(2 * tp, 2 * tp + fp + fn),
|
|
12
|
+
'iou': safe_div(tp, tp + fp + fn),
|
|
13
|
+
}
|
|
14
|
+
if tn is not None:
|
|
15
|
+
accuracies = {
|
|
16
|
+
'accuracy': safe_div(tp + tn, tn + fn + tp + fp),
|
|
17
|
+
'balanced_accuracy': (
|
|
18
|
+
safe_div(tp, tp + fn) + safe_div(tn, tn + fp)) / 2,
|
|
19
|
+
}
|
|
20
|
+
result.update(accuracies)
|
|
21
|
+
if prefix is not None:
|
|
22
|
+
result = {f'{prefix}_{k}': v for k, v in result.items()}
|
|
23
|
+
return result
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compute_metrics(result):
|
|
27
|
+
logits, labels = result
|
|
28
|
+
preds = np.argmax(logits, axis=2)
|
|
29
|
+
idx = labels != -100
|
|
30
|
+
logits, preds, labels = logits[idx], preds[idx], labels[idx]
|
|
31
|
+
loss = torch.nn.functional.cross_entropy(
|
|
32
|
+
torch.Tensor(logits), torch.LongTensor(labels), reduction='mean')
|
|
33
|
+
preds_set = set(preds.nonzero()[0])
|
|
34
|
+
labels_set = set(labels.nonzero()[0])
|
|
35
|
+
fp = len(preds_set - labels_set)
|
|
36
|
+
tp = len(labels_set & preds_set)
|
|
37
|
+
fn = len(labels_set - preds_set)
|
|
38
|
+
tn = len(set(range(len(preds))) - (labels_set | preds_set))
|
|
39
|
+
metrics = {
|
|
40
|
+
**binary_metrics(fp, tp, fn, tn),
|
|
41
|
+
'overall_accuracy': (preds == labels).mean(),
|
|
42
|
+
'loss': loss,
|
|
43
|
+
}
|
|
44
|
+
return metrics
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
The MIT License (MIT)
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Xitong Gao
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted,
|
|
6
|
+
free of charge,
|
|
7
|
+
to any person obtaining a copy of this software
|
|
8
|
+
and associated documentation files (the "Software"),
|
|
9
|
+
to deal in the Software without restriction,
|
|
10
|
+
including without limitation
|
|
11
|
+
the rights to use, copy, modify, merge,
|
|
12
|
+
publish, distribute, sublicense,
|
|
13
|
+
and/or sell copies of the Software,
|
|
14
|
+
and to permit persons to whom the Software
|
|
15
|
+
is furnished to do so,
|
|
16
|
+
subject to the following conditions:
|
|
17
|
+
|
|
18
|
+
The above copyright notice
|
|
19
|
+
and this permission notice
|
|
20
|
+
shall be included in all copies
|
|
21
|
+
or substantial portions of the Software.
|
|
22
|
+
|
|
23
|
+
THE SOFTWARE IS PROVIDED "AS IS",
|
|
24
|
+
WITHOUT WARRANTY OF ANY KIND,
|
|
25
|
+
EXPRESS OR IMPLIED,
|
|
26
|
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
27
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
28
|
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
|
29
|
+
BE LIABLE FOR ANY CLAIM,
|
|
30
|
+
DAMAGES OR OTHER LIABILITY,
|
|
31
|
+
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
|
32
|
+
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
|
|
33
|
+
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: sembr
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A semantic linebreaker powered by transformers
|
|
5
|
+
Author: admk
|
|
6
|
+
Project-URL: Homepage, https://github.com/admk/sembr
|
|
7
|
+
Project-URL: Issues, https://github.com/admk/sembr/issues
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.10
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE.txt
|
|
14
|
+
Requires-Dist: transformers
|
|
15
|
+
Requires-Dist: torch
|
|
16
|
+
Requires-Dist: numpy
|
|
17
|
+
Requires-Dist: tqdm
|
|
18
|
+
Requires-Dist: requests
|
|
19
|
+
Requires-Dist: flask
|
|
20
|
+
|
|
21
|
+
# Semantic Line Breaker (SemBr)
|
|
22
|
+
|
|
23
|
+
```
|
|
24
|
+
> When writing text
|
|
25
|
+
> with a compatible markup language,
|
|
26
|
+
> add a line break
|
|
27
|
+
> after each substantial unit of thought.
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## What is SemBr?
|
|
31
|
+
|
|
32
|
+
SemBr is a tool
|
|
33
|
+
that breaks lines in a text file
|
|
34
|
+
at semantic boundaries.
|
|
35
|
+
|
|
36
|
+
### Installation
|
|
37
|
+
|
|
38
|
+
SemBr is available as a Python package
|
|
39
|
+
on PyPI.
|
|
40
|
+
To install it,
|
|
41
|
+
simply run the following command
|
|
42
|
+
in your terminal,
|
|
43
|
+
assuming that you have Python 3.10 or later installed:
|
|
44
|
+
```shell
|
|
45
|
+
pip install sembr
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
### Supported Platforms
|
|
49
|
+
|
|
50
|
+
SemBr is supported on Linux, Mac and Windows.
|
|
51
|
+
On machines with CUDA devices,
|
|
52
|
+
or on Apple Silicon Macs,
|
|
53
|
+
SemBr will use the GPU / Apple Neural Engine
|
|
54
|
+
to accelerate inference.
|
|
55
|
+
|
|
56
|
+
### Usage
|
|
57
|
+
|
|
58
|
+
To use SemBr,
|
|
59
|
+
run the following command
|
|
60
|
+
in your terminal:
|
|
61
|
+
```shell
|
|
62
|
+
sembr -i <input_file> -o <output_file>
|
|
63
|
+
```
|
|
64
|
+
where `<input_file>` is the path to the input file.
|
|
65
|
+
|
|
66
|
+
Alternatively,
|
|
67
|
+
you can pipe the input
|
|
68
|
+
into `sembr`,
|
|
69
|
+
and the output can also be printed
|
|
70
|
+
to the terminal:
|
|
71
|
+
```shell
|
|
72
|
+
cat <input_file> | sembr
|
|
73
|
+
```
|
|
74
|
+
This is especially useful
|
|
75
|
+
if you want to use SemBr
|
|
76
|
+
with clipboard managers,
|
|
77
|
+
for instance, on a Mac:
|
|
78
|
+
```shell
|
|
79
|
+
pbpaste | sembr | pbcopy
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
Additionally,
|
|
83
|
+
you can specify the following options
|
|
84
|
+
to customize the behavior of SemBr:
|
|
85
|
+
* `-m <model_name>`: The name of the Hugging Face model to use.
|
|
86
|
+
The default is `admko/sembr2023-bert-small`.
|
|
87
|
+
* `-l`: Serves the SemBr API on a local server.
|
|
88
|
+
Each instance of `sembr` run
|
|
89
|
+
will detect if the API is accessible,
|
|
90
|
+
and if not it will run the model on its own.
|
|
91
|
+
This option is useful
|
|
92
|
+
to avoid the time taken to initialize the model
|
|
93
|
+
by keeping it in memory in a separate process.
|
|
94
|
+
* `-p <port>`: The port to serve the SemBr API on.
|
|
95
|
+
The default is `8384`.
|
|
96
|
+
* `-s <ip>`: The IP address to serve the SemBr API on.
|
|
97
|
+
The default is `127.0.0.1`.
|
|
98
|
+
|
|
99
|
+
## What are Semantic Line Breaks?
|
|
100
|
+
|
|
101
|
+
[Semantic Line Breaks](https://sembr.org)
|
|
102
|
+
or [
|
|
103
|
+
Semantic Linefeeds
|
|
104
|
+
](https://rhodesmill.org/brandon/2012/one-sentence-per-line/)
|
|
105
|
+
describe a set of conventions
|
|
106
|
+
for using insensitive vertical whitespace
|
|
107
|
+
to structure prose along semantic boundaries.
|
|
108
|
+
|
|
109
|
+
## Why use Semantic Line Breaks?
|
|
110
|
+
|
|
111
|
+
Semantic Line Breaks has the following advantages:
|
|
112
|
+
|
|
113
|
+
* Breaking lines by splitting clauses
|
|
114
|
+
reflects the logical, grammatical and semantic structure
|
|
115
|
+
of the text.
|
|
116
|
+
|
|
117
|
+
* It enhances the ease of editing and version control
|
|
118
|
+
for a text file.
|
|
119
|
+
Merge conflicts are less likely to occur
|
|
120
|
+
when small changes are made,
|
|
121
|
+
and the changes are easier to identify.
|
|
122
|
+
|
|
123
|
+
* Documents written with semantic line breaks
|
|
124
|
+
are easier to navigate and edit
|
|
125
|
+
with Vim and other text editors
|
|
126
|
+
that use Vim keybindings.
|
|
127
|
+
|
|
128
|
+
* Semantic line breaks
|
|
129
|
+
are invisible to readers.
|
|
130
|
+
The final rendered output
|
|
131
|
+
shows no changes to the source text.
|
|
132
|
+
|
|
133
|
+
## Why SemBr?
|
|
134
|
+
|
|
135
|
+
Converting existing text not written
|
|
136
|
+
with semantic line breaks
|
|
137
|
+
takes a long time to do it manually,
|
|
138
|
+
and it is surprisingly difficult
|
|
139
|
+
to do it automatically
|
|
140
|
+
with rule-based methods.
|
|
141
|
+
|
|
142
|
+
### Challenges of rule-based methods
|
|
143
|
+
|
|
144
|
+
Rule-based heuristics do not work well
|
|
145
|
+
with the actual semantic structure of the text,
|
|
146
|
+
often leading to incorrect semantic boundaries.
|
|
147
|
+
Moreover,
|
|
148
|
+
semantic boundaries are hierarchical and nested,
|
|
149
|
+
and a rule-based approach
|
|
150
|
+
cannot capture this structure.
|
|
151
|
+
A semantic line break
|
|
152
|
+
may occur after a dependent clause,
|
|
153
|
+
but not all clauses should be broken into lines.
|
|
154
|
+
For examples:
|
|
155
|
+
|
|
156
|
+
* A rule that breaks lines at punctuation marks
|
|
157
|
+
will not work well
|
|
158
|
+
with sentences that contain
|
|
159
|
+
periods in abbreviations or mathematical expressions.
|
|
160
|
+
|
|
161
|
+
* For example,
|
|
162
|
+
"I like to eat apples and oranges
|
|
163
|
+
because they are healthy."
|
|
164
|
+
should be broken into lines as follows:
|
|
165
|
+
```
|
|
166
|
+
I like to eat apples and oranges
|
|
167
|
+
because they are healthy.
|
|
168
|
+
```
|
|
169
|
+
rather than:
|
|
170
|
+
```
|
|
171
|
+
I like to eat apples
|
|
172
|
+
and oranges because they are healthy.
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
For this reason,
|
|
176
|
+
I have created SemBr,
|
|
177
|
+
which uses finetuned Transformer models
|
|
178
|
+
to predict line breaks at semantic boundaries.
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
## How does SemBr work?
|
|
182
|
+
|
|
183
|
+
SemBr uses a Transformer model
|
|
184
|
+
to predict line breaks
|
|
185
|
+
at semantic boundaries.
|
|
186
|
+
A small dataset of text
|
|
187
|
+
with semantic line breaks was created
|
|
188
|
+
from my existing LaTeX documents.
|
|
189
|
+
The dataset was split into training
|
|
190
|
+
(46,295 lines, 170,681 words and 1,492,952 characters)
|
|
191
|
+
and test
|
|
192
|
+
(2,187 lines, 7,564 words and 72,231 characters)
|
|
193
|
+
datasets.
|
|
194
|
+
|
|
195
|
+
The data was prepared
|
|
196
|
+
by extracting line breaks and indent levels
|
|
197
|
+
from the files,
|
|
198
|
+
and then converting the files
|
|
199
|
+
into strings of paragraphs
|
|
200
|
+
with line breaks removed.
|
|
201
|
+
The data can then be tokenized
|
|
202
|
+
using the tokenizer
|
|
203
|
+
and converted into a dataset
|
|
204
|
+
with tokens,
|
|
205
|
+
where each token has a label
|
|
206
|
+
denoting:
|
|
207
|
+
* no line break (label = 0), or
|
|
208
|
+
* a line break
|
|
209
|
+
that adds a space in LaTeX documents
|
|
210
|
+
at the token with an indent level
|
|
211
|
+
(label in [0, 1, 2, ..., MAX_INDENT]), or
|
|
212
|
+
* a line break that adds no space
|
|
213
|
+
(label in [MAX_INDENT + 1, MAX_INDENT + 2, ..., 2 * MAX_INDENT]).
|
|
214
|
+
|
|
215
|
+
The pretrained masked language model
|
|
216
|
+
is then finetuned as a token classifier
|
|
217
|
+
on the training dataset
|
|
218
|
+
to predict the labels
|
|
219
|
+
of the tokens.
|
|
220
|
+
We save the model
|
|
221
|
+
with the best F1 score
|
|
222
|
+
on correctly predicting line breaks of any kind
|
|
223
|
+
on the test set.
|
|
224
|
+
The finetuning logs
|
|
225
|
+
for all models including the following
|
|
226
|
+
can be found
|
|
227
|
+
on this [WandB](https://wandb.ai/admko/sembr2023) report:
|
|
228
|
+
* [`distilbert-base-uncased`](https://huggingface.co/distilbert-base-uncased).
|
|
229
|
+
* [`distilbert-base-cased`](https://huggingface.co/distilbert-base-cased).
|
|
230
|
+
* [`distilbert-base-uncased-finetuned-sst-2-english`](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).
|
|
231
|
+
* [`prajjwal1/bert-tiny`](https://huggingface.co/prajjwal1/bert-tiny).
|
|
232
|
+
* [`prajjwal1/bert-small`](https://huggingface.co/prajjwal1/bert-small).
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
## Performance
|
|
236
|
+
|
|
237
|
+
Current inference speed
|
|
238
|
+
on an M2 Macbook Pro
|
|
239
|
+
is about 1,500 characters per second
|
|
240
|
+
on `bert-small`,
|
|
241
|
+
the memory usage is about 1.70 GB.
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
## Improvements and TODOs
|
|
245
|
+
|
|
246
|
+
* [ ] Support natural languages other than English.
|
|
247
|
+
* [ ] Support other markup languages
|
|
248
|
+
such as Markdown.
|
|
249
|
+
* [ ] Some lines are too long
|
|
250
|
+
without a line break.
|
|
251
|
+
The inference algorithm
|
|
252
|
+
can be improved to penalize long lines.
|
|
253
|
+
* [ ] Performance benchmarking.
|
|
254
|
+
* [ ] Improve inference speed.
|
|
255
|
+
* [ ] Reduce memory usage.
|
|
256
|
+
* [ ] Improve indent level prediction.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
sembr/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
sembr/cli.py,sha256=3uP01_v3-MRhwL3_46u1aG6p2BXq46DIE8SPA35Opu8,5179
|
|
3
|
+
sembr/dataset.py,sha256=FZqXwfrXY8veN7YNI5hu0FM3fCNcxNBWzEsK8XySR9U,3177
|
|
4
|
+
sembr/eval.py,sha256=fWIMlcRapnD3z0wj8ejvyaPBK4lx06n5bDke9oB4rHE,1066
|
|
5
|
+
sembr/inference.py,sha256=80bWRBa1FSMfN30SN72ii9eboTuO9DHXyzGdBJjEhqQ,4361
|
|
6
|
+
sembr/process.py,sha256=QgP2LbsgS2lboNvtKquiyWPhw3DSdPguB6WEg3-cRJ4,11017
|
|
7
|
+
sembr/sembr2023.py,sha256=tEdZR6IThAML69gJM0ZtWi6mBIh6xTSGFnz8mE2xgGI,1878
|
|
8
|
+
sembr/train.py,sha256=BpBybR8s0m8kLWEhOy5XOFNRffg6hqbddcMRYJzUnQw,4791
|
|
9
|
+
sembr/utils.py,sha256=dp4xY_OME3P7aETYS16Xn55lEfqVg3K871Un7jllu_4,1436
|
|
10
|
+
sembr-0.0.1.dist-info/LICENSE.txt,sha256=ox_0dcgQRQpWRLa3ZuFBzCoAbnCoKxFRMHnQPh7Mue0,1077
|
|
11
|
+
sembr-0.0.1.dist-info/METADATA,sha256=cCgJMS2QtC56V0-iLcyZ50y-XFW8I4J2Ib6EukSzBPA,6775
|
|
12
|
+
sembr-0.0.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
13
|
+
sembr-0.0.1.dist-info/entry_points.txt,sha256=TPj7T_mm1qglnW5yf8JWQjEY8BaAb3CWyH8TpFE1u8U,41
|
|
14
|
+
sembr-0.0.1.dist-info/top_level.txt,sha256=q-9hKQBeQBW_3ERrGmDUlJjcaZla-VxX9PTK13v-u9E,6
|
|
15
|
+
sembr-0.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
sembr
|