renard-pipeline 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of renard-pipeline might be problematic. Click here for more details.

renard/graph_utils.py CHANGED
@@ -70,10 +70,17 @@ def graph_with_names(
70
70
  else:
71
71
  name_style_fn = name_style
72
72
 
73
- return nx.relabel_nodes(
74
- G,
75
- {character: name_style_fn(character) for character in G.nodes()}, # type: ignore
76
- )
73
+ mapping = {}
74
+ for character in G.nodes():
75
+ # NOTE: it is *possible* to have a graph where nodes are not
76
+ # characters (for example, simple strings). Therefore, we are
77
+ # lenient here
78
+ try:
79
+ mapping[character] = name_style_fn(character)
80
+ except AttributeError:
81
+ mapping[character] = character
82
+
83
+ return nx.relabel_nodes(G, mapping)
77
84
 
78
85
 
79
86
  def layout_with_names(
renard/ner_utils.py CHANGED
@@ -110,6 +110,10 @@ class NERDataset(Dataset):
110
110
  elt_context_mask = self._context_mask[index]
111
111
  for i in range(len(element)):
112
112
  w2t = batch.word_to_tokens(0, i)
113
+ # w2t can be None in case of truncation, which can happen
114
+ # if `element' is too long
115
+ if w2t is None:
116
+ continue
113
117
  mask_value = elt_context_mask[i]
114
118
  tokens_mask = [mask_value] * (w2t.end - w2t.start)
115
119
  batch["context_mask"][w2t.start : w2t.end] = tokens_mask
@@ -61,6 +61,8 @@ def _assign_coreference_mentions(
61
61
  # we assign each chain to the character with highest name
62
62
  # occurence in it
63
63
  for chain in corefs:
64
+ if len(char_mentions) == 0:
65
+ break
64
66
  # determine the characters with the highest number of
65
67
  # occurences
66
68
  occ_counter = {}
@@ -98,8 +100,13 @@ class NaiveCharacterUnifier(PipelineStep):
98
100
  character for it to be valid
99
101
  """
100
102
  self.min_appearances = min_appearances
103
+ # a default value, will be est by _pipeline_init_
104
+ self.character_ner_tag = "PER"
101
105
  super().__init__()
102
106
 
107
+ def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
108
+ self.character_ner_tag = character_ner_tag
109
+
103
110
  def __call__(
104
111
  self,
105
112
  text: str,
@@ -112,7 +119,7 @@ class NaiveCharacterUnifier(PipelineStep):
112
119
  :param tokens:
113
120
  :param entities:
114
121
  """
115
- persons = [e for e in entities if e.tag == "PER"]
122
+ persons = [e for e in entities if e.tag == self.character_ner_tag]
116
123
 
117
124
  characters = defaultdict(list)
118
125
  for entity in persons:
@@ -159,6 +166,7 @@ class GraphRulesCharacterUnifier(PipelineStep):
159
166
  min_appearances: int = 0,
160
167
  additional_hypocorisms: Optional[List[Tuple[str, List[str]]]] = None,
161
168
  link_corefs_mentions: bool = False,
169
+ ignore_lone_titles: Optional[Set[str]] = None,
162
170
  ) -> None:
163
171
  """
164
172
  :param min_appearances: minimum number of appearances of a
@@ -173,20 +181,27 @@ class GraphRulesCharacterUnifier(PipelineStep):
173
181
  extract a lot of spurious links. However, linking by
174
182
  coref is sometimes the only way to resolve a character
175
183
  alias.
184
+ :param ignore_lone_titles: a set of titles to ignore when
185
+ they stand on their own. This avoids extracting false
186
+ positives characters such as 'Mr.' or 'Miss'.
176
187
  """
177
188
  self.min_appearances = min_appearances
178
189
  self.additional_hypocorisms = additional_hypocorisms
179
190
  self.link_corefs_mentions = link_corefs_mentions
191
+ self.ignore_lone_titles = ignore_lone_titles or set()
192
+ self.character_ner_tag = "PER" # a default value, will be set by _pipeline_init
180
193
 
181
194
  super().__init__()
182
195
 
183
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
196
+ def _pipeline_init_(self, lang: str, character_ner_tag: str, **kwargs):
184
197
  self.hypocorism_gazetteer = HypocorismGazetteer(lang=lang)
185
198
  if not self.additional_hypocorisms is None:
186
199
  for name, nicknames in self.additional_hypocorisms:
187
200
  self.hypocorism_gazetteer._add_hypocorism_(name, nicknames)
188
201
 
189
- return super()._pipeline_init_(lang, progress_reporter)
202
+ self.character_ner_tag = character_ner_tag
203
+
204
+ return super()._pipeline_init_(lang, **kwargs)
190
205
 
191
206
  def __call__(
192
207
  self,
@@ -196,12 +211,17 @@ class GraphRulesCharacterUnifier(PipelineStep):
196
211
  ) -> Dict[str, Any]:
197
212
  import networkx as nx
198
213
 
199
- mentions = [m for m in entities if m.tag == "PER"]
200
- mentions_str = [" ".join(m.tokens) for m in mentions]
214
+ mentions = [m for m in entities if m.tag == self.character_ner_tag]
215
+ mentions_str = set(
216
+ filter(
217
+ lambda m: not m in self.ignore_lone_titles,
218
+ map(lambda m: " ".join(m.tokens), mentions),
219
+ )
220
+ )
201
221
 
202
222
  # * create a graph where each node is a mention detected by NER
203
223
  G = nx.Graph()
204
- for mention_str in set(mentions_str):
224
+ for mention_str in mentions_str:
205
225
  G.add_node(mention_str)
206
226
 
207
227
  # * HumanName local configuration - dependant on language
@@ -1,7 +1,9 @@
1
+ import sys
1
2
  import renard.pipeline.character_unification as cu
2
3
 
3
4
  print(
4
- "[warning] the characters_extraction module is deprecated. Use character_unification instead."
5
+ "[warning] the characters_extraction module is deprecated. Use character_unification instead.",
6
+ file=sys.stderr,
5
7
  )
6
8
 
7
9
  Character = cu.Character
renard/pipeline/core.py CHANGED
@@ -79,11 +79,18 @@ class PipelineStep:
79
79
  """Initialize the :class:`PipelineStep` with a given configuration."""
80
80
  pass
81
81
 
82
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
83
- """Set the step configuration that is common to the whole pipeline.
84
-
85
- :param lang: ISO 639-3 language string
86
- :param progress_report:
82
+ def _pipeline_init_(
83
+ self, lang: str, progress_reporter: ProgressReporter, **kwargs
84
+ ) -> Optional[Dict[Pipeline.PipelineParameter, Any]]:
85
+ """Set the step configuration that is common to the whole
86
+ pipeline.
87
+
88
+ :param lang: the lang of the whole pipeline
89
+ :param progress_reporter:
90
+ :param kwargs: additional pipeline parameters.
91
+
92
+ :return: a step can return a dictionary of pipeline params if
93
+ it wish to modify some of these.
87
94
  """
88
95
  supported_langs = self.supported_langs()
89
96
  if not supported_langs == "any" and not lang in supported_langs:
@@ -150,13 +157,14 @@ class PipelineState:
150
157
  #: input text
151
158
  text: Optional[str]
152
159
 
153
- #: text split into chapters
154
- chapters: Optional[List[str]] = None
160
+ #: text split into blocks of texts. When dynamic blocks are given,
161
+ #: the final network is dynamic, and split according to blocks.
162
+ dynamic_blocks: Optional[List[Tuple[int, int]]] = None
155
163
 
156
164
  #: text splitted in tokens
157
165
  tokens: Optional[List[str]] = None
158
- #: text splitted in tokens, by chapter
159
- chapter_tokens: Optional[List[List[str]]] = None
166
+ #: mapping from a character to its corresponding token
167
+ char2token: Optional[List[int]] = None
160
168
  #: text splitted into sentences, each sentence being a list of
161
169
  #: tokens
162
170
  sentences: Optional[List[List[str]]] = None
@@ -182,14 +190,12 @@ class PipelineState:
182
190
  #: network)
183
191
  character_network: Optional[Union[List[nx.Graph], nx.Graph]] = None
184
192
 
193
+ # aliases of self.character_network
185
194
  def get_characters_graph(self) -> Optional[Union[List[nx.Graph], nx.Graph]]:
186
- print(
187
- "[warning] the characters_graph attribute is deprecated, use character_network instead",
188
- file=sys.stderr,
189
- )
190
195
  return self.character_network
191
196
 
192
197
  characters_graph = property(get_characters_graph)
198
+ character_graph = property(get_characters_graph)
193
199
 
194
200
  def get_character(
195
201
  self, name: str, partial_match: bool = True
@@ -280,6 +286,9 @@ class PipelineState:
280
286
  cumulative: bool = False,
281
287
  stable_layout: bool = False,
282
288
  layout: Optional[CharactersGraphLayout] = None,
289
+ node_kwargs: Optional[List[Dict[str, Any]]] = None,
290
+ edge_kwargs: Optional[List[Dict[str, Any]]] = None,
291
+ label_kwargs: Optional[List[Dict[str, Any]]] = None,
283
292
  ):
284
293
  """Plot ``self.character_graph`` using reasonable default
285
294
  parameters, and save the produced figures in the specified
@@ -294,6 +303,9 @@ class PipelineState:
294
303
  timestep. Characters' positions are based on the final
295
304
  cumulative graph layout.
296
305
  :param layout: pre-computed graph layout
306
+ :param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
307
+ :param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
308
+ :param label_kwargs: passed to :func:`nx.draw_networkx_labels`
297
309
  """
298
310
  import matplotlib.pyplot as plt
299
311
 
@@ -317,13 +329,24 @@ class PipelineState:
317
329
  )
318
330
  layout = layout_nx_graph_reasonably(layout_graph)
319
331
 
332
+ node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
333
+ edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
334
+ label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
335
+
320
336
  for i, G in enumerate(graphs):
321
337
  _, ax = plt.subplots()
322
338
  local_layout = layout
323
339
  if not local_layout is None:
324
340
  local_layout = layout_with_names(G, local_layout, name_style)
325
341
  G = graph_with_names(G, name_style=name_style)
326
- plot_nx_graph_reasonably(G, ax=ax, layout=local_layout)
342
+ plot_nx_graph_reasonably(
343
+ G,
344
+ ax=ax,
345
+ layout=local_layout,
346
+ node_kwargs=node_kwargs[i],
347
+ edge_kwargs=edge_kwargs[i],
348
+ label_kwargs=label_kwargs[i],
349
+ )
327
350
  plt.savefig(f"{directory}/{i}.png")
328
351
  plt.close()
329
352
 
@@ -335,6 +358,9 @@ class PipelineState:
335
358
  ] = "most_frequent",
336
359
  layout: Optional[CharactersGraphLayout] = None,
337
360
  fig: Optional[plt.Figure] = None,
361
+ node_kwargs: Optional[Dict[str, Any]] = None,
362
+ edge_kwargs: Optional[Dict[str, Any]] = None,
363
+ label_kwargs: Optional[Dict[str, Any]] = None,
338
364
  ):
339
365
  """Plot ``self.character_graph`` using reasonable parameters,
340
366
  and save the produced figure to a file
@@ -344,6 +370,9 @@ class PipelineState:
344
370
  :param layout: pre-computed graph layout
345
371
  :param fig: if specified, this matplotlib figure will be used
346
372
  for plotting
373
+ :param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
374
+ :param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
375
+ :param label_kwargs: passed to :func:`nx.draw_networkx_labels`
347
376
  """
348
377
  import matplotlib.pyplot as plt
349
378
 
@@ -361,7 +390,14 @@ class PipelineState:
361
390
  fig.set_dpi(300)
362
391
  fig.set_size_inches(24, 24)
363
392
  ax = fig.add_subplot(111)
364
- plot_nx_graph_reasonably(G, ax=ax, layout=layout)
393
+ plot_nx_graph_reasonably(
394
+ G,
395
+ ax=ax,
396
+ layout=layout,
397
+ node_kwargs=node_kwargs,
398
+ edge_kwargs=edge_kwargs,
399
+ label_kwargs=label_kwargs,
400
+ )
365
401
  plt.savefig(path)
366
402
  plt.close()
367
403
 
@@ -375,6 +411,9 @@ class PipelineState:
375
411
  graph_start_idx: int = 1,
376
412
  stable_layout: bool = False,
377
413
  layout: Optional[CharactersGraphLayout] = None,
414
+ node_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
415
+ edge_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
416
+ label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
378
417
  ):
379
418
  """Plot ``self.character_network`` using reasonable default
380
419
  parameters
@@ -400,6 +439,9 @@ class PipelineState:
400
439
  same position in space at each timestep. Characters'
401
440
  positions are based on the final cumulative graph layout.
402
441
  :param layout: pre-computed graph layout
442
+ :param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
443
+ :param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
444
+ :param label_kwargs: passed to :func:`nx.draw_networkx_labels`
403
445
  """
404
446
  import matplotlib.pyplot as plt
405
447
  from matplotlib.widgets import Slider
@@ -418,13 +460,30 @@ class PipelineState:
418
460
  fig.set_dpi(300)
419
461
  fig.set_size_inches(24, 24)
420
462
  ax = fig.add_subplot(111)
421
- plot_nx_graph_reasonably(G, ax=ax, layout=layout)
463
+ assert not isinstance(node_kwargs, list)
464
+ assert not isinstance(edge_kwargs, list)
465
+ assert not isinstance(label_kwargs, list)
466
+ plot_nx_graph_reasonably(
467
+ G,
468
+ ax=ax,
469
+ layout=layout,
470
+ node_kwargs=node_kwargs,
471
+ edge_kwargs=edge_kwargs,
472
+ label_kwargs=label_kwargs,
473
+ )
422
474
  return
423
475
 
424
476
  if not isinstance(self.character_network, list):
425
477
  raise TypeError
426
478
  # self.character_network is a list: plot a dynamic graph
427
479
 
480
+ node_kwargs = node_kwargs or [{} for _ in range(len(self.character_network))]
481
+ assert isinstance(node_kwargs, list)
482
+ edge_kwargs = edge_kwargs or [{} for _ in range(len(self.character_network))]
483
+ assert isinstance(edge_kwargs, list)
484
+ label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
485
+ assert isinstance(label_kwargs, list)
486
+
428
487
  if fig is None:
429
488
  fig, ax = plt.subplots()
430
489
  assert not fig is None
@@ -440,12 +499,13 @@ class PipelineState:
440
499
 
441
500
  def update(slider_value):
442
501
  assert isinstance(self.character_network, list)
502
+ slider_i = int(slider_value) - 1
443
503
 
444
504
  character_networks = self.character_network
445
505
  if cumulative:
446
506
  character_networks = cumulative_character_networks
447
507
 
448
- G = character_networks[int(slider_value) - 1]
508
+ G = character_networks[slider_i]
449
509
 
450
510
  local_layout = layout
451
511
  if not local_layout is None:
@@ -453,7 +513,14 @@ class PipelineState:
453
513
  G = graph_with_names(G, name_style)
454
514
 
455
515
  ax.clear()
456
- plot_nx_graph_reasonably(G, ax=ax, layout=local_layout)
516
+ plot_nx_graph_reasonably(
517
+ G,
518
+ ax=ax,
519
+ layout=local_layout,
520
+ node_kwargs=node_kwargs[slider_i],
521
+ edge_kwargs=edge_kwargs[slider_i],
522
+ label_kwargs=label_kwargs[slider_i],
523
+ )
457
524
  ax.set_xlim(-1.2, 1.2)
458
525
  ax.set_ylim(-1.2, 1.2)
459
526
 
@@ -474,6 +541,10 @@ class PipelineState:
474
541
  class Pipeline:
475
542
  """A flexible NLP pipeline"""
476
543
 
544
+ #: all the possible parameters of the whole pipeline, that are
545
+ #: shared between steps
546
+ PipelineParameter = Literal["lang", "progress_reporter", "character_ner_tag"]
547
+
477
548
  def __init__(
478
549
  self,
479
550
  steps: List[PipelineStep],
@@ -496,17 +567,27 @@ class Pipeline:
496
567
  self.progress_reporter = get_progress_reporter(progress_report)
497
568
 
498
569
  self.lang = lang
570
+ self.character_ner_tag = "PER"
499
571
  self.warn = warn
500
572
 
501
- def _pipeline_init_steps(self, ignored_steps: Optional[List[str]] = None):
502
- """
573
+ def _pipeline_init_steps_(self, ignored_steps: Optional[List[str]] = None):
574
+ """Initialise steps with global pipeline parameters.
575
+
503
576
  :param ignored_steps: a list of steps production. All steps
504
577
  with a production in ``ignored_steps`` will be ignored.
505
578
  """
506
- steps_progress_reporter = get_progress_reporter(self.progress_report)
579
+ steps_progress_reporter = self.progress_reporter.get_subreporter()
507
580
  steps = self._non_ignored_steps(ignored_steps)
581
+ pipeline_params = {
582
+ "progress_reporter": steps_progress_reporter,
583
+ "character_ner_tag": self.character_ner_tag,
584
+ }
508
585
  for step in steps:
509
- step._pipeline_init_(self.lang, steps_progress_reporter)
586
+ step_additional_params = step._pipeline_init_(self.lang, **pipeline_params)
587
+ if not step_additional_params is None:
588
+ for key, value in step_additional_params.items():
589
+ setattr(self, key, value)
590
+ pipeline_params[key] = value
510
591
 
511
592
  def _non_ignored_steps(
512
593
  self, ignored_steps: Optional[List[str]]
@@ -549,13 +630,27 @@ class Pipeline:
549
630
  return (
550
631
  False,
551
632
  [
552
- f"step {i + 1} ({step.__class__.__name__}) has unsatisfied needs (needs : {step.needs()}, available : {pipeline_state})"
633
+ "".join(
634
+ [
635
+ f"step {i + 1} ({step.__class__.__name__}) has unsatisfied needs. "
636
+ + f"needs: {step.needs()}. "
637
+ + f"available: {pipeline_state}). "
638
+ + f"missing: {step.needs() - pipeline_state}."
639
+ ]
640
+ ),
553
641
  ],
554
642
  )
555
643
 
556
644
  if not step.optional_needs().issubset(pipeline_state):
557
645
  warnings.append(
558
- f"step {i + 1} ({step.__class__.__name__}) has unsatisfied optional needs : (optional needs : {step.optional_needs()}, available : {pipeline_state})"
646
+ "".join(
647
+ [
648
+ f"step {i + 1} ({step.__class__.__name__}) has unsatisfied optional needs. "
649
+ + f"needs: {step.optional_needs()}. "
650
+ + f"available: {pipeline_state}). "
651
+ + f"missing: {step.optional_needs() - pipeline_state}."
652
+ ]
653
+ )
559
654
  )
560
655
 
561
656
  pipeline_state = pipeline_state.union(step.production())
@@ -582,9 +677,9 @@ class Pipeline:
582
677
  raise ValueError(warnings_or_errors)
583
678
  if self.warn:
584
679
  for warning in warnings_or_errors:
585
- print(f"[warning] : {warning}")
680
+ print(f"[warning] : {warning}", file=sys.stderr)
586
681
 
587
- self._pipeline_init_steps(ignored_steps)
682
+ self._pipeline_init_steps_(ignored_steps)
588
683
 
589
684
  state = PipelineState(text)
590
685
  # sets attributes to PipelineState dynamically. This ensures
@@ -25,6 +25,7 @@ class BertCoreferenceResolver(PipelineStep):
25
25
  device: Literal["auto", "cuda", "cpu"] = "auto",
26
26
  tokenizer: Optional[PreTrainedTokenizerFast] = None,
27
27
  block_size: int = 512,
28
+ hierarchical_merging: bool = False,
28
29
  ) -> None:
29
30
  """
30
31
  .. note::
@@ -40,6 +41,10 @@ class BertCoreferenceResolver(PipelineStep):
40
41
  :param device: computation device
41
42
  :param block_size: size of blocks to pass to the coreference
42
43
  model
44
+ :param hierarchical_merging: if ``True``, attempts to use
45
+ tibert's hierarchical merging feature. In that case,
46
+ blocks of size ``block_size`` are merged to perform
47
+ inference on the whole document.
43
48
  """
44
49
  if isinstance(model, str):
45
50
  self.hugginface_model_id = hugginface_model_id
@@ -58,15 +63,15 @@ class BertCoreferenceResolver(PipelineStep):
58
63
  self.device = torch.device(device)
59
64
 
60
65
  self.block_size = block_size
66
+ self.hierarchical_merging = hierarchical_merging
61
67
 
62
68
  super().__init__()
63
69
 
64
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
70
+ def _pipeline_init_(self, lang: str, **kwargs):
65
71
  from tibert import BertForCoreferenceResolution
66
72
  from transformers import BertTokenizerFast, AutoTokenizer
67
73
 
68
74
  if self.model is None:
69
-
70
75
  # the user supplied a huggingface ID: load model from the HUB
71
76
  if not self.hugginface_model_id is None:
72
77
  self.model = BertForCoreferenceResolution.from_pretrained(
@@ -87,16 +92,29 @@ class BertCoreferenceResolver(PipelineStep):
87
92
 
88
93
  assert not self.tokenizer is None
89
94
 
90
- super()._pipeline_init_(lang, progress_reporter)
95
+ super()._pipeline_init_(lang, **kwargs)
91
96
 
92
97
  def __call__(self, tokens: List[str], **kwargs) -> Dict[str, Any]:
93
- from tibert import stream_predict_coref
98
+ from tibert import stream_predict_coref, predict_coref
99
+ from tibert.bertcoref import CoreferenceDocument
94
100
 
95
101
  blocks = [
96
102
  tokens[block_start : block_start + self.block_size]
97
103
  for block_start in range(0, len(tokens), self.block_size)
98
104
  ]
99
105
 
106
+ if self.hierarchical_merging:
107
+ doc = predict_coref(
108
+ blocks,
109
+ self.model,
110
+ self.tokenizer,
111
+ batch_size=self.batch_size,
112
+ quiet=True,
113
+ device_str=self.device,
114
+ hierarchical_merging=True,
115
+ )
116
+ return {"corefs": doc.coref_chains}
117
+
100
118
  coref_docs = []
101
119
  for doc in self._progress_(
102
120
  stream_predict_coref(
@@ -111,26 +129,7 @@ class BertCoreferenceResolver(PipelineStep):
111
129
  ):
112
130
  coref_docs.append(doc)
113
131
 
114
- # chains found in coref_docs are each local to their
115
- # blocks. The following code adjusts their start and end index
116
- # to match their global coordinate in the text.
117
- coref_chains = []
118
- cur_doc_start = 0
119
- for doc in coref_docs:
120
- for chain in doc.coref_chains:
121
- adjusted_chain = []
122
- for mention in chain:
123
- # FIXME: It seems that a rare bug in Tibert can
124
- # ----- sometimes produce this unwanted state.
125
- if mention.start_idx is None:
126
- mention.start_idx = 0
127
- start_idx = mention.start_idx + cur_doc_start
128
- end_idx = mention.end_idx + cur_doc_start
129
- adjusted_chain.append(Mention(mention.tokens, start_idx, end_idx))
130
- coref_chains.append(adjusted_chain)
131
- cur_doc_start += len(doc)
132
-
133
- return {"corefs": coref_chains}
132
+ return {"corefs": CoreferenceDocument.concatenated(coref_docs).coref_chains}
134
133
 
135
134
  def needs(self) -> Set[str]:
136
135
  return {"tokens"}
@@ -239,19 +238,19 @@ class SpacyCorefereeCoreferenceResolver(PipelineStep):
239
238
  self,
240
239
  text: str,
241
240
  tokens: List[str],
242
- chapter_tokens: Optional[List[List[str]]] = None,
241
+ dynamic_blocks_tokens: Optional[List[List[str]]] = None,
243
242
  **kwargs,
244
243
  ) -> Dict[str, Any]:
245
244
  from spacy.tokens import Doc
246
245
  from coreferee.manager import CorefereeBroker
247
246
 
248
- if chapter_tokens is None:
249
- chapter_tokens = [tokens]
247
+ if dynamic_blocks_tokens is None:
248
+ dynamic_blocks_tokens = [tokens]
250
249
 
251
- if len(chapter_tokens) > 1:
250
+ if len(dynamic_blocks_tokens) > 1:
252
251
  chunks = []
253
- for chapter in chapter_tokens:
254
- chunks += self._cut_into_chunks(chapter)
252
+ for block in dynamic_blocks_tokens:
253
+ chunks += self._cut_into_chunks(block)
255
254
  else:
256
255
  chunks = self._cut_into_chunks(tokens)
257
256
 
@@ -317,7 +316,7 @@ class SpacyCorefereeCoreferenceResolver(PipelineStep):
317
316
  return {"tokens"}
318
317
 
319
318
  def optional_needs(self) -> Set[str]:
320
- return {"chapter_tokens"}
319
+ return {"dynamic_blocks_tokens"}
321
320
 
322
321
  def production(self) -> Set[str]:
323
322
  return {"corefs"}