deeplotx 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deeplotx/ner/bert_ner.py CHANGED
@@ -11,6 +11,8 @@ from deeplotx.ner.base_ner import BaseNER
11
11
  from deeplotx.ner.named_entity import NamedEntity, NamedPerson
12
12
 
13
13
  CACHE_PATH = os.path.join(__ROOT__, '.cache')
14
+ NEW_LINE, BLANK = '\n', ' '
15
+ DEFAULT_LENGTH_THRESHOLD = 384
14
16
  DEFAULT_BERT_NER = 'Davlan/xlm-roberta-base-ner-hrl'
15
17
  N2G_MODEL: list[Name2Gender] = []
16
18
  logger = logging.getLogger('deeplotx.ner')
@@ -44,11 +46,11 @@ class BertNER(BaseNER):
44
46
  trust_remote_code=True, local_files_only=True).to(self.device)
45
47
  self.embed_dim = self.encoder.config.max_position_embeddings
46
48
  self._ner_pipeline = pipeline(task='ner', model=self.encoder, tokenizer=self.tokenizer, trust_remote_code=True)
47
- logger.debug(f'{BaseNER.__name__} initialized on device: {self.device}.')
49
+ logger.debug(f'{BertNER.__name__} initialized on device: {self.device}.')
48
50
 
49
51
  def _fast_extract(self, s: str, with_gender: bool = True, prob_threshold: float = .0) -> list[NamedEntity]:
50
52
  assert prob_threshold <= 1., f'prob_threshold ({prob_threshold}) cannot be larger than 1.'
51
- s = f' {s} '
53
+ s = f' {s.replace(NEW_LINE, BLANK)} '
52
54
  raw_entities = self._ner_pipeline(s)
53
55
  entities = []
54
56
  for ent in raw_entities:
@@ -69,8 +71,21 @@ class BertNER(BaseNER):
69
71
  break
70
72
  for ent in entities:
71
73
  ent[0] = ent[0].strip()
74
+ # stripping
75
+ while not ent[0][0].isalpha():
76
+ if len(ent[0]) < 2:
77
+ break
78
+ if not ent[0][0].isnumeric():
79
+ ent[0] = ent[0][1:]
80
+ while not ent[0][-1].isalpha():
81
+ if len(ent[0]) < 2:
82
+ break
83
+ if not ent[0][-1].isnumeric():
84
+ ent[0] = ent[0][:-1]
72
85
  if ent[1].upper().startswith('B'):
73
86
  ent[1] = ent[1].upper()[1:].strip('-')
87
+ if len(entities) > 0:
88
+ logger.debug(f'Entities: {[_[0] for _ in entities]}, extracted from: "{s.strip()}".')
74
89
  entities = [NamedEntity(*_) for _ in entities if _[2] >= prob_threshold]
75
90
  if not with_gender:
76
91
  return entities
@@ -88,14 +103,19 @@ class BertNER(BaseNER):
88
103
  return entities
89
104
 
90
105
  def _slow_extract(self, s: str, with_gender: bool = True, prob_threshold: float = .0, deduplicate: bool = True) -> list[NamedEntity]:
91
- _entities = self._fast_extract(s, with_gender=with_gender, prob_threshold=prob_threshold) if len(s) < 512 else []
92
- if len(s) >= 512:
93
- window_size: int = 512
94
- offset = window_size // 6
95
- for _offset in [- offset, offset]:
96
- _window_size = window_size + _offset
97
- for i in range(0, len(s) + _window_size, _window_size):
98
- _entities.extend(self._fast_extract(s[i: i + _window_size], with_gender=with_gender, prob_threshold=prob_threshold))
106
+ _length_threshold = DEFAULT_LENGTH_THRESHOLD
107
+ _s_seq = self.tokenizer.encode(s, add_special_tokens=False)
108
+ _entities = self._fast_extract(self.tokenizer.decode(_s_seq, skip_special_tokens=True),
109
+ with_gender=with_gender,
110
+ prob_threshold=prob_threshold) if len(_s_seq) < _length_threshold else []
111
+ # sliding window extracting
112
+ if len(_s_seq) >= _length_threshold:
113
+ _window_size = _length_threshold
114
+ _stride = _length_threshold // 4
115
+ for i in range(0, len(_s_seq) + _stride, _stride):
116
+ _window_text = self.tokenizer.decode(_s_seq[i: i + _window_size], skip_special_tokens=True)
117
+ _entities.extend(self._fast_extract(_window_text, with_gender=with_gender, prob_threshold=prob_threshold))
118
+ # entity combination
99
119
  _tmp_entities = sorted(_entities, key=lambda x: len(x.text), reverse=True)
