nltkor 1.2.14__cp311-cp311-macosx_13_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nltkor/Kor_char.py +193 -0
- nltkor/__init__.py +16 -0
- nltkor/alignment/__init__.py +1315 -0
- nltkor/cider/__init__.py +2 -0
- nltkor/cider/cider.py +55 -0
- nltkor/cider/cider_scorer.py +207 -0
- nltkor/distance/__init__.py +441 -0
- nltkor/distance/wasserstein.py +126 -0
- nltkor/etc.py +22 -0
- nltkor/lazyimport.py +144 -0
- nltkor/make_requirement.py +11 -0
- nltkor/metrics/__init__.py +63 -0
- nltkor/metrics/bartscore.py +301 -0
- nltkor/metrics/bertscore.py +331 -0
- nltkor/metrics/bleu_tensor.py +20 -0
- nltkor/metrics/classical.py +847 -0
- nltkor/metrics/entment.py +24 -0
- nltkor/metrics/eval.py +517 -0
- nltkor/metrics/mauve.py +273 -0
- nltkor/metrics/mauve_utils.py +131 -0
- nltkor/misc/__init__.py +11 -0
- nltkor/misc/string2string_basic_functions.py +59 -0
- nltkor/misc/string2string_default_tokenizer.py +83 -0
- nltkor/misc/string2string_hash_functions.py +159 -0
- nltkor/misc/string2string_word_embeddings.py +503 -0
- nltkor/search/__init__.py +10 -0
- nltkor/search/classical.py +569 -0
- nltkor/search/faiss_search.py +787 -0
- nltkor/search/kobert_tokenizer.py +181 -0
- nltkor/sejong/__init__.py +3 -0
- nltkor/sejong/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/sejong/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/sejong/__pycache__/sejong_download.cpython-38.pyc +0 -0
- nltkor/sejong/__pycache__/sejong_download.cpython-39.pyc +0 -0
- nltkor/sejong/__pycache__/ssem.cpython-38.pyc +0 -0
- nltkor/sejong/__pycache__/ssem.cpython-39.pyc +0 -0
- nltkor/sejong/ch.py +12 -0
- nltkor/sejong/dict_semClassNum.txt +491 -0
- nltkor/sejong/layer.txt +630 -0
- nltkor/sejong/sejong_download.py +87 -0
- nltkor/sejong/ssem.py +684 -0
- nltkor/similarity/__init__.py +3 -0
- nltkor/similarity/bartscore____.py +337 -0
- nltkor/similarity/bertscore____.py +339 -0
- nltkor/similarity/classical.py +245 -0
- nltkor/similarity/cosine_similarity.py +175 -0
- nltkor/tag/__init__.py +71 -0
- nltkor/tag/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/__pycache__/espresso_tag.cpython-38.pyc +0 -0
- nltkor/tag/__pycache__/espresso_tag.cpython-39.pyc +0 -0
- nltkor/tag/espresso_tag.py +220 -0
- nltkor/tag/libs/__init__.py +10 -0
- nltkor/tag/libs/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/attributes.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/attributes.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/config.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/config.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/metadata.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/metadata.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/reader.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/reader.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/taggers.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/taggers.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/utils.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/utils.cpython-39.pyc +0 -0
- nltkor/tag/libs/__pycache__/word_dictionary.cpython-38.pyc +0 -0
- nltkor/tag/libs/__pycache__/word_dictionary.cpython-39.pyc +0 -0
- nltkor/tag/libs/arguments.py +280 -0
- nltkor/tag/libs/attributes.py +231 -0
- nltkor/tag/libs/config.py +159 -0
- nltkor/tag/libs/metadata.py +129 -0
- nltkor/tag/libs/ner/__init__.py +2 -0
- nltkor/tag/libs/ner/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/libs/ner/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/libs/ner/__pycache__/ner_reader.cpython-38.pyc +0 -0
- nltkor/tag/libs/ner/__pycache__/ner_reader.cpython-39.pyc +0 -0
- nltkor/tag/libs/ner/macmorphoreader.py +7 -0
- nltkor/tag/libs/ner/ner_reader.py +92 -0
- nltkor/tag/libs/network.c +72325 -0
- nltkor/tag/libs/network.cpython-311-darwin.so +0 -0
- nltkor/tag/libs/network.pyx +878 -0
- nltkor/tag/libs/networkconv.pyx +1028 -0
- nltkor/tag/libs/networkdependencyconv.pyx +451 -0
- nltkor/tag/libs/parse/__init__.py +1 -0
- nltkor/tag/libs/parse/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/libs/parse/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/libs/parse/__pycache__/parse_reader.cpython-38.pyc +0 -0
- nltkor/tag/libs/parse/__pycache__/parse_reader.cpython-39.pyc +0 -0
- nltkor/tag/libs/parse/parse_reader.py +283 -0
- nltkor/tag/libs/pos/__init__.py +2 -0
- nltkor/tag/libs/pos/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/libs/pos/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/libs/pos/__pycache__/pos_reader.cpython-38.pyc +0 -0
- nltkor/tag/libs/pos/__pycache__/pos_reader.cpython-39.pyc +0 -0
- nltkor/tag/libs/pos/macmorphoreader.py +7 -0
- nltkor/tag/libs/pos/pos_reader.py +97 -0
- nltkor/tag/libs/reader.py +485 -0
- nltkor/tag/libs/srl/__init__.py +3 -0
- nltkor/tag/libs/srl/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/libs/srl/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/libs/srl/__pycache__/srl_reader.cpython-38.pyc +0 -0
- nltkor/tag/libs/srl/__pycache__/srl_reader.cpython-39.pyc +0 -0
- nltkor/tag/libs/srl/__pycache__/train_srl.cpython-38.pyc +0 -0
- nltkor/tag/libs/srl/__pycache__/train_srl.cpython-39.pyc +0 -0
- nltkor/tag/libs/srl/__srl_reader_.py +535 -0
- nltkor/tag/libs/srl/srl_reader.py +436 -0
- nltkor/tag/libs/srl/train_srl.py +87 -0
- nltkor/tag/libs/taggers.py +926 -0
- nltkor/tag/libs/utils.py +384 -0
- nltkor/tag/libs/word_dictionary.py +239 -0
- nltkor/tag/libs/wsd/__init__.py +2 -0
- nltkor/tag/libs/wsd/__pycache__/__init__.cpython-38.pyc +0 -0
- nltkor/tag/libs/wsd/__pycache__/__init__.cpython-39.pyc +0 -0
- nltkor/tag/libs/wsd/__pycache__/wsd_reader.cpython-38.pyc +0 -0
- nltkor/tag/libs/wsd/__pycache__/wsd_reader.cpython-39.pyc +0 -0
- nltkor/tag/libs/wsd/macmorphoreader.py +7 -0
- nltkor/tag/libs/wsd/wsd_reader.py +93 -0
- nltkor/tokenize/__init__.py +62 -0
- nltkor/tokenize/ko_tokenize.py +115 -0
- nltkor/trans.py +121 -0
- nltkor-1.2.14.dist-info/LICENSE.txt +1093 -0
- nltkor-1.2.14.dist-info/METADATA +41 -0
- nltkor-1.2.14.dist-info/RECORD +127 -0
- nltkor-1.2.14.dist-info/WHEEL +5 -0
- nltkor-1.2.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,787 @@
|
|
1
|
+
"""
|
2
|
+
string2string search
|
3
|
+
src = https://github.com/stanfordnlp/string2string
|
4
|
+
|
5
|
+
|
6
|
+
MIT License
|
7
|
+
|
8
|
+
Copyright (c) 2023 Mirac Suzgun
|
9
|
+
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
12
|
+
in the Software without restriction, including without limitation the rights
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
15
|
+
furnished to do so, subject to the following conditions:
|
16
|
+
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
18
|
+
copies or substantial portions of the Software.
|
19
|
+
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
26
|
+
SOFTWARE.
|
27
|
+
|
28
|
+
|
29
|
+
"""
|
30
|
+
|
31
|
+
|
32
|
+
"""
|
33
|
+
This module contains a wrapper for the Faiss library by Facebook AI Research.
|
34
|
+
"""
|
35
|
+
|
36
|
+
from collections import Counter
|
37
|
+
from typing import List, Union, Optional, Dict, Any
|
38
|
+
import os
|
39
|
+
import copy
|
40
|
+
import logging
|
41
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
42
|
+
|
43
|
+
from nltkor.make_requirement import make_requirement
|
44
|
+
try:
|
45
|
+
import torch
|
46
|
+
from transformers import AutoTokenizer, AutoModel, XLNetTokenizer
|
47
|
+
import pandas as pd
|
48
|
+
from datasets import Dataset
|
49
|
+
# import protobuf
|
50
|
+
except ImportError:
|
51
|
+
requirment = ['torch', 'transformers>=4.8.2', 'pandas', 'datasets', "protobuf", 'sentencepiece']
|
52
|
+
file_path = make_requirement(requirment)
|
53
|
+
raise Exception(f"""
|
54
|
+
Need to install Libraries, please pip install below libraries
|
55
|
+
\t pip install transformers>=4.8.2
|
56
|
+
\t pip install torch
|
57
|
+
\t pip install pandas
|
58
|
+
\t pip install datasets
|
59
|
+
\t pip install protobuf
|
60
|
+
\t pip install sentencepiece
|
61
|
+
Or, use pip install requirement.txt
|
62
|
+
\t pip install -r {file_path}
|
63
|
+
""")
|
64
|
+
|
65
|
+
# from nltk.search.kobert_tokenizer import KoBERTTokenizer
|
66
|
+
|
67
|
+
|
68
|
+
class FaissSearch:
|
69
|
+
def __new__(cls,
|
70
|
+
mode = None,
|
71
|
+
model_name_or_path: str = 'klue/bert-base',
|
72
|
+
tokenizer_name_or_path: str = 'klue/bert-base',
|
73
|
+
device: str = 'cpu'
|
74
|
+
) -> None:
|
75
|
+
if mode == 'sentence':
|
76
|
+
return FaissSearch_SenEmbed(model_name_or_path)
|
77
|
+
elif mode == 'word':
|
78
|
+
return FaissSearch_WordEmbed(model_name_or_path)
|
79
|
+
else:
|
80
|
+
raise ValueError("choice 'sentence' or 'word'")
|
81
|
+
|
82
|
+
|
83
|
+
# FAISS original library wrapper class
|
84
|
+
class FaissSearch_SenEmbed:
|
85
|
+
def __init__(self,
|
86
|
+
model_name_or_path: str = 'klue/bert-base',
|
87
|
+
tokenizer_name_or_path: str = 'klue/bert-base',
|
88
|
+
device: str = 'cpu',
|
89
|
+
) -> None:
|
90
|
+
r"""
|
91
|
+
This function initializes the wrapper for the FAISS library, which is used to perform semantic search.
|
92
|
+
|
93
|
+
|
94
|
+
.. attention::
|
95
|
+
|
96
|
+
* If you use this class, please make sure to cite the following paper:
|
97
|
+
|
98
|
+
.. code-block:: latex
|
99
|
+
|
100
|
+
@article{johnson2019billion,
|
101
|
+
title={Billion-scale similarity search with {GPUs}},
|
102
|
+
author={Johnson, Jeff and Douze, Matthijs and J{\'e}gou, Herv{\'e}},
|
103
|
+
journal={IEEE Transactions on Big Data},
|
104
|
+
volume={7},
|
105
|
+
number={3},
|
106
|
+
pages={535--547},
|
107
|
+
year={2019},
|
108
|
+
publisher={IEEE}
|
109
|
+
}
|
110
|
+
|
111
|
+
* The code is based on the following GitHub repository:
|
112
|
+
https://github.com/facebookresearch/faiss
|
113
|
+
|
114
|
+
Arguments:
|
115
|
+
model_name_or_path (str, optional): The name or path of the model to use. Defaults to 'facebook/bart-large'.
|
116
|
+
tokenizer_name_or_path (str, optional): The name or path of the tokenizer to use. Defaults to 'facebook/bart-large'.
|
117
|
+
device (str, optional): The device to use. Defaults to 'cpu'.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
None
|
121
|
+
"""
|
122
|
+
|
123
|
+
# Set the device
|
124
|
+
self.device = device
|
125
|
+
|
126
|
+
# If the tokenizer is not specified, use the model name or path
|
127
|
+
if tokenizer_name_or_path is None:
|
128
|
+
tokenizer_name_or_path = model_name_or_path
|
129
|
+
|
130
|
+
# Load the tokenizer
|
131
|
+
if tokenizer_name_or_path == 'skt/kobert-base-v1':
|
132
|
+
# self.tokenizer = KoBERTTokenizer.from_pretrained(tokenizer_name_or_path)
|
133
|
+
self.tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name_or_path)
|
134
|
+
else:
|
135
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
136
|
+
|
137
|
+
# Load the model
|
138
|
+
self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device)
|
139
|
+
|
140
|
+
# Set the model to evaluation mode (since we do not need the gradients)
|
141
|
+
self.model.eval()
|
142
|
+
|
143
|
+
# Initialize the dataset
|
144
|
+
self.dataset = None
|
145
|
+
|
146
|
+
|
147
|
+
|
148
|
+
# Auxiliary function to get the last hidden state
|
149
|
+
def get_last_hidden_state(self,
|
150
|
+
embeddings: torch.Tensor,
|
151
|
+
) -> torch.Tensor:
|
152
|
+
"""
|
153
|
+
This function returns the last hidden state (e.g., [CLS] token's) of the input embeddings.
|
154
|
+
|
155
|
+
Arguments:
|
156
|
+
embeddings (torch.Tensor): The input embeddings.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
torch.Tensor: The last hidden state.
|
160
|
+
"""
|
161
|
+
|
162
|
+
# Get the last hidden state
|
163
|
+
last_hidden_state = embeddings.last_hidden_state
|
164
|
+
|
165
|
+
# Return the last hidden state
|
166
|
+
return last_hidden_state[:, 0, :]
|
167
|
+
|
168
|
+
|
169
|
+
|
170
|
+
# Auxiliary function to get the mean pooling
|
171
|
+
def get_mean_pooling(self,
|
172
|
+
embeddings: torch.Tensor,
|
173
|
+
) -> torch.Tensor:
|
174
|
+
"""
|
175
|
+
This function returns the mean pooling of the input embeddings.
|
176
|
+
|
177
|
+
Arguments:
|
178
|
+
embeddings (torch.Tensor): The input embeddings.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
torch.Tensor: The mean pooling.
|
182
|
+
"""
|
183
|
+
|
184
|
+
# Get the mean pooling
|
185
|
+
mean_pooling = embeddings.last_hidden_state.mean(dim=1)
|
186
|
+
|
187
|
+
# Return the mean pooling
|
188
|
+
return mean_pooling
|
189
|
+
|
190
|
+
|
191
|
+
# Get the embeddings
|
192
|
+
def get_embeddings(self,
|
193
|
+
text: Union[str, List[str]],
|
194
|
+
embedding_type: str = 'last_hidden_state',
|
195
|
+
batch_size: int = 8,
|
196
|
+
num_workers: int = 4,
|
197
|
+
) -> torch.Tensor:
|
198
|
+
"""
|
199
|
+
This function returns the embeddings of the input text.
|
200
|
+
|
201
|
+
Arguments:
|
202
|
+
text (Union[str, List[str]]): The input text.
|
203
|
+
embedding_type (str, optional): The type of embedding to use. Defaults to 'last_hidden_state'.
|
204
|
+
batch_size (int, optional): The batch size to use. Defaults to 8.
|
205
|
+
num_workers (int, optional): The number of workers to use. Defaults to 4.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
torch.Tensor: The embeddings.
|
209
|
+
|
210
|
+
Raises:
|
211
|
+
ValueError: If the embedding type is invalid.
|
212
|
+
"""
|
213
|
+
|
214
|
+
# Check if the embedding type is valid
|
215
|
+
if embedding_type not in ['last_hidden_state', 'mean_pooling']:
|
216
|
+
raise ValueError(f'Invalid embedding type: {embedding_type}. Only "last_hidden_state" and "mean_pooling" are supported.')
|
217
|
+
|
218
|
+
# Tokenize the input text
|
219
|
+
encoded_text = self.tokenizer(
|
220
|
+
text,
|
221
|
+
padding=True,
|
222
|
+
truncation=True,
|
223
|
+
return_tensors='pt',
|
224
|
+
)
|
225
|
+
|
226
|
+
# Move the input text to the device
|
227
|
+
encoded_text = encoded_text.to(self.device)
|
228
|
+
|
229
|
+
# encoded_inputs = {k: v.to(self.device) for k, v in encoded_inputs.items()}
|
230
|
+
|
231
|
+
# Get the embeddings
|
232
|
+
with torch.no_grad():
|
233
|
+
embeddings = self.model(**encoded_text)
|
234
|
+
|
235
|
+
# Get the proper embedding type
|
236
|
+
if embedding_type == 'last_hidden_state':
|
237
|
+
# Get the last hidden state
|
238
|
+
embeddings = self.get_last_hidden_state(embeddings)
|
239
|
+
elif embedding_type == 'mean_pooling':
|
240
|
+
# Get the mean pooling
|
241
|
+
embeddings = self.get_mean_pooling(embeddings)
|
242
|
+
|
243
|
+
# Return the embeddings
|
244
|
+
return embeddings
|
245
|
+
|
246
|
+
|
247
|
+
|
248
|
+
# Add FAISS index
|
249
|
+
def add_faiss_index(self,
|
250
|
+
column_name: str = 'embeddings',
|
251
|
+
metric_type: Optional[int] = None,
|
252
|
+
batch_size: int = 8,
|
253
|
+
**kwargs,
|
254
|
+
) -> None:
|
255
|
+
"""
|
256
|
+
This function adds a FAISS index to the dataset.
|
257
|
+
|
258
|
+
Arguments:
|
259
|
+
column_name (str, optional): The name of the column containing the embeddings. Defaults to 'embeddings'.
|
260
|
+
index_type (str, optional): The index type to use. Defaults to 'Flat'.
|
261
|
+
metric_type (str, optional): The metric type to use. Defaults to 'L2'.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
None
|
265
|
+
|
266
|
+
Raises:
|
267
|
+
ValueError: If the dataset is not initialized.
|
268
|
+
"""
|
269
|
+
|
270
|
+
# Check if the dataset is initialized
|
271
|
+
if self.dataset is None:
|
272
|
+
raise ValueError('The dataset is not initialized. Please initialize the dataset first.')
|
273
|
+
|
274
|
+
print('Adding FAISS index...')
|
275
|
+
self.dataset.add_faiss_index(
|
276
|
+
column_name,
|
277
|
+
# metric_type=metric_type,
|
278
|
+
# device=self.device,
|
279
|
+
# batch_size=batch_size,
|
280
|
+
faiss_verbose=True,
|
281
|
+
# **kwargs,
|
282
|
+
)
|
283
|
+
|
284
|
+
|
285
|
+
def save_faiss_index(self,
|
286
|
+
index_name: str,
|
287
|
+
file_path: str,
|
288
|
+
) -> None:
|
289
|
+
"""
|
290
|
+
This function saves the FAISS index to the specified file path.
|
291
|
+
* This is a wrapper function for the `save_faiss_index` function in the `Dataset` class.
|
292
|
+
|
293
|
+
Arguments:
|
294
|
+
index_name (str): The name of the FAISS index (e.g., "embeddings")
|
295
|
+
file_path (str): The file path to save the FAISS index.
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
None
|
299
|
+
|
300
|
+
Raises:
|
301
|
+
ValueError: If the dataset is not initialized.
|
302
|
+
"""
|
303
|
+
|
304
|
+
# Check if the dataset is initialized
|
305
|
+
if self.dataset is None:
|
306
|
+
raise ValueError('The dataset is not initialized. Please initialize the dataset first.')
|
307
|
+
|
308
|
+
print('Saving FAISS index...')
|
309
|
+
self.dataset.save_faiss_index(index_name=index_name, file=file_path)
|
310
|
+
|
311
|
+
|
312
|
+
|
313
|
+
def load_faiss_index(self,
|
314
|
+
index_name: str,
|
315
|
+
file_path: str,
|
316
|
+
device: str = 'cpu',
|
317
|
+
) -> None:
|
318
|
+
"""
|
319
|
+
This function loads the FAISS index from the specified file path.
|
320
|
+
* This is a wrapper function for the `load_faiss_index` function in the `Dataset` class.
|
321
|
+
|
322
|
+
Arguments:
|
323
|
+
index_name (str): The name of the FAISS index (e.g., "embeddings")
|
324
|
+
file_path (str): The file path to load the FAISS index from.
|
325
|
+
device (str, optional): The device to use ("cpu" or "cuda") (default: "cpu").
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
None
|
329
|
+
|
330
|
+
Raises:
|
331
|
+
ValueError: If the dataset is not initialized.
|
332
|
+
"""
|
333
|
+
|
334
|
+
# Check if the dataset is initialized
|
335
|
+
if self.dataset is None:
|
336
|
+
raise ValueError('The dataset is not initialized. Please initialize the dataset first.')
|
337
|
+
|
338
|
+
print('Loading FAISS index...')
|
339
|
+
self.dataset.load_faiss_index(index_name=index_name, file=file_path, device=device)
|
340
|
+
|
341
|
+
|
342
|
+
|
343
|
+
# Initialize the corpus using a dictionary or pandas DataFrame or HuggingFace Datasets object
|
344
|
+
def initialize_corpus(self,
|
345
|
+
corpus: Union[Dict[str, List[str]], pd.DataFrame, Dataset],
|
346
|
+
section: str = 'text',
|
347
|
+
index_column_name: str = 'embeddings',
|
348
|
+
embedding_type: str = 'last_hidden_state',
|
349
|
+
batch_size: Optional[int] = None,
|
350
|
+
num_workers: Optional[int] = None,
|
351
|
+
save_path: Optional[str] = None,
|
352
|
+
) -> Dataset:
|
353
|
+
"""
|
354
|
+
This function initializes a dataset using a dictionary or pandas DataFrame or HuggingFace Datasets object.
|
355
|
+
|
356
|
+
Arguments:
|
357
|
+
dataset_dict (Dict[str, List[str]]): The dataset dictionary.
|
358
|
+
section (str): The section of the dataset to use whose embeddings will be used for semantic search (e.g., 'text', 'title', etc.) (default: 'text').
|
359
|
+
index_column_name (str): The name of the column containing the embeddings (default: 'embeddings')
|
360
|
+
embedding_type (str): The type of embedding to use (default: 'last_hidden_state').
|
361
|
+
batch_size (int, optional): The batch size to use (default: 8).
|
362
|
+
max_length (int, optional): The maximum length of the input sequences.
|
363
|
+
num_workers (int, optional): The number of workers to use.
|
364
|
+
save_path (Optional[str], optional): The path to save the dataset (default: None).
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
Dataset: The dataset object (HuggingFace Datasets).
|
368
|
+
|
369
|
+
Raises:
|
370
|
+
ValueError: If the dataset is not a dictionary or pandas DataFrame or HuggingFace Datasets object.
|
371
|
+
"""
|
372
|
+
|
373
|
+
# Create the dataset
|
374
|
+
if isinstance(corpus, dict):
|
375
|
+
self.dataset = Dataset.from_dict(corpus)
|
376
|
+
elif isinstance(corpus, pd.DataFrame):
|
377
|
+
self.dataset = Dataset.from_pandas(corpus)
|
378
|
+
elif isinstance(corpus, Dataset):
|
379
|
+
self.dataset = corpus
|
380
|
+
else:
|
381
|
+
raise ValueError('The dataset must be a dictionary or pandas DataFrame.')
|
382
|
+
|
383
|
+
# Set the embedding_type
|
384
|
+
self.embedding_type = embedding_type
|
385
|
+
|
386
|
+
|
387
|
+
# Map the section of the dataset to the embeddings
|
388
|
+
self.dataset = self.dataset.map(
|
389
|
+
lambda x: {
|
390
|
+
index_column_name: self.get_embeddings(x[section], embedding_type=self.embedding_type).detach().cpu().numpy()[0]
|
391
|
+
},
|
392
|
+
# batched=True,
|
393
|
+
batch_size=batch_size,
|
394
|
+
num_proc=num_workers,
|
395
|
+
)
|
396
|
+
|
397
|
+
# Save the dataset
|
398
|
+
if save_path is not None:
|
399
|
+
self.dataset.to_json(save_path)
|
400
|
+
|
401
|
+
# Add FAISS index
|
402
|
+
self.add_faiss_index(
|
403
|
+
column_name=index_column_name,
|
404
|
+
)
|
405
|
+
|
406
|
+
# Return the dataset
|
407
|
+
return self.dataset
|
408
|
+
|
409
|
+
|
410
|
+
|
411
|
+
# Initialize the dataset using a JSON file
|
412
|
+
def load_dataset_from_json(self,
|
413
|
+
json_path: str,
|
414
|
+
) -> Dataset:
|
415
|
+
"""
|
416
|
+
This function loads a dataset from a JSON file.
|
417
|
+
|
418
|
+
Arguments:
|
419
|
+
json_path (str): The path to the JSON file.
|
420
|
+
|
421
|
+
Returns:
|
422
|
+
Dataset: The dataset.
|
423
|
+
"""
|
424
|
+
|
425
|
+
# Load the dataset
|
426
|
+
self.dataset = Dataset.from_json(json_path)
|
427
|
+
|
428
|
+
# Return the dataset
|
429
|
+
return self.dataset
|
430
|
+
|
431
|
+
|
432
|
+
|
433
|
+
# Search for the most similar elements in the dataset, given a query
|
434
|
+
def search(self,
|
435
|
+
query: str,
|
436
|
+
k: int = 1,
|
437
|
+
index_column_name: str = 'embeddings',
|
438
|
+
) -> pd.DataFrame:
|
439
|
+
"""
|
440
|
+
This function searches for the most similar elements in the dataset, given a query.
|
441
|
+
|
442
|
+
Arguments:
|
443
|
+
query (str): The query.
|
444
|
+
k (int, optional): The number of elements to return (default: 1).
|
445
|
+
index_column_name (str, optional): The name of the column containing the embeddings (default: 'embeddings')
|
446
|
+
|
447
|
+
Returns:
|
448
|
+
pd.DataFrame: The most similar elements in the dataset (text, score, etc.), sorted by score.
|
449
|
+
|
450
|
+
Remarks:
|
451
|
+
The returned elements are dictionaries containing the text and the score.
|
452
|
+
"""
|
453
|
+
|
454
|
+
# Get the embeddings of the query
|
455
|
+
query_embeddings = self.get_embeddings([query], embedding_type=self.embedding_type).detach().cpu().numpy()
|
456
|
+
|
457
|
+
# Search for the most similar elements in the dataset
|
458
|
+
scores, similar_elts = self.dataset.get_nearest_examples(
|
459
|
+
index_name=index_column_name,
|
460
|
+
query=query_embeddings,
|
461
|
+
k=k,
|
462
|
+
)
|
463
|
+
|
464
|
+
# Convert the results to a pandas DataFrame
|
465
|
+
results_df = pd.DataFrame.from_dict(similar_elts)
|
466
|
+
|
467
|
+
# Add the scores
|
468
|
+
results_df['score'] = scores
|
469
|
+
|
470
|
+
# Sort the results by score
|
471
|
+
results_df.sort_values("score", ascending=True, inplace=True)
|
472
|
+
|
473
|
+
# Return the most similar elements
|
474
|
+
return results_df
|
475
|
+
|
476
|
+
|
477
|
+
|
478
|
+
|
479
|
+
# FAISS word embedding library wrapper class
|
480
|
+
class FaissSearch_WordEmbed(FaissSearch_SenEmbed):
|
481
|
+
def __init__(self,
|
482
|
+
model_name_or_path: str = 'klue/bert-base',
|
483
|
+
tokenizer_name_or_path: str = 'klue/bert-base',
|
484
|
+
device: str = 'cpu',
|
485
|
+
) -> None:
|
486
|
+
r"""
|
487
|
+
This function initializes the wrapper for the FAISS library, which is used to perform semantic search.
|
488
|
+
|
489
|
+
|
490
|
+
.. attention::
|
491
|
+
|
492
|
+
* If you use this class, please make sure to cite the following paper:
|
493
|
+
|
494
|
+
.. code-block:: latex
|
495
|
+
|
496
|
+
@article{johnson2019billion,
|
497
|
+
title={Billion-scale similarity search with {GPUs}},
|
498
|
+
author={Johnson, Jeff and Douze, Matthijs and J{\'e}gou, Herv{\'e}},
|
499
|
+
journal={IEEE Transactions on Big Data},
|
500
|
+
volume={7},
|
501
|
+
number={3},
|
502
|
+
pages={535--547},
|
503
|
+
year={2019},
|
504
|
+
publisher={IEEE}
|
505
|
+
}
|
506
|
+
|
507
|
+
* The code is based on the following GitHub repository:
|
508
|
+
https://github.com/facebookresearch/faiss
|
509
|
+
|
510
|
+
Arguments:
|
511
|
+
model_name_or_path (str, optional): The name or path of the model to use. Defaults to 'facebook/bart-large'.
|
512
|
+
tokenizer_name_or_path (str, optional): The name or path of the tokenizer to use. Defaults to 'facebook/bart-large'.
|
513
|
+
device (str, optional): The device to use. Defaults to 'cpu'.
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
None
|
517
|
+
"""
|
518
|
+
|
519
|
+
# Set the device
|
520
|
+
self.device = device
|
521
|
+
|
522
|
+
# If the tokenizer is not specified, use the model name or path
|
523
|
+
if tokenizer_name_or_path is None:
|
524
|
+
tokenizer_name_or_path = model_name_or_path
|
525
|
+
|
526
|
+
# Load the tokenizer
|
527
|
+
if tokenizer_name_or_path == 'skt/kobert-base-v1':
|
528
|
+
# self.tokenizer = KoBERTTokenizer.from_pretrained(tokenizer_name_or_path)
|
529
|
+
self.tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name_or_path)
|
530
|
+
else:
|
531
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
532
|
+
|
533
|
+
# Load the model
|
534
|
+
self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device)
|
535
|
+
|
536
|
+
# Set the model to evaluation mode (since we do not need the gradients)
|
537
|
+
self.model.eval()
|
538
|
+
|
539
|
+
# Initialize the dataset
|
540
|
+
self.dataset = None
|
541
|
+
|
542
|
+
|
543
|
+
|
544
|
+
# Get the embeddings (new code)
|
545
|
+
def get_doc_embeddings(self,
|
546
|
+
#text: Union[str, List[str]],
|
547
|
+
text=None,
|
548
|
+
embedding_type: str = 'last_hidden_state',
|
549
|
+
batch_size: int = 8,
|
550
|
+
num_workers: int = 4,
|
551
|
+
) -> torch.Tensor:
|
552
|
+
"""
|
553
|
+
This function returns the embeddings of the input text.
|
554
|
+
|
555
|
+
Arguments:
|
556
|
+
text (Union[str, List[str]]): The input text.
|
557
|
+
embedding_type (str, optional): The type of embedding to use. Defaults to 'last_hidden_state'.
|
558
|
+
batch_size (int, optional): The batch size to use. Defaults to 8.
|
559
|
+
num_workers (int, optional): The number of workers to use. Defaults to 4.
|
560
|
+
|
561
|
+
Returns:
|
562
|
+
torch.Tensor: The embeddings.
|
563
|
+
|
564
|
+
Raises:
|
565
|
+
ValueError: If the embedding type is invalid.
|
566
|
+
"""
|
567
|
+
|
568
|
+
# Check if the embedding type is valid
|
569
|
+
if embedding_type not in ['last_hidden_state', 'mean_pooling']:
|
570
|
+
raise ValueError(f'Invalid embedding type: {embedding_type}. Only "last_hidden_state" and "mean_pooling" are supported.')
|
571
|
+
|
572
|
+
ids_dict = {}
|
573
|
+
# Tokenize the input text
|
574
|
+
for sentence in text['text']:
|
575
|
+
encoded_text = self.tokenizer(
|
576
|
+
sentence,
|
577
|
+
padding=False,
|
578
|
+
truncation=True,
|
579
|
+
return_tensors='pt',
|
580
|
+
add_special_tokens=False,
|
581
|
+
)
|
582
|
+
|
583
|
+
# Move the input text to the device
|
584
|
+
encoded_text = encoded_text.to(self.device)
|
585
|
+
|
586
|
+
token_ids_list = encoded_text['input_ids'].tolist()
|
587
|
+
token_ids_list = token_ids_list[0]
|
588
|
+
for ids in token_ids_list:
|
589
|
+
if ids not in ids_dict.keys():
|
590
|
+
ids_dict[ids] = [sentence]
|
591
|
+
else:
|
592
|
+
if text not in ids_dict[ids]:
|
593
|
+
ids_dict[ids].append(sentence)
|
594
|
+
|
595
|
+
# Get the embeddings
|
596
|
+
embedding_dict = {}
|
597
|
+
self.model.eval()
|
598
|
+
for key, value in ids_dict.items():
|
599
|
+
embed = self.model(torch.tensor([[key]]), output_hidden_states=True).hidden_states[-1][:,0,:].detach()
|
600
|
+
embedding_dict[embed] = value
|
601
|
+
|
602
|
+
# Return the embeddings
|
603
|
+
return embedding_dict
|
604
|
+
|
605
|
+
|
606
|
+
|
607
|
+
# Get the embeddings (new code)
|
608
|
+
def get_query_embeddings(self,
|
609
|
+
text: Union[str, List[str]],
|
610
|
+
embedding_type: str = 'last_hidden_state',
|
611
|
+
batch_size: int = 8,
|
612
|
+
num_workers: int = 4,
|
613
|
+
) -> torch.Tensor:
|
614
|
+
"""
|
615
|
+
This function returns the embeddings of the input text.
|
616
|
+
|
617
|
+
Arguments:
|
618
|
+
text (Union[str, List[str]]): The input text.
|
619
|
+
embedding_type (str, optional): The type of embedding to use. Defaults to 'last_hidden_state'.
|
620
|
+
batch_size (int, optional): The batch size to use. Defaults to 8.
|
621
|
+
num_workers (int, optional): The number of workers to use. Defaults to 4.
|
622
|
+
|
623
|
+
Returns:
|
624
|
+
torch.Tensor: The embeddings.
|
625
|
+
|
626
|
+
Raises:
|
627
|
+
ValueError: If the embedding type is invalid.
|
628
|
+
"""
|
629
|
+
|
630
|
+
# Check if the embedding type is valid
|
631
|
+
if embedding_type not in ['last_hidden_state', 'mean_pooling']:
|
632
|
+
raise ValueError(f'Invalid embedding type: {embedding_type}. Only "last_hidden_state" and "mean_pooling" are supported.')
|
633
|
+
|
634
|
+
# Tokenize the input text
|
635
|
+
encoded_text = self.tokenizer(
|
636
|
+
text,
|
637
|
+
padding=False,
|
638
|
+
truncation=True,
|
639
|
+
return_tensors='pt',
|
640
|
+
add_special_tokens=False,
|
641
|
+
)
|
642
|
+
|
643
|
+
# Move the input text to the device
|
644
|
+
encoded_text = encoded_text.to(self.device)
|
645
|
+
|
646
|
+
token_ids_list = encoded_text['input_ids'].tolist()
|
647
|
+
token_ids_list = token_ids_list[0]
|
648
|
+
tensor_list = [torch.tensor([[value]]) for value in token_ids_list]
|
649
|
+
|
650
|
+
# Get the embeddings
|
651
|
+
embeds = []
|
652
|
+
self.model.eval()
|
653
|
+
for index, tensor in enumerate(tensor_list):
|
654
|
+
embed = self.model(tensor, output_hidden_states=True).hidden_states[-1][:,0,:].detach().cpu().numpy()
|
655
|
+
embeds.append(embed)
|
656
|
+
|
657
|
+
# Return the embeddings
|
658
|
+
return embeds
|
659
|
+
|
660
|
+
|
661
|
+
|
662
|
+
# Initialize the corpus using a dictionary or pandas DataFrame or HuggingFace Datasets object
|
663
|
+
def initialize_corpus(self,
|
664
|
+
corpus: Union[Dict[str, List[str]], pd.DataFrame, Dataset],
|
665
|
+
section: str = 'text',
|
666
|
+
index_column_name: str = 'embeddings',
|
667
|
+
embedding_type: str = 'last_hidden_state',
|
668
|
+
batch_size: Optional[int] = None,
|
669
|
+
num_workers: Optional[int] = None,
|
670
|
+
save_path: Optional[str] = None,
|
671
|
+
) -> Dataset:
|
672
|
+
"""
|
673
|
+
This function initializes a dataset using a dictionary or pandas DataFrame or HuggingFace Datasets object.
|
674
|
+
|
675
|
+
Arguments:
|
676
|
+
dataset_dict (Dict[str, List[str]]): The dataset dictionary.
|
677
|
+
section (str): The section of the dataset to use whose embeddings will be used for semantic search (e.g., 'text', 'title', etc.) (default: 'text').
|
678
|
+
index_column_name (str): The name of the column containing the embeddings (default: 'embeddings')
|
679
|
+
embedding_type (str): The type of embedding to use (default: 'last_hidden_state').
|
680
|
+
batch_size (int, optional): The batch size to use (default: 8).
|
681
|
+
max_length (int, optional): The maximum length of the input sequences.
|
682
|
+
num_workers (int, optional): The number of workers to use.
|
683
|
+
save_path (Optional[str], optional): The path to save the dataset (default: None).
|
684
|
+
|
685
|
+
Returns:
|
686
|
+
Dataset: The dataset object (HuggingFace Datasets).
|
687
|
+
|
688
|
+
Raises:
|
689
|
+
ValueError: If the dataset is not a dictionary or pandas DataFrame or HuggingFace Datasets object.
|
690
|
+
"""
|
691
|
+
|
692
|
+
# corpus = { 'text': [...] } -> form_dict
|
693
|
+
|
694
|
+
# Set the embedding_type
|
695
|
+
self.embedding_type = embedding_type
|
696
|
+
|
697
|
+
# get embedding dict
|
698
|
+
embedding_dict = self.get_doc_embeddings(text=corpus, embedding_type=self.embedding_type)
|
699
|
+
|
700
|
+
data = {
|
701
|
+
'text' : embedding_dict.values(),
|
702
|
+
'embeddings': []
|
703
|
+
}
|
704
|
+
|
705
|
+
for embed in embedding_dict.keys():
|
706
|
+
embed_list = embed.tolist()
|
707
|
+
data['embeddings'].append(embed_list[0])
|
708
|
+
|
709
|
+
|
710
|
+
if isinstance(data, dict):
|
711
|
+
self.dataset = Dataset.from_dict(data)
|
712
|
+
elif isinstance(data, pd.DataFrame):
|
713
|
+
self.dataset = Dataset.from_pandas(data)
|
714
|
+
elif isinstance(data, Dataset):
|
715
|
+
self.dataset = corpus
|
716
|
+
else:
|
717
|
+
raise ValueError('The dataset must be a dictionary or pandas DataFrame.')
|
718
|
+
|
719
|
+
# Save the dataset
|
720
|
+
if save_path is not None:
|
721
|
+
self.dataset.to_json(save_path)
|
722
|
+
|
723
|
+
# Add FAISS index
|
724
|
+
self.add_faiss_index(
|
725
|
+
column_name=index_column_name,
|
726
|
+
)
|
727
|
+
|
728
|
+
# Return the dataset
|
729
|
+
return self.dataset
|
730
|
+
|
731
|
+
|
732
|
+
|
733
|
+
# Search for the most similar elements in the dataset, given a query
|
734
|
+
def search(self,
|
735
|
+
query: str,
|
736
|
+
k: int = 1,
|
737
|
+
index_column_name: str = 'embeddings',
|
738
|
+
) -> pd.DataFrame:
|
739
|
+
"""
|
740
|
+
This function searches for the most similar elements in the dataset, given a query.
|
741
|
+
|
742
|
+
Arguments:
|
743
|
+
query (str): The query.
|
744
|
+
k (int, optional): The number of elements to return (default: 1).
|
745
|
+
index_column_name (str, optional): The name of the column containing the embeddings (default: 'embeddings')
|
746
|
+
|
747
|
+
Returns:
|
748
|
+
pd.DataFrame: The most similar elements in the dataset (text, score, etc.), sorted by score.
|
749
|
+
|
750
|
+
Remarks:
|
751
|
+
The returned elements are dictionaries containing the text and the score.
|
752
|
+
"""
|
753
|
+
|
754
|
+
|
755
|
+
# Get the embeddings of the query
|
756
|
+
query_embeddings = self.get_query_embeddings([query], embedding_type=self.embedding_type)
|
757
|
+
|
758
|
+
# query_embedding이랑 self.dataset['embeddings'] 값 비교
|
759
|
+
scores = []
|
760
|
+
similar_elts = []
|
761
|
+
for query in query_embeddings:
|
762
|
+
# Search for the most similar elements in the dataset
|
763
|
+
score, similar_elt = self.dataset.get_nearest_examples(
|
764
|
+
index_name=index_column_name,
|
765
|
+
query=query,
|
766
|
+
k=k,
|
767
|
+
)
|
768
|
+
scores.append(score)
|
769
|
+
similar_elts.append(similar_elt)
|
770
|
+
|
771
|
+
text_list = []
|
772
|
+
for item in similar_elts:
|
773
|
+
for text in item['text']:
|
774
|
+
text_list.append(text)
|
775
|
+
|
776
|
+
flat_list = [sentence for sublist in text_list for sentence in sublist]
|
777
|
+
count = Counter(flat_list)
|
778
|
+
count = dict(count.most_common(5))
|
779
|
+
|
780
|
+
sorted_dict = dict(sorted(count.items(), key=lambda x: x[1], reverse=True))
|
781
|
+
|
782
|
+
# Convert the results to a pandas DataFrame
|
783
|
+
results_df = pd.DataFrame({'text': sorted_dict.keys() , 'freq': sorted_dict.values()})
|
784
|
+
|
785
|
+
|
786
|
+
# Return the most similar elements
|
787
|
+
return results_df
|