hyperbase-parser-ab 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperbase_parser_ab-0.2.0/.pre-commit-config.yaml +7 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/CHANGELOG.md +6 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/PKG-INFO +3 -2
- hyperbase_parser_ab-0.2.0/VERSION +1 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/pyproject.toml +15 -2
- hyperbase_parser_ab-0.2.0/scripts/generate_alpha_training_data.py +113 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/scripts/train_atomizer.py +29 -27
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/src/hyperbase_parser_ab/alpha.py +32 -20
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/src/hyperbase_parser_ab/atomizer.py +33 -24
- hyperbase_parser_ab-0.2.0/src/hyperbase_parser_ab/lang_models.py +9 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/src/hyperbase_parser_ab/parser.py +317 -251
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/src/hyperbase_parser_ab/rules.py +32 -24
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/src/hyperbase_parser_ab/sentensizer.py +2 -2
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/tests/test_parser.py +95 -81
- hyperbase_parser_ab-0.2.0/tests/test_parser_helpers.py +257 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/tests/test_rules.py +45 -46
- hyperbase_parser_ab-0.1.0/VERSION +0 -1
- hyperbase_parser_ab-0.1.0/scripts/generate_alpha_training_data.py +0 -107
- hyperbase_parser_ab-0.1.0/src/hyperbase_parser_ab/lang_models.py +0 -50
- hyperbase_parser_ab-0.1.0/tests/test_parser_helpers.py +0 -250
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/.github/workflows/publish.yml +0 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/.gitignore +0 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/LICENSE +0 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/README.md +0 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/src/hyperbase_parser_ab/__init__.py +0 -0
- {hyperbase_parser_ab-0.1.0 → hyperbase_parser_ab-0.2.0}/tests/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hyperbase-parser-ab
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Semantic Hypergraph AlphaBeta Parser
|
|
5
5
|
Project-URL: Homepage, https://hyperquest.ai/hyperbase
|
|
6
6
|
Author-email: "Telmo Menezes et al." <telmo@telmomenezes.net>
|
|
@@ -15,7 +15,7 @@ Classifier: Programming Language :: Python :: 3
|
|
|
15
15
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
16
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
|
17
17
|
Requires-Python: >=3.10
|
|
18
|
-
Requires-Dist: hyperbase>=0.
|
|
18
|
+
Requires-Dist: hyperbase>=0.9.0
|
|
19
19
|
Requires-Dist: pip
|
|
20
20
|
Requires-Dist: scikit-learn>=1.3.0
|
|
21
21
|
Requires-Dist: spacy>=3.8.0
|
|
@@ -24,6 +24,7 @@ Requires-Dist: transformers>=4.46.0
|
|
|
24
24
|
Provides-Extra: dev
|
|
25
25
|
Requires-Dist: coverage>=7.4.3; extra == 'dev'
|
|
26
26
|
Requires-Dist: datasets>=4.0.0; extra == 'dev'
|
|
27
|
+
Requires-Dist: evaluate>=0.4.6; extra == 'dev'
|
|
27
28
|
Requires-Dist: mypy>=1.8.0; extra == 'dev'
|
|
28
29
|
Requires-Dist: pre-commit>=3.6.2; extra == 'dev'
|
|
29
30
|
Requires-Dist: pytest>=9.0.0; extra == 'dev'
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.2.0
|
|
@@ -26,12 +26,12 @@ classifiers = [
|
|
|
26
26
|
"Topic :: Scientific/Engineering :: Information Analysis",
|
|
27
27
|
]
|
|
28
28
|
dependencies = [
|
|
29
|
-
"hyperbase>=0.
|
|
29
|
+
"hyperbase>=0.9.0",
|
|
30
30
|
"scikit-learn>=1.3.0",
|
|
31
31
|
"spacy>=3.8.0",
|
|
32
32
|
"torch>=2.0.0",
|
|
33
33
|
"transformers>=4.46.0",
|
|
34
|
-
"pip",
|
|
34
|
+
"pip", # so that spaCy models can be easily installed with uv
|
|
35
35
|
]
|
|
36
36
|
|
|
37
37
|
[tool.uv.sources]
|
|
@@ -46,6 +46,7 @@ dev = [
|
|
|
46
46
|
"coverage>=7.4.3",
|
|
47
47
|
"datasets>=4.0.0",
|
|
48
48
|
"pytest>=9.0.0",
|
|
49
|
+
"evaluate>=0.4.6",
|
|
49
50
|
]
|
|
50
51
|
|
|
51
52
|
[project.urls]
|
|
@@ -70,3 +71,15 @@ strict = true
|
|
|
70
71
|
|
|
71
72
|
[tool.ruff]
|
|
72
73
|
target-version = "py310"
|
|
74
|
+
|
|
75
|
+
[tool.ruff.lint]
|
|
76
|
+
select = ["E", "F", "W", "I", "UP", "B", "SIM", "RUF", "Q", "C4", "PT", "N", "ANN"]
|
|
77
|
+
|
|
78
|
+
[tool.ruff.lint.per-file-ignores]
|
|
79
|
+
"tests/*" = ["E501", "ANN001", "ANN201", "ANN202", "ANN205", "D100", "D101", "D102", "D400", "D415"]
|
|
80
|
+
|
|
81
|
+
[tool.ruff.lint.flake8-quotes]
|
|
82
|
+
inline-quotes = "double"
|
|
83
|
+
|
|
84
|
+
[tool.ruff.format]
|
|
85
|
+
quote-style = "double"
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
from hyperbase import hedge
|
|
5
|
+
|
|
6
|
+
from hyperbase_parser_ab import AlphaBetaParser
|
|
7
|
+
|
|
8
|
+
if __name__ == "__main__":
|
|
9
|
+
arg_parser = argparse.ArgumentParser(description="Generate alpha training data.")
|
|
10
|
+
arg_parser.add_argument("infile", type=str, help="input jsonl file")
|
|
11
|
+
arg_parser.add_argument("outfile", type=str, help="output tsv file")
|
|
12
|
+
arg_parser.add_argument(
|
|
13
|
+
"--lang", type=str, default="en", help="language (default: en)"
|
|
14
|
+
)
|
|
15
|
+
args = arg_parser.parse_args()
|
|
16
|
+
|
|
17
|
+
total_sentences = 0
|
|
18
|
+
ignored_sentences = 0
|
|
19
|
+
failed_parses = 0
|
|
20
|
+
total_atoms = 0
|
|
21
|
+
|
|
22
|
+
parser = AlphaBetaParser(lang=args.lang)
|
|
23
|
+
|
|
24
|
+
with open(args.infile) as infile, open(args.outfile, "w") as outfile:
|
|
25
|
+
for line in infile.readlines():
|
|
26
|
+
case = json.loads(line)
|
|
27
|
+
sentence = case["sentence"]
|
|
28
|
+
atoms = case["atoms"]
|
|
29
|
+
parses = parser.parse_sentence(sentence)
|
|
30
|
+
spacy_sentence = next(iter(parser.doc.sents)) if parser.doc else None
|
|
31
|
+
if not spacy_sentence or not parses:
|
|
32
|
+
failed_parses += 1
|
|
33
|
+
elif case["ignore"]:
|
|
34
|
+
ignored_sentences += 1
|
|
35
|
+
elif len(atoms) == len(spacy_sentence):
|
|
36
|
+
total_sentences += 1
|
|
37
|
+
total_atoms += len(atoms)
|
|
38
|
+
|
|
39
|
+
for i in range(len(atoms)):
|
|
40
|
+
atom = atoms[i]
|
|
41
|
+
token = spacy_sentence[i]
|
|
42
|
+
atom_edge = hedge(atom)
|
|
43
|
+
if atom_edge is None:
|
|
44
|
+
continue
|
|
45
|
+
|
|
46
|
+
word_before = ""
|
|
47
|
+
word_after = ""
|
|
48
|
+
pos_before = ""
|
|
49
|
+
pos_after = ""
|
|
50
|
+
tag_before = ""
|
|
51
|
+
tag_after = ""
|
|
52
|
+
dep_before = ""
|
|
53
|
+
dep_after = ""
|
|
54
|
+
punct_before = False
|
|
55
|
+
punct_after = False
|
|
56
|
+
if i > 0:
|
|
57
|
+
word_before = str(spacy_sentence[i - 1])
|
|
58
|
+
pos_before = spacy_sentence[i - 1].pos_
|
|
59
|
+
tag_before = spacy_sentence[i - 1].tag_
|
|
60
|
+
dep_before = spacy_sentence[i - 1].dep_
|
|
61
|
+
if spacy_sentence[i - 1].pos_ == "PUNCT":
|
|
62
|
+
punct_before = True
|
|
63
|
+
if i < len(atoms) - 1:
|
|
64
|
+
word_after = str(spacy_sentence[i + 1])
|
|
65
|
+
pos_after = spacy_sentence[i + 1].pos_
|
|
66
|
+
tag_after = spacy_sentence[i + 1].tag_
|
|
67
|
+
dep_after = spacy_sentence[i + 1].dep_
|
|
68
|
+
if spacy_sentence[i + 1].pos_ == "PUNCT":
|
|
69
|
+
punct_after = True
|
|
70
|
+
|
|
71
|
+
head = token.head
|
|
72
|
+
is_root = head is None
|
|
73
|
+
has_lefts = token.n_lefts > 0
|
|
74
|
+
has_rights = token.n_rights > 0
|
|
75
|
+
outfile.write(
|
|
76
|
+
("{}" + "\t{}" * 25 + "\n").format(
|
|
77
|
+
atom_edge.mtype(),
|
|
78
|
+
str(token),
|
|
79
|
+
token.pos_,
|
|
80
|
+
token.tag_,
|
|
81
|
+
token.dep_,
|
|
82
|
+
str(head) if head else "",
|
|
83
|
+
head.pos_ if head else "",
|
|
84
|
+
head.tag_ if head else "",
|
|
85
|
+
head.dep_ if head else "",
|
|
86
|
+
is_root,
|
|
87
|
+
has_lefts,
|
|
88
|
+
has_rights,
|
|
89
|
+
token.ent_type_,
|
|
90
|
+
token.shape_[:2],
|
|
91
|
+
word_before,
|
|
92
|
+
word_after,
|
|
93
|
+
punct_before,
|
|
94
|
+
punct_after,
|
|
95
|
+
pos_before,
|
|
96
|
+
pos_after,
|
|
97
|
+
tag_before,
|
|
98
|
+
tag_after,
|
|
99
|
+
dep_before,
|
|
100
|
+
dep_after,
|
|
101
|
+
case["correct"],
|
|
102
|
+
case["source"],
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
failed_parses += 1
|
|
107
|
+
print(
|
|
108
|
+
f"sentences: {total_sentences}; "
|
|
109
|
+
f"ignored: {ignored_sentences}; "
|
|
110
|
+
f"failed: {failed_parses}; "
|
|
111
|
+
f"atoms: {total_atoms}"
|
|
112
|
+
)
|
|
113
|
+
print("done.")
|
|
@@ -1,27 +1,28 @@
|
|
|
1
1
|
import json
|
|
2
2
|
|
|
3
|
+
import evaluate
|
|
3
4
|
import numpy as np
|
|
4
|
-
from numpy.typing import NDArray
|
|
5
5
|
from datasets import Dataset
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
from transformers import (
|
|
7
|
-
AutoTokenizer,
|
|
8
8
|
AutoModelForTokenClassification,
|
|
9
|
+
AutoTokenizer,
|
|
10
|
+
Trainer,
|
|
9
11
|
TrainingArguments,
|
|
10
|
-
Trainer
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def tokenize_and_align_labels(examples: dict[str, list]) -> dict[str, list]:
|
|
15
16
|
"""Tokenize each sample and align the original token labels
|
|
16
|
-
|
|
17
|
+
to the new subword (tokenized) structure."""
|
|
17
18
|
|
|
18
19
|
tokenized_outputs = tokenizer(
|
|
19
20
|
examples["tokens"],
|
|
20
21
|
truncation=True,
|
|
21
|
-
is_split_into_words=True,
|
|
22
|
+
is_split_into_words=True, # Important for token-based tasks
|
|
22
23
|
return_offsets_mapping=True, # We'll use this if needed
|
|
23
|
-
padding="max_length",
|
|
24
|
-
max_length=200
|
|
24
|
+
padding="max_length", # or "longest" / "do_not_pad"
|
|
25
|
+
max_length=200, # adjust as needed
|
|
25
26
|
)
|
|
26
27
|
|
|
27
28
|
labels_aligned: list[list[int]] = []
|
|
@@ -31,7 +32,6 @@ def tokenize_and_align_labels(examples: dict[str, list]) -> dict[str, list]:
|
|
|
31
32
|
# repeating the label for all subwords of the original token.
|
|
32
33
|
word_ids: list[int | None] = tokenized_outputs.word_ids(batch_index=i)
|
|
33
34
|
label_ids: list[int] = []
|
|
34
|
-
previous_word_idx: int | None = None
|
|
35
35
|
|
|
36
36
|
for word_idx in word_ids:
|
|
37
37
|
if word_idx is None:
|
|
@@ -39,7 +39,6 @@ def tokenize_and_align_labels(examples: dict[str, list]) -> dict[str, list]:
|
|
|
39
39
|
label_ids.append(-100)
|
|
40
40
|
else:
|
|
41
41
|
label_ids.append(label_to_id[labels[word_idx]])
|
|
42
|
-
previous_word_idx = word_idx
|
|
43
42
|
|
|
44
43
|
labels_aligned.append(label_ids)
|
|
45
44
|
|
|
@@ -52,8 +51,8 @@ def tokenize_and_align_labels(examples: dict[str, list]) -> dict[str, list]:
|
|
|
52
51
|
|
|
53
52
|
def compute_metrics(eval_pred: tuple[NDArray, NDArray]) -> dict[str, float]:
|
|
54
53
|
"""Compute accuracy at the token level (simple example).
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
You can also compute F1, precision, recall, etc. by ignoring
|
|
55
|
+
the -100 special tokens."""
|
|
57
56
|
logits: NDArray
|
|
58
57
|
labels: NDArray
|
|
59
58
|
logits, labels = eval_pred
|
|
@@ -62,33 +61,35 @@ def compute_metrics(eval_pred: tuple[NDArray, NDArray]) -> dict[str, float]:
|
|
|
62
61
|
# Flatten ignoring -100
|
|
63
62
|
true_predictions: list[int] = []
|
|
64
63
|
true_labels: list[int] = []
|
|
65
|
-
for pred, lab in zip(predictions, labels):
|
|
66
|
-
for
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
64
|
+
for pred, lab in zip(predictions, labels, strict=True):
|
|
65
|
+
for _pred, _lab in zip(
|
|
66
|
+
pred,
|
|
67
|
+
lab,
|
|
68
|
+
strict=False,
|
|
69
|
+
):
|
|
70
|
+
if _lab != -100: # skip special tokens
|
|
71
|
+
true_predictions.append(_pred)
|
|
72
|
+
true_labels.append(_lab)
|
|
70
73
|
|
|
71
74
|
results: dict[str, float] = accuracy_metric.compute(
|
|
72
|
-
references=true_labels,
|
|
73
|
-
predictions=true_predictions
|
|
75
|
+
references=true_labels, predictions=true_predictions
|
|
74
76
|
)
|
|
75
77
|
return {"accuracy": results["accuracy"]}
|
|
76
78
|
|
|
77
79
|
|
|
78
|
-
if __name__ ==
|
|
79
|
-
with open("sentences.jsonl"
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
with open("sentences.jsonl") as f:
|
|
80
82
|
sentences: list[dict] = [json.loads(line) for line in f]
|
|
81
83
|
|
|
82
84
|
dataset_dict: dict[str, list] = {
|
|
83
85
|
"tokens": [sentence["words"] for sentence in sentences],
|
|
84
|
-
"labels": [sentence["types"] for sentence in sentences]
|
|
86
|
+
"labels": [sentence["types"] for sentence in sentences],
|
|
85
87
|
}
|
|
86
88
|
|
|
87
89
|
full_dataset: Dataset = Dataset.from_dict(dataset_dict)
|
|
88
90
|
|
|
89
91
|
max_words: int = max([len(sentence["words"]) for sentence in sentences])
|
|
90
92
|
|
|
91
|
-
|
|
92
93
|
labels: set[str] = set()
|
|
93
94
|
for sentence in sentences:
|
|
94
95
|
labels |= set(sentence["types"])
|
|
@@ -103,9 +104,10 @@ if __name__ == '__main__':
|
|
|
103
104
|
print("Num train samples:", len(train_dataset))
|
|
104
105
|
print("Num test samples: ", len(test_dataset))
|
|
105
106
|
|
|
106
|
-
|
|
107
107
|
model_checkpoint: str = "distilbert-base-multilingual-cased"
|
|
108
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
108
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
109
|
+
model_checkpoint, use_fast=True, add_prefix_space=True
|
|
110
|
+
)
|
|
109
111
|
|
|
110
112
|
# Apply to train/test datasets
|
|
111
113
|
train_dataset = train_dataset.map(tokenize_and_align_labels, batched=True)
|
|
@@ -123,7 +125,7 @@ if __name__ == '__main__':
|
|
|
123
125
|
model_checkpoint,
|
|
124
126
|
num_labels=len(labels),
|
|
125
127
|
id2label=id_to_label,
|
|
126
|
-
label2id=label_to_id
|
|
128
|
+
label2id=label_to_id,
|
|
127
129
|
)
|
|
128
130
|
|
|
129
131
|
accuracy_metric = evaluate.load("accuracy") # type: ignore[attr-defined]
|
|
@@ -139,7 +141,7 @@ if __name__ == '__main__':
|
|
|
139
141
|
weight_decay=0.01,
|
|
140
142
|
logging_dir="./logs",
|
|
141
143
|
logging_steps=10,
|
|
142
|
-
report_to="none" # Set to "tensorboard" if you want logs
|
|
144
|
+
report_to="none", # Set to "tensorboard" if you want logs
|
|
143
145
|
)
|
|
144
146
|
|
|
145
147
|
trainer: Trainer = Trainer(
|
|
@@ -148,7 +150,7 @@ if __name__ == '__main__':
|
|
|
148
150
|
train_dataset=train_dataset,
|
|
149
151
|
eval_dataset=test_dataset,
|
|
150
152
|
processing_class=tokenizer,
|
|
151
|
-
compute_metrics=compute_metrics
|
|
153
|
+
compute_metrics=compute_metrics,
|
|
152
154
|
)
|
|
153
155
|
|
|
154
156
|
trainer.train()
|
|
@@ -8,20 +8,22 @@ from spacy.tokens import Span
|
|
|
8
8
|
from hyperbase_parser_ab.atomizer import Atomizer
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
class Alpha
|
|
12
|
-
def __init__(
|
|
11
|
+
class Alpha:
|
|
12
|
+
def __init__(
|
|
13
|
+
self, cases_str: str | None = None, use_atomizer: bool = False
|
|
14
|
+
) -> None:
|
|
13
15
|
if use_atomizer:
|
|
14
16
|
self.atomizer: Atomizer | None = Atomizer()
|
|
15
17
|
elif cases_str:
|
|
16
18
|
self.atomizer = None
|
|
17
19
|
|
|
18
|
-
|
|
20
|
+
x: list[tuple[str, str, str, str, str]] = []
|
|
19
21
|
y: list[list[str]] = []
|
|
20
22
|
|
|
21
|
-
for line in cases_str.strip().split(
|
|
23
|
+
for line in cases_str.strip().split("\n"):
|
|
22
24
|
sline: str = line.strip()
|
|
23
25
|
if len(sline) > 0:
|
|
24
|
-
row: list[str] = sline.strip().split(
|
|
26
|
+
row: list[str] = sline.strip().split("\t")
|
|
25
27
|
true_value: str = row[0]
|
|
26
28
|
tag: str = row[3]
|
|
27
29
|
dep: str = row[4]
|
|
@@ -30,40 +32,50 @@ class Alpha(object):
|
|
|
30
32
|
pos_after: str = row[19]
|
|
31
33
|
|
|
32
34
|
y.append([true_value])
|
|
33
|
-
|
|
35
|
+
x.append((tag, dep, hpos, hdep, pos_after))
|
|
34
36
|
|
|
35
37
|
if len(y) > 0:
|
|
36
38
|
self.empty: bool = False
|
|
37
39
|
|
|
38
|
-
self.encX: OneHotEncoder = OneHotEncoder(
|
|
39
|
-
|
|
40
|
-
|
|
40
|
+
self.encX: OneHotEncoder = OneHotEncoder(
|
|
41
|
+
handle_unknown="ignore", sparse_output=False
|
|
42
|
+
)
|
|
43
|
+
self.encX.fit(np.array(x))
|
|
44
|
+
self.ency: OneHotEncoder = OneHotEncoder(
|
|
45
|
+
handle_unknown="ignore", sparse_output=False
|
|
46
|
+
)
|
|
41
47
|
self.ency.fit(np.array(y))
|
|
42
48
|
|
|
43
|
-
|
|
49
|
+
x_: NDArray | spmatrix = self.encX.transform(np.array(x))
|
|
44
50
|
y_: NDArray | spmatrix = self.ency.transform(np.array(y))
|
|
45
51
|
|
|
46
|
-
self.clf: RandomForestClassifier = RandomForestClassifier(
|
|
47
|
-
|
|
52
|
+
self.clf: RandomForestClassifier = RandomForestClassifier(
|
|
53
|
+
random_state=777
|
|
54
|
+
)
|
|
55
|
+
self.clf.fit(x_, y_)
|
|
48
56
|
else:
|
|
49
57
|
self.empty = True
|
|
50
58
|
|
|
51
|
-
def predict(
|
|
59
|
+
def predict(
|
|
60
|
+
self, sentence: Span, features: list[tuple[str, str, str, str, str]]
|
|
61
|
+
) -> tuple[str, ...] | list[str]:
|
|
52
62
|
if self.atomizer:
|
|
53
63
|
preds: list[tuple[str, str]] = self.atomizer.atomize(
|
|
54
|
-
sentence=str(sentence),
|
|
55
|
-
|
|
64
|
+
sentence=str(sentence), tokens=[str(token) for token in sentence]
|
|
65
|
+
)
|
|
56
66
|
atom_types: list[str] = [pred[1] for pred in preds]
|
|
57
67
|
|
|
58
68
|
# force known cases
|
|
59
69
|
for i in range(len(atom_types)):
|
|
60
|
-
if sentence[i].pos_ ==
|
|
61
|
-
atom_types[i] =
|
|
70
|
+
if sentence[i].pos_ == "VERB":
|
|
71
|
+
atom_types[i] = "P"
|
|
62
72
|
return atom_types
|
|
63
73
|
else:
|
|
64
74
|
# an empty classifier always predicts 'C'
|
|
65
75
|
if self.empty:
|
|
66
|
-
return tuple(
|
|
76
|
+
return tuple("C" for _ in range(len(features)))
|
|
67
77
|
_features: NDArray | spmatrix = self.encX.transform(np.array(features))
|
|
68
|
-
preds_arr: NDArray | spmatrix = self.ency.inverse_transform(
|
|
69
|
-
|
|
78
|
+
preds_arr: NDArray | spmatrix = self.ency.inverse_transform(
|
|
79
|
+
self.clf.predict(_features)
|
|
80
|
+
)
|
|
81
|
+
return tuple(pred[0] if pred else "C" for pred in preds_arr)
|
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
from collections import Counter
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from transformers import
|
|
5
|
-
|
|
4
|
+
from transformers import (
|
|
5
|
+
AutoModelForTokenClassification,
|
|
6
|
+
AutoTokenizer,
|
|
7
|
+
PreTrainedModel,
|
|
8
|
+
PreTrainedTokenizerBase,
|
|
9
|
+
)
|
|
6
10
|
|
|
7
11
|
HF_REPO: str = "hyperquest/atom-classifier"
|
|
8
12
|
|
|
@@ -11,21 +15,21 @@ class Atomizer:
|
|
|
11
15
|
def __init__(self, model_path: str | None = None) -> None:
|
|
12
16
|
model_id: str = model_path or HF_REPO
|
|
13
17
|
self.model_path: str = model_id
|
|
14
|
-
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
|
15
|
-
|
|
18
|
+
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
|
19
|
+
model_id, use_fast=True
|
|
20
|
+
)
|
|
21
|
+
self.model: PreTrainedModel = AutoModelForTokenClassification.from_pretrained(
|
|
22
|
+
model_id
|
|
23
|
+
)
|
|
16
24
|
assert self.model.config.id2label
|
|
17
25
|
self.id2label: dict[int, str] = self.model.config.id2label
|
|
18
26
|
|
|
19
|
-
def atomize(
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
) -> list[tuple[str, str]]:
|
|
27
|
+
def atomize(
|
|
28
|
+
self, sentence: str, tokens: list[str] | None = None
|
|
29
|
+
) -> list[tuple[str, str]]:
|
|
23
30
|
# Tokenize the raw sentence and request offsets
|
|
24
31
|
encoded = self.tokenizer(
|
|
25
|
-
sentence,
|
|
26
|
-
return_tensors="pt",
|
|
27
|
-
truncation=True,
|
|
28
|
-
return_offsets_mapping=True
|
|
32
|
+
sentence, return_tensors="pt", truncation=True, return_offsets_mapping=True
|
|
29
33
|
)
|
|
30
34
|
|
|
31
35
|
offset_mapping = encoded.pop("offset_mapping") # remove so model doesn't see it
|
|
@@ -39,7 +43,9 @@ class Atomizer:
|
|
|
39
43
|
|
|
40
44
|
if tokens is not None:
|
|
41
45
|
# Map provided tokens to model predictions based on character offsets
|
|
42
|
-
return self._map_tokens_to_predictions(
|
|
46
|
+
return self._map_tokens_to_predictions(
|
|
47
|
+
sentence, tokens, word_ids, pred_ids, offset_mapping
|
|
48
|
+
)
|
|
43
49
|
|
|
44
50
|
predicted_labels: list[tuple[str, str]] = []
|
|
45
51
|
current_word_id: int | None = None
|
|
@@ -79,13 +85,14 @@ class Atomizer:
|
|
|
79
85
|
|
|
80
86
|
return predicted_labels
|
|
81
87
|
|
|
82
|
-
def _map_tokens_to_predictions(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
def _map_tokens_to_predictions(
|
|
89
|
+
self,
|
|
90
|
+
sentence: str,
|
|
91
|
+
tokens: list[str],
|
|
92
|
+
word_ids: list[int | None],
|
|
93
|
+
pred_ids: list[int],
|
|
94
|
+
offset_mapping: list[list[int]],
|
|
95
|
+
) -> list[tuple[str, str]]:
|
|
89
96
|
"""
|
|
90
97
|
Maps provided tokens to model predictions by finding character offsets
|
|
91
98
|
and assigning the most appropriate label based on overlapping model tokens.
|
|
@@ -105,10 +112,10 @@ class Atomizer:
|
|
|
105
112
|
|
|
106
113
|
# For each provided token, collect overlapping model predictions
|
|
107
114
|
result: list[tuple[str, str]] = []
|
|
108
|
-
for token, positions in zip(tokens, token_positions):
|
|
115
|
+
for token, positions in zip(tokens, token_positions, strict=True):
|
|
109
116
|
if positions is None:
|
|
110
117
|
# Token not found in sentence - assign default label
|
|
111
|
-
result.append((token,
|
|
118
|
+
result.append((token, "C"))
|
|
112
119
|
continue
|
|
113
120
|
|
|
114
121
|
token_start: int
|
|
@@ -133,10 +140,12 @@ class Atomizer:
|
|
|
133
140
|
# Assign the most common label, or first label if tie
|
|
134
141
|
if overlapping_labels:
|
|
135
142
|
# Use most common label
|
|
136
|
-
most_common_label: str = Counter(overlapping_labels).most_common(1)[0][
|
|
143
|
+
most_common_label: str = Counter(overlapping_labels).most_common(1)[0][
|
|
144
|
+
0
|
|
145
|
+
]
|
|
137
146
|
result.append((token, most_common_label))
|
|
138
147
|
else:
|
|
139
148
|
# No overlap found - use default
|
|
140
|
-
result.append((token,
|
|
149
|
+
result.append((token, "C"))
|
|
141
150
|
|
|
142
151
|
return result
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
SPACY_MODELS: dict[str, list[str]] = {
|
|
2
|
+
"de": ["de_dep_news_trf", "de_core_news_lg", "de_core_news_md", "de_core_news_sm"],
|
|
3
|
+
"en": ["en_core_web_trf", "en_core_web_lg", "en_core_web_md", "en_core_web_sm"],
|
|
4
|
+
"es": ["es_dep_news_trf", "es_core_news_lg", "es_core_news_md", "es_core_news_sm"],
|
|
5
|
+
"fr": ["fr_dep_news_trf", "fr_core_news_lg", "fr_core_news_md", "fr_core_news_sm"],
|
|
6
|
+
"it": ["it_core_news_lg", "it_core_news_md", "it_core_news_sm"],
|
|
7
|
+
"pt": ["pt_core_news_lg", "pt_core_news_md", "pt_core_news_sm"],
|
|
8
|
+
"zh": ["zh_core_news_trf", "zh_core_news_lg", "zh_core_news_md", "zh_core_news_sm"],
|
|
9
|
+
}
|