renard-pipeline 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of renard-pipeline might be problematic. Click here for more details.

@@ -0,0 +1 @@
1
+ from renard.pipeline.ner.ner import *
@@ -1,21 +1,32 @@
1
1
  from __future__ import annotations
2
- import random, itertools
3
- from typing import TYPE_CHECKING, List, Dict, Any, Set, Tuple, Optional, Union, Literal
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ List,
5
+ Dict,
6
+ Any,
7
+ Set,
8
+ Tuple,
9
+ Optional,
10
+ Union,
11
+ Literal,
12
+ )
4
13
  from dataclasses import dataclass
5
14
  import torch
6
- from seqeval.metrics import precision_score, recall_score, f1_score
7
15
  from renard.nltk_utils import nltk_fix_bio_tags
8
16
  from renard.ner_utils import (
9
17
  DataCollatorForTokenClassificationWithBatchEncoding,
10
18
  NERDataset,
11
19
  )
12
20
  from renard.pipeline.core import PipelineStep, Mention
13
- from renard.pipeline.progress import ProgressReporter
14
21
  from renard.ner_utils import ner_entities
15
22
 
16
23
  if TYPE_CHECKING:
17
24
  from transformers.tokenization_utils_base import BatchEncoding
18
- from transformers import PreTrainedModel, PreTrainedTokenizerFast
25
+ from transformers import (
26
+ PreTrainedModel,
27
+ PreTrainedTokenizerFast,
28
+ )
29
+ from renard.pipeline.ner.retrieval import NERContextRetriever
19
30
 
20
31
 
21
32
  @dataclass
@@ -27,7 +38,7 @@ class NEREntity(Mention):
27
38
  """
28
39
  .. note::
29
40
 
30
- This method is implemtented here to avoid type issues. Since
41
+ This method is implemented here to avoid type issues. Since
31
42
  :meth:`.Mention.shifted` cannot be annotated as returning
32
43
  ``Self``, this method annotate the correct return type when
33
44
  using :meth:`.NEREntity.shifted`.
@@ -41,18 +52,21 @@ class NEREntity(Mention):
41
52
  def score_ner(
42
53
  pred_bio_tags: List[str], ref_bio_tags: List[str]
43
54
  ) -> Tuple[float, float, float]:
44
- """Score NER as in CoNLL-2003 shared task using ``seqeval``
55
+ """Score NER as in CoNLL-2003 shared task using the ``seqeval``
56
+ library, if installed.
45
57
 
46
58
  Precision is the percentage of named entities in ``ref_bio_tags``
47
- that are correct. Recall is the percentage of named entities in
48
- pred_bio_tags that are in ref_bio_tags. F1 is the harmonic mean of
49
- both.
59
+ that are correct. Recall is the percentage of named entities in
60
+ pred_bio_tags that are in ref_bio_tags. F1 is the harmonic mean
61
+ of both.
50
62
 
51
63
  :param pred_bio_tags:
52
64
  :param ref_bio_tags:
53
- :return: ``(precision, recall, F1 score)``
54
65
 
66
+ :return: ``(precision, recall, F1 score)``
55
67
  """
68
+ from seqeval.metrics import precision_score, recall_score, f1_score
69
+
56
70
  assert len(pred_bio_tags) == len(ref_bio_tags)
