renard-pipeline 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of renard-pipeline might be problematic. Click here for more details.
- renard/ner_utils.py +24 -18
- renard/pipeline/character_unification.py +48 -15
- renard/pipeline/core.py +20 -0
- renard/pipeline/corefs/corefs.py +2 -2
- renard/pipeline/ner/__init__.py +1 -0
- renard/pipeline/{ner.py → ner/ner.py} +45 -75
- renard/pipeline/ner/retrieval.py +375 -0
- renard/pipeline/tokenization.py +1 -0
- renard/plot_utils.py +10 -3
- renard/resources/determiners/__init__.py +1 -0
- renard/resources/determiners/determiners.py +41 -0
- {renard_pipeline-0.5.0.dist-info → renard_pipeline-0.6.1.dist-info}/METADATA +21 -20
- {renard_pipeline-0.5.0.dist-info → renard_pipeline-0.6.1.dist-info}/RECORD +15 -11
- {renard_pipeline-0.5.0.dist-info → renard_pipeline-0.6.1.dist-info}/LICENSE +0 -0
- {renard_pipeline-0.5.0.dist-info → renard_pipeline-0.6.1.dist-info}/WHEEL +0 -0
renard/ner_utils.py
CHANGED
|
@@ -74,7 +74,7 @@ class DataCollatorForTokenClassificationWithBatchEncoding:
|
|
|
74
74
|
class NERDataset(Dataset):
|
|
75
75
|
"""
|
|
76
76
|
:ivar _context_mask: for each element, a mask indicating which
|
|
77
|
-
tokens are part of the context (
|
|
77
|
+
tokens are part of the context (0 for context, 1 for text on
|
|
78
78
|
which to perform inference). The mask allows to discard
|
|
79
79
|
predictions made for context at inference time, even though
|
|
80
80
|
the context can still be passed as input to the model.
|
|
@@ -92,11 +92,11 @@ class NERDataset(Dataset):
|
|
|
92
92
|
assert all(
|
|
93
93
|
[len(cm) == len(elt) for elt, cm in zip(self.elements, context_mask)]
|
|
94
94
|
)
|
|
95
|
-
self._context_mask = context_mask or [[
|
|
95
|
+
self._context_mask = context_mask or [[1] * len(elt) for elt in self.elements]
|
|
96
96
|
|
|
97
97
|
self.tokenizer = tokenizer
|
|
98
98
|
|
|
99
|
-
def __getitem__(self, index:
|
|
99
|
+
def __getitem__(self, index: int) -> BatchEncoding:
|
|
100
100
|
element = self.elements[index]
|
|
101
101
|
|
|
102
102
|
batch = self.tokenizer(
|
|
@@ -104,19 +104,18 @@ class NERDataset(Dataset):
|
|
|
104
104
|
truncation=True,
|
|
105
105
|
max_length=512, # TODO
|
|
106
106
|
is_split_into_words=True,
|
|
107
|
+
return_length=True,
|
|
107
108
|
)
|
|
108
109
|
|
|
109
|
-
batch["
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
tokens_mask = [mask_value] * (w2t.end - w2t.start)
|
|
119
|
-
batch["context_mask"][w2t.start : w2t.end] = tokens_mask
|
|
110
|
+
length = batch["length"][0]
|
|
111
|
+
del batch["length"]
|
|
112
|
+
if self.tokenizer.truncation_side == "right":
|
|
113
|
+
batch["context_mask"] = self._context_mask[index][:length]
|
|
114
|
+
else:
|
|
115
|
+
assert self.tokenizer.truncation_side == "left"
|
|
116
|
+
batch["context_mask"] = self._context_mask[index][
|
|
117
|
+
len(batch["input_ids"]) - length :
|
|
118
|
+
]
|
|
120
119
|
|
|
121
120
|
return batch
|
|
122
121
|
|
|
@@ -185,6 +184,7 @@ def load_conll2002_bio(
|
|
|
185
184
|
path: str,
|
|
186
185
|
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
187
186
|
separator: str = "\t",
|
|
187
|
+
max_sent_len: Optional[int] = None,
|
|
188
188
|
**kwargs,
|
|
189
189
|
) -> Tuple[List[List[str]], List[str], List[NEREntity]]:
|
|
190
190
|
"""Load a file under CoNLL2022 BIO format. Sentences are expected
|
|
@@ -196,7 +196,9 @@ def load_conll2002_bio(
|
|
|
196
196
|
:param separator: separator between token and BIO tags
|
|
197
197
|
:param tag_conversion_map: conversion map for tags found in the
|
|
198
198
|
input file. Example : ``{'B': 'B-PER', 'I': 'I-PER'}``
|
|
199
|
-
:param
|
|
199
|
+
:param max_sent_len: if specified, maximum length, in tokens, of
|
|
200
|
+
sentences.
|
|
201
|
+
:param kwargs: additional kwargs for :func:`open` (such as
|
|
200
202
|
``encoding`` or ``newline``).
|
|
201
203
|
|
|
202
204
|
:return: ``(sentences, tokens, entities)``
|
|
@@ -211,7 +213,9 @@ def load_conll2002_bio(
|
|
|
211
213
|
tags = []
|
|
212
214
|
for line in raw_data.split("\n"):
|
|
213
215
|
line = line.strip("\n")
|
|
214
|
-
if re.fullmatch(r"\s*", line)
|
|
216
|
+
if re.fullmatch(r"\s*", line) or (
|
|
217
|
+
not max_sent_len is None and len(sent_tokens) >= max_sent_len
|
|
218
|
+
):
|
|
215
219
|
if len(sent_tokens) == 0:
|
|
216
220
|
continue
|
|
217
221
|
sents.append(sent_tokens)
|
|
@@ -231,6 +235,7 @@ def hgdataset_from_conll2002(
|
|
|
231
235
|
path: str,
|
|
232
236
|
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
233
237
|
separator: str = "\t",
|
|
238
|
+
max_sent_len: Optional[int] = None,
|
|
234
239
|
**kwargs,
|
|
235
240
|
) -> HGDataset:
|
|
236
241
|
"""Load a CoNLL-2002 file as a Huggingface Dataset.
|
|
@@ -238,12 +243,13 @@ def hgdataset_from_conll2002(
|
|
|
238
243
|
:param path: passed to :func:`.load_conll2002_bio`
|
|
239
244
|
:param tag_conversion_map: passed to :func:`load_conll2002_bio`
|
|
240
245
|
:param separator: passed to :func:`load_conll2002_bio`
|
|
241
|
-
:param
|
|
246
|
+
:param max_sent_len: passed to :func:`load_conll2002_bio`
|
|
247
|
+
:param kwargs: additional kwargs for :func:`open`
|
|
242
248
|
|
|
243
249
|
:return: a :class:`datasets.Dataset` with features 'tokens' and 'labels'.
|
|
244
250
|
"""
|
|
245
251
|
sentences, tokens, entities = load_conll2002_bio(
|
|
246
|
-
path, tag_conversion_map, separator, **kwargs
|
|
252
|
+
path, tag_conversion_map, separator, max_sent_len, **kwargs
|
|
247
253
|
)
|
|
248
254
|
|
|
249
255
|
# convert entities to labels
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict, List, FrozenSet, Set, Optional, Tuple, Union, Literal
|
|
2
|
-
import
|
|
2
|
+
import re, sys
|
|
3
3
|
from itertools import combinations
|
|
4
4
|
from collections import defaultdict, Counter
|
|
5
5
|
from dataclasses import dataclass
|
|
@@ -11,6 +11,7 @@ from renard.pipeline.ner import NEREntity
|
|
|
11
11
|
from renard.pipeline.progress import ProgressReporter
|
|
12
12
|
from renard.resources.hypocorisms import HypocorismGazetteer
|
|
13
13
|
from renard.resources.pronouns import is_a_female_pronoun, is_a_male_pronoun
|
|
14
|
+
from renard.resources.determiners import singular_determiners
|
|
14
15
|
from renard.resources.titles import is_a_male_title, is_a_female_title, all_titles
|
|
15
16
|
|
|
16
17
|
|
|
@@ -167,6 +168,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
167
168
|
additional_hypocorisms: Optional[List[Tuple[str, List[str]]]] = None,
|
|
168
169
|
link_corefs_mentions: bool = False,
|
|
169
170
|
ignore_lone_titles: Optional[Set[str]] = None,
|
|
171
|
+
ignore_leading_determiner: bool = False,
|
|
170
172
|
) -> None:
|
|
171
173
|
"""
|
|
172
174
|
:param min_appearances: minimum number of appearances of a
|
|
@@ -181,15 +183,20 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
181
183
|
extract a lot of spurious links. However, linking by
|
|
182
184
|
coref is sometimes the only way to resolve a character
|
|
183
185
|
alias.
|
|
184
|
-
:param ignore_lone_titles: a set of titles to ignore when
|
|
185
|
-
|
|
186
|
+
:param ignore_lone_titles: a set of titles to ignore when they
|
|
187
|
+
stand on their own. This avoids extracting false
|
|
186
188
|
positives characters such as 'Mr.' or 'Miss'.
|
|
189
|
+
:param ignore_leading_determiner: if ``True``, will ignore the
|
|
190
|
+
leading determiner when applying unification rules. This
|
|
191
|
+
is useful if the NER model used in the pipeline adds
|
|
192
|
+
leading determiners as part of entites.
|
|
187
193
|
"""
|
|
188
194
|
self.min_appearances = min_appearances
|
|
189
195
|
self.additional_hypocorisms = additional_hypocorisms
|
|
190
196
|
self.link_corefs_mentions = link_corefs_mentions
|
|
191
197
|
self.ignore_lone_titles = ignore_lone_titles or set()
|
|
192
198
|
self.character_ner_tag = "PER" # a default value, will be set by _pipeline_init
|
|
199
|
+
self.ignore_leading_determiner = ignore_leading_determiner
|
|
193
200
|
|
|
194
201
|
super().__init__()
|
|
195
202
|
|
|
@@ -229,23 +236,28 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
229
236
|
|
|
230
237
|
# * link nodes based on several rules
|
|
231
238
|
for name1, name2 in combinations(G.nodes(), 2):
|
|
239
|
+
|
|
240
|
+
# preprocess name when needed
|
|
241
|
+
pname1 = self._preprocess_name(name1)
|
|
242
|
+
pname2 = self._preprocess_name(name2)
|
|
243
|
+
|
|
232
244
|
# is one name a known hypocorism of the other ? (also
|
|
233
245
|
# checks if both names are the same)
|
|
234
|
-
if self.hypocorism_gazetteer.are_related(
|
|
246
|
+
if self.hypocorism_gazetteer.are_related(pname1, pname2):
|
|
235
247
|
G.add_edge(name1, name2)
|
|
236
248
|
continue
|
|
237
249
|
|
|
238
250
|
# if we remove the title, is one name related to the other
|
|
239
251
|
# ?
|
|
240
252
|
if self.names_are_related_after_title_removal(
|
|
241
|
-
|
|
253
|
+
pname1, pname2, hname_constants
|
|
242
254
|
):
|
|
243
255
|
G.add_edge(name1, name2)
|
|
244
256
|
continue
|
|
245
257
|
|
|
246
258
|
# add an edge if two characters have the same family names
|
|
247
|
-
human_name1 = HumanName(
|
|
248
|
-
human_name2 = HumanName(
|
|
259
|
+
human_name1 = HumanName(pname1, constants=hname_constants)
|
|
260
|
+
human_name2 = HumanName(pname2, constants=hname_constants)
|
|
249
261
|
if (
|
|
250
262
|
len(human_name1.last) > 0
|
|
251
263
|
and human_name1.last.lower() == human_name2.last.lower()
|
|
@@ -282,10 +294,15 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
282
294
|
pass
|
|
283
295
|
|
|
284
296
|
for name1, name2 in combinations(G.nodes(), 2):
|
|
297
|
+
|
|
298
|
+
# preprocess names when needed
|
|
299
|
+
pname1 = self._preprocess_name(name1)
|
|
300
|
+
pname2 = self._preprocess_name(name2)
|
|
301
|
+
|
|
285
302
|
# check if characters have the same last name but a
|
|
286
303
|
# different first name.
|
|
287
|
-
human_name1 = HumanName(
|
|
288
|
-
human_name2 = HumanName(
|
|
304
|
+
human_name1 = HumanName(pname1, constants=hname_constants)
|
|
305
|
+
human_name2 = HumanName(pname2, constants=hname_constants)
|
|
289
306
|
if (
|
|
290
307
|
len(human_name1.last) > 0
|
|
291
308
|
and len(human_name2.last) > 0
|
|
@@ -337,6 +354,17 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
337
354
|
|
|
338
355
|
return {"characters": characters}
|
|
339
356
|
|
|
357
|
+
def _preprocess_name(self, name) -> str:
|
|
358
|
+
if self.ignore_leading_determiner:
|
|
359
|
+
if not self.lang in singular_determiners:
|
|
360
|
+
print(
|
|
361
|
+
f"[warning] can't ignore leading determiners for {self.lang}",
|
|
362
|
+
file=sys.stderr,
|
|
363
|
+
)
|
|
364
|
+
for determiner in singular_determiners.get(self.lang, []):
|
|
365
|
+
name = re.sub(f"^{determiner} ", " ", name, flags=re.I)
|
|
366
|
+
return name
|
|
367
|
+
|
|
340
368
|
def _make_hname_constants(self) -> Constants:
|
|
341
369
|
if self.lang == "eng":
|
|
342
370
|
return Constants()
|
|
@@ -365,13 +393,18 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
365
393
|
or self.hypocorism_gazetteer.are_related(raw_name1, raw_name2)
|
|
366
394
|
)
|
|
367
395
|
|
|
368
|
-
def names_are_in_coref(
|
|
396
|
+
def names_are_in_coref(
|
|
397
|
+
self, name1: str, name2: str, corefs: List[List[Mention]]
|
|
398
|
+
) -> bool:
|
|
399
|
+
once_together = False
|
|
369
400
|
for coref_chain in corefs:
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
):
|
|
373
|
-
return
|
|
374
|
-
|
|
401
|
+
name1_in = any([name1 == " ".join(m.tokens) for m in coref_chain])
|
|
402
|
+
name2_in = any([name2 == " ".join(m.tokens) for m in coref_chain])
|
|
403
|
+
if name1_in == (not name2_in):
|
|
404
|
+
return False
|
|
405
|
+
elif name1_in and name2_in:
|
|
406
|
+
once_together = True
|
|
407
|
+
return once_together
|
|
375
408
|
|
|
376
409
|
def infer_name_gender(
|
|
377
410
|
self,
|
renard/pipeline/core.py
CHANGED
|
@@ -289,6 +289,7 @@ class PipelineState:
|
|
|
289
289
|
node_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
290
290
|
edge_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
291
291
|
label_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
292
|
+
legend: bool = False,
|
|
292
293
|
):
|
|
293
294
|
"""Plot ``self.character_graph`` using reasonable default
|
|
294
295
|
parameters, and save the produced figures in the specified
|
|
@@ -306,6 +307,7 @@ class PipelineState:
|
|
|
306
307
|
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
307
308
|
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
308
309
|
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
310
|
+
:param legend: passed to :func:`.plot_nx_graph_reasonably`
|
|
309
311
|
"""
|
|
310
312
|
import matplotlib.pyplot as plt
|
|
311
313
|
|
|
@@ -346,6 +348,7 @@ class PipelineState:
|
|
|
346
348
|
node_kwargs=node_kwargs[i],
|
|
347
349
|
edge_kwargs=edge_kwargs[i],
|
|
348
350
|
label_kwargs=label_kwargs[i],
|
|
351
|
+
legend=legend,
|
|
349
352
|
)
|
|
350
353
|
plt.savefig(f"{directory}/{i}.png")
|
|
351
354
|
plt.close()
|
|
@@ -361,6 +364,8 @@ class PipelineState:
|
|
|
361
364
|
node_kwargs: Optional[Dict[str, Any]] = None,
|
|
362
365
|
edge_kwargs: Optional[Dict[str, Any]] = None,
|
|
363
366
|
label_kwargs: Optional[Dict[str, Any]] = None,
|
|
367
|
+
tight_layout: bool = False,
|
|
368
|
+
legend: bool = False,
|
|
364
369
|
):
|
|
365
370
|
"""Plot ``self.character_graph`` using reasonable parameters,
|
|
366
371
|
and save the produced figure to a file
|
|
@@ -373,6 +378,8 @@ class PipelineState:
|
|
|
373
378
|
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
374
379
|
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
375
380
|
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
381
|
+
:param tight_layout: if ``True``, will use matplotlib's tight_layout
|
|
382
|
+
:param legend: passed to :func:`.plot_nx_graph_reasonably`
|
|
376
383
|
"""
|
|
377
384
|
import matplotlib.pyplot as plt
|
|
378
385
|
|
|
@@ -397,7 +404,10 @@ class PipelineState:
|
|
|
397
404
|
node_kwargs=node_kwargs,
|
|
398
405
|
edge_kwargs=edge_kwargs,
|
|
399
406
|
label_kwargs=label_kwargs,
|
|
407
|
+
legend=legend,
|
|
400
408
|
)
|
|
409
|
+
if tight_layout:
|
|
410
|
+
fig.tight_layout()
|
|
401
411
|
plt.savefig(path)
|
|
402
412
|
plt.close()
|
|
403
413
|
|
|
@@ -414,6 +424,8 @@ class PipelineState:
|
|
|
414
424
|
node_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
415
425
|
edge_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
416
426
|
label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
427
|
+
tight_layout: bool = False,
|
|
428
|
+
legend: bool = False,
|
|
417
429
|
):
|
|
418
430
|
"""Plot ``self.character_network`` using reasonable default
|
|
419
431
|
parameters
|
|
@@ -442,6 +454,8 @@ class PipelineState:
|
|
|
442
454
|
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
443
455
|
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
444
456
|
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
457
|
+
:param tight_layout: if ``True``, will use matplotlib's tight_layout
|
|
458
|
+
:param legend: passed to :func:`.plot_nx_graph_reasonably`
|
|
445
459
|
"""
|
|
446
460
|
import matplotlib.pyplot as plt
|
|
447
461
|
from matplotlib.widgets import Slider
|
|
@@ -463,6 +477,8 @@ class PipelineState:
|
|
|
463
477
|
assert not isinstance(node_kwargs, list)
|
|
464
478
|
assert not isinstance(edge_kwargs, list)
|
|
465
479
|
assert not isinstance(label_kwargs, list)
|
|
480
|
+
if tight_layout:
|
|
481
|
+
fig.tight_layout()
|
|
466
482
|
plot_nx_graph_reasonably(
|
|
467
483
|
G,
|
|
468
484
|
ax=ax,
|
|
@@ -470,6 +486,7 @@ class PipelineState:
|
|
|
470
486
|
node_kwargs=node_kwargs,
|
|
471
487
|
edge_kwargs=edge_kwargs,
|
|
472
488
|
label_kwargs=label_kwargs,
|
|
489
|
+
legend=legend,
|
|
473
490
|
)
|
|
474
491
|
return
|
|
475
492
|
|
|
@@ -520,11 +537,14 @@ class PipelineState:
|
|
|
520
537
|
node_kwargs=node_kwargs[slider_i],
|
|
521
538
|
edge_kwargs=edge_kwargs[slider_i],
|
|
522
539
|
label_kwargs=label_kwargs[slider_i],
|
|
540
|
+
legend=legend,
|
|
523
541
|
)
|
|
524
542
|
ax.set_xlim(-1.2, 1.2)
|
|
525
543
|
ax.set_ylim(-1.2, 1.2)
|
|
526
544
|
|
|
527
545
|
slider_ax = fig.add_axes([0.1, 0.05, 0.8, 0.04])
|
|
546
|
+
if tight_layout:
|
|
547
|
+
fig.tight_layout()
|
|
528
548
|
# HACK: we save the slider to the figure. This ensure the
|
|
529
549
|
# slider is still alive at plotting time.
|
|
530
550
|
fig.slider = Slider( # type: ignore
|
renard/pipeline/corefs/corefs.py
CHANGED
|
@@ -20,7 +20,7 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
22
|
model: Optional[Union[BertForCoreferenceResolution]] = None,
|
|
23
|
-
|
|
23
|
+
huggingface_model_id: Optional[str] = None,
|
|
24
24
|
batch_size: int = 1,
|
|
25
25
|
device: Literal["auto", "cuda", "cpu"] = "auto",
|
|
26
26
|
tokenizer: Optional[PreTrainedTokenizerFast] = None,
|
|
@@ -47,7 +47,7 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
47
47
|
inference on the whole document.
|
|
48
48
|
"""
|
|
49
49
|
if isinstance(model, str):
|
|
50
|
-
self.hugginface_model_id =
|
|
50
|
+
self.hugginface_model_id = huggingface_model_id
|
|
51
51
|
self.model = None # model will be init by _pipeline_init_
|
|
52
52
|
else:
|
|
53
53
|
self.hugginface_model_id = None
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from renard.pipeline.ner.ner import *
|
|
@@ -1,22 +1,32 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
import
|
|
3
|
-
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
List,
|
|
5
|
+
Dict,
|
|
6
|
+
Any,
|
|
7
|
+
Set,
|
|
8
|
+
Tuple,
|
|
9
|
+
Optional,
|
|
10
|
+
Union,
|
|
11
|
+
Literal,
|
|
12
|
+
)
|
|
4
13
|
from dataclasses import dataclass
|
|
5
14
|
import torch
|
|
6
|
-
from seqeval.metrics import precision_score, recall_score, f1_score
|
|
7
15
|
from renard.nltk_utils import nltk_fix_bio_tags
|
|
8
16
|
from renard.ner_utils import (
|
|
9
17
|
DataCollatorForTokenClassificationWithBatchEncoding,
|
|
10
18
|
NERDataset,
|
|
11
19
|
)
|
|
12
20
|
from renard.pipeline.core import PipelineStep, Mention
|
|
13
|
-
from renard.pipeline.progress import ProgressReporter
|
|
14
21
|
from renard.ner_utils import ner_entities
|
|
15
22
|
|
|
16
23
|
if TYPE_CHECKING:
|
|
17
24
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
18
|
-
from transformers import
|
|
19
|
-
|
|
25
|
+
from transformers import (
|
|
26
|
+
PreTrainedModel,
|
|
27
|
+
PreTrainedTokenizerFast,
|
|
28
|
+
)
|
|
29
|
+
from renard.pipeline.ner.retrieval import NERContextRetriever
|
|
20
30
|
|
|
21
31
|
|
|
22
32
|
@dataclass
|
|
@@ -28,7 +38,7 @@ class NEREntity(Mention):
|
|
|
28
38
|
"""
|
|
29
39
|
.. note::
|
|
30
40
|
|
|
31
|
-
This method is
|
|
41
|
+
This method is implemented here to avoid type issues. Since
|
|
32
42
|
:meth:`.Mention.shifted` cannot be annotated as returning
|
|
33
43
|
``Self``, this method annotate the correct return type when
|
|
34
44
|
using :meth:`.NEREntity.shifted`.
|
|
@@ -42,18 +52,21 @@ class NEREntity(Mention):
|
|
|
42
52
|
def score_ner(
|
|
43
53
|
pred_bio_tags: List[str], ref_bio_tags: List[str]
|
|
44
54
|
) -> Tuple[float, float, float]:
|
|
45
|
-
"""Score NER as in CoNLL-2003 shared task using ``seqeval``
|
|
55
|
+
"""Score NER as in CoNLL-2003 shared task using the ``seqeval``
|
|
56
|
+
library, if installed.
|
|
46
57
|
|
|
47
58
|
Precision is the percentage of named entities in ``ref_bio_tags``
|
|
48
|
-
that are correct.
|
|
49
|
-
pred_bio_tags that are in ref_bio_tags.
|
|
50
|
-
both.
|
|
59
|
+
that are correct. Recall is the percentage of named entities in
|
|
60
|
+
pred_bio_tags that are in ref_bio_tags. F1 is the harmonic mean
|
|
61
|
+
of both.
|
|
51
62
|
|
|
52
63
|
:param pred_bio_tags:
|
|
53
64
|
:param ref_bio_tags:
|
|
54
|
-
:return: ``(precision, recall, F1 score)``
|
|
55
65
|
|
|
66
|
+
:return: ``(precision, recall, F1 score)``
|
|
56
67
|
"""
|
|
68
|
+
from seqeval.metrics import precision_score, recall_score, f1_score
|
|
69
|
+
|
|
57
70
|
assert len(pred_bio_tags) == len(ref_bio_tags)
|
|
58
71
|
return (
|
|
59
72
|
precision_score([ref_bio_tags], [pred_bio_tags]),
|
|
@@ -71,12 +84,19 @@ class NLTKNamedEntityRecognizer(PipelineStep):
|
|
|
71
84
|
"""
|
|
72
85
|
import nltk
|
|
73
86
|
|
|
74
|
-
nltk.download("averaged_perceptron_tagger", quiet=True)
|
|
87
|
+
nltk.download(f"averaged_perceptron_tagger", quiet=True)
|
|
75
88
|
nltk.download("maxent_ne_chunker", quiet=True)
|
|
89
|
+
nltk.download("maxent_ne_chunker_tab", quiet=True)
|
|
76
90
|
nltk.download("words", quiet=True)
|
|
77
91
|
|
|
78
92
|
super().__init__()
|
|
79
93
|
|
|
94
|
+
def _pipeline_init_(self, lang: str, **kwargs):
|
|
95
|
+
import nltk
|
|
96
|
+
|
|
97
|
+
nltk.download(f"averaged_perceptron_tagger_{lang}", quiet=True)
|
|
98
|
+
super()._pipeline_init_(lang, **kwargs)
|
|
99
|
+
|
|
80
100
|
def __call__(self, tokens: List[str], **kwargs) -> Dict[str, Any]:
|
|
81
101
|
"""
|
|
82
102
|
:param text:
|
|
@@ -102,64 +122,6 @@ class NLTKNamedEntityRecognizer(PipelineStep):
|
|
|
102
122
|
return {"entities"}
|
|
103
123
|
|
|
104
124
|
|
|
105
|
-
class NERContextRetriever:
|
|
106
|
-
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
107
|
-
raise NotImplementedError
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
class NERSamenounContextRetriever(NERContextRetriever):
|
|
111
|
-
"""
|
|
112
|
-
Retrieve relevant context using the samenoun strategy as in
|
|
113
|
-
Amalvy et al. 2023.
|
|
114
|
-
"""
|
|
115
|
-
|
|
116
|
-
def __init__(self, k: int) -> None:
|
|
117
|
-
"""
|
|
118
|
-
:param k: the number of sentences to retrieve
|
|
119
|
-
"""
|
|
120
|
-
self.k = k
|
|
121
|
-
|
|
122
|
-
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
123
|
-
import nltk
|
|
124
|
-
|
|
125
|
-
# NOTE: POS tagging is not incorporated in the pipeline yet,
|
|
126
|
-
# so we manually compute it here.
|
|
127
|
-
elements_names = [
|
|
128
|
-
{t[0] for t in nltk.pos_tag(element) if t[1].startswith("NN")}
|
|
129
|
-
for element in dataset.elements
|
|
130
|
-
]
|
|
131
|
-
|
|
132
|
-
elements_with_context = []
|
|
133
|
-
|
|
134
|
-
for elt_i, elt in enumerate(dataset.elements):
|
|
135
|
-
retrieved_elts = [
|
|
136
|
-
other_elt
|
|
137
|
-
for other_elt_i, other_elt in enumerate(dataset.elements)
|
|
138
|
-
if not other_elt_i == elt_i
|
|
139
|
-
and len(elements_names[elt_i].intersection(elements_names[other_elt_i]))
|
|
140
|
-
> 0
|
|
141
|
-
]
|
|
142
|
-
retrieved_elts = random.sample(
|
|
143
|
-
retrieved_elts, k=min(self.k, len(retrieved_elts))
|
|
144
|
-
)
|
|
145
|
-
elements_with_context.append(
|
|
146
|
-
(
|
|
147
|
-
elt,
|
|
148
|
-
[dataset.tokenizer.sep_token]
|
|
149
|
-
+ list(itertools.chain.from_iterable(retrieved_elts)),
|
|
150
|
-
)
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
return NERDataset(
|
|
154
|
-
[element + context for element, context in elements_with_context],
|
|
155
|
-
dataset.tokenizer,
|
|
156
|
-
[
|
|
157
|
-
[0] * len(element) + [1] * len(context)
|
|
158
|
-
for element, context in elements_with_context
|
|
159
|
-
],
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
|
|
163
125
|
class BertNamedEntityRecognizer(PipelineStep):
|
|
164
126
|
"""An entity recognizer based on BERT"""
|
|
165
127
|
|
|
@@ -307,7 +269,7 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
307
269
|
batch_i: int,
|
|
308
270
|
wp_labels: List[str],
|
|
309
271
|
tokens: List[str],
|
|
310
|
-
|
|
272
|
+
ctxmask: torch.Tensor,
|
|
311
273
|
) -> List[str]:
|
|
312
274
|
"""Align labels to tokens rather than wordpiece tokens.
|
|
313
275
|
|
|
@@ -318,13 +280,21 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
318
280
|
"""
|
|
319
281
|
batch_labels = ["O"] * len(tokens)
|
|
320
282
|
|
|
283
|
+
try:
|
|
284
|
+
inference_start = ctxmask[batch_i].tolist().index(1)
|
|
285
|
+
except ValueError:
|
|
286
|
+
inference_start = 0
|
|
287
|
+
|
|
321
288
|
for wplabel_j, wp_label in enumerate(wp_labels):
|
|
322
|
-
|
|
323
|
-
continue
|
|
289
|
+
|
|
324
290
|
token_i = batchs.token_to_word(batch_i, wplabel_j)
|
|
325
291
|
if token_i is None:
|
|
326
292
|
continue
|
|
327
|
-
|
|
293
|
+
|
|
294
|
+
if ctxmask[batch_i][token_i] == 0:
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
batch_labels[token_i - inference_start] = wp_label
|
|
328
298
|
|
|
329
299
|
return batch_labels
|
|
330
300
|
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
from collections.abc import Set
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Union, List, cast, Literal, Optional
|
|
4
|
+
import random
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from more_itertools import flatten
|
|
7
|
+
from renard.ner_utils import NERDataset
|
|
8
|
+
import nltk
|
|
9
|
+
from rank_bm25 import BM25Okapi
|
|
10
|
+
from transformers import (
|
|
11
|
+
BertForSequenceClassification,
|
|
12
|
+
BertTokenizerFast,
|
|
13
|
+
DataCollatorWithPadding,
|
|
14
|
+
)
|
|
15
|
+
from transformers.tokenization_utils_base import BatchEncoding
|
|
16
|
+
import torch
|
|
17
|
+
from torch.utils.data import Dataset, DataLoader
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class NERContextRetrievalMatch:
|
|
22
|
+
element: List[str]
|
|
23
|
+
element_i: int
|
|
24
|
+
side: Literal["left", "right"]
|
|
25
|
+
score: Optional[float]
|
|
26
|
+
|
|
27
|
+
def __hash__(self) -> int:
|
|
28
|
+
return hash(tuple(self.element) + (self.element_i, self.side, self.score))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class NERContextRetriever:
|
|
32
|
+
def __init__(self, k: int) -> None:
|
|
33
|
+
self.k = k
|
|
34
|
+
|
|
35
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
36
|
+
return {}
|
|
37
|
+
|
|
38
|
+
def retrieve(
|
|
39
|
+
self, element_i: int, elements: List[List[str]], **kwargs
|
|
40
|
+
) -> List[NERContextRetrievalMatch]:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
44
|
+
# [(left_ctx, element, right_ctx), ...]
|
|
45
|
+
elements_with_context = []
|
|
46
|
+
|
|
47
|
+
global_features = self.compute_global_features(dataset.elements)
|
|
48
|
+
|
|
49
|
+
for elt_i, elt in enumerate(dataset.elements):
|
|
50
|
+
matchs = self.retrieve(elt_i, dataset.elements, **global_features)
|
|
51
|
+
assert len(matchs) <= self.k
|
|
52
|
+
|
|
53
|
+
lctx = sorted(
|
|
54
|
+
(m for m in matchs if m.side == "left"),
|
|
55
|
+
key=lambda m: m.element_i,
|
|
56
|
+
)
|
|
57
|
+
lctx = list(flatten([m.element for m in lctx]))
|
|
58
|
+
|
|
59
|
+
rctx = sorted(
|
|
60
|
+
(m for m in matchs if m.side == "right"),
|
|
61
|
+
key=lambda m: m.element_i,
|
|
62
|
+
)
|
|
63
|
+
rctx = list(flatten([m.element for m in rctx]))
|
|
64
|
+
|
|
65
|
+
elements_with_context.append((lctx, elt, rctx))
|
|
66
|
+
|
|
67
|
+
return NERDataset(
|
|
68
|
+
[lctx + element + rctx for lctx, element, rctx in elements_with_context],
|
|
69
|
+
dataset.tokenizer,
|
|
70
|
+
[
|
|
71
|
+
[1] * len(lctx) + [0] * len(element) + [1] * len(rctx)
|
|
72
|
+
for lctx, element, rctx in elements_with_context
|
|
73
|
+
],
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class NERSamenounContextRetriever(NERContextRetriever):
|
|
78
|
+
"""
|
|
79
|
+
Retrieve relevant context using the samenoun strategy as in
|
|
80
|
+
Amalvy et al. 2023.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, k: int) -> None:
|
|
84
|
+
"""
|
|
85
|
+
:param k: the max number of sentences to retrieve
|
|
86
|
+
"""
|
|
87
|
+
super().__init__(k)
|
|
88
|
+
|
|
89
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
90
|
+
return {
|
|
91
|
+
"NNs": [
|
|
92
|
+
{t[0] for t in nltk.pos_tag(element) if t[1] == "NN"}
|
|
93
|
+
for element in elements
|
|
94
|
+
]
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def retrieve(
|
|
98
|
+
self, element_i: int, elements: List[List[str]], NNs: List[Set[str]], **kwargs
|
|
99
|
+
) -> List[NERContextRetrievalMatch]:
|
|
100
|
+
matchs = [
|
|
101
|
+
NERContextRetrievalMatch(
|
|
102
|
+
other_elt,
|
|
103
|
+
other_elt_i,
|
|
104
|
+
"left" if other_elt_i < element_i else "right",
|
|
105
|
+
None,
|
|
106
|
+
)
|
|
107
|
+
for other_elt_i, other_elt in enumerate(elements)
|
|
108
|
+
if not other_elt_i == element_i
|
|
109
|
+
and len(NNs[element_i].intersection(NNs[other_elt_i])) > 0 # type: ignore
|
|
110
|
+
]
|
|
111
|
+
return random.sample(matchs, k=min(self.k, len(matchs)))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class NERNeighborsContextRetriever(NERContextRetriever):
|
|
115
|
+
"""A context retriever that chooses nearby elements."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, k: int):
|
|
118
|
+
assert k % 2 == 0
|
|
119
|
+
super().__init__(k)
|
|
120
|
+
|
|
121
|
+
def retrieve(
|
|
122
|
+
self, element_i: int, elements: List[List[str]], **kwargs
|
|
123
|
+
) -> List[NERContextRetrievalMatch]:
|
|
124
|
+
left_nb = self.k // 2
|
|
125
|
+
right_nb = left_nb
|
|
126
|
+
|
|
127
|
+
lctx = []
|
|
128
|
+
for i, elt in enumerate(elements[element_i - left_nb : element_i]):
|
|
129
|
+
lctx.append(
|
|
130
|
+
NERContextRetrievalMatch(elt, element_i - left_nb + i, "left", None)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
rctx = []
|
|
134
|
+
for i, elt in enumerate(elements[element_i + 1 : element_i + 1 + right_nb]):
|
|
135
|
+
rctx.append(NERContextRetrievalMatch(elt, element_i + 1 + i, "right", None))
|
|
136
|
+
|
|
137
|
+
return lctx + rctx
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class NERBM25ContextRetriever(NERContextRetriever):
|
|
141
|
+
"""A context retriever that selects elements according to the BM25 ranking formula."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, k: int) -> None:
|
|
144
|
+
super().__init__(k)
|
|
145
|
+
|
|
146
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
147
|
+
return {"bm25_model": BM25Okapi(elements)}
|
|
148
|
+
|
|
149
|
+
def retrieve(
|
|
150
|
+
self, element_i: int, elements: List[List[str]], bm25_model: BM25Okapi, **kwargs
|
|
151
|
+
) -> List[NERContextRetrievalMatch]:
|
|
152
|
+
query = elements[element_i]
|
|
153
|
+
sent_scores = bm25_model.get_scores(query)
|
|
154
|
+
sent_scores[element_i] = float("-Inf") # don't retrieve self
|
|
155
|
+
topk_values, topk_indexs = torch.topk(
|
|
156
|
+
torch.tensor(sent_scores), k=min(self.k, len(sent_scores)), dim=0
|
|
157
|
+
)
|
|
158
|
+
return [
|
|
159
|
+
NERContextRetrievalMatch(
|
|
160
|
+
elements[index], index, "left" if index < element_i else "right", value
|
|
161
|
+
)
|
|
162
|
+
for value, index in zip(topk_values.tolist(), topk_indexs.tolist())
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(frozen=True)
|
|
167
|
+
class NERNeuralContextRetrievalExample:
|
|
168
|
+
"""A context retrieval example."""
|
|
169
|
+
|
|
170
|
+
#: text on which NER is performed
|
|
171
|
+
element: List[str]
|
|
172
|
+
#: context to assist during prediction
|
|
173
|
+
context: List[str]
|
|
174
|
+
#: context side (does the context comes from the left or the right of ``sent`` ?)
|
|
175
|
+
context_side: Literal["left", "right"]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class NERNeuralContextRetrievalDataset(Dataset):
|
|
179
|
+
""""""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
examples: List[NERNeuralContextRetrievalExample],
|
|
184
|
+
tokenizer: BertTokenizerFast,
|
|
185
|
+
) -> None:
|
|
186
|
+
self.examples = examples
|
|
187
|
+
self.tokenizer: BertTokenizerFast = tokenizer
|
|
188
|
+
|
|
189
|
+
def __len__(self) -> int:
|
|
190
|
+
return len(self.examples)
|
|
191
|
+
|
|
192
|
+
def __getitem__(self, index: int) -> BatchEncoding:
|
|
193
|
+
"""Get a BatchEncoding representing example at index.
|
|
194
|
+
|
|
195
|
+
:param index: index of the example to retrieve
|
|
196
|
+
|
|
197
|
+
:return: a ``BatchEncoding``, with key ``'label'`` set.
|
|
198
|
+
"""
|
|
199
|
+
example = self.examples[index]
|
|
200
|
+
|
|
201
|
+
tokens = example.context + ["[SEP]"] + example.element
|
|
202
|
+
|
|
203
|
+
batch: BatchEncoding = self.tokenizer(
|
|
204
|
+
tokens,
|
|
205
|
+
is_split_into_words=True,
|
|
206
|
+
truncation=True,
|
|
207
|
+
max_length=512,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return batch
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class NERNeuralContextRetriever(NERContextRetriever):
|
|
214
|
+
"""
|
|
215
|
+
A neural context retriever as in Amalvy et al. 2024
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
heuristic_context_selector: NERContextRetriever,
|
|
221
|
+
pretrained_model: Union[
|
|
222
|
+
str, BertForSequenceClassification
|
|
223
|
+
] = "compnet-renard/bert-base-cased-NER-reranker",
|
|
224
|
+
k: int = 3,
|
|
225
|
+
batch_size: int = 1,
|
|
226
|
+
threshold: float = 0.0,
|
|
227
|
+
device_str: Literal["cuda", "cpu", "auto"] = "auto",
|
|
228
|
+
) -> None:
|
|
229
|
+
"""
|
|
230
|
+
:param pretrained_model: pretrained model name, used to
|
|
231
|
+
load a :class:`transformers.BertForSequenceClassification`
|
|
232
|
+
:param heuristic_context_selector: name of the context
|
|
233
|
+
selector to use as retrieval heuristic, from
|
|
234
|
+
``context_selector_name_to_class``
|
|
235
|
+
:param heuristic_context_selector_kwargs: kwargs to pass the
|
|
236
|
+
heuristic context retriever at instantiation time
|
|
237
|
+
:param k: max number of sents to retrieve
|
|
238
|
+
:param batch_size: batch size used at inference
|
|
239
|
+
:param threshold:
|
|
240
|
+
:param device_str:
|
|
241
|
+
"""
|
|
242
|
+
from transformers import BertForSequenceClassification, BertTokenizerFast
|
|
243
|
+
|
|
244
|
+
if isinstance(pretrained_model, str):
|
|
245
|
+
self.ctx_classifier = BertForSequenceClassification.from_pretrained(
|
|
246
|
+
pretrained_model
|
|
247
|
+
) # type: ignore
|
|
248
|
+
else:
|
|
249
|
+
self.ctx_classifier = pretrained_model
|
|
250
|
+
self.ctx_classifier = cast(BertForSequenceClassification, self.ctx_classifier)
|
|
251
|
+
|
|
252
|
+
self.tokenizer = BertTokenizerFast.from_pretrained(
|
|
253
|
+
pretrained_model if isinstance(pretrained_model, str) else "bert-base-cased"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
self.heuristic_context_selector = heuristic_context_selector
|
|
257
|
+
|
|
258
|
+
self.batch_size = batch_size
|
|
259
|
+
self.threshold = threshold
|
|
260
|
+
|
|
261
|
+
if device_str == "auto":
|
|
262
|
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
263
|
+
self.device = torch.device(device_str)
|
|
264
|
+
|
|
265
|
+
super().__init__(k)
|
|
266
|
+
|
|
267
|
+
def set_heuristic_k_(self, k: int):
|
|
268
|
+
self.heuristic_context_selector.k = k
|
|
269
|
+
|
|
270
|
+
def predict(self, examples: List[NERNeuralContextRetrievalExample]) -> torch.Tensor:
|
|
271
|
+
"""
|
|
272
|
+
:param dataset: A list of :class:`ContextSelectionExample`
|
|
273
|
+
:return: A tensor of shape ``(len(dataset), 2)`` of class
|
|
274
|
+
scores
|
|
275
|
+
"""
|
|
276
|
+
dataset = NERNeuralContextRetrievalDataset(examples, self.tokenizer)
|
|
277
|
+
|
|
278
|
+
self.ctx_classifier = self.ctx_classifier.to(self.device)
|
|
279
|
+
|
|
280
|
+
data_collator = DataCollatorWithPadding(dataset.tokenizer) # type: ignore
|
|
281
|
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, collate_fn=data_collator) # type: ignore
|
|
282
|
+
|
|
283
|
+
# inference using self.ctx_classifier
|
|
284
|
+
self.ctx_classifier = self.ctx_classifier.eval()
|
|
285
|
+
with torch.no_grad():
|
|
286
|
+
scores = torch.zeros((0,)).to(self.device)
|
|
287
|
+
for X in dataloader:
|
|
288
|
+
X = X.to(self.device)
|
|
289
|
+
# out.logits is of shape (batch_size, 2)
|
|
290
|
+
out = self.ctx_classifier(
|
|
291
|
+
X["input_ids"],
|
|
292
|
+
token_type_ids=X["token_type_ids"],
|
|
293
|
+
attention_mask=X["attention_mask"],
|
|
294
|
+
)
|
|
295
|
+
# (batch_size, 2)
|
|
296
|
+
pred = torch.softmax(out.logits, dim=1)
|
|
297
|
+
scores = torch.cat([scores, pred], dim=0)
|
|
298
|
+
|
|
299
|
+
return scores
|
|
300
|
+
|
|
301
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
302
|
+
features = self.heuristic_context_selector.compute_global_features(elements)
|
|
303
|
+
return {
|
|
304
|
+
"heuristic_matchs": [
|
|
305
|
+
self.heuristic_context_selector.retrieve(i, elements, **features)
|
|
306
|
+
for i in range(len(elements))
|
|
307
|
+
]
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
def retrieve(
|
|
311
|
+
self,
|
|
312
|
+
element_i: int,
|
|
313
|
+
elements: List[List[str]],
|
|
314
|
+
heuristic_matchs: List[List[NERContextRetrievalMatch]],
|
|
315
|
+
**kwargs,
|
|
316
|
+
) -> List[NERContextRetrievalMatch]:
|
|
317
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
318
|
+
self.ctx_classifier = self.ctx_classifier.to(device) # type: ignore
|
|
319
|
+
|
|
320
|
+
# no context retrieved by heuristic : nothing to do
|
|
321
|
+
if len(heuristic_matchs) == 0:
|
|
322
|
+
return []
|
|
323
|
+
|
|
324
|
+
element = elements[element_i]
|
|
325
|
+
matchs = heuristic_matchs[element_i]
|
|
326
|
+
|
|
327
|
+
# prepare datas for inference
|
|
328
|
+
ctx_dataset = [
|
|
329
|
+
NERNeuralContextRetrievalExample(element, m.element, m.side) for m in matchs
|
|
330
|
+
]
|
|
331
|
+
|
|
332
|
+
# (len(dataset), 2)
|
|
333
|
+
scores = self.predict(ctx_dataset)
|
|
334
|
+
for i, m in enumerate(matchs):
|
|
335
|
+
m.score = float(scores[i, 1].item())
|
|
336
|
+
|
|
337
|
+
assert all([not m.score is None for m in matchs])
|
|
338
|
+
return [
|
|
339
|
+
m
|
|
340
|
+
for m in sorted(matchs, key=lambda m: -m.score)[: self.k] # type: ignore
|
|
341
|
+
if m.score > self.threshold # type: ignore
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class NEREnsembleContextRetriever(NERContextRetriever):
|
|
346
|
+
"""Combine several context retriever"""
|
|
347
|
+
|
|
348
|
+
def __init__(self, retrievers: List[NERContextRetriever], k: int) -> None:
|
|
349
|
+
self.retrievers = retrievers
|
|
350
|
+
super().__init__(k)
|
|
351
|
+
|
|
352
|
+
def compute_global_features(self, elements: List[List[str]]) -> dict:
|
|
353
|
+
features = {}
|
|
354
|
+
for retriever in self.retrievers:
|
|
355
|
+
for k, v in retriever.compute_global_features(elements).items():
|
|
356
|
+
if k in features:
|
|
357
|
+
print(
|
|
358
|
+
f"[warning] NEREnsembleContextRetriver: incompatible global feature '{k}' between multiple retrievers.",
|
|
359
|
+
file=sys.stderr,
|
|
360
|
+
)
|
|
361
|
+
features[k] = v
|
|
362
|
+
return features
|
|
363
|
+
|
|
364
|
+
def retrieve(
|
|
365
|
+
self, element_i: int, elements: List[List[str]], **kwargs
|
|
366
|
+
) -> List[NERContextRetrievalMatch]:
|
|
367
|
+
all_matchs = set()
|
|
368
|
+
|
|
369
|
+
for retriever in self.retrievers:
|
|
370
|
+
matchs = retriever.retrieve(element_i, elements, **kwargs)
|
|
371
|
+
all_matchs = all_matchs.union(matchs)
|
|
372
|
+
|
|
373
|
+
if all(not m.score is None for m in all_matchs):
|
|
374
|
+
return sorted(all_matchs, key=lambda m: -m.score)[: self.k] # type: ignore
|
|
375
|
+
return random.choices(list(all_matchs), k=self.k)
|
renard/pipeline/tokenization.py
CHANGED
renard/plot_utils.py
CHANGED
|
@@ -25,6 +25,7 @@ def plot_nx_graph_reasonably(
|
|
|
25
25
|
node_kwargs: Optional[Dict[str, Any]] = None,
|
|
26
26
|
edge_kwargs: Optional[Dict[str, Any]] = None,
|
|
27
27
|
label_kwargs: Optional[Dict[str, Any]] = None,
|
|
28
|
+
legend: bool = False,
|
|
28
29
|
):
|
|
29
30
|
"""Try to plot a :class:`nx.Graph` with 'reasonable' parameters
|
|
30
31
|
|
|
@@ -35,6 +36,7 @@ def plot_nx_graph_reasonably(
|
|
|
35
36
|
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
36
37
|
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
37
38
|
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
39
|
+
:param legend: if ``True``, will try to plot an additional legend.
|
|
38
40
|
"""
|
|
39
41
|
pos = layout
|
|
40
42
|
if pos is None:
|
|
@@ -48,7 +50,12 @@ def plot_nx_graph_reasonably(
|
|
|
48
50
|
node_kwargs["node_size"] = node_kwargs.get(
|
|
49
51
|
"node_size", [1 + degree * 10 for _, degree in G.degree]
|
|
50
52
|
)
|
|
51
|
-
nx.draw_networkx_nodes(G, pos, ax=ax, **node_kwargs)
|
|
53
|
+
scatter = nx.draw_networkx_nodes(G, pos, ax=ax, **node_kwargs)
|
|
54
|
+
if legend:
|
|
55
|
+
if ax:
|
|
56
|
+
ax.legend(*scatter.legend_elements("sizes"))
|
|
57
|
+
else:
|
|
58
|
+
plt.legend(*scatter.legend_elements("sizes"))
|
|
52
59
|
|
|
53
60
|
edge_kwargs = edge_kwargs or {}
|
|
54
61
|
edges_attrs = graph_edges_attributes(G)
|
|
@@ -64,11 +71,11 @@ def plot_nx_graph_reasonably(
|
|
|
64
71
|
edge_kwargs["edge_cmap"] = None
|
|
65
72
|
else:
|
|
66
73
|
edge_kwargs["edge_color"] = edge_kwargs.get(
|
|
67
|
-
"edge_color", [math.log(d
|
|
74
|
+
"edge_color", [math.log(d.get("weight", 1)) for *_, d in G.edges.data()]
|
|
68
75
|
)
|
|
69
76
|
edge_kwargs["edge_cmap"] = edge_kwargs.get("edge_cmap", plt.get_cmap("viridis"))
|
|
70
77
|
edge_kwargs["width"] = edge_kwargs.get(
|
|
71
|
-
"width", [1 + math.log(d
|
|
78
|
+
"width", [1 + math.log(d.get("weight", 1)) for _, _, d in G.edges.data()]
|
|
72
79
|
)
|
|
73
80
|
edge_kwargs["alpha"] = edge_kwargs.get("alpha", 0.35)
|
|
74
81
|
nx.draw_networkx_edges(G, pos, ax=ax, **edge_kwargs)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from renard.resources.determiners.determiners import *
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
singular_determiners = {
|
|
2
|
+
"eng": {
|
|
3
|
+
"a",
|
|
4
|
+
"some",
|
|
5
|
+
"the",
|
|
6
|
+
"his",
|
|
7
|
+
"her",
|
|
8
|
+
"my",
|
|
9
|
+
"their",
|
|
10
|
+
"this",
|
|
11
|
+
"that",
|
|
12
|
+
"its",
|
|
13
|
+
"our",
|
|
14
|
+
"your",
|
|
15
|
+
"such",
|
|
16
|
+
},
|
|
17
|
+
"fra": {
|
|
18
|
+
"le",
|
|
19
|
+
"la",
|
|
20
|
+
"les",
|
|
21
|
+
"un",
|
|
22
|
+
"une",
|
|
23
|
+
"du",
|
|
24
|
+
"de",
|
|
25
|
+
"de la",
|
|
26
|
+
"ce",
|
|
27
|
+
"cette",
|
|
28
|
+
"mon",
|
|
29
|
+
"ma",
|
|
30
|
+
"ton",
|
|
31
|
+
"ta",
|
|
32
|
+
"son",
|
|
33
|
+
"sa",
|
|
34
|
+
"notre",
|
|
35
|
+
"votre",
|
|
36
|
+
"leur",
|
|
37
|
+
"au",
|
|
38
|
+
"à",
|
|
39
|
+
"l '",
|
|
40
|
+
},
|
|
41
|
+
}
|
|
@@ -1,37 +1,38 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: renard-pipeline
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.1
|
|
4
4
|
Summary: Relationships Extraction from NARrative Documents
|
|
5
5
|
Home-page: https://github.com/CompNet/Renard
|
|
6
6
|
License: GPL-3.0-only
|
|
7
7
|
Author: Arthur Amalvy
|
|
8
8
|
Author-email: arthur.amalvy@univ-avignon.fr
|
|
9
|
-
Requires-Python: >=3.8,<3.
|
|
9
|
+
Requires-Python: >=3.8,<3.12
|
|
10
10
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.8
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.9
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
16
|
Provides-Extra: spacy
|
|
16
17
|
Provides-Extra: stanza
|
|
17
|
-
Requires-Dist: coreferee (>=1.4
|
|
18
|
-
Requires-Dist: datasets (>=
|
|
19
|
-
Requires-Dist: grimbert (>=0.1
|
|
20
|
-
Requires-Dist: matplotlib (>=3.5
|
|
21
|
-
Requires-Dist: more-itertools (>=10.
|
|
22
|
-
Requires-Dist: nameparser (>=1.1
|
|
23
|
-
Requires-Dist: networkx (>=
|
|
24
|
-
Requires-Dist: nltk (>=3.
|
|
25
|
-
Requires-Dist: pandas (>=2.0
|
|
26
|
-
Requires-Dist: pytest (>=
|
|
27
|
-
Requires-Dist:
|
|
28
|
-
Requires-Dist: spacy (>=3.5
|
|
29
|
-
Requires-Dist: spacy-transformers (>=1.
|
|
30
|
-
Requires-Dist: stanza (>=1.3
|
|
31
|
-
Requires-Dist: tibert (>=0.
|
|
18
|
+
Requires-Dist: coreferee (>=1.4,<2.0) ; extra == "spacy"
|
|
19
|
+
Requires-Dist: datasets (>=3.0,<4.0)
|
|
20
|
+
Requires-Dist: grimbert (>=0.1,<0.2)
|
|
21
|
+
Requires-Dist: matplotlib (>=3.5,<4.0)
|
|
22
|
+
Requires-Dist: more-itertools (>=10.5,<11.0)
|
|
23
|
+
Requires-Dist: nameparser (>=1.1,<2.0)
|
|
24
|
+
Requires-Dist: networkx (>=3.0,<4.0)
|
|
25
|
+
Requires-Dist: nltk (>=3.9,<4.0)
|
|
26
|
+
Requires-Dist: pandas (>=2.0,<3.0)
|
|
27
|
+
Requires-Dist: pytest (>=8.3.0,<9.0.0)
|
|
28
|
+
Requires-Dist: rank-bm25 (>=0.2.2,<0.3.0)
|
|
29
|
+
Requires-Dist: spacy (>=3.5,<4.0) ; extra == "spacy"
|
|
30
|
+
Requires-Dist: spacy-transformers (>=1.3,<2.0) ; extra == "spacy"
|
|
31
|
+
Requires-Dist: stanza (>=1.3,<2.0) ; extra == "stanza"
|
|
32
|
+
Requires-Dist: tibert (>=0.5,<0.6)
|
|
32
33
|
Requires-Dist: torch (>=2.0.0,!=2.0.1)
|
|
33
34
|
Requires-Dist: tqdm (>=4.62.3,<5.0.0)
|
|
34
|
-
Requires-Dist: transformers (>=4.36
|
|
35
|
+
Requires-Dist: transformers (>=4.36,<5.0)
|
|
35
36
|
Project-URL: Documentation, https://compnet.github.io/Renard/
|
|
36
37
|
Project-URL: Repository, https://github.com/CompNet/Renard
|
|
37
38
|
Description-Content-Type: text/markdown
|
|
@@ -40,7 +41,7 @@ Description-Content-Type: text/markdown
|
|
|
40
41
|
|
|
41
42
|
[](https://doi.org/10.21105/joss.06574)
|
|
42
43
|
|
|
43
|
-
Renard (
|
|
44
|
+
Renard (Relationship Extraction from NARrative Documents) is a library for creating and using custom character networks extraction pipelines. Renard can extract dynamic as well as static character networks.
|
|
44
45
|
|
|
45
46
|

|
|
46
47
|
|
|
@@ -51,7 +52,7 @@ You can install the latest version using pip:
|
|
|
51
52
|
|
|
52
53
|
> pip install renard-pipeline
|
|
53
54
|
|
|
54
|
-
Currently, Renard supports Python
|
|
55
|
+
Currently, Renard supports Python>=3.8,<=3.11
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
# Documentation
|
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
renard/gender.py,sha256=HDtJQKOqIkV8F-Mxva95XFXWJoKRKckQ3fc93OBM6sw,102
|
|
2
2
|
renard/graph_utils.py,sha256=EV0_56KtI3VOElCu7wxd2kL8QVPsOu7itE6wGJAJsNA,6073
|
|
3
|
-
renard/ner_utils.py,sha256=
|
|
3
|
+
renard/ner_utils.py,sha256=SFZoyJM6c2avE7-NDkCSzkx-O8ppzS00a8EyHt64iGI,11628
|
|
4
4
|
renard/nltk_utils.py,sha256=mUJiwMrEDZV4Fla7WuMR-hA_OC2ZIwSXgW_0Ew18VSo,977
|
|
5
5
|
renard/pipeline/__init__.py,sha256=8Yim2mmny8YGvM7N5-na5zK-C9UDxUb77K9ml-VirUA,35
|
|
6
|
-
renard/pipeline/character_unification.py,sha256=
|
|
6
|
+
renard/pipeline/character_unification.py,sha256=SsMaBHfGgRAvZyYbVcm6pxnIqHqD_JyQndGvwSjsGCc,17074
|
|
7
7
|
renard/pipeline/characters_extraction.py,sha256=bMic8dtlYKUmAlTzQqDPraYy5VsGWoGkho35mA8w3_Y,396
|
|
8
|
-
renard/pipeline/core.py,sha256=
|
|
8
|
+
renard/pipeline/core.py,sha256=LILUIQZp9f3FzqjBocUS7dKzX7lHQQVdL29jyqU1UeY,27754
|
|
9
9
|
renard/pipeline/corefs/__init__.py,sha256=9c9AaXBcRrDBf1jhTtJ7DyjOJhX_Zej3FjlcGak7MK8,44
|
|
10
|
-
renard/pipeline/corefs/corefs.py,sha256=
|
|
10
|
+
renard/pipeline/corefs/corefs.py,sha256=d47Sd8ekwhQQV6rQ0F9QyAX2GOTqUnkDUA-eKgMtMS4,11417
|
|
11
11
|
renard/pipeline/graph_extraction.py,sha256=Ga3wfUW9tDtatcTv2taLrNky9jz2wUwZ8uzoXJoSVk8,22928
|
|
12
|
-
renard/pipeline/ner.py,sha256=
|
|
12
|
+
renard/pipeline/ner/__init__.py,sha256=Dqxcf_EKhK1UwiCscZ3gGHInlcxJyvpR4o-ZCLEyV48,38
|
|
13
|
+
renard/pipeline/ner/ner.py,sha256=8zUtaqaGNirfGFRyMpDzdqtO3abrRLyLtjmwnqBNwUI,9893
|
|
14
|
+
renard/pipeline/ner/retrieval.py,sha256=JIU3fi0Q1gl_YGP6kYx6zC9xz4UN6gnqdVuzWVXzzyM,12853
|
|
13
15
|
renard/pipeline/preconfigured.py,sha256=j4-0OUZrmtC8rQfwGWEAAGNxc8-4hlY7N823Uami5lk,5392
|
|
14
16
|
renard/pipeline/preprocessing.py,sha256=OsdsYzmRweAiQV_CtP7uiz--OGogZtQlsdR8XX5DCk0,952
|
|
15
17
|
renard/pipeline/progress.py,sha256=PJ174ssaqr5qHaTrVQ8HqJtvpvX6QhtHM5PHT893_Xk,2689
|
|
@@ -17,9 +19,11 @@ renard/pipeline/quote_detection.py,sha256=FyldJhynIT843fB7rwVtHmDZJqTKkjGml6qTLj
|
|
|
17
19
|
renard/pipeline/sentiment_analysis.py,sha256=76MPin4L1-vSswJe5yGrbCSSDim1LYxSEgNj_BdQDvk,1464
|
|
18
20
|
renard/pipeline/speaker_attribution.py,sha256=Uts6JdUo_sbWyIb2AJ6SO5JuUbgROIpcbUNTg4dHo4U,4329
|
|
19
21
|
renard/pipeline/stanford_corenlp.py,sha256=14b6Ee6oPz1EL-bNRT688aNxVTk_Jwa_vJ20FiBODC4,8189
|
|
20
|
-
renard/pipeline/tokenization.py,sha256=
|
|
21
|
-
renard/plot_utils.py,sha256=
|
|
22
|
+
renard/pipeline/tokenization.py,sha256=gZP0ZpAa0rhtUDPk6W0PiXRxmiC3IcSyRF_E7KaP19A,2957
|
|
23
|
+
renard/plot_utils.py,sha256=qsQI-wbk_5KCXDvt1tPerq4UW4VWLrJpoCet4qkONwE,3344
|
|
22
24
|
renard/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
renard/resources/determiners/__init__.py,sha256=dAcx2hWb_aAd5Rv9rif7CQOvjKcSdIY_mCXJBQQtw60,55
|
|
26
|
+
renard/resources/determiners/determiners.py,sha256=lQ5XGmKWK8h6dcBp0tB2TcEJbkQ9KCHkACJ_gqWjexU,594
|
|
23
27
|
renard/resources/hypocorisms/__init__.py,sha256=vlsY9PqxQCIpijxm79Y0KYh2c0S4S1pgrC9w-AUQGvE,55
|
|
24
28
|
renard/resources/hypocorisms/datas/License.txt,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
25
29
|
renard/resources/hypocorisms/datas/hypocorisms.csv,sha256=CKTo7A5i14NzN6JRBz7U2NJnxrEo8VOlmmdhzEZnqlI,21470
|
|
@@ -29,7 +33,7 @@ renard/resources/pronouns/pronouns.py,sha256=YJ8hM6H8QHrF2Xx6O5blqc-Sqe1D1YFL0sR
|
|
|
29
33
|
renard/resources/titles/__init__.py,sha256=Jcg4B7stsWiAaXbFgNl_L3ICtCQmFe9bo3YjdkVL50w,45
|
|
30
34
|
renard/resources/titles/titles.py,sha256=GsFccVJuTkgDWiAqWZpFd2R9pGvFKQZBOk4RWWuWDkw,968
|
|
31
35
|
renard/utils.py,sha256=WL6djr3iu5Kzo2Jq6qDllHXgvZcEnmqBxPkQf1drq7c,4072
|
|
32
|
-
renard_pipeline-0.
|
|
33
|
-
renard_pipeline-0.
|
|
34
|
-
renard_pipeline-0.
|
|
35
|
-
renard_pipeline-0.
|
|
36
|
+
renard_pipeline-0.6.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
37
|
+
renard_pipeline-0.6.1.dist-info/METADATA,sha256=vijGA3DMBq0Tkn2SJxMKacOw8zI5Z4IDSmIBWBuMEuM,4374
|
|
38
|
+
renard_pipeline-0.6.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
39
|
+
renard_pipeline-0.6.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|