renard-pipeline 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of renard-pipeline might be problematic. Click here for more details.
- renard/graph_utils.py +11 -4
- renard/ner_utils.py +4 -0
- renard/pipeline/character_unification.py +14 -4
- renard/pipeline/characters_extraction.py +3 -1
- renard/pipeline/core.py +121 -26
- renard/pipeline/corefs/corefs.py +30 -31
- renard/pipeline/graph_extraction.py +281 -192
- renard/pipeline/ner.py +3 -2
- renard/pipeline/progress.py +32 -1
- renard/pipeline/speaker_attribution.py +2 -3
- renard/pipeline/tokenization.py +59 -30
- renard/plot_utils.py +41 -28
- renard/resources/hypocorisms/hypocorisms.py +3 -2
- renard/utils.py +57 -1
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.5.0.dist-info}/METADATA +27 -3
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.5.0.dist-info}/RECORD +18 -18
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.5.0.dist-info}/LICENSE +0 -0
- {renard_pipeline-0.4.2.dist-info → renard_pipeline-0.5.0.dist-info}/WHEEL +0 -0
renard/graph_utils.py
CHANGED
|
@@ -70,10 +70,17 @@ def graph_with_names(
|
|
|
70
70
|
else:
|
|
71
71
|
name_style_fn = name_style
|
|
72
72
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
73
|
+
mapping = {}
|
|
74
|
+
for character in G.nodes():
|
|
75
|
+
# NOTE: it is *possible* to have a graph where nodes are not
|
|
76
|
+
# characters (for example, simple strings). Therefore, we are
|
|
77
|
+
# lenient here
|
|
78
|
+
try:
|
|
79
|
+
mapping[character] = name_style_fn(character)
|
|
80
|
+
except AttributeError:
|
|
81
|
+
mapping[character] = character
|
|
82
|
+
|
|
83
|
+
return nx.relabel_nodes(G, mapping)
|
|
77
84
|
|
|
78
85
|
|
|
79
86
|
def layout_with_names(
|
renard/ner_utils.py
CHANGED
|
@@ -110,6 +110,10 @@ class NERDataset(Dataset):
|
|
|
110
110
|
elt_context_mask = self._context_mask[index]
|
|
111
111
|
for i in range(len(element)):
|
|
112
112
|
w2t = batch.word_to_tokens(0, i)
|
|
113
|
+
# w2t can be None in case of truncation, which can happen
|
|
114
|
+
# if `element' is too long
|
|
115
|
+
if w2t is None:
|
|
116
|
+
continue
|
|
113
117
|
mask_value = elt_context_mask[i]
|
|
114
118
|
tokens_mask = [mask_value] * (w2t.end - w2t.start)
|
|
115
119
|
batch["context_mask"][w2t.start : w2t.end] = tokens_mask
|
|
@@ -61,6 +61,8 @@ def _assign_coreference_mentions(
|
|
|
61
61
|
# we assign each chain to the character with highest name
|
|
62
62
|
# occurence in it
|
|
63
63
|
for chain in corefs:
|
|
64
|
+
if len(char_mentions) == 0:
|
|
65
|
+
break
|
|
64
66
|
# determine the characters with the highest number of
|
|
65
67
|
# occurences
|
|
66
68
|
occ_counter = {}
|
|
@@ -98,8 +100,13 @@ class NaiveCharacterUnifier(PipelineStep):
|
|
|
98
100
|
character for it to be valid
|
|
99
101
|
"""
|
|
100
102
|
self.min_appearances = min_appearances
|
|
103
|
+
# a default value, will be est by _pipeline_init_
|
|
104
|
+
self.character_ner_tag = "PER"
|
|
101
105
|
super().__init__()
|
|
102
106
|
|
|
107
|
+
def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
|
|
108
|
+
self.character_ner_tag = character_ner_tag
|
|
109
|
+
|
|
103
110
|
def __call__(
|
|
104
111
|
self,
|
|
105
112
|
text: str,
|
|
@@ -112,7 +119,7 @@ class NaiveCharacterUnifier(PipelineStep):
|
|
|
112
119
|
:param tokens:
|
|
113
120
|
:param entities:
|
|
114
121
|
"""
|
|
115
|
-
persons = [e for e in entities if e.tag ==
|
|
122
|
+
persons = [e for e in entities if e.tag == self.character_ner_tag]
|
|
116
123
|
|
|
117
124
|
characters = defaultdict(list)
|
|
118
125
|
for entity in persons:
|
|
@@ -182,16 +189,19 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
182
189
|
self.additional_hypocorisms = additional_hypocorisms
|
|
183
190
|
self.link_corefs_mentions = link_corefs_mentions
|
|
184
191
|
self.ignore_lone_titles = ignore_lone_titles or set()
|
|
192
|
+
self.character_ner_tag = "PER" # a default value, will be set by _pipeline_init
|
|
185
193
|
|
|
186
194
|
super().__init__()
|
|
187
195
|
|
|
188
|
-
def _pipeline_init_(self, lang: str,
|
|
196
|
+
def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
|
|
189
197
|
self.hypocorism_gazetteer = HypocorismGazetteer(lang=lang)
|
|
190
198
|
if not self.additional_hypocorisms is None:
|
|
191
199
|
for name, nicknames in self.additional_hypocorisms:
|
|
192
200
|
self.hypocorism_gazetteer._add_hypocorism_(name, nicknames)
|
|
193
201
|
|
|
194
|
-
|
|
202
|
+
self.character_ner_tag = character_ner_tag
|
|
203
|
+
|
|
204
|
+
return super()._pipeline_init_(lang, **kwargs)
|
|
195
205
|
|
|
196
206
|
def __call__(
|
|
197
207
|
self,
|
|
@@ -201,7 +211,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
|
|
|
201
211
|
) -> Dict[str, Any]:
|
|
202
212
|
import networkx as nx
|
|
203
213
|
|
|
204
|
-
mentions = [m for m in entities if m.tag ==
|
|
214
|
+
mentions = [m for m in entities if m.tag == self.character_ner_tag]
|
|
205
215
|
mentions_str = set(
|
|
206
216
|
filter(
|
|
207
217
|
lambda m: not m in self.ignore_lone_titles,
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
import renard.pipeline.character_unification as cu
|
|
2
3
|
|
|
3
4
|
print(
|
|
4
|
-
"[warning] the characters_extraction module is deprecated. Use character_unification instead."
|
|
5
|
+
"[warning] the characters_extraction module is deprecated. Use character_unification instead.",
|
|
6
|
+
file=sys.stderr,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
9
|
Character = cu.Character
|
renard/pipeline/core.py
CHANGED
|
@@ -79,11 +79,18 @@ class PipelineStep:
|
|
|
79
79
|
"""Initialize the :class:`PipelineStep` with a given configuration."""
|
|
80
80
|
pass
|
|
81
81
|
|
|
82
|
-
def _pipeline_init_(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
82
|
+
def _pipeline_init_(
|
|
83
|
+
self, lang: str, progress_reporter: ProgressReporter, **kwargs
|
|
84
|
+
) -> Optional[Dict[Pipeline.PipelineParameter, Any]]:
|
|
85
|
+
"""Set the step configuration that is common to the whole
|
|
86
|
+
pipeline.
|
|
87
|
+
|
|
88
|
+
:param lang: the lang of the whole pipeline
|
|
89
|
+
:param progress_reporter:
|
|
90
|
+
:param kwargs: additional pipeline parameters.
|
|
91
|
+
|
|
92
|
+
:return: a step can return a dictionary of pipeline params if
|
|
93
|
+
it wish to modify some of these.
|
|
87
94
|
"""
|
|
88
95
|
supported_langs = self.supported_langs()
|
|
89
96
|
if not supported_langs == "any" and not lang in supported_langs:
|
|
@@ -150,13 +157,14 @@ class PipelineState:
|
|
|
150
157
|
#: input text
|
|
151
158
|
text: Optional[str]
|
|
152
159
|
|
|
153
|
-
#: text split into
|
|
154
|
-
|
|
160
|
+
#: text split into blocks of texts. When dynamic blocks are given,
|
|
161
|
+
#: the final network is dynamic, and split according to blocks.
|
|
162
|
+
dynamic_blocks: Optional[List[Tuple[int, int]]] = None
|
|
155
163
|
|
|
156
164
|
#: text splitted in tokens
|
|
157
165
|
tokens: Optional[List[str]] = None
|
|
158
|
-
#:
|
|
159
|
-
|
|
166
|
+
#: mapping from a character to its corresponding token
|
|
167
|
+
char2token: Optional[List[int]] = None
|
|
160
168
|
#: text splitted into sentences, each sentence being a list of
|
|
161
169
|
#: tokens
|
|
162
170
|
sentences: Optional[List[List[str]]] = None
|
|
@@ -182,14 +190,12 @@ class PipelineState:
|
|
|
182
190
|
#: network)
|
|
183
191
|
character_network: Optional[Union[List[nx.Graph], nx.Graph]] = None
|
|
184
192
|
|
|
193
|
+
# aliases of self.character_network
|
|
185
194
|
def get_characters_graph(self) -> Optional[Union[List[nx.Graph], nx.Graph]]:
|
|
186
|
-
print(
|
|
187
|
-
"[warning] the characters_graph attribute is deprecated, use character_network instead",
|
|
188
|
-
file=sys.stderr,
|
|
189
|
-
)
|
|
190
195
|
return self.character_network
|
|
191
196
|
|
|
192
197
|
characters_graph = property(get_characters_graph)
|
|
198
|
+
character_graph = property(get_characters_graph)
|
|
193
199
|
|
|
194
200
|
def get_character(
|
|
195
201
|
self, name: str, partial_match: bool = True
|
|
@@ -280,6 +286,9 @@ class PipelineState:
|
|
|
280
286
|
cumulative: bool = False,
|
|
281
287
|
stable_layout: bool = False,
|
|
282
288
|
layout: Optional[CharactersGraphLayout] = None,
|
|
289
|
+
node_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
290
|
+
edge_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
291
|
+
label_kwargs: Optional[List[Dict[str, Any]]] = None,
|
|
283
292
|
):
|
|
284
293
|
"""Plot ``self.character_graph`` using reasonable default
|
|
285
294
|
parameters, and save the produced figures in the specified
|
|
@@ -294,6 +303,9 @@ class PipelineState:
|
|
|
294
303
|
timestep. Characters' positions are based on the final
|
|
295
304
|
cumulative graph layout.
|
|
296
305
|
:param layout: pre-computed graph layout
|
|
306
|
+
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
307
|
+
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
308
|
+
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
297
309
|
"""
|
|
298
310
|
import matplotlib.pyplot as plt
|
|
299
311
|
|
|
@@ -317,13 +329,24 @@ class PipelineState:
|
|
|
317
329
|
)
|
|
318
330
|
layout = layout_nx_graph_reasonably(layout_graph)
|
|
319
331
|
|
|
332
|
+
node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
|
|
333
|
+
edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
|
|
334
|
+
label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
|
|
335
|
+
|
|
320
336
|
for i, G in enumerate(graphs):
|
|
321
337
|
_, ax = plt.subplots()
|
|
322
338
|
local_layout = layout
|
|
323
339
|
if not local_layout is None:
|
|
324
340
|
local_layout = layout_with_names(G, local_layout, name_style)
|
|
325
341
|
G = graph_with_names(G, name_style=name_style)
|
|
326
|
-
plot_nx_graph_reasonably(
|
|
342
|
+
plot_nx_graph_reasonably(
|
|
343
|
+
G,
|
|
344
|
+
ax=ax,
|
|
345
|
+
layout=local_layout,
|
|
346
|
+
node_kwargs=node_kwargs[i],
|
|
347
|
+
edge_kwargs=edge_kwargs[i],
|
|
348
|
+
label_kwargs=label_kwargs[i],
|
|
349
|
+
)
|
|
327
350
|
plt.savefig(f"{directory}/{i}.png")
|
|
328
351
|
plt.close()
|
|
329
352
|
|
|
@@ -335,6 +358,9 @@ class PipelineState:
|
|
|
335
358
|
] = "most_frequent",
|
|
336
359
|
layout: Optional[CharactersGraphLayout] = None,
|
|
337
360
|
fig: Optional[plt.Figure] = None,
|
|
361
|
+
node_kwargs: Optional[Dict[str, Any]] = None,
|
|
362
|
+
edge_kwargs: Optional[Dict[str, Any]] = None,
|
|
363
|
+
label_kwargs: Optional[Dict[str, Any]] = None,
|
|
338
364
|
):
|
|
339
365
|
"""Plot ``self.character_graph`` using reasonable parameters,
|
|
340
366
|
and save the produced figure to a file
|
|
@@ -344,6 +370,9 @@ class PipelineState:
|
|
|
344
370
|
:param layout: pre-computed graph layout
|
|
345
371
|
:param fig: if specified, this matplotlib figure will be used
|
|
346
372
|
for plotting
|
|
373
|
+
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
374
|
+
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
375
|
+
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
347
376
|
"""
|
|
348
377
|
import matplotlib.pyplot as plt
|
|
349
378
|
|
|
@@ -361,7 +390,14 @@ class PipelineState:
|
|
|
361
390
|
fig.set_dpi(300)
|
|
362
391
|
fig.set_size_inches(24, 24)
|
|
363
392
|
ax = fig.add_subplot(111)
|
|
364
|
-
plot_nx_graph_reasonably(
|
|
393
|
+
plot_nx_graph_reasonably(
|
|
394
|
+
G,
|
|
395
|
+
ax=ax,
|
|
396
|
+
layout=layout,
|
|
397
|
+
node_kwargs=node_kwargs,
|
|
398
|
+
edge_kwargs=edge_kwargs,
|
|
399
|
+
label_kwargs=label_kwargs,
|
|
400
|
+
)
|
|
365
401
|
plt.savefig(path)
|
|
366
402
|
plt.close()
|
|
367
403
|
|
|
@@ -375,6 +411,9 @@ class PipelineState:
|
|
|
375
411
|
graph_start_idx: int = 1,
|
|
376
412
|
stable_layout: bool = False,
|
|
377
413
|
layout: Optional[CharactersGraphLayout] = None,
|
|
414
|
+
node_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
415
|
+
edge_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
416
|
+
label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
378
417
|
):
|
|
379
418
|
"""Plot ``self.character_network`` using reasonable default
|
|
380
419
|
parameters
|
|
@@ -400,6 +439,9 @@ class PipelineState:
|
|
|
400
439
|
same position in space at each timestep. Characters'
|
|
401
440
|
positions are based on the final cumulative graph layout.
|
|
402
441
|
:param layout: pre-computed graph layout
|
|
442
|
+
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
443
|
+
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
|
|
444
|
+
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
|
|
403
445
|
"""
|
|
404
446
|
import matplotlib.pyplot as plt
|
|
405
447
|
from matplotlib.widgets import Slider
|
|
@@ -418,13 +460,30 @@ class PipelineState:
|
|
|
418
460
|
fig.set_dpi(300)
|
|
419
461
|
fig.set_size_inches(24, 24)
|
|
420
462
|
ax = fig.add_subplot(111)
|
|
421
|
-
|
|
463
|
+
assert not isinstance(node_kwargs, list)
|
|
464
|
+
assert not isinstance(edge_kwargs, list)
|
|
465
|
+
assert not isinstance(label_kwargs, list)
|
|
466
|
+
plot_nx_graph_reasonably(
|
|
467
|
+
G,
|
|
468
|
+
ax=ax,
|
|
469
|
+
layout=layout,
|
|
470
|
+
node_kwargs=node_kwargs,
|
|
471
|
+
edge_kwargs=edge_kwargs,
|
|
472
|
+
label_kwargs=label_kwargs,
|
|
473
|
+
)
|
|
422
474
|
return
|
|
423
475
|
|
|
424
476
|
if not isinstance(self.character_network, list):
|
|
425
477
|
raise TypeError
|
|
426
478
|
# self.character_network is a list: plot a dynamic graph
|
|
427
479
|
|
|
480
|
+
node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
|
|
481
|
+
assert isinstance(node_kwargs, list)
|
|
482
|
+
edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
|
|
483
|
+
assert isinstance(edge_kwargs, list)
|
|
484
|
+
label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
|
|
485
|
+
assert isinstance(label_kwargs, list)
|
|
486
|
+
|
|
428
487
|
if fig is None:
|
|
429
488
|
fig, ax = plt.subplots()
|
|
430
489
|
assert not fig is None
|
|
@@ -440,12 +499,13 @@ class PipelineState:
|
|
|
440
499
|
|
|
441
500
|
def update(slider_value):
|
|
442
501
|
assert isinstance(self.character_network, list)
|
|
502
|
+
slider_i = int(slider_value) - 1
|
|
443
503
|
|
|
444
504
|
character_networks = self.character_network
|
|
445
505
|
if cumulative:
|
|
446
506
|
character_networks = cumulative_character_networks
|
|
447
507
|
|
|
448
|
-
G = character_networks[
|
|
508
|
+
G = character_networks[slider_i]
|
|
449
509
|
|
|
450
510
|
local_layout = layout
|
|
451
511
|
if not local_layout is None:
|
|
@@ -453,7 +513,14 @@ class PipelineState:
|
|
|
453
513
|
G = graph_with_names(G, name_style)
|
|
454
514
|
|
|
455
515
|
ax.clear()
|
|
456
|
-
plot_nx_graph_reasonably(
|
|
516
|
+
plot_nx_graph_reasonably(
|
|
517
|
+
G,
|
|
518
|
+
ax=ax,
|
|
519
|
+
layout=local_layout,
|
|
520
|
+
node_kwargs=node_kwargs[slider_i],
|
|
521
|
+
edge_kwargs=edge_kwargs[slider_i],
|
|
522
|
+
label_kwargs=label_kwargs[slider_i],
|
|
523
|
+
)
|
|
457
524
|
ax.set_xlim(-1.2, 1.2)
|
|
458
525
|
ax.set_ylim(-1.2, 1.2)
|
|
459
526
|
|
|
@@ -474,6 +541,10 @@ class PipelineState:
|
|
|
474
541
|
class Pipeline:
|
|
475
542
|
"""A flexible NLP pipeline"""
|
|
476
543
|
|
|
544
|
+
#: all the possible parameters of the whole pipeline, that are
|
|
545
|
+
#: shared between steps
|
|
546
|
+
PipelineParameter = Literal["lang", "progress_reporter", "character_ner_tag"]
|
|
547
|
+
|
|
477
548
|
def __init__(
|
|
478
549
|
self,
|
|
479
550
|
steps: List[PipelineStep],
|
|
@@ -496,17 +567,27 @@ class Pipeline:
|
|
|
496
567
|
self.progress_reporter = get_progress_reporter(progress_report)
|
|
497
568
|
|
|
498
569
|
self.lang = lang
|
|
570
|
+
self.character_ner_tag = "PER"
|
|
499
571
|
self.warn = warn
|
|
500
572
|
|
|
501
|
-
def
|
|
502
|
-
"""
|
|
573
|
+
def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
|
|
574
|
+
"""Initialise steps with global pipeline parameters.
|
|
575
|
+
|
|
503
576
|
:param ignored_steps: a list of steps production. All steps
|
|
504
577
|
with a production in ``ignored_steps`` will be ignored.
|
|
505
578
|
"""
|
|
506
|
-
steps_progress_reporter =
|
|
579
|
+
steps_progress_reporter = self.progress_reporter.get_subreporter()
|
|
507
580
|
steps = self._non_ignored_steps(ignored_steps)
|
|
581
|
+
pipeline_params = {
|
|
582
|
+
"progress_reporter": steps_progress_reporter,
|
|
583
|
+
"character_ner_tag": self.character_ner_tag,
|
|
584
|
+
}
|
|
508
585
|
for step in steps:
|
|
509
|
-
step._pipeline_init_(self.lang,
|
|
586
|
+
step_additional_params = step._pipeline_init_(self.lang, **pipeline_params)
|
|
587
|
+
if not step_additional_params is None:
|
|
588
|
+
for key, value in step_additional_params.items():
|
|
589
|
+
setattr(self, key, value)
|
|
590
|
+
pipeline_params[key] = value
|
|
510
591
|
|
|
511
592
|
def _non_ignored_steps(
|
|
512
593
|
self, ignored_steps: Optional[List[str]]
|
|
@@ -549,13 +630,27 @@ class Pipeline:
|
|
|
549
630
|
return (
|
|
550
631
|
False,
|
|
551
632
|
[
|
|
552
|
-
|
|
633
|
+
"".join(
|
|
634
|
+
[
|
|
635
|
+
f"step {i + 1} ({step.__class__.__name__}) has unsatisfied needs. "
|
|
636
|
+
+ f"needs: {step.needs()}. "
|
|
637
|
+
+ f"available: {pipeline_state}). "
|
|
638
|
+
+ f"missing: {step.needs() - pipeline_state}."
|
|
639
|
+
]
|
|
640
|
+
),
|
|
553
641
|
],
|
|
554
642
|
)
|
|
555
643
|
|
|
556
644
|
if not step.optional_needs().issubset(pipeline_state):
|
|
557
645
|
warnings.append(
|
|
558
|
-
|
|
646
|
+
"".join(
|
|
647
|
+
[
|
|
648
|
+
f"step {i + 1} ({step.__class__.__name__}) has unsatisfied optional needs. "
|
|
649
|
+
+ f"needs: {step.optional_needs()}. "
|
|
650
|
+
+ f"available: {pipeline_state}). "
|
|
651
|
+
+ f"missing: {step.optional_needs() - pipeline_state}."
|
|
652
|
+
]
|
|
653
|
+
)
|
|
559
654
|
)
|
|
560
655
|
|
|
561
656
|
pipeline_state = pipeline_state.union(step.production())
|
|
@@ -582,9 +677,9 @@ class Pipeline:
|
|
|
582
677
|
raise ValueError(warnings_or_errors)
|
|
583
678
|
if self.warn:
|
|
584
679
|
for warning in warnings_or_errors:
|
|
585
|
-
print(f"[warning] : {warning}")
|
|
680
|
+
print(f"[warning] : {warning}", file=sys.stderr)
|
|
586
681
|
|
|
587
|
-
self.
|
|
682
|
+
self._pipeline_init_steps_(ignored_steps)
|
|
588
683
|
|
|
589
684
|
state = PipelineState(text)
|
|
590
685
|
# sets attributes to PipelineState dynamically. This ensures
|
renard/pipeline/corefs/corefs.py
CHANGED
|
@@ -25,6 +25,7 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
25
25
|
device: Literal["auto", "cuda", "cpu"] = "auto",
|
|
26
26
|
tokenizer: Optional[PreTrainedTokenizerFast] = None,
|
|
27
27
|
block_size: int = 512,
|
|
28
|
+
hierarchical_merging: bool = False,
|
|
28
29
|
) -> None:
|
|
29
30
|
"""
|
|
30
31
|
.. note::
|
|
@@ -40,6 +41,10 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
40
41
|
:param device: computation device
|
|
41
42
|
:param block_size: size of blocks to pass to the coreference
|
|
42
43
|
model
|
|
44
|
+
:param hierarchical_merging: if ``True``, attempts to use
|
|
45
|
+
tibert's hierarchical merging feature. In that case,
|
|
46
|
+
blocks of size ``block_size`` are merged to perform
|
|
47
|
+
inference on the whole document.
|
|
43
48
|
"""
|
|
44
49
|
if isinstance(model, str):
|
|
45
50
|
self.hugginface_model_id = hugginface_model_id
|
|
@@ -58,15 +63,15 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
58
63
|
self.device = torch.device(device)
|
|
59
64
|
|
|
60
65
|
self.block_size = block_size
|
|
66
|
+
self.hierarchical_merging = hierarchical_merging
|
|
61
67
|
|
|
62
68
|
super().__init__()
|
|
63
69
|
|
|
64
|
-
def _pipeline_init_(self, lang: str,
|
|
70
|
+
def _pipeline_init_(self, lang: str, **kwargs):
|
|
65
71
|
from tibert import BertForCoreferenceResolution
|
|
66
72
|
from transformers import BertTokenizerFast, AutoTokenizer
|
|
67
73
|
|
|
68
74
|
if self.model is None:
|
|
69
|
-
|
|
70
75
|
# the user supplied a huggingface ID: load model from the HUB
|
|
71
76
|
if not self.hugginface_model_id is None:
|
|
72
77
|
self.model = BertForCoreferenceResolution.from_pretrained(
|
|
@@ -87,16 +92,29 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
87
92
|
|
|
88
93
|
assert not self.tokenizer is None
|
|
89
94
|
|
|
90
|
-
super()._pipeline_init_(lang,
|
|
95
|
+
super()._pipeline_init_(lang, **kwargs)
|
|
91
96
|
|
|
92
97
|
def __call__(self, tokens: List[str], **kwargs) -> Dict[str, Any]:
|
|
93
|
-
from tibert import stream_predict_coref
|
|
98
|
+
from tibert import stream_predict_coref, predict_coref
|
|
99
|
+
from tibert.bertcoref import CoreferenceDocument
|
|
94
100
|
|
|
95
101
|
blocks = [
|
|
96
102
|
tokens[block_start : block_start + self.block_size]
|
|
97
103
|
for block_start in range(0, len(tokens), self.block_size)
|
|
98
104
|
]
|
|
99
105
|
|
|
106
|
+
if self.hierarchical_merging:
|
|
107
|
+
doc = predict_coref(
|
|
108
|
+
blocks,
|
|
109
|
+
self.model,
|
|
110
|
+
self.tokenizer,
|
|
111
|
+
batch_size=self.batch_size,
|
|
112
|
+
quiet=True,
|
|
113
|
+
device_str=self.device,
|
|
114
|
+
hierarchical_merging=True,
|
|
115
|
+
)
|
|
116
|
+
return {"corefs": doc.coref_chains}
|
|
117
|
+
|
|
100
118
|
coref_docs = []
|
|
101
119
|
for doc in self._progress_(
|
|
102
120
|
stream_predict_coref(
|
|
@@ -111,26 +129,7 @@ class BertCoreferenceResolver(PipelineStep):
|
|
|
111
129
|
):
|
|
112
130
|
coref_docs.append(doc)
|
|
113
131
|
|
|
114
|
-
|
|
115
|
-
# blocks. The following code adjusts their start and end index
|
|
116
|
-
# to match their global coordinate in the text.
|
|
117
|
-
coref_chains = []
|
|
118
|
-
cur_doc_start = 0
|
|
119
|
-
for doc in coref_docs:
|
|
120
|
-
for chain in doc.coref_chains:
|
|
121
|
-
adjusted_chain = []
|
|
122
|
-
for mention in chain:
|
|
123
|
-
# FIXME: It seems that a rare bug in Tibert can
|
|
124
|
-
# ----- sometimes produce this unwanted state.
|
|
125
|
-
if mention.start_idx is None:
|
|
126
|
-
mention.start_idx = 0
|
|
127
|
-
start_idx = mention.start_idx + cur_doc_start
|
|
128
|
-
end_idx = mention.end_idx + cur_doc_start
|
|
129
|
-
adjusted_chain.append(Mention(mention.tokens, start_idx, end_idx))
|
|
130
|
-
coref_chains.append(adjusted_chain)
|
|
131
|
-
cur_doc_start += len(doc)
|
|
132
|
-
|
|
133
|
-
return {"corefs": coref_chains}
|
|
132
|
+
return {"corefs": CoreferenceDocument.concatenated(coref_docs).coref_chains}
|
|
134
133
|
|
|
135
134
|
def needs(self) -> Set[str]:
|
|
136
135
|
return {"tokens"}
|
|
@@ -239,19 +238,19 @@ class SpacyCorefereeCoreferenceResolver(PipelineStep):
|
|
|
239
238
|
self,
|
|
240
239
|
text: str,
|
|
241
240
|
tokens: List[str],
|
|
242
|
-
|
|
241
|
+
dynamic_blocks_tokens: Optional[List[List[str]]] = None,
|
|
243
242
|
**kwargs,
|
|
244
243
|
) -> Dict[str, Any]:
|
|
245
244
|
from spacy.tokens import Doc
|
|
246
245
|
from coreferee.manager import CorefereeBroker
|
|
247
246
|
|
|
248
|
-
if
|
|
249
|
-
|
|
247
|
+
if dynamic_blocks_tokens is None:
|
|
248
|
+
dynamic_blocks_tokens = [tokens]
|
|
250
249
|
|
|
251
|
-
if len(
|
|
250
|
+
if len(dynamic_blocks_tokens) > 1:
|
|
252
251
|
chunks = []
|
|
253
|
-
for
|
|
254
|
-
chunks += self._cut_into_chunks(
|
|
252
|
+
for block in dynamic_blocks_tokens:
|
|
253
|
+
chunks += self._cut_into_chunks(block)
|
|
255
254
|
else:
|
|
256
255
|
chunks = self._cut_into_chunks(tokens)
|
|
257
256
|
|
|
@@ -317,7 +316,7 @@ class SpacyCorefereeCoreferenceResolver(PipelineStep):
|
|
|
317
316
|
return {"tokens"}
|
|
318
317
|
|
|
319
318
|
def optional_needs(self) -> Set[str]:
|
|
320
|
-
return {"
|
|
319
|
+
return {"dynamic_blocks_tokens"}
|
|
321
320
|
|
|
322
321
|
def production(self) -> Set[str]:
|
|
323
322
|
return {"corefs"}
|