Anchor-annotator 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
anchor/plot.py CHANGED
@@ -11,7 +11,6 @@ import numpy as np
11
11
  import pyqtgraph as pg
12
12
  import sqlalchemy
13
13
  from Bio import pairwise2
14
- from line_profiler_pycharm import profile
15
14
  from montreal_forced_aligner.data import CtmInterval
16
15
  from montreal_forced_aligner.db import Speaker, Utterance
17
16
  from montreal_forced_aligner.dictionary.mixins import (
@@ -682,6 +681,7 @@ class UtteranceView(QtWidgets.QWidget):
682
681
  def set_extra_tiers(self):
683
682
  self.extra_tiers = {}
684
683
  self.extra_tiers["Normalized text"] = "normalized_text"
684
+ self.extra_tiers["Transcription"] = "transcription_text"
685
685
  if self.corpus_model.has_alignments and "Words" not in self.extra_tiers:
686
686
  self.extra_tiers["Words"] = "aligned_word_intervals"
687
687
  self.extra_tiers["Phones"] = "aligned_phone_intervals"
@@ -691,14 +691,12 @@ class UtteranceView(QtWidgets.QWidget):
691
691
  self.corpus_model.has_transcribed_alignments
692
692
  and "Transcription" not in self.extra_tiers
693
693
  ):
694
- self.extra_tiers["Transcription"] = "transcription_text"
695
694
  self.extra_tiers["Transcribed words"] = "transcribed_word_intervals"
696
695
  self.extra_tiers["Transcribed phones"] = "transcribed_phone_intervals"
697
696
  if (
698
697
  self.corpus_model.has_per_speaker_transcribed_alignments
699
698
  and "Transcription" not in self.extra_tiers
700
699
  ):
701
- self.extra_tiers["Transcription"] = "transcription_text"
702
700
  self.extra_tiers["Transcribed words"] = "per_speaker_transcribed_word_intervals"
703
701
  self.extra_tiers["Transcribed phones"] = "per_speaker_transcribed_phone_intervals"
704
702
 
@@ -714,6 +712,10 @@ class UtteranceView(QtWidgets.QWidget):
714
712
  for tier in self.speaker_tiers.values():
715
713
  tier.reset_tier()
716
714
 
715
+ def refresh_text_grid(self):
716
+ for tier in self.speaker_tiers.values():
717
+ tier.refresh(reset_bounds=True)
718
+
717
719
  def draw_text_grid(self):
718
720
  scroll_to = None
719
721
  for i, (key, tier) in enumerate(self.speaker_tiers.items()):
@@ -893,6 +895,7 @@ class TextEdit(QtWidgets.QTextEdit):
893
895
  lostFocus = QtCore.Signal()
894
896
  gainedFocus = QtCore.Signal()
895
897
  menuRequested = QtCore.Signal(object, object)
898
+ doubleClicked = QtCore.Signal()
896
899
 
897
900
  def __init__(self, dictionary_model, speaker_id, *args):
898
901
  super().__init__(*args)
@@ -910,6 +913,10 @@ class TextEdit(QtWidgets.QTextEdit):
910
913
  self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
911
914
  self.setWordWrapMode(QtGui.QTextOption.WrapMode.WordWrap)
912
915
 
916
+ def mouseDoubleClickEvent(self, e):
917
+ self.doubleClicked.emit()
918
+ super().mouseDoubleClickEvent(e)
919
+
913
920
  def dragMoveEvent(self, e: QtGui.QDragMoveEvent) -> None:
914
921
  e.ignore()
915
922
 
@@ -935,7 +942,6 @@ class TextEdit(QtWidgets.QTextEdit):
935
942
 
936
943
 
937
944
  class UtterancePGTextItem(pg.TextItem):
