janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +130 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +19 -14
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +165 -72
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
- janus_llm-2.0.1.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
janus/converter.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Any
|
|
4
4
|
from langchain.schema.language_model import BaseLanguageModel
|
5
5
|
|
6
6
|
from .language.binary import BinarySplitter
|
7
|
-
from .language.combine import Combiner
|
8
7
|
from .language.mumps import MumpsSplitter
|
9
8
|
from .language.splitter import Splitter
|
10
9
|
from .language.treesitter import TreeSitterSplitter
|
@@ -45,6 +44,8 @@ class Converter:
|
|
45
44
|
self,
|
46
45
|
source_language: str = "fortran",
|
47
46
|
max_tokens: None | int = None,
|
47
|
+
protected_node_types: set[str] | list[str] | tuple[str] = (),
|
48
|
+
prune_node_types: set[str] | list[str] | tuple[str] = (),
|
48
49
|
) -> None:
|
49
50
|
"""Initialize a Converter instance.
|
50
51
|
|
@@ -59,13 +60,15 @@ class Converter:
|
|
59
60
|
|
60
61
|
self._source_language: None | str
|
61
62
|
self._source_glob: None | str
|
63
|
+
self._protected_node_types: tuple[str] = ()
|
64
|
+
self._prune_node_types: tuple[str] = ()
|
62
65
|
self._splitter: None | Splitter
|
63
66
|
self._llm: None | BaseLanguageModel = None
|
64
67
|
self._max_tokens: None | int = max_tokens
|
65
68
|
|
66
|
-
self.
|
67
|
-
|
68
|
-
self.
|
69
|
+
self.set_source_language(source_language)
|
70
|
+
self.set_protected_node_types(protected_node_types)
|
71
|
+
self.set_prune_node_types(prune_node_types)
|
69
72
|
|
70
73
|
# Child class must call this. Should we enforce somehow?
|
71
74
|
# self._load_parameters()
|
@@ -86,7 +89,7 @@ class Converter:
|
|
86
89
|
def set_source_language(self, source_language: str) -> None:
|
87
90
|
"""Validate and set the source language.
|
88
91
|
|
89
|
-
The affected objects will not be updated until
|
92
|
+
The affected objects will not be updated until _load_parameters() is called.
|
90
93
|
|
91
94
|
Arguments:
|
92
95
|
source_language: The source programming language.
|
@@ -101,27 +104,55 @@ class Converter:
|
|
101
104
|
self._source_glob = f"**/*.{LANGUAGES[source_language]['suffix']}"
|
102
105
|
self._source_language = source_language
|
103
106
|
|
104
|
-
|
107
|
+
def set_protected_node_types(
|
108
|
+
self, protected_node_types: set[str] | list[str] | tuple[str]
|
109
|
+
) -> None:
|
110
|
+
"""Set the protected (non-mergeable) node types. This will often be structures
|
111
|
+
like functions, classes, or modules which you might want to keep separate
|
112
|
+
|
113
|
+
The affected objects will not be updated until _load_parameters() is called.
|
114
|
+
|
115
|
+
Arguments:
|
116
|
+
protected_node_types: A set of node types that aren't to be merged
|
117
|
+
"""
|
118
|
+
self._protected_node_types = tuple(set(protected_node_types or []))
|
119
|
+
|
120
|
+
def set_prune_node_types(
|
121
|
+
self, prune_node_types: set[str] | list[str] | tuple[str]
|
122
|
+
) -> None:
|
123
|
+
"""Set the node types to prune. This will often be structures
|
124
|
+
like comments or whitespace which you might want to keep out of the LLM
|
125
|
+
|
126
|
+
The affected objects will not be updated until _load_parameters() is called.
|
127
|
+
|
128
|
+
Arguments:
|
129
|
+
prune_node_types: A set of node types which should be pruned
|
130
|
+
"""
|
131
|
+
self._prune_node_types = tuple(set(prune_node_types or []))
|
132
|
+
|
133
|
+
@run_if_changed(
|
134
|
+
"_source_language",
|
135
|
+
"_max_tokens",
|
136
|
+
"_llm",
|
137
|
+
"_protected_node_types",
|
138
|
+
"_prune_node_types",
|
139
|
+
)
|
105
140
|
def _load_splitter(self) -> None:
|
106
141
|
"""Load the splitter according to this instance's attributes.
|
107
142
|
|
108
143
|
If the relevant fields have not been changed since the last time this method was
|
109
144
|
called, nothing happens.
|
110
145
|
"""
|
146
|
+
kwargs = dict(
|
147
|
+
max_tokens=self._max_tokens,
|
148
|
+
model=self._llm,
|
149
|
+
protected_node_types=self._protected_node_types,
|
150
|
+
prune_node_types=self._prune_node_types,
|
151
|
+
)
|
111
152
|
if self._source_language in CUSTOM_SPLITTERS:
|
112
153
|
if self._source_language == "mumps":
|
113
|
-
self._splitter = MumpsSplitter(
|
114
|
-
max_tokens=self._max_tokens,
|
115
|
-
model=self._llm,
|
116
|
-
)
|
154
|
+
self._splitter = MumpsSplitter(**kwargs)
|
117
155
|
elif self._source_language == "binary":
|
118
|
-
self._splitter = BinarySplitter(
|
119
|
-
max_tokens=self._max_tokens,
|
120
|
-
model=self._llm,
|
121
|
-
)
|
156
|
+
self._splitter = BinarySplitter(**kwargs)
|
122
157
|
else:
|
123
|
-
self._splitter = TreeSitterSplitter(
|
124
|
-
language=self._source_language,
|
125
|
-
max_tokens=self._max_tokens,
|
126
|
-
model=self._llm,
|
127
|
-
)
|
158
|
+
self._splitter = TreeSitterSplitter(language=self._source_language, **kwargs)
|
@@ -14,9 +14,7 @@ class TestCollections(unittest.TestCase):
|
|
14
14
|
self.collections = Collections(self._db)
|
15
15
|
|
16
16
|
def test_creation(self):
|
17
|
-
self.
|
18
|
-
|
19
|
-
result = self.collections.create(EmbeddingType.PSEUDO)
|
17
|
+
self.collections.create(EmbeddingType.PSEUDO)
|
20
18
|
|
21
19
|
metadata = {
|
22
20
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
@@ -24,12 +22,9 @@ class TestCollections(unittest.TestCase):
|
|
24
22
|
}
|
25
23
|
|
26
24
|
self._db.create_collection.assert_called_with("pseudo_1", metadata=metadata)
|
27
|
-
self.assertEqual(result, "foo")
|
28
25
|
|
29
26
|
def test_creation_triangulation(self):
|
30
|
-
self.
|
31
|
-
|
32
|
-
result = self.collections.create(EmbeddingType.REQUIREMENT)
|
27
|
+
self.collections.create(EmbeddingType.REQUIREMENT)
|
33
28
|
|
34
29
|
metadata = {
|
35
30
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
@@ -37,7 +32,6 @@ class TestCollections(unittest.TestCase):
|
|
37
32
|
}
|
38
33
|
|
39
34
|
self._db.create_collection.assert_called_with("requirement_1", metadata=metadata)
|
40
|
-
self.assertEqual(result, [])
|
41
35
|
|
42
36
|
def test_creation_of_existing_type(self):
|
43
37
|
mock_collection = MagicMock()
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import unittest
|
2
|
+
from pathlib import Path
|
3
|
+
from unittest.mock import patch
|
4
|
+
|
5
|
+
from ..database import ChromaEmbeddingDatabase, uri_to_path
|
6
|
+
|
7
|
+
|
8
|
+
class TestDatabase(unittest.TestCase):
|
9
|
+
def test_uri_to_path(self):
|
10
|
+
uri = (Path.home().expanduser() / "Documents" / "testfile.txt").as_uri()
|
11
|
+
expected_path = Path.home().expanduser() / "Documents" / "testfile.txt"
|
12
|
+
self.assertEqual(uri_to_path(uri), expected_path)
|
13
|
+
|
14
|
+
@patch("chromadb.PersistentClient", autospec=True)
|
15
|
+
def test_ChromaEmbeddingDatabase(self, mock_client):
|
16
|
+
# Test with default path
|
17
|
+
_ = ChromaEmbeddingDatabase()
|
18
|
+
mock_client.assert_called_once()
|
19
|
+
|
20
|
+
# Test with custom path
|
21
|
+
custom_path = "/custom/path/to/chroma-data"
|
22
|
+
_ = ChromaEmbeddingDatabase(custom_path)
|
23
|
+
mock_client.assert_called()
|
24
|
+
|
25
|
+
# Test with URL
|
26
|
+
url = "http://example.com/chroma-data"
|
27
|
+
_ = ChromaEmbeddingDatabase(url)
|
28
|
+
mock_client.assert_called()
|
29
|
+
|
30
|
+
|
31
|
+
if __name__ == "__main__":
|
32
|
+
unittest.main()
|
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
|
5
5
|
|
6
6
|
from chromadb.api.client import Client
|
7
7
|
|
8
|
+
from ...language.treesitter import TreeSitterSplitter
|
8
9
|
from ...utils.enums import EmbeddingType
|
9
10
|
from ..vectorize import Vectorizer, VectorizerFactory
|
10
11
|
|
@@ -22,7 +23,7 @@ class MockDBVectorizer(VectorizerFactory):
|
|
22
23
|
model: None | str = "gpt4all",
|
23
24
|
path: str | Path = None,
|
24
25
|
) -> Vectorizer:
|
25
|
-
return Vectorizer(self._db
|
26
|
+
return Vectorizer(self._db)
|
26
27
|
|
27
28
|
|
28
29
|
class TestVectorize(unittest.TestCase):
|
@@ -35,7 +36,11 @@ class TestVectorize(unittest.TestCase):
|
|
35
36
|
self.database.list_collections = list_collections
|
36
37
|
self.vectorizer = MockDBVectorizer(self.database).create_vectorizer()
|
37
38
|
self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
|
38
|
-
|
39
|
+
splitter = TreeSitterSplitter(
|
40
|
+
language="fortran",
|
41
|
+
max_tokens=16_384,
|
42
|
+
)
|
43
|
+
self.test_block = splitter.split(self.test_file)
|
39
44
|
|
40
45
|
def test_add_nodes_recursively(self):
|
41
46
|
embedding_type = EmbeddingType.SOURCE
|
@@ -45,9 +50,9 @@ class TestVectorize(unittest.TestCase):
|
|
45
50
|
"time_updated": datetime.datetime.now().time().isoformat("minutes"),
|
46
51
|
}
|
47
52
|
self.database.create_collection.assert_called_with("source_1", metadata=metadata)
|
48
|
-
self.vectorizer.
|
53
|
+
self.vectorizer.add_nodes_recursively(
|
49
54
|
self.test_block, embedding_type, self.test_file.name
|
50
55
|
)
|
51
56
|
self.database.get_or_create_collection.assert_called_with(
|
52
|
-
"source_1", metadata=metadata
|
57
|
+
name="source_1", embedding_function=None, metadata=metadata
|
53
58
|
)
|
janus/embedding/collections.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
import datetime
|
2
2
|
import os
|
3
|
-
from typing import Sequence
|
3
|
+
from typing import Dict, Optional, Sequence
|
4
4
|
|
5
5
|
from chromadb import Client, Collection
|
6
|
+
from langchain_community.vectorstores import Chroma
|
6
7
|
|
7
8
|
from ..utils.enums import EmbeddingType
|
9
|
+
from .embedding_models_info import load_embedding_model
|
8
10
|
|
9
11
|
# See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
|
10
12
|
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
@@ -13,10 +15,16 @@ os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
|
13
15
|
class Collections:
|
14
16
|
"""Manage embedding collections"""
|
15
17
|
|
16
|
-
def __init__(self, client: Client):
|
18
|
+
def __init__(self, client: Client, config: Optional[Dict[str, str]] = None):
|
17
19
|
self._client = client
|
20
|
+
if config is not None:
|
21
|
+
self._config = config
|
22
|
+
else:
|
23
|
+
self._config = {}
|
18
24
|
|
19
|
-
def create(
|
25
|
+
def create(
|
26
|
+
self, name: EmbeddingType | str, model_name: Optional[str] = None
|
27
|
+
) -> Chroma:
|
20
28
|
"""Create a Chroma collection for the given embedding type.
|
21
29
|
|
22
30
|
Arguments:
|
@@ -27,9 +35,23 @@ class Collections:
|
|
27
35
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
28
36
|
"time_updated": datetime.datetime.now().time().isoformat("minutes"),
|
29
37
|
}
|
30
|
-
|
38
|
+
if model_name is not None:
|
39
|
+
metadata["embedding_model"] = model_name
|
40
|
+
self._client.create_collection(collection_name, metadata=metadata)
|
41
|
+
self._config[collection_name] = model_name
|
42
|
+
model, _, _ = load_embedding_model(model_name)
|
43
|
+
return Chroma(
|
44
|
+
client=self._client,
|
45
|
+
collection_name=collection_name,
|
46
|
+
embedding_function=model,
|
47
|
+
)
|
48
|
+
else:
|
49
|
+
self._client.create_collection(collection_name, metadata=metadata)
|
50
|
+
return Chroma(client=self._client, collection_name=collection_name)
|
31
51
|
|
32
|
-
def get_or_create(
|
52
|
+
def get_or_create(
|
53
|
+
self, name: EmbeddingType | str, model_name: Optional[str] = None
|
54
|
+
) -> Chroma:
|
33
55
|
"""Create a Chroma collection for the given embedding type.
|
34
56
|
|
35
57
|
Arguments:
|
@@ -40,7 +62,26 @@ class Collections:
|
|
40
62
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
41
63
|
"time_updated": datetime.datetime.now().time().isoformat("minutes"),
|
42
64
|
}
|
43
|
-
|
65
|
+
if collection_name in self._config:
|
66
|
+
model_name = self._config[collection_name]
|
67
|
+
if model_name is not None:
|
68
|
+
metadata["embedding_model"] = model_name
|
69
|
+
self._config[collection_name] = model_name
|
70
|
+
model, _, _ = load_embedding_model(model_name)
|
71
|
+
self._client.get_or_create_collection(collection_name, metadata=metadata)
|
72
|
+
return Chroma(
|
73
|
+
client=self._client,
|
74
|
+
collection_name=collection_name,
|
75
|
+
embedding_function=model,
|
76
|
+
collection_metadata=metadata,
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
self._client.get_or_create_collection(collection_name, metadata=metadata)
|
80
|
+
return Chroma(
|
81
|
+
client=self._client,
|
82
|
+
collection_name=collection_name,
|
83
|
+
collection_metadata=metadata,
|
84
|
+
)
|
44
85
|
|
45
86
|
def get(self, name: None | EmbeddingType | str = None) -> Sequence[Collection]:
|
46
87
|
"""Get the Chroma collections.
|
@@ -61,6 +102,8 @@ class Collections:
|
|
61
102
|
collection_name = name.name.lower()
|
62
103
|
else:
|
63
104
|
collection_name = name
|
105
|
+
if collection_name in self._config:
|
106
|
+
del self._config[collection_name]
|
64
107
|
self._client.delete_collection(collection_name)
|
65
108
|
|
66
109
|
def _set_collection_name(self, name: EmbeddingType | str) -> str:
|
@@ -0,0 +1,130 @@
|
|
1
|
+
import json
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Any, Callable, Dict, Tuple
|
4
|
+
|
5
|
+
from aenum import MultiValueEnum
|
6
|
+
from dotenv import load_dotenv
|
7
|
+
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
|
8
|
+
from langchain_core.embeddings import Embeddings
|
9
|
+
from langchain_openai import OpenAIEmbeddings
|
10
|
+
|
11
|
+
from ..utils.logger import create_logger
|
12
|
+
|
13
|
+
load_dotenv()
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
try:
|
18
|
+
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
|
19
|
+
except ImportError:
|
20
|
+
log.warning(
|
21
|
+
"Could not import LangChain's HuggingFace Embeddings Client. If you would like "
|
22
|
+
"to use HuggingFace models, please install LangChain's HuggingFace Embeddings "
|
23
|
+
"Client by running 'pip install janus-embedding[hf-local]' or poetry install "
|
24
|
+
"-E hf-local."
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class EmbeddingModelType(MultiValueEnum):
|
29
|
+
OpenAI = "OpenAI", "openai", "open-ai", "oai"
|
30
|
+
HuggingFaceLocal = "HuggingFaceLocal", "huggingfacelocal", "huggingface-local", "hfl"
|
31
|
+
HuggingFaceInferenceAPI = (
|
32
|
+
"HuggingFaceInferenceAPI",
|
33
|
+
"huggingfaceinferenceapi",
|
34
|
+
"huggingface-inference-api",
|
35
|
+
"hfia",
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS: Dict[
|
40
|
+
EmbeddingModelType, Callable[[Any], Embeddings]
|
41
|
+
] = {}
|
42
|
+
|
43
|
+
for model_type in EmbeddingModelType:
|
44
|
+
for value in model_type.values:
|
45
|
+
if model_type == EmbeddingModelType.OpenAI:
|
46
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = OpenAIEmbeddings
|
47
|
+
elif model_type == EmbeddingModelType.HuggingFaceLocal:
|
48
|
+
try:
|
49
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = HuggingFaceEmbeddings
|
50
|
+
except NameError:
|
51
|
+
pass
|
52
|
+
elif model_type == EmbeddingModelType.HuggingFaceInferenceAPI:
|
53
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = HuggingFaceInferenceAPIEmbeddings
|
54
|
+
|
55
|
+
EMBEDDING_MODEL_TYPE_DEFAULT_IDS: Dict[EmbeddingModelType, Dict[str, Any]] = {
|
56
|
+
EmbeddingModelType.OpenAI.value: "text-embedding-3-small",
|
57
|
+
EmbeddingModelType.HuggingFaceLocal.value: "all-MiniLM-L6-v2",
|
58
|
+
EmbeddingModelType.HuggingFaceInferenceAPI.value: "",
|
59
|
+
}
|
60
|
+
|
61
|
+
EMBEDDING_MODEL_DEFAULT_ARGUMENTS: Dict[str, Dict[str, Any]] = {
|
62
|
+
"text-embedding-3-small": dict(model="text-embedding-3-small"),
|
63
|
+
"text-embedding-3-large": dict(model="text-embedding-3-large"),
|
64
|
+
"text-embedding-ada-002": dict(model="text-embedding-ada-002"),
|
65
|
+
}
|
66
|
+
|
67
|
+
EMBEDDING_MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "embeddings"
|
68
|
+
|
69
|
+
EMBEDDING_TOKEN_LIMITS: Dict[str, int] = {
|
70
|
+
"text-embedding-3-small": 8191,
|
71
|
+
"text-embedding-3-large": 8191,
|
72
|
+
"text-embedding-ada-002": 8191,
|
73
|
+
}
|
74
|
+
|
75
|
+
EMBEDDING_COST_PER_MODEL: Dict[str, float] = {
|
76
|
+
"text-embedding-3-small": {"input": 0.02 / 1e6, "output": 0.0},
|
77
|
+
"text-embedding-3-large": {"input": 0.13 / 1e6, "output": 0.0},
|
78
|
+
"text-embedding-ada-002": {"input": 0.10 / 1e6, "output": 0.0},
|
79
|
+
}
|
80
|
+
|
81
|
+
|
82
|
+
def load_embedding_model(
|
83
|
+
model_name: str,
|
84
|
+
) -> Tuple[Embeddings, int, Dict[str, float]]:
|
85
|
+
"""Load an embedding model from the configuration file or create a new one
|
86
|
+
|
87
|
+
Arguments:
|
88
|
+
model_name: The user-given name of the model to load.
|
89
|
+
model_type: The type of the model to load.
|
90
|
+
identifier: The identifier for the model (e.g. the name, URL, or HuggingFace
|
91
|
+
path).
|
92
|
+
"""
|
93
|
+
if not EMBEDDING_MODEL_CONFIG_DIR.exists():
|
94
|
+
EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
|
95
|
+
model_config_file = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
|
96
|
+
|
97
|
+
if not model_config_file.exists():
|
98
|
+
# The default model type is HuggingFaceLocal because that's the default for Chroma
|
99
|
+
model_type = EmbeddingModelType.HuggingFaceLocal.value
|
100
|
+
identifier = EMBEDDING_MODEL_TYPE_DEFAULT_IDS[model_type]
|
101
|
+
model_config = {
|
102
|
+
"model_type": model_type,
|
103
|
+
"model_identifier": identifier,
|
104
|
+
"model_args": EMBEDDING_MODEL_DEFAULT_ARGUMENTS.get(identifier, {}),
|
105
|
+
"token_limit": EMBEDDING_TOKEN_LIMITS.get(identifier, 8191),
|
106
|
+
"model_cost": EMBEDDING_COST_PER_MODEL.get(
|
107
|
+
identifier, {"input": 0, "output": 0}
|
108
|
+
),
|
109
|
+
}
|
110
|
+
log.info(
|
111
|
+
f"WARNING: Creating new model config file: \
|
112
|
+
{model_config_file} with default config"
|
113
|
+
)
|
114
|
+
with open(model_config_file, "w") as f:
|
115
|
+
json.dump(model_config, f, indent=2)
|
116
|
+
else:
|
117
|
+
with open(model_config_file, "r") as f:
|
118
|
+
model_config = json.load(f)
|
119
|
+
model_constructor = EMBEDDING_MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
|
120
|
+
model_args = model_config["model_args"]
|
121
|
+
if model_config["model_type"] in EmbeddingModelType.HuggingFaceInferenceAPI.values:
|
122
|
+
model_args.update({"api_url": model_config["model_identifier"]})
|
123
|
+
elif model_config["model_type"] in EmbeddingModelType.HuggingFaceLocal.values:
|
124
|
+
model_args.update({"model_name": model_config["model_identifier"]})
|
125
|
+
model = model_constructor(**model_args)
|
126
|
+
return (
|
127
|
+
model,
|
128
|
+
model_config["token_limit"],
|
129
|
+
model_config["model_cost"],
|
130
|
+
)
|
janus/embedding/vectorize.py
CHANGED
@@ -1,67 +1,52 @@
|
|
1
1
|
import uuid
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Sequence
|
4
|
+
from typing import Any, Dict, Optional, Sequence
|
5
5
|
|
6
6
|
from chromadb import Client, Collection
|
7
|
+
from langchain_community.vectorstores import Chroma
|
7
8
|
|
8
|
-
from ..
|
9
|
-
from ..language.block import CodeBlock
|
10
|
-
from ..llm.models_info import TOKEN_LIMITS
|
9
|
+
from ..language.block import CodeBlock, TranslatedCodeBlock
|
11
10
|
from ..utils.enums import EmbeddingType
|
12
11
|
from .collections import Collections
|
13
12
|
from .database import ChromaEmbeddingDatabase
|
14
13
|
|
15
14
|
|
16
|
-
class Vectorizer(
|
15
|
+
class Vectorizer(object):
|
17
16
|
"""Class for creating embeddings/vectors in a specified ChromaDB"""
|
18
17
|
|
19
|
-
def __init__(
|
20
|
-
self,
|
21
|
-
client: Client,
|
22
|
-
source_language: str,
|
23
|
-
max_tokens: None | int,
|
24
|
-
model: None | str,
|
25
|
-
) -> None:
|
18
|
+
def __init__(self, client: Client, config: Optional[Dict[str, Any]] = None) -> None:
|
26
19
|
"""Initializes the Vectorizer class
|
27
20
|
|
28
21
|
Arguments:
|
29
22
|
client: ChromaDB client instance
|
30
|
-
source_language: The source programming language.
|
31
|
-
max_tokens: The maximum number of tokens to send to the embedding model at
|
32
|
-
once. If `None`, the `Vectorizer` will use the default value for the
|
33
|
-
`model`.
|
34
|
-
model: The name of the model to use. This will also determine the `max_tokens`
|
35
|
-
if that variable is not set.
|
36
23
|
"""
|
37
|
-
if max_tokens is None:
|
38
|
-
max_tokens = TOKEN_LIMITS[model]
|
39
|
-
|
40
|
-
super().__init__(
|
41
|
-
source_language=source_language,
|
42
|
-
max_tokens=max_tokens,
|
43
|
-
)
|
44
24
|
self._db = client
|
45
|
-
self._collections = Collections(self._db)
|
25
|
+
self._collections = Collections(self._db, config)
|
46
26
|
|
47
|
-
|
27
|
+
def get_or_create_collection(
|
28
|
+
self, name: EmbeddingType | str, model_name: Optional[str] = None
|
29
|
+
) -> Chroma:
|
30
|
+
return self._collections.get_or_create(name, model_name=model_name)
|
48
31
|
|
49
|
-
def create_collection(
|
50
|
-
|
32
|
+
def create_collection(
|
33
|
+
self, embedding_type: EmbeddingType, model_name: Optional[str] = None
|
34
|
+
) -> Chroma:
|
35
|
+
return self._collections.create(embedding_type, model_name=model_name)
|
51
36
|
|
52
37
|
def collections(
|
53
38
|
self, name: None | EmbeddingType | str = None
|
54
39
|
) -> Sequence[Collection]:
|
55
40
|
return self._collections.get(name)
|
56
41
|
|
57
|
-
def
|
42
|
+
def add_nodes_recursively(
|
58
43
|
self, code_block: CodeBlock, collection_name: EmbeddingType | str, file_name: str
|
59
44
|
) -> None:
|
60
45
|
"""Embed all nodes in the tree rooted at `code_block`
|
61
46
|
|
62
47
|
Arguments:
|
63
48
|
code_block: CodeBlock to embed
|
64
|
-
|
49
|
+
collection_name: Collection to add to
|
65
50
|
file_name: Name of file containing `code_block`
|
66
51
|
"""
|
67
52
|
nodes = [code_block]
|
@@ -74,41 +59,55 @@ class Vectorizer(Converter):
|
|
74
59
|
self,
|
75
60
|
code_block: CodeBlock,
|
76
61
|
collection_name: EmbeddingType | str,
|
77
|
-
|
62
|
+
filename: str # perhaps this should be a relative path from the source, but for
|
78
63
|
# now we're all in 1 directory
|
79
|
-
) ->
|
64
|
+
) -> None:
|
80
65
|
"""Calculate `code_block` embedding, returning success & storing in `embedding_id`
|
81
66
|
|
82
67
|
Arguments:
|
83
68
|
code_block: CodeBlock to embed
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
Returns:
|
88
|
-
True if embedding was successful, False otherwise
|
69
|
+
collection_name: Collection to add to
|
70
|
+
filename: Name of file containing `code_block`
|
89
71
|
"""
|
90
72
|
if code_block.text:
|
91
73
|
metadatas = [
|
92
74
|
{
|
93
|
-
"type": code_block.
|
94
|
-
"
|
75
|
+
"type": code_block.node_type,
|
76
|
+
"id": code_block.id,
|
77
|
+
"name": code_block.name,
|
78
|
+
"language": code_block.language,
|
79
|
+
"filename": filename,
|
95
80
|
"tokens": code_block.tokens,
|
96
81
|
"cost": 0, # TranslatedCodeBlock has cost
|
97
82
|
},
|
98
83
|
]
|
84
|
+
if collection_name in self.config:
|
85
|
+
metadatas[0]["embedding_model"] = self.config[collection_name]
|
99
86
|
# for now, dealing with missing metadata by skipping it
|
87
|
+
if isinstance(code_block, TranslatedCodeBlock):
|
88
|
+
self._add(
|
89
|
+
code_block=code_block.original,
|
90
|
+
collection_name=collection_name,
|
91
|
+
filename=filename,
|
92
|
+
)
|
93
|
+
if code_block.original.embedding_id is not None:
|
94
|
+
metadatas[0][
|
95
|
+
"original_embedding_id"
|
96
|
+
] = code_block.original.embedding_id
|
97
|
+
metadatas[0]["cost"] = code_block.cost
|
100
98
|
if code_block.text is not None:
|
101
99
|
metadatas[0]["hash"] = hash(code_block.text)
|
102
100
|
if code_block.start_point is not None:
|
103
101
|
metadatas[0]["start_line"] = code_block.start_point[0]
|
104
102
|
if code_block.end_point is not None:
|
105
103
|
metadatas[0]["end_line"] = code_block.end_point[0]
|
104
|
+
# TODO: Add metadata about translation parameters (e.g. model)
|
106
105
|
the_text = [code_block.text]
|
107
|
-
code_block.embedding_id = self.add_text(
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
106
|
+
code_block.embedding_id = self.add_text(
|
107
|
+
collection_name,
|
108
|
+
the_text,
|
109
|
+
metadatas,
|
110
|
+
)[0]
|
112
111
|
|
113
112
|
def add_text(
|
114
113
|
self,
|
@@ -121,7 +120,7 @@ class Vectorizer(Converter):
|
|
121
120
|
metadatas, returning the embedding id
|
122
121
|
|
123
122
|
Arguments:
|
124
|
-
|
123
|
+
collection_name: Collection to add to
|
125
124
|
texts: list of texts to store
|
126
125
|
metadatas: list of metadatas to store
|
127
126
|
ids: list of embedding ids (must match lengh of texts),
|
@@ -137,20 +136,20 @@ class Vectorizer(Converter):
|
|
137
136
|
# based on the text.
|
138
137
|
ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, text)) for text in texts]
|
139
138
|
collection = self._collections.get_or_create(collection_name)
|
140
|
-
collection.
|
139
|
+
collection.add_texts(ids=ids, texts=texts, metadatas=metadatas)
|
141
140
|
return ids
|
142
141
|
|
142
|
+
@property
|
143
|
+
def config(self):
|
144
|
+
return self._collections._config
|
145
|
+
|
143
146
|
|
144
147
|
class VectorizerFactory(ABC):
|
145
148
|
"""Interface for creating a Vectorizer independent of type of ChromaDB client"""
|
146
149
|
|
147
150
|
@abstractmethod
|
148
151
|
def create_vectorizer(
|
149
|
-
self,
|
150
|
-
source_language: str,
|
151
|
-
max_tokens: None | int,
|
152
|
-
model: None | str,
|
153
|
-
path: str | Path,
|
152
|
+
self, path: str | Path, config: Dict[str, Any] = {}
|
154
153
|
) -> Vectorizer:
|
155
154
|
"""Factory method"""
|
156
155
|
|
@@ -160,19 +159,11 @@ class ChromaDBVectorizer(VectorizerFactory):
|
|
160
159
|
|
161
160
|
def create_vectorizer(
|
162
161
|
self,
|
163
|
-
source_language: str = "fortran",
|
164
|
-
max_tokens: None | int = None,
|
165
|
-
model: None | str = "gpt4all",
|
166
162
|
path: str | Path = Path.home() / ".janus" / "chroma" / "chroma-data",
|
163
|
+
config: Optional[Dict[str, Any]] = None,
|
167
164
|
) -> Vectorizer:
|
168
165
|
"""
|
169
166
|
Arguments:
|
170
|
-
source_language: The source programming language.
|
171
|
-
max_tokens: The maximum number of tokens to send to the embedding model at
|
172
|
-
once. If `None`, the `Vectorizer` will use the default value for the
|
173
|
-
`model`.
|
174
|
-
model: The name of the model to use. This will also determine the `max_tokens`
|
175
|
-
if that variable is not set.
|
176
167
|
path: The path to the ChromaDB. Can be either a string of a URL or path or a
|
177
168
|
Path object
|
178
169
|
|
@@ -180,4 +171,4 @@ class ChromaDBVectorizer(VectorizerFactory):
|
|
180
171
|
Vectorizer
|
181
172
|
"""
|
182
173
|
database = ChromaEmbeddingDatabase(path)
|
183
|
-
return Vectorizer(database,
|
174
|
+
return Vectorizer(database, config)
|
File without changes
|