renard-pipeline 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of renard-pipeline might be problematic. Click here for more details.
- renard/ner_utils.py +304 -42
- renard/pipeline/character_unification.py +10 -11
- renard/pipeline/characters_extraction.py +1 -1
- renard/pipeline/core.py +51 -34
- renard/pipeline/graph_extraction.py +7 -10
- renard/pipeline/ner.py +79 -58
- renard/pipeline/stanford_corenlp.py +1 -1
- renard/py.typed +0 -0
- renard/utils.py +1 -52
- {renard_pipeline-0.3.1.dist-info → renard_pipeline-0.4.1.dist-info}/METADATA +42 -4
- {renard_pipeline-0.3.1.dist-info → renard_pipeline-0.4.1.dist-info}/RECORD +13 -12
- {renard_pipeline-0.3.1.dist-info → renard_pipeline-0.4.1.dist-info}/LICENSE +0 -0
- {renard_pipeline-0.3.1.dist-info → renard_pipeline-0.4.1.dist-info}/WHEEL +0 -0
renard/ner_utils.py
CHANGED
|
@@ -1,9 +1,26 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING, List, Optional, Union, Dict, Tuple
|
|
3
|
+
import os, re
|
|
4
|
+
import itertools as it
|
|
5
|
+
import functools as ft
|
|
6
|
+
from more_itertools import flatten
|
|
2
7
|
import torch
|
|
3
8
|
from torch.utils.data import Dataset
|
|
4
|
-
from
|
|
9
|
+
from datasets import Dataset as HGDataset
|
|
10
|
+
from datasets import Sequence, ClassLabel
|
|
11
|
+
from transformers import (
|
|
12
|
+
AutoModelForTokenClassification,
|
|
13
|
+
AutoTokenizer,
|
|
14
|
+
PreTrainedTokenizerFast,
|
|
15
|
+
PreTrainedModel,
|
|
16
|
+
Trainer,
|
|
17
|
+
TrainingArguments,
|
|
18
|
+
)
|
|
5
19
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
6
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from renard.pipeline.ner import NEREntity
|
|
23
|
+
|
|
7
24
|
|
|
8
25
|
class DataCollatorForTokenClassificationWithBatchEncoding:
|
|
9
26
|
"""Same as ``transformers.DataCollatorForTokenClassification``,
|
|
@@ -20,61 +37,306 @@ class DataCollatorForTokenClassificationWithBatchEncoding:
|
|
|
20
37
|
) -> None:
|
|
21
38
|
self.tokenizer = tokenizer
|
|
22
39
|
self.pad_to_multiple_of = pad_to_multiple_of
|
|
23
|
-
self.
|
|
24
|
-
|
|
25
|
-
def __call__(self, features) -> Union[dict, BatchEncoding]:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
batch =
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
40
|
+
self.pad_token_id = {"label": -100, "labels": -100}
|
|
41
|
+
|
|
42
|
+
def __call__(self, features: List[dict]) -> Union[dict, BatchEncoding]:
|
|
43
|
+
keys = features[0].keys()
|
|
44
|
+
sequence_len = max([len(f["input_ids"]) for f in features])
|
|
45
|
+
|
|
46
|
+
# We do the padding and collating manually instead of calling
|
|
47
|
+
# self.tokenizer.pad, because pad does not work on arbitrary
|
|
48
|
+
# features.
|
|
49
|
+
batch = BatchEncoding({})
|
|
50
|
+
for key in keys:
|
|
51
|
+
if self.tokenizer.padding_side == "right":
|
|
52
|
+
batch[key] = [
|
|
53
|
+
f[key]
|
|
54
|
+
+ [self.pad_token_id.get(key, 0)] * (sequence_len - len(f[key]))
|
|
55
|
+
for f in features
|
|
56
|
+
]
|
|
57
|
+
else:
|
|
58
|
+
batch[key] = [
|
|
59
|
+
[
|
|
60
|
+
self.pad_token_id.get(key, 0) * (sequence_len - len(f[key]))
|
|
61
|
+
+ f[key]
|
|
62
|
+
for f in features
|
|
63
|
+
]
|
|
64
|
+
]
|
|
65
|
+
|
|
39
66
|
batch._encodings = [f.encodings[0] for f in features]
|
|
40
67
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
|
|
45
|
-
padding_side = self.tokenizer.padding_side
|
|
46
|
-
if padding_side == "right":
|
|
47
|
-
batch[label_name] = [
|
|
48
|
-
list(label) + [self.label_pad_token_id] * (sequence_length - len(label))
|
|
49
|
-
for label in labels
|
|
50
|
-
]
|
|
51
|
-
else:
|
|
52
|
-
batch[label_name] = [
|
|
53
|
-
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label)
|
|
54
|
-
for label in labels
|
|
55
|
-
]
|
|
68
|
+
for k, v in batch.items():
|
|
69
|
+
batch[k] = torch.tensor(v)
|
|
56
70
|
|
|
57
71
|
return batch
|
|
58
72
|
|
|
59
73
|
|
|
60
74
|
class NERDataset(Dataset):
|
|
75
|
+
"""
|
|
76
|
+
:ivar _context_mask: for each element, a mask indicating which
|
|
77
|
+
tokens are part of the context (1 for context, 0 for text on
|
|
78
|
+
which to perform inference). The mask allows to discard
|
|
79
|
+
predictions made for context at inference time, even though
|
|
80
|
+
the context can still be passed as input to the model.
|
|
81
|
+
"""
|
|
82
|
+
|
|
61
83
|
def __init__(
|
|
62
|
-
self,
|
|
84
|
+
self,
|
|
85
|
+
elements: List[List[str]],
|
|
86
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
87
|
+
context_mask: Optional[List[List[int]]] = None,
|
|
63
88
|
) -> None:
|
|
64
|
-
self.
|
|
89
|
+
self.elements = elements
|
|
90
|
+
|
|
91
|
+
if context_mask:
|
|
92
|
+
assert all(
|
|
93
|
+
[len(cm) == len(elt) for elt, cm in zip(self.elements, context_mask)]
|
|
94
|
+
)
|
|
95
|
+
self._context_mask = context_mask or [[0] * len(elt) for elt in self.elements]
|
|
96
|
+
|
|
65
97
|
self.tokenizer = tokenizer
|
|
66
98
|
|
|
67
|
-
def __getitem__(self, index) -> BatchEncoding:
|
|
99
|
+
def __getitem__(self, index: Union[int, List[int]]) -> BatchEncoding:
|
|
100
|
+
element = self.elements[index]
|
|
101
|
+
|
|
68
102
|
batch = self.tokenizer(
|
|
69
|
-
|
|
70
|
-
return_tensors="pt",
|
|
71
|
-
padding=True,
|
|
103
|
+
element,
|
|
72
104
|
truncation=True,
|
|
105
|
+
max_length=512, # TODO
|
|
73
106
|
is_split_into_words=True,
|
|
74
107
|
)
|
|
75
|
-
|
|
76
|
-
|
|
108
|
+
|
|
109
|
+
batch["context_mask"] = [0] * len(batch["input_ids"])
|
|
110
|
+
elt_context_mask = self._context_mask[index]
|
|
111
|
+
for i in range(len(element)):
|
|
112
|
+
w2t = batch.word_to_tokens(0, i)
|
|
113
|
+
mask_value = elt_context_mask[i]
|
|
114
|
+
tokens_mask = [mask_value] * (w2t.end - w2t.start)
|
|
115
|
+
batch["context_mask"][w2t.start : w2t.end] = tokens_mask
|
|
116
|
+
|
|
77
117
|
return batch
|
|
78
118
|
|
|
79
119
|
def __len__(self) -> int:
|
|
80
|
-
return len(self.
|
|
120
|
+
return len(self.elements)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def ner_entities(
|
|
124
|
+
tokens: List[str], bio_tags: List[str], resolve_inconsistencies: bool = True
|
|
125
|
+
) -> List[NEREntity]:
|
|
126
|
+
"""Extract NER entities from a list of BIO tags
|
|
127
|
+
|
|
128
|
+
:param tokens: a list of tokens
|
|
129
|
+
:param bio_tags: a list of BIO tags. In particular, BIO tags
|
|
130
|
+
should be in the CoNLL-2002 form (such as 'B-PER I-PER')
|
|
131
|
+
|
|
132
|
+
:return: A list of ner entities, in apparition order
|
|
133
|
+
"""
|
|
134
|
+
from renard.pipeline.ner import NEREntity
|
|
135
|
+
|
|
136
|
+
assert len(tokens) == len(bio_tags)
|
|
137
|
+
|
|
138
|
+
entities = []
|
|
139
|
+
current_tag: Optional[str] = None
|
|
140
|
+
current_tag_start_idx: Optional[int] = None
|
|
141
|
+
|
|
142
|
+
for i, tag in enumerate(bio_tags):
|
|
143
|
+
if not current_tag is None and not tag.startswith("I-"):
|
|
144
|
+
assert not current_tag_start_idx is None
|
|
145
|
+
entities.append(
|
|
146
|
+
NEREntity(
|
|
147
|
+
tokens[current_tag_start_idx:i],
|
|
148
|
+
current_tag_start_idx,
|
|
149
|
+
i,
|
|
150
|
+
current_tag,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
current_tag = None
|
|
154
|
+
current_tag_start_idx = None
|
|
155
|
+
|
|
156
|
+
if tag.startswith("B-"):
|
|
157
|
+
current_tag = tag[2:]
|
|
158
|
+
current_tag_start_idx = i
|
|
159
|
+
|
|
160
|
+
elif tag.startswith("I-"):
|
|
161
|
+
if current_tag is None and resolve_inconsistencies:
|
|
162
|
+
current_tag = tag[2:]
|
|
163
|
+
current_tag_start_idx = i
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
if not current_tag is None:
|
|
167
|
+
assert not current_tag_start_idx is None
|
|
168
|
+
entities.append(
|
|
169
|
+
NEREntity(
|
|
170
|
+
tokens[current_tag_start_idx : len(tokens)],
|
|
171
|
+
current_tag_start_idx,
|
|
172
|
+
len(bio_tags),
|
|
173
|
+
current_tag,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return entities
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def load_conll2002_bio(
|
|
181
|
+
path: str,
|
|
182
|
+
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
183
|
+
separator: str = "\t",
|
|
184
|
+
**kwargs,
|
|
185
|
+
) -> Tuple[List[List[str]], List[str], List[NEREntity]]:
|
|
186
|
+
"""Load a file under CoNLL2022 BIO format. Sentences are expected
|
|
187
|
+
to be separated by end of lines. Tags should be in the CoNLL-2002
|
|
188
|
+
format (such as 'B-PER I-PER') - If this is not the case, see the
|
|
189
|
+
``tag_conversion_map`` argument.
|
|
190
|
+
|
|
191
|
+
:param path: path to the CoNLL-2002 formatted file
|
|
192
|
+
:param separator: separator between token and BIO tags
|
|
193
|
+
:param tag_conversion_map: conversion map for tags found in the
|
|
194
|
+
input file. Example : ``{'B': 'B-PER', 'I': 'I-PER'}``
|
|
195
|
+
:param kwargs: additional kwargs for ``open`` (such as
|
|
196
|
+
``encoding`` or ``newline``).
|
|
197
|
+
|
|
198
|
+
:return: ``(sentences, tokens, entities)``
|
|
199
|
+
"""
|
|
200
|
+
tag_conversion_map = tag_conversion_map or {}
|
|
201
|
+
|
|
202
|
+
with open(os.path.expanduser(path), **kwargs) as f:
|
|
203
|
+
raw_data = f.read()
|
|
204
|
+
|
|
205
|
+
sents = []
|
|
206
|
+
sent_tokens = []
|
|
207
|
+
tags = []
|
|
208
|
+
for line in raw_data.split("\n"):
|
|
209
|
+
line = line.strip("\n")
|
|
210
|
+
if re.fullmatch(r"\s*", line):
|
|
211
|
+
if len(sent_tokens) == 0:
|
|
212
|
+
continue
|
|
213
|
+
sents.append(sent_tokens)
|
|
214
|
+
sent_tokens = []
|
|
215
|
+
continue
|
|
216
|
+
token, tag = line.split(separator)
|
|
217
|
+
sent_tokens.append(token)
|
|
218
|
+
tags.append(tag_conversion_map.get(tag, tag))
|
|
219
|
+
|
|
220
|
+
tokens = list(flatten(sents))
|
|
221
|
+
entities = ner_entities(tokens, tags)
|
|
222
|
+
|
|
223
|
+
return sents, list(flatten(sents)), entities
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def hgdataset_from_conll2002(
|
|
227
|
+
path: str,
|
|
228
|
+
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
229
|
+
separator: str = "\t",
|
|
230
|
+
**kwargs,
|
|
231
|
+
) -> HGDataset:
|
|
232
|
+
"""Load a CoNLL-2002 file as a Huggingface Dataset.
|
|
233
|
+
|
|
234
|
+
:param path: passed to :func:`.load_conll2002_bio`
|
|
235
|
+
:param tag_conversion_map: passed to :func:`load_conll2002_bio`
|
|
236
|
+
:param separator: passed to :func:`load_conll2002_bio`
|
|
237
|
+
:param kwargs: passed to :func:`load_conll2002_bio`
|
|
238
|
+
|
|
239
|
+
:return: a :class:`datasets.Dataset` with features 'tokens' and 'labels'.
|
|
240
|
+
"""
|
|
241
|
+
sentences, tokens, entities = load_conll2002_bio(
|
|
242
|
+
path, tag_conversion_map, separator, **kwargs
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# convert entities to labels
|
|
246
|
+
tags = ["O"] * len(tokens)
|
|
247
|
+
for entity in entities:
|
|
248
|
+
entity_len = entity.end_idx - entity.start_idx
|
|
249
|
+
tags[entity.start_idx : entity.end_idx] = [f"B-{entity.tag}"] + [
|
|
250
|
+
f"I-{entity.tag}"
|
|
251
|
+
] * (entity_len - 1)
|
|
252
|
+
|
|
253
|
+
# cut into sentences
|
|
254
|
+
sent_ends = list(it.accumulate([len(s) for s in sentences]))
|
|
255
|
+
sent_starts = [0] + sent_ends[:-1]
|
|
256
|
+
sent_tags = [
|
|
257
|
+
tags[sent_start:sent_end]
|
|
258
|
+
for sent_start, sent_end in zip(sent_starts, sent_ends)
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
dataset = HGDataset.from_dict({"tokens": sentences, "labels": sent_tags})
|
|
262
|
+
dataset = dataset.cast_column(
|
|
263
|
+
"labels", Sequence(ClassLabel(names=sorted(set(tags))))
|
|
264
|
+
)
|
|
265
|
+
return dataset
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _tokenize_and_align_labels(
|
|
269
|
+
examples, tokenizer: PreTrainedTokenizerFast, label_all_tokens: bool = True
|
|
270
|
+
):
|
|
271
|
+
"""Adapted from https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/token_classification.ipynb#scrollTo=vc0BSBLIIrJQ
|
|
272
|
+
|
|
273
|
+
:param examples: an object with keys 'tokens' and 'labels'
|
|
274
|
+
"""
|
|
275
|
+
tokenized_inputs = tokenizer(
|
|
276
|
+
examples["tokens"], truncation=True, is_split_into_words=True
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
labels = []
|
|
280
|
+
for i, label in enumerate(examples[f"labels"]):
|
|
281
|
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
|
282
|
+
previous_word_idx = None
|
|
283
|
+
label_ids = []
|
|
284
|
+
for word_idx in word_ids:
|
|
285
|
+
# Special tokens have a word id that is None. We set the
|
|
286
|
+
# label to -100 so they are automatically ignored in the
|
|
287
|
+
# loss function.
|
|
288
|
+
if word_idx is None:
|
|
289
|
+
label_ids.append(-100)
|
|
290
|
+
# We set the label for the first token of each word.
|
|
291
|
+
elif word_idx != previous_word_idx:
|
|
292
|
+
label_ids.append(label[word_idx])
|
|
293
|
+
# For the other tokens in a word, we set the label to
|
|
294
|
+
# either the current label or -100, depending on the
|
|
295
|
+
# label_all_tokens flag.
|
|
296
|
+
else:
|
|
297
|
+
label_ids.append(label[word_idx] if label_all_tokens else -100)
|
|
298
|
+
previous_word_idx = word_idx
|
|
299
|
+
|
|
300
|
+
labels.append(label_ids)
|
|
301
|
+
|
|
302
|
+
tokenized_inputs["labels"] = labels
|
|
303
|
+
|
|
304
|
+
return tokenized_inputs
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def train_ner_model(
|
|
308
|
+
hg_id: str,
|
|
309
|
+
dataset: HGDataset,
|
|
310
|
+
targs: TrainingArguments,
|
|
311
|
+
) -> PreTrainedModel:
|
|
312
|
+
from transformers import DataCollatorForTokenClassification
|
|
313
|
+
|
|
314
|
+
# BERT tokenizer splits tokens into subtokens. The
|
|
315
|
+
# tokenize_and_align_labels function correctly aligns labels and
|
|
316
|
+
# subtokens.
|
|
317
|
+
tokenizer = AutoTokenizer.from_pretrained(hg_id)
|
|
318
|
+
dataset = dataset.map(
|
|
319
|
+
ft.partial(_tokenize_and_align_labels, tokenizer=tokenizer), batched=True
|
|
320
|
+
)
|
|
321
|
+
dataset = dataset.train_test_split(test_size=0.1)
|
|
322
|
+
|
|
323
|
+
label_lst = dataset["train"].features["labels"].feature.names
|
|
324
|
+
model = AutoModelForTokenClassification.from_pretrained(
|
|
325
|
+
hg_id,
|
|
326
|
+
num_labels=len(label_lst),
|
|
327
|
+
id2label={i: label for i, label in enumerate(label_lst)},
|
|
328
|
+
label2id={label: i for i, label in enumerate(label_lst)},
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
trainer = Trainer(
|
|
332
|
+
model,
|
|
333
|
+
targs,
|
|
334
|
+
train_dataset=dataset["train"],
|
|
335
|
+
eval_dataset=dataset["test"],
|
|
336
|
+
# data_collator=DataCollatorForTokenClassificationWithBatchEncoding(tokenizer),
|
|
337
|
+
data_collator=DataCollatorForTokenClassification(tokenizer),
|
|
338
|
+
tokenizer=tokenizer,
|
|
339
|
+
)
|
|
340
|
+
trainer.train()
|
|
341
|
+
|
|
342
|
+
return model
|
|
@@ -54,8 +54,8 @@ def _assign_coreference_mentions(
|
|
|
54
54
|
:param corefs:
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
|
-
char_mentions: Dict[Character,
|
|
58
|
-
character: character.mentions for character in characters
|
|
57
|
+
char_mentions: Dict[Character, Set[Mention]] = {
|
|
58
|
+
character: set(character.mentions) for character in characters
|
|
59
59
|
}
|
|
60
60
|
|
|
61
61
|
# we assign each chain to the character with highest name
|
|
@@ -80,12 +80,12 @@ def _assign_coreference_mentions(
|
|
|
80
80
|
|
|
81
81
|
# assign the chain to the character with the most occurences
|
|
82
82
|
for mention in chain:
|
|
83
|
-
# TODO: complexity
|
|
84
83
|
if not mention in char_mentions[best_character]:
|
|
85
|
-
char_mentions[best_character].
|
|
84
|
+
char_mentions[best_character].add(mention)
|
|
86
85
|
|
|
87
86
|
return [
|
|
88
|
-
Character(c.names, mentions,
|
|
87
|
+
Character(c.names, sorted(mentions, key=lambda m: m.start_idx), c.gender)
|
|
88
|
+
for c, mentions in char_mentions.items()
|
|
89
89
|
]
|
|
90
90
|
|
|
91
91
|
|
|
@@ -209,7 +209,6 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
209
209
|
|
|
210
210
|
# * link nodes based on several rules
|
|
211
211
|
for name1, name2 in combinations(G.nodes(), 2):
|
|
212
|
-
|
|
213
212
|
# is one name a known hypocorism of the other ? (also
|
|
214
213
|
# checks if both names are the same)
|
|
215
214
|
if self.hypocorism_gazetteer.are_related(name1, name2):
|
|
@@ -263,7 +262,6 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
263
262
|
pass
|
|
264
263
|
|
|
265
264
|
for name1, name2 in combinations(G.nodes(), 2):
|
|
266
|
-
|
|
267
265
|
# check if characters have the same last name but a
|
|
268
266
|
# different first name.
|
|
269
267
|
human_name1 = HumanName(name1, constants=hname_constants)
|
|
@@ -333,10 +331,11 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
333
331
|
self, name1: str, name2: str, hname_constants: Constants
|
|
334
332
|
) -> bool:
|
|
335
333
|
"""Check if two names are related after removing their titles"""
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
raw_name1 = HumanName(name1, constants=
|
|
339
|
-
raw_name2 = HumanName(name2, constants=
|
|
334
|
+
old_string_format = hname_constants.string_format
|
|
335
|
+
hname_constants.string_format = "{first} {middle} {last}"
|
|
336
|
+
raw_name1 = HumanName(name1, constants=hname_constants).full_name
|
|
337
|
+
raw_name2 = HumanName(name2, constants=hname_constants).full_name
|
|
338
|
+
hname_constants.string_format = old_string_format
|
|
340
339
|
|
|
341
340
|
if raw_name1 == "" or raw_name2 == "":
|
|
342
341
|
return False
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import renard.pipeline.character_unification as cu
|
|
2
2
|
|
|
3
3
|
print(
|
|
4
|
-
"[warning] the characters_extraction module is deprecated. Use
|
|
4
|
+
"[warning] the characters_extraction module is deprecated. Use character_unification instead."
|
|
5
5
|
)
|
|
6
6
|
|
|
7
7
|
Character = cu.Character
|
renard/pipeline/core.py
CHANGED
|
@@ -16,7 +16,7 @@ from typing import (
|
|
|
16
16
|
Type,
|
|
17
17
|
TYPE_CHECKING,
|
|
18
18
|
)
|
|
19
|
-
import os
|
|
19
|
+
import os, sys
|
|
20
20
|
|
|
21
21
|
import networkx as nx
|
|
22
22
|
from networkx.readwrite.gexf import GEXFWriter
|
|
@@ -50,6 +50,13 @@ class Mention:
|
|
|
50
50
|
self_dict["end_idx"] = self.end_idx + shift
|
|
51
51
|
return self.__class__(**self_dict)
|
|
52
52
|
|
|
53
|
+
def __eq__(self, other: Mention) -> bool:
|
|
54
|
+
return (
|
|
55
|
+
self.tokens == other.tokens
|
|
56
|
+
and self.start_idx == other.start_idx
|
|
57
|
+
and self.end_idx == other.end_idx
|
|
58
|
+
)
|
|
59
|
+
|
|
53
60
|
def __hash__(self) -> int:
|
|
54
61
|
return hash(tuple(self.tokens) + (self.start_idx, self.end_idx))
|
|
55
62
|
|
|
@@ -171,8 +178,18 @@ class PipelineState:
|
|
|
171
178
|
#: detected characters
|
|
172
179
|
characters: Optional[List[Character]] = None
|
|
173
180
|
|
|
174
|
-
#:
|
|
175
|
-
|
|
181
|
+
#: character network (or list of network in the case of a dynamic
|
|
182
|
+
#: network)
|
|
183
|
+
character_network: Optional[Union[List[nx.Graph], nx.Graph]] = None
|
|
184
|
+
|
|
185
|
+
def get_characters_graph(self) -> Optional[Union[List[nx.Graph], nx.Graph]]:
|
|
186
|
+
print(
|
|
187
|
+
"[warning] the characters_graph attribute is deprecated, use character_network instead",
|
|
188
|
+
file=sys.stderr,
|
|
189
|
+
)
|
|
190
|
+
return self.character_network
|
|
191
|
+
|
|
192
|
+
characters_graph = property(get_characters_graph)
|
|
176
193
|
|
|
177
194
|
def get_character(
|
|
178
195
|
self, name: str, partial_match: bool = True
|
|
@@ -228,8 +245,8 @@ class PipelineState:
|
|
|
228
245
|
for more details
|
|
229
246
|
"""
|
|
230
247
|
path = os.path.expanduser(path)
|
|
231
|
-
if isinstance(self.
|
|
232
|
-
G = dynamic_graph_to_gephi_graph(self.
|
|
248
|
+
if isinstance(self.character_network, list):
|
|
249
|
+
G = dynamic_graph_to_gephi_graph(self.character_network)
|
|
233
250
|
G = graph_with_names(G, name_style)
|
|
234
251
|
# HACK: networkx cannot set a dynamic "weight" attribute
|
|
235
252
|
# in gexf since "weight" has a specific meaning in
|
|
@@ -251,7 +268,7 @@ class PipelineState:
|
|
|
251
268
|
attvalue.set("for", "weight")
|
|
252
269
|
writer.write(path)
|
|
253
270
|
else:
|
|
254
|
-
G = graph_with_names(self.
|
|
271
|
+
G = graph_with_names(self.character_network, name_style)
|
|
255
272
|
nx.write_gexf(G, path)
|
|
256
273
|
|
|
257
274
|
def plot_graphs_to_dir(
|
|
@@ -280,23 +297,23 @@ class PipelineState:
|
|
|
280
297
|
"""
|
|
281
298
|
import matplotlib.pyplot as plt
|
|
282
299
|
|
|
283
|
-
assert not self.
|
|
284
|
-
if isinstance(self.
|
|
300
|
+
assert not self.character_network is None
|
|
301
|
+
if isinstance(self.character_network, nx.Graph):
|
|
285
302
|
raise ValueError("this function is supposed to be used on a dynamic graph")
|
|
286
303
|
|
|
287
304
|
directory = directory.rstrip("/")
|
|
288
305
|
directory = os.path.expanduser(directory)
|
|
289
306
|
os.makedirs(directory, exist_ok=True)
|
|
290
307
|
|
|
291
|
-
graphs = self.
|
|
308
|
+
graphs = self.character_network
|
|
292
309
|
if cumulative:
|
|
293
|
-
graphs = cumulative_graph(self.
|
|
310
|
+
graphs = cumulative_graph(self.character_network)
|
|
294
311
|
|
|
295
312
|
if stable_layout:
|
|
296
313
|
layout_graph = (
|
|
297
314
|
graphs[-1]
|
|
298
315
|
if cumulative
|
|
299
|
-
else cumulative_graph(self.
|
|
316
|
+
else cumulative_graph(self.character_network)[-1]
|
|
300
317
|
)
|
|
301
318
|
layout = layout_nx_graph_reasonably(layout_graph)
|
|
302
319
|
|
|
@@ -330,13 +347,13 @@ class PipelineState:
|
|
|
330
347
|
"""
|
|
331
348
|
import matplotlib.pyplot as plt
|
|
332
349
|
|
|
333
|
-
assert not self.
|
|
334
|
-
if isinstance(self.
|
|
350
|
+
assert not self.character_network is None
|
|
351
|
+
if isinstance(self.character_network, list):
|
|
335
352
|
raise ValueError("this function is supposed to be used on a static graph")
|
|
336
353
|
|
|
337
354
|
if not layout is None:
|
|
338
|
-
layout = layout_with_names(self.
|
|
339
|
-
G = graph_with_names(self.
|
|
355
|
+
layout = layout_with_names(self.character_network, layout, name_style)
|
|
356
|
+
G = graph_with_names(self.character_network, name_style=name_style)
|
|
340
357
|
if fig is None:
|
|
341
358
|
# default values for a sufficiently sized graph
|
|
342
359
|
fig = plt.gcf()
|
|
@@ -359,7 +376,7 @@ class PipelineState:
|
|
|
359
376
|
stable_layout: bool = False,
|
|
360
377
|
layout: Optional[CharactersGraphLayout] = None,
|
|
361
378
|
):
|
|
362
|
-
"""Plot ``self.
|
|
379
|
+
"""Plot ``self.character_network`` using reasonable default
|
|
363
380
|
parameters
|
|
364
381
|
|
|
365
382
|
.. note::
|
|
@@ -372,13 +389,13 @@ class PipelineState:
|
|
|
372
389
|
details
|
|
373
390
|
:param fig: if specified, this matplotlib figure will be used
|
|
374
391
|
for plotting
|
|
375
|
-
:param cumulative: if ``True`` and ``self.
|
|
392
|
+
:param cumulative: if ``True`` and ``self.character_network``
|
|
376
393
|
is dynamic, plot a cumulative graph instead of a
|
|
377
394
|
sequential one
|
|
378
|
-
:param graph_start_idx: When ``self.
|
|
395
|
+
:param graph_start_idx: When ``self.character_network`` is
|
|
379
396
|
dynamic, index of the first graph to plot, starting at 1
|
|
380
397
|
(not 0, since the graph slider starts at 1)
|
|
381
|
-
:param stable_layout: if ``self.
|
|
398
|
+
:param stable_layout: if ``self.character_network`` is dynamic
|
|
382
399
|
and this parameter is ``True``, characters will keep the
|
|
383
400
|
same position in space at each timestep. Characters'
|
|
384
401
|
positions are based on the final cumulative graph layout.
|
|
@@ -387,13 +404,13 @@ class PipelineState:
|
|
|
387
404
|
import matplotlib.pyplot as plt
|
|
388
405
|
from matplotlib.widgets import Slider
|
|
389
406
|
|
|
390
|
-
assert not self.
|
|
407
|
+
assert not self.character_network is None
|
|
391
408
|
|
|
392
|
-
# self.
|
|
393
|
-
if isinstance(self.
|
|
409
|
+
# self.character_network is a static graph
|
|
410
|
+
if isinstance(self.character_network, nx.Graph):
|
|
394
411
|
if not layout is None:
|
|
395
|
-
layout = layout_with_names(self.
|
|
396
|
-
G = graph_with_names(self.
|
|
412
|
+
layout = layout_with_names(self.character_network, layout, name_style)
|
|
413
|
+
G = graph_with_names(self.character_network, name_style)
|
|
397
414
|
if fig is None:
|
|
398
415
|
# default value for a sufficiently sized graph
|
|
399
416
|
fig = plt.gcf()
|
|
@@ -404,9 +421,9 @@ class PipelineState:
|
|
|
404
421
|
plot_nx_graph_reasonably(G, ax=ax, layout=layout)
|
|
405
422
|
return
|
|
406
423
|
|
|
407
|
-
if not isinstance(self.
|
|
424
|
+
if not isinstance(self.character_network, list):
|
|
408
425
|
raise TypeError
|
|
409
|
-
# self.
|
|
426
|
+
# self.character_network is a list: plot a dynamic graph
|
|
410
427
|
|
|
411
428
|
if fig is None:
|
|
412
429
|
fig, ax = plt.subplots()
|
|
@@ -417,18 +434,18 @@ class PipelineState:
|
|
|
417
434
|
ax = fig.add_subplot(111)
|
|
418
435
|
assert not fig is None
|
|
419
436
|
|
|
420
|
-
|
|
437
|
+
cumulative_character_networks = cumulative_graph(self.character_network)
|
|
421
438
|
if stable_layout:
|
|
422
|
-
layout = layout_nx_graph_reasonably(
|
|
439
|
+
layout = layout_nx_graph_reasonably(cumulative_character_networks[-1])
|
|
423
440
|
|
|
424
441
|
def update(slider_value):
|
|
425
|
-
assert isinstance(self.
|
|
442
|
+
assert isinstance(self.character_network, list)
|
|
426
443
|
|
|
427
|
-
|
|
444
|
+
character_networks = self.character_network
|
|
428
445
|
if cumulative:
|
|
429
|
-
|
|
446
|
+
character_networks = cumulative_character_networks
|
|
430
447
|
|
|
431
|
-
G =
|
|
448
|
+
G = character_networks[int(slider_value) - 1]
|
|
432
449
|
|
|
433
450
|
local_layout = layout
|
|
434
451
|
if not local_layout is None:
|
|
@@ -447,8 +464,8 @@ class PipelineState:
|
|
|
447
464
|
ax=slider_ax,
|
|
448
465
|
label="Graph",
|
|
449
466
|
valmin=1,
|
|
450
|
-
valmax=len(self.
|
|
451
|
-
valstep=[i + 1 for i in range(len(self.
|
|
467
|
+
valmax=len(self.character_network),
|
|
468
|
+
valstep=[i + 1 for i in range(len(self.character_network))],
|
|
452
469
|
)
|
|
453
470
|
fig.slider.on_changed(update) # type: ignore
|
|
454
471
|
fig.slider.set_val(graph_start_idx) # type: ignore
|
|
@@ -158,7 +158,7 @@ class CoOccurrencesGraphExtractor(PipelineStep):
|
|
|
158
158
|
|
|
159
159
|
:param characters:
|
|
160
160
|
|
|
161
|
-
:return: a ``dict`` with key ``'
|
|
161
|
+
:return: a ``dict`` with key ``'character_network'`` and a
|
|
162
162
|
:class:`nx.Graph` or a list of :class:`nx.Graph` as
|
|
163
163
|
value.
|
|
164
164
|
"""
|
|
@@ -170,7 +170,7 @@ class CoOccurrencesGraphExtractor(PipelineStep):
|
|
|
170
170
|
|
|
171
171
|
if self.dynamic:
|
|
172
172
|
return {
|
|
173
|
-
"
|
|
173
|
+
"character_network": self._extract_dynamic_graph(
|
|
174
174
|
mentions,
|
|
175
175
|
self.dynamic_window,
|
|
176
176
|
self.dynamic_overlap,
|
|
@@ -180,7 +180,7 @@ class CoOccurrencesGraphExtractor(PipelineStep):
|
|
|
180
180
|
)
|
|
181
181
|
}
|
|
182
182
|
return {
|
|
183
|
-
"
|
|
183
|
+
"character_network": self._extract_graph(
|
|
184
184
|
mentions, sentences, sentences_polarities
|
|
185
185
|
)
|
|
186
186
|
}
|
|
@@ -419,7 +419,7 @@ class CoOccurrencesGraphExtractor(PipelineStep):
|
|
|
419
419
|
return needs
|
|
420
420
|
|
|
421
421
|
def production(self) -> Set[str]:
|
|
422
|
-
return {"
|
|
422
|
+
return {"character_network"}
|
|
423
423
|
|
|
424
424
|
def optional_needs(self) -> Set[str]:
|
|
425
425
|
return {"sentences_polarities"}
|
|
@@ -475,20 +475,17 @@ class ConversationalGraphExtractor(PipelineStep):
|
|
|
475
475
|
characters: Set[Character],
|
|
476
476
|
**kwargs,
|
|
477
477
|
) -> Dict[str, Any]:
|
|
478
|
-
|
|
479
478
|
G = nx.Graph()
|
|
480
479
|
for character in characters:
|
|
481
480
|
G.add_node(character)
|
|
482
481
|
|
|
483
482
|
for i, (quote_1, speaker_1) in enumerate(zip(quotes, speakers)):
|
|
484
|
-
|
|
485
483
|
# no speaker prediction: ignore
|
|
486
484
|
if speaker_1 is None:
|
|
487
485
|
continue
|
|
488
486
|
|
|
489
487
|
# check ahead for co-occurences
|
|
490
488
|
for quote_2, speaker_2 in zip(quotes[i + 1 :], speakers[i + 1 :]):
|
|
491
|
-
|
|
492
489
|
# no speaker prediction: ignore
|
|
493
490
|
if speaker_2 is None:
|
|
494
491
|
continue
|
|
@@ -507,12 +504,12 @@ class ConversationalGraphExtractor(PipelineStep):
|
|
|
507
504
|
G.add_edge(speaker_1, speaker_2, weight=0)
|
|
508
505
|
G.edges[speaker_1, speaker_2]["weight"] += 1
|
|
509
506
|
|
|
510
|
-
return {"
|
|
507
|
+
return {"character_network": G}
|
|
511
508
|
|
|
512
509
|
def needs(self) -> Set[str]:
|
|
513
510
|
"""sentences, quotes, speakers, characters"""
|
|
514
511
|
return {"sentences", "quotes", "speakers", "characters"}
|
|
515
512
|
|
|
516
513
|
def production(self) -> Set[str]:
|
|
517
|
-
"""
|
|
518
|
-
return {"
|
|
514
|
+
"""character_network"""
|
|
515
|
+
return {"character_network"}
|
renard/pipeline/ner.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import random, itertools
|
|
2
3
|
from typing import TYPE_CHECKING, List, Dict, Any, Set, Tuple, Optional, Union, Literal
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
import torch
|
|
@@ -10,6 +11,7 @@ from renard.ner_utils import (
|
|
|
10
11
|
)
|
|
11
12
|
from renard.pipeline.core import PipelineStep, Mention
|
|
12
13
|
from renard.pipeline.progress import ProgressReporter
|
|
14
|
+
from renard.ner_utils import ner_entities
|
|
13
15
|
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
17
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
@@ -32,60 +34,8 @@ class NEREntity(Mention):
|
|
|
32
34
|
"""
|
|
33
35
|
return super().shifted(shift) # type: ignore
|
|
34
36
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
tokens: List[str], bio_tags: List[str], resolve_inconsistencies: bool = True
|
|
38
|
-
) -> List[NEREntity]:
|
|
39
|
-
"""Extract NER entities from a list of BIO tags
|
|
40
|
-
|
|
41
|
-
:param tokens: a list of tokens
|
|
42
|
-
:param bio_tags: a list of BIO tags. In particular, BIO tags
|
|
43
|
-
should be in the CoNLL-2002 form (such as 'B-PER I-PER')
|
|
44
|
-
|
|
45
|
-
:return: A list of ner entities, in apparition order
|
|
46
|
-
"""
|
|
47
|
-
assert len(tokens) == len(bio_tags)
|
|
48
|
-
|
|
49
|
-
entities = []
|
|
50
|
-
current_tag: Optional[str] = None
|
|
51
|
-
current_tag_start_idx: Optional[int] = None
|
|
52
|
-
|
|
53
|
-
for i, tag in enumerate(bio_tags):
|
|
54
|
-
if not current_tag is None and not tag.startswith("I-"):
|
|
55
|
-
assert not current_tag_start_idx is None
|
|
56
|
-
entities.append(
|
|
57
|
-
NEREntity(
|
|
58
|
-
tokens[current_tag_start_idx:i],
|
|
59
|
-
current_tag_start_idx,
|
|
60
|
-
i,
|
|
61
|
-
current_tag,
|
|
62
|
-
)
|
|
63
|
-
)
|
|
64
|
-
current_tag = None
|
|
65
|
-
current_tag_start_idx = None
|
|
66
|
-
|
|
67
|
-
if tag.startswith("B-"):
|
|
68
|
-
current_tag = tag[2:]
|
|
69
|
-
current_tag_start_idx = i
|
|
70
|
-
|
|
71
|
-
elif tag.startswith("I-"):
|
|
72
|
-
if current_tag is None and resolve_inconsistencies:
|
|
73
|
-
current_tag = tag[2:]
|
|
74
|
-
current_tag_start_idx = i
|
|
75
|
-
continue
|
|
76
|
-
|
|
77
|
-
if not current_tag is None:
|
|
78
|
-
assert not current_tag_start_idx is None
|
|
79
|
-
entities.append(
|
|
80
|
-
NEREntity(
|
|
81
|
-
tokens[current_tag_start_idx : len(tokens)],
|
|
82
|
-
current_tag_start_idx,
|
|
83
|
-
len(bio_tags),
|
|
84
|
-
current_tag,
|
|
85
|
-
)
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
return entities
|
|
37
|
+
def __hash__(self) -> int:
|
|
38
|
+
return hash(tuple(self.tokens) + (self.start_idx, self.end_idx, self.tag))
|
|
89
39
|
|
|
90
40
|
|
|
91
41
|
def score_ner(
|
|
@@ -151,11 +101,69 @@ class NLTKNamedEntityRecognizer(PipelineStep):
|
|
|
151
101
|
return {"entities"}
|
|
152
102
|
|
|
153
103
|
|
|
104
|
+
class NERContextRetriever:
|
|
105
|
+
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
106
|
+
raise NotImplementedError
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class NERSamenounContextRetriever(NERContextRetriever):
|
|
110
|
+
"""
|
|
111
|
+
Retrieve relevant context using the samenoun strategy as in
|
|
112
|
+
Amalvy et al. 2023.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, k: int) -> None:
|
|
116
|
+
"""
|
|
117
|
+
:param k: the number of sentences to retrieve
|
|
118
|
+
"""
|
|
119
|
+
self.k = k
|
|
120
|
+
|
|
121
|
+
def __call__(self, dataset: NERDataset) -> NERDataset:
|
|
122
|
+
import nltk
|
|
123
|
+
|
|
124
|
+
# NOTE: POS tagging is not incorporated in the pipeline yet,
|
|
125
|
+
# so we manually compute it here.
|
|
126
|
+
elements_names = [
|
|
127
|
+
{t[0] for t in nltk.pos_tag(element) if t[1].startswith("NN")}
|
|
128
|
+
for element in dataset.elements
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
elements_with_context = []
|
|
132
|
+
|
|
133
|
+
for elt_i, elt in enumerate(dataset.elements):
|
|
134
|
+
retrieved_elts = [
|
|
135
|
+
other_elt
|
|
136
|
+
for other_elt_i, other_elt in enumerate(dataset.elements)
|
|
137
|
+
if not other_elt_i == elt_i
|
|
138
|
+
and len(elements_names[elt_i].intersection(elements_names[other_elt_i]))
|
|
139
|
+
> 0
|
|
140
|
+
]
|
|
141
|
+
retrieved_elts = random.sample(
|
|
142
|
+
retrieved_elts, k=min(self.k, len(retrieved_elts))
|
|
143
|
+
)
|
|
144
|
+
elements_with_context.append(
|
|
145
|
+
(
|
|
146
|
+
elt,
|
|
147
|
+
[dataset.tokenizer.sep_token]
|
|
148
|
+
+ list(itertools.chain.from_iterable(retrieved_elts)),
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return NERDataset(
|
|
153
|
+
[element + context for element, context in elements_with_context],
|
|
154
|
+
dataset.tokenizer,
|
|
155
|
+
[
|
|
156
|
+
[0] * len(element) + [1] * len(context)
|
|
157
|
+
for element, context in elements_with_context
|
|
158
|
+
],
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
154
162
|
class BertNamedEntityRecognizer(PipelineStep):
|
|
155
163
|
"""An entity recognizer based on BERT"""
|
|
156
164
|
|
|
157
165
|
LANG_TO_MODELS = {
|
|
158
|
-
"fra": "
|
|
166
|
+
"fra": "compnet-renard/camembert-base-literary-NER",
|
|
159
167
|
"eng": "compnet-renard/bert-base-cased-literary-NER",
|
|
160
168
|
}
|
|
161
169
|
|
|
@@ -165,6 +173,7 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
165
173
|
batch_size: int = 4,
|
|
166
174
|
device: Literal["cpu", "cuda", "auto"] = "auto",
|
|
167
175
|
tokenizer: Optional[PreTrainedTokenizerFast] = None,
|
|
176
|
+
context_retriever: Optional[NERContextRetriever] = None,
|
|
168
177
|
):
|
|
169
178
|
"""
|
|
170
179
|
:param model: Either:
|
|
@@ -181,6 +190,9 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
181
190
|
:param batch_size: batch size at inference
|
|
182
191
|
:param device: computation device
|
|
183
192
|
:param tokenizer: a custom tokenizer
|
|
193
|
+
:param context_retriever: if specified, use
|
|
194
|
+
``context_retriever`` to retrieve relevant global context
|
|
195
|
+
at run time, generally trading runtme for NER performance.
|
|
184
196
|
"""
|
|
185
197
|
if isinstance(model, str):
|
|
186
198
|
self.huggingface_model_id = model
|
|
@@ -198,6 +210,8 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
198
210
|
else:
|
|
199
211
|
self.device = torch.device(device)
|
|
200
212
|
|
|
213
|
+
self.context_retriever = context_retriever
|
|
214
|
+
|
|
201
215
|
super().__init__()
|
|
202
216
|
|
|
203
217
|
def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
|
|
@@ -208,7 +222,6 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
208
222
|
# init model if needed (this happens if the user did not pass
|
|
209
223
|
# the instance of a model)
|
|
210
224
|
if self.model is None:
|
|
211
|
-
|
|
212
225
|
# the user supplied a huggingface ID: load model from the HUB
|
|
213
226
|
if not self.huggingface_model_id is None:
|
|
214
227
|
self.model = AutoModelForTokenClassification.from_pretrained(
|
|
@@ -251,6 +264,10 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
251
264
|
self.model = self.model.to(self.device)
|
|
252
265
|
|
|
253
266
|
dataset = NERDataset(sentences, self.tokenizer)
|
|
267
|
+
|
|
268
|
+
if not self.context_retriever is None:
|
|
269
|
+
dataset = self.context_retriever(dataset)
|
|
270
|
+
|
|
254
271
|
dataloader = DataLoader(
|
|
255
272
|
dataset,
|
|
256
273
|
batch_size=self.batch_size,
|
|
@@ -262,7 +279,6 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
262
279
|
labels = []
|
|
263
280
|
|
|
264
281
|
with torch.no_grad():
|
|
265
|
-
|
|
266
282
|
for batch_i, batch in enumerate(self._progress_(dataloader)):
|
|
267
283
|
out = self.model(
|
|
268
284
|
batch["input_ids"].to(self.device),
|
|
@@ -277,7 +293,9 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
277
293
|
for tens in batch_classes_tens[i]
|
|
278
294
|
]
|
|
279
295
|
sent_tokens = sentences[self.batch_size * batch_i + i]
|
|
280
|
-
sent_labels = self.batch_labels(
|
|
296
|
+
sent_labels = self.batch_labels(
|
|
297
|
+
batch, i, wp_labels, sent_tokens, batch["context_mask"]
|
|
298
|
+
)
|
|
281
299
|
labels += sent_labels
|
|
282
300
|
|
|
283
301
|
return {"entities": ner_entities(tokens, labels)}
|
|
@@ -288,6 +306,7 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
288
306
|
batch_i: int,
|
|
289
307
|
wp_labels: List[str],
|
|
290
308
|
tokens: List[str],
|
|
309
|
+
context_mask: torch.Tensor,
|
|
291
310
|
) -> List[str]:
|
|
292
311
|
"""Align labels to tokens rather than wordpiece tokens.
|
|
293
312
|
|
|
@@ -299,6 +318,8 @@ class BertNamedEntityRecognizer(PipelineStep):
|
|
|
299
318
|
batch_labels = ["O"] * len(tokens)
|
|
300
319
|
|
|
301
320
|
for wplabel_j, wp_label in enumerate(wp_labels):
|
|
321
|
+
if context_mask[batch_i][wplabel_j] == 1:
|
|
322
|
+
continue
|
|
302
323
|
token_i = batchs.token_to_word(batch_i, wplabel_j)
|
|
303
324
|
if token_i is None:
|
|
304
325
|
continue
|
renard/py.typed
ADDED
|
File without changes
|
renard/utils.py
CHANGED
|
@@ -1,12 +1,7 @@
|
|
|
1
|
-
from typing import List, Tuple, TypeVar, Collection, Iterable,
|
|
2
|
-
import re, os
|
|
3
|
-
from more_itertools import flatten
|
|
1
|
+
from typing import List, Tuple, TypeVar, Collection, Iterable, cast
|
|
4
2
|
from more_itertools.more import windowed
|
|
5
3
|
import torch
|
|
6
4
|
|
|
7
|
-
from renard.pipeline.ner import NEREntity, ner_entities
|
|
8
|
-
|
|
9
|
-
|
|
10
5
|
T = TypeVar("T")
|
|
11
6
|
|
|
12
7
|
|
|
@@ -81,49 +76,3 @@ def search_pattern(seq: Iterable[R], pattern: List[R]) -> List[int]:
|
|
|
81
76
|
if list(subseq) == pattern:
|
|
82
77
|
start_indices.append(subseq_i)
|
|
83
78
|
return start_indices
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def load_conll2002_bio(
|
|
87
|
-
path: str,
|
|
88
|
-
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
89
|
-
separator: str = "\t",
|
|
90
|
-
**kwargs
|
|
91
|
-
) -> Tuple[List[List[str]], List[str], List[NEREntity]]:
|
|
92
|
-
"""Load a file under CoNLL2022 BIO format. Sentences are expected
|
|
93
|
-
to be separated by end of lines. Tags should be in the CoNLL-2002
|
|
94
|
-
format (such as 'B-PER I-PER') - If this is not the case, see the
|
|
95
|
-
``tag_conversion_map`` argument.
|
|
96
|
-
|
|
97
|
-
:param path: path to the CoNLL-2002 formatted file
|
|
98
|
-
:param separator: separator between token and BIO tags
|
|
99
|
-
:param tag_conversion_map: conversion map for tags found in the
|
|
100
|
-
input file. Example : ``{'B': 'B-PER', 'I': 'I-PER'}``
|
|
101
|
-
:param kwargs: additional kwargs for ``open`` (such as
|
|
102
|
-
``encoding`` or ``newline``).
|
|
103
|
-
|
|
104
|
-
:return: ``(sentences, tokens, entities)``
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
if tag_conversion_map is None:
|
|
108
|
-
tag_conversion_map = {}
|
|
109
|
-
|
|
110
|
-
with open(os.path.expanduser(path), **kwargs) as f:
|
|
111
|
-
raw_data = f.read()
|
|
112
|
-
|
|
113
|
-
sents = []
|
|
114
|
-
sent_tokens = []
|
|
115
|
-
tags = []
|
|
116
|
-
for line in raw_data.split("\n"):
|
|
117
|
-
line = line.strip("\n")
|
|
118
|
-
if re.fullmatch(r"\s*", line):
|
|
119
|
-
sents.append(sent_tokens)
|
|
120
|
-
sent_tokens = []
|
|
121
|
-
continue
|
|
122
|
-
token, tag = line.split(separator)
|
|
123
|
-
sent_tokens.append(token)
|
|
124
|
-
tags.append(tag_conversion_map.get(tag, tag))
|
|
125
|
-
|
|
126
|
-
tokens = list(flatten(sents))
|
|
127
|
-
entities = ner_entities(tokens, tags)
|
|
128
|
-
|
|
129
|
-
return sents, list(flatten(sents)), entities
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: renard-pipeline
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Relationships Extraction from NARrative Documents
|
|
5
|
+
Home-page: https://github.com/CompNet/Renard
|
|
5
6
|
License: GPL-3.0-only
|
|
6
7
|
Author: Arthur Amalvy
|
|
7
8
|
Author-email: arthur.amalvy@univ-avignon.fr
|
|
@@ -14,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
14
15
|
Provides-Extra: spacy
|
|
15
16
|
Provides-Extra: stanza
|
|
16
17
|
Requires-Dist: coreferee (>=1.4.0,<2.0.0) ; extra == "spacy"
|
|
18
|
+
Requires-Dist: datasets (>=2.16.1,<3.0.0)
|
|
17
19
|
Requires-Dist: grimbert (>=0.1.0,<0.2.0)
|
|
18
20
|
Requires-Dist: matplotlib (>=3.5.3,<4.0.0)
|
|
19
21
|
Requires-Dist: more-itertools (>=10.1.0,<11.0.0)
|
|
@@ -26,15 +28,19 @@ Requires-Dist: seqeval (==1.2.2)
|
|
|
26
28
|
Requires-Dist: spacy (>=3.5.0,<4.0.0) ; extra == "spacy"
|
|
27
29
|
Requires-Dist: spacy-transformers (>=1.2.1,<2.0.0) ; extra == "spacy"
|
|
28
30
|
Requires-Dist: stanza (>=1.3.0,<2.0.0) ; extra == "stanza"
|
|
29
|
-
Requires-Dist: tibert (>=0.
|
|
31
|
+
Requires-Dist: tibert (>=0.3.0,<0.4.0)
|
|
30
32
|
Requires-Dist: torch (>=2.0.0,!=2.0.1)
|
|
31
33
|
Requires-Dist: tqdm (>=4.62.3,<5.0.0)
|
|
32
34
|
Requires-Dist: transformers (>=4.36.0,<5.0.0)
|
|
35
|
+
Project-URL: Documentation, https://compnet.github.io/Renard/
|
|
36
|
+
Project-URL: Repository, https://github.com/CompNet/Renard
|
|
33
37
|
Description-Content-Type: text/markdown
|
|
34
38
|
|
|
35
39
|
# Renard
|
|
36
40
|
|
|
37
|
-
Relationships Extraction from NARrative Documents
|
|
41
|
+
Renard (Relationships Extraction from NARrative Documents) is a library for creating and using custom character networks extraction pipelines. Renard can extract dynamic as well as static character networks.
|
|
42
|
+
|
|
43
|
+

|
|
38
44
|
|
|
39
45
|
|
|
40
46
|
# Installation
|
|
@@ -43,6 +49,8 @@ You can install the latest version using pip:
|
|
|
43
49
|
|
|
44
50
|
> pip install renard-pipeline
|
|
45
51
|
|
|
52
|
+
Currently, Renard supports Python 3.8, 3.9 and 3.10.
|
|
53
|
+
|
|
46
54
|
|
|
47
55
|
# Documentation
|
|
48
56
|
|
|
@@ -53,7 +61,32 @@ If you need local documentation, it can be generated using `Sphinx`. From the `d
|
|
|
53
61
|
|
|
54
62
|
# Tutorial
|
|
55
63
|
|
|
56
|
-
|
|
64
|
+
Renard's central concept is the `Pipeline`.A `Pipeline` is a list of `PipelineStep` that are run sequentially in order to extract a character graph from a document. Here is a simple example:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
from renard.pipeline import Pipeline
|
|
68
|
+
from renard.pipeline.tokenization import NLTKTokenizer
|
|
69
|
+
from renard.pipeline.ner import NLTKNamedEntityRecognizer
|
|
70
|
+
from renard.pipeline.character_unification import GraphRulesCharacterUnifier
|
|
71
|
+
from renard.pipeline.graph_extraction import CoOccurrencesGraphExtractor
|
|
72
|
+
|
|
73
|
+
with open("./my_doc.txt") as f:
|
|
74
|
+
text = f.read()
|
|
75
|
+
|
|
76
|
+
pipeline = Pipeline(
|
|
77
|
+
[
|
|
78
|
+
NLTKTokenizer(),
|
|
79
|
+
NLTKNamedEntityRecognizer(),
|
|
80
|
+
GraphRulesCharacterUnifier(min_appearance=10),
|
|
81
|
+
CoOccurrencesGraphExtractor(co_occurrences_dist=25)
|
|
82
|
+
]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
out = pipeline(text)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
For more information, see `renard_tutorial.py`, which is a tutorial in the `jupytext` format. You can open it as a notebook in Jupyter Notebook (or export it as a notebook with `jupytext --to ipynb renard-tutorial.py`).
|
|
89
|
+
|
|
57
90
|
|
|
58
91
|
|
|
59
92
|
# Running tests
|
|
@@ -64,3 +97,8 @@ If you need local documentation, it can be generated using `Sphinx`. From the `d
|
|
|
64
97
|
|
|
65
98
|
Expensive tests are disabled by default. These can be run by setting the environment variable `RENARD_TEST_ALL` to `1`.
|
|
66
99
|
|
|
100
|
+
|
|
101
|
+
# Contributing
|
|
102
|
+
|
|
103
|
+
see [the "Contributing" section of the documentation](https://compnet.github.io/Renard/contributing.html).
|
|
104
|
+
|
|
@@ -1,24 +1,25 @@
|
|
|
1
1
|
renard/gender.py,sha256=HDtJQKOqIkV8F-Mxva95XFXWJoKRKckQ3fc93OBM6sw,102
|
|
2
2
|
renard/graph_utils.py,sha256=5jwky9JgJ-WMVHfeaiXkAAQwEfhR2BFSrWhck1Qmpgo,5812
|
|
3
|
-
renard/ner_utils.py,sha256=
|
|
3
|
+
renard/ner_utils.py,sha256=jN1AQkaV0Kx-Bc0oc3SYBEmSUuKPBbzXqByOlaqH62k,11263
|
|
4
4
|
renard/nltk_utils.py,sha256=mUJiwMrEDZV4Fla7WuMR-hA_OC2ZIwSXgW_0Ew18VSo,977
|
|
5
5
|
renard/pipeline/__init__.py,sha256=8Yim2mmny8YGvM7N5-na5zK-C9UDxUb77K9ml-VirUA,35
|
|
6
|
-
renard/pipeline/character_unification.py,sha256=
|
|
7
|
-
renard/pipeline/characters_extraction.py,sha256=
|
|
8
|
-
renard/pipeline/core.py,sha256=
|
|
6
|
+
renard/pipeline/character_unification.py,sha256=GcnC8UYqn1RBOGVhYS9LVcTNqpxm9YoT-lPsE3vodek,14818
|
|
7
|
+
renard/pipeline/characters_extraction.py,sha256=NzF8H9X19diW6rqwS5ERrRku7rFueO3S077H5C6kb7I,363
|
|
8
|
+
renard/pipeline/core.py,sha256=luKNUTCDtZfwKzxVIaImyIMwFFvIknfT1LdQtongj24,22570
|
|
9
9
|
renard/pipeline/corefs/__init__.py,sha256=9c9AaXBcRrDBf1jhTtJ7DyjOJhX_Zej3FjlcGak7MK8,44
|
|
10
10
|
renard/pipeline/corefs/corefs.py,sha256=nzYT6S9ify3FlgGB3FSDpAhs2UQYgW9c3CL2GRYzTms,11508
|
|
11
|
-
renard/pipeline/graph_extraction.py,sha256=
|
|
12
|
-
renard/pipeline/ner.py,sha256=
|
|
11
|
+
renard/pipeline/graph_extraction.py,sha256=n0T_nzNGiwE9bDubpPknHe7bbDhJ4ndnqmoMmyfbeWg,19468
|
|
12
|
+
renard/pipeline/ner.py,sha256=5zqZlEjhO__0iuRQAN9rvhCbcd9QmNCcH9_NP_BaTbc,11261
|
|
13
13
|
renard/pipeline/preconfigured.py,sha256=j4-0OUZrmtC8rQfwGWEAAGNxc8-4hlY7N823Uami5lk,5392
|
|
14
14
|
renard/pipeline/preprocessing.py,sha256=OsdsYzmRweAiQV_CtP7uiz--OGogZtQlsdR8XX5DCk0,952
|
|
15
15
|
renard/pipeline/progress.py,sha256=VQsIxTuz0QQnepXPevHhMU-dHXMa1RWsjmMfBgoWdiY,1684
|
|
16
16
|
renard/pipeline/quote_detection.py,sha256=FyldJhynIT843fB7rwVtHmDZJqTKkjGml6qTLjsIhMA,2045
|
|
17
17
|
renard/pipeline/sentiment_analysis.py,sha256=76MPin4L1-vSswJe5yGrbCSSDim1LYxSEgNj_BdQDvk,1464
|
|
18
18
|
renard/pipeline/speaker_attribution.py,sha256=qCY-Z1haDDgZy8L4k8pAc6xIcSFmtcuuESu631QxRUY,4366
|
|
19
|
-
renard/pipeline/stanford_corenlp.py,sha256=
|
|
19
|
+
renard/pipeline/stanford_corenlp.py,sha256=14b6Ee6oPz1EL-bNRT688aNxVTk_Jwa_vJ20FiBODC4,8189
|
|
20
20
|
renard/pipeline/tokenization.py,sha256=RllOxSjaV_Sdu3CH8vKIbceNj3Noeey31mKircxWoyM,1806
|
|
21
21
|
renard/plot_utils.py,sha256=bmIBybleFJ-YiVPLPPWYW8x1UHpkuXTE7O9lQlRiWrk,2133
|
|
22
|
+
renard/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
23
|
renard/resources/hypocorisms/__init__.py,sha256=vlsY9PqxQCIpijxm79Y0KYh2c0S4S1pgrC9w-AUQGvE,55
|
|
23
24
|
renard/resources/hypocorisms/datas/License.txt,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
24
25
|
renard/resources/hypocorisms/datas/hypocorisms.csv,sha256=CKTo7A5i14NzN6JRBz7U2NJnxrEo8VOlmmdhzEZnqlI,21470
|
|
@@ -27,8 +28,8 @@ renard/resources/pronouns/__init__.py,sha256=62h0zuXp8kCToTLTyg8D8rJ-MXQpT8Vyc6m
|
|
|
27
28
|
renard/resources/pronouns/pronouns.py,sha256=YJ8hM6H8QHrF2Xx6O5blqc-Sqe1D1YFL0sRdqO_rroE,817
|
|
28
29
|
renard/resources/titles/__init__.py,sha256=Jcg4B7stsWiAaXbFgNl_L3ICtCQmFe9bo3YjdkVL50w,45
|
|
29
30
|
renard/resources/titles/titles.py,sha256=GsFccVJuTkgDWiAqWZpFd2R9pGvFKQZBOk4RWWuWDkw,968
|
|
30
|
-
renard/utils.py,sha256=
|
|
31
|
-
renard_pipeline-0.
|
|
32
|
-
renard_pipeline-0.
|
|
33
|
-
renard_pipeline-0.
|
|
34
|
-
renard_pipeline-0.
|
|
31
|
+
renard/utils.py,sha256=8J3swFqSi4YqhgYNXvttJ0s-DmJbl_yEYri6JpGEWH8,2340
|
|
32
|
+
renard_pipeline-0.4.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
33
|
+
renard_pipeline-0.4.1.dist-info/METADATA,sha256=KgpnPAR6BtLS4RNjsxIBWqUygUcoRdJfkqHigzZMSqU,3697
|
|
34
|
+
renard_pipeline-0.4.1.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
|
35
|
+
renard_pipeline-0.4.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|