janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,112 @@
1
+ import json
2
+ import re
3
+ from typing import Any
4
+
5
+ from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.messages import BaseMessage
8
+ from langchain_core.pydantic_v1 import BaseModel, Field, conint
9
+
10
+ from janus.language.block import CodeBlock
11
+ from janus.parsers.parser import JanusParser
12
+ from janus.utils.logger import create_logger
13
+
14
+ log = create_logger(__name__)
15
+
16
+
17
+ class Criteria(BaseModel):
18
+ reasoning: str = Field(description="A short explanation for the given score")
19
+ # Constrained to an integer between 1 and 4
20
+ score: conint(ge=1, le=4) = Field( # type: ignore
21
+ description="An integer score between 1 and 4 (inclusive), 4 being the best"
22
+ )
23
+
24
+
25
+ class Comment(BaseModel):
26
+ comment_id: str = Field(description="The 8-character comment ID")
27
+ completeness: Criteria = Field(description="The completeness of the comment")
28
+ hallucination: Criteria = Field(description="The factualness of the comment")
29
+ readability: Criteria = Field(description="The readability of the comment")
30
+ usefulness: Criteria = Field(description="The usefulness of the comment")
31
+
32
+
33
+ class CommentList(BaseModel):
34
+ __root__: list[Comment] = Field(
35
+ description=(
36
+ "A list of inline comment evaluations. Each element should include"
37
+ " the comment's 8-character ID in the `comment_id` field, and four"
38
+ " score objects corresponding to each metric (`completeness`,"
39
+ " `hallucination`, `readability`, and `usefulness`)."
40
+ )
41
+ )
42
+
43
+
44
+ class InlineCommentParser(JanusParser, PydanticOutputParser):
45
+ comments: dict[str, str]
46
+
47
+ def __init__(self):
48
+ PydanticOutputParser.__init__(
49
+ self,
50
+ pydantic_object=CommentList,
51
+ comments=[],
52
+ )
53
+
54
+ def parse_input(self, block: CodeBlock) -> str:
55
+ # TODO: Perform comment stripping/placeholding here rather than in script
56
+ text = super().parse_input(block)
57
+ self.comments = dict(
58
+ re.findall(
59
+ r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
60
+ text,
61
+ flags=re.MULTILINE,
62
+ )
63
+ )
64
+ return text
65
+
66
+ def parse(self, text: str | BaseMessage) -> str:
67
+ if isinstance(text, BaseMessage):
68
+ text = str(text.content)
69
+
70
+ # Strip everything outside the JSON object
71
+ begin, end = text.find("["), text.rfind("]")
72
+ text = text[begin : end + 1]
73
+
74
+ try:
75
+ out: CommentList = super().parse(text)
76
+ except json.JSONDecodeError as e:
77
+ log.debug(f"Invalid JSON object. Output:\n{text}")
78
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
79
+
80
+ evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
81
+
82
+ seen_keys = set(evals.keys())
83
+ expected_keys = set(self.comments.keys())
84
+ missing_keys = expected_keys.difference(seen_keys)
85
+ invalid_keys = seen_keys.difference(expected_keys)
86
+ if missing_keys:
87
+ log.debug(f"Missing keys: {missing_keys}")
88
+ if invalid_keys:
89
+ log.debug(f"Invalid keys: {invalid_keys}")
90
+ log.debug(f"Missing keys: {missing_keys}")
91
+ raise OutputParserException(
92
+ f"Got invalid return object. Missing the following expected "
93
+ f"keys: {missing_keys}"
94
+ )
95
+
96
+ for key in invalid_keys:
97
+ del evals[key]
98
+
99
+ for cid in evals.keys():
100
+ evals[cid]["comment"] = self.comments[cid]
101
+ evals[cid].pop("comment_id")
102
+
103
+ return json.dumps(evals)
104
+
105
+ def parse_combined_output(self, text: str) -> str:
106
+ if not text.strip():
107
+ return str({})
108
+ objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
109
+ output_obj = {}
110
+ for obj in objs:
111
+ output_obj.update(obj)
112
+ return json.dumps(output_obj)
@@ -0,0 +1,168 @@
1
+ import json
2
+ import random
3
+ import uuid
4
+
5
+ from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.language_models import BaseLanguageModel
8
+ from langchain_core.messages import BaseMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field
10
+
11
+ from janus.language.block import CodeBlock
12
+ from janus.parsers.parser import JanusParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+ RNG = random.Random()
17
+
18
+
19
+ class PartitionObject(BaseModel):
20
+ reasoning: str = Field(
21
+ description="An explanation for why the code should be split at this point"
22
+ )
23
+ location: str = Field(
24
+ description="The 8-character line label which should start a new chunk"
25
+ )
26
+
27
+
28
+ class PartitionList(BaseModel):
29
+ __root__: list[PartitionObject] = Field(
30
+ description=(
31
+ "A list of appropriate split points, each with a `reasoning` field "
32
+ "that explains a justification for splitting the code at that point, "
33
+ "and a `location` field which is simply the 8-character line ID. "
34
+ "The `reasoning` field should always be included first."
35
+ )
36
+ )
37
+
38
+
39
+ # The following IDs appear in the prompt example. If the LLM produces them,
40
+ # they should be ignored
41
+ EXAMPLE_IDS = {
42
+ "0d2f4f8d",
43
+ "def2a953",
44
+ "75315253",
45
+ "e7f928da",
46
+ "1781b2a9",
47
+ "2fe21e27",
48
+ "9aef6179",
49
+ "6061bd82",
50
+ "22bd0c30",
51
+ "5d85e19e",
52
+ "06027969",
53
+ "91b722fb",
54
+ "4b3f79be",
55
+ "k57w964a",
56
+ "51638s96",
57
+ "065o6q32",
58
+ "j5q6p852",
59
+ }
60
+
61
+
62
+ class PartitionParser(JanusParser, PydanticOutputParser):
63
+ token_limit: int
64
+ model: BaseLanguageModel
65
+ lines: list[str] = []
66
+ line_id_to_index: dict[str, int] = {}
67
+
68
+ def __init__(self, token_limit: int, model: BaseLanguageModel):
69
+ PydanticOutputParser.__init__(
70
+ self,
71
+ pydantic_object=PartitionList,
72
+ model=model,
73
+ token_limit=token_limit,
74
+ )
75
+
76
+ def parse_input(self, block: CodeBlock) -> str:
77
+ code = str(block.text)
78
+ RNG.seed(code)
79
+
80
+ self.lines = code.split("\n")
81
+
82
+ # Generate a unique ID for each line (ensure they are unique)
83
+ line_ids = set()
84
+ while len(line_ids) < len(self.lines):
85
+ line_id = str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8]
86
+ if line_id in EXAMPLE_IDS:
87
+ continue
88
+ line_ids.add(line_id)
89
+
90
+ # Prepend each line with the corresponding ID, save the mapping
91
+ self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
92
+ processed = "\n".join(
93
+ f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
94
+ )
95
+ return processed
96
+
97
+ def parse(self, text: str | BaseMessage) -> str:
98
+ if isinstance(text, BaseMessage):
99
+ text = str(text.content)
100
+
101
+ # Strip everything outside the JSON object
102
+ begin, end = text.find("["), text.rfind("]")
103
+ text = text[begin : end + 1]
104
+
105
+ try:
106
+ out: PartitionList = super().parse(text)
107
+ except (OutputParserException, json.JSONDecodeError):
108
+ log.debug(f"Invalid JSON object. Output:\n{text}")
109
+ raise
110
+
111
+ # Get partition locations, discard reasoning
112
+ partition_locations = {partition.location for partition in out.__root__}
113
+
114
+ # Ignore IDs from the example input
115
+ partition_locations.difference_update(EXAMPLE_IDS)
116
+
117
+ # Locate any invalid line IDs, raise exception if any found
118
+ invalid_splits = partition_locations.difference(self.line_id_to_index)
119
+ if invalid_splits:
120
+ err_msg = (
121
+ f"{len(invalid_splits)} line ID(s) not found in input: "
122
+ + ", ".join(invalid_splits)
123
+ )
124
+ log.warning(err_msg)
125
+ raise OutputParserException(err_msg)
126
+
127
+ # Map line IDs to indices (so they can be sorted and lines indexed)
128
+ index_to_line_id = {0: "START", None: "END"}
129
+ split_points = {0}
130
+ for partition in partition_locations:
131
+ index = self.line_id_to_index[partition]
132
+ index_to_line_id[index] = partition
133
+ split_points.add(index)
134
+
135
+ # Get partition start/ends, chunks, chunk lengths
136
+ split_points = sorted(split_points) + [None]
137
+ partition_indices = list(zip(split_points, split_points[1:]))
138
+ partition_points = [
139
+ (index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
140
+ ]
141
+ chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
142
+ chunk_tokens = list(map(self.model.get_num_tokens, chunks))
143
+
144
+ # Collect any chunks that exceed token limit
145
+ oversized_indices: list[int] = [
146
+ i for i, n in enumerate(chunk_tokens) if n > self.token_limit
147
+ ]
148
+ if oversized_indices:
149
+ data = list(zip(partition_points, chunks, chunk_tokens))
150
+ data = [data[i] for i in oversized_indices]
151
+
152
+ problem_points = "\n".join(
153
+ [
154
+ f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
155
+ for (i0, i1), _, t in data
156
+ ]
157
+ )
158
+ log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
159
+ log.debug(
160
+ "Oversized chunks:\n"
161
+ + "\n#############\n".join(chunk for _, chunk, _ in data)
162
+ )
163
+ raise OutputParserException(
164
+ f"The following segments are too long and must be "
165
+ f"further subdivided:\n{problem_points}"
166
+ )
167
+
168
+ return "\n<JANUS_PARTITION>\n".join(chunks)
janus/refiners/refiner.py CHANGED
@@ -1,6 +1,8 @@
1
+ import re
1
2
  from typing import Any
2
3
 
3
4
  from langchain.output_parsers import RetryWithErrorOutputParser
5
+ from langchain_core.exceptions import OutputParserException
4
6
  from langchain_core.output_parsers import StrOutputParser
5
7
  from langchain_core.prompt_values import PromptValue
6
8
  from langchain_core.runnables import RunnableSerializable
@@ -25,9 +27,38 @@ class JanusRefiner(JanusParser):
25
27
  raise NotImplementedError
26
28
 
27
29
 
30
+ class SimpleRetry(JanusRefiner):
31
+ max_retries: int
32
+ retry_chain: RunnableSerializable
33
+
34
+ def __init__(
35
+ self,
36
+ llm: JanusModel,
37
+ parser: JanusParser,
38
+ max_retries: int,
39
+ ):
40
+ retry_chain = llm | StrOutputParser()
41
+ super().__init__(
42
+ retry_chain=retry_chain,
43
+ parser=parser,
44
+ max_retries=max_retries,
45
+ )
46
+
47
+ def parse_completion(
48
+ self, completion: str, prompt_value: PromptValue, **kwargs
49
+ ) -> Any:
50
+ for retry_number in range(self.max_retries):
51
+ try:
52
+ return self.parser.parse(completion)
53
+ except OutputParserException:
54
+ completion = self.retry_chain.invoke(prompt_value)
55
+
56
+ return self.parser.parse(completion)
57
+
58
+
28
59
  class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
60
  def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
- retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
61
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
31
62
  source_language="text",
32
63
  prompt_template="refinement/fix_exceptions",
33
64
  ).prompt
@@ -46,6 +77,7 @@ class ReflectionRefiner(JanusRefiner):
46
77
  max_retries: int
47
78
  reflection_chain: RunnableSerializable
48
79
  revision_chain: RunnableSerializable
80
+ reflection_prompt_name: str
49
81
 
50
82
  def __init__(
51
83
  self,
@@ -54,11 +86,11 @@ class ReflectionRefiner(JanusRefiner):
54
86
  max_retries: int,
55
87
  prompt_template_name: str = "refinement/reflection",
56
88
  ):
57
- reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
89
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
58
90
  source_language="text",
59
91
  prompt_template=prompt_template_name,
60
92
  ).prompt
61
- revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
93
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
62
94
  source_language="text",
63
95
  prompt_template="refinement/revision",
64
96
  ).prompt
@@ -66,6 +98,7 @@ class ReflectionRefiner(JanusRefiner):
66
98
  reflection_chain = reflection_prompt | llm | StrOutputParser()
67
99
  revision_chain = revision_prompt | llm | StrOutputParser()
68
100
  super().__init__(
101
+ reflection_prompt_name=prompt_template_name,
69
102
  reflection_chain=reflection_chain,
70
103
  revision_chain=revision_chain,
71
104
  parser=parser,
@@ -75,6 +108,7 @@ class ReflectionRefiner(JanusRefiner):
75
108
  def parse_completion(
76
109
  self, completion: str, prompt_value: PromptValue, **kwargs
77
110
  ) -> Any:
111
+ log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
78
112
  for retry_number in range(self.max_retries):
79
113
  reflection = self.reflection_chain.invoke(
80
114
  dict(
@@ -82,7 +116,7 @@ class ReflectionRefiner(JanusRefiner):
82
116
  completion=completion,
83
117
  )
84
118
  )
85
- if reflection.strip() == "LGTM":
119
+ if re.search(r"\bLGTM\b", reflection) is not None:
86
120
  return self.parser.parse(completion)
87
121
  if not retry_number:
88
122
  log.info(f"Completion:\n{completion}")
@@ -105,11 +139,3 @@ class HallucinationRefiner(ReflectionRefiner):
105
139
  prompt_template_name="refinement/hallucination",
106
140
  **kwargs,
107
141
  )
