renard-pipeline 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of renard-pipeline might be problematic. Click here for more details.

@@ -20,11 +20,12 @@ class BertCoreferenceResolver(PipelineStep):
20
20
  def __init__(
21
21
  self,
22
22
  model: Optional[Union[BertForCoreferenceResolution]] = None,
23
- hugginface_model_id: Optional[str] = None,
23
+ huggingface_model_id: Optional[str] = None,
24
24
  batch_size: int = 1,
25
25
  device: Literal["auto", "cuda", "cpu"] = "auto",
26
26
  tokenizer: Optional[PreTrainedTokenizerFast] = None,
27
27
  block_size: int = 512,
28
+ hierarchical_merging: bool = False,
28
29
  ) -> None:
29
30
  """
30
31
  .. note::
@@ -40,9 +41,13 @@ class BertCoreferenceResolver(PipelineStep):
40
41
  :param device: computation device
41
42
  :param block_size: size of blocks to pass to the coreference
42
43
  model
44
+ :param hierarchical_merging: if ``True``, attempts to use
45
+ tibert's hierarchical merging feature. In that case,
46
+ blocks of size ``block_size`` are merged to perform
47
+ inference on the whole document.
43
48
  """
44
49
  if isinstance(model, str):
45
- self.hugginface_model_id = hugginface_model_id
50
+ self.hugginface_model_id = huggingface_model_id
46
51
  self.model = None # model will be init by _pipeline_init_
47
52
  else:
48
53
  self.hugginface_model_id = None
@@ -58,15 +63,15 @@ class BertCoreferenceResolver(PipelineStep):
58
63
  self.device = torch.device(device)
59
64
 
60
65
  self.block_size = block_size
66
+ self.hierarchical_merging = hierarchical_merging
61
67
 
62
68
  super().__init__()
63
69
 
64
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
70
+ def _pipeline_init_(self, lang: str, **kwargs):
65
71
  from tibert import BertForCoreferenceResolution
66
72
  from transformers import BertTokenizerFast, AutoTokenizer
67
73
 
68
74
  if self.model is None:
69
-
70
75
  # the user supplied a huggingface ID: load model from the HUB
71
76
  if not self.hugginface_model_id is None:
72
77
  self.model = BertForCoreferenceResolution.from_pretrained(
@@ -87,16 +92,29 @@ class BertCoreferenceResolver(PipelineStep):
87
92
 
88
93
  assert not self.tokenizer is None
89
94
 
90
- super()._pipeline_init_(lang, progress_reporter)
95
+ super()._pipeline_init_(lang, **kwargs)
91
96
 
92
97
  def __call__(self, tokens: List[str], **kwargs) -> Dict[str, Any]:
93
- from tibert import stream_predict_coref
98
+ from tibert import stream_predict_coref, predict_coref
99
+ from tibert.bertcoref import CoreferenceDocument
94
100
 
95
101
  blocks = [
96
102
  tokens[block_start : block_start + self.block_size]
97
103
  for block_start in range(0, len(tokens), self.block_size)
98
104
  ]
99
105
 
106
+ if self.hierarchical_merging:
107
+ doc = predict_coref(
108
+ blocks,
109
+ self.model,
110
+ self.tokenizer,
111
+ batch_size=self.batch_size,
112
+ quiet=True,
113
+ device_str=self.device,
114
+ hierarchical_merging=True,
115
+ )
116
+ return {"corefs": doc.coref_chains}
117
+
100
118
  coref_docs = []
101
119
  for doc in self._progress_(
102
120
  stream_predict_coref(
@@ -111,26 +129,7 @@ class BertCoreferenceResolver(PipelineStep):
111
129
  ):
112
130
  coref_docs.append(doc)
113
131
 
114
- # chains found in coref_docs are each local to their
115
- # blocks. The following code adjusts their start and end index
116
- # to match their global coordinate in the text.
117
- coref_chains = []
118
- cur_doc_start = 0
119
- for doc in coref_docs:
120
- for chain in doc.coref_chains:
121
- adjusted_chain = []
122
- for mention in chain:
123
- # FIXME: It seems that a rare bug in Tibert can
124
- # ----- sometimes produce this unwanted state.
125
- if mention.start_idx is None:
126
- mention.start_idx = 0
127
- start_idx = mention.start_idx + cur_doc_start
128
- end_idx = mention.end_idx + cur_doc_start
129
- adjusted_chain.append(Mention(mention.tokens, start_idx, end_idx))
130
- coref_chains.append(adjusted_chain)
131
- cur_doc_start += len(doc)
132
-
133
- return {"corefs": coref_chains}
132
+ return {"corefs": CoreferenceDocument.concatenated(coref_docs).coref_chains}
134
133
 
135
134
  def needs(self) -> Set[str]:
136
135
  return {"tokens"}
@@ -239,19 +238,19 @@ class SpacyCorefereeCoreferenceResolver(PipelineStep):
239
238
  self,
240
239
  text: str,
241
240
  tokens: List[str],
242
- chapter_tokens: Optional[List[List[str]]] = None,
241
+ dynamic_blocks_tokens: Optional[List[List[str]]] = None,
243
242
  **kwargs,
244
243
  ) -> Dict[str, Any]:
245
244
  from spacy.tokens import Doc
246
245
  from coreferee.manager import CorefereeBroker
247
246
 
248
- if chapter_tokens is None:
249
- chapter_tokens = [tokens]
247
+ if dynamic_blocks_tokens is None:
248
+ dynamic_blocks_tokens = [tokens]
250
249
 
251
- if len(chapter_tokens) > 1:
250
+ if len(dynamic_blocks_tokens) > 1:
252
251
  chunks = []
253
- for chapter in chapter_tokens:
254
- chunks += self._cut_into_chunks(chapter)
252
+ for block in dynamic_blocks_tokens:
253
+ chunks += self._cut_into_chunks(block)
255
254
  else:
256
255
  chunks = self._cut_into_chunks(tokens)
257
256
 
@@ -317,7 +316,7 @@ class SpacyCorefereeCoreferenceResolver(PipelineStep):
317
316
  return {"tokens"}
318
317
 
319
318
  def optional_needs(self) -> Set[str]:
320
- return {"chapter_tokens"}
319
+ return {"dynamic_blocks_tokens"}
321
320
 
322
321
  def production(self) -> Set[str]:
323
322
  return {"corefs"}