renard-pipeline 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of renard-pipeline might be problematic. Click here for more details.
- renard/graph_utils.py +11 -4
- renard/ner_utils.py +24 -14
- renard/pipeline/character_unification.py +62 -19
- renard/pipeline/characters_extraction.py +3 -1
- renard/pipeline/core.py +141 -26
- renard/pipeline/corefs/corefs.py +32 -33
- renard/pipeline/graph_extraction.py +281 -192
- renard/pipeline/ner/__init__.py +1 -0
- renard/pipeline/{ner.py → ner/ner.py} +47 -76
- renard/pipeline/ner/retrieval.py +375 -0
- renard/pipeline/progress.py +32 -1
- renard/pipeline/speaker_attribution.py +2 -3
- renard/pipeline/tokenization.py +59 -30
- renard/plot_utils.py +48 -28
- renard/resources/determiners/__init__.py +1 -0
- renard/resources/determiners/determiners.py +41 -0
- renard/resources/hypocorisms/hypocorisms.py +3 -2
- renard/utils.py +57 -1
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.6.0.dist-info}/METADATA +45 -20
- renard_pipeline-0.6.0.dist-info/RECORD +39 -0
- renard_pipeline-0.4.2.dist-info/RECORD +0 -35
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.6.0.dist-info}/LICENSE +0 -0
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.6.0.dist-info}/WHEEL +0 -0
renard/graph_utils.py
CHANGED
|
@@ -70,10 +70,17 @@ def graph_with_names(
|
|
|
70
70
|
else:
|
|
71
71
|
name_style_fn = name_style
|
|
72
72
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
73
|
+
mapping = {}
|
|
74
|
+
for character in G.nodes():
|
|
75
|
+
# NOTE: it is *possible* to have a graph where nodes are not
|
|
76
|
+
# characters (for example, simple strings). Therefore, we are
|
|
77
|
+
# lenient here
|
|
78
|
+
try:
|
|
79
|
+
mapping[character] = name_style_fn(character)
|
|
80
|
+
except AttributeError:
|
|
81
|
+
mapping[character] = character
|
|
82
|
+
|
|
83
|
+
return nx.relabel_nodes(G, mapping)
|
|
77
84
|
|
|
78
85
|
|
|
79
86
|
def layout_with_names(
|
renard/ner_utils.py
CHANGED
|
@@ -74,7 +74,7 @@ class DataCollatorForTokenClassificationWithBatchEncoding:
|
|
|
74
74
|
class NERDataset(Dataset):
|
|
75
75
|
"""
|
|
76
76
|
:ivar _context_mask: for each element, a mask indicating which
|
|
77
|
-
tokens are part of the context (
|
|
77
|
+
tokens are part of the context (0 for context, 1 for text on
|
|
78
78
|
which to perform inference). The mask allows to discard
|
|
79
79
|
predictions made for context at inference time, even though
|
|
80
80
|
the context can still be passed as input to the model.
|
|
@@ -92,11 +92,11 @@ class NERDataset(Dataset):
|
|
|
92
92
|
assert all(
|
|
93
93
|
[len(cm) == len(elt) for elt, cm in zip(self.elements, context_mask)]
|
|
94
94
|
)
|
|
95
|
-
self._context_mask = context_mask or [[
|
|
95
|
+
self._context_mask = context_mask or [[1] * len(elt) for elt in self.elements]
|
|
96
96
|
|
|
97
97
|
self.tokenizer = tokenizer
|
|
98
98
|
|
|
99
|
-
def __getitem__(self, index:
|
|
99
|
+
def __getitem__(self, index: int) -> BatchEncoding:
|
|
100
100
|
element = self.elements[index]
|
|
101
101
|
|
|
102
102
|
batch = self.tokenizer(
|
|
@@ -104,15 +104,18 @@ class NERDataset(Dataset):
|
|
|
104
104
|
truncation=True,
|
|
105
105
|
max_length=512, # TODO
|
|
106
106
|
is_split_into_words=True,
|
|
107
|
+
return_length=True,
|
|
107
108
|
)
|
|
108
109
|
|
|
109
|
-
batch["
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
batch["context_mask"]
|
|
110
|
+
length = batch["length"][0]
|
|
111
|
+
del batch["length"]
|
|
112
|
+
if self.tokenizer.truncation_side == "right":
|
|
113
|
+
batch["context_mask"] = self._context_mask[index][:length]
|
|
114
|
+
else:
|
|
115
|
+
assert self.tokenizer.truncation_side == "left"
|
|
116
|
+
batch["context_mask"] = self._context_mask[index][
|
|
117
|
+
len(batch["input_ids"]) - length :
|
|
118
|
+
]
|
|
116
119
|
|
|
117
120
|
return batch
|
|
118
121
|
|
|
@@ -181,6 +184,7 @@ def load_conll2002_bio(
|
|
|
181
184
|
path: str,
|
|
182
185
|
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
183
186
|
separator: str = "\t",
|
|
187
|
+
max_sent_len: Optional[int] = None,
|
|
184
188
|
**kwargs,
|
|
185
189
|
) -> Tuple[List[List[str]], List[str], List[NEREntity]]:
|
|
186
190
|
"""Load a file under CoNLL2022 BIO format. Sentences are expected
|
|
@@ -192,7 +196,9 @@ def load_conll2002_bio(
|
|
|
192
196
|
:param separator: separator between token and BIO tags
|
|
193
197
|
:param tag_conversion_map: conversion map for tags found in the
|
|
194
198
|
input file. Example : ``{'B': 'B-PER', 'I': 'I-PER'}``
|
|
195
|
-
:param
|
|
199
|
+
:param max_sent_len: if specified, maximum length, in tokens, of
|
|
200
|
+
sentences.
|
|
201
|
+
:param kwargs: additional kwargs for :func:`open` (such as
|
|
196
202
|
``encoding`` or ``newline``).
|
|
197
203
|
|
|
198
204
|
:return: ``(sentences, tokens, entities)``
|
|
@@ -207,7 +213,9 @@ def load_conll2002_bio(
|
|
|
207
213
|
tags = []
|
|
208
214
|
for line in raw_data.split("\n"):
|
|
209
215
|
line = line.strip("\n")
|
|
210
|
-
if re.fullmatch(r"\s*", line)
|
|
216
|
+
if re.fullmatch(r"\s*", line) or (
|
|
217
|
+
not max_sent_len is None and len(sent_tokens) >= max_sent_len
|
|
218
|
+
):
|
|
211
219
|
if len(sent_tokens) == 0:
|
|
212
220
|
continue
|
|
213
221
|
sents.append(sent_tokens)
|
|
@@ -227,6 +235,7 @@ def hgdataset_from_conll2002(
|
|
|
227
235
|
path: str,
|
|
228
236
|
tag_conversion_map: Optional[Dict[str, str]] = None,
|
|
229
237
|
separator: str = "\t",
|
|
238
|
+
max_sent_len: Optional[int] = None,
|
|
230
239
|
**kwargs,
|
|
231
240
|
) -> HGDataset:
|
|
232
241
|
"""Load a CoNLL-2002 file as a Huggingface Dataset.
|
|
@@ -234,12 +243,13 @@ def hgdataset_from_conll2002(
|
|
|
234
243
|
:param path: passed to :func:`.load_conll2002_bio`
|
|
235
244
|
:param tag_conversion_map: passed to :func:`load_conll2002_bio`
|
|
236
245
|
:param separator: passed to :func:`load_conll2002_bio`
|
|
237
|
-
:param
|
|
246
|
+
:param max_sent_len: passed to :func:`load_conll2002_bio`
|
|
247
|
+
:param kwargs: additional kwargs for :func:`open`
|
|
238
248
|
|
|
239
249
|
:return: a :class:`datasets.Dataset` with features 'tokens' and 'labels'.
|
|
240
250
|
"""
|
|
241
251
|
sentences, tokens, entities = load_conll2002_bio(
|
|
242
|
-
path, tag_conversion_map, separator, **kwargs
|
|
252
|
+
path, tag_conversion_map, separator, max_sent_len, **kwargs
|
|
243
253
|
)
|
|
244
254
|
|
|
245
255
|
# convert entities to labels
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict, List, FrozenSet, Set, Optional, Tuple, Union, Literal
|
|
2
|
-
import
|
|
2
|
+
import re, sys
|
|
3
3
|
from itertools import combinations
|
|
4
4
|
from collections import defaultdict, Counter
|
|
5
5
|
from dataclasses import dataclass
|
|
@@ -11,6 +11,7 @@ from renard.pipeline.ner import NEREntity
|
|
|
11
11
|
from renard.pipeline.progress import ProgressReporter
|
|
12
12
|
from renard.resources.hypocorisms import HypocorismGazetteer
|
|
13
13
|
from renard.resources.pronouns import is_a_female_pronoun, is_a_male_pronoun
|
|
14
|
+
from renard.resources.determiners import singular_determiners
|
|
14
15
|
from renard.resources.titles import is_a_male_title, is_a_female_title, all_titles
|
|
15
16
|
|
|
16
17
|
|
|
@@ -61,6 +62,8 @@ def _assign_coreference_mentions(
|
|
|
61
62
|
# we assign each chain to the character with highest name
|
|
62
63
|
# occurence in it
|
|
63
64
|
for chain in corefs:
|
|
65
|
+
if len(char_mentions) == 0:
|
|
66
|
+
break
|
|
64
67
|
# determine the characters with the highest number of
|
|
65
68
|
# occurences
|
|
66
69
|
occ_counter = {}
|
|
@@ -98,8 +101,13 @@ class NaiveCharacterUnifier(PipelineStep):
|
|
|
98
101
|
character for it to be valid
|
|
99
102
|
"""
|
|
100
103
|
self.min_appearances = min_appearances
|
|
104
|
+
# a default value, will be est by _pipeline_init_
|
|
105
|
+
self.character_ner_tag = "PER"
|
|
101
106
|
super().__init__()
|
|
102
107
|
|
|
108
|
+
def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
|
|
109
|
+
self.character_ner_tag = character_ner_tag
|
|
110
|
+
|
|
103
111
|
def __call__(
|
|
104
112
|
self,
|
|
105
113
|
text: str,
|
|
@@ -112,7 +120,7 @@ class NaiveCharacterUnifier(PipelineStep):
|
|
|
112
120
|
:param tokens:
|
|
113
121
|
:param entities:
|
|
114
122
|
"""
|
|
115
|
-
persons = [e for e in entities if e.tag ==
|
|
123
|
+
persons = [e for e in entities if e.tag == self.character_ner_tag]
|
|
116
124
|
|
|
117
125
|
characters = defaultdict(list)
|
|
118
126
|
for entity in persons:
|
|
@@ -160,6 +168,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
160
168
|
additional_hypocorisms: Optional[List[Tuple[str, List[str]]]] = None,
|
|
161
169
|
link_corefs_mentions: bool = False,
|
|
162
170
|
ignore_lone_titles: Optional[Set[str]] = None,
|
|
171
|
+
ignore_leading_determiner: bool = False,
|
|
163
172
|
) -> None:
|
|
164
173
|
"""
|
|
165
174
|
:param min_appearances: minimum number of appearances of a
|
|
@@ -174,24 +183,32 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
174
183
|
extract a lot of spurious links. However, linking by
|
|
175
184
|
coref is sometimes the only way to resolve a character
|
|
176
185
|
alias.
|
|
177
|
-
:param ignore_lone_titles: a set of titles to ignore when
|
|
178
|
-
|
|
186
|
+
:param ignore_lone_titles: a set of titles to ignore when they
|
|
187
|
+
stand on their own. This avoids extracting false
|
|
179
188
|
positives characters such as 'Mr.' or 'Miss'.
|
|
189
|
+
:param ignore_leading_determiner: if ``True``, will ignore the
|
|
190
|
+
leading determiner when applying unification rules. This
|
|
191
|
+
is useful if the NER model used in the pipeline adds
|
|
192
|
+
leading determiners as part of entites.
|
|
180
193
|
"""
|
|
181
194
|
self.min_appearances = min_appearances
|
|
182
195
|
self.additional_hypocorisms = additional_hypocorisms
|
|
183
196
|
self.link_corefs_mentions = link_corefs_mentions
|
|
184
197
|
self.ignore_lone_titles = ignore_lone_titles or set()
|
|
198
|
+
self.character_ner_tag = "PER" # a default value, will be set by _pipeline_init
|
|
199
|
+
self.ignore_leading_determiner = ignore_leading_determiner
|
|
185
200
|
|
|
186
201
|
super().__init__()
|
|
187
202
|
|
|
188
|
-
def _pipeline_init_(self, lang: str,
|
|
203
|
+
def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
|
|
189
204
|
self.hypocorism_gazetteer = HypocorismGazetteer(lang=lang)
|
|
190
205
|
if not self.additional_hypocorisms is None:
|
|
191
206
|
for name, nicknames in self.additional_hypocorisms:
|
|
192
207
|
self.hypocorism_gazetteer._add_hypocorism_(name, nicknames)
|
|
193
208
|
|
|
194
|
-
|
|
209
|
+
self.character_ner_tag = character_ner_tag
|
|
210
|
+
|
|
211
|
+
return super()._pipeline_init_(lang, **kwargs)
|
|
195
212
|
|
|
196
213
|
def __call__(
|
|
197
214
|
self,
|
|
@@ -201,7 +218,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
201
218
|
) -> Dict[str, Any]:
|
|
202
219
|
import networkx as nx
|
|
203
220
|
|
|
204
|
-
mentions = [m for m in entities if m.tag ==
|
|
221
|
+
mentions = [m for m in entities if m.tag == self.character_ner_tag]
|
|
205
222
|
mentions_str = set(
|
|
206
223
|
filter(
|
|
207
224
|
lambda m: not m in self.ignore_lone_titles,
|
|
@@ -219,23 +236,28 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
219
236
|
|
|
220
237
|
# * link nodes based on several rules
|
|
221
238
|
for name1, name2 in combinations(G.nodes(), 2):
|
|
239
|
+
|
|
240
|
+
# preprocess name when needed
|
|
241
|
+
pname1 = self._preprocess_name(name1)
|
|
242
|
+
pname2 = self._preprocess_name(name2)
|
|
243
|
+
|
|
222
244
|
# is one name a known hypocorism of the other ? (also
|
|
223
245
|
# checks if both names are the same)
|
|
224
|
-
if self.hypocorism_gazetteer.are_related(
|
|
246
|
+
if self.hypocorism_gazetteer.are_related(pname1, pname2):
|
|
225
247
|
G.add_edge(name1, name2)
|
|
226
248
|
continue
|
|
227
249
|
|
|
228
250
|
# if we remove the title, is one name related to the other
|
|
229
251
|
# ?
|
|
230
252
|
if self.names_are_related_after_title_removal(
|
|
231
|
-
|
|
253
|
+
pname1, pname2, hname_constants
|
|
232
254
|
):
|
|
233
255
|
G.add_edge(name1, name2)
|
|
234
256
|
continue
|
|
235
257
|
|
|
236
258
|
# add an edge if two characters have the same family names
|
|
237
|
-
human_name1 = HumanName(
|
|
238
|
-
human_name2 = HumanName(
|
|
259
|
+
human_name1 = HumanName(pname1, constants=hname_constants)
|
|
260
|
+
human_name2 = HumanName(pname2, constants=hname_constants)
|
|
239
261
|
if (
|
|
240
262
|
len(human_name1.last) > 0
|
|
241
263
|
and human_name1.last.lower() == human_name2.last.lower()
|
|
@@ -272,10 +294,15 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
272
294
|
pass
|
|
273
295
|
|
|
274
296
|
for name1, name2 in combinations(G.nodes(), 2):
|
|
297
|
+
|
|
298
|
+
# preprocess names when needed
|
|
299
|
+
pname1 = self._preprocess_name(name1)
|
|
300
|
+
pname2 = self._preprocess_name(name2)
|
|
301
|
+
|
|
275
302
|
# check if characters have the same last name but a
|
|
276
303
|
# different first name.
|
|
277
|
-
human_name1 = HumanName(
|
|
278
|
-
human_name2 = HumanName(
|
|
304
|
+
human_name1 = HumanName(pname1, constants=hname_constants)
|
|
305
|
+
human_name2 = HumanName(pname2, constants=hname_constants)
|
|
279
306
|
if (
|
|
280
307
|
len(human_name1.last) > 0
|
|
281
308
|
and len(human_name2.last) > 0
|
|
@@ -327,6 +354,17 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
327
354
|
|
|
328
355
|
return {"characters": characters}
|
|
329
356
|
|
|
357
|
+
def _preprocess_name(self, name) -> str:
|
|
358
|
+
if self.ignore_leading_determiner:
|
|
359
|
+
if not self.lang in singular_determiners:
|
|
360
|
+
print(
|
|
361
|
+
f"[warning] can't ignore leading determiners for {self.lang}",
|
|
362
|
+
file=sys.stderr,
|
|
363
|
+
)
|
|
364
|
+
for determiner in singular_determiners.get(self.lang, []):
|
|
365
|
+
name = re.sub(f"^{determiner} ", " ", name, flags=re.I)
|
|
366
|
+
return name
|
|
367
|
+
|
|
330
368
|
def _make_hname_constants(self) -> Constants:
|
|
331
369
|
if self.lang == "eng":
|
|
332
370
|
return Constants()
|
|
@@ -355,13 +393,18 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
355
393
|
or self.hypocorism_gazetteer.are_related(raw_name1, raw_name2)
|
|
356
394
|
)
|
|
357
395
|
|
|
358
|
-
def names_are_in_coref(
|
|
396
|
+
def names_are_in_coref(
|
|
397
|
+
self, name1: str, name2: str, corefs: List[List[Mention]]
|
|
398
|
+
) -> bool:
|
|
399
|
+
once_together = False
|
|
359
400
|
for coref_chain in corefs:
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
):
|
|
363
|
-
return
|
|
364
|
-
|
|
401
|
+
name1_in = any([name1 == " ".join(m.tokens) for m in coref_chain])
|
|
402
|
+
name2_in = any([name2 == " ".join(m.tokens) for m in coref_chain])
|
|
403
|
+
if name1_in == (not name2_in):
|
|
404
|
+
return False
|
|
405
|
+
elif name1_in and name2_in:
|
|
406
|
+
once_together = True
|
|
407
|
+
return once_together
|
|
365
408
|
|
|
366
409
|
def infer_name_gender(
|
|
367
410
|
self,
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
import renard.pipeline.character_unification as cu
|
|
2
3
|
|
|
3
4
|
print(
|
|
4
|
-
"[warning] the characters_extraction module is deprecated. Use character_unification instead."
|
|
5
|
+
"[warning] the characters_extraction module is deprecated. Use character_unification instead.",
|
|
6
|
+
file=sys.stderr,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
9
|
Character = cu.Character
|
renard/pipeline/core.py
CHANGED
|
@@ -79,11 +79,18 @@ class PipelineStep:
|
|
|
79
79
|
"""Initialize the :class:`PipelineStep` with a given configuration."""
|
|
80
80
|
pass
|
|
81
81
|
|
|
82
|
-
def _pipeline_init_(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
82
|
+
def _pipeline_init_(
|
|
83
|
+
self, lang: str, progress_reporter: ProgressReporter, **kwargs
|
|
84
|
+
) -> Optional[Dict[Pipeline.PipelineParameter, Any]]:
|
|
85
|
+
"""Set the step configuration that is common to the whole
|
|
86
|
+
pipeline.
|
|
87
|
+
|
|
88
|
+
:param lang: the lang of the whole pipeline
|
|
89
|
+
:param progress_reporter:
|
|
90
|
+
:param kwargs: additional pipeline parameters.
|
|
91
|
+
|
|
92
|
+
:return: a step can return a dictionary of pipeline params if
|
|
93
|
+
it wish to modify some of these.
|
|
87
94
|
"""
|
|
88
95
|
supported_langs = self.supported_langs()
|
|
89
96
|
if not supported_langs == "any" and not lang in supported_langs:
|
|
@@ -150,13 +157,14 @@ class PipelineState:
|
|
|
150
157
|
#: input text
|
|
151
158
|
text: Optional[str]
|
|
152
159
|
|
|
153
|
-
#: text split into
|
|
154
|
-
|
|
160
|
+
#: text split into blocks of texts. When dynamic blocks are given,
|
|
161
|
+
#: the final network is dynamic, and split according to blocks.
|
|
162
|
+
dynamic_blocks: Optional[List[Tuple[int, int]]] = None
|
|
155
163
|
|
|
156
164
|
#: text splitted in tokens
|
|
157
165
|
tokens: Optional[List[str]] = None
|
|
158
|
-
#:
|
|
159
|
-
|
|
166
|
+
#: mapping from a character to its corresponding token
|
|
167
|
+
char2token: Optional[List[int]] = None
|
|
160
168
|
#: text splitted into sentences, each sentence being a list of
|
|
161
169
|
#: tokens
|
|
162
170
|
sentences: Optional[List[List[str]]] = None
|
|
@@ -182,14 +190,12 @@ class PipelineState:
|
|
|
182
190
|
#: network)
|
|
183
191
|
character_network: Optional[Union[List[nx.Graph], nx.Graph]] = None
|
|
184
192
|
|
|
193
|
+
# aliases of self.character_network
|
|
185
194
|
def get_characters_graph(self) -> Optional[Union[List[nx.Graph], nx.Graph]]:
|
|
186
|
-
print(
|
|
187
|
-
"[warning] the characters_graph attribute is deprecated, use character_network instead",
|
|
188
|
-
file=sys.stderr,
|
|
189
|
-
)
|
|
190
195
|
return self.character_network
|
|
191
196
|
|
|
192
197
|
characters_graph = property(get_characters_graph)
|
|
198
|
+
character_graph = property(get_characters_graph)
|
|
193
199
|
|
|
194
200
|
def get_character(
|
|
195
201
|
self, name: str, partial_match: bool = True
|
|
@@ -280,6 +286,10 @@ class PipelineState:
|
|
|
280
286
|
cumulative: bool = False,
|
|
281
287
|
stable_layout: bool = False,
|
|
282
288
|
layout: Optional[CharactersGraphLayout] = None,
|
|
289
|
+
node_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
290
|
+
edge_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
291
|
+
label_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
292
|
+
legend: bool = False,
|
|
283
293
|
):
|
|
284
294
|
"""Plot ``self.character_graph`` using reasonable default
|
|
285
295
|
parameters, and save the produced figures in the specified
|
|
@@ -294,6 +304,10 @@ class PipelineState:
|
|
|
294
304
|
timestep. Characters' positions are based on the final
|
|
295
305
|
cumulative graph layout.
|
|
296
306
|
:param layout: pre-computed graph layout
|
|
307
|
+
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
308
|
+
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
309
|
+
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
310
|
+
:param legend: passed to :func:`.plot_nx_graph_reasonably`
|
|
297
311
|
"""
|
|
298
312
|
import matplotlib.pyplot as plt
|
|
299
313
|
|
|
@@ -317,13 +331,25 @@ class PipelineState:
|
|
|
317
331
|
)
|
|
318
332
|
layout = layout_nx_graph_reasonably(layout_graph)
|
|
319
333
|
|
|
334
|
+
node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
|
|
335
|
+
edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
|
|
336
|
+
label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
|
|
337
|
+
|
|
320
338
|
for i, G in enumerate(graphs):
|
|
321
339
|
_, ax = plt.subplots()
|
|
322
340
|
local_layout = layout
|
|
323
341
|
if not local_layout is None:
|
|
324
342
|
local_layout = layout_with_names(G, local_layout, name_style)
|
|
325
343
|
G = graph_with_names(G, name_style=name_style)
|
|
326
|
-
plot_nx_graph_reasonably(
|
|
344
|
+
plot_nx_graph_reasonably(
|
|
345
|
+
G,
|
|
346
|
+
ax=ax,
|
|
347
|
+
layout=local_layout,
|
|
348
|
+
node_kwargs=node_kwargs[i],
|
|
349
|
+
edge_kwargs=edge_kwargs[i],
|
|
350
|
+
label_kwargs=label_kwargs[i],
|
|
351
|
+
legend=legend,
|
|
352
|
+
)
|
|
327
353
|
plt.savefig(f"{directory}/{i}.png")
|
|
328
354
|
plt.close()
|
|
329
355
|
|
|
@@ -335,6 +361,11 @@ class PipelineState:
|
|
|
335
361
|
] = "most_frequent",
|
|
336
362
|
layout: Optional[CharactersGraphLayout] = None,
|
|
337
363
|
fig: Optional[plt.Figure] = None,
|
|
364
|
+
node_kwargs: Optional[Dict[str, Any]] = None,
|
|
365
|
+
edge_kwargs: Optional[Dict[str, Any]] = None,
|
|
366
|
+
label_kwargs: Optional[Dict[str, Any]] = None,
|
|
367
|
+
tight_layout: bool = False,
|
|
368
|
+
legend: bool = False,
|
|
338
369
|
):
|
|
339
370
|
"""Plot ``self.character_graph`` using reasonable parameters,
|
|
340
371
|
and save the produced figure to a file
|
|
@@ -344,6 +375,11 @@ class PipelineState:
|
|
|
344
375
|
:param layout: pre-computed graph layout
|
|
345
376
|
:param fig: if specified, this matplotlib figure will be used
|
|
346
377
|
for plotting
|
|
378
|
+
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
379
|
+
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
380
|
+
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
381
|
+
:param tight_layout: if ``True``, will use matplotlib's tight_layout
|
|
382
|
+
:param legend: passed to :func:`.plot_nx_graph_reasonably`
|
|
347
383
|
"""
|
|
348
384
|
import matplotlib.pyplot as plt
|
|
349
385
|
|
|
@@ -361,7 +397,17 @@ class PipelineState:
|
|
|
361
397
|
fig.set_dpi(300)
|
|
362
398
|
fig.set_size_inches(24, 24)
|
|
363
399
|
ax = fig.add_subplot(111)
|
|
364
|
-
plot_nx_graph_reasonably(
|
|
400
|
+
plot_nx_graph_reasonably(
|
|
401
|
+
G,
|
|
402
|
+
ax=ax,
|
|
403
|
+
layout=layout,
|
|
404
|
+
node_kwargs=node_kwargs,
|
|
405
|
+
edge_kwargs=edge_kwargs,
|
|
406
|
+
label_kwargs=label_kwargs,
|
|
407
|
+
legend=legend,
|
|
408
|
+
)
|
|
409
|
+
if tight_layout:
|
|
410
|
+
fig.tight_layout()
|
|
365
411
|
plt.savefig(path)
|
|
366
412
|
plt.close()
|
|
367
413
|
|
|
@@ -375,6 +421,11 @@ class PipelineState:
|
|
|
375
421
|
graph_start_idx: int = 1,
|
|
376
422
|
stable_layout: bool = False,
|
|
377
423
|
layout: Optional[CharactersGraphLayout] = None,
|
|
424
|
+
node_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
425
|
+
edge_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
426
|
+
label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
427
|
+
tight_layout: bool = False,
|
|
428
|
+
legend: bool = False,
|
|
378
429
|
):
|
|
379
430
|
"""Plot ``self.character_network`` using reasonable default
|
|
380
431
|
parameters
|
|
@@ -400,6 +451,11 @@ class PipelineState:
|
|
|
400
451
|
same position in space at each timestep. Characters'
|
|
401
452
|
positions are based on the final cumulative graph layout.
|
|
402
453
|
:param layout: pre-computed graph layout
|
|
454
|
+
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
455
|
+
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
456
|
+
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
457
|
+
:param tight_layout: if ``True``, will use matplotlib's tight_layout
|
|
458
|
+
:param legend: passed to :func:`.plot_nx_graph_reasonably`
|
|
403
459
|
"""
|
|
404
460
|
import matplotlib.pyplot as plt
|
|
405
461
|
from matplotlib.widgets import Slider
|
|
@@ -418,13 +474,33 @@ class PipelineState:
|
|
|
418
474
|
fig.set_dpi(300)
|
|
419
475
|
fig.set_size_inches(24, 24)
|
|
420
476
|
ax = fig.add_subplot(111)
|
|
421
|
-
|
|
477
|
+
assert not isinstance(node_kwargs, list)
|
|
478
|
+
assert not isinstance(edge_kwargs, list)
|
|
479
|
+
assert not isinstance(label_kwargs, list)
|
|
480
|
+
if tight_layout:
|
|
481
|
+
fig.tight_layout()
|
|
482
|
+
plot_nx_graph_reasonably(
|
|
483
|
+
G,
|
|
484
|
+
ax=ax,
|
|
485
|
+
layout=layout,
|
|
486
|
+
node_kwargs=node_kwargs,
|
|
487
|
+
edge_kwargs=edge_kwargs,
|
|
488
|
+
label_kwargs=label_kwargs,
|
|
489
|
+
legend=legend,
|
|
490
|
+
)
|
|
422
491
|
return
|
|
423
492
|
|
|
424
493
|
if not isinstance(self.character_network, list):
|
|
425
494
|
raise TypeError
|
|
426
495
|
# self.character_network is a list: plot a dynamic graph
|
|
427
496
|
|
|
497
|
+
node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
|
|
498
|
+
assert isinstance(node_kwargs, list)
|
|
499
|
+
edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
|
|
500
|
+
assert isinstance(edge_kwargs, list)
|
|
501
|
+
label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
|
|
502
|
+
assert isinstance(label_kwargs, list)
|
|
503
|
+
|
|
428
504
|
if fig is None:
|
|
429
505
|
fig, ax = plt.subplots()
|
|
430
506
|
assert not fig is None
|
|
@@ -440,12 +516,13 @@ class PipelineState:
|
|
|
440
516
|
|
|
441
517
|
def update(slider_value):
|
|
442
518
|
assert isinstance(self.character_network, list)
|
|
519
|
+
slider_i = int(slider_value) - 1
|
|
443
520
|
|
|
444
521
|
character_networks = self.character_network
|
|
445
522
|
if cumulative:
|
|
446
523
|
character_networks = cumulative_character_networks
|
|
447
524
|
|
|
448
|
-
G = character_networks[
|
|
525
|
+
G = character_networks[slider_i]
|
|
449
526
|
|
|
450
527
|
local_layout = layout
|
|
451
528
|
if not local_layout is None:
|
|
@@ -453,11 +530,21 @@ class PipelineState:
|
|
|
453
530
|
G = graph_with_names(G, name_style)
|
|
454
531
|
|
|
455
532
|
ax.clear()
|
|
456
|
-
plot_nx_graph_reasonably(
|
|
533
|
+
plot_nx_graph_reasonably(
|
|
534
|
+
G,
|
|
535
|
+
ax=ax,
|
|
536
|
+
layout=local_layout,
|
|
537
|
+
node_kwargs=node_kwargs[slider_i],
|
|
538
|
+
edge_kwargs=edge_kwargs[slider_i],
|
|
539
|
+
label_kwargs=label_kwargs[slider_i],
|
|
540
|
+
legend=legend,
|
|
541
|
+
)
|
|
457
542
|
ax.set_xlim(-1.2, 1.2)
|
|
458
543
|
ax.set_ylim(-1.2, 1.2)
|
|
459
544
|
|
|
460
545
|
slider_ax = fig.add_axes([0.1, 0.05, 0.8, 0.04])
|
|
546
|
+
if tight_layout:
|
|
547
|
+
fig.tight_layout()
|
|
461
548
|
# HACK: we save the slider to the figure. This ensure the
|
|
462
549
|
# slider is still alive at plotting time.
|
|
463
550
|
fig.slider = Slider( # type: ignore
|
|
@@ -474,6 +561,10 @@ class PipelineState:
|
|
|
474
561
|
class Pipeline:
|
|
475
562
|
"""A flexible NLP pipeline"""
|
|
476
563
|
|
|
564
|
+
#: all the possible parameters of the whole pipeline, that are
|
|
565
|
+
#: shared between steps
|
|
566
|
+
PipelineParameter = Literal["lang", "progress_reporter", "character_ner_tag"]
|
|
567
|
+
|
|
477
568
|
def __init__(
|
|
478
569
|
self,
|
|
479
570
|
steps: List[PipelineStep],
|
|
@@ -496,17 +587,27 @@ class Pipeline:
|
|
|
496
587
|
self.progress_reporter = get_progress_reporter(progress_report)
|
|
497
588
|
|
|
498
589
|
self.lang = lang
|
|
590
|
+
self.character_ner_tag = "PER"
|
|
499
591
|
self.warn = warn
|
|
500
592
|
|
|
501
|
-
def
|
|
502
|
-
"""
|
|
593
|
+
def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
|
|
594
|
+
"""Initialise steps with global pipeline parameters.
|
|
595
|
+
|
|
503
596
|
:param ignored_steps: a list of steps production. All steps
|
|
504
597
|
with a production in ``ignored_steps`` will be ignored.
|
|
505
598
|
"""
|
|
506
|
-
steps_progress_reporter =
|
|
599
|
+
steps_progress_reporter = self.progress_reporter.get_subreporter()
|
|
507
600
|
steps = self._non_ignored_steps(ignored_steps)
|
|
601
|
+
pipeline_params = {
|
|
602
|
+
"progress_reporter": steps_progress_reporter,
|
|
603
|
+
"character_ner_tag": self.character_ner_tag,
|
|
604
|
+
}
|
|
508
605
|
for step in steps:
|
|
509
|
-
step._pipeline_init_(self.lang,
|
|
606
|
+
step_additional_params = step._pipeline_init_(self.lang, **pipeline_params)
|
|
607
|
+
if not step_additional_params is None:
|
|
608
|
+
for key, value in step_additional_params.items():
|
|
609
|
+
setattr(self, key, value)
|
|
610
|
+
pipeline_params[key] = value
|
|
510
611
|
|
|
511
612
|
def _non_ignored_steps(
|
|
512
613
|
self, ignored_steps: Optional[List[str]]
|
|
@@ -549,13 +650,27 @@ class Pipeline:
|
|
|
549
650
|
return (
|
|
550
651
|
False,
|
|
551
652
|
[
|
|
552
|
-
|
|
653
|
+
"".join(
|
|
654
|
+
[
|
|
655
|
+
f"step {i + 1} ({step.__class__.__name__}) has unsatisfied needs. "
|
|
656
|
+
+ f"needs: {step.needs()}. "
|
|
657
|
+
+ f"available: {pipeline_state}). "
|
|
658
|
+
+ f"missing: {step.needs() - pipeline_state}."
|
|
659
|
+
]
|
|
660
|
+
),
|
|
553
661
|
],
|
|
554
662
|
)
|
|
555
663
|
|
|
556
664
|
if not step.optional_needs().issubset(pipeline_state):
|
|
557
665
|
warnings.append(
|
|
558
|
-
|
|
666
|
+
"".join(
|
|
667
|
+
[
|
|
668
|
+
f"step {i + 1} ({step.__class__.__name__}) has unsatisfied optional needs. "
|
|
669
|
+
+ f"needs: {step.optional_needs()}. "
|
|
670
|
+
+ f"available: {pipeline_state}). "
|
|
671
|
+
+ f"missing: {step.optional_needs() - pipeline_state}."
|
|
672
|
+
]
|
|
673
|
+
)
|
|
559
674
|
)
|
|
560
675
|
|
|
561
676
|
pipeline_state = pipeline_state.union(step.production())
|
|
@@ -582,9 +697,9 @@ class Pipeline:
|
|
|
582
697
|
raise ValueError(warnings_or_errors)
|
|
583
698
|
if self.warn:
|
|
584
699
|
for warning in warnings_or_errors:
|
|
585
|
-
print(f"[warning] : {warning}")
|
|
700
|
+
print(f"[warning] : {warning}", file=sys.stderr)
|
|
586
701
|
|
|
587
|
-
self.
|
|
702
|
+
self._pipeline_init_steps_(ignored_steps)
|
|
588
703
|
|
|
589
704
|
state = PipelineState(text)
|
|
590
705
|
# sets attributes to PipelineState dynamically. This ensures
|