108
-
109
-
110
- REFINERS = dict(
111
- none=JanusRefiner,
112
- parser=FixParserExceptions,
113
- reflection=ReflectionRefiner,
114
- hallucination=HallucinationRefiner,
115
- )
janus/refiners/uml.py ADDED
@@ -0,0 +1,33 @@
1
+ from janus.llm.models_info import JanusModel
2
+ from janus.parsers.parser import JanusParser
3
+ from janus.refiners.refiner import ReflectionRefiner
4
+
5
+
6
+ class ALCFixUMLVariablesRefiner(ReflectionRefiner):
7
+ def __init__(
8
+ self,
9
+ llm: JanusModel,
10
+ parser: JanusParser,
11
+ max_retries: int,
12
+ ):
13
+ super().__init__(
14
+ llm=llm,
15
+ parser=parser,
16
+ max_retries=max_retries,
17
+ prompt_template_name="refinement/uml/alc_fix_variables",
18
+ )
19
+
20
+
21
+ class FixUMLConnectionsRefiner(ReflectionRefiner):
22
+ def __init__(
23
+ self,
24
+ llm: JanusModel,
25
+ parser: JanusParser,
26
+ max_retries: int,
27
+ ):
28
+ super().__init__(
29
+ llm=llm,
30
+ parser=parser,
31
+ max_retries=max_retries,
32
+ prompt_template_name="refinement/uml/fix_connections",
33
+ )
@@ -1,7 +1,16 @@
1
+ from typing import List
2
+
3
+ from langchain_core.documents import Document
4
+ from langchain_core.output_parsers import StrOutputParser
1
5
  from langchain_core.retrievers import BaseRetriever
2
6
  from langchain_core.runnables import Runnable, RunnableConfig
3
7
 
4
8
  from janus.language.block import CodeBlock
9
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
10
+ from janus.utils.logger import create_logger
11
+ from janus.utils.pdf_docs_reader import PDFDocsReader
12
+
13
+ log = create_logger(__name__)
5
14
 
6
15
 
7
16
  class JanusRetriever(Runnable):
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
40
49
  docs = self.retriever.invoke(code_block.text)
41
50
  context = "\n\n".join(doc.page_content for doc in docs)
42
51
  return f"You may use the following additional context: {context}"
52
+
53
+
54
+ class LanguageDocsRetriever(JanusRetriever):
55
+ def __init__(
56
+ self,
57
+ llm: JanusModel,
58
+ language_name: str,
59
+ prompt_template_name: str = "retrieval/language_docs",
60
+ ):
61
+ super().__init__()
62
+ self.llm: JanusModel = llm
63
+ self.language: str = language_name
64
+
65
+ self.PDF_reader = PDFDocsReader(
66
+ language=self.language,
67
+ )
68
+
69
+ language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
70
+ source_language=self.language,
71
+ prompt_template=prompt_template_name,
72
+ ).prompt
73
+
74
+ parser: StrOutputParser = StrOutputParser()
75
+ self.chain = language_docs_prompt | self.llm | parser
76
+
77
+ def get_context(self, code_block: CodeBlock) -> str:
78
+ functionality_to_reference: str = self.chain.invoke(
79
+ dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
80
+ )
81
+ if functionality_to_reference == "NODOCS":
82
+ log.debug("No Opcodes requested from language docs retriever.")
83
+ return ""
84
+ else:
85
+ functionality_to_reference: List = functionality_to_reference.split(", ")
86
+ log.debug(
87
+ f"List of opcodes requested by language docs retriever"
88
+ f"to search the {self.language} "
89
+ f"docs for: {functionality_to_reference}"
90
+ )
91
+
92
+ docs: List[Document] = self.PDF_reader.search_language_reference(
93
+ functionality_to_reference
94
+ )
95
+ context = "\n\n".join(doc.page_content for doc in docs)
96
+ if context:
97
+ return (
98
+ f"You may reference the following excerpts from the {self.language} "
99
+ f"language documentation: {context}"
100
+ )
101
+ else:
102
+ return ""
janus/utils/enums.py CHANGED
@@ -89,6 +89,20 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
89
89
  "url": "https://github.com/stsewd/tree-sitter-comment",
90
90
  "example": "# This is a comment\n",
91
91
  },
92
+ "cobol": {
93
+ "comment": "*",
94
+ "suffix": "cbl",
95
+ "url": "https://github.com/yutaro-sakamoto/tree-sitter-cobol",
96
+ "example": (
97
+ " IDENTIFICATION DIVISION.\n"
98
+ " PROGRAM-ID. HelloWorld.\n"
99
+ " ENVIRONMENT DIVISION.\n"
100
+ " DATA DIVISION.\n"
101
+ " PROCEDURE DIVISION.\n"
102
+ ' DISPLAY "Hello, World!".\n'
103
+ " STOP RUN.\n"
104
+ ),
105
+ },
92
106
  "commonlisp": {
93
107
  "comment": ";;",
94
108
  "suffix": "lisp",
@@ -0,0 +1,134 @@
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ import joblib
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_core.documents import Document
9
+ from langchain_unstructured import UnstructuredLoader
10
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+
17
+
18
+ class PDFDocsReader:
19
+ def __init__(
20
+ self,
21
+ language: str,
22
+ chunk_size: int = 1000,
23
+ chunk_overlap: int = 100,
24
+ start_page: Optional[int] = None,
25
+ end_page: Optional[int] = None,
26
+ vectorizer: CountVectorizer = TfidfVectorizer(),
27
+ ):
28
+ self.retrieval_docs_dir: Path = Path(
29
+ os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
30
+ )
31
+ self.language = language
32
+ self.chunk_size = chunk_size
33
+ self.chunk_overlap = chunk_overlap
34
+ self.start_page = start_page
35
+ self.end_page = end_page
36
+ self.vectorizer = vectorizer
37
+ self.documents = self.load_and_chunk_pdf()
38
+ self.doc_vectors = self.vectorize_documents()
39
+
40
+ def load_and_chunk_pdf(self) -> List[str]:
41
+ pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
42
+ pickled_documents_path = (
43
+ self.retrieval_docs_dir / f"{self.language}_documents.pkl"
44
+ )
45
+
46
+ if pickled_documents_path.exists():
47
+ log.debug(
48
+ f"Loading pre-chunked PDF from {pickled_documents_path}. "
49
+ f"If you want to regenerate retrieval docs for {self.language}, "
50
+ f"delete the file at {pickled_documents_path}, "
51
+ f"then add a new {self.language}.pdf."
52
+ )
53
+ documents = joblib.load(pickled_documents_path)
54
+ else:
55
+ if not pdf_path.exists():
56
+ raise FileNotFoundError(
57
+ f"Language docs retrieval is enabled, but no PDF for language "
58
+ f"'{self.language}' was found. Move a "
59
+ f"{self.language} reference manual to "
60
+ f"{pdf_path.absolute()} "
61
+ f"(the path to the directory of PDF docs can be "
62
+ f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
63
+ )
64
+ log.info(
65
+ f"Chunking reference PDF for {self.language} using unstructured - "
66
+ f"if your PDF has many pages, this could take a while..."
67
+ )
68
+ start_time = time.time()
69
+ loader = UnstructuredLoader(
70
+ pdf_path,
71
+ chunking_strategy="basic",
72
+ max_characters=1000000,
73
+ include_orig_elements=False,
74
+ start_page=self.start_page,
75
+ end_page=self.end_page,
76
+ )
77
+ docs = loader.load()
78
+ text = "\n\n".join([doc.page_content for doc in docs])
79
+ text_splitter = RecursiveCharacterTextSplitter(
80
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
81
+ )
82
+ documents = text_splitter.split_text(text)
83
+ log.info(f"Document store created for language: {self.language}")
84
+ end_time = time.time()
85
+ log.info(
86
+ f"Processing time for {self.language} PDF: "
87
+ f"{end_time - start_time} seconds"
88
+ )
89
+
90
+ joblib.dump(documents, pickled_documents_path)
91
+ log.debug(f"Documents saved to {pickled_documents_path}")
92
+
93
+ return documents
94
+
95
+ def vectorize_documents(self) -> (TfidfVectorizer, any):
96
+ doc_vectors = self.vectorizer.fit_transform(self.documents)
97
+ return doc_vectors
98
+
99
+ def search_language_reference(
100
+ self,
101
+ query: List[str],
102
+ top_k: int = 1,
103
+ min_similarity: float = 0.1,
104
+ ) -> List[Document]:
105
+ """Searches through the vectorized PDF for the query using
106
+ tf-idf and returns a list of langchain Documents."""
107
+
108
+ docs: List[Document] = []
109
+
110
+ for item in query:
111
+ # Transform the query using the TF-IDF vectorizer
112
+ query_vector = self.vectorizer.transform([item])
113
+
114
+ # Calculate cosine similarities between the query and document vectors
115
+ similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
116
+
117
+ # Get the indices of documents with similarity above the threshold
118
+ valid_indices = [
119
+ i for i, sim in enumerate(similarities) if sim >= min_similarity
120
+ ]
121
+
122
+ # Sort the valid indices by similarity score in descending order
123
+ sorted_indices = sorted(
124
+ valid_indices, key=lambda i: similarities[i], reverse=True
125
+ )
126
+
127
+ # Limit to top-k results
128
+ top_indices = sorted_indices[:top_k]
129
+
130
+ # Retrieve the top-k most relevant documents
131
+ docs += [Document(page_content=self.documents[i]) for i in top_indices]
132
+ log.debug(f"Langauge documentation search result: {docs}")
133
+
134
+ return docs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 4.1.0
3
+ Version: 4.3.1
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
23
23
  Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
24
24
  Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
25
25
  Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
26
+ Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
26
27
  Requires-Dist: nltk (>=3.8.1,<4.0.0)
27
28
  Requires-Dist: numpy (>=1.24.3,<2.0.0)
28
29
  Requires-Dist: openai (>=1.14.0,<2.0.0)
30
+ Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
29
31
  Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
30
32
  Requires-Dist: py-rouge (>=1.1,<2.0)
33
+ Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
31
34
  Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
32
35
  Requires-Dist: rich (>=13.7.1,<14.0.0)
33
36
  Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
37
+ Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
34
38
  Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
39
+ Requires-Dist: tesseract (>=0.1.3,<0.2.0)
35
40
  Requires-Dist: text-generation (>=0.6.0,<0.7.0)
36
41
  Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
37
42
  Requires-Dist: transformers (>=4.31.0,<5.0.0)
38
43
  Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
39
44
  Requires-Dist: typer (>=0.9.0,<0.10.0)
45
+ Requires-Dist: unstructured (>=0.15.9,<0.16.0)
46
+ Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
47
+ Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
40
48
  Project-URL: Documentation, https://janus-llm.github.io/janus-llm
41
49
  Project-URL: Repository, https://github.com/janus-llm/janus-llm
42
50
  Description-Content-Type: text/markdown