renard-pipeline 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of renard-pipeline might be problematic. Click here for more details.

@@ -1,12 +1,12 @@
1
- import itertools
2
1
  from typing import Dict, Any, List, Set, Optional, Tuple, Literal, Union
2
+ import itertools as it
3
3
  import operator
4
- from itertools import accumulate
5
4
 
6
5
  import networkx as nx
7
6
  import numpy as np
8
7
  from more_itertools import windowed
9
8
 
9
+ from renard.utils import BlockBounds, charbb2tokenbb
10
10
  from renard.pipeline.ner import NEREntity
11
11
  from renard.pipeline.core import PipelineStep
12
12
  from renard.pipeline.character_unification import Character
@@ -15,65 +15,63 @@ from renard.pipeline.quote_detection import Quote
15
15
 
16
16
  def sent_index_for_token_index(token_index: int, sentences: List[List[str]]) -> int:
17
17
  """Compute the index of the sentence of the token at ``token_index``"""
18
- sents_len = accumulate([len(s) for s in sentences], operator.add)
18
+ sents_len = it.accumulate([len(s) for s in sentences], operator.add)
19
19
  return next((i for i, l in enumerate(sents_len) if l > token_index))
20
20
 
21
21
 
22
- def sent_indices_for_chapter(
23
- chapters: List[List[str]], chapter_idx: int, sentences: List[List[str]]
22
+ def sent_indices_for_block(
23
+ dynamic_block: Tuple[int, int], sentences: List[List[str]]
24
24
  ) -> Tuple[int, int]:
25
25
  """Return the indices of the first and the last sentence of a
26
- chapter
26
+ block
27
27
 
28
- :param chapters: all chapters
29
- :param chapter_idx: index of the chapter for which sentence
30
- indices are returned
31
- :param sentences: all sentences
28
+ :param dynamic_block: (START, END) in tokens
32
29
  :return: ``(first sentence index, last sentence index)``
33
30
  """
34
- chapter_start_idx = sum([len(c) for i, c in enumerate(chapters) if i < chapter_idx])
35
- chapter_end_idx = chapter_start_idx + len(chapters[chapter_idx])
36
- sents_start_idx = None
37
- sents_end_idx = None
31
+ block_start, block_end = dynamic_block
32
+ sents_start = None
33
+ sents_end = None
38
34
  count = 0
39
35
  for sent_i, sent in enumerate(sentences):
40
- start_idx, end_idx = (count, count + len(sent))
41
- count = end_idx
42
- if sents_start_idx is None and start_idx >= chapter_start_idx:
43
- sents_start_idx = sent_i
44
- if sents_end_idx is None and end_idx >= chapter_end_idx:
45
- sents_end_idx = sent_i
36
+ start, end = (count, count + len(sent))
37
+ count = end
38
+ if sents_start is None and start >= block_start:
39
+ sents_start = sent_i
40
+ if sents_end is None and end >= block_end:
41
+ # this happens when the block is _smaller_ than the
42
+ # current sentence. In that case, we return the current
43
+ # sentence even though it overflows the block.
44
+ if sents_start is None:
45
+ sents_start = sent_i
46
+ sents_end = sent_i
46
47
  break
47
- assert not sents_start_idx is None and not sents_end_idx is None
48
- return (sents_start_idx, sents_end_idx)
48
+ assert not sents_start is None and not sents_end is None
49
+ return (sents_start, sents_end)
49
50
 
50
51
 
51
- def mentions_for_chapters(
52
- chapters: List[List[str]],
53
- mentions: List[Tuple[Character, NEREntity]],
54
- ) -> List[List[Tuple[Character, NEREntity]]]:
55
- """Return each chapter mentions
52
+ def mentions_for_blocks(
53
+ block_bounds: BlockBounds,
54
+ mentions: List[Tuple[Any, NEREntity]],
55
+ ) -> List[List[Tuple[Any, NEREntity]]]:
56
+ """Return each block mentions.
56
57
 
57
- :param chapters:
58
+ :param block_bounds: block bounds, in tokens
58
59
  :param mentions:
59
60
 
60
- :return: a list of mentions per chapters. This list has len
61
- ``len(chapters)``.
61
+ :return: a list of mentions per blocks. This list has len
62
+ ``len(block_bounds)``.
62
63
  """
63
- chapters_mentions = [[] for _ in range(len(chapters))]
64
+ assert block_bounds[1] == "tokens"
64
65
 
65
- start_indices = list(
66
- itertools.accumulate([0] + [len(chapter) for chapter in chapters[:-1]])
67
- )
68
- end_indices = start_indices[1:] + [start_indices[-1] + len(chapters[-1])]
66
+ blocks_mentions = [[] for _ in range(len(block_bounds[0]))]
69
67
 
70
68
  for mention in mentions:
71
- for chapter_i, (start_i, end_i) in enumerate(zip(start_indices, end_indices)):
69
+ for block_i, (start_i, end_i) in enumerate(block_bounds[0]):
72
70
  if mention[1].start_idx >= start_i and mention[1].end_idx < end_i:
73
- chapters_mentions[chapter_i].append(mention)
71
+ blocks_mentions[block_i].append(mention)
74
72
  break
75
73
 
76
- return chapters_mentions
74
+ return blocks_mentions
77
75
 
78
76
 
79
77
  class CoOccurrencesGraphExtractor(PipelineStep):
@@ -83,13 +81,11 @@ class CoOccurrencesGraphExtractor(PipelineStep):
83
81
  self,
84
82
  co_occurrences_dist: Optional[
85
83
  Union[int, Tuple[int, Literal["tokens", "sentences"]]]
86
- ],
84
+ ] = None,
87
85
  dynamic: bool = False,
88
86
  dynamic_window: Optional[int] = None,
89
87
  dynamic_overlap: int = 0,
90
- co_occurences_dist: Optional[
91
- Union[int, Tuple[int, Literal["tokens", "sentences"]]]
92
- ] = None,
88
+ additional_ner_classes: Optional[List[str]] = None,
93
89
  ) -> None:
94
90
  """
95
91
  :param co_occurrences_dist: max accepted distance between two
@@ -111,10 +107,10 @@ class CoOccurrencesGraphExtractor(PipelineStep):
111
107
  that case, ``dynamic_window`` and
112
108
  ``dynamic_overlap``*can* be specified. If
113
109
  ``dynamic_window`` is not specified, this step is
114
- expecting the text to be cut into chapters', and a
115
- graph will be extracted for each 'chapter'. In that
116
- case, ``chapters`` must be passed to the pipeline as
117
- a ``List[str]`` at runtime.
110
+ expecting the text to be cut into 'dynamic blocks',
111
+ and a graph will be extracted for each block. In
112
+ that case, ``dynamic_blocks`` must be passed to the
113
+ pipeline as a ``List[str]`` at runtime.
118
114
 
119
115
  :param dynamic_window: dynamic window, in number of
120
116
  interactions. a dynamic window of `n` means that each
@@ -122,19 +118,15 @@ class CoOccurrencesGraphExtractor(PipelineStep):
122
118
 
123
119
  :param dynamic_overlap: overlap, in number of interactions.
124
120
 
125
- :param co_occurences_dist: same as ``co_occurrences_dist``.
126
- Included because of retro-compatibility, as it was a
127
- previously included typo.
121
+ :param additional_ner_classes: if specified, will include
122
+ entities other than characters in the final graph. No
123
+ attempt will be made at unifying the entities (for example,
124
+ "New York" will be distinct from "New York City").
128
125
  """
129
- # typo retrocompatibility
130
- if not co_occurences_dist is None:
131
- co_occurrences_dist = co_occurences_dist
132
- if co_occurrences_dist is None and co_occurences_dist is None:
133
- raise ValueError()
134
-
135
126
  if isinstance(co_occurrences_dist, int):
136
127
  co_occurrences_dist = (co_occurrences_dist, "tokens")
137
128
  self.co_occurrences_dist = co_occurrences_dist
129
+ self.need_co_occurrences_blocks = co_occurrences_dist is None
138
130
 
139
131
  if dynamic:
140
132
  if not dynamic_window is None:
@@ -143,89 +135,135 @@ class CoOccurrencesGraphExtractor(PipelineStep):
143
135
  self.dynamic = dynamic
144
136
  self.dynamic_window = dynamic_window
145
137
  self.dynamic_overlap = dynamic_overlap
146
- self.dynamic_needs_chapter = dynamic == "nx" and dynamic_window is None
138
+ self.need_dynamic_blocks = dynamic and dynamic_window is None
139
+
140
+ self.additional_ner_classes = additional_ner_classes or []
141
+
147
142
  super().__init__()
148
143
 
149
144
  def __call__(
150
145
  self,
151
146
  characters: Set[Character],
152
147
  sentences: List[List[str]],
153
- chapter_tokens: Optional[List[List[str]]] = None,
148
+ char2token: Optional[List[int]] = None,
149
+ dynamic_blocks: Optional[BlockBounds] = None,
154
150
  sentences_polarities: Optional[List[float]] = None,
151
+ entities: Optional[List[NEREntity]] = None,
152
+ co_occurrences_blocks: Optional[BlockBounds] = None,
155
153
  **kwargs,
156
154
  ) -> Dict[str, Any]:
157
- """Extract a characters graph
155
+ """Extract a co-occurrence character network.
158
156
 
159
- :param characters:
157
+ :param co_occurrences_blocks: custom blocks where
158
+ co-occurrences should be recorded. For example, this can
159
+ be used to perform chapter level co-occurrences.
160
160
 
161
161
  :return: a ``dict`` with key ``'character_network'`` and a
162
- :class:`nx.Graph` or a list of :class:`nx.Graph` as
163
- value.
162
+ :class:`nx.Graph` or a list of :class:`nx.Graph` as
163
+ value.
164
164
  """
165
165
  mentions = []
166
166
  for character in characters:
167
167
  for mention in character.mentions:
168
168
  mentions.append((character, mention))
169
+
170
+ if len(self.additional_ner_classes) > 0:
171
+ assert not entities is None
172
+ for entity in entities:
173
+ if entity.tag in self.additional_ner_classes:
174
+ mentions.append((" ".join(entity.tokens), entity))
175
+
169
176
  mentions = sorted(mentions, key=lambda cm: cm[1].start_idx)
170
177
 
178
+ # convert from char blocks to token blocks
179
+ if not dynamic_blocks is None and dynamic_blocks[1] == "characters":
180
+ assert not char2token is None
181
+ dynamic_blocks = charbb2tokenbb(dynamic_blocks, char2token)
182
+ if (
183
+ not co_occurrences_blocks is None
184
+ and co_occurrences_blocks[1] == "characters"
185
+ ):
186
+ assert not char2token is None
187
+ co_occurrences_blocks = charbb2tokenbb(co_occurrences_blocks, char2token)
188
+
171
189
  if self.dynamic:
172
190
  return {
173
191
  "character_network": self._extract_dynamic_graph(
174
192
  mentions,
175
193
  self.dynamic_window,
176
194
  self.dynamic_overlap,
177
- chapter_tokens,
195
+ dynamic_blocks,
178
196
  sentences,
179
197
  sentences_polarities,
198
+ co_occurrences_blocks,
180
199
  )
181
200
  }
182
201
  return {
183
202
  "character_network": self._extract_graph(
184
- mentions, sentences, sentences_polarities
203
+ mentions, sentences, sentences_polarities, co_occurrences_blocks
185
204
  )
186
205
  }
187
206
 
188
- def _mentions_interact(
189
- self,
190
- mention_1: NEREntity,
191
- mention_2: NEREntity,
192
- sentences: Optional[List[List[str]]] = None,
193
- ) -> bool:
194
- """Check if two mentions are close enough to be in interactions.
195
-
196
- .. note::
197
-
198
- the attribute ``self.co_occurrences_dist`` is used to know wether mentions are in co_occurences
207
+ def _create_co_occurrences_blocks(
208
+ self, sentences: List[List[str]], mentions: List[Tuple[Any, NEREntity]]
209
+ ) -> BlockBounds:
210
+ """Create co-occurrences blocks using
211
+ ``self.co_occurrences_dist``. All entities within a block are
212
+ considered as co-occurring.
199
213
 
200
- :param mention_1:
201
- :param mention_2:
202
214
  :param sentences:
203
- :return: a boolean indicating wether the two mentions are co-occuring
204
215
  """
205
- if self.co_occurrences_dist[1] == "tokens":
206
- return (
207
- abs(mention_2.start_idx - mention_1.start_idx)
208
- <= self.co_occurrences_dist[0]
209
- )
210
- elif self.co_occurrences_dist[1] == "sentences":
211
- assert not sentences is None
212
- mention_1_sent = sent_index_for_token_index(mention_1.start_idx, sentences)
213
- mention_2_sent = sent_index_for_token_index(
214
- mention_2.end_idx - 1, sentences
215
- )
216
- return abs(mention_2_sent - mention_1_sent) <= self.co_occurrences_dist[0]
216
+ assert not self.co_occurrences_dist is None
217
+
218
+ dist_unit = self.co_occurrences_dist[1]
219
+
220
+ if dist_unit == "tokens":
221
+ tokens_dist = self.co_occurrences_dist[0]
222
+ blocks = []
223
+ for _, entity in mentions:
224
+ block_start = entity.start_idx - tokens_dist
225
+ block_end = entity.end_idx + tokens_dist
226
+ blocks.append((block_start, block_end))
227
+ return (blocks, "tokens")
228
+
229
+ elif dist_unit == "sentences":
230
+ blocks_indices = set()
231
+ sent_dist = self.co_occurrences_dist[0]
232
+ for _, entity in mentions:
233
+ start_sent_i = max(
234
+ 0,
235
+ sent_index_for_token_index(entity.start_idx, sentences) - sent_dist,
236
+ )
237
+ start_token_i = sum(len(sent) for sent in sentences[:start_sent_i])
238
+ end_sent_i = min(
239
+ len(sentences) - 1,
240
+ sent_index_for_token_index(entity.end_idx - 1, sentences)
241
+ + sent_dist,
242
+ )
243
+ end_token_i = sum(len(sent) for sent in sentences[: end_sent_i + 1])
244
+ blocks_indices.add((start_token_i, end_token_i))
245
+ blocks = [
246
+ (start, end)
247
+ for start, end in sorted(blocks_indices, key=lambda indices: indices[0])
248
+ ]
249
+ return (blocks, "tokens")
250
+
217
251
  else:
218
- raise NotImplementedError
252
+ raise ValueError(
253
+ f"co_occurrences_dist unit should be one of: 'tokens', 'sentences'"
254
+ )
219
255
 
220
256
  def _extract_graph(
221
257
  self,
222
- mentions: List[Tuple[Character, NEREntity]],
258
+ mentions: List[Tuple[Any, NEREntity]],
223
259
  sentences: List[List[str]],
224
260
  sentences_polarities: Optional[List[float]],
225
- ):
261
+ co_occurrences_blocks: Optional[BlockBounds],
262
+ ) -> nx.Graph:
226
263
  """
227
- :param mentions: A list of character mentions, ordered by
228
- appearance
264
+ :param mentions: A list of entity mentions, ordered by
265
+ appearance, each of the form (KEY MENTION). KEY
266
+ determines the unicity of the entity.
229
267
  :param sentences: if specified, ``sentences_polarities`` must
230
268
  be specified as well.
231
269
  :param sentences_polarities: if specified, ``sentences`` must
@@ -234,25 +272,37 @@ class CoOccurrencesGraphExtractor(PipelineStep):
234
272
  of the relationship between two characters. Polarity
235
273
  between two interactions is computed as the strongest
236
274
  sentence polarity between those two mentions.
275
+ :param co_occurrences_blocks: only unit 'tokens' is accepted.
237
276
  """
238
277
  compute_polarity = not sentences_polarities is None
239
278
 
279
+ assert co_occurrences_blocks is None or co_occurrences_blocks[1] == "tokens"
280
+ if co_occurrences_blocks is None:
281
+ co_occurrences_blocks = self._create_co_occurrences_blocks(
282
+ sentences, mentions
283
+ )
284
+
240
285
  # co-occurence matrix, where C[i][j] is 1 when appearance
241
286
  # i co-occur with j if i < j, or 0 when it doesn't
242
287
  C = np.zeros((len(mentions), len(mentions)))
243
- for i, (char1, mention_1) in enumerate(mentions):
244
- # check ahead for co-occurences
245
- for j, (char2, mention_2) in enumerate(mentions[i + 1 :]):
246
- if not self._mentions_interact(mention_1, mention_2, sentences):
247
- # dist between current token and future token is
248
- # too great : we finished co-occurences search for
249
- # the current token
288
+ for block_start, block_end in co_occurrences_blocks[0]:
289
+ # collect all mentions in this co-occurrences block
290
+ block_mentions = []
291
+ for i, (key, mention) in enumerate(mentions):
292
+ if mention.start_idx >= block_start and mention.end_idx <= block_end:
293
+ block_mentions.append((i, key, mention))
294
+ # since mentions are ordered, the first mention
295
+ # outside of the blocks ends the search inside this block
296
+ if mention.start_idx > block_end:
250
297
  break
251
- # ignore co-occurences with self
252
- if char1 == char2:
298
+ # assign mentions in this co-occurrences blocks to C
299
+ for m1, m2 in it.combinations(block_mentions, 2):
300
+ i1, key1, mention1 = m1
301
+ i2, key2, mention2 = m2
302
+ # ignore co-occurrence with self
303
+ if key1 == key2:
253
304
  continue
254
- # record co_occurence
255
- C[i][i + 1 + j] = 1
305
+ C[i1][i2] = 1
256
306
 
257
307
  # * Construct graph from co-occurence matrix
258
308
  G = nx.Graph()
@@ -291,25 +341,29 @@ class CoOccurrencesGraphExtractor(PipelineStep):
291
341
 
292
342
  def _extract_dynamic_graph(
293
343
  self,
294
- mentions: List[Tuple[Character, NEREntity]],
344
+ mentions: List[Tuple[Any, NEREntity]],
295
345
  window: Optional[int],
296
346
  overlap: int,
297
- chapter_tokens: Optional[List[List[str]]],
347
+ dynamic_blocks: Optional[BlockBounds],
298
348
  sentences: List[List[str]],
299
349
  sentences_polarities: Optional[List[float]],
350
+ co_occurrences_blocks: Optional[BlockBounds],
300
351
  ) -> List[nx.Graph]:
301
352
  """
302
353
  .. note::
303
354
 
304
- only one of ``window`` or ``chapter_tokens`` should be specified
355
+ only one of ``window`` or ``dynamic_blocks_tokens`` should be specified
305
356
 
306
- :param mentions: A list of character mentions, ordered by appearance
357
+ :param mentions: A list of entity mentions, ordered by
358
+ appearance, each of the form (KEY MENTION). KEY
359
+ determines the unicity of the entity.
307
360
  :param window: dynamic window, in tokens.
308
361
  :param overlap: window overlap
309
- :param chapter_tokens: list of tokens for each chapter. If
310
- given, one graph will be extracted per chapter.
362
+ :param dynamic_blocks: boundaries of each dynamic block
363
+ :param co_occurrences_blocks: boundaries of each co-occurrences blocks
311
364
  """
312
- assert window is None or chapter_tokens is None
365
+ assert co_occurrences_blocks is None or co_occurrences_blocks[1] == "tokens"
366
+ assert window is None or dynamic_blocks is None
313
367
  compute_polarity = not sentences is None and not sentences_polarities is None
314
368
 
315
369
  if not window is None:
@@ -318,104 +372,66 @@ class CoOccurrencesGraphExtractor(PipelineStep):
318
372
  [elt for elt in ct if not elt is None],
319
373
  sentences,
320
374
  sentences_polarities,
375
+ co_occurrences_blocks,
321
376
  )
322
377
  for ct in windowed(mentions, window, step=window - overlap)
323
378
  ]
324
379
 
325
- assert not chapter_tokens is None
380
+ assert not dynamic_blocks is None
326
381
 
327
382
  graphs = []
328
383
 
329
- chapters_mentions = mentions_for_chapters(chapter_tokens, mentions)
330
- for chapter_i, (_, chapter_mentions) in enumerate(
331
- zip(chapter_tokens, chapters_mentions)
332
- ):
333
- chapter_start_idx = sum(
334
- [len(c) for i, c in enumerate(chapter_tokens) if i < chapter_i]
335
- )
336
- # make mentions coordinates chapter local
337
- chapter_mentions = [
338
- (c, m.shifted(-chapter_start_idx)) for c, m in chapter_mentions
339
- ]
384
+ blocks_mentions = mentions_for_blocks(dynamic_blocks, mentions)
385
+ for dynamic_block, block_mentions in zip(dynamic_blocks[0], blocks_mentions):
386
+ block_start, block_end = dynamic_block
340
387
 
341
- sent_start_idx, sent_end_idx = sent_indices_for_chapter(
342
- chapter_tokens, chapter_i, sentences
343
- )
344
- chapter_sentences = sentences[sent_start_idx : sent_end_idx + 1]
388
+ sent_start, sent_end = sent_indices_for_block(dynamic_block, sentences)
389
+ block_sentences = sentences[sent_start : sent_end + 1]
345
390
 
346
- chapter_sentences_polarities = None
391
+ block_sentences_polarities = None
347
392
  if compute_polarity:
348
393
  assert not sentences_polarities is None
349
- chapter_sentences_polarities = sentences_polarities[
350
- sent_start_idx : sent_end_idx + 1
394
+ block_sentences_polarities = sentences_polarities[
395
+ sent_start : sent_end + 1
396
+ ]
397
+
398
+ if co_occurrences_blocks is None:
399
+ block_co_occ_bounds = None
400
+ else:
401
+ bounds = [
402
+ (start, end)
403
+ for start, end in co_occurrences_blocks[0]
404
+ if start >= block_start and end <= block_end
351
405
  ]
406
+ block_co_occ_bounds = (bounds, "tokens")
352
407
 
353
408
  graphs.append(
354
409
  self._extract_graph(
355
- chapter_mentions,
356
- chapter_sentences,
357
- chapter_sentences_polarities,
410
+ block_mentions,
411
+ block_sentences,
412
+ block_sentences_polarities,
413
+ block_co_occ_bounds,
358
414
  )
359
415
  )
360
416
 
361
417
  return graphs
362
418
 
363
- def _extract_gephi_dynamic_graph(
364
- self, mentions: List[Tuple[Character, NEREntity]], sentences: List[List[str]]
365
- ) -> nx.Graph:
366
- """
367
- :param mentions: A list of character mentions, ordered by appearance
368
- :param sentences:
369
- """
370
- # keep only longest name in graph node : possible only if it is unique
371
- # TODO: might want to try and get shorter names if longest names aren't
372
- # unique
373
- characters = set([e[0] for e in mentions])
374
-
375
- G = nx.Graph()
376
-
377
- character_to_last_appearance: Dict[Character, Optional[NEREntity]] = {
378
- character: None for character in characters
379
- }
380
-
381
- for i, (character, mention) in enumerate(mentions):
382
- if not character in characters:
383
- continue
384
- character_to_last_appearance[character] = mention
385
- close_characters = [
386
- c
387
- for c, last_appearance in character_to_last_appearance.items()
388
- if not last_appearance is None
389
- and self._mentions_interact(mention, last_appearance, sentences)
390
- and not c == character
391
- ]
392
- for close_character in close_characters:
393
- if not G.has_edge(character, close_character):
394
- G.add_edge(character, close_character)
395
- G.edges[character, close_character]["start"] = i
396
- G.edges[character, close_character]["dweight"] = []
397
- # add a new entry to the weight series according to networkx
398
- # source code, each entry must be of the form
399
- # [value, start, end]
400
- weights = G.edges[character, close_character]["dweight"]
401
- if len(weights) != 0:
402
- # end of last weight attribute
403
- weights[-1][-1] = i
404
- # value, start and end of current weight attribute
405
- last_weight_value = weights[-1][0] if len(weights) > 0 else 0
406
- G.edges[character, close_character]["dweight"].append(
407
- [float(last_weight_value) + 1, i, len(mentions)]
408
- )
409
-
410
- return G
411
-
412
419
  def supported_langs(self) -> Union[Set[str], Literal["any"]]:
413
420
  return "any"
414
421
 
415
422
  def needs(self) -> Set[str]:
416
423
  needs = {"characters", "sentences"}
417
- if self.dynamic_needs_chapter:
418
- needs.add("chapter_tokens")
424
+
425
+ if self.need_dynamic_blocks:
426
+ needs.add("dynamic_blocks")
427
+ needs.add("char2token")
428
+ if self.need_co_occurrences_blocks:
429
+ needs.add("co_occurrences_blocks")
430
+ needs.add("char2token")
431
+
432
+ if len(self.additional_ner_classes) > 0:
433
+ needs.add("entities")
434
+
419
435
  return needs
420
436
 
421
437
  def production(self) -> Set[str]:
@@ -426,26 +442,49 @@ class CoOccurrencesGraphExtractor(PipelineStep):
426
442
 
427
443
 
428
444
  class ConversationalGraphExtractor(PipelineStep):
429
- """A graph extractor using conversation between characters
445
+ """A graph extractor using conversation between characters or
446
+ mentions.
430
447
 
431
448
  .. note::
432
449
 
433
- This is an early version, that only supports static graphs
434
- for now.
450
+ Does not support dynamic networks yet.
435
451
  """
436
452
 
437
453
  def __init__(
438
- self, conversation_dist: Union[int, Tuple[int, Literal["tokens", "sentences"]]]
454
+ self,
455
+ graph_type: Literal["conversation", "mention"],
456
+ conversation_dist: Optional[
457
+ Union[int, Tuple[int, Literal["tokens", "sentences"]]]
458
+ ] = None,
459
+ ignore_self_mention: bool = True,
439
460
  ):
461
+ """
462
+ :param graph_type: either 'conversation' or 'mention'.
463
+ 'conversation' extracts an undirected graph with
464
+ interactions being extracted from the conversations
465
+ occurring between characters. 'mention' extracts a
466
+ directed graph where interactions are character mentions
467
+ of one another in quoted speech.
468
+ :param conversation_dist: must be supplied if `graph_type` is
469
+ 'conversation'. The distance between two quotation for
470
+ them to be considered as being interacting.
471
+ :param ignore_self_mention: if ``True``, self mentions are
472
+ ignore for ``graph_type=='mention'``
473
+ """
474
+ self.graph_type = graph_type
475
+
440
476
  if isinstance(conversation_dist, int):
441
477
  conversation_dist = (conversation_dist, "tokens")
442
478
  self.conversation_dist = conversation_dist
443
479
 
480
+ self.ignore_self_mention = ignore_self_mention
481
+
444
482
  super().__init__()
445
483
 
446
484
  def _quotes_interact(
447
485
  self, quote_1: Quote, quote_2: Quote, sentences: List[List[str]]
448
486
  ) -> bool:
487
+ assert not self.conversation_dist is None
449
488
  ordered = quote_2.start >= quote_1.end
450
489
  if self.conversation_dist[1] == "tokens":
451
490
  return (
@@ -467,14 +506,13 @@ class ConversationalGraphExtractor(PipelineStep):
467
506
  else:
468
507
  raise NotImplementedError
469
508
 
470
- def __call__(
509
+ def _conversation_extract(
471
510
  self,
472
511
  sentences: List[List[str]],
473
512
  quotes: List[Quote],
474
513
  speakers: List[Optional[Character]],
475
514
  characters: Set[Character],
476
- **kwargs,
477
- ) -> Dict[str, Any]:
515
+ ) -> nx.Graph:
478
516
  G = nx.Graph()
479
517
  for character in characters:
480
518
  G.add_node(character)
@@ -504,6 +542,57 @@ class ConversationalGraphExtractor(PipelineStep):
504
542
  G.add_edge(speaker_1, speaker_2, weight=0)
505
543
  G.edges[speaker_1, speaker_2]["weight"] += 1
506
544
 
545
+ return G
546
+
547
+ def _mention_extract(
548
+ self,
549
+ quotes: List[Quote],
550
+ speakers: List[Optional[Character]],
551
+ characters: Set[Character],
552
+ ) -> nx.Graph:
553
+ G = nx.DiGraph()
554
+ for character in characters:
555
+ G.add_node(character)
556
+
557
+ for quote, speaker in zip(quotes, speakers):
558
+ # no speaker prediction: ignore
559
+ if speaker is None:
560
+ continue
561
+
562
+ # TODO: optim
563
+ # find characters mentioned in quote and add a directed
564
+ # edge speaker => character
565
+ for character in characters:
566
+ if character == speaker and self.ignore_self_mention:
567
+ continue
568
+ for mention in character.mentions:
569
+ if (
570
+ mention.start_idx >= quote.start
571
+ and mention.end_idx <= quote.end
572
+ ):
573
+ if not G.has_edge(speaker, character):
574
+ G.add_edge(speaker, character, weight=0)
575
+ G.edges[speaker, character]["weight"] += 1
576
+ break
577
+
578
+ return G
579
+
580
+ def __call__(
581
+ self,
582
+ sentences: List[List[str]],
583
+ quotes: List[Quote],
584
+ speakers: List[Optional[Character]],
585
+ characters: Set[Character],
586
+ **kwargs,
587
+ ) -> Dict[str, Any]:
588
+
589
+ if self.graph_type == "conversation":
590
+ G = self._conversation_extract(sentences, quotes, speakers, characters)
591
+ elif self.graph_type == "mention":
592
+ G = self._mention_extract(quotes, speakers, characters)
593
+ else:
594
+ raise ValueError(f"unknown graph_type: {self.graph_type}")
595
+
507
596
  return {"character_network": G}
508
597
 
509
598
  def needs(self) -> Set[str]: