renard-pipeline 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of renard-pipeline might be problematic. Click here for more details.
- renard/graph_utils.py +11 -4
- renard/ner_utils.py +24 -14
- renard/pipeline/character_unification.py +62 -19
- renard/pipeline/characters_extraction.py +3 -1
- renard/pipeline/core.py +141 -26
- renard/pipeline/corefs/corefs.py +32 -33
- renard/pipeline/graph_extraction.py +281 -192
- renard/pipeline/ner/__init__.py +1 -0
- renard/pipeline/{ner.py → ner/ner.py} +47 -76
- renard/pipeline/ner/retrieval.py +375 -0
- renard/pipeline/progress.py +32 -1
- renard/pipeline/speaker_attribution.py +2 -3
- renard/pipeline/tokenization.py +59 -30
- renard/plot_utils.py +48 -28
- renard/resources/determiners/__init__.py +1 -0
- renard/resources/determiners/determiners.py +41 -0
- renard/resources/hypocorisms/hypocorisms.py +3 -2
- renard/utils.py +57 -1
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.6.0.dist-info}/METADATA +45 -20
- renard_pipeline-0.6.0.dist-info/RECORD +39 -0
- renard_pipeline-0.4.2.dist-info/RECORD +0 -35
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.6.0.dist-info}/LICENSE +0 -0
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.6.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from renard.pipeline.ner.ner import *
|
|
@@ -1,21 +1,32 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
import
|
|
3
|
-
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
List,
|
|
5
|
+
Dict,
|
|
6
|
+
Any,
|
|
7
|
+
Set,
|
|
8
|
+
Tuple,
|
|
9
|
+
Optional,
|
|
10
|
+
Union,
|
|
11
|
+
Literal,
|
|
12
|
+
)
|
|
4
13
|
from dataclasses import dataclass
|
|
5
14
|
import torch
|
|
6
|
-
from seqeval.metrics import precision_score, recall_score, f1_score
|
|
7
15
|
from renard.nltk_utils import nltk_fix_bio_tags
|
|
8
16
|
from renard.ner_utils import (
|
|
9
17
|
DataCollatorForTokenClassificationWithBatchEncoding,
|
|
10
18
|
NERDataset,
|
|
11
19
|
)
|
|
12
20
|
from renard.pipeline.core import PipelineStep, Mention
|
|
13
|
-
from renard.pipeline.progress import ProgressReporter
|
|
14
21
|
from renard.ner_utils import ner_entities
|
|
15
22
|
|
|
16
23
|
if TYPE_CHECKING:
|
|
17
24
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
18
|
-
from transformers import
|
|
25
|
+
from transformers import (
|
|
26
|
+
PreTrainedModel,
|
|
27
|
+
PreTrainedTokenizerFast,
|
|
28
|
+
)
|
|
29
|
+
from renard.pipeline.ner.retrieval import NERContextRetriever
|
|
19
30
|
|
|
20
31
|
|
|
21
32
|
@dataclass
|
|
@@ -27,7 +38,7 @@ class NEREntity(Mention):
|
|
|
27
38
|
"""
|
|
28
39
|
.. note::
|
|
29
40
|
|
|
30
|
-
This method is
|
|
41
|
+
This method is implemented here to avoid type issues. Since
|
|
31
42
|
:meth:`.Mention.shifted` cannot be annotated as returning
|
|
32
43
|
``Self``, this method annotate the correct return type when
|
|
33
44
|
using :meth:`.NEREntity.shifted`.
|
|
@@ -41,18 +52,21 @@ class NEREntity(Mention):
|
|
|
41
52
|
def score_ner(
|
|
42
53
|
pred_bio_tags: List[str], ref_bio_tags: List[str]
|
|
43
54
|
) -> Tuple[float, float, float]:
|
|
44
|
-
"""Score NER as in CoNLL-2003 shared task using ``seqeval``
|
|
55
|
+
"""Score NER as in CoNLL-2003 shared task using the ``seqeval``
|
|
56
|
+
library, if installed.
|
|
45
57
|
|
|
46
58
|
Precision is the percentage of named entities in ``ref_bio_tags``
|
|
47
|
-
that are correct.
|
|
48
|
-
pred_bio_tags that are in ref_bio_tags.
|
|
49
|
-
both.
|
|
59
|
+
that are correct. Recall is the percentage of named entities in
|
|
60
|
+
pred_bio_tags that are in ref_bio_tags. F1 is the harmonic mean
|
|
61
|
+
of both.
|
|
50
62
|
|
|
51
63
|
:param pred_bio_tags:
|
|
52
64
|
:param ref_bio_tags:
|
|
53
|
-
:return: ``(precision, recall, F1 score)``
|
|
54
65
|
|
|
66
|
+
:return: ``(precision, recall, F1 score)``
|
|
55
67
|
"""
|
|
68
|
+
from seqeval.metrics import precision_score, recall_score, f1_score
|
|
69
|
+
|
|
56
70
|
assert len(pred_bio_tags) == len(ref_bio_tags)
|
|
57
71
|
return (
|
|
58
72
|
precision_score([ref_bio_tags], [pred_bio_tags]),
|
|
@@ -70,12 +84,19 @@ class NLTKNamedEntityRecognizer(PipelineStep):
|
|
|
70
84
|
"""
|
|
71
85
|
import nltk
|
|
72
86
|
|
|
73
|
-
nltk.download("averaged_perceptron_tagger", quiet=True)
|
|
87
|
+
nltk.download(f"averaged_perceptron_tagger", quiet=True)
|
|
74
88
|
nltk.download("maxent_ne_chunker", quiet=True)
|
|
89
|
+
nltk.download("maxent_ne_chunker_tab", quiet=True)
|
|
75
90
|
nltk.download("words", quiet=True)
|
|
76
91
|
|
|
77
92
|
super().__init__()
|
|
78
93
|
|
|
94
|
+
def _pipeline_init_(self, lang: str, **kwargs):
|
|
95
|
+
import nltk
|
|
96
|
+
|
|
97
|
+
nltk.download(f"averaged_perceptron_tagger_{lang}", quiet=True)
|
|
98
|
+
super()._pipeline_init_(lang, **kwargs)
|
|
99
|
+
|
|
79
100
|
def __call__(self, tokens: List[str], **kwargs) -> Dict[str, Any]:
|
|
80
101
|
"""
|
|
81
102
|
:param text:
|
|
@@ -101,64 +122,6 @@ class NLTKNamedEntityRecognizer(PipelineStep):
|
|
|
101
122
|
return {"entities"}
|
|
102
123
|
|
|
103
124
|
|
|
104
|
-
class NERContextRetriever:
|
|
105
|
-
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
106
|
-
raise NotImplementedError
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class NERSamenounContextRetriever(NERContextRetriever):
|
|
110
|
-
"""
|
|
111
|
-
Retrieve relevant context using the samenoun strategy as in
|
|
112
|
-
Amalvy et al. 2023.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def __init__(self, k: int) -> None:
|
|
116
|
-
"""
|
|
117
|
-
:param k: the number of sentences to retrieve
|
|
118
|
-
"""
|
|
119
|
-
self.k = k
|
|
120
|
-
|
|
121
|
-
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
122
|
-
import nltk
|
|
123
|
-
|
|
124
|
-
# NOTE: POS tagging is not incorporated in the pipeline yet,
|
|
125
|
-
# so we manually compute it here.
|
|
126
|
-
elements_names = [
|
|
127
|
-
{t[0] for t in nltk.pos_tag(element) if t[1].startswith("NN")}
|
|
128
|
-
for element in dataset.elements
|
|
129
|
-
]
|
|
130
|
-
|
|
131
|
-
elements_with_context = []
|
|
132
|
-
|
|
133
|
-
for elt_i, elt in enumerate(dataset.elements):
|
|
134
|
-
retrieved_elts = [
|
|
135
|
-
other_elt
|
|
136
|
-
for other_elt_i, other_elt in enumerate(dataset.elements)
|
|
137
|
-
if not other_elt_i == elt_i
|
|
138
|
-
and len(elements_names[elt_i].intersection(elements_names[other_elt_i]))
|
|
139
|
-
> 0
|
|
140
|
-
]
|
|
141
|
-
retrieved_elts = random.sample(
|
|
142
|
-
retrieved_elts, k=min(self.k, len(retrieved_elts))
|
|
143
|
-
)
|
|
144
|
-
elements_with_context.append(
|
|
145
|
-
(
|
|
146
|
-
elt,
|
|
147
|
-
[dataset.tokenizer.sep_token]
|
|
148
|
-
+ list(itertools.chain.from_iterable(retrieved_elts)),
|
|
149
|
-
)
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
return NERDataset(
|
|
153
|
-
[element + context for element, context in elements_with_context],
|
|
154
|
-
dataset.tokenizer,
|
|
155
|
-
[
|
|
156
|
-
[0] * len(element) + [1] * len(context)
|
|
157
|
-
for element, context in elements_with_context
|
|
158
|
-
],
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
|
|
162
125
|
class BertNamedEntityRecognizer(PipelineStep):
|
|
163
126
|
"""An entity recognizer based on BERT"""
|
|
164
127
|
|
|
@@ -214,10 +177,10 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
214
177
|
|
|
215
178
|
super().__init__()
|
|
216
179
|
|
|
217
|
-
def _pipeline_init_(self, lang: str,
|
|
180
|
+
def _pipeline_init_(self, lang: str, **kwargs):
|
|
218
181
|
from transformers import AutoModelForTokenClassification, AutoTokenizer # type: ignore
|
|
219
182
|
|
|
220
|
-
super()._pipeline_init_(lang,
|
|
183
|
+
super()._pipeline_init_(lang, **kwargs)
|
|
221
184
|
|
|
222
185
|
# init model if needed (this happens if the user did not pass
|
|
223
186
|
# the instance of a model)
|
|
@@ -306,7 +269,7 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
306
269
|
batch_i: int,
|
|
307
270
|
wp_labels: List[str],
|
|
308
271
|
tokens: List[str],
|
|
309
|
-
|
|
272
|
+
ctxmask: torch.Tensor,
|
|
310
273
|
) -> List[str]:
|
|
311
274
|
"""Align labels to tokens rather than wordpiece tokens.
|
|
312
275
|
|
|
@@ -317,13 +280,21 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
317
280
|
"""
|
|
318
281
|
batch_labels = ["O"] * len(tokens)
|
|
319
282
|
|
|
283
|
+
try:
|
|
284
|
+
inference_start = ctxmask[batch_i].tolist().index(1)
|
|
285
|
+
except ValueError:
|
|
286
|
+
inference_start = 0
|
|
287
|
+
|
|
320
288
|
for wplabel_j, wp_label in enumerate(wp_labels):
|
|
321
|
-
|
|
322
|
-
continue
|
|
289
|
+
|
|
323
290
|
token_i = batchs.token_to_word(batch_i, wplabel_j)
|
|
324
291
|
if token_i is None:
|
|
325
292
|
continue
|
|
326
|
-
|
|
293
|
+
|
|
294
|
+
if ctxmask[batch_i][token_i] == 0:
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
batch_labels[token_i - inference_start] = wp_label
|
|
327
298
|
|
|
328
299
|
return batch_labels
|
|
329
300
|
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
from collections.abc import Set
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Union, List, cast, Literal, Optional
|
|
4
|
+
import random
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from more_itertools import flatten
|
|
7
|
+
from renard.ner_utils import NERDataset
|
|
8
|
+
import nltk
|
|
9
|
+
from rank_bm25 import BM25Okapi
|
|
10
|
+
from transformers import (
|
|
11
|
+
BertForSequenceClassification,
|
|
12
|
+
BertTokenizerFast,
|
|
13
|
+
DataCollatorWithPadding,
|
|
14
|
+
)
|
|
15
|
+
from transformers.tokenization_utils_base import BatchEncoding
|
|
16
|
+
import torch
|
|
17
|
+
from torch.utils.data import Dataset, DataLoader
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class NERContextRetrievalMatch:
|
|
22
|
+
element: List[str]
|
|
23
|
+
element_i: int
|
|
24
|
+
side: Literal["left", "right"]
|
|
25
|
+
score: Optional[float]
|
|
26
|
+
|
|
27
|
+
def __hash__(self) -> int:
|
|
28
|
+
return hash(tuple(self.element) + (self.element_i, self.side, self.score))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class NERContextRetriever:
|
|
32
|
+
def __init__(self, k: int) -> None:
|
|
33
|
+
self.k = k
|
|
34
|
+
|
|
35
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
36
|
+
return {}
|
|
37
|
+
|
|
38
|
+
def retrieve(
|
|
39
|
+
self, element_i: int, elements: List[List[str]], **kwargs
|
|
40
|
+
) -> List[NERContextRetrievalMatch]:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
44
|
+
# [(left_ctx, element, right_ctx), ...]
|
|
45
|
+
elements_with_context = []
|
|
46
|
+
|
|
47
|
+
global_features = self.compute_global_features(dataset.elements)
|
|
48
|
+
|
|
49
|
+
for elt_i, elt in enumerate(dataset.elements):
|
|
50
|
+
matchs = self.retrieve(elt_i, dataset.elements, **global_features)
|
|
51
|
+
assert len(matchs) <= self.k
|
|
52
|
+
|
|
53
|
+
lctx = sorted(
|
|
54
|
+
(m for m in matchs if m.side == "left"),
|
|
55
|
+
key=lambda m: m.element_i,
|
|
56
|
+
)
|
|
57
|
+
lctx = list(flatten([m.element for m in lctx]))
|
|
58
|
+
|
|
59
|
+
rctx = sorted(
|
|
60
|
+
(m for m in matchs if m.side == "right"),
|
|
61
|
+
key=lambda m: m.element_i,
|
|
62
|
+
)
|
|
63
|
+
rctx = list(flatten([m.element for m in rctx]))
|
|
64
|
+
|
|
65
|
+
elements_with_context.append((lctx, elt, rctx))
|
|
66
|
+
|
|
67
|
+
return NERDataset(
|
|
68
|
+
[lctx + element + rctx for lctx, element, rctx in elements_with_context],
|
|
69
|
+
dataset.tokenizer,
|
|
70
|
+
[
|
|
71
|
+
[1] * len(lctx) + [0] * len(element) + [1] * len(rctx)
|
|
72
|
+
for lctx, element, rctx in elements_with_context
|
|
73
|
+
],
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class NERSamenounContextRetriever(NERContextRetriever):
|
|
78
|
+
"""
|
|
79
|
+
Retrieve relevant context using the samenoun strategy as in
|
|
80
|
+
Amalvy et al. 2023.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, k: int) -> None:
|
|
84
|
+
"""
|
|
85
|
+
:param k: the max number of sentences to retrieve
|
|
86
|
+
"""
|
|
87
|
+
super().__init__(k)
|
|
88
|
+
|
|
89
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
90
|
+
return {
|
|
91
|
+
"NNs": [
|
|
92
|
+
{t[0] for t in nltk.pos_tag(element) if t[1] == "NN"}
|
|
93
|
+
for element in elements
|
|
94
|
+
]
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def retrieve(
|
|
98
|
+
self, element_i: int, elements: List[List[str]], NNs: List[Set[str]], **kwargs
|
|
99
|
+
) -> List[NERContextRetrievalMatch]:
|
|
100
|
+
matchs = [
|
|
101
|
+
NERContextRetrievalMatch(
|
|
102
|
+
other_elt,
|
|
103
|
+
other_elt_i,
|
|
104
|
+
"left" if other_elt_i < element_i else "right",
|
|
105
|
+
None,
|
|
106
|
+
)
|
|
107
|
+
for other_elt_i, other_elt in enumerate(elements)
|
|
108
|
+
if not other_elt_i == element_i
|
|
109
|
+
and len(NNs[element_i].intersection(NNs[other_elt_i])) > 0 # type: ignore
|
|
110
|
+
]
|
|
111
|
+
return random.sample(matchs, k=min(self.k, len(matchs)))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class NERNeighborsContextRetriever(NERContextRetriever):
|
|
115
|
+
"""A context retriever that chooses nearby elements."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, k: int):
|
|
118
|
+
assert k % 2 == 0
|
|
119
|
+
super().__init__(k)
|
|
120
|
+
|
|
121
|
+
def retrieve(
|
|
122
|
+
self, element_i: int, elements: List[List[str]], **kwargs
|
|
123
|
+
) -> List[NERContextRetrievalMatch]:
|
|
124
|
+
left_nb = self.k // 2
|
|
125
|
+
right_nb = left_nb
|
|
126
|
+
|
|
127
|
+
lctx = []
|
|
128
|
+
for i, elt in enumerate(elements[element_i - left_nb : element_i]):
|
|
129
|
+
lctx.append(
|
|
130
|
+
NERContextRetrievalMatch(elt, element_i - left_nb + i, "left", None)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
rctx = []
|
|
134
|
+
for i, elt in enumerate(elements[element_i + 1 : element_i + 1 + right_nb]):
|
|
135
|
+
rctx.append(NERContextRetrievalMatch(elt, element_i + 1 + i, "right", None))
|
|
136
|
+
|
|
137
|
+
return lctx + rctx
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class NERBM25ContextRetriever(NERContextRetriever):
|
|
141
|
+
"""A context retriever that selects elements according to the BM25 ranking formula."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, k: int) -> None:
|
|
144
|
+
super().__init__(k)
|
|
145
|
+
|
|
146
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
147
|
+
return {"bm25_model": BM25Okapi(elements)}
|
|
148
|
+
|
|
149
|
+
def retrieve(
|
|
150
|
+
self, element_i: int, elements: List[List[str]], bm25_model: BM25Okapi, **kwargs
|
|
151
|
+
) -> List[NERContextRetrievalMatch]:
|
|
152
|
+
query = elements[element_i]
|
|
153
|
+
sent_scores = bm25_model.get_scores(query)
|
|
154
|
+
sent_scores[element_i] = float("-Inf") # don't retrieve self
|
|
155
|
+
topk_values, topk_indexs = torch.topk(
|
|
156
|
+
torch.tensor(sent_scores), k=min(self.k, len(sent_scores)), dim=0
|
|
157
|
+
)
|
|
158
|
+
return [
|
|
159
|
+
NERContextRetrievalMatch(
|
|
160
|
+
elements[index], index, "left" if index < element_i else "right", value
|
|
161
|
+
)
|
|
162
|
+
for value, index in zip(topk_values.tolist(), topk_indexs.tolist())
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(frozen=True)
|
|
167
|
+
class NERNeuralContextRetrievalExample:
|
|
168
|
+
"""A context retrieval example."""
|
|
169
|
+
|
|
170
|
+
#: text on which NER is performed
|
|
171
|
+
element: List[str]
|
|
172
|
+
#: context to assist during prediction
|
|
173
|
+
context: List[str]
|
|
174
|
+
#: context side (does the context comes from the left or the right of ``sent`` ?)
|
|
175
|
+
context_side: Literal["left", "right"]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class NERNeuralContextRetrievalDataset(Dataset):
|
|
179
|
+
""""""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
examples: List[NERNeuralContextRetrievalExample],
|
|
184
|
+
tokenizer: BertTokenizerFast,
|
|
185
|
+
) -> None:
|
|
186
|
+
self.examples = examples
|
|
187
|
+
self.tokenizer: BertTokenizerFast = tokenizer
|
|
188
|
+
|
|
189
|
+
def __len__(self) -> int:
|
|
190
|
+
return len(self.examples)
|
|
191
|
+
|
|
192
|
+
def __getitem__(self, index: int) -> BatchEncoding:
|
|
193
|
+
"""Get a BatchEncoding representing example at index.
|
|
194
|
+
|
|
195
|
+
:param index: index of the example to retrieve
|
|
196
|
+
|
|
197
|
+
:return: a ``BatchEncoding``, with key ``'label'`` set.
|
|
198
|
+
"""
|
|
199
|
+
example = self.examples[index]
|
|
200
|
+
|
|
201
|
+
tokens = example.context + ["[SEP]"] + example.element
|
|
202
|
+
|
|
203
|
+
batch: BatchEncoding = self.tokenizer(
|
|
204
|
+
tokens,
|
|
205
|
+
is_split_into_words=True,
|
|
206
|
+
truncation=True,
|
|
207
|
+
max_length=512,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return batch
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class NERNeuralContextRetriever(NERContextRetriever):
|
|
214
|
+
"""
|
|
215
|
+
A neural context retriever as in Amalvy et al. 2024
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
heuristic_context_selector: NERContextRetriever,
|
|
221
|
+
pretrained_model: Union[
|
|
222
|
+
str, BertForSequenceClassification
|
|
223
|
+
] = "compnet-renard/bert-base-cased-NER-reranker",
|
|
224
|
+
k: int = 3,
|
|
225
|
+
batch_size: int = 1,
|
|
226
|
+
threshold: float = 0.0,
|
|
227
|
+
device_str: Literal["cuda", "cpu", "auto"] = "auto",
|
|
228
|
+
) -> None:
|
|
229
|
+
"""
|
|
230
|
+
:param pretrained_model: pretrained model name, used to
|
|
231
|
+
load a :class:`transformers.BertForSequenceClassification`
|
|
232
|
+
:param heuristic_context_selector: name of the context
|
|
233
|
+
selector to use as retrieval heuristic, from
|
|
234
|
+
``context_selector_name_to_class``
|
|
235
|
+
:param heuristic_context_selector_kwargs: kwargs to pass the
|
|
236
|
+
heuristic context retriever at instantiation time
|
|
237
|
+
:param k: max number of sents to retrieve
|
|
238
|
+
:param batch_size: batch size used at inference
|
|
239
|
+
:param threshold:
|
|
240
|
+
:param device_str:
|
|
241
|
+
"""
|
|
242
|
+
from transformers import BertForSequenceClassification, BertTokenizerFast
|
|
243
|
+
|
|
244
|
+
if isinstance(pretrained_model, str):
|
|
245
|
+
self.ctx_classifier = BertForSequenceClassification.from_pretrained(
|
|
246
|
+
pretrained_model
|
|
247
|
+
) # type: ignore
|
|
248
|
+
else:
|
|
249
|
+
self.ctx_classifier = pretrained_model
|
|
250
|
+
self.ctx_classifier = cast(BertForSequenceClassification, self.ctx_classifier)
|
|
251
|
+
|
|
252
|
+
self.tokenizer = BertTokenizerFast.from_pretrained(
|
|
253
|
+
pretrained_model if isinstance(pretrained_model, str) else "bert-base-cased"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
self.heuristic_context_selector = heuristic_context_selector
|
|
257
|
+
|
|
258
|
+
self.batch_size = batch_size
|
|
259
|
+
self.threshold = threshold
|
|
260
|
+
|
|
261
|
+
if device_str == "auto":
|
|
262
|
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
263
|
+
self.device = torch.device(device_str)
|
|
264
|
+
|
|
265
|
+
super().__init__(k)
|
|
266
|
+
|
|
267
|
+
def set_heuristic_k_(self, k: int):
|
|
268
|
+
self.heuristic_context_selector.k = k
|
|
269
|
+
|
|
270
|
+
def predict(self, examples: List[NERNeuralContextRetrievalExample]) -> torch.Tensor:
|
|
271
|
+
"""
|
|
272
|
+
:param dataset: A list of :class:`ContextSelectionExample`
|
|
273
|
+
:return: A tensor of shape ``(len(dataset), 2)`` of class
|
|
274
|
+
scores
|
|
275
|
+
"""
|
|
276
|
+
dataset = NERNeuralContextRetrievalDataset(examples, self.tokenizer)
|
|
277
|
+
|
|
278
|
+
self.ctx_classifier = self.ctx_classifier.to(self.device)
|
|
279
|
+
|
|
280
|
+
data_collator = DataCollatorWithPadding(dataset.tokenizer) # type: ignore
|
|
281
|
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, collate_fn=data_collator) # type: ignore
|
|
282
|
+
|
|
283
|
+
# inference using self.ctx_classifier
|
|
284
|
+
self.ctx_classifier = self.ctx_classifier.eval()
|
|
285
|
+
with torch.no_grad():
|
|
286
|
+
scores = torch.zeros((0,)).to(self.device)
|
|
287
|
+
for X in dataloader:
|
|
288
|
+
X = X.to(self.device)
|
|
289
|
+
# out.logits is of shape (batch_size, 2)
|
|
290
|
+
out = self.ctx_classifier(
|
|
291
|
+
X["input_ids"],
|
|
292
|
+
token_type_ids=X["token_type_ids"],
|
|
293
|
+
attention_mask=X["attention_mask"],
|
|
294
|
+
)
|
|
295
|
+
# (batch_size, 2)
|
|
296
|
+
pred = torch.softmax(out.logits, dim=1)
|
|
297
|
+
scores = torch.cat([scores, pred], dim=0)
|
|
298
|
+
|
|
299
|
+
return scores
|
|
300
|
+
|
|
301
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
302
|
+
features = self.heuristic_context_selector.compute_global_features(elements)
|
|
303
|
+
return {
|
|
304
|
+
"heuristic_matchs": [
|
|
305
|
+
self.heuristic_context_selector.retrieve(i, elements, **features)
|
|
306
|
+
for i in range(len(elements))
|
|
307
|
+
]
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
def retrieve(
|
|
311
|
+
self,
|
|
312
|
+
element_i: int,
|
|
313
|
+
elements: List[List[str]],
|
|
314
|
+
heuristic_matchs: List[List[NERContextRetrievalMatch]],
|
|
315
|
+
**kwargs,
|
|
316
|
+
) -> List[NERContextRetrievalMatch]:
|
|
317
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
318
|
+
self.ctx_classifier = self.ctx_classifier.to(device) # type: ignore
|
|
319
|
+
|
|
320
|
+
# no context retrieved by heuristic : nothing to do
|
|
321
|
+
if len(heuristic_matchs) == 0:
|
|
322
|
+
return []
|
|
323
|
+
|
|
324
|
+
element = elements[element_i]
|
|
325
|
+
matchs = heuristic_matchs[element_i]
|
|
326
|
+
|
|
327
|
+
# prepare datas for inference
|
|
328
|
+
ctx_dataset = [
|
|
329
|
+
NERNeuralContextRetrievalExample(element, m.element, m.side) for m in matchs
|
|
330
|
+
]
|
|
331
|
+
|
|
332
|
+
# (len(dataset), 2)
|
|
333
|
+
scores = self.predict(ctx_dataset)
|
|
334
|
+
for i, m in enumerate(matchs):
|
|
335
|
+
m.score = float(scores[i, 1].item())
|
|
336
|
+
|
|
337
|
+
assert all([not m.score is None for m in matchs])
|
|
338
|
+
return [
|
|
339
|
+
m
|
|
340
|
+
for m in sorted(matchs, key=lambda m: -m.score)[: self.k] # type: ignore
|
|
341
|
+
if m.score > self.threshold # type: ignore
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class NEREnsembleContextRetriever(NERContextRetriever):
|
|
346
|
+
"""Combine several context retriever"""
|
|
347
|
+
|
|
348
|
+
def __init__(self, retrievers: List[NERContextRetriever], k: int) -> None:
|
|
349
|
+
self.retrievers = retrievers
|
|
350
|
+
super().__init__(k)
|
|
351
|
+
|
|
352
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
353
|
+
features = {}
|
|
354
|
+
for retriever in self.retrievers:
|
|
355
|
+
for k, v in retriever.compute_global_features(elements).items():
|
|
356
|
+
if k in features:
|
|
357
|
+
print(
|
|
358
|
+
f"[warning] NEREnsembleContextRetriver: incompatible global feature '{k}' between multiple retrievers.",
|
|
359
|
+
file=sys.stderr,
|
|
360
|
+
)
|
|
361
|
+
features[k] = v
|
|
362
|
+
return features
|
|
363
|
+
|
|
364
|
+
def retrieve(
|
|
365
|
+
self, element_i: int, elements: List[List[str]], **kwargs
|
|
366
|
+
) -> List[NERContextRetrievalMatch]:
|
|
367
|
+
all_matchs = set()
|
|
368
|
+
|
|
369
|
+
for retriever in self.retrievers:
|
|
370
|
+
matchs = retriever.retrieve(element_i, elements, **kwargs)
|
|
371
|
+
all_matchs = all_matchs.union(matchs)
|
|
372
|
+
|
|
373
|
+
if all(not m.score is None for m in all_matchs):
|
|
374
|
+
return sorted(all_matchs, key=lambda m: -m.score)[: self.k] # type: ignore
|
|
375
|
+
return random.choices(list(all_matchs), k=self.k)
|
renard/pipeline/progress.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
1
2
|
from typing import Iterable, Literal, Optional, TypeVar, Generator
|
|
3
|
+
import sys
|
|
2
4
|
from tqdm import tqdm
|
|
3
5
|
|
|
4
6
|
|
|
@@ -20,6 +22,10 @@ class ProgressReporter:
|
|
|
20
22
|
"""Update reporter current message."""
|
|
21
23
|
pass
|
|
22
24
|
|
|
25
|
+
def get_subreporter(self) -> ProgressReporter:
|
|
26
|
+
"""Get the subreporter corresponding to that reporter."""
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
23
29
|
|
|
24
30
|
class NoopProgressReporter(ProgressReporter):
|
|
25
31
|
def reset_(self):
|
|
@@ -28,6 +34,28 @@ class NoopProgressReporter(ProgressReporter):
|
|
|
28
34
|
def update_progress_(self, added_progress: int):
|
|
29
35
|
pass
|
|
30
36
|
|
|
37
|
+
def get_subreporter(self) -> ProgressReporter:
|
|
38
|
+
return NoopProgressReporter()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TQDMSubProgressReporter(ProgressReporter):
|
|
42
|
+
def __init__(self, reporter: TQDMProgressReporter) -> None:
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.reporter = reporter
|
|
45
|
+
|
|
46
|
+
def start_(self, total: int):
|
|
47
|
+
super().start_(total)
|
|
48
|
+
self.progress = 0
|
|
49
|
+
|
|
50
|
+
def update_progress_(self, added_progress: int):
|
|
51
|
+
self.progress += added_progress
|
|
52
|
+
self.reporter.tqdm.set_postfix(step=f"({self.progress}/{self.total})")
|
|
53
|
+
|
|
54
|
+
def update_message_(self, message: str):
|
|
55
|
+
self.reporter.tqdm.set_postfix(
|
|
56
|
+
step=f"({self.progress}/{self.total})", message=message
|
|
57
|
+
)
|
|
58
|
+
|
|
31
59
|
|
|
32
60
|
class TQDMProgressReporter(ProgressReporter):
|
|
33
61
|
def start_(self, total: int):
|
|
@@ -40,6 +68,9 @@ class TQDMProgressReporter(ProgressReporter):
|
|
|
40
68
|
def update_message_(self, message: str):
|
|
41
69
|
self.tqdm.set_description_str(message)
|
|
42
70
|
|
|
71
|
+
def get_subreporter(self) -> ProgressReporter:
|
|
72
|
+
return TQDMSubProgressReporter(self)
|
|
73
|
+
|
|
43
74
|
|
|
44
75
|
T = TypeVar("T")
|
|
45
76
|
|
|
@@ -62,5 +93,5 @@ def get_progress_reporter(name: Optional[Literal["tqdm"]]) -> ProgressReporter:
|
|
|
62
93
|
return NoopProgressReporter()
|
|
63
94
|
if name == "tqdm":
|
|
64
95
|
return TQDMProgressReporter()
|
|
65
|
-
print(f"[warning] unknown progress reporter: {name}")
|
|
96
|
+
print(f"[warning] unknown progress reporter: {name}", file=sys.stderr)
|
|
66
97
|
return NoopProgressReporter()
|
|
@@ -49,13 +49,12 @@ class BertSpeakerDetector(PipelineStep):
|
|
|
49
49
|
|
|
50
50
|
super().__init__()
|
|
51
51
|
|
|
52
|
-
def _pipeline_init_(self, lang: str,
|
|
52
|
+
def _pipeline_init_(self, lang: str, **kwargs):
|
|
53
53
|
from transformers import AutoTokenizer
|
|
54
54
|
|
|
55
|
-
super()._pipeline_init_(lang,
|
|
55
|
+
super()._pipeline_init_(lang, **kwargs)
|
|
56
56
|
|
|
57
57
|
if self.model is None:
|
|
58
|
-
|
|
59
58
|
# the user supplied a huggingface ID: load model from the HUB
|
|
60
59
|
if not self.huggingface_model_id is None:
|
|
61
60
|
self.model = SpeakerAttributionModel.from_pretrained(
|