janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +120 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +9 -6
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +134 -70
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
- janus_llm-2.0.0.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
janus/converter.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Any
|
|
4
4
|
from langchain.schema.language_model import BaseLanguageModel
|
5
5
|
|
6
6
|
from .language.binary import BinarySplitter
|
7
|
-
from .language.combine import Combiner
|
8
7
|
from .language.mumps import MumpsSplitter
|
9
8
|
from .language.splitter import Splitter
|
10
9
|
from .language.treesitter import TreeSitterSplitter
|
@@ -45,6 +44,8 @@ class Converter:
|
|
45
44
|
self,
|
46
45
|
source_language: str = "fortran",
|
47
46
|
max_tokens: None | int = None,
|
47
|
+
protected_node_types: set[str] | list[str] | tuple[str] = (),
|
48
|
+
prune_node_types: set[str] | list[str] | tuple[str] = (),
|
48
49
|
) -> None:
|
49
50
|
"""Initialize a Converter instance.
|
50
51
|
|
@@ -59,13 +60,15 @@ class Converter:
|
|
59
60
|
|
60
61
|
self._source_language: None | str
|
61
62
|
self._source_glob: None | str
|
63
|
+
self._protected_node_types: tuple[str] = ()
|
64
|
+
self._prune_node_types: tuple[str] = ()
|
62
65
|
self._splitter: None | Splitter
|
63
66
|
self._llm: None | BaseLanguageModel = None
|
64
67
|
self._max_tokens: None | int = max_tokens
|
65
68
|
|
66
|
-
self.
|
67
|
-
|
68
|
-
self.
|
69
|
+
self.set_source_language(source_language)
|
70
|
+
self.set_protected_node_types(protected_node_types)
|
71
|
+
self.set_prune_node_types(prune_node_types)
|
69
72
|
|
70
73
|
# Child class must call this. Should we enforce somehow?
|
71
74
|
# self._load_parameters()
|
@@ -86,7 +89,7 @@ class Converter:
|
|
86
89
|
def set_source_language(self, source_language: str) -> None:
|
87
90
|
"""Validate and set the source language.
|
88
91
|
|
89
|
-
The affected objects will not be updated until
|
92
|
+
The affected objects will not be updated until _load_parameters() is called.
|
90
93
|
|
91
94
|
Arguments:
|
92
95
|
source_language: The source programming language.
|
@@ -101,27 +104,55 @@ class Converter:
|
|
101
104
|
self._source_glob = f"**/*.{LANGUAGES[source_language]['suffix']}"
|
102
105
|
self._source_language = source_language
|
103
106
|
|
104
|
-
|
107
|
+
def set_protected_node_types(
|
108
|
+
self, protected_node_types: set[str] | list[str] | tuple[str]
|
109
|
+
) -> None:
|
110
|
+
"""Set the protected (non-mergeable) node types. This will often be structures
|
111
|
+
like functions, classes, or modules which you might want to keep separate
|
112
|
+
|
113
|
+
The affected objects will not be updated until _load_parameters() is called.
|
114
|
+
|
115
|
+
Arguments:
|
116
|
+
protected_node_types: A set of node types that aren't to be merged
|
117
|
+
"""
|
118
|
+
self._protected_node_types = tuple(set(protected_node_types or []))
|
119
|
+
|
120
|
+
def set_prune_node_types(
|
121
|
+
self, prune_node_types: set[str] | list[str] | tuple[str]
|
122
|
+
) -> None:
|
123
|
+
"""Set the node types to prune. This will often be structures
|
124
|
+
like comments or whitespace which you might want to keep out of the LLM
|
125
|
+
|
126
|
+
The affected objects will not be updated until _load_parameters() is called.
|
127
|
+
|
128
|
+
Arguments:
|
129
|
+
prune_node_types: A set of node types which should be pruned
|
130
|
+
"""
|
131
|
+
self._prune_node_types = tuple(set(prune_node_types or []))
|
132
|
+
|
133
|
+
@run_if_changed(
|
134
|
+
"_source_language",
|
135
|
+
"_max_tokens",
|
136
|
+
"_llm",
|
137
|
+
"_protected_node_types",
|
138
|
+
"_prune_node_types",
|
139
|
+
)
|
105
140
|
def _load_splitter(self) -> None:
|
106
141
|
"""Load the splitter according to this instance's attributes.
|
107
142
|
|
108
143
|
If the relevant fields have not been changed since the last time this method was
|
109
144
|
called, nothing happens.
|
110
145
|
"""
|
146
|
+
kwargs = dict(
|
147
|
+
max_tokens=self._max_tokens,
|
148
|
+
model=self._llm,
|
149
|
+
protected_node_types=self._protected_node_types,
|
150
|
+
prune_node_types=self._prune_node_types,
|
151
|
+
)
|
111
152
|
if self._source_language in CUSTOM_SPLITTERS:
|
112
153
|
if self._source_language == "mumps":
|
113
|
-
self._splitter = MumpsSplitter(
|
114
|
-
max_tokens=self._max_tokens,
|
115
|
-
model=self._llm,
|
116
|
-
)
|
154
|
+
self._splitter = MumpsSplitter(**kwargs)
|
117
155
|
elif self._source_language == "binary":
|
118
|
-
self._splitter = BinarySplitter(
|
119
|
-
max_tokens=self._max_tokens,
|
120
|
-
model=self._llm,
|
121
|
-
)
|
156
|
+
self._splitter = BinarySplitter(**kwargs)
|
122
157
|
else:
|
123
|
-
self._splitter = TreeSitterSplitter(
|
124
|
-
language=self._source_language,
|
125
|
-
max_tokens=self._max_tokens,
|
126
|
-
model=self._llm,
|
127
|
-
)
|
158
|
+
self._splitter = TreeSitterSplitter(language=self._source_language, **kwargs)
|
@@ -14,9 +14,7 @@ class TestCollections(unittest.TestCase):
|
|
14
14
|
self.collections = Collections(self._db)
|
15
15
|
|
16
16
|
def test_creation(self):
|
17
|
-
self.
|
18
|
-
|
19
|
-
result = self.collections.create(EmbeddingType.PSEUDO)
|
17
|
+
self.collections.create(EmbeddingType.PSEUDO)
|
20
18
|
|
21
19
|
metadata = {
|
22
20
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
@@ -24,12 +22,9 @@ class TestCollections(unittest.TestCase):
|
|
24
22
|
}
|
25
23
|
|
26
24
|
self._db.create_collection.assert_called_with("pseudo_1", metadata=metadata)
|
27
|
-
self.assertEqual(result, "foo")
|
28
25
|
|
29
26
|
def test_creation_triangulation(self):
|
30
|
-
self.
|
31
|
-
|
32
|
-
result = self.collections.create(EmbeddingType.REQUIREMENT)
|
27
|
+
self.collections.create(EmbeddingType.REQUIREMENT)
|
33
28
|
|
34
29
|
metadata = {
|
35
30
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
@@ -37,7 +32,6 @@ class TestCollections(unittest.TestCase):
|
|
37
32
|
}
|
38
33
|
|
39
34
|
self._db.create_collection.assert_called_with("requirement_1", metadata=metadata)
|
40
|
-
self.assertEqual(result, [])
|
41
35
|
|
42
36
|
def test_creation_of_existing_type(self):
|
43
37
|
mock_collection = MagicMock()
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import unittest
|
2
|
+
from pathlib import Path
|
3
|
+
from unittest.mock import patch
|
4
|
+
|
5
|
+
from ..database import ChromaEmbeddingDatabase, uri_to_path
|
6
|
+
|
7
|
+
|
8
|
+
class TestDatabase(unittest.TestCase):
|
9
|
+
def test_uri_to_path(self):
|
10
|
+
uri = (Path.home().expanduser() / "Documents" / "testfile.txt").as_uri()
|
11
|
+
expected_path = Path.home().expanduser() / "Documents" / "testfile.txt"
|
12
|
+
self.assertEqual(uri_to_path(uri), expected_path)
|
13
|
+
|
14
|
+
@patch("chromadb.PersistentClient", autospec=True)
|
15
|
+
def test_ChromaEmbeddingDatabase(self, mock_client):
|
16
|
+
# Test with default path
|
17
|
+
_ = ChromaEmbeddingDatabase()
|
18
|
+
mock_client.assert_called_once()
|
19
|
+
|
20
|
+
# Test with custom path
|
21
|
+
custom_path = "/custom/path/to/chroma-data"
|
22
|
+
_ = ChromaEmbeddingDatabase(custom_path)
|
23
|
+
mock_client.assert_called()
|
24
|
+
|
25
|
+
# Test with URL
|
26
|
+
url = "http://example.com/chroma-data"
|
27
|
+
_ = ChromaEmbeddingDatabase(url)
|
28
|
+
mock_client.assert_called()
|
29
|
+
|
30
|
+
|
31
|
+
if __name__ == "__main__":
|
32
|
+
unittest.main()
|
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
|
5
5
|
|
6
6
|
from chromadb.api.client import Client
|
7
7
|
|
8
|
+
from ...language.treesitter import TreeSitterSplitter
|
8
9
|
from ...utils.enums import EmbeddingType
|
9
10
|
from ..vectorize import Vectorizer, VectorizerFactory
|
10
11
|
|
@@ -22,7 +23,7 @@ class MockDBVectorizer(VectorizerFactory):
|
|
22
23
|
model: None | str = "gpt4all",
|
23
24
|
path: str | Path = None,
|
24
25
|
) -> Vectorizer:
|
25
|
-
return Vectorizer(self._db
|
26
|
+
return Vectorizer(self._db)
|
26
27
|
|
27
28
|
|
28
29
|
class TestVectorize(unittest.TestCase):
|
@@ -35,7 +36,11 @@ class TestVectorize(unittest.TestCase):
|
|
35
36
|
self.database.list_collections = list_collections
|
36
37
|
self.vectorizer = MockDBVectorizer(self.database).create_vectorizer()
|
37
38
|
self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
|
38
|
-
|
39
|
+
splitter = TreeSitterSplitter(
|
40
|
+
language="fortran",
|
41
|
+
max_tokens=16_384,
|
42
|
+
)
|
43
|
+
self.test_block = splitter.split(self.test_file)
|
39
44
|
|
40
45
|
def test_add_nodes_recursively(self):
|
41
46
|
embedding_type = EmbeddingType.SOURCE
|
@@ -45,9 +50,9 @@ class TestVectorize(unittest.TestCase):
|
|
45
50
|
"time_updated": datetime.datetime.now().time().isoformat("minutes"),
|
46
51
|
}
|
47
52
|
self.database.create_collection.assert_called_with("source_1", metadata=metadata)
|
48
|
-
self.vectorizer.
|
53
|
+
self.vectorizer.add_nodes_recursively(
|
49
54
|
self.test_block, embedding_type, self.test_file.name
|
50
55
|
)
|
51
56
|
self.database.get_or_create_collection.assert_called_with(
|
52
|
-
"source_1", metadata=metadata
|
57
|
+
name="source_1", embedding_function=None, metadata=metadata
|
53
58
|
)
|
janus/embedding/collections.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
import datetime
|
2
2
|
import os
|
3
|
-
from typing import Sequence
|
3
|
+
from typing import Dict, Optional, Sequence
|
4
4
|
|
5
5
|
from chromadb import Client, Collection
|
6
|
+
from langchain_community.vectorstores import Chroma
|
6
7
|
|
7
8
|
from ..utils.enums import EmbeddingType
|
9
|
+
from .embedding_models_info import load_embedding_model
|
8
10
|
|
9
11
|
# See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
|
10
12
|
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
@@ -13,10 +15,16 @@ os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
|
13
15
|
class Collections:
|
14
16
|
"""Manage embedding collections"""
|
15
17
|
|
16
|
-
def __init__(self, client: Client):
|
18
|
+
def __init__(self, client: Client, config: Optional[Dict[str, str]] = None):
|
17
19
|
self._client = client
|
20
|
+
if config is not None:
|
21
|
+
self._config = config
|
22
|
+
else:
|
23
|
+
self._config = {}
|
18
24
|
|
19
|
-
def create(
|
25
|
+
def create(
|
26
|
+
self, name: EmbeddingType | str, model_name: Optional[str] = None
|
27
|
+
) -> Chroma:
|
20
28
|
"""Create a Chroma collection for the given embedding type.
|
21
29
|
|
22
30
|
Arguments:
|
@@ -27,9 +35,23 @@ class Collections:
|
|
27
35
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
28
36
|
"time_updated": datetime.datetime.now().time().isoformat("minutes"),
|
29
37
|
}
|
30
|
-
|
38
|
+
if model_name is not None:
|
39
|
+
metadata["embedding_model"] = model_name
|
40
|
+
self._client.create_collection(collection_name, metadata=metadata)
|
41
|
+
self._config[collection_name] = model_name
|
42
|
+
model, _, _ = load_embedding_model(model_name)
|
43
|
+
return Chroma(
|
44
|
+
client=self._client,
|
45
|
+
collection_name=collection_name,
|
46
|
+
embedding_function=model,
|
47
|
+
)
|
48
|
+
else:
|
49
|
+
self._client.create_collection(collection_name, metadata=metadata)
|
50
|
+
return Chroma(client=self._client, collection_name=collection_name)
|
31
51
|
|
32
|
-
def get_or_create(
|
52
|
+
def get_or_create(
|
53
|
+
self, name: EmbeddingType | str, model_name: Optional[str] = None
|
54
|
+
) -> Chroma:
|
33
55
|
"""Create a Chroma collection for the given embedding type.
|
34
56
|
|
35
57
|
Arguments:
|
@@ -40,7 +62,26 @@ class Collections:
|
|
40
62
|
"date_updated": datetime.datetime.now().date().isoformat(),
|
41
63
|
"time_updated": datetime.datetime.now().time().isoformat("minutes"),
|
42
64
|
}
|
43
|
-
|
65
|
+
if collection_name in self._config:
|
66
|
+
model_name = self._config[collection_name]
|
67
|
+
if model_name is not None:
|
68
|
+
metadata["embedding_model"] = model_name
|
69
|
+
self._config[collection_name] = model_name
|
70
|
+
model, _, _ = load_embedding_model(model_name)
|
71
|
+
self._client.get_or_create_collection(collection_name, metadata=metadata)
|
72
|
+
return Chroma(
|
73
|
+
client=self._client,
|
74
|
+
collection_name=collection_name,
|
75
|
+
embedding_function=model,
|
76
|
+
collection_metadata=metadata,
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
self._client.get_or_create_collection(collection_name, metadata=metadata)
|
80
|
+
return Chroma(
|
81
|
+
client=self._client,
|
82
|
+
collection_name=collection_name,
|
83
|
+
collection_metadata=metadata,
|
84
|
+
)
|
44
85
|
|
45
86
|
def get(self, name: None | EmbeddingType | str = None) -> Sequence[Collection]:
|
46
87
|
"""Get the Chroma collections.
|
@@ -61,6 +102,8 @@ class Collections:
|
|
61
102
|
collection_name = name.name.lower()
|
62
103
|
else:
|
63
104
|
collection_name = name
|
105
|
+
if collection_name in self._config:
|
106
|
+
del self._config[collection_name]
|
64
107
|
self._client.delete_collection(collection_name)
|
65
108
|
|
66
109
|
def _set_collection_name(self, name: EmbeddingType | str) -> str:
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import json
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Any, Callable, Dict, Tuple
|
4
|
+
|
5
|
+
from aenum import MultiValueEnum
|
6
|
+
from dotenv import load_dotenv
|
7
|
+
from langchain_community.embeddings.huggingface import (
|
8
|
+
HuggingFaceEmbeddings,
|
9
|
+
HuggingFaceInferenceAPIEmbeddings,
|
10
|
+
)
|
11
|
+
from langchain_core.embeddings import Embeddings
|
12
|
+
from langchain_openai import OpenAIEmbeddings
|
13
|
+
|
14
|
+
from janus.utils.logger import create_logger
|
15
|
+
|
16
|
+
load_dotenv()
|
17
|
+
|
18
|
+
log = create_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
class EmbeddingModelType(MultiValueEnum):
|
22
|
+
OpenAI = "OpenAI", "openai", "open-ai", "oai"
|
23
|
+
HuggingFaceLocal = "HuggingFaceLocal", "huggingfacelocal", "huggingface-local", "hfl"
|
24
|
+
HuggingFaceInferenceAPI = (
|
25
|
+
"HuggingFaceInferenceAPI",
|
26
|
+
"huggingfaceinferenceapi",
|
27
|
+
"huggingface-inference-api",
|
28
|
+
"hfia",
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS: Dict[
|
33
|
+
EmbeddingModelType, Callable[[Any], Embeddings]
|
34
|
+
] = {}
|
35
|
+
|
36
|
+
for model_type in EmbeddingModelType:
|
37
|
+
for value in model_type.values:
|
38
|
+
if model_type == EmbeddingModelType.OpenAI:
|
39
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = OpenAIEmbeddings
|
40
|
+
elif model_type == EmbeddingModelType.HuggingFaceLocal:
|
41
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = HuggingFaceEmbeddings
|
42
|
+
elif model_type == EmbeddingModelType.HuggingFaceInferenceAPI:
|
43
|
+
EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = HuggingFaceInferenceAPIEmbeddings
|
44
|
+
|
45
|
+
EMBEDDING_MODEL_TYPE_DEFAULT_IDS: Dict[EmbeddingModelType, Dict[str, Any]] = {
|
46
|
+
EmbeddingModelType.OpenAI.value: "text-embedding-3-small",
|
47
|
+
EmbeddingModelType.HuggingFaceLocal.value: "all-MiniLM-L6-v2",
|
48
|
+
EmbeddingModelType.HuggingFaceInferenceAPI.value: "",
|
49
|
+
}
|
50
|
+
|
51
|
+
EMBEDDING_MODEL_DEFAULT_ARGUMENTS: Dict[str, Dict[str, Any]] = {
|
52
|
+
"text-embedding-3-small": dict(model="text-embedding-3-small"),
|
53
|
+
"text-embedding-3-large": dict(model="text-embedding-3-large"),
|
54
|
+
"text-embedding-ada-002": dict(model="text-embedding-ada-002"),
|
55
|
+
}
|
56
|
+
|
57
|
+
EMBEDDING_MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "embeddings"
|
58
|
+
|
59
|
+
EMBEDDING_TOKEN_LIMITS: Dict[str, int] = {
|
60
|
+
"text-embedding-3-small": 8191,
|
61
|
+
"text-embedding-3-large": 8191,
|
62
|
+
"text-embedding-ada-002": 8191,
|
63
|
+
}
|
64
|
+
|
65
|
+
EMBEDDING_COST_PER_MODEL: Dict[str, float] = {
|
66
|
+
"text-embedding-3-small": {"input": 0.02 / 1e6, "output": 0.0},
|
67
|
+
"text-embedding-3-large": {"input": 0.13 / 1e6, "output": 0.0},
|
68
|
+
"text-embedding-ada-002": {"input": 0.10 / 1e6, "output": 0.0},
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
def load_embedding_model(
|
73
|
+
model_name: str,
|
74
|
+
) -> Tuple[Embeddings, int, Dict[str, float]]:
|
75
|
+
"""Load an embedding model from the configuration file or create a new one
|
76
|
+
|
77
|
+
Arguments:
|
78
|
+
model_name: The user-given name of the model to load.
|
79
|
+
model_type: The type of the model to load.
|
80
|
+
identifier: The identifier for the model (e.g. the name, URL, or HuggingFace
|
81
|
+
path).
|
82
|
+
"""
|
83
|
+
if not EMBEDDING_MODEL_CONFIG_DIR.exists():
|
84
|
+
EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
|
85
|
+
model_config_file = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
|
86
|
+
|
87
|
+
if not model_config_file.exists():
|
88
|
+
# The default model type is HuggingFaceLocal because that's the default for Chroma
|
89
|
+
model_type = EmbeddingModelType.HuggingFaceLocal.value
|
90
|
+
identifier = EMBEDDING_MODEL_TYPE_DEFAULT_IDS[model_type]
|
91
|
+
model_config = {
|
92
|
+
"model_type": model_type,
|
93
|
+
"model_identifier": identifier,
|
94
|
+
"model_args": EMBEDDING_MODEL_DEFAULT_ARGUMENTS.get(identifier, {}),
|
95
|
+
"token_limit": EMBEDDING_TOKEN_LIMITS.get(identifier, 8191),
|
96
|
+
"model_cost": EMBEDDING_COST_PER_MODEL.get(
|
97
|
+
identifier, {"input": 0, "output": 0}
|
98
|
+
),
|
99
|
+
}
|
100
|
+
log.info(
|
101
|
+
f"WARNING: Creating new model config file: \
|
102
|
+
{model_config_file} with default config"
|
103
|
+
)
|
104
|
+
with open(model_config_file, "w") as f:
|
105
|
+
json.dump(model_config, f, indent=2)
|
106
|
+
else:
|
107
|
+
with open(model_config_file, "r") as f:
|
108
|
+
model_config = json.load(f)
|
109
|
+
model_constructor = EMBEDDING_MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
|
110
|
+
model_args = model_config["model_args"]
|
111
|
+
if model_config["model_type"] in EmbeddingModelType.HuggingFaceInferenceAPI.values:
|
112
|
+
model_args.update({"api_url": model_config["model_identifier"]})
|
113
|
+
elif model_config["model_type"] in EmbeddingModelType.HuggingFaceLocal.values:
|
114
|
+
model_args.update({"model_name": model_config["model_identifier"]})
|
115
|
+
model = model_constructor(**model_args)
|
116
|
+
return (
|
117
|
+
model,
|
118
|
+
model_config["token_limit"],
|
119
|
+
model_config["model_cost"],
|
120
|
+
)
|
janus/embedding/vectorize.py
CHANGED
@@ -1,67 +1,52 @@
|
|
1
1
|
import uuid
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Sequence
|
4
|
+
from typing import Any, Dict, Optional, Sequence
|
5
5
|
|
6
6
|
from chromadb import Client, Collection
|
7
|
+
from langchain_community.vectorstores import Chroma
|
7
8
|
|
8
|
-
from ..
|
9
|
-
from ..language.block import CodeBlock
|
10
|
-
from ..llm.models_info import TOKEN_LIMITS
|
9
|
+
from ..language.block import CodeBlock, TranslatedCodeBlock
|
11
10
|
from ..utils.enums import EmbeddingType
|
12
11
|
from .collections import Collections
|
13
12
|
from .database import ChromaEmbeddingDatabase
|
14
13
|
|
15
14
|
|
16
|
-
class Vectorizer(
|
15
|
+
class Vectorizer(object):
|
17
16
|
"""Class for creating embeddings/vectors in a specified ChromaDB"""
|
18
17
|
|
19
|
-
def __init__(
|
20
|
-
self,
|
21
|
-
client: Client,
|
22
|
-
source_language: str,
|
23
|
-
max_tokens: None | int,
|
24
|
-
model: None | str,
|
25
|
-
) -> None:
|
18
|
+
def __init__(self, client: Client, config: Optional[Dict[str, Any]] = None) -> None:
|
26
19
|
"""Initializes the Vectorizer class
|
27
20
|
|
28
21
|
Arguments:
|
29
22
|
client: ChromaDB client instance
|
30
|
-
source_language: The source programming language.
|
31
|
-
max_tokens: The maximum number of tokens to send to the embedding model at
|
32
|
-
once. If `None`, the `Vectorizer` will use the default value for the
|
33
|
-
`model`.
|
34
|
-
model: The name of the model to use. This will also determine the `max_tokens`
|
35
|
-
if that variable is not set.
|
36
23
|
"""
|
37
|
-
if max_tokens is None:
|
38
|
-
max_tokens = TOKEN_LIMITS[model]
|
39
|
-
|
40
|
-
super().__init__(
|
41
|
-
source_language=source_language,
|
42
|
-
max_tokens=max_tokens,
|
43
|
-
)
|
44
24
|
self._db = client
|
45
|
-
self._collections = Collections(self._db)
|
25
|
+
self._collections = Collections(self._db, config)
|
46
26
|
|
47
|
-
|
27
|
+
def get_or_create_collection(
|
28
|
+
self, name: EmbeddingType | str, model_name: Optional[str] = None
|
29
|
+
) -> Chroma:
|
30
|
+
return self._collections.get_or_create(name, model_name=model_name)
|
48
31
|
|
49
|
-
def create_collection(
|
50
|
-
|
32
|
+
def create_collection(
|
33
|
+
self, embedding_type: EmbeddingType, model_name: Optional[str] = None
|
34
|
+
) -> Chroma:
|
35
|
+
return self._collections.create(embedding_type, model_name=model_name)
|
51
36
|
|
52
37
|
def collections(
|
53
38
|
self, name: None | EmbeddingType | str = None
|
54
39
|
) -> Sequence[Collection]:
|
55
40
|
return self._collections.get(name)
|
56
41
|
|
57
|
-
def
|
42
|
+
def add_nodes_recursively(
|
58
43
|
self, code_block: CodeBlock, collection_name: EmbeddingType | str, file_name: str
|
59
44
|
) -> None:
|
60
45
|
"""Embed all nodes in the tree rooted at `code_block`
|
61
46
|
|
62
47
|
Arguments:
|
63
48
|
code_block: CodeBlock to embed
|
64
|
-
|
49
|
+
collection_name: Collection to add to
|
65
50
|
file_name: Name of file containing `code_block`
|
66
51
|
"""
|
67
52
|
nodes = [code_block]
|
@@ -74,41 +59,55 @@ class Vectorizer(Converter):
|
|
74
59
|
self,
|
75
60
|
code_block: CodeBlock,
|
76
61
|
collection_name: EmbeddingType | str,
|
77
|
-
|
62
|
+
filename: str # perhaps this should be a relative path from the source, but for
|
78
63
|
# now we're all in 1 directory
|
79
|
-
) ->
|
64
|
+
) -> None:
|
80
65
|
"""Calculate `code_block` embedding, returning success & storing in `embedding_id`
|
81
66
|
|
82
67
|
Arguments:
|
83
68
|
code_block: CodeBlock to embed
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
Returns:
|
88
|
-
True if embedding was successful, False otherwise
|
69
|
+
collection_name: Collection to add to
|
70
|
+
filename: Name of file containing `code_block`
|
89
71
|
"""
|
90
72
|
if code_block.text:
|
91
73
|
metadatas = [
|
92
74
|
{
|
93
|
-
"type": code_block.
|
94
|
-
"
|
75
|
+
"type": code_block.node_type,
|
76
|
+
"id": code_block.id,
|
77
|
+
"name": code_block.name,
|
78
|
+
"language": code_block.language,
|
79
|
+
"filename": filename,
|
95
80
|
"tokens": code_block.tokens,
|
96
81
|
"cost": 0, # TranslatedCodeBlock has cost
|
97
82
|
},
|
98
83
|
]
|
84
|
+
if collection_name in self.config:
|
85
|
+
metadatas[0]["embedding_model"] = self.config[collection_name]
|
99
86
|
# for now, dealing with missing metadata by skipping it
|
87
|
+
if isinstance(code_block, TranslatedCodeBlock):
|
88
|
+
self._add(
|
89
|
+
code_block=code_block.original,
|
90
|
+
collection_name=collection_name,
|
91
|
+
filename=filename,
|
92
|
+
)
|
93
|
+
if code_block.original.embedding_id is not None:
|
94
|
+
metadatas[0][
|
95
|
+
"original_embedding_id"
|
96
|
+
] = code_block.original.embedding_id
|
97
|
+
metadatas[0]["cost"] = code_block.cost
|
100
98
|
if code_block.text is not None:
|
101
99
|
metadatas[0]["hash"] = hash(code_block.text)
|
102
100
|
if code_block.start_point is not None:
|
103
101
|
metadatas[0]["start_line"] = code_block.start_point[0]
|
104
102
|
if code_block.end_point is not None:
|
105
103
|
metadatas[0]["end_line"] = code_block.end_point[0]
|
104
|
+
# TODO: Add metadata about translation parameters (e.g. model)
|
106
105
|
the_text = [code_block.text]
|
107
|
-
code_block.embedding_id = self.add_text(
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
106
|
+
code_block.embedding_id = self.add_text(
|
107
|
+
collection_name,
|
108
|
+
the_text,
|
109
|
+
metadatas,
|
110
|
+
)[0]
|
112
111
|
|
113
112
|
def add_text(
|
114
113
|
self,
|
@@ -121,7 +120,7 @@ class Vectorizer(Converter):
|
|
121
120
|
metadatas, returning the embedding id
|
122
121
|
|
123
122
|
Arguments:
|
124
|
-
|
123
|
+
collection_name: Collection to add to
|
125
124
|
texts: list of texts to store
|
126
125
|
metadatas: list of metadatas to store
|
127
126
|
ids: list of embedding ids (must match lengh of texts),
|
@@ -137,20 +136,20 @@ class Vectorizer(Converter):
|
|
137
136
|
# based on the text.
|
138
137
|
ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, text)) for text in texts]
|
139
138
|
collection = self._collections.get_or_create(collection_name)
|
140
|
-
collection.
|
139
|
+
collection.add_texts(ids=ids, texts=texts, metadatas=metadatas)
|
141
140
|
return ids
|
142
141
|
|
142
|
+
@property
|
143
|
+
def config(self):
|
144
|
+
return self._collections._config
|
145
|
+
|
143
146
|
|
144
147
|
class VectorizerFactory(ABC):
|
145
148
|
"""Interface for creating a Vectorizer independent of type of ChromaDB client"""
|
146
149
|
|
147
150
|
@abstractmethod
|
148
151
|
def create_vectorizer(
|
149
|
-
self,
|
150
|
-
source_language: str,
|
151
|
-
max_tokens: None | int,
|
152
|
-
model: None | str,
|
153
|
-
path: str | Path,
|
152
|
+
self, path: str | Path, config: Dict[str, Any] = {}
|
154
153
|
) -> Vectorizer:
|
155
154
|
"""Factory method"""
|
156
155
|
|
@@ -160,19 +159,11 @@ class ChromaDBVectorizer(VectorizerFactory):
|
|
160
159
|
|
161
160
|
def create_vectorizer(
|
162
161
|
self,
|
163
|
-
source_language: str = "fortran",
|
164
|
-
max_tokens: None | int = None,
|
165
|
-
model: None | str = "gpt4all",
|
166
162
|
path: str | Path = Path.home() / ".janus" / "chroma" / "chroma-data",
|
163
|
+
config: Optional[Dict[str, Any]] = None,
|
167
164
|
) -> Vectorizer:
|
168
165
|
"""
|
169
166
|
Arguments:
|
170
|
-
source_language: The source programming language.
|
171
|
-
max_tokens: The maximum number of tokens to send to the embedding model at
|
172
|
-
once. If `None`, the `Vectorizer` will use the default value for the
|
173
|
-
`model`.
|
174
|
-
model: The name of the model to use. This will also determine the `max_tokens`
|
175
|
-
if that variable is not set.
|
176
167
|
path: The path to the ChromaDB. Can be either a string of a URL or path or a
|
177
168
|
Path object
|
178
169
|
|
@@ -180,4 +171,4 @@ class ChromaDBVectorizer(VectorizerFactory):
|
|
180
171
|
Vectorizer
|
181
172
|
"""
|
182
173
|
database = ChromaEmbeddingDatabase(path)
|
183
|
-
return Vectorizer(database,
|
174
|
+
return Vectorizer(database, config)
|
File without changes
|