swarmauri_embedding_doc2vec 0.6.0.dev154__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarmauri_embedding_doc2vec/Doc2VecEmbedding.py +82 -0
- swarmauri_embedding_doc2vec/__init__.py +12 -0
- swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/METADATA +20 -0
- swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/RECORD +6 -0
- swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/WHEEL +4 -0
- swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import List, Any, Literal
|
|
2
|
+
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
|
|
3
|
+
from swarmauri_standard.vectors.Vector import Vector
|
|
4
|
+
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase
|
|
5
|
+
from swarmauri_core.ComponentBase import ComponentBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@ComponentBase.register_type(EmbeddingBase, "Doc2VecEmbedding")
|
|
9
|
+
class Doc2VecEmbedding(EmbeddingBase):
|
|
10
|
+
_model: Doc2Vec
|
|
11
|
+
type: Literal["Doc2VecEmbedding"] = "Doc2VecEmbedding"
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
vector_size: int = 3000, # Reduced size for better performance
|
|
16
|
+
window: int = 10,
|
|
17
|
+
min_count: int = 1,
|
|
18
|
+
workers: int = 5,
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
super().__init__(**kwargs)
|
|
22
|
+
self._model = Doc2Vec(
|
|
23
|
+
vector_size=vector_size,
|
|
24
|
+
window=window,
|
|
25
|
+
min_count=min_count,
|
|
26
|
+
workers=workers,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def fit(self, documents: List[str], labels=None) -> None:
|
|
30
|
+
tagged_data = [
|
|
31
|
+
TaggedDocument(words=doc.split(), tags=[str(i)])
|
|
32
|
+
for i, doc in enumerate(documents)
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
# Check if the model already has a vocabulary built
|
|
36
|
+
if len(self._model.wv) == 0:
|
|
37
|
+
self._model.build_vocab(
|
|
38
|
+
tagged_data
|
|
39
|
+
) # Build vocabulary if not already built
|
|
40
|
+
else:
|
|
41
|
+
self._model.build_vocab(
|
|
42
|
+
tagged_data, update=True
|
|
43
|
+
) # Update the vocabulary if it exists
|
|
44
|
+
|
|
45
|
+
self._model.train(
|
|
46
|
+
tagged_data,
|
|
47
|
+
total_examples=self._model.corpus_count,
|
|
48
|
+
epochs=self._model.epochs,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def extract_features(self) -> List[Any]:
|
|
52
|
+
return list(self._model.wv.key_to_index.keys())
|
|
53
|
+
|
|
54
|
+
def infer_vector(self, data: str) -> Vector:
|
|
55
|
+
words = data.split()
|
|
56
|
+
# Check if words are known to the model's vocabulary
|
|
57
|
+
known_words = [word for word in words if word in self._model.wv]
|
|
58
|
+
|
|
59
|
+
if not known_words:
|
|
60
|
+
# Return a zero-vector if all words are OOV
|
|
61
|
+
vector = [0.0] * self._model.vector_size
|
|
62
|
+
else:
|
|
63
|
+
# Infer vector from known words
|
|
64
|
+
vector = self._model.infer_vector(known_words)
|
|
65
|
+
|
|
66
|
+
return Vector(value=vector)
|
|
67
|
+
|
|
68
|
+
def transform(self, documents: List[str]) -> List[Vector]:
|
|
69
|
+
return [self.infer_vector(doc) for doc in documents]
|
|
70
|
+
|
|
71
|
+
def fit_transform(self, documents: List[str], **kwargs) -> List[Vector]:
|
|
72
|
+
self.fit(documents, **kwargs)
|
|
73
|
+
return self.transform(documents)
|
|
74
|
+
|
|
75
|
+
def save_model(self, path: str) -> None:
|
|
76
|
+
self._model.save(path)
|
|
77
|
+
|
|
78
|
+
def load_model(self, path: str) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Loads a Doc2Vec model from the specified path.
|
|
81
|
+
"""
|
|
82
|
+
self._model = Doc2Vec.load(path)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .Doc2VecEmbedding import Doc2VecEmbedding
|
|
2
|
+
|
|
3
|
+
__version__ = "0.6.0.dev26"
|
|
4
|
+
__long_desc__ = """
|
|
5
|
+
|
|
6
|
+
# Swarmauri Doc2VecEmbedding Plugin
|
|
7
|
+
|
|
8
|
+
Visit us at: https://swarmauri.com
|
|
9
|
+
Follow us at: https://github.com/swarmauri
|
|
10
|
+
Star us at: https://github.com/swarmauri/swarmauri-sdk
|
|
11
|
+
|
|
12
|
+
"""
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: swarmauri_embedding_doc2vec
|
|
3
|
+
Version: 0.6.0.dev154
|
|
4
|
+
Summary: A Doc2Vec based Embedding Model.
|
|
5
|
+
License: Apache-2.0
|
|
6
|
+
Author: Jacob Stewart
|
|
7
|
+
Author-email: jacob@swarmauri.com
|
|
8
|
+
Requires-Python: >=3.10,<3.13
|
|
9
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Requires-Dist: gensim (==4.3.3)
|
|
15
|
+
Requires-Dist: swarmauri_base (>=0.6.0.dev154,<0.7.0)
|
|
16
|
+
Requires-Dist: swarmauri_core (>=0.6.0.dev154,<0.7.0)
|
|
17
|
+
Project-URL: Repository, http://github.com/swarmauri/swarmauri-sdk
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
|
|
20
|
+
# Swarmauri Example Plugin
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
swarmauri_embedding_doc2vec/__init__.py,sha256=DXLta3Z3moSMMIveBFK-fVcEfGvLbyUKtMhR6k5KUR4,284
|
|
2
|
+
swarmauri_embedding_doc2vec/Doc2VecEmbedding.py,sha256=9U1EbNWJeF3og9VlK_Ks-x10pdVarrxSxjMTATJk5NA,2802
|
|
3
|
+
swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/entry_points.txt,sha256=xBXtQFyoxuVM54qQHK_n6b5FnnhRuL7tVXRSBeC47Lc,86
|
|
4
|
+
swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/METADATA,sha256=nEdSxKb06yLts-5uPqzZTBFFXJli75WyeYGY0n2vBAs,765
|
|
5
|
+
swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
|
|
6
|
+
swarmauri_embedding_doc2vec-0.6.0.dev154.dist-info/RECORD,,
|