dstklib 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dstk/__init__.py +12 -0
- dstk/collocations.py +121 -0
- dstk/count_models.py +112 -0
- dstk/geometric_distance.py +107 -0
- dstk/lib_types/__init__.py +9 -0
- dstk/lib_types/dstk_types.py +26 -0
- dstk/lib_types/fasttext_types.py +1 -0
- dstk/lib_types/gensim_types.py +1 -0
- dstk/lib_types/matplotlib_types.py +4 -0
- dstk/lib_types/nltk_types.py +1 -0
- dstk/lib_types/numpy_types.py +2 -0
- dstk/lib_types/pandas_types.py +1 -0
- dstk/lib_types/sklearn_types.py +1 -0
- dstk/lib_types/spacy_types.py +6 -0
- dstk/matrix_base.py +113 -0
- dstk/pipeline_tools.py +27 -0
- dstk/pipelines.py +114 -0
- dstk/plot_embeddings.py +240 -0
- dstk/predict_models.py +189 -0
- dstk/text_matrix_builder.py +87 -0
- dstk/text_processor.py +450 -0
- dstk/weight_matrix.py +71 -0
- dstk/workflow_tools.py +257 -0
- dstklib-1.0.0.dist-info/LICENSE +674 -0
- dstklib-1.0.0.dist-info/METADATA +360 -0
- dstklib-1.0.0.dist-info/RECORD +28 -0
- dstklib-1.0.0.dist-info/WHEEL +5 -0
- dstklib-1.0.0.dist-info/top_level.txt +1 -0
dstk/text_processor.py
ADDED
@@ -0,0 +1,450 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
import spacy
|
3
|
+
from typing import Callable, cast, Any, TypeGuard
|
4
|
+
from .workflow_tools import workflow, requires, WorkflowManager, accepts_generic
|
5
|
+
|
6
|
+
from .lib_types.spacy_types import *
|
7
|
+
from .lib_types.dstk_types import TokenIterator, POSIterator, SentenceIterator, TextIterator, Sentences, Sentence, POSTags
|
8
|
+
|
9
|
+
STAGES = [
|
10
|
+
"start", # Before any processing
|
11
|
+
"model", # Text to nlp Doc
|
12
|
+
"token_manipulation", # Maniputation of spaCy tokens
|
13
|
+
"text_processing" # After the tokens have been transformed to text
|
14
|
+
"end" # End of the workflow. After this stage the user must necessarily call processed_text to continue with the analysis
|
15
|
+
]
|
16
|
+
|
17
|
+
UNITS = [
|
18
|
+
"sentences", # Processed the text by sentences
|
19
|
+
"words", # Processed the text by words
|
20
|
+
]
|
21
|
+
|
22
|
+
def is_pos_tags(tokens: Any) -> TypeGuard[POSTags]:
|
23
|
+
"""
|
24
|
+
If tokens is of type POSTags, returns True. Else, returns False.
|
25
|
+
|
26
|
+
:param tokens: A list of tokens to check its type.
|
27
|
+
"""
|
28
|
+
|
29
|
+
if not isinstance(tokens, list) or not tokens:
|
30
|
+
return False
|
31
|
+
return all(
|
32
|
+
isinstance(item, tuple) and
|
33
|
+
len(item) == 2 and
|
34
|
+
(isinstance(item[0], Token) or
|
35
|
+
isinstance(item[0], str)) and
|
36
|
+
isinstance(item[1], str)
|
37
|
+
for item in tokens
|
38
|
+
)
|
39
|
+
|
40
|
+
def is_sentence(tokens: Any) -> TypeGuard[Sentences]:
|
41
|
+
"""
|
42
|
+
If tokens is of type Sentences, returns True. Else, returns False.
|
43
|
+
|
44
|
+
:param tokens: A list of tokens to check its type.
|
45
|
+
"""
|
46
|
+
|
47
|
+
return (
|
48
|
+
isinstance(tokens, list) and
|
49
|
+
all(
|
50
|
+
isinstance(item, Span) or
|
51
|
+
(
|
52
|
+
isinstance(item, list) and
|
53
|
+
(
|
54
|
+
all(isinstance(token, Token) for token in item) or
|
55
|
+
all(isinstance(token, str) for token in item)
|
56
|
+
)
|
57
|
+
) or is_pos_tags(item)
|
58
|
+
for item in tokens
|
59
|
+
)
|
60
|
+
)
|
61
|
+
|
62
|
+
def accepts_sentences(accepts: bool = True, custom_error_message: str = "", intercept: bool = True) -> Callable:
|
63
|
+
"""
|
64
|
+
Decorator that allows a method to accept sentence-level input (i.e., a list of sentences).
|
65
|
+
|
66
|
+
This decorator uses a shared generic handler to determine whether the `tokens` argument contains sentence-level input.
|
67
|
+
If both `accepts` and `intercept` are True, it processes each sentence individually by invoking the decorated method once per sentence,
|
68
|
+
then aggregates the results into a single list.
|
69
|
+
|
70
|
+
If `accepts` is True but `intercept` is False, the method is called normally with the full input.
|
71
|
+
If the input is sentence-level but `accepts` is False, a ValueError is raised.
|
72
|
+
|
73
|
+
:param accepts: Whether the method should accept sentence-level inputs. Default is True.
|
74
|
+
:param custom_error_message: An optional message to include in raised errors.
|
75
|
+
:param intercept: If True, intercepts sentence input and splits it into per-sentence processing. If False, the method receives the input unchanged.
|
76
|
+
|
77
|
+
:raises ValueError: If sentence input is provided but the method does not accept it.
|
78
|
+
|
79
|
+
:return: A wrapped method that optionally processes sentence input at the sentence level.
|
80
|
+
"""
|
81
|
+
|
82
|
+
def intercept_sentence(self, input_value: Sentences, method: Callable, *args, **kwargs) -> Sentences:
|
83
|
+
sentences: Sentences = getattr(self, f"_{self._current_stage}") if self._flow else input_value
|
84
|
+
processed_sentences: Sentences = []
|
85
|
+
|
86
|
+
for sentence in sentences:
|
87
|
+
result: Sentence = cast(Sentence, method(self, *args, tokens=sentence, **kwargs))
|
88
|
+
processed_sentences.append(result)
|
89
|
+
|
90
|
+
filtered_sentences = cast(Sentences, [sentence for sentence in processed_sentences if sentence])
|
91
|
+
|
92
|
+
return filtered_sentences
|
93
|
+
|
94
|
+
return accepts_generic(
|
95
|
+
type_checker=is_sentence,
|
96
|
+
input_arg="tokens",
|
97
|
+
accepts=accepts,
|
98
|
+
intercept=intercept,
|
99
|
+
interceptor=intercept_sentence,
|
100
|
+
input_type=Sentences,
|
101
|
+
custom_error_message=custom_error_message
|
102
|
+
)
|
103
|
+
|
104
|
+
def accepts_tags(accepts: bool = True, custom_error_message: str = "", intercept: bool = True) -> Callable:
|
105
|
+
"""
|
106
|
+
Decorator that allows a method to accept POS-tagged input (i.e., list of (token, tag) tuples).
|
107
|
+
|
108
|
+
This decorator uses a shared generic handler to detect tagged input. If both `accepts` and `intercept` are True, it extracts the tokens (stripping the tags), passes them to the method, and reattaches the tags to the result.
|
109
|
+
|
110
|
+
If the result length differs from the original input, it attempts to realign the tags by matching token text.
|
111
|
+
If `intercept` is False, the method is called as-is with the original tagged input.
|
112
|
+
If tagged input is provided but `accepts` is False, a ValueError is raised.
|
113
|
+
|
114
|
+
:param accepts: Whether the method supports POS-tagged inputs. Default is True.
|
115
|
+
:param custom_error_message: Optional custom error message for input validation failure.
|
116
|
+
:param intercept: If True, extracts tokens, passes them to the method, and restores tags after processing.
|
117
|
+
|
118
|
+
:raises ValueError: If input is tagged but cannot be processed or restored properly.
|
119
|
+
|
120
|
+
:return: A wrapped method that conditionally handles tagged input transformation.
|
121
|
+
"""
|
122
|
+
|
123
|
+
def intercept_tags(self, input_value: POSTags, method: Callable, *args, **kwargs) -> POSTags:
|
124
|
+
try:
|
125
|
+
raw_word_tokens, raw_pos = zip(*input_value)
|
126
|
+
word_tokens = cast(list[Token | str], raw_word_tokens)
|
127
|
+
pos = cast(list[str], raw_pos)
|
128
|
+
|
129
|
+
result: list[Token | str] = cast(list[Token | str], method(self, *args, tokens=list(word_tokens), **kwargs))
|
130
|
+
|
131
|
+
if len(result) == len(input_value):
|
132
|
+
return list(zip(result, pos))
|
133
|
+
else:
|
134
|
+
original_pos_map = dict(zip([word_token.text.lower() if isinstance(word_token, Token) else word_token.lower() for word_token in word_tokens], pos))
|
135
|
+
|
136
|
+
result_with_pos: POSTags = [(word, original_pos_map[word.text.lower() if isinstance(word, Token) else word.lower()]) for word in result]
|
137
|
+
|
138
|
+
return result_with_pos if self._current_stage == "text_processing" else [(token, token.pos_) for token in cast(list[Token], result)]
|
139
|
+
except:
|
140
|
+
raise ValueError(f"Method {method.__name__} does not accept a tagged text with the structre (Token, POS) as input.")
|
141
|
+
|
142
|
+
return accepts_generic(
|
143
|
+
type_checker=is_pos_tags,
|
144
|
+
input_arg="tokens",
|
145
|
+
accepts=accepts,
|
146
|
+
intercept=intercept,
|
147
|
+
interceptor=intercept_tags,
|
148
|
+
input_type=POSTags,
|
149
|
+
custom_error_message=custom_error_message
|
150
|
+
)
|
151
|
+
|
152
|
+
class TextProcessor(WorkflowManager):
|
153
|
+
"""
|
154
|
+
Provides a set of methods for text processing, such as raw tokenizers (include punctuation and/or stopwords), tokenizers (remove both punctuation and stopwords, lemmatize), and others (vocabulary extractor, sentence extractor, pos_tagger, etc.)
|
155
|
+
|
156
|
+
:param text: The text to be processed. Defaults to None.
|
157
|
+
:param model: Either the name of installed spaCy language model (e.g., 'en_core_web_sm') or an already loaded spaCy model object. Defaults to None.
|
158
|
+
"""
|
159
|
+
|
160
|
+
_start: Doc
|
161
|
+
_end: str
|
162
|
+
|
163
|
+
def __init__(self, text: str | None = None):
|
164
|
+
"""
|
165
|
+
Initializes TextProcessor with given attributes.
|
166
|
+
"""
|
167
|
+
super().__init__()
|
168
|
+
|
169
|
+
# Stages
|
170
|
+
|
171
|
+
self._model: Doc
|
172
|
+
self._token_manipulation: TokenIterator | SentenceIterator
|
173
|
+
self._text_processing: TextIterator
|
174
|
+
|
175
|
+
self._set_workflow(input_arg=text)
|
176
|
+
|
177
|
+
@requires(stages=["start"])
|
178
|
+
@workflow(input_arg="text", input_process="_start", output_process="_model", next_stage="model")
|
179
|
+
def set_model(self, text: str, model: str | Language) -> Doc:
|
180
|
+
"""
|
181
|
+
Takes a text and analyzes it using a language model. It returns a processed version of the text that includes helpful information like the words, their meanings, and how they relate to each other.
|
182
|
+
|
183
|
+
:param text: The text to be processed.
|
184
|
+
:param model: The name of the model to be used or its instance.
|
185
|
+
"""
|
186
|
+
|
187
|
+
nlp: Language
|
188
|
+
|
189
|
+
if isinstance(model, str):
|
190
|
+
nlp = spacy.load(model)
|
191
|
+
else:
|
192
|
+
nlp = model
|
193
|
+
|
194
|
+
return nlp(text)
|
195
|
+
|
196
|
+
@requires(stages=["model"])
|
197
|
+
@workflow(input_arg="tokens", input_process="_model", output_process="_token_manipulation", next_stage="token_manipulation", set_unit="words")
|
198
|
+
def get_tokens(self, *, tokens: Doc) -> list[Token]:
|
199
|
+
"""
|
200
|
+
Returns a list of spaCy tokens from a Doc object.
|
201
|
+
|
202
|
+
:param Doc: A spaCy Doc object. Defaults to None.
|
203
|
+
"""
|
204
|
+
|
205
|
+
return [token for token in tokens]
|
206
|
+
|
207
|
+
@requires(stages=["model"])
|
208
|
+
@workflow(input_arg="tokens", input_process="_model", output_process="_token_manipulation", next_stage="token_manipulation", set_unit="sentences")
|
209
|
+
def get_sentences(self, *, tokens: Doc) -> list[Span]:
|
210
|
+
"""
|
211
|
+
Returns a list containing sentences as strings or as spaCy Span objects.
|
212
|
+
|
213
|
+
:param Doc: A spaCy Doc object.
|
214
|
+
"""
|
215
|
+
|
216
|
+
return list(tokens.sents)
|
217
|
+
|
218
|
+
@requires(stages=["token_manipulation"])
|
219
|
+
@workflow(input_arg="tokens", input_process="_token_manipulation", output_process="_token_manipulation")
|
220
|
+
@accepts_sentences()
|
221
|
+
@accepts_tags()
|
222
|
+
def remove_stop_words(self, *, tokens: TokenIterator, custom_stop_words: list[str] | None = None) -> list[Token]:
|
223
|
+
"""
|
224
|
+
Filters tokens, returning only alphanumeric tokens that are not stop words.
|
225
|
+
|
226
|
+
:param tokens: A spaCy Doc or Span object or list of spaCy tokens.
|
227
|
+
:param custom_stop_words: A list of custom stop words.
|
228
|
+
|
229
|
+
Supported inputs:
|
230
|
+
|
231
|
+
This method supports different token forms due to decorator-based preprocessing:
|
232
|
+
- tokens: TokenIterator
|
233
|
+
- sentences: Sentences
|
234
|
+
- taggged_tokens: POSTags
|
235
|
+
"""
|
236
|
+
|
237
|
+
lower_stop_words: list[str]
|
238
|
+
|
239
|
+
if custom_stop_words:
|
240
|
+
lower_stop_words = [word.lower() for word in custom_stop_words]
|
241
|
+
|
242
|
+
return [
|
243
|
+
token for token in tokens
|
244
|
+
if token.is_alpha and not token.is_stop and
|
245
|
+
(custom_stop_words is None or token.text.lower() not in lower_stop_words)
|
246
|
+
]
|
247
|
+
|
248
|
+
@requires(stages=["token_manipulation"])
|
249
|
+
@workflow(input_arg="tokens", input_process="_token_manipulation", output_process="_token_manipulation")
|
250
|
+
@accepts_sentences()
|
251
|
+
@accepts_tags()
|
252
|
+
def raw_tokenizer(self, *, tokens: TokenIterator) -> list[Token]:
|
253
|
+
"""
|
254
|
+
Tokenizes a text including punctuation and stop words.
|
255
|
+
|
256
|
+
:param tokens: A spaCy Doc or Span object or list of spaCy tokens.
|
257
|
+
|
258
|
+
Supported inputs:
|
259
|
+
|
260
|
+
This method supports different token forms due to decorator-based preprocessing:
|
261
|
+
- tokens: TokenIterator
|
262
|
+
- sentences: Sentences
|
263
|
+
- taggged_tokens: POSTags
|
264
|
+
"""
|
265
|
+
|
266
|
+
return [token for token in tokens]
|
267
|
+
|
268
|
+
@requires(stages=["token_manipulation"])
|
269
|
+
@workflow(input_arg="tokens", input_process="_token_manipulation", output_process="_token_manipulation")
|
270
|
+
@accepts_sentences()
|
271
|
+
@accepts_tags()
|
272
|
+
def alphanumeric_raw_tokenizer(self, *, tokens: TokenIterator) -> list[Token]:
|
273
|
+
"""
|
274
|
+
Tokenizes a text including only alphanumeric characters and stop words.
|
275
|
+
|
276
|
+
:param tokens: A spaCy Doc or Span object or list of spaCy tokens.
|
277
|
+
|
278
|
+
Supported inputs:
|
279
|
+
|
280
|
+
This method supports different token forms due to decorator-based preprocessing:
|
281
|
+
- tokens: TokenIterator
|
282
|
+
- sentences: Sentences
|
283
|
+
- taggged_tokens: POSTags
|
284
|
+
"""
|
285
|
+
|
286
|
+
return [
|
287
|
+
token
|
288
|
+
for token in tokens
|
289
|
+
if token.text.isalpha()
|
290
|
+
]
|
291
|
+
|
292
|
+
@requires(stages=["token_manipulation"])
|
293
|
+
@workflow(input_arg="tokens", input_process="_token_manipulation", output_process="_token_manipulation")
|
294
|
+
@accepts_sentences()
|
295
|
+
@accepts_tags()
|
296
|
+
def filter_by_pos(self, *, tokens: TokenIterator, pos: str) -> list[Token]:
|
297
|
+
"""
|
298
|
+
Returns a list of spaCy tokens filtered by a spacific part-of-speech tag.
|
299
|
+
|
300
|
+
:param tokens: A spaCy Doc or Span object or list of spaCy tokens.
|
301
|
+
:param pos: The POS tag to filter by (e.g., 'NOUN', 'VERB', etc.). Case-sensitive.
|
302
|
+
|
303
|
+
Supported inputs:
|
304
|
+
|
305
|
+
This method supports different token forms due to decorator-based preprocessing:
|
306
|
+
- tokens: TokenIterator
|
307
|
+
- sentences: Sentences
|
308
|
+
- taggged_tokens: POSTags
|
309
|
+
"""
|
310
|
+
|
311
|
+
return [token for token in tokens if token.pos_ == pos]
|
312
|
+
|
313
|
+
@requires(stages=["token_manipulation"])
|
314
|
+
@workflow(input_arg="tokens", input_process="_token_manipulation", output_process="_token_manipulation")
|
315
|
+
@accepts_sentences()
|
316
|
+
def pos_tagger(self, *, tokens: TokenIterator) -> list[tuple[Token, str]]:
|
317
|
+
"""
|
318
|
+
Returns a list of (Token, POS) tuples, pairing each token with its part-of-speech tag.
|
319
|
+
|
320
|
+
:param tokens: A spaCy Doc or Span object or list of spaCy tokens.
|
321
|
+
|
322
|
+
Supported inputs:
|
323
|
+
|
324
|
+
This method supports different token forms due to decorator-based preprocessing:
|
325
|
+
- tokens: TokenIterator
|
326
|
+
- sentences: Sentences
|
327
|
+
- taggged_tokens: POSTags
|
328
|
+
"""
|
329
|
+
|
330
|
+
return [(token, token.pos_) for token in tokens]
|
331
|
+
|
332
|
+
@requires(stages=["token_manipulation"])
|
333
|
+
@workflow(input_arg="tokens", input_process="_token_manipulation", output_process="_text_processing", next_stage="text_processing")
|
334
|
+
@accepts_sentences()
|
335
|
+
@accepts_tags()
|
336
|
+
def get_text(self, *, tokens: TokenIterator, lemmatize: bool = False) -> TextIterator:
|
337
|
+
"""
|
338
|
+
Returns the text content from a list of spaCy tokens, Span objects or list of spaCy tokens.
|
339
|
+
|
340
|
+
:param tokens: A spaCy Doc or Span object or list of spaCy tokens.
|
341
|
+
:param lemmatize: If True, lemmatizes the words in the text. Defaults to False.
|
342
|
+
|
343
|
+
Supported inputs:
|
344
|
+
|
345
|
+
This method supports different token forms due to decorator-based preprocessing:
|
346
|
+
- tokens: TokenIterator
|
347
|
+
- sentences: Sentences
|
348
|
+
- taggged_tokens: POSTags
|
349
|
+
"""
|
350
|
+
|
351
|
+
return [token.lemma_.lower() if lemmatize else token.text for token in tokens]
|
352
|
+
|
353
|
+
@requires(stages=["text_processing"])
|
354
|
+
@workflow(input_arg="tokens", input_process="_text_processing", output_process="_text_processing")
|
355
|
+
@accepts_sentences()
|
356
|
+
@accepts_tags()
|
357
|
+
def to_lower(self, *, tokens: list[str]) -> list[str]:
|
358
|
+
"""
|
359
|
+
Returns a list of lower cased words.
|
360
|
+
|
361
|
+
:param tokens: A list containing tokens as strings.
|
362
|
+
|
363
|
+
Supported inputs:
|
364
|
+
|
365
|
+
This method supports different token forms due to decorator-based preprocessing:
|
366
|
+
- tokens: list[str]
|
367
|
+
- sentences: list[list[str]] | list[list[tuple[str, str]]]
|
368
|
+
- taggged_tokens: list[tuple[str, str]]
|
369
|
+
"""
|
370
|
+
|
371
|
+
return [string.lower() for string in tokens]
|
372
|
+
|
373
|
+
@requires(stages=["text_processing"])
|
374
|
+
@workflow(input_arg="tokens", input_process="_text_processing", output_process="_text_processing")
|
375
|
+
@accepts_sentences()
|
376
|
+
@accepts_tags(accepts=False)
|
377
|
+
def corpus_by_context_window(self, *, tokens: list[str], window_size: int) -> list[str]:
|
378
|
+
"""
|
379
|
+
Splits the tokens into groups of window_size consecutive words and joins each group into a string.
|
380
|
+
|
381
|
+
:param tokens: A list containing tokens as strings.
|
382
|
+
:param window_size: size of the context window.
|
383
|
+
|
384
|
+
Supported inputs:
|
385
|
+
|
386
|
+
This method supports different token forms due to decorator-based preprocessing:
|
387
|
+
- tokens: list[str]
|
388
|
+
- sentences: list[list[str]]
|
389
|
+
"""
|
390
|
+
|
391
|
+
ngrams: list[list[str]] = [tokens[index:index + window_size] for index in range(len(tokens) - window_size + 1)]
|
392
|
+
return [' '.join(ngram) for ngram in ngrams]
|
393
|
+
|
394
|
+
@requires(stages=["text_processing"], unit="words")
|
395
|
+
@workflow(input_arg="tokens", input_process="_text_processing", output_process="_text_processing")
|
396
|
+
@accepts_tags(accepts=False)
|
397
|
+
def get_vocabulary(self, *, tokens: list[str]) -> list[str]:
|
398
|
+
"""
|
399
|
+
Returns the vocabulary a text.
|
400
|
+
|
401
|
+
:param tokens: A list containing tokens as strings.
|
402
|
+
|
403
|
+
This method supports different token forms due to decorator-based preprocessing:
|
404
|
+
- tokens: list[str]
|
405
|
+
- sentences: list[list[str]]
|
406
|
+
"""
|
407
|
+
|
408
|
+
return sorted(set(tokens))
|
409
|
+
|
410
|
+
@requires(stages=["text_processing"], unit="sentences")
|
411
|
+
@workflow(input_arg="tokens", input_process="_text_processing", output_process="_text_processing")
|
412
|
+
@accepts_sentences()
|
413
|
+
@accepts_tags(accepts=False)
|
414
|
+
def join(self, *, tokens: list[str]) -> str:
|
415
|
+
"""
|
416
|
+
Joins a list of strings into a single string text.
|
417
|
+
|
418
|
+
:param tokens: A list containing tokens as strings.
|
419
|
+
|
420
|
+
This method supports different token forms due to decorator-based preprocessing:
|
421
|
+
- tokens: list[str]
|
422
|
+
- sentences: list[list[str]]
|
423
|
+
"""
|
424
|
+
|
425
|
+
return " ".join(tokens)
|
426
|
+
|
427
|
+
@requires(stages=["text_processing"])
|
428
|
+
@workflow(input_arg="tokens", input_process="_text_processing", output_process="_end", next_stage="end")
|
429
|
+
@accepts_sentences(accepts=False, custom_error_message="You must first use join before saving them to a file.")
|
430
|
+
@accepts_tags(accepts=True, intercept=False)
|
431
|
+
def save_to_file(self, *, tokens: list[str] | list[tuple[str, str]], path: str) -> str:
|
432
|
+
"""
|
433
|
+
Saves a list of strings or (Token, POS) tuples in the specified path. If tokens is a list of strings, it saves each string in a new line. If it is a list of tuples, it saves each tuple in a new line as a pair or values separated by a comma, in a CSV format.
|
434
|
+
|
435
|
+
:param tokens: A list containing tokens as strings or a list of (Token, POS) tuples.
|
436
|
+
:param path: The path where to save the list of tokens.
|
437
|
+
|
438
|
+
This method supports different token forms due to decorator-based preprocessing:
|
439
|
+
- tokens: list[str]
|
440
|
+
- taggged_tokens: list[tuple[str, str]]
|
441
|
+
"""
|
442
|
+
|
443
|
+
with open(path, "w") as file:
|
444
|
+
for token in tokens:
|
445
|
+
if type(token) == str:
|
446
|
+
file.write(token + "\n")
|
447
|
+
else:
|
448
|
+
file.write(token[0] + "," + token[1] + "\n")
|
449
|
+
|
450
|
+
return str(Path(path).resolve())
|
dstk/weight_matrix.py
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
import numpy as np
|
3
|
+
from sklearn.feature_extraction.text import TfidfTransformer
|
4
|
+
from .workflow_tools import requires, workflow, WorkflowManager
|
5
|
+
|
6
|
+
from .lib_types import DataFrame, ndarray, Series, csr_matrix
|
7
|
+
|
8
|
+
STAGES = [
|
9
|
+
"start", # Before any processing
|
10
|
+
"end" # Weighted dataframe
|
11
|
+
]
|
12
|
+
|
13
|
+
class WeightMatrix(WorkflowManager):
|
14
|
+
"""
|
15
|
+
Provides a set of methods to weight a Co-ocurrence matrix.
|
16
|
+
|
17
|
+
:param dataframe: A Co-occurrence matrix to be weighted.
|
18
|
+
"""
|
19
|
+
|
20
|
+
_start: DataFrame
|
21
|
+
_end: DataFrame
|
22
|
+
|
23
|
+
def __init__(self, dataframe: DataFrame | None = None):
|
24
|
+
"""
|
25
|
+
Initializes WeightMatrix with given attributes.
|
26
|
+
"""
|
27
|
+
|
28
|
+
super().__init__()
|
29
|
+
|
30
|
+
self._set_workflow(input_arg=dataframe)
|
31
|
+
|
32
|
+
@requires(stages=["start"])
|
33
|
+
@workflow(input_arg="dataframe", input_process="_start", output_process="_end", next_stage="end")
|
34
|
+
def pmi(self, *, dataframe: DataFrame, positive: bool = False) -> DataFrame:
|
35
|
+
"""
|
36
|
+
Weights a Co-occurrence matrix by PMI or PPMI.
|
37
|
+
|
38
|
+
:param dataframe: A Co-occurrence matrix to be weighted.
|
39
|
+
:param positive: If True, weights the Co-ocurrence matrix by PPMI. If False, weighths it by PMI. Defaults to False.
|
40
|
+
"""
|
41
|
+
|
42
|
+
df: DataFrame = dataframe
|
43
|
+
|
44
|
+
col_totals: Series = df.sum(axis=0)
|
45
|
+
total: float = col_totals.sum()
|
46
|
+
row_totals: Series = df.sum(axis=1)
|
47
|
+
expected: ndarray = np.outer(row_totals, col_totals) / total
|
48
|
+
df = df / expected
|
49
|
+
# Silence distracting warnings about log(0):
|
50
|
+
with np.errstate(divide='ignore'):
|
51
|
+
df = np.log(df)
|
52
|
+
df[np.isinf(df)] = 0.0 # log(0) = 0
|
53
|
+
if positive:
|
54
|
+
df[df < 0] = 0.0
|
55
|
+
|
56
|
+
return df
|
57
|
+
|
58
|
+
@requires(stages=["start"])
|
59
|
+
@workflow(input_arg="dataframe", input_process="_start", output_process="_end", next_stage="end")
|
60
|
+
def tf_idf(self, *, dataframe: DataFrame, **kwargs) -> DataFrame:
|
61
|
+
"""
|
62
|
+
Weights a Co-occurrence matrix by Tf-idf.
|
63
|
+
|
64
|
+
:param dataframe: A Co-occurrence matrix to be weighted.
|
65
|
+
"""
|
66
|
+
|
67
|
+
transformer: TfidfTransformer = TfidfTransformer(**kwargs)
|
68
|
+
tf_idf_matrix: csr_matrix = transformer.fit_transform(dataframe)
|
69
|
+
|
70
|
+
return pd.DataFrame(tf_idf_matrix.toarray(), index=dataframe.index, columns=dataframe.columns)
|
71
|
+
|