100
120
  for _ent_i in _tmp_entities:
101
121
  for _ent_j in _entities:
@@ -103,6 +123,7 @@ class BertNER(BaseNER):
103
123
  and len(_ent_j.text) != len(_ent_i.text)
104
124
  and _ent_j in _tmp_entities):
105
125
  _tmp_entities.remove(_ent_j)
126
+ # entity cleaning
106
127
  while True:
107
128
  for _ent in _tmp_entities:
108
129
  if _ent.text not in s or len(_ent.text) < 2:
@@ -115,7 +136,8 @@ class BertNER(BaseNER):
115
136
  if not _continue:
116
137
  break
117
138
  if not deduplicate:
118
- return _tmp_entities
139
+ return sorted(_tmp_entities, key=lambda _: _.text[0], reverse=False)
140
+ # entity deduplication
119
141
  _fin_entities = dict()
120
142
  texts = set([text.text for text in _tmp_entities])
121
143
  for text in texts:
@@ -126,10 +148,11 @@ class BertNER(BaseNER):
126
148
  else:
127
149
  if _ent.base_probability > _fin_entities[_ent.text].base_probability:
128
150
  _fin_entities[_ent.text] = _ent
129
- return [v for k, v in _fin_entities.items()]
151
+ return sorted([v for k, v in _fin_entities.items()], key=lambda _: _.text[0], reverse=False)
130
152
 
131
153
  def __call__(self, s: str, with_gender: bool = True, prob_threshold: float = .0, fast_mode: bool = False, *args, **kwargs):
132
154
  if fast_mode:
133
155
  return self._fast_extract(s=s, with_gender=with_gender, prob_threshold=prob_threshold)
134
156
  else:
135
- return self._slow_extract(s=s, with_gender=with_gender, prob_threshold=prob_threshold, deduplicate=True)
157
+ return self._slow_extract(s=s, with_gender=with_gender, prob_threshold=prob_threshold,
158
+ deduplicate=kwargs.get('deduplicate', True))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.9.3
3
+ Version: 0.9.5
4
4
  Summary: An out-of-the-box long-text NLP framework.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -5,7 +5,7 @@ deeplotx/encoder/long_text_encoder.py,sha256=4oRa9FqfGNZ8-gq14UKuhDkZC0A1Xi-wKmb
5
5
  deeplotx/encoder/longformer_encoder.py,sha256=7Lm65AUD3qwbrzrhJ3dPZkyHeNRSapga3f-5QJCxV5A,3538
6
6
  deeplotx/ner/__init__.py,sha256=Rss1pup9HzHZCG8U9ub8niWa9zRjWCy3Z7zg378KZQg,114
7
7
  deeplotx/ner/base_ner.py,sha256=pZTl50OrHH_FJm4rKp9iuixeOE6FX_AzgDXD32aXsN0,204
8
- deeplotx/ner/bert_ner.py,sha256=I8yFsarsLEQv0vcnNU2JIc0-LuPJcxaO-mLhDFCh1PI,7704
8
+ deeplotx/ner/bert_ner.py,sha256=Jz76QrLy5MabkwWRdOq20lIDvxYtcvoU5Nf0Vnz4l4g,8924
9
9
  deeplotx/ner/named_entity.py,sha256=c6XufIwH6yloJ-ccUjagf4mBl1XbbYDT8xyEJJ_-ZNs,269
10
10
  deeplotx/ner/n2g/__init__.py,sha256=b6fOWJVLaOCtoz8Qlp8NWQbL5lUSbn6H3-8fnVNIPi0,3940
11
11
  deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
@@ -33,8 +33,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=TFxOX8rWU_zKliI9zm
33
33
  deeplotx/util/__init__.py,sha256=5CH4MTeSgsmCe3LPMfvKoSBpwh6jDSBuHVElJvzQzgs,90
34
34
  deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
35
35
  deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
36
- deeplotx-0.9.3.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
37
- deeplotx-0.9.3.dist-info/METADATA,sha256=Fg0KzWIxFcMtuTfmuQ9BBJDFXjNTWtl9l3Cuuc1sX3I,13472
38
- deeplotx-0.9.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- deeplotx-0.9.3.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
40
- deeplotx-0.9.3.dist-info/RECORD,,
36
+ deeplotx-0.9.5.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
37
+ deeplotx-0.9.5.dist-info/METADATA,sha256=IoALoZ2i1T1AbNBjjAML1MXDABrAj8NTiKAMeHwJQpQ,13472
38
+ deeplotx-0.9.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ deeplotx-0.9.5.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
40
+ deeplotx-0.9.5.dist-info/RECORD,,