57
71
  return (
58
72
  precision_score([ref_bio_tags], [pred_bio_tags]),
@@ -70,12 +84,19 @@ class NLTKNamedEntityRecognizer(PipelineStep):
70
84
  """
71
85
  import nltk
72
86
 
73
- nltk.download("averaged_perceptron_tagger", quiet=True)
87
+ nltk.download(f"averaged_perceptron_tagger", quiet=True)
74
88
  nltk.download("maxent_ne_chunker", quiet=True)
89
+ nltk.download("maxent_ne_chunker_tab", quiet=True)
75
90
  nltk.download("words", quiet=True)
76
91
 
77
92
  super().__init__()
78
93
 
94
+ def _pipeline_init_(self, lang: str, **kwargs):
95
+ import nltk
96
+
97
+ nltk.download(f"averaged_perceptron_tagger_{lang}", quiet=True)
98
+ super()._pipeline_init_(lang, **kwargs)
99
+
79
100
  def __call__(self, tokens: List[str], **kwargs) -> Dict[str, Any]:
80
101
  """
81
102
  :param text:
@@ -101,64 +122,6 @@ class NLTKNamedEntityRecognizer(PipelineStep):
101
122
  return {"entities"}
102
123
 
103
124
 
104
- class NERContextRetriever:
105
- def __call__(self, dataset: NERDataset) -> NERDataset:
106
- raise NotImplementedError
107
-
108
-
109
- class NERSamenounContextRetriever(NERContextRetriever):
110
- """
111
- Retrieve relevant context using the samenoun strategy as in
112
- Amalvy et al. 2023.
113
- """
114
-
115
- def __init__(self, k: int) -> None:
116
- """
117
- :param k: the number of sentences to retrieve
118
- """
119
- self.k = k
120
-
121
- def __call__(self, dataset: NERDataset) -> NERDataset:
122
- import nltk
123
-
124
- # NOTE: POS tagging is not incorporated in the pipeline yet,
125
- # so we manually compute it here.
126
- elements_names = [
127
- {t[0] for t in nltk.pos_tag(element) if t[1].startswith("NN")}
128
- for element in dataset.elements
129
- ]
130
-
131
- elements_with_context = []
132
-
133
- for elt_i, elt in enumerate(dataset.elements):
134
- retrieved_elts = [
135
- other_elt
136
- for other_elt_i, other_elt in enumerate(dataset.elements)
137
- if not other_elt_i == elt_i
138
- and len(elements_names[elt_i].intersection(elements_names[other_elt_i]))
139
- > 0
140
- ]
141
- retrieved_elts = random.sample(
142
- retrieved_elts, k=min(self.k, len(retrieved_elts))
143
- )
144
- elements_with_context.append(
145
- (
146
- elt,
147
- [dataset.tokenizer.sep_token]
148
- + list(itertools.chain.from_iterable(retrieved_elts)),
149
- )
150
- )
151
-
152
- return NERDataset(
153
- [element + context for element, context in elements_with_context],
154
- dataset.tokenizer,
155
- [
156
- [0] * len(element) + [1] * len(context)
157
- for element, context in elements_with_context
158
- ],
159
- )
160
-
161
-
162
125
  class BertNamedEntityRecognizer(PipelineStep):
163
126
  """An entity recognizer based on BERT"""
164
127
 
@@ -214,10 +177,10 @@ class BertNamedEntityRecognizer(PipelineStep):
214
177
 
215
178
  super().__init__()
216
179
 
217
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
180
+ def _pipeline_init_(self, lang: str, **kwargs):
218
181
  from transformers import AutoModelForTokenClassification, AutoTokenizer # type: ignore
219
182
 
220
- super()._pipeline_init_(lang, progress_reporter)
183
+ super()._pipeline_init_(lang, **kwargs)
221
184
 
222
185
  # init model if needed (this happens if the user did not pass
223
186
  # the instance of a model)
@@ -306,7 +269,7 @@ class BertNamedEntityRecognizer(PipelineStep):
306
269
  batch_i: int,
307
270
  wp_labels: List[str],
308
271
  tokens: List[str],
309
- context_mask: torch.Tensor,
272
+ ctxmask: torch.Tensor,
310
273
  ) -> List[str]:
311
274
  """Align labels to tokens rather than wordpiece tokens.
312
275
 
@@ -317,13 +280,21 @@ class BertNamedEntityRecognizer(PipelineStep):
317
280
  """
318
281
  batch_labels = ["O"] * len(tokens)
319
282
 
283
+ try:
284
+ inference_start = ctxmask[batch_i].tolist().index(1)
285
+ except ValueError:
286
+ inference_start = 0
287
+
320
288
  for wplabel_j, wp_label in enumerate(wp_labels):
321
- if context_mask[batch_i][wplabel_j] == 1:
322
- continue
289
+
323
290
  token_i = batchs.token_to_word(batch_i, wplabel_j)
324
291
  if token_i is None:
325
292
  continue
326
- batch_labels[token_i] = wp_label
293
+
294
+ if ctxmask[batch_i][token_i] == 0:
295
+ continue
296
+
297
+ batch_labels[token_i - inference_start] = wp_label
327
298
 
328
299
  return batch_labels
329
300
 
@@ -0,0 +1,375 @@
1
+ from collections.abc import Set
2
+ import sys
3
+ from typing import Union, List, cast, Literal, Optional
4
+ import random
5
+ from dataclasses import dataclass
6
+ from more_itertools import flatten
7
+ from renard.ner_utils import NERDataset
8
+ import nltk
9
+ from rank_bm25 import BM25Okapi
10
+ from transformers import (
11
+ BertForSequenceClassification,
12
+ BertTokenizerFast,
13
+ DataCollatorWithPadding,
14
+ )
15
+ from transformers.tokenization_utils_base import BatchEncoding
16
+ import torch
17
+ from torch.utils.data import Dataset, DataLoader
18
+
19
+
20
+ @dataclass
21
+ class NERContextRetrievalMatch:
22
+ element: List[str]
23
+ element_i: int
24
+ side: Literal["left", "right"]
25
+ score: Optional[float]
26
+
27
+ def __hash__(self) -> int:
28
+ return hash(tuple(self.element) + (self.element_i, self.side, self.score))
29
+
30
+
31
+ class NERContextRetriever:
32
+ def __init__(self, k: int) -> None:
33
+ self.k = k
34
+
35
+ def compute_global_features(self, elements: List[List[str]]) -> dict:
36
+ return {}
37
+
38
+ def retrieve(
39
+ self, element_i: int, elements: List[List[str]], **kwargs
40
+ ) -> List[NERContextRetrievalMatch]:
41
+ raise NotImplementedError
42
+
43
+ def __call__(self, dataset: NERDataset) -> NERDataset:
44
+ # [(left_ctx, element, right_ctx), ...]
45
+ elements_with_context = []
46
+
47
+ global_features = self.compute_global_features(dataset.elements)
48
+
49
+ for elt_i, elt in enumerate(dataset.elements):
50
+ matchs = self.retrieve(elt_i, dataset.elements, **global_features)
51
+ assert len(matchs) <= self.k
52
+
53
+ lctx = sorted(
54
+ (m for m in matchs if m.side == "left"),
55
+ key=lambda m: m.element_i,
56
+ )
57
+ lctx = list(flatten([m.element for m in lctx]))
58
+
59
+ rctx = sorted(
60
+ (m for m in matchs if m.side == "right"),
61
+ key=lambda m: m.element_i,
62
+ )
63
+ rctx = list(flatten([m.element for m in rctx]))
64
+
65
+ elements_with_context.append((lctx, elt, rctx))
66
+
67
+ return NERDataset(
68
+ [lctx + element + rctx for lctx, element, rctx in elements_with_context],
69
+ dataset.tokenizer,
70
+ [
71
+ [1] * len(lctx) + [0] * len(element) + [1] * len(rctx)
72
+ for lctx, element, rctx in elements_with_context
73
+ ],
74
+ )
75
+
76
+
77
+ class NERSamenounContextRetriever(NERContextRetriever):
78
+ """
79
+ Retrieve relevant context using the samenoun strategy as in
80
+ Amalvy et al. 2023.
81
+ """
82
+
83
+ def __init__(self, k: int) -> None:
84
+ """
85
+ :param k: the max number of sentences to retrieve
86
+ """
87
+ super().__init__(k)
88
+
89
+ def compute_global_features(self, elements: List[List[str]]) -> dict:
90
+ return {
91
+ "NNs": [
92
+ {t[0] for t in nltk.pos_tag(element) if t[1] == "NN"}
93
+ for element in elements
94
+ ]
95
+ }
96
+
97
+ def retrieve(
98
+ self, element_i: int, elements: List[List[str]], NNs: List[Set[str]], **kwargs
99
+ ) -> List[NERContextRetrievalMatch]:
100
+ matchs = [
101
+ NERContextRetrievalMatch(
102
+ other_elt,
103
+ other_elt_i,
104
+ "left" if other_elt_i < element_i else "right",
105
+ None,
106
+ )
107
+ for other_elt_i, other_elt in enumerate(elements)
108
+ if not other_elt_i == element_i
109
+ and len(NNs[element_i].intersection(NNs[other_elt_i])) > 0 # type: ignore
110
+ ]
111
+ return random.sample(matchs, k=min(self.k, len(matchs)))
112
+
113
+
114
+ class NERNeighborsContextRetriever(NERContextRetriever):
115
+ """A context retriever that chooses nearby elements."""
116
+
117
+ def __init__(self, k: int):
118
+ assert k % 2 == 0
119
+ super().__init__(k)
120
+
121
+ def retrieve(
122
+ self, element_i: int, elements: List[List[str]], **kwargs
123
+ ) -> List[NERContextRetrievalMatch]:
124
+ left_nb = self.k // 2
125
+ right_nb = left_nb
126
+
127
+ lctx = []
128
+ for i, elt in enumerate(elements[element_i - left_nb : element_i]):
129
+ lctx.append(
130
+ NERContextRetrievalMatch(elt, element_i - left_nb + i, "left", None)
131
+ )
132
+
133
+ rctx = []
134
+ for i, elt in enumerate(elements[element_i + 1 : element_i + 1 + right_nb]):
135
+ rctx.append(NERContextRetrievalMatch(elt, element_i + 1 + i, "right", None))
136
+
137
+ return lctx + rctx
138
+
139
+
140
+ class NERBM25ContextRetriever(NERContextRetriever):
141
+ """A context retriever that selects elements according to the BM25 ranking formula."""
142
+
143
+ def __init__(self, k: int) -> None:
144
+ super().__init__(k)
145
+
146
+ def compute_global_features(self, elements: List[List[str]]) -> dict:
147
+ return {"bm25_model": BM25Okapi(elements)}
148
+
149
+ def retrieve(
150
+ self, element_i: int, elements: List[List[str]], bm25_model: BM25Okapi, **kwargs
151
+ ) -> List[NERContextRetrievalMatch]:
152
+ query = elements[element_i]
153
+ sent_scores = bm25_model.get_scores(query)
154
+ sent_scores[element_i] = float("-Inf") # don't retrieve self
155
+ topk_values, topk_indexs = torch.topk(
156
+ torch.tensor(sent_scores), k=min(self.k, len(sent_scores)), dim=0
157
+ )
158
+ return [
159
+ NERContextRetrievalMatch(
160
+ elements[index], index, "left" if index < element_i else "right", value
161
+ )
162
+ for value, index in zip(topk_values.tolist(), topk_indexs.tolist())
163
+ ]
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class NERNeuralContextRetrievalExample:
168
+ """A context retrieval example."""
169
+
170
+ #: text on which NER is performed
171
+ element: List[str]
172
+ #: context to assist during prediction
173
+ context: List[str]
174
+ #: context side (does the context comes from the left or the right of ``sent`` ?)
175
+ context_side: Literal["left", "right"]
176
+
177
+
178
+ class NERNeuralContextRetrievalDataset(Dataset):
179
+ """"""
180
+
181
+ def __init__(
182
+ self,
183
+ examples: List[NERNeuralContextRetrievalExample],
184
+ tokenizer: BertTokenizerFast,
185
+ ) -> None:
186
+ self.examples = examples
187
+ self.tokenizer: BertTokenizerFast = tokenizer
188
+
189
+ def __len__(self) -> int:
190
+ return len(self.examples)
191
+
192
+ def __getitem__(self, index: int) -> BatchEncoding:
193
+ """Get a BatchEncoding representing example at index.
194
+
195
+ :param index: index of the example to retrieve
196
+
197
+ :return: a ``BatchEncoding``, with key ``'label'`` set.
198
+ """
199
+ example = self.examples[index]
200
+
201
+ tokens = example.context + ["[SEP]"] + example.element
202
+
203
+ batch: BatchEncoding = self.tokenizer(
204
+ tokens,
205
+ is_split_into_words=True,
206
+ truncation=True,
207
+ max_length=512,
208
+ )
209
+
210
+ return batch
211
+
212
+
213
+ class NERNeuralContextRetriever(NERContextRetriever):
214
+ """
215
+ A neural context retriever as in Amalvy et al. 2024
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ heuristic_context_selector: NERContextRetriever,
221
+ pretrained_model: Union[
222
+ str, BertForSequenceClassification
223
+ ] = "compnet-renard/bert-base-cased-NER-reranker",
224
+ k: int = 3,
225
+ batch_size: int = 1,
226
+ threshold: float = 0.0,
227
+ device_str: Literal["cuda", "cpu", "auto"] = "auto",
228
+ ) -> None:
229
+ """
230
+ :param pretrained_model: pretrained model name, used to
231
+ load a :class:`transformers.BertForSequenceClassification`
232
+ :param heuristic_context_selector: name of the context
233
+ selector to use as retrieval heuristic, from
234
+ ``context_selector_name_to_class``
235
+ :param heuristic_context_selector_kwargs: kwargs to pass the
236
+ heuristic context retriever at instantiation time
237
+ :param k: max number of sents to retrieve
238
+ :param batch_size: batch size used at inference
239
+ :param threshold:
240
+ :param device_str:
241
+ """
242
+ from transformers import BertForSequenceClassification, BertTokenizerFast
243
+
244
+ if isinstance(pretrained_model, str):
245
+ self.ctx_classifier = BertForSequenceClassification.from_pretrained(
246
+ pretrained_model
247
+ ) # type: ignore
248
+ else:
249
+ self.ctx_classifier = pretrained_model
250
+ self.ctx_classifier = cast(BertForSequenceClassification, self.ctx_classifier)
251
+
252
+ self.tokenizer = BertTokenizerFast.from_pretrained(
253
+ pretrained_model if isinstance(pretrained_model, str) else "bert-base-cased"
254
+ )
255
+
256
+ self.heuristic_context_selector = heuristic_context_selector
257
+
258
+ self.batch_size = batch_size
259
+ self.threshold = threshold
260
+
261
+ if device_str == "auto":
262
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
263
+ self.device = torch.device(device_str)
264
+
265
+ super().__init__(k)
266
+
267
+ def set_heuristic_k_(self, k: int):
268
+ self.heuristic_context_selector.k = k
269
+
270
+ def predict(self, examples: List[NERNeuralContextRetrievalExample]) -> torch.Tensor:
271
+ """
272
+ :param dataset: A list of :class:`ContextSelectionExample`
273
+ :return: A tensor of shape ``(len(dataset), 2)`` of class
274
+ scores
275
+ """
276
+ dataset = NERNeuralContextRetrievalDataset(examples, self.tokenizer)
277
+
278
+ self.ctx_classifier = self.ctx_classifier.to(self.device)
279
+
280
+ data_collator = DataCollatorWithPadding(dataset.tokenizer) # type: ignore
281
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, collate_fn=data_collator) # type: ignore
282
+
283
+ # inference using self.ctx_classifier
284
+ self.ctx_classifier = self.ctx_classifier.eval()
285
+ with torch.no_grad():
286
+ scores = torch.zeros((0,)).to(self.device)
287
+ for X in dataloader:
288
+ X = X.to(self.device)
289
+ # out.logits is of shape (batch_size, 2)
290
+ out = self.ctx_classifier(
291
+ X["input_ids"],
292
+ token_type_ids=X["token_type_ids"],
293
+ attention_mask=X["attention_mask"],
294
+ )
295
+ # (batch_size, 2)
296
+ pred = torch.softmax(out.logits, dim=1)
297
+ scores = torch.cat([scores, pred], dim=0)
298
+
299
+ return scores
300
+
301
+ def compute_global_features(self, elements: List[List[str]]) -> dict:
302
+ features = self.heuristic_context_selector.compute_global_features(elements)
303
+ return {
304
+ "heuristic_matchs": [
305
+ self.heuristic_context_selector.retrieve(i, elements, **features)
306
+ for i in range(len(elements))
307
+ ]
308
+ }
309
+
310
+ def retrieve(
311
+ self,
312
+ element_i: int,
313
+ elements: List[List[str]],
314
+ heuristic_matchs: List[List[NERContextRetrievalMatch]],
315
+ **kwargs,
316
+ ) -> List[NERContextRetrievalMatch]:
317
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
318
+ self.ctx_classifier = self.ctx_classifier.to(device) # type: ignore
319
+
320
+ # no context retrieved by heuristic : nothing to do
321
+ if len(heuristic_matchs) == 0:
322
+ return []
323
+
324
+ element = elements[element_i]
325
+ matchs = heuristic_matchs[element_i]
326
+
327
+ # prepare datas for inference
328
+ ctx_dataset = [
329
+ NERNeuralContextRetrievalExample(element, m.element, m.side) for m in matchs
330
+ ]
331
+
332
+ # (len(dataset), 2)
333
+ scores = self.predict(ctx_dataset)
334
+ for i, m in enumerate(matchs):
335
+ m.score = float(scores[i, 1].item())
336
+
337
+ assert all([not m.score is None for m in matchs])
338
+ return [
339
+ m
340
+ for m in sorted(matchs, key=lambda m: -m.score)[: self.k] # type: ignore
341
+ if m.score > self.threshold # type: ignore
342
+ ]
343
+
344
+
345
+ class NEREnsembleContextRetriever(NERContextRetriever):
346
+ """Combine several context retriever"""
347
+
348
+ def __init__(self, retrievers: List[NERContextRetriever], k: int) -> None:
349
+ self.retrievers = retrievers
350
+ super().__init__(k)
351
+
352
+ def compute_global_features(self, elements: List[List[str]]) -> dict:
353
+ features = {}
354
+ for retriever in self.retrievers:
355
+ for k, v in retriever.compute_global_features(elements).items():
356
+ if k in features:
357
+ print(
358
+ f"[warning] NEREnsembleContextRetriver: incompatible global feature '{k}' between multiple retrievers.",
359
+ file=sys.stderr,
360
+ )
361
+ features[k] = v
362
+ return features
363
+
364
+ def retrieve(
365
+ self, element_i: int, elements: List[List[str]], **kwargs
366
+ ) -> List[NERContextRetrievalMatch]:
367
+ all_matchs = set()
368
+
369
+ for retriever in self.retrievers:
370
+ matchs = retriever.retrieve(element_i, elements, **kwargs)
371
+ all_matchs = all_matchs.union(matchs)
372
+
373
+ if all(not m.score is None for m in all_matchs):
374
+ return sorted(all_matchs, key=lambda m: -m.score)[: self.k] # type: ignore
375
+ return random.choices(list(all_matchs), k=self.k)
@@ -1,4 +1,6 @@
1
+ from __future__ import annotations
1
2
  from typing import Iterable, Literal, Optional, TypeVar, Generator
3
+ import sys
2
4
  from tqdm import tqdm
3
5
 
4
6
 
@@ -20,6 +22,10 @@ class ProgressReporter:
20
22
  """Update reporter current message."""
21
23
  pass
22
24
 
25
+ def get_subreporter(self) -> ProgressReporter:
26
+ """Get the subreporter corresponding to that reporter."""
27
+ raise NotImplementedError
28
+
23
29
 
24
30
  class NoopProgressReporter(ProgressReporter):
25
31
  def reset_(self):
@@ -28,6 +34,28 @@ class NoopProgressReporter(ProgressReporter):
28
34
  def update_progress_(self, added_progress: int):
29
35
  pass
30
36
 
37
+ def get_subreporter(self) -> ProgressReporter:
38
+ return NoopProgressReporter()
39
+
40
+
41
+ class TQDMSubProgressReporter(ProgressReporter):
42
+ def __init__(self, reporter: TQDMProgressReporter) -> None:
43
+ super().__init__()
44
+ self.reporter = reporter
45
+
46
+ def start_(self, total: int):
47
+ super().start_(total)
48
+ self.progress = 0
49
+
50
+ def update_progress_(self, added_progress: int):
51
+ self.progress += added_progress
52
+ self.reporter.tqdm.set_postfix(step=f"({self.progress}/{self.total})")
53
+
54
+ def update_message_(self, message: str):
55
+ self.reporter.tqdm.set_postfix(
56
+ step=f"({self.progress}/{self.total})", message=message
57
+ )
58
+
31
59
 
32
60
  class TQDMProgressReporter(ProgressReporter):
33
61
  def start_(self, total: int):
@@ -40,6 +68,9 @@ class TQDMProgressReporter(ProgressReporter):
40
68
  def update_message_(self, message: str):
41
69
  self.tqdm.set_description_str(message)
42
70
 
71
+ def get_subreporter(self) -> ProgressReporter:
72
+ return TQDMSubProgressReporter(self)
73
+
43
74
 
44
75
  T = TypeVar("T")
45
76
 
@@ -62,5 +93,5 @@ def get_progress_reporter(name: Optional[Literal["tqdm"]]) -> ProgressReporter:
62
93
  return NoopProgressReporter()
63
94
  if name == "tqdm":
64
95
  return TQDMProgressReporter()
65
- print(f"[warning] unknown progress reporter: {name}")
96
+ print(f"[warning] unknown progress reporter: {name}", file=sys.stderr)
66
97
  return NoopProgressReporter()
@@ -49,13 +49,12 @@ class BertSpeakerDetector(PipelineStep):
49
49
 
50
50
  super().__init__()
51
51
 
52
- def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter):
52
+ def _pipeline_init_(self, lang: str, **kwargs):
53
53
  from transformers import AutoTokenizer
54
54
 
55
- super()._pipeline_init_(lang, progress_reporter)
55
+ super()._pipeline_init_(lang, **kwargs)
56
56
 
57
57
  if self.model is None:
58
-
59
58
  # the user supplied a huggingface ID: load model from the HUB
60
59
  if not self.huggingface_model_id is None:
61
60
  self.model = SpeakerAttributionModel.from_pretrained(