renard-pipeline 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of renard-pipeline might be problematic. Click here for more details.

renard/graph_utils.py CHANGED
@@ -70,10 +70,17 @@ def graph_with_names(
70
70
  else:
71
71
  name_style_fn = name_style
72
72
 
73
- return nx.relabel_nodes(
74
- G,
75
- {character: name_style_fn(character) for character in G.nodes()}, # type: ignore
76
- )
73
+ mapping = {}
74
+ for character in G.nodes():
75
+ # NOTE: it is *possible* to have a graph where nodes are not
76
+ # characters (for example, simple strings). Therefore, we are
77
+ # lenient here
78
+ try:
79
+ mapping[character] = name_style_fn(character)
80
+ except AttributeError:
81
+ mapping[character] = character
82
+
83
+ return nx.relabel_nodes(G, mapping)
77
84
 
78
85
 
79
86
  def layout_with_names(
renard/ner_utils.py CHANGED
@@ -74,7 +74,7 @@ class DataCollatorForTokenClassificationWithBatchEncoding:
74
74
  class NERDataset(Dataset):
75
75
  """
76
76
  :ivar _context_mask: for each element, a mask indicating which
77
- tokens are part of the context (1 for context, 0 for text on
77
+ tokens are part of the context (0 for context, 1 for text on
78
78
  which to perform inference). The mask allows to discard
79
79
  predictions made for context at inference time, even though
80
80
  the context can still be passed as input to the model.
@@ -92,11 +92,11 @@ class NERDataset(Dataset):
92
92
  assert all(
93
93
  [len(cm) == len(elt) for elt, cm in zip(self.elements, context_mask)]
94
94
  )
95
- self._context_mask = context_mask or [[0] * len(elt) for elt in self.elements]
95
+ self._context_mask = context_mask or [[1] * len(elt) for elt in self.elements]
96
96
 
97
97
  self.tokenizer = tokenizer
98
98
 
99
- def __getitem__(self, index: Union[int, List[int]]) -> BatchEncoding:
99
+ def __getitem__(self, index: int) -> BatchEncoding:
100
100
  element = self.elements[index]
101
101
 
102
102
  batch = self.tokenizer(
@@ -104,15 +104,18 @@ class NERDataset(Dataset):
104
104
  truncation=True,
105
105
  max_length=512, # TODO
106
106
  is_split_into_words=True,
107
+ return_length=True,
107
108
  )
108
109
 
109
- batch["context_mask"] = [0] * len(batch["input_ids"])
110
- elt_context_mask = self._context_mask[index]
111
- for i in range(len(element)):
112
- w2t = batch.word_to_tokens(0, i)
113
- mask_value = elt_context_mask[i]
114
- tokens_mask = [mask_value] * (w2t.end - w2t.start)
115
- batch["context_mask"][w2t.start : w2t.end] = tokens_mask
110
+ length = batch["length"][0]
111
+ del batch["length"]
112
+ if self.tokenizer.truncation_side == "right":
113
+ batch["context_mask"] = self._context_mask[index][:length]
114
+ else:
115
+ assert self.tokenizer.truncation_side == "left"
116
+ batch["context_mask"] = self._context_mask[index][
117
+ len(batch["input_ids"]) - length :
118
+ ]
116
119
 
117
120
  return batch
118
121
 
@@ -181,6 +184,7 @@ def load_conll2002_bio(
181
184
  path: str,
182
185
  tag_conversion_map: Optional[Dict[str, str]] = None,
183
186
  separator: str = "\t",
187
+ max_sent_len: Optional[int] = None,
184
188
  **kwargs,
185
189
  ) -> Tuple[List[List[str]], List[str], List[NEREntity]]:
186
190
  """Load a file under CoNLL2022 BIO format. Sentences are expected
@@ -192,7 +196,9 @@ def load_conll2002_bio(
192
196
  :param separator: separator between token and BIO tags
193
197
  :param tag_conversion_map: conversion map for tags found in the
194
198
  input file. Example : ``{'B': 'B-PER', 'I': 'I-PER'}``
195
- :param kwargs: additional kwargs for ``open`` (such as
199
+ :param max_sent_len: if specified, maximum length, in tokens, of
200
+ sentences.
201
+ :param kwargs: additional kwargs for :func:`open` (such as
196
202
  ``encoding`` or ``newline``).
197
203
 
198
204
  :return: ``(sentences, tokens, entities)``
@@ -207,7 +213,9 @@ def load_conll2002_bio(
207
213
  tags = []
208
214
  for line in raw_data.split("\n"):
209
215
  line = line.strip("\n")
210
- if re.fullmatch(r"\s*", line):
216
+ if re.fullmatch(r"\s*", line) or (
217
+ not max_sent_len is None and len(sent_tokens) >= max_sent_len
218
+ ):
211
219
  if len(sent_tokens) == 0:
212
220
  continue
213
221
  sents.append(sent_tokens)
@@ -227,6 +235,7 @@ def hgdataset_from_conll2002(
227
235
  path: str,
228
236
  tag_conversion_map: Optional[Dict[str, str]] = None,
229
237
  separator: str = "\t",
238
+ max_sent_len: Optional[int] = None,
230
239
  **kwargs,
231
240
  ) -> HGDataset:
232
241
  """Load a CoNLL-2002 file as a Huggingface Dataset.
@@ -234,12 +243,13 @@ def hgdataset_from_conll2002(
234
243
  :param path: passed to :func:`.load_conll2002_bio`
235
244
  :param tag_conversion_map: passed to :func:`load_conll2002_bio`
236
245
  :param separator: passed to :func:`load_conll2002_bio`
237
- :param kwargs: passed to :func:`load_conll2002_bio`
246
+ :param max_sent_len: passed to :func:`load_conll2002_bio`
247
+ :param kwargs: additional kwargs for :func:`open`
238
248
 
239
249
  :return: a :class:`datasets.Dataset` with features 'tokens' and 'labels'.
240
250
  """
241
251
  sentences, tokens, entities = load_conll2002_bio(
242
- path, tag_conversion_map, separator, **kwargs
252
+ path, tag_conversion_map, separator, max_sent_len, **kwargs
243
253
  )
244
254
 
245
255
  # convert entities to labels
@@ -1,5 +1,5 @@
1
1
  from typing import Any, Dict, List, FrozenSet, Set, Optional, Tuple, Union, Literal
2
- import copy
2
+ import re, sys
3
3
  from itertools import combinations
4
4
  from collections import defaultdict, Counter
5
5
  from dataclasses import dataclass
@@ -11,6 +11,7 @@ from renard.pipeline.ner import NEREntity
11
11
  from renard.pipeline.progress import ProgressReporter
12
12
  from renard.resources.hypocorisms import HypocorismGazetteer
13
13
  from renard.resources.pronouns import is_a_female_pronoun, is_a_male_pronoun
14
+ from renard.resources.determiners import singular_determiners
14
15
  from renard.resources.titles import is_a_male_title, is_a_female_title, all_titles
15
16
 
16
17
 
@@ -61,6 +62,8 @@ def _assign_coreference_mentions(
61
62
  # we assign each chain to the character with highest name
62
63
  # occurence in it
63
64
  for chain in corefs:
65
+ if len(char_mentions) == 0:
66
+ break
64
67
  # determine the characters with the highest number of
65
68
  # occurences
66
69
  occ_counter = {}
@@ -98,8 +101,13 @@ class NaiveCharacterUnifier(PipelineStep):
98
101
  character for it to be valid
99
102
  """
100
103
  self.min_appearances = min_appearances
104
+ # a default value, will be est by _pipeline_init_
105
+ self.character_ner_tag = "PER"
101
106
  super().__init__()
102
107
 
108
+ def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
109
+ self.character_ner_tag = character_ner_tag
110
+
103
111
  def __call__(
104
112
  self,
105
113
  text: str,
@@ -112,7 +120,7 @@ class NaiveCharacterUnifier(PipelineStep):
112
120
  :param tokens:
113
121
  :param entities:
114
122
  """
115
- persons = [e for e in entities if e.tag == "PER"]
123
+ persons = [e for e in entities if e.tag == self.character_ner_tag]
116
124
 
117
125
  characters = defaultdict(list)
118
126
  for entity in persons:
@@ -160,6 +168,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
160
168
  additional_hypocorisms: Optional[List[Tuple[str, List[str]]]] = None,
161
169
  link_corefs_mentions: bool = False,
162
170
  ignore_lone_titles: Optional[Set[str]] = None,
171
+ ignore_leading_determiner: bool = False,
163
172
  ) -> None:
164
173
  """
165
174
  :param min_appearances: minimum number of appearances of a
@@ -174,24 +183,32 @@ class GraphRulesCharacterUnifier(PipelineStep):
174
183
  extract a lot of spurious links. However, linking by
175
184
  coref is sometimes the only way to resolve a character
176
185
  alias.
177
- :param ignore_lone_titles: a set of titles to ignore when
178
- they stand on their own. This avoids extracting false
186
+ :param ignore_lone_titles: a set of titles to ignore when they
187
+ stand on their own. This avoids extracting false
179
188
  positives characters such as 'Mr.' or 'Miss'.
189
+ :param ignore_leading_determiner: if ``True``, will ignore the
190
+ leading determiner when applying unification rules. This
191
+ is useful if the NER model used in the pipeline adds
192
+ leading determiners as part of entites.
180
193
  """
181
194
  self.min_appearances = min_appearances
182
195
  self.additional_hypocorisms = additional_hypocorisms
183
196
  self.link_corefs_mentions = link_corefs_mentions
184
197
  self.ignore_lone_titles = ignore_lone_titles or set()
198
+ self.character_ner_tag = "PER" # a default value, will be set by _pipeline_init
199
+ self.ignore_leading_determiner = ignore_leading_determiner
185
200
 
186
201
  super().__init__()
187
202
 
188
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
203
+ def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
189
204
  self.hypocorism_gazetteer = HypocorismGazetteer(lang=lang)
190
205
  if not self.additional_hypocorisms is None:
191
206
  for name, nicknames in self.additional_hypocorisms:
192
207
  self.hypocorism_gazetteer._add_hypocorism_(name, nicknames)
193
208
 
194
- return super()._pipeline_init_(lang, progress_reporter)
209
+ self.character_ner_tag = character_ner_tag
210
+
211
+ return super()._pipeline_init_(lang, **kwargs)
195
212
 
196
213
  def __call__(
197
214
  self,
@@ -201,7 +218,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
201
218
  ) -> Dict[str, Any]:
202
219
  import networkx as nx
203
220
 
204
- mentions = [m for m in entities if m.tag == "PER"]
221
+ mentions = [m for m in entities if m.tag == self.character_ner_tag]
205
222
  mentions_str = set(
206
223
  filter(
207
224
  lambda m: not m in self.ignore_lone_titles,
@@ -219,23 +236,28 @@ class GraphRulesCharacterUnifier(PipelineStep):
219
236
 
220
237
  # * link nodes based on several rules
221
238
  for name1, name2 in combinations(G.nodes(), 2):
239
+
240
+ # preprocess name when needed
241
+ pname1 = self._preprocess_name(name1)
242
+ pname2 = self._preprocess_name(name2)
243
+
222
244
  # is one name a known hypocorism of the other ? (also
223
245
  # checks if both names are the same)
224
- if self.hypocorism_gazetteer.are_related(name1, name2):
246
+ if self.hypocorism_gazetteer.are_related(pname1, pname2):
225
247
  G.add_edge(name1, name2)
226
248
  continue
227
249
 
228
250
  # if we remove the title, is one name related to the other
229
251
  # ?
230
252
  if self.names_are_related_after_title_removal(
231
- name1, name2, hname_constants
253
+ pname1, pname2, hname_constants
232
254
  ):
233
255
  G.add_edge(name1, name2)
234
256
  continue
235
257
 
236
258
  # add an edge if two characters have the same family names
237
- human_name1 = HumanName(name1, constants=hname_constants)
238
- human_name2 = HumanName(name2, constants=hname_constants)
259
+ human_name1 = HumanName(pname1, constants=hname_constants)
260
+ human_name2 = HumanName(pname2, constants=hname_constants)
239
261
  if (
240
262
  len(human_name1.last) > 0
241
263
  and human_name1.last.lower() == human_name2.last.lower()
@@ -272,10 +294,15 @@ class GraphRulesCharacterUnifier(PipelineStep):
272
294
  pass
273
295
 
274
296
  for name1, name2 in combinations(G.nodes(), 2):
297
+
298
+ # preprocess names when needed
299
+ pname1 = self._preprocess_name(name1)
300
+ pname2 = self._preprocess_name(name2)
301
+
275
302
  # check if characters have the same last name but a
276
303
  # different first name.
277
- human_name1 = HumanName(name1, constants=hname_constants)
278
- human_name2 = HumanName(name2, constants=hname_constants)
304
+ human_name1 = HumanName(pname1, constants=hname_constants)
305
+ human_name2 = HumanName(pname2, constants=hname_constants)
279
306
  if (
280
307
  len(human_name1.last) > 0
281
308
  and len(human_name2.last) > 0
@@ -327,6 +354,17 @@ class GraphRulesCharacterUnifier(PipelineStep):
327
354
 
328
355
  return {"characters": characters}
329
356
 
357
+ def _preprocess_name(self, name) -> str:
358
+ if self.ignore_leading_determiner:
359
+ if not self.lang in singular_determiners:
360
+ print(
361
+ f"[warning] can't ignore leading determiners for {self.lang}",
362
+ file=sys.stderr,
363
+ )
364
+ for determiner in singular_determiners.get(self.lang, []):
365
+ name = re.sub(f"^{determiner} ", " ", name, flags=re.I)
366
+ return name
367
+
330
368
  def _make_hname_constants(self) -> Constants:
331
369
  if self.lang == "eng":
332
370
  return Constants()
@@ -355,13 +393,18 @@ class GraphRulesCharacterUnifier(PipelineStep):
355
393
  or self.hypocorism_gazetteer.are_related(raw_name1, raw_name2)
356
394
  )
357
395
 
358
- def names_are_in_coref(self, name1: str, name2: str, corefs: List[List[Mention]]):
396
+ def names_are_in_coref(
397
+ self, name1: str, name2: str, corefs: List[List[Mention]]
398
+ ) -> bool:
399
+ once_together = False
359
400
  for coref_chain in corefs:
360
- if any([name1 == " ".join(m.tokens) for m in coref_chain]) and any(
361
- [name2 == " ".join(m.tokens) for m in coref_chain]
362
- ):
363
- return True
364
- return False
401
+ name1_in = any([name1 == " ".join(m.tokens) for m in coref_chain])
402
+ name2_in = any([name2 == " ".join(m.tokens) for m in coref_chain])
403
+ if name1_in == (not name2_in):
404
+ return False
405
+ elif name1_in and name2_in:
406
+ once_together = True
407
+ return once_together
365
408
 
366
409
  def infer_name_gender(
367
410
  self,
@@ -1,7 +1,9 @@
1
+ import sys
1
2
  import renard.pipeline.character_unification as cu
2
3
 
3
4
  print(
4
- "[warning] the characters_extraction module is deprecated. Use character_unification instead."
5
+ "[warning] the characters_extraction module is deprecated. Use character_unification instead.",
6
+ file=sys.stderr,
5
7
  )
6
8
 
7
9
  Character = cu.Character
renard/pipeline/core.py CHANGED
@@ -79,11 +79,18 @@ class PipelineStep:
79
79
  """Initialize the :class:`PipelineStep` with a given configuration."""
80
80
  pass
81
81
 
82
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
83
- """Set the step configuration that is common to the whole pipeline.
84
-
85
- :param lang: ISO 639-3 language string
86
- :param progress_report:
82
+ def _pipeline_init_(
83
+ self, lang: str, progress_reporter: ProgressReporter, **kwargs
84
+ ) -> Optional[Dict[Pipeline.PipelineParameter, Any]]:
85
+ """Set the step configuration that is common to the whole
86
+ pipeline.
87
+
88
+ :param lang: the lang of the whole pipeline
89
+ :param progress_reporter:
90
+ :param kwargs: additional pipeline parameters.
91
+
92
+ :return: a step can return a dictionary of pipeline params if
93
+ it wish to modify some of these.
87
94
  """
88
95
  supported_langs = self.supported_langs()
89
96
  if not supported_langs == "any" and not lang in supported_langs:
@@ -150,13 +157,14 @@ class PipelineState:
150
157
  #: input text
151
158
  text: Optional[str]
152
159
 
153
- #: text split into chapters
154
- chapters: Optional[List[str]] = None
160
+ #: text split into blocks of texts. When dynamic blocks are given,
161
+ #: the final network is dynamic, and split according to blocks.
162
+ dynamic_blocks: Optional[List[Tuple[int, int]]] = None
155
163
 
156
164
  #: text splitted in tokens
157
165
  tokens: Optional[List[str]] = None
158
- #: text splitted in tokens, by chapter
159
- chapter_tokens: Optional[List[List[str]]] = None
166
+ #: mapping from a character to its corresponding token
167
+ char2token: Optional[List[int]] = None
160
168
  #: text splitted into sentences, each sentence being a list of
161
169
  #: tokens
162
170
  sentences: Optional[List[List[str]]] = None
@@ -182,14 +190,12 @@ class PipelineState:
182
190
  #: network)
183
191
  character_network: Optional[Union[List[nx.Graph], nx.Graph]] = None
184
192
 
193
+ # aliases of self.character_network
185
194
  def get_characters_graph(self) -> Optional[Union[List[nx.Graph], nx.Graph]]:
186
- print(
187
- "[warning] the characters_graph attribute is deprecated, use character_network instead",
188
- file=sys.stderr,
189
- )
190
195
  return self.character_network
191
196
 
192
197
  characters_graph = property(get_characters_graph)
198
+ character_graph = property(get_characters_graph)
193
199
 
194
200
  def get_character(
195
201
  self, name: str, partial_match: bool = True
@@ -280,6 +286,10 @@ class PipelineState:
280
286
  cumulative: bool = False,
281
287
  stable_layout: bool = False,
282
288
  layout: Optional[CharactersGraphLayout] = None,
289
+ node_kwargs: Optional[List[Dict[str, Any]]] = None,
290
+ edge_kwargs: Optional[List[Dict[str, Any]]] = None,
291
+ label_kwargs: Optional[List[Dict[str, Any]]] = None,
292
+ legend: bool = False,
283
293
  ):
284
294
  """Plot ``self.character_graph`` using reasonable default
285
295
  parameters, and save the produced figures in the specified
@@ -294,6 +304,10 @@ class PipelineState:
294
304
  timestep. Characters' positions are based on the final
295
305
  cumulative graph layout.
296
306
  :param layout: pre-computed graph layout
307
+ :param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
308
+ :param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
309
+ :param label_kwargs: passed to :func:`nx.draw_networkx_labels`
310
+ :param legend: passed to :func:`.plot_nx_graph_reasonably`
297
311
  """
298
312
  import matplotlib.pyplot as plt
299
313
 
@@ -317,13 +331,25 @@ class PipelineState:
317
331
  )
318
332
  layout = layout_nx_graph_reasonably(layout_graph)
319
333
 
334
+ node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
335
+ edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
336
+ label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
337
+
320
338
  for i, G in enumerate(graphs):
321
339
  _, ax = plt.subplots()
322
340
  local_layout = layout
323
341
  if not local_layout is None:
324
342
  local_layout = layout_with_names(G, local_layout, name_style)
325
343
  G = graph_with_names(G, name_style=name_style)
326
- plot_nx_graph_reasonably(G, ax=ax, layout=local_layout)
344
+ plot_nx_graph_reasonably(
345
+ G,
346
+ ax=ax,
347
+ layout=local_layout,
348
+ node_kwargs=node_kwargs[i],
349
+ edge_kwargs=edge_kwargs[i],
350
+ label_kwargs=label_kwargs[i],
351
+ legend=legend,
352
+ )
327
353
  plt.savefig(f"{directory}/{i}.png")
328
354
  plt.close()
329
355
 
@@ -335,6 +361,11 @@ class PipelineState:
335
361
  ] = "most_frequent",
336
362
  layout: Optional[CharactersGraphLayout] = None,
337
363
  fig: Optional[plt.Figure] = None,
364
+ node_kwargs: Optional[Dict[str, Any]] = None,
365
+ edge_kwargs: Optional[Dict[str, Any]] = None,
366
+ label_kwargs: Optional[Dict[str, Any]] = None,
367
+ tight_layout: bool = False,
368
+ legend: bool = False,
338
369
  ):
339
370
  """Plot ``self.character_graph`` using reasonable parameters,
340
371
  and save the produced figure to a file
@@ -344,6 +375,11 @@ class PipelineState:
344
375
  :param layout: pre-computed graph layout
345
376
  :param fig: if specified, this matplotlib figure will be used
346
377
  for plotting
378
+ :param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
379
+ :param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
380
+ :param label_kwargs: passed to :func:`nx.draw_networkx_labels`
381
+ :param tight_layout: if ``True``, will use matplotlib's tight_layout
382
+ :param legend: passed to :func:`.plot_nx_graph_reasonably`
347
383
  """
348
384
  import matplotlib.pyplot as plt
349
385
 
@@ -361,7 +397,17 @@ class PipelineState:
361
397
  fig.set_dpi(300)
362
398
  fig.set_size_inches(24, 24)
363
399
  ax = fig.add_subplot(111)
364
- plot_nx_graph_reasonably(G, ax=ax, layout=layout)
400
+ plot_nx_graph_reasonably(
401
+ G,
402
+ ax=ax,
403
+ layout=layout,
404
+ node_kwargs=node_kwargs,
405
+ edge_kwargs=edge_kwargs,
406
+ label_kwargs=label_kwargs,
407
+ legend=legend,
408
+ )
409
+ if tight_layout:
410
+ fig.tight_layout()
365
411
  plt.savefig(path)
366
412
  plt.close()
367
413
 
@@ -375,6 +421,11 @@ class PipelineState:
375
421
  graph_start_idx: int = 1,
376
422
  stable_layout: bool = False,
377
423
  layout: Optional[CharactersGraphLayout] = None,
424
+ node_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
425
+ edge_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
426
+ label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
427
+ tight_layout: bool = False,
428
+ legend: bool = False,
378
429
  ):
379
430
  """Plot ``self.character_network`` using reasonable default
380
431
  parameters
@@ -400,6 +451,11 @@ class PipelineState:
400
451
  same position in space at each timestep. Characters'
401
452
  positions are based on the final cumulative graph layout.
402
453
  :param layout: pre-computed graph layout
454
+ :param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
455
+ :param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
456
+ :param label_kwargs: passed to :func:`nx.draw_networkx_labels`
457
+ :param tight_layout: if ``True``, will use matplotlib's tight_layout
458
+ :param legend: passed to :func:`.plot_nx_graph_reasonably`
403
459
  """
404
460
  import matplotlib.pyplot as plt
405
461
  from matplotlib.widgets import Slider
@@ -418,13 +474,33 @@ class PipelineState:
418
474
  fig.set_dpi(300)
419
475
  fig.set_size_inches(24, 24)
420
476
  ax = fig.add_subplot(111)
421
- plot_nx_graph_reasonably(G, ax=ax, layout=layout)
477
+ assert not isinstance(node_kwargs, list)
478
+ assert not isinstance(edge_kwargs, list)
479
+ assert not isinstance(label_kwargs, list)
480
+ if tight_layout:
481
+ fig.tight_layout()
482
+ plot_nx_graph_reasonably(
483
+ G,
484
+ ax=ax,
485
+ layout=layout,
486
+ node_kwargs=node_kwargs,
487
+ edge_kwargs=edge_kwargs,
488
+ label_kwargs=label_kwargs,
489
+ legend=legend,
490
+ )
422
491
  return
423
492
 
424
493
  if not isinstance(self.character_network, list):
425
494
  raise TypeError
426
495
  # self.character_network is a list: plot a dynamic graph
427
496
 
497
+ node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
498
+ assert isinstance(node_kwargs, list)
499
+ edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
500
+ assert isinstance(edge_kwargs, list)
501
+ label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
502
+ assert isinstance(label_kwargs, list)
503
+
428
504
  if fig is None:
429
505
  fig, ax = plt.subplots()
430
506
  assert not fig is None
@@ -440,12 +516,13 @@ class PipelineState:
440
516
 
441
517
  def update(slider_value):
442
518
  assert isinstance(self.character_network, list)
519
+ slider_i = int(slider_value) - 1
443
520
 
444
521
  character_networks = self.character_network
445
522
  if cumulative:
446
523
  character_networks = cumulative_character_networks
447
524
 
448
- G = character_networks[int(slider_value) - 1]
525
+ G = character_networks[slider_i]
449
526
 
450
527
  local_layout = layout
451
528
  if not local_layout is None:
@@ -453,11 +530,21 @@ class PipelineState:
453
530
  G = graph_with_names(G, name_style)
454
531
 
455
532
  ax.clear()
456
- plot_nx_graph_reasonably(G, ax=ax, layout=local_layout)
533
+ plot_nx_graph_reasonably(
534
+ G,
535
+ ax=ax,
536
+ layout=local_layout,
537
+ node_kwargs=node_kwargs[slider_i],
538
+ edge_kwargs=edge_kwargs[slider_i],
539
+ label_kwargs=label_kwargs[slider_i],
540
+ legend=legend,
541
+ )
457
542
  ax.set_xlim(-1.2, 1.2)
458
543
  ax.set_ylim(-1.2, 1.2)
459
544
 
460
545
  slider_ax = fig.add_axes([0.1, 0.05, 0.8, 0.04])
546
+ if tight_layout:
547
+ fig.tight_layout()
461
548
  # HACK: we save the slider to the figure. This ensure the
462
549
  # slider is still alive at plotting time.
463
550
  fig.slider = Slider( # type: ignore
@@ -474,6 +561,10 @@ class PipelineState:
474
561
  class Pipeline:
475
562
  """A flexible NLP pipeline"""
476
563
 
564
+ #: all the possible parameters of the whole pipeline, that are
565
+ #: shared between steps
566
+ PipelineParameter = Literal["lang", "progress_reporter", "character_ner_tag"]
567
+
477
568
  def __init__(
478
569
  self,
479
570
  steps: List[PipelineStep],
@@ -496,17 +587,27 @@ class Pipeline:
496
587
  self.progress_reporter = get_progress_reporter(progress_report)
497
588
 
498
589
  self.lang = lang
590
+ self.character_ner_tag = "PER"
499
591
  self.warn = warn
500
592
 
501
- def _pipeline_init_steps(self, ignored_steps: Optional[List[str]] = None):
502
- """
593
+ def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
594
+ """Initialise steps with global pipeline parameters.
595
+
503
596
  :param ignored_steps: a list of steps production. All steps
504
597
  with a production in ``ignored_steps`` will be ignored.
505
598
  """
506
- steps_progress_reporter = get_progress_reporter(self.progress_report)
599
+ steps_progress_reporter = self.progress_reporter.get_subreporter()
507
600
  steps = self._non_ignored_steps(ignored_steps)
601
+ pipeline_params = {
602
+ "progress_reporter": steps_progress_reporter,
603
+ "character_ner_tag": self.character_ner_tag,
604
+ }
508
605
  for step in steps:
509
- step._pipeline_init_(self.lang, steps_progress_reporter)
606
+ step_additional_params = step._pipeline_init_(self.lang, **pipeline_params)
607
+ if not step_additional_params is None:
608
+ for key, value in step_additional_params.items():
609
+ setattr(self, key, value)
610
+ pipeline_params[key] = value
510
611
 
511
612
  def _non_ignored_steps(
512
613
  self, ignored_steps: Optional[List[str]]
@@ -549,13 +650,27 @@ class Pipeline:
549
650
  return (
550
651
  False,
551
652
  [
552
- f"step {i + 1} ({step.__class__.__name__}) has unsatisfied needs (needs : {step.needs()}, available : {pipeline_state})"
653
+ "".join(
654
+ [
655
+ f"step {i + 1} ({step.__class__.__name__}) has unsatisfied needs. "
656
+ + f"needs: {step.needs()}. "
657
+ + f"available: {pipeline_state}). "
658
+ + f"missing: {step.needs() - pipeline_state}."
659
+ ]
660
+ ),
553
661
  ],
554
662
  )
555
663
 
556
664
  if not step.optional_needs().issubset(pipeline_state):
557
665
  warnings.append(
558
- f"step {i + 1} ({step.__class__.__name__}) has unsatisfied optional needs : (optional needs : {step.optional_needs()}, available : {pipeline_state})"
666
+ "".join(
667
+ [
668
+ f"step {i + 1} ({step.__class__.__name__}) has unsatisfied optional needs. "
669
+ + f"needs: {step.optional_needs()}. "
670
+ + f"available: {pipeline_state}). "
671
+ + f"missing: {step.optional_needs() - pipeline_state}."
672
+ ]
673
+ )
559
674
  )
560
675
 
561
676
  pipeline_state = pipeline_state.union(step.production())
@@ -582,9 +697,9 @@ class Pipeline:
582
697
  raise ValueError(warnings_or_errors)
583
698
  if self.warn:
584
699
  for warning in warnings_or_errors:
585
- print(f"[warning] : {warning}")
700
+ print(f"[warning] : {warning}", file=sys.stderr)
586
701
 
587
- self._pipeline_init_steps(ignored_steps)
702
+ self._pipeline_init_steps_(ignored_steps)
588
703
 
589
704
  state = PipelineState(text)
590
705
  # sets attributes to PipelineState dynamically. This ensures