janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,112 @@
1
+ import json
2
+ import re
3
+ from typing import Any
4
+
5
+ from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.messages import BaseMessage
8
+ from langchain_core.pydantic_v1 import BaseModel, Field, conint
9
+
10
+ from janus.language.block import CodeBlock
11
+ from janus.parsers.parser import JanusParser
12
+ from janus.utils.logger import create_logger
13
+
14
+ log = create_logger(__name__)
15
+
16
+
17
+ class Criteria(BaseModel):
18
+ reasoning: str = Field(description="A short explanation for the given score")
19
+ # Constrained to an integer between 1 and 4
20
+ score: conint(ge=1, le=4) = Field( # type: ignore
21
+ description="An integer score between 1 and 4 (inclusive), 4 being the best"
22
+ )
23
+
24
+
25
+ class Comment(BaseModel):
26
+ comment_id: str = Field(description="The 8-character comment ID")
27
+ completeness: Criteria = Field(description="The completeness of the comment")
28
+ hallucination: Criteria = Field(description="The factualness of the comment")
29
+ readability: Criteria = Field(description="The readability of the comment")
30
+ usefulness: Criteria = Field(description="The usefulness of the comment")
31
+
32
+
33
+ class CommentList(BaseModel):
34
+ __root__: list[Comment] = Field(
35
+ description=(
36
+ "A list of inline comment evaluations. Each element should include"
37
+ " the comment's 8-character ID in the `comment_id` field, and four"
38
+ " score objects corresponding to each metric (`completeness`,"
39
+ " `hallucination`, `readability`, and `usefulness`)."
40
+ )
41
+ )
42
+
43
+
44
+ class InlineCommentParser(JanusParser, PydanticOutputParser):
45
+ comments: dict[str, str]
46
+
47
+ def __init__(self):
48
+ PydanticOutputParser.__init__(
49
+ self,
50
+ pydantic_object=CommentList,
51
+ comments=[],
52
+ )
53
+
54
+ def parse_input(self, block: CodeBlock) -> str:
55
+ # TODO: Perform comment stripping/placeholding here rather than in script
56
+ text = super().parse_input(block)
57
+ self.comments = dict(
58
+ re.findall(
59
+ r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
60
+ text,
61
+ flags=re.MULTILINE,
62
+ )
63
+ )
64
+ return text
65
+
66
+ def parse(self, text: str | BaseMessage) -> str:
67
+ if isinstance(text, BaseMessage):
68
+ text = str(text.content)
69
+
70
+ # Strip everything outside the JSON object
71
+ begin, end = text.find("["), text.rfind("]")
72
+ text = text[begin : end + 1]
73
+
74
+ try:
75
+ out: CommentList = super().parse(text)
76
+ except json.JSONDecodeError as e:
77
+ log.debug(f"Invalid JSON object. Output:\n{text}")
78
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
79
+
80
+ evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
81
+
82
+ seen_keys = set(evals.keys())
83
+ expected_keys = set(self.comments.keys())
84
+ missing_keys = expected_keys.difference(seen_keys)
85
+ invalid_keys = seen_keys.difference(expected_keys)
86
+ if missing_keys:
87
+ log.debug(f"Missing keys: {missing_keys}")
88
+ if invalid_keys:
89
+ log.debug(f"Invalid keys: {invalid_keys}")
90
+ log.debug(f"Missing keys: {missing_keys}")
91
+ raise OutputParserException(
92
+ f"Got invalid return object. Missing the following expected "
93
+ f"keys: {missing_keys}"
94
+ )
95
+
96
+ for key in invalid_keys:
97
+ del evals[key]
98
+
99
+ for cid in evals.keys():
100
+ evals[cid]["comment"] = self.comments[cid]
101
+ evals[cid].pop("comment_id")
102
+
103
+ return json.dumps(evals)
104
+
105
+ def parse_combined_output(self, text: str) -> str:
106
+ if not text.strip():
107
+ return str({})
108
+ objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
109
+ output_obj = {}
110
+ for obj in objs:
111
+ output_obj.update(obj)
112
+ return json.dumps(output_obj)
@@ -0,0 +1,168 @@
1
+ import json
2
+ import random
3
+ import uuid
4
+
5
+ from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.language_models import BaseLanguageModel
8
+ from langchain_core.messages import BaseMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field
10
+
11
+ from janus.language.block import CodeBlock
12
+ from janus.parsers.parser import JanusParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+ RNG = random.Random()
17
+
18
+
19
+ class PartitionObject(BaseModel):
20
+ reasoning: str = Field(
21
+ description="An explanation for why the code should be split at this point"
22
+ )
23
+ location: str = Field(
24
+ description="The 8-character line label which should start a new chunk"
25
+ )
26
+
27
+
28
+ class PartitionList(BaseModel):
29
+ __root__: list[PartitionObject] = Field(
30
+ description=(
31
+ "A list of appropriate split points, each with a `reasoning` field "
32
+ "that explains a justification for splitting the code at that point, "
33
+ "and a `location` field which is simply the 8-character line ID. "
34
+ "The `reasoning` field should always be included first."
35
+ )
36
+ )
37
+
38
+
39
+ # The following IDs appear in the prompt example. If the LLM produces them,
40
+ # they should be ignored
41
+ EXAMPLE_IDS = {
42
+ "0d2f4f8d",
43
+ "def2a953",
44
+ "75315253",
45
+ "e7f928da",
46
+ "1781b2a9",
47
+ "2fe21e27",
48
+ "9aef6179",
49
+ "6061bd82",
50
+ "22bd0c30",
51
+ "5d85e19e",
52
+ "06027969",
53
+ "91b722fb",
54
+ "4b3f79be",
55
+ "k57w964a",
56
+ "51638s96",
57
+ "065o6q32",
58
+ "j5q6p852",
59
+ }
60
+
61
+
62
+ class PartitionParser(JanusParser, PydanticOutputParser):
63
+ token_limit: int
64
+ model: BaseLanguageModel
65
+ lines: list[str] = []
66
+ line_id_to_index: dict[str, int] = {}
67
+
68
+ def __init__(self, token_limit: int, model: BaseLanguageModel):
69
+ PydanticOutputParser.__init__(
70
+ self,
71
+ pydantic_object=PartitionList,
72
+ model=model,
73
+ token_limit=token_limit,
74
+ )
75
+
76
+ def parse_input(self, block: CodeBlock) -> str:
77
+ code = str(block.text)
78
+ RNG.seed(code)
79
+
80
+ self.lines = code.split("\n")
81
+
82
+ # Generate a unique ID for each line (ensure they are unique)
83
+ line_ids = set()
84
+ while len(line_ids) < len(self.lines):
85
+ line_id = str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8]
86
+ if line_id in EXAMPLE_IDS:
87
+ continue
88
+ line_ids.add(line_id)
89
+
90
+ # Prepend each line with the corresponding ID, save the mapping
91
+ self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
92
+ processed = "\n".join(
93
+ f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
94
+ )
95
+ return processed
96
+
97
+ def parse(self, text: str | BaseMessage) -> str:
98
+ if isinstance(text, BaseMessage):
99
+ text = str(text.content)
100
+
101
+ # Strip everything outside the JSON object
102
+ begin, end = text.find("["), text.rfind("]")
103
+ text = text[begin : end + 1]
104
+
105
+ try:
106
+ out: PartitionList = super().parse(text)
107
+ except (OutputParserException, json.JSONDecodeError):
108
+ log.debug(f"Invalid JSON object. Output:\n{text}")
109
+ raise
110
+
111
+ # Get partition locations, discard reasoning
112
+ partition_locations = {partition.location for partition in out.__root__}
113
+
114
+ # Ignore IDs from the example input
115
+ partition_locations.difference_update(EXAMPLE_IDS)
116
+
117
+ # Locate any invalid line IDs, raise exception if any found
118
+ invalid_splits = partition_locations.difference(self.line_id_to_index)
119
+ if invalid_splits:
120
+ err_msg = (
121
+ f"{len(invalid_splits)} line ID(s) not found in input: "
122
+ + ", ".join(invalid_splits)
123
+ )
124
+ log.warning(err_msg)
125
+ raise OutputParserException(err_msg)
126
+
127
+ # Map line IDs to indices (so they can be sorted and lines indexed)
128
+ index_to_line_id = {0: "START", None: "END"}
129
+ split_points = {0}
130
+ for partition in partition_locations:
131
+ index = self.line_id_to_index[partition]
132
+ index_to_line_id[index] = partition
133
+ split_points.add(index)
134
+
135
+ # Get partition start/ends, chunks, chunk lengths
136
+ split_points = sorted(split_points) + [None]
137
+ partition_indices = list(zip(split_points, split_points[1:]))
138
+ partition_points = [
139
+ (index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
140
+ ]
141
+ chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
142
+ chunk_tokens = list(map(self.model.get_num_tokens, chunks))
143
+
144
+ # Collect any chunks that exceed token limit
145
+ oversized_indices: list[int] = [
146
+ i for i, n in enumerate(chunk_tokens) if n > self.token_limit
147
+ ]
148
+ if oversized_indices:
149
+ data = list(zip(partition_points, chunks, chunk_tokens))
150
+ data = [data[i] for i in oversized_indices]
151
+
152
+ problem_points = "\n".join(
153
+ [
154
+ f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
155
+ for (i0, i1), _, t in data
156
+ ]
157
+ )
158
+ log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
159
+ log.debug(
160
+ "Oversized chunks:\n"
161
+ + "\n#############\n".join(chunk for _, chunk, _ in data)
162
+ )
163
+ raise OutputParserException(
164
+ f"The following segments are too long and must be "
165
+ f"further subdivided:\n{problem_points}"
166
+ )
167
+
168
+ return "\n<JANUS_PARTITION>\n".join(chunks)
janus/refiners/refiner.py CHANGED
@@ -1,6 +1,8 @@
1
+ import re
1
2
  from typing import Any
2
3
 
3
4
  from langchain.output_parsers import RetryWithErrorOutputParser
5
+ from langchain_core.exceptions import OutputParserException
4
6
  from langchain_core.output_parsers import StrOutputParser
5
7
  from langchain_core.prompt_values import PromptValue
6
8
  from langchain_core.runnables import RunnableSerializable
@@ -25,9 +27,38 @@ class JanusRefiner(JanusParser):
25
27
  raise NotImplementedError
26
28
 
27
29
 
30
+ class SimpleRetry(JanusRefiner):
31
+ max_retries: int
32
+ retry_chain: RunnableSerializable
33
+
34
+ def __init__(
35
+ self,
36
+ llm: JanusModel,
37
+ parser: JanusParser,
38
+ max_retries: int,
39
+ ):
40
+ retry_chain = llm | StrOutputParser()
41
+ super().__init__(
42
+ retry_chain=retry_chain,
43
+ parser=parser,
44
+ max_retries=max_retries,
45
+ )
46
+
47
+ def parse_completion(
48
+ self, completion: str, prompt_value: PromptValue, **kwargs
49
+ ) -> Any:
50
+ for retry_number in range(self.max_retries):
51
+ try:
52
+ return self.parser.parse(completion)
53
+ except OutputParserException:
54
+ completion = self.retry_chain.invoke(prompt_value)
55
+
56
+ return self.parser.parse(completion)
57
+
58
+
28
59
  class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
60
  def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
- retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
61
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
31
62
  source_language="text",
32
63
  prompt_template="refinement/fix_exceptions",
33
64
  ).prompt