938
- @profile
939
945
  def __init__(
940
946
  self,
941
947
  begin: float,
@@ -989,28 +995,24 @@ class UtterancePGTextItem(pg.TextItem):
989
995
  self.end = end
990
996
  if self.end <= self.view_min or self.begin >= self.view_max:
991
997
  return
992
- self.hide()
993
998
  if (
994
999
  self.view_min <= self.begin < self.view_max
995
1000
  or self.view_max >= self.end > self.view_min
996
1001
  or (self.begin <= self.view_min and self.end >= self.view_max)
997
1002
  ):
998
1003
  self.update_pos()
999
- self.show()
1000
1004
 
1001
1005
  def update_view_times(self, begin, end):
1002
1006
  self.view_min = begin
1003
1007
  self.view_max = end
1004
1008
  if self.end <= self.view_min or self.begin >= self.view_max:
1005
1009
  return
1006
- self.hide()
1007
1010
  if (
1008
1011
  self.view_min <= self.begin < self.view_max
1009
1012
  or self.view_max >= self.end > self.view_min
1010
1013
  or (self.begin <= self.view_min and self.end >= self.view_max)
1011
1014
  ):
1012
1015
  self.update_pos()
1013
- self.show()
1014
1016
 
1015
1017
  def update_pos(self):
1016
1018
  visible_begin = max(self.begin, self.view_min)
@@ -1034,7 +1036,6 @@ class UtterancePGTextItem(pg.TextItem):
1034
1036
  self.text_edit.setFixedWidth(width)
1035
1037
  self.text_edit.setFixedHeight(height)
1036
1038
 
1037
- @profile
1038
1039
  def boundingRect(self):
1039
1040
  br = QtCore.QRectF() # bounds of containing ViewBox mapped to local coords.
1040
1041
  if self._cached_pixel_size is None:
@@ -1117,23 +1118,45 @@ class TranscriberErrorHighlighter(QtGui.QSyntaxHighlighter):
1117
1118
  self.rehighlight()
1118
1119
 
1119
1120
  def set_alignment(self, alignment):
1120
- self.alignment = alignment
1121
+ if alignment != self.alignment:
1122
+ self.alignment = alignment
1123
+ self.rehighlight()
1121
1124
 
1122
1125
  def highlightBlock(self, text):
1123
- if not self.alignment:
1124
- return
1125
- current_align_ind = 0
1126
- for word_object in re.finditer(self.WORDS, text):
1127
- sb = self.alignment.seqB[current_align_ind]
1128
- sa = self.alignment.seqA[current_align_ind]
1129
- if sb == word_object.group():
1130
- if sb != sa:
1126
+ if self.alignment:
1127
+ current_align_ind = 0
1128
+ for word_object in re.finditer(self.WORDS, text.lower()):
1129
+ sb = self.alignment.seqB[current_align_ind]
1130
+ sa = self.alignment.seqA[current_align_ind]
1131
+ if sb == "-":
1132
+ start = word_object.start() - 1
1133
+ if start < 0:
1134
+ start = 0
1135
+ count = 1
1131
1136
  self.setFormat(
1132
- word_object.start(),
1133
- word_object.end() - word_object.start(),
1137
+ start,
1138
+ count,
1134
1139
  self.highlight_format,
1135
1140
  )
1136
- current_align_ind += 1
1141
+ while sb != word_object.group():
1142
+ current_align_ind += 1
1143
+ sb = self.alignment.seqB[current_align_ind]
1144
+ sa = self.alignment.seqA[current_align_ind]
1145
+ if sb == word_object.group():
1146
+ if sb != sa:
1147
+ self.setFormat(
1148
+ word_object.start(),
1149
+ word_object.end() - word_object.start(),
1150
+ self.highlight_format,
1151
+ )
1152
+ current_align_ind += 1
1153
+ if current_align_ind < len(self.alignment.seqB):
1154
+ self.setFormat(
1155
+ len(text) - 1,
1156
+ 1,
1157
+ self.highlight_format,
1158
+ )
1159
+
1137
1160
  if self.search_term:
1138
1161
  if not self.search_term.case_sensitive:
1139
1162
  text = text.lower()
@@ -1358,7 +1381,8 @@ class TextAttributeRegion(pg.GraphicsObject):
1358
1381
 
1359
1382
 
1360
1383
  class TranscriberTextRegion(TextAttributeRegion):
1361
- @profile
1384
+ transcribeRequested = QtCore.Signal(object)
1385
+
1362
1386
  def __init__(
1363
1387
  self,
1364
1388
  parent,
@@ -1385,12 +1409,22 @@ class TranscriberTextRegion(TextAttributeRegion):
1385
1409
  plot_theme,
1386
1410
  )
1387
1411
  self.item = item
1412
+ self.text_edit.setPlaceholderText("Double click to transcribe...")
1413
+ self.text_edit.doubleClicked.connect(self.transcribe_utterance)
1388
1414
  self.highlighter = TranscriberErrorHighlighter(self.text_edit.document())
1389
1415
  if alignment is not None:
1390
1416
  self.highlighter.set_alignment(alignment)
1391
1417
  if search_term:
1392
1418
  self.highlighter.setSearchTerm(search_term)
1393
1419
 
1420
+ def transcribe_utterance(self):
1421
+ if not self.text_edit.toPlainText():
1422
+ self.transcribeRequested.emit(self.item.id)
1423
+
1424
+ def mouseDoubleClickEvent(self, event):
1425
+ self.transcribe_utterance()
1426
+ super().mouseDoubleClickEvent(event)
1427
+
1394
1428
 
1395
1429
  class NormalizedTextRegion(TextAttributeRegion):
1396
1430
  def __init__(
@@ -1588,6 +1622,7 @@ class MfaRegion(pg.LinearRegionItem):
1588
1622
  self._generate_picture()
1589
1623
  self.sigRegionChanged.connect(self.update_bounds)
1590
1624
  self.sigRegionChangeFinished.connect(self.update_bounds)
1625
+ self.setCursor(QtCore.Qt.CursorShape.ArrowCursor)
1591
1626
 
1592
1627
  def update_bounds(self):
1593
1628
  beg, end = self.getRegion()
@@ -1685,7 +1720,9 @@ class IntervalTier(pg.GraphicsObject):
1685
1720
  self.intervals = intervals
1686
1721
  self.word = word
1687
1722
  self.array = pg.Qt.internals.PrimitiveArray(QtCore.QRectF, 4)
1723
+ self.selected_array = pg.Qt.internals.PrimitiveArray(QtCore.QRectF, 4)
1688
1724
  self.array.resize(len(self.intervals))
1725
+ self.selected = []
1689
1726
  self.settings = AnchorSettings()
1690
1727
  self.plot_theme = self.settings.plot_theme
1691
1728
  memory = self.array.ndarray()
@@ -1726,9 +1763,16 @@ class IntervalTier(pg.GraphicsObject):
1726
1763
  self.break_line_color = self.plot_theme.break_line_color
1727
1764
  self.text_color = self.plot_theme.text_color
1728
1765
  self.selected_interval_color = self.plot_theme.selected_interval_color
1766
+ self.highlight_interval_color = self.plot_theme.break_line_color
1767
+ self.highlight_text_color = self.plot_theme.background_color
1729
1768
  self.text_pen = pg.mkPen(self.text_color)
1769
+ self.text_brush = pg.mkBrush(self.text_color)
1770
+ self.highlight_text_pen = pg.mkPen(self.highlight_text_color)
1771
+ self.highlight_text_brush = pg.mkBrush(self.highlight_text_color)
1730
1772
  self.border_pen = pg.mkPen(self.break_line_color, width=1)
1731
1773
  self.border_pen.setCapStyle(QtCore.Qt.PenCapStyle.FlatCap)
1774
+ self.search_term = None
1775
+ self.search_regex = None
1732
1776
 
1733
1777
  def mousePressEvent(self, e: QtGui.QMouseEvent) -> None:
1734
1778
  if e.button() == QtCore.Qt.MouseButton.LeftButton:
@@ -1748,13 +1792,40 @@ class IntervalTier(pg.GraphicsObject):
1748
1792
 
1749
1793
  return super().mousePressEvent(e)
1750
1794
 
1795
+ def set_search_term(self, search_term: TextFilterQuery):
1796
+ self.search_term = search_term
1797
+ self.search_regex = None
1798
+ if self.search_term is not None and self.search_term.text:
1799
+ self.search_regex = re.compile(self.search_term.generate_expression())
1800
+ self.selected = []
1801
+ for i, interval in enumerate(self.intervals):
1802
+ if self.search_regex.search(interval.label):
1803
+ self.selected.append(interval)
1804
+ self.selected_array.resize(len(self.selected))
1805
+ if self.selected:
1806
+ memory = self.selected_array.ndarray()
1807
+ for i, interval in enumerate(self.selected):
1808
+ memory[i, 0] = interval.begin
1809
+ memory[i, 2] = interval.end - interval.begin
1810
+ memory[:, 1] = self.bottom_point
1811
+ memory[:, 3] = self.top_point - self.bottom_point
1812
+
1751
1813
  def paint(self, painter, *args):
1752
- inst = self.array.instances()
1753
1814
  vb = self.getViewBox()
1754
1815
  px = vb.viewPixelSize()
1816
+ inst = self.array.instances()
1817
+ painter.save()
1755
1818
  painter.setPen(self.border_pen)
1756
1819
  painter.drawRects(inst)
1820
+ painter.restore()
1757
1821
  total_time = self.selection_model.max_time - self.selection_model.min_time
1822
+ if self.selected:
1823
+ selected_inst = self.selected_array.instances()
1824
+ painter.save()
1825
+ painter.setPen(self.highlight_text_pen)
1826
+ painter.setBrush(pg.mkBrush(self.highlight_interval_color))
1827
+ painter.drawRects(selected_inst)
1828
+ painter.restore()
1758
1829
  for i, interval in enumerate(self.intervals):
1759
1830
  r = inst[i]
1760
1831
  visible_begin = max(r.left(), self.selection_model.plot_min)
@@ -1767,8 +1838,14 @@ class IntervalTier(pg.GraphicsObject):
1767
1838
  options = QtGui.QTextOption()
1768
1839
  options.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter)
1769
1840
  painter.setRenderHint(painter.RenderHint.Antialiasing, True)
1770
- painter.setPen(self.text_pen)
1771
- painter.setBrush(self.text_color)
1841
+ text_pen = self.text_pen
1842
+ text_brush = self.text_brush
1843
+ if self.search_regex is not None:
1844
+ if self.search_regex.search(interval.label):
1845
+ text_pen = self.highlight_text_pen
1846
+ text_brush = self.highlight_text_brush
1847
+ painter.setPen(text_pen)
1848
+ painter.setBrush(text_brush)
1772
1849
  painter.translate(x, (self.top_point + self.bottom_point) / 2)
1773
1850
  path = self.parentItem().painter_path_cache[self.word][interval.label]
1774
1851
  painter.scale(px[0], -px[1])
@@ -1787,8 +1864,8 @@ class IntervalTier(pg.GraphicsObject):
1787
1864
  class UtteranceRegion(MfaRegion):
1788
1865
  lookUpWord = QtCore.Signal(object)
1789
1866
  createWord = QtCore.Signal(object)
1867
+ transcribeRequested = QtCore.Signal(object)
1790
1868
 
1791
- @profile
1792
1869
  def __init__(
1793
1870
  self,
1794
1871
  parent,
@@ -1919,35 +1996,30 @@ class UtteranceRegion(MfaRegion):
1919
1996
  self.normalized_text.text_edit.menuRequested.connect(self.generate_text_edit_menu)
1920
1997
  continue
1921
1998
  elif lookup == "transcription_text":
1922
- alignment = None
1923
- if self.item.normalized_text and self.item.transcription_text:
1924
- alignment = pairwise2.align.globalms(
1925
- self.item.normalized_text.lower().split(),
1926
- self.item.transcription_text.lower().split(),
1927
- 0,
1928
- -2,
1929
- -1,
1930
- -1,
1931
- gap_char=["-"],
1932
- one_alignment_only=True,
1933
- )[0]
1934
1999
  self.transcription_text = TranscriberTextRegion(
1935
2000
  self,
1936
2001
  self.item,
1937
2002
  tier_top_point,
1938
2003
  self.per_tier_range,
1939
2004
  self.selection_model,
1940
- alignment=alignment,
1941
2005
  dictionary_model=dictionary_model,
1942
2006
  search_term=search_term,
1943
2007
  speaker_id=utterance.speaker_id,
1944
2008
  plot_theme=self.plot_theme,
1945
2009
  )
2010
+ self.transcription_text.transcribeRequested.connect(self.transcribeRequested.emit)
1946
2011
  self.transcription_text.setParentItem(self)
1947
2012
  self.transcription_text.text_edit.gainedFocus.connect(self.select_self)
1948
2013
  self.transcription_text.text_edit.menuRequested.connect(
1949
2014
  self.generate_text_edit_menu
1950
2015
  )
2016
+ self.normalized_text.text_edit.textChanged.connect(
2017
+ self.update_transcription_highlight
2018
+ )
2019
+ self.transcription_text.text_edit.textChanged.connect(
2020
+ self.update_transcription_highlight
2021
+ )
2022
+ self.update_transcription_highlight()
1951
2023
  continue
1952
2024
  intervals = getattr(self.item, lookup)
1953
2025
 
@@ -1984,6 +2056,7 @@ class UtteranceRegion(MfaRegion):
1984
2056
  bottom_point=tier_bottom_point,
1985
2057
  word=False,
1986
2058
  )
2059
+ self.extra_tier_intervals[tier_name].append(interval_tier)
1987
2060
 
1988
2061
  elif "word_intervals" in lookup:
1989
2062
  interval_tier = IntervalTier(
@@ -1995,7 +2068,7 @@ class UtteranceRegion(MfaRegion):
1995
2068
  bottom_point=tier_bottom_point,
1996
2069
  word=True,
1997
2070
  )
1998
- interval_tier.highlightRequested.connect(self.text_item.highlighter.setSearchTerm)
2071
+ interval_tier.highlightRequested.connect(self.set_search_term)
1999
2072
  if self.transcription_text is not None:
2000
2073
  interval_tier.highlightRequested.connect(
2001
2074
  self.transcription_text.highlighter.setSearchTerm
@@ -2004,6 +2077,7 @@ class UtteranceRegion(MfaRegion):
2004
2077
  interval_tier.highlightRequested.connect(
2005
2078
  self.normalized_text.highlighter.setSearchTerm
2006
2079
  )
2080
+ self.extra_tier_intervals[tier_name].append(interval_tier)
2007
2081
 
2008
2082
  for interval in intervals:
2009
2083
  if "phone_intervals" in lookup or "word_intervals" in lookup:
@@ -2031,6 +2105,30 @@ class UtteranceRegion(MfaRegion):
2031
2105
  self.show()
2032
2106
  self.available_speakers = available_speakers
2033
2107
 
2108
+ def update_transcription_highlight(self):
2109
+ if self.item.normalized_text and self.item.transcription_text:
2110
+ alignment = pairwise2.align.globalms(
2111
+ self.item.normalized_text.lower().split(),
2112
+ self.item.transcription_text.lower().split(),
2113
+ 0,
2114
+ -2,
2115
+ -1,
2116
+ -1,
2117
+ gap_char=["-"],
2118
+ one_alignment_only=True,
2119
+ )[0]
2120
+ self.transcription_text.highlighter.set_alignment(alignment)
2121
+
2122
+ def set_search_term(self, term):
2123
+ self.text_item.highlighter.setSearchTerm(term)
2124
+ if self.transcription_text is not None:
2125
+ self.transcription_text.highlighter.setSearchTerm(term)
2126
+ if self.normalized_text is not None:
2127
+ self.normalized_text.highlighter.setSearchTerm(term)
2128
+ for tier in self.extra_tier_intervals.values():
2129
+ if tier and isinstance(tier[0], IntervalTier):
2130
+ tier[0].set_search_term(term)
2131
+
2034
2132
  @property
2035
2133
  def painter_path_cache(self):
2036
2134
  return self.parentItem().painter_path_cache
@@ -2043,22 +2141,6 @@ class UtteranceRegion(MfaRegion):
2043
2141
  if self.transcription_text is not None:
2044
2142
  self.transcription_text.text_item.update_times(begin, end)
2045
2143
 
2046
- def show(self):
2047
- for intervals in self.extra_tier_intervals.values():
2048
- for interval in intervals:
2049
- if (
2050
- self.selection_model.min_time
2051
- < interval.item.end
2052
- <= self.selection_model.max_time
2053
- or self.selection_model.min_time
2054
- <= interval.item.begin
2055
- < self.selection_model.max_time
2056
- ):
2057
- interval.show()
2058
- else:
2059
- interval.hide()
2060
- super().show()
2061
-
2062
2144
  def change_editing(self, editable: bool):
2063
2145
  self.lines[0].movable = editable
2064
2146
  self.lines[1].movable = editable
@@ -2730,7 +2812,7 @@ class SpeakerTier(pg.GraphicsObject):
2730
2812
  self.bottom_point = bottom_point
2731
2813
  self.annotation_range = self.top_point - self.bottom_point
2732
2814
  self.extra_tiers = {}
2733
- self.visible_utterances: dict[str, UtteranceRegion] = {}
2815
+ self.visible_utterances: dict[int, UtteranceRegion] = {}
2734
2816
  self.background_brush = pg.mkBrush(self.plot_theme.background_color)
2735
2817
  self.border = pg.mkPen(self.plot_theme.break_line_color)
2736
2818
  self.picture = QtGui.QPicture()
@@ -2792,6 +2874,7 @@ class SpeakerTier(pg.GraphicsObject):
2792
2874
  def setSearchTerm(self, term):
2793
2875
  for utt in self.visible_utterances.values():
2794
2876
  utt.text_item.highlighter.setSearchTerm(term)
2877
+ utt.set_search_term(term)
2795
2878
 
2796
2879
  def refreshTexts(self, utt_id, text):
2797
2880
  for reg in self.visible_utterances.values():
@@ -2807,22 +2890,26 @@ class SpeakerTier(pg.GraphicsObject):
2807
2890
  reg.scene().removeItem(reg)
2808
2891
  self.visible_utterances = {}
2809
2892
 
2810
- @profile
2811
- def refresh(self, *args):
2893
+ def refresh(self, *args, reset_bounds=False):
2812
2894
  self.hide()
2813
2895
  if self.selection_model.plot_min is None:
2814
2896
  return
2815
- # self.rect.setLeft(self.selection_model.plot_min)
2816
- # self.rect.setRight(self.selection_model.plot_max)
2817
- # self._generate_picture()
2818
2897
  self.has_visible_utterances = False
2819
2898
  self.has_selected_utterances = False
2820
2899
  self.speaker_label.setPos(self.selection_model.plot_min, self.top_point)
2821
2900
  cleanup_ids = []
2822
2901
  model_visible_utterances = self.selection_model.visible_utterances()
2823
- visible_ids = [x.id for x in model_visible_utterances]
2902
+ visible_ids = {x.id: x for x in model_visible_utterances}
2824
2903
  for reg in self.visible_utterances.values():
2825
2904
  reg.hide()
2905
+ if reset_bounds and reg.item.id in visible_ids:
2906
+ with QtCore.QSignalBlocker(reg):
2907
+ reg.item.begin, reg.item.end = (
2908
+ visible_ids[reg.item.id].begin,
2909
+ visible_ids[reg.item.id].end,
2910
+ )
2911
+ reg.setRegion((reg.item.begin, reg.item.end))
2912
+ reg.update_edit_fields()
2826
2913
 
2827
2914
  item_min, item_max = reg.getRegion()
2828
2915
  if (
@@ -2877,6 +2964,7 @@ class SpeakerTier(pg.GraphicsObject):
2877
2964
  reg.audioSelected.connect(self.selection_model.select_audio)
2878
2965
  reg.viewRequested.connect(self.selection_model.set_view_times)
2879
2966
  reg.textEdited.connect(self.update_utterance_text)
2967
+ reg.transcribeRequested.connect(self.corpus_model.transcribeRequested.emit)
2880
2968
  reg.selectRequested.connect(self.selection_model.update_select)
2881
2969
  self.visible_utterances[u.id] = reg
2882
2970
 
@@ -2918,7 +3006,7 @@ class SpeakerTier(pg.GraphicsObject):
2918
3006
  reg.setRegion([beg, self.selection_model.model().file.duration])
2919
3007
  return
2920
3008
  for r in self.visible_utterances.values():
2921
- if r == reg:
3009
+ if r.item.id == reg.item.id:
2922
3010
  continue
2923
3011
  other_begin, other_end = r.getRegion()
2924
3012
  if other_begin <= beg < other_end or beg <= other_begin < other_end < end: