janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +120 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +9 -6
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +134 -70
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
  70. janus_llm-2.0.0.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
janus/converter.py CHANGED
@@ -4,7 +4,6 @@ from typing import Any
4
4
  from langchain.schema.language_model import BaseLanguageModel
5
5
 
6
6
  from .language.binary import BinarySplitter
7
- from .language.combine import Combiner
8
7
  from .language.mumps import MumpsSplitter
9
8
  from .language.splitter import Splitter
10
9
  from .language.treesitter import TreeSitterSplitter
@@ -45,6 +44,8 @@ class Converter:
45
44
  self,
46
45
  source_language: str = "fortran",
47
46
  max_tokens: None | int = None,
47
+ protected_node_types: set[str] | list[str] | tuple[str] = (),
48
+ prune_node_types: set[str] | list[str] | tuple[str] = (),
48
49
  ) -> None:
49
50
  """Initialize a Converter instance.
50
51
 
@@ -59,13 +60,15 @@ class Converter:
59
60
 
60
61
  self._source_language: None | str
61
62
  self._source_glob: None | str
63
+ self._protected_node_types: tuple[str] = ()
64
+ self._prune_node_types: tuple[str] = ()
62
65
  self._splitter: None | Splitter
63
66
  self._llm: None | BaseLanguageModel = None
64
67
  self._max_tokens: None | int = max_tokens
65
68
 
66
- self._combiner: Combiner = Combiner()
67
-
68
- self.set_source_language(source_language=source_language)
69
+ self.set_source_language(source_language)
70
+ self.set_protected_node_types(protected_node_types)
71
+ self.set_prune_node_types(prune_node_types)
69
72
 
70
73
  # Child class must call this. Should we enforce somehow?
71
74
  # self._load_parameters()
@@ -86,7 +89,7 @@ class Converter:
86
89
  def set_source_language(self, source_language: str) -> None:
87
90
  """Validate and set the source language.
88
91
 
89
- The affected objects will not be updated until translate() is called.
92
+ The affected objects will not be updated until _load_parameters() is called.
90
93
 
91
94
  Arguments:
92
95
  source_language: The source programming language.
@@ -101,27 +104,55 @@ class Converter:
101
104
  self._source_glob = f"**/*.{LANGUAGES[source_language]['suffix']}"
102
105
  self._source_language = source_language
103
106
 
104
- @run_if_changed("_source_language", "_max_tokens", "_llm")
107
+ def set_protected_node_types(
108
+ self, protected_node_types: set[str] | list[str] | tuple[str]
109
+ ) -> None:
110
+ """Set the protected (non-mergeable) node types. This will often be structures
111
+ like functions, classes, or modules which you might want to keep separate
112
+
113
+ The affected objects will not be updated until _load_parameters() is called.
114
+
115
+ Arguments:
116
+ protected_node_types: A set of node types that aren't to be merged
117
+ """
118
+ self._protected_node_types = tuple(set(protected_node_types or []))
119
+
120
+ def set_prune_node_types(
121
+ self, prune_node_types: set[str] | list[str] | tuple[str]
122
+ ) -> None:
123
+ """Set the node types to prune. This will often be structures
124
+ like comments or whitespace which you might want to keep out of the LLM
125
+
126
+ The affected objects will not be updated until _load_parameters() is called.
127
+
128
+ Arguments:
129
+ prune_node_types: A set of node types which should be pruned
130
+ """
131
+ self._prune_node_types = tuple(set(prune_node_types or []))
132
+
133
+ @run_if_changed(
134
+ "_source_language",
135
+ "_max_tokens",
136
+ "_llm",
137
+ "_protected_node_types",
138
+ "_prune_node_types",
139
+ )
105
140
  def _load_splitter(self) -> None:
106
141
  """Load the splitter according to this instance's attributes.
107
142
 
108
143
  If the relevant fields have not been changed since the last time this method was
109
144
  called, nothing happens.
110
145
  """
146
+ kwargs = dict(
147
+ max_tokens=self._max_tokens,
148
+ model=self._llm,
149
+ protected_node_types=self._protected_node_types,
150
+ prune_node_types=self._prune_node_types,
151
+ )
111
152
  if self._source_language in CUSTOM_SPLITTERS:
112
153
  if self._source_language == "mumps":
113
- self._splitter = MumpsSplitter(
114
- max_tokens=self._max_tokens,
115
- model=self._llm,
116
- )
154
+ self._splitter = MumpsSplitter(**kwargs)
117
155
  elif self._source_language == "binary":
118
- self._splitter = BinarySplitter(
119
- max_tokens=self._max_tokens,
120
- model=self._llm,
121
- )
156
+ self._splitter = BinarySplitter(**kwargs)
122
157
  else:
123
- self._splitter = TreeSitterSplitter(
124
- language=self._source_language,
125
- max_tokens=self._max_tokens,
126
- model=self._llm,
127
- )
158
+ self._splitter = TreeSitterSplitter(language=self._source_language, **kwargs)
@@ -14,9 +14,7 @@ class TestCollections(unittest.TestCase):
14
14
  self.collections = Collections(self._db)
15
15
 
16
16
  def test_creation(self):
17
- self._db.create_collection.return_value = "foo"
18
-
19
- result = self.collections.create(EmbeddingType.PSEUDO)
17
+ self.collections.create(EmbeddingType.PSEUDO)
20
18
 
21
19
  metadata = {
22
20
  "date_updated": datetime.datetime.now().date().isoformat(),
@@ -24,12 +22,9 @@ class TestCollections(unittest.TestCase):
24
22
  }
25
23
 
26
24
  self._db.create_collection.assert_called_with("pseudo_1", metadata=metadata)
27
- self.assertEqual(result, "foo")
28
25
 
29
26
  def test_creation_triangulation(self):
30
- self._db.create_collection.return_value = []
31
-
32
- result = self.collections.create(EmbeddingType.REQUIREMENT)
27
+ self.collections.create(EmbeddingType.REQUIREMENT)
33
28
 
34
29
  metadata = {
35
30
  "date_updated": datetime.datetime.now().date().isoformat(),
@@ -37,7 +32,6 @@ class TestCollections(unittest.TestCase):
37
32
  }
38
33
 
39
34
  self._db.create_collection.assert_called_with("requirement_1", metadata=metadata)
40
- self.assertEqual(result, [])
41
35
 
42
36
  def test_creation_of_existing_type(self):
43
37
  mock_collection = MagicMock()
@@ -0,0 +1,32 @@
1
+ import unittest
2
+ from pathlib import Path
3
+ from unittest.mock import patch
4
+
5
+ from ..database import ChromaEmbeddingDatabase, uri_to_path
6
+
7
+
8
+ class TestDatabase(unittest.TestCase):
9
+ def test_uri_to_path(self):
10
+ uri = (Path.home().expanduser() / "Documents" / "testfile.txt").as_uri()
11
+ expected_path = Path.home().expanduser() / "Documents" / "testfile.txt"
12
+ self.assertEqual(uri_to_path(uri), expected_path)
13
+
14
+ @patch("chromadb.PersistentClient", autospec=True)
15
+ def test_ChromaEmbeddingDatabase(self, mock_client):
16
+ # Test with default path
17
+ _ = ChromaEmbeddingDatabase()
18
+ mock_client.assert_called_once()
19
+
20
+ # Test with custom path
21
+ custom_path = "/custom/path/to/chroma-data"
22
+ _ = ChromaEmbeddingDatabase(custom_path)
23
+ mock_client.assert_called()
24
+
25
+ # Test with URL
26
+ url = "http://example.com/chroma-data"
27
+ _ = ChromaEmbeddingDatabase(url)
28
+ mock_client.assert_called()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ unittest.main()
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock
5
5
 
6
6
  from chromadb.api.client import Client
7
7
 
8
+ from ...language.treesitter import TreeSitterSplitter
8
9
  from ...utils.enums import EmbeddingType
9
10
  from ..vectorize import Vectorizer, VectorizerFactory
10
11
 
@@ -22,7 +23,7 @@ class MockDBVectorizer(VectorizerFactory):
22
23
  model: None | str = "gpt4all",
23
24
  path: str | Path = None,
24
25
  ) -> Vectorizer:
25
- return Vectorizer(self._db, source_language, max_tokens, model)
26
+ return Vectorizer(self._db)
26
27
 
27
28
 
28
29
  class TestVectorize(unittest.TestCase):
@@ -35,7 +36,11 @@ class TestVectorize(unittest.TestCase):
35
36
  self.database.list_collections = list_collections
36
37
  self.vectorizer = MockDBVectorizer(self.database).create_vectorizer()
37
38
  self.test_file = Path("janus/language/treesitter/_tests/languages/fortran.f90")
38
- self.test_block = self.vectorizer._splitter.split(self.test_file)
39
+ splitter = TreeSitterSplitter(
40
+ language="fortran",
41
+ max_tokens=16_384,
42
+ )
43
+ self.test_block = splitter.split(self.test_file)
39
44
 
40
45
  def test_add_nodes_recursively(self):
41
46
  embedding_type = EmbeddingType.SOURCE
@@ -45,9 +50,9 @@ class TestVectorize(unittest.TestCase):
45
50
  "time_updated": datetime.datetime.now().time().isoformat("minutes"),
46
51
  }
47
52
  self.database.create_collection.assert_called_with("source_1", metadata=metadata)
48
- self.vectorizer._add_nodes_recursively(
53
+ self.vectorizer.add_nodes_recursively(
49
54
  self.test_block, embedding_type, self.test_file.name
50
55
  )
51
56
  self.database.get_or_create_collection.assert_called_with(
52
- "source_1", metadata=metadata
57
+ name="source_1", embedding_function=None, metadata=metadata
53
58
  )
@@ -1,10 +1,12 @@
1
1
  import datetime
2
2
  import os
3
- from typing import Sequence
3
+ from typing import Dict, Optional, Sequence
4
4
 
5
5
  from chromadb import Client, Collection
6
+ from langchain_community.vectorstores import Chroma
6
7
 
7
8
  from ..utils.enums import EmbeddingType
9
+ from .embedding_models_info import load_embedding_model
8
10
 
9
11
  # See https://docs.trychroma.com/telemetry#in-chromas-backend-using-environment-variables
10
12
  os.environ["ANONYMIZED_TELEMETRY"] = "False"
@@ -13,10 +15,16 @@ os.environ["ANONYMIZED_TELEMETRY"] = "False"
13
15
  class Collections:
14
16
  """Manage embedding collections"""
15
17
 
16
- def __init__(self, client: Client):
18
+ def __init__(self, client: Client, config: Optional[Dict[str, str]] = None):
17
19
  self._client = client
20
+ if config is not None:
21
+ self._config = config
22
+ else:
23
+ self._config = {}
18
24
 
19
- def create(self, name: EmbeddingType | str) -> Collection:
25
+ def create(
26
+ self, name: EmbeddingType | str, model_name: Optional[str] = None
27
+ ) -> Chroma:
20
28
  """Create a Chroma collection for the given embedding type.
21
29
 
22
30
  Arguments:
@@ -27,9 +35,23 @@ class Collections:
27
35
  "date_updated": datetime.datetime.now().date().isoformat(),
28
36
  "time_updated": datetime.datetime.now().time().isoformat("minutes"),
29
37
  }
30
- return self._client.create_collection(collection_name, metadata=metadata)
38
+ if model_name is not None:
39
+ metadata["embedding_model"] = model_name
40
+ self._client.create_collection(collection_name, metadata=metadata)
41
+ self._config[collection_name] = model_name
42
+ model, _, _ = load_embedding_model(model_name)
43
+ return Chroma(
44
+ client=self._client,
45
+ collection_name=collection_name,
46
+ embedding_function=model,
47
+ )
48
+ else:
49
+ self._client.create_collection(collection_name, metadata=metadata)
50
+ return Chroma(client=self._client, collection_name=collection_name)
31
51
 
32
- def get_or_create(self, name: EmbeddingType | str) -> Collection:
52
+ def get_or_create(
53
+ self, name: EmbeddingType | str, model_name: Optional[str] = None
54
+ ) -> Chroma:
33
55
  """Create a Chroma collection for the given embedding type.
34
56
 
35
57
  Arguments:
@@ -40,7 +62,26 @@ class Collections:
40
62
  "date_updated": datetime.datetime.now().date().isoformat(),
41
63
  "time_updated": datetime.datetime.now().time().isoformat("minutes"),
42
64
  }
43
- return self._client.get_or_create_collection(collection_name, metadata=metadata)
65
+ if collection_name in self._config:
66
+ model_name = self._config[collection_name]
67
+ if model_name is not None:
68
+ metadata["embedding_model"] = model_name
69
+ self._config[collection_name] = model_name
70
+ model, _, _ = load_embedding_model(model_name)
71
+ self._client.get_or_create_collection(collection_name, metadata=metadata)
72
+ return Chroma(
73
+ client=self._client,
74
+ collection_name=collection_name,
75
+ embedding_function=model,
76
+ collection_metadata=metadata,
77
+ )
78
+ else:
79
+ self._client.get_or_create_collection(collection_name, metadata=metadata)
80
+ return Chroma(
81
+ client=self._client,
82
+ collection_name=collection_name,
83
+ collection_metadata=metadata,
84
+ )
44
85
 
45
86
  def get(self, name: None | EmbeddingType | str = None) -> Sequence[Collection]:
46
87
  """Get the Chroma collections.
@@ -61,6 +102,8 @@ class Collections:
61
102
  collection_name = name.name.lower()
62
103
  else:
63
104
  collection_name = name
105
+ if collection_name in self._config:
106
+ del self._config[collection_name]
64
107
  self._client.delete_collection(collection_name)
65
108
 
66
109
  def _set_collection_name(self, name: EmbeddingType | str) -> str:
@@ -0,0 +1,120 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Dict, Tuple
4
+
5
+ from aenum import MultiValueEnum
6
+ from dotenv import load_dotenv
7
+ from langchain_community.embeddings.huggingface import (
8
+ HuggingFaceEmbeddings,
9
+ HuggingFaceInferenceAPIEmbeddings,
10
+ )
11
+ from langchain_core.embeddings import Embeddings
12
+ from langchain_openai import OpenAIEmbeddings
13
+
14
+ from janus.utils.logger import create_logger
15
+
16
+ load_dotenv()
17
+
18
+ log = create_logger(__name__)
19
+
20
+
21
+ class EmbeddingModelType(MultiValueEnum):
22
+ OpenAI = "OpenAI", "openai", "open-ai", "oai"
23
+ HuggingFaceLocal = "HuggingFaceLocal", "huggingfacelocal", "huggingface-local", "hfl"
24
+ HuggingFaceInferenceAPI = (
25
+ "HuggingFaceInferenceAPI",
26
+ "huggingfaceinferenceapi",
27
+ "huggingface-inference-api",
28
+ "hfia",
29
+ )
30
+
31
+
32
+ EMBEDDING_MODEL_TYPE_CONSTRUCTORS: Dict[
33
+ EmbeddingModelType, Callable[[Any], Embeddings]
34
+ ] = {}
35
+
36
+ for model_type in EmbeddingModelType:
37
+ for value in model_type.values:
38
+ if model_type == EmbeddingModelType.OpenAI:
39
+ EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = OpenAIEmbeddings
40
+ elif model_type == EmbeddingModelType.HuggingFaceLocal:
41
+ EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = HuggingFaceEmbeddings
42
+ elif model_type == EmbeddingModelType.HuggingFaceInferenceAPI:
43
+ EMBEDDING_MODEL_TYPE_CONSTRUCTORS[value] = HuggingFaceInferenceAPIEmbeddings
44
+
45
+ EMBEDDING_MODEL_TYPE_DEFAULT_IDS: Dict[EmbeddingModelType, Dict[str, Any]] = {
46
+ EmbeddingModelType.OpenAI.value: "text-embedding-3-small",
47
+ EmbeddingModelType.HuggingFaceLocal.value: "all-MiniLM-L6-v2",
48
+ EmbeddingModelType.HuggingFaceInferenceAPI.value: "",
49
+ }
50
+
51
+ EMBEDDING_MODEL_DEFAULT_ARGUMENTS: Dict[str, Dict[str, Any]] = {
52
+ "text-embedding-3-small": dict(model="text-embedding-3-small"),
53
+ "text-embedding-3-large": dict(model="text-embedding-3-large"),
54
+ "text-embedding-ada-002": dict(model="text-embedding-ada-002"),
55
+ }
56
+
57
+ EMBEDDING_MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "embeddings"
58
+
59
+ EMBEDDING_TOKEN_LIMITS: Dict[str, int] = {
60
+ "text-embedding-3-small": 8191,
61
+ "text-embedding-3-large": 8191,
62
+ "text-embedding-ada-002": 8191,
63
+ }
64
+
65
+ EMBEDDING_COST_PER_MODEL: Dict[str, float] = {
66
+ "text-embedding-3-small": {"input": 0.02 / 1e6, "output": 0.0},
67
+ "text-embedding-3-large": {"input": 0.13 / 1e6, "output": 0.0},
68
+ "text-embedding-ada-002": {"input": 0.10 / 1e6, "output": 0.0},
69
+ }
70
+
71
+
72
+ def load_embedding_model(
73
+ model_name: str,
74
+ ) -> Tuple[Embeddings, int, Dict[str, float]]:
75
+ """Load an embedding model from the configuration file or create a new one
76
+
77
+ Arguments:
78
+ model_name: The user-given name of the model to load.
79
+ model_type: The type of the model to load.
80
+ identifier: The identifier for the model (e.g. the name, URL, or HuggingFace
81
+ path).
82
+ """
83
+ if not EMBEDDING_MODEL_CONFIG_DIR.exists():
84
+ EMBEDDING_MODEL_CONFIG_DIR.mkdir(parents=True)
85
+ model_config_file = EMBEDDING_MODEL_CONFIG_DIR / f"{model_name}.json"
86
+
87
+ if not model_config_file.exists():
88
+ # The default model type is HuggingFaceLocal because that's the default for Chroma
89
+ model_type = EmbeddingModelType.HuggingFaceLocal.value
90
+ identifier = EMBEDDING_MODEL_TYPE_DEFAULT_IDS[model_type]
91
+ model_config = {
92
+ "model_type": model_type,
93
+ "model_identifier": identifier,
94
+ "model_args": EMBEDDING_MODEL_DEFAULT_ARGUMENTS.get(identifier, {}),
95
+ "token_limit": EMBEDDING_TOKEN_LIMITS.get(identifier, 8191),
96
+ "model_cost": EMBEDDING_COST_PER_MODEL.get(
97
+ identifier, {"input": 0, "output": 0}
98
+ ),
99
+ }
100
+ log.info(
101
+ f"WARNING: Creating new model config file: \
102
+ {model_config_file} with default config"
103
+ )
104
+ with open(model_config_file, "w") as f:
105
+ json.dump(model_config, f, indent=2)
106
+ else:
107
+ with open(model_config_file, "r") as f:
108
+ model_config = json.load(f)
109
+ model_constructor = EMBEDDING_MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
110
+ model_args = model_config["model_args"]
111
+ if model_config["model_type"] in EmbeddingModelType.HuggingFaceInferenceAPI.values:
112
+ model_args.update({"api_url": model_config["model_identifier"]})
113
+ elif model_config["model_type"] in EmbeddingModelType.HuggingFaceLocal.values:
114
+ model_args.update({"model_name": model_config["model_identifier"]})
115
+ model = model_constructor(**model_args)
116
+ return (
117
+ model,
118
+ model_config["token_limit"],
119
+ model_config["model_cost"],
120
+ )
@@ -1,67 +1,52 @@
1
1
  import uuid
2
2
  from abc import ABC, abstractmethod
3
3
  from pathlib import Path
4
- from typing import Sequence
4
+ from typing import Any, Dict, Optional, Sequence
5
5
 
6
6
  from chromadb import Client, Collection
7
+ from langchain_community.vectorstores import Chroma
7
8
 
8
- from ..converter import Converter
9
- from ..language.block import CodeBlock
10
- from ..llm.models_info import TOKEN_LIMITS
9
+ from ..language.block import CodeBlock, TranslatedCodeBlock
11
10
  from ..utils.enums import EmbeddingType
12
11
  from .collections import Collections
13
12
  from .database import ChromaEmbeddingDatabase
14
13
 
15
14
 
16
- class Vectorizer(Converter):
15
+ class Vectorizer(object):
17
16
  """Class for creating embeddings/vectors in a specified ChromaDB"""
18
17
 
19
- def __init__(
20
- self,
21
- client: Client,
22
- source_language: str,
23
- max_tokens: None | int,
24
- model: None | str,
25
- ) -> None:
18
+ def __init__(self, client: Client, config: Optional[Dict[str, Any]] = None) -> None:
26
19
  """Initializes the Vectorizer class
27
20
 
28
21
  Arguments:
29
22
  client: ChromaDB client instance
30
- source_language: The source programming language.
31
- max_tokens: The maximum number of tokens to send to the embedding model at
32
- once. If `None`, the `Vectorizer` will use the default value for the
33
- `model`.
34
- model: The name of the model to use. This will also determine the `max_tokens`
35
- if that variable is not set.
36
23
  """
37
- if max_tokens is None:
38
- max_tokens = TOKEN_LIMITS[model]
39
-
40
- super().__init__(
41
- source_language=source_language,
42
- max_tokens=max_tokens,
43
- )
44
24
  self._db = client
45
- self._collections = Collections(self._db)
25
+ self._collections = Collections(self._db, config)
46
26
 
47
- super()._load_parameters()
27
+ def get_or_create_collection(
28
+ self, name: EmbeddingType | str, model_name: Optional[str] = None
29
+ ) -> Chroma:
30
+ return self._collections.get_or_create(name, model_name=model_name)
48
31
 
49
- def create_collection(self, embedding_type: EmbeddingType) -> Collection:
50
- return self._collections.create(embedding_type)
32
+ def create_collection(
33
+ self, embedding_type: EmbeddingType, model_name: Optional[str] = None
34
+ ) -> Chroma:
35
+ return self._collections.create(embedding_type, model_name=model_name)
51
36
 
52
37
  def collections(
53
38
  self, name: None | EmbeddingType | str = None
54
39
  ) -> Sequence[Collection]:
55
40
  return self._collections.get(name)
56
41
 
57
- def _add_nodes_recursively(
42
+ def add_nodes_recursively(
58
43
  self, code_block: CodeBlock, collection_name: EmbeddingType | str, file_name: str
59
44
  ) -> None:
60
45
  """Embed all nodes in the tree rooted at `code_block`
61
46
 
62
47
  Arguments:
63
48
  code_block: CodeBlock to embed
64
- embedding_type: EmbeddingType to use
49
+ collection_name: Collection to add to
65
50
  file_name: Name of file containing `code_block`
66
51
  """
67
52
  nodes = [code_block]
@@ -74,41 +59,55 @@ class Vectorizer(Converter):
74
59
  self,
75
60
  code_block: CodeBlock,
76
61
  collection_name: EmbeddingType | str,
77
- file_name: str # perhaps this should be a relative path from the source, but for
62
+ filename: str # perhaps this should be a relative path from the source, but for
78
63
  # now we're all in 1 directory
79
- ) -> bool:
64
+ ) -> None:
80
65
  """Calculate `code_block` embedding, returning success & storing in `embedding_id`
81
66
 
82
67
  Arguments:
83
68
  code_block: CodeBlock to embed
84
- embedding_type: EmbeddingType to use
85
- file_name: Name of file containing `code_block`
86
-
87
- Returns:
88
- True if embedding was successful, False otherwise
69
+ collection_name: Collection to add to
70
+ filename: Name of file containing `code_block`
89
71
  """
90
72
  if code_block.text:
91
73
  metadatas = [
92
74
  {
93
- "type": code_block.type,
94
- "original_filename": file_name,
75
+ "type": code_block.node_type,
76
+ "id": code_block.id,
77
+ "name": code_block.name,
78
+ "language": code_block.language,
79
+ "filename": filename,
95
80
  "tokens": code_block.tokens,
96
81
  "cost": 0, # TranslatedCodeBlock has cost
97
82
  },
98
83
  ]
84
+ if collection_name in self.config:
85
+ metadatas[0]["embedding_model"] = self.config[collection_name]
99
86
  # for now, dealing with missing metadata by skipping it
87
+ if isinstance(code_block, TranslatedCodeBlock):
88
+ self._add(
89
+ code_block=code_block.original,
90
+ collection_name=collection_name,
91
+ filename=filename,
92
+ )
93
+ if code_block.original.embedding_id is not None:
94
+ metadatas[0][
95
+ "original_embedding_id"
96
+ ] = code_block.original.embedding_id
97
+ metadatas[0]["cost"] = code_block.cost
100
98
  if code_block.text is not None:
101
99
  metadatas[0]["hash"] = hash(code_block.text)
102
100
  if code_block.start_point is not None:
103
101
  metadatas[0]["start_line"] = code_block.start_point[0]
104
102
  if code_block.end_point is not None:
105
103
  metadatas[0]["end_line"] = code_block.end_point[0]
104
+ # TODO: Add metadata about translation parameters (e.g. model)
106
105
  the_text = [code_block.text]
107
- code_block.embedding_id = self.add_text(collection_name, the_text, metadatas)[
108
- 0
109
- ]
110
- return True
111
- return False
106
+ code_block.embedding_id = self.add_text(
107
+ collection_name,
108
+ the_text,
109
+ metadatas,
110
+ )[0]
112
111
 
113
112
  def add_text(
114
113
  self,
@@ -121,7 +120,7 @@ class Vectorizer(Converter):
121
120
  metadatas, returning the embedding id
122
121
 
123
122
  Arguments:
124
- embedding_type: EmbeddingType to use
123
+ collection_name: Collection to add to
125
124
  texts: list of texts to store
126
125
  metadatas: list of metadatas to store
127
126
  ids: list of embedding ids (must match lengh of texts),
@@ -137,20 +136,20 @@ class Vectorizer(Converter):
137
136
  # based on the text.
138
137
  ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, text)) for text in texts]
139
138
  collection = self._collections.get_or_create(collection_name)
140
- collection.upsert(ids=ids, documents=texts, metadatas=metadatas)
139
+ collection.add_texts(ids=ids, texts=texts, metadatas=metadatas)
141
140
  return ids
142
141
 
142
+ @property
143
+ def config(self):
144
+ return self._collections._config
145
+
143
146
 
144
147
  class VectorizerFactory(ABC):
145
148
  """Interface for creating a Vectorizer independent of type of ChromaDB client"""
146
149
 
147
150
  @abstractmethod
148
151
  def create_vectorizer(
149
- self,
150
- source_language: str,
151
- max_tokens: None | int,
152
- model: None | str,
153
- path: str | Path,
152
+ self, path: str | Path, config: Dict[str, Any] = {}
154
153
  ) -> Vectorizer:
155
154
  """Factory method"""
156
155
 
@@ -160,19 +159,11 @@ class ChromaDBVectorizer(VectorizerFactory):
160
159
 
161
160
  def create_vectorizer(
162
161
  self,
163
- source_language: str = "fortran",
164
- max_tokens: None | int = None,
165
- model: None | str = "gpt4all",
166
162
  path: str | Path = Path.home() / ".janus" / "chroma" / "chroma-data",
163
+ config: Optional[Dict[str, Any]] = None,
167
164
  ) -> Vectorizer:
168
165
  """
169
166
  Arguments:
170
- source_language: The source programming language.
171
- max_tokens: The maximum number of tokens to send to the embedding model at
172
- once. If `None`, the `Vectorizer` will use the default value for the
173
- `model`.
174
- model: The name of the model to use. This will also determine the `max_tokens`
175
- if that variable is not set.
176
167
  path: The path to the ChromaDB. Can be either a string of a URL or path or a
177
168
  Path object
178
169
 
@@ -180,4 +171,4 @@ class ChromaDBVectorizer(VectorizerFactory):
180
171
  Vectorizer
181
172
  """
182
173
  database = ChromaEmbeddingDatabase(path)
183
- return Vectorizer(database, source_language, max_tokens, model)
174
+ return Vectorizer(database, config)
File without changes