@@ -46,6 +77,7 @@ class ReflectionRefiner(JanusRefiner):
46
77
  max_retries: int
47
78
  reflection_chain: RunnableSerializable
48
79
  revision_chain: RunnableSerializable
80
+ reflection_prompt_name: str
49
81
 
50
82
  def __init__(
51
83
  self,
@@ -54,11 +86,11 @@ class ReflectionRefiner(JanusRefiner):
54
86
  max_retries: int,
55
87
  prompt_template_name: str = "refinement/reflection",
56
88
  ):
57
- reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
89
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
58
90
  source_language="text",
59
91
  prompt_template=prompt_template_name,
60
92
  ).prompt
61
- revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
93
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
62
94
  source_language="text",
63
95
  prompt_template="refinement/revision",
64
96
  ).prompt
@@ -66,6 +98,7 @@ class ReflectionRefiner(JanusRefiner):
66
98
  reflection_chain = reflection_prompt | llm | StrOutputParser()
67
99
  revision_chain = revision_prompt | llm | StrOutputParser()
68
100
  super().__init__(
101
+ reflection_prompt_name=prompt_template_name,
69
102
  reflection_chain=reflection_chain,
70
103
  revision_chain=revision_chain,
71
104
  parser=parser,
@@ -75,6 +108,7 @@ class ReflectionRefiner(JanusRefiner):
75
108
  def parse_completion(
76
109
  self, completion: str, prompt_value: PromptValue, **kwargs
77
110
  ) -> Any:
111
+ log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
78
112
  for retry_number in range(self.max_retries):
79
113
  reflection = self.reflection_chain.invoke(
80
114
  dict(
@@ -82,7 +116,7 @@ class ReflectionRefiner(JanusRefiner):
82
116
  completion=completion,
83
117
  )
84
118
  )
85
- if reflection.strip() == "LGTM":
119
+ if re.search(r"\bLGTM\b", reflection) is not None:
86
120
  return self.parser.parse(completion)
87
121
  if not retry_number:
88
122
  log.info(f"Completion:\n{completion}")
@@ -105,11 +139,3 @@ class HallucinationRefiner(ReflectionRefiner):
105
139
  prompt_template_name="refinement/hallucination",
106
140
  **kwargs,
107
141
  )
108
-
109
-
110
- REFINERS = dict(
111
- none=JanusRefiner,
112
- parser=FixParserExceptions,
113
- reflection=ReflectionRefiner,
114
- hallucination=HallucinationRefiner,
115
- )
janus/refiners/uml.py ADDED
@@ -0,0 +1,33 @@
1
+ from janus.llm.models_info import JanusModel
2
+ from janus.parsers.parser import JanusParser
3
+ from janus.refiners.refiner import ReflectionRefiner
4
+
5
+
6
+ class ALCFixUMLVariablesRefiner(ReflectionRefiner):
7
+ def __init__(
8
+ self,
9
+ llm: JanusModel,
10
+ parser: JanusParser,
11
+ max_retries: int,
12
+ ):
13
+ super().__init__(
14
+ llm=llm,
15
+ parser=parser,
16
+ max_retries=max_retries,
17
+ prompt_template_name="refinement/uml/alc_fix_variables",
18
+ )
19
+
20
+
21
+ class FixUMLConnectionsRefiner(ReflectionRefiner):
22
+ def __init__(
23
+ self,
24
+ llm: JanusModel,
25
+ parser: JanusParser,
26
+ max_retries: int,
27
+ ):
28
+ super().__init__(
29
+ llm=llm,
30
+ parser=parser,
31
+ max_retries=max_retries,
32
+ prompt_template_name="refinement/uml/fix_connections",
33
+ )
@@ -1,7 +1,16 @@
1
+ from typing import List
2
+
3
+ from langchain_core.documents import Document
4
+ from langchain_core.output_parsers import StrOutputParser
1
5
  from langchain_core.retrievers import BaseRetriever
2
6
  from langchain_core.runnables import Runnable, RunnableConfig
3
7
 
4
8
  from janus.language.block import CodeBlock
9
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
10
+ from janus.utils.logger import create_logger
11
+ from janus.utils.pdf_docs_reader import PDFDocsReader
12
+
13
+ log = create_logger(__name__)
5
14
 
6
15
 
7
16
  class JanusRetriever(Runnable):
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
40
49
  docs = self.retriever.invoke(code_block.text)
41
50
  context = "\n\n".join(doc.page_content for doc in docs)
42
51
  return f"You may use the following additional context: {context}"
52
+
53
+
54
+ class LanguageDocsRetriever(JanusRetriever):
55
+ def __init__(
56
+ self,
57
+ llm: JanusModel,
58
+ language_name: str,
59
+ prompt_template_name: str = "retrieval/language_docs",
60
+ ):
61
+ super().__init__()
62
+ self.llm: JanusModel = llm
63
+ self.language: str = language_name
64
+
65
+ self.PDF_reader = PDFDocsReader(
66
+ language=self.language,
67
+ )
68
+
69
+ language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
70
+ source_language=self.language,
71
+ prompt_template=prompt_template_name,
72
+ ).prompt
73
+
74
+ parser: StrOutputParser = StrOutputParser()
75
+ self.chain = language_docs_prompt | self.llm | parser
76
+
77
+ def get_context(self, code_block: CodeBlock) -> str:
78
+ functionality_to_reference: str = self.chain.invoke(
79
+ dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
80
+ )
81
+ if functionality_to_reference == "NODOCS":
82
+ log.debug("No Opcodes requested from language docs retriever.")
83
+ return ""
84
+ else:
85
+ functionality_to_reference: List = functionality_to_reference.split(", ")
86
+ log.debug(
87
+ f"List of opcodes requested by language docs retriever"
88
+ f"to search the {self.language} "
89
+ f"docs for: {functionality_to_reference}"
90
+ )
91
+
92
+ docs: List[Document] = self.PDF_reader.search_language_reference(
93
+ functionality_to_reference
94
+ )
95
+ context = "\n\n".join(doc.page_content for doc in docs)
96
+ if context:
97
+ return (
98
+ f"You may reference the following excerpts from the {self.language} "
99
+ f"language documentation: {context}"
100
+ )
101
+ else:
102
+ return ""
janus/utils/enums.py CHANGED
@@ -89,6 +89,20 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
89
89
  "url": "https://github.com/stsewd/tree-sitter-comment",
90
90
  "example": "# This is a comment\n",
91
91
  },
92
+ "cobol": {
93
+ "comment": "*",
94
+ "suffix": "cbl",
95
+ "url": "https://github.com/yutaro-sakamoto/tree-sitter-cobol",
96
+ "example": (
97
+ " IDENTIFICATION DIVISION.\n"
98
+ " PROGRAM-ID. HelloWorld.\n"
99
+ " ENVIRONMENT DIVISION.\n"
100
+ " DATA DIVISION.\n"
101
+ " PROCEDURE DIVISION.\n"
102
+ ' DISPLAY "Hello, World!".\n'
103
+ " STOP RUN.\n"
104
+ ),
105
+ },
92
106
  "commonlisp": {
93
107
  "comment": ";;",
94
108
  "suffix": "lisp",
@@ -0,0 +1,134 @@
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ import joblib
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_core.documents import Document
9
+ from langchain_unstructured import UnstructuredLoader
10
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+
17
+
18
+ class PDFDocsReader:
19
+ def __init__(
20
+ self,
21
+ language: str,
22
+ chunk_size: int = 1000,
23
+ chunk_overlap: int = 100,
24
+ start_page: Optional[int] = None,
25
+ end_page: Optional[int] = None,
26
+ vectorizer: CountVectorizer = TfidfVectorizer(),
27
+ ):
28
+ self.retrieval_docs_dir: Path = Path(
29
+ os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
30
+ )
31
+ self.language = language
32
+ self.chunk_size = chunk_size
33
+ self.chunk_overlap = chunk_overlap
34
+ self.start_page = start_page
35
+ self.end_page = end_page
36
+ self.vectorizer = vectorizer
37
+ self.documents = self.load_and_chunk_pdf()
38
+ self.doc_vectors = self.vectorize_documents()
39
+
40
+ def load_and_chunk_pdf(self) -> List[str]:
41
+ pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
42
+ pickled_documents_path = (
43
+ self.retrieval_docs_dir / f"{self.language}_documents.pkl"
44
+ )
45
+
46
+ if pickled_documents_path.exists():
47
+ log.debug(
48
+ f"Loading pre-chunked PDF from {pickled_documents_path}. "
49
+ f"If you want to regenerate retrieval docs for {self.language}, "
50
+ f"delete the file at {pickled_documents_path}, "
51
+ f"then add a new {self.language}.pdf."
52
+ )
53
+ documents = joblib.load(pickled_documents_path)
54
+ else:
55
+ if not pdf_path.exists():
56
+ raise FileNotFoundError(
57
+ f"Language docs retrieval is enabled, but no PDF for language "
58
+ f"'{self.language}' was found. Move a "
59
+ f"{self.language} reference manual to "
60
+ f"{pdf_path.absolute()} "
61
+ f"(the path to the directory of PDF docs can be "
62
+ f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
63
+ )
64
+ log.info(
65
+ f"Chunking reference PDF for {self.language} using unstructured - "
66
+ f"if your PDF has many pages, this could take a while..."
67
+ )
68
+ start_time = time.time()
69
+ loader = UnstructuredLoader(
70
+ pdf_path,
71
+ chunking_strategy="basic",
72
+ max_characters=1000000,
73
+ include_orig_elements=False,
74
+ start_page=self.start_page,
75
+ end_page=self.end_page,
76
+ )
77
+ docs = loader.load()
78
+ text = "\n\n".join([doc.page_content for doc in docs])
79
+ text_splitter = RecursiveCharacterTextSplitter(
80
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
81
+ )
82
+ documents = text_splitter.split_text(text)
83
+ log.info(f"Document store created for language: {self.language}")
84
+ end_time = time.time()
85
+ log.info(
86
+ f"Processing time for {self.language} PDF: "
87
+ f"{end_time - start_time} seconds"
88
+ )
89
+
90
+ joblib.dump(documents, pickled_documents_path)
91
+ log.debug(f"Documents saved to {pickled_documents_path}")
92
+
93
+ return documents
94
+
95
+ def vectorize_documents(self) -> (TfidfVectorizer, any):
96
+ doc_vectors = self.vectorizer.fit_transform(self.documents)
97
+ return doc_vectors
98
+
99
+ def search_language_reference(
100
+ self,
101
+ query: List[str],
102
+ top_k: int = 1,
103
+ min_similarity: float = 0.1,
104
+ ) -> List[Document]:
105
+ """Searches through the vectorized PDF for the query using
106
+ tf-idf and returns a list of langchain Documents."""
107
+
108
+ docs: List[Document] = []
109
+
110
+ for item in query:
111
+ # Transform the query using the TF-IDF vectorizer
112
+ query_vector = self.vectorizer.transform([item])
113
+
114
+ # Calculate cosine similarities between the query and document vectors
115
+ similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
116
+
117
+ # Get the indices of documents with similarity above the threshold
118
+ valid_indices = [
119
+ i for i, sim in enumerate(similarities) if sim >= min_similarity
120
+ ]
121
+
122
+ # Sort the valid indices by similarity score in descending order
123
+ sorted_indices = sorted(
124
+ valid_indices, key=lambda i: similarities[i], reverse=True
125
+ )
126
+
127
+ # Limit to top-k results
128
+ top_indices = sorted_indices[:top_k]
129
+
130
+ # Retrieve the top-k most relevant documents
131
+ docs += [Document(page_content=self.documents[i]) for i in top_indices]
132
+ log.debug(f"Langauge documentation search result: {docs}")
133
+
134
+ return docs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 4.1.0
3
+ Version: 4.3.1
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
23
23
  Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
24
24
  Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
25
25
  Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
26
+ Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
26
27
  Requires-Dist: nltk (>=3.8.1,<4.0.0)
27
28
  Requires-Dist: numpy (>=1.24.3,<2.0.0)
28
29
  Requires-Dist: openai (>=1.14.0,<2.0.0)
30
+ Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
29
31
  Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
30
32
  Requires-Dist: py-rouge (>=1.1,<2.0)
33
+ Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
31
34
  Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
32
35
  Requires-Dist: rich (>=13.7.1,<14.0.0)
33
36
  Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
37
+ Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
34
38
  Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
39
+ Requires-Dist: tesseract (>=0.1.3,<0.2.0)
35
40
  Requires-Dist: text-generation (>=0.6.0,<0.7.0)
36
41
  Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
37
42
  Requires-Dist: transformers (>=4.31.0,<5.0.0)
38
43
  Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
39
44
  Requires-Dist: typer (>=0.9.0,<0.10.0)
45
+ Requires-Dist: unstructured (>=0.15.9,<0.16.0)
46
+ Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
47
+ Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
40
48
  Project-URL: Documentation, https://janus-llm.github.io/janus-llm
41
49
  Project-URL: Repository, https://github.com/janus-llm/janus-llm
42
50
  Description-Content-Type: text/markdown