janus-llm 4.0.0__py3-none-any.whl → 4.2.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/llm/models_info.py CHANGED
@@ -1,15 +1,14 @@
1
1
  import json
2
2
  import os
3
- import time
4
3
  from pathlib import Path
5
- from typing import Protocol, TypeVar
4
+ from typing import Callable, Protocol, TypeVar
6
5
 
7
6
  from dotenv import load_dotenv
8
7
  from langchain_community.llms import HuggingFaceTextGenInference
9
8
  from langchain_core.runnables import Runnable
10
- from langchain_openai import ChatOpenAI
9
+ from langchain_openai import AzureChatOpenAI
11
10
 
12
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
13
12
  from janus.prompts.prompt import (
14
13
  ChatGptPromptEngine,
15
14
  ClaudePromptEngine,
@@ -46,7 +45,7 @@ except ImportError:
46
45
 
47
46
  ModelType = TypeVar(
48
47
  "ModelType",
49
- ChatOpenAI,
48
+ AzureChatOpenAI,
50
49
  HuggingFaceTextGenInference,
51
50
  Bedrock,
52
51
  BedrockChat,
@@ -72,7 +71,6 @@ class JanusModel(Runnable, JanusModelProtocol):
72
71
 
73
72
  load_dotenv()
74
73
 
75
-
76
74
  openai_models = [
77
75
  "gpt-4o",
78
76
  "gpt-4o-mini",
@@ -82,11 +80,17 @@ openai_models = [
82
80
  "gpt-3.5-turbo",
83
81
  "gpt-3.5-turbo-16k",
84
82
  ]
83
+ azure_models = [
84
+ "gpt-4o",
85
+ "gpt-4o-mini",
86
+ "gpt-3.5-turbo-16k",
87
+ ]
85
88
  claude_models = [
86
89
  "bedrock-claude-v2",
87
90
  "bedrock-claude-instant-v1",
88
91
  "bedrock-claude-haiku",
89
92
  "bedrock-claude-sonnet",
93
+ "bedrock-claude-sonnet-3.5",
90
94
  ]
91
95
  llama2_models = [
92
96
  "bedrock-llama2-70b",
@@ -120,18 +124,21 @@ bedrock_models = [
120
124
  *cohere_models,
121
125
  *mistral_models,
122
126
  ]
123
- all_models = [*openai_models, *bedrock_models]
127
+ all_models = [*azure_models, *bedrock_models]
124
128
 
125
129
  MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
126
- "OpenAI": ChatOpenAI,
130
+ # "OpenAI": ChatOpenAI,
127
131
  "HuggingFace": HuggingFaceTextGenInference,
132
+ "Azure": AzureChatOpenAI,
128
133
  "Bedrock": Bedrock,
129
134
  "BedrockChat": BedrockChat,
130
135
  "HuggingFaceLocal": HuggingFacePipeline,
131
136
  }
132
137
 
133
- MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
134
- **{m: ChatGptPromptEngine for m in openai_models},
138
+
139
+ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
140
+ # **{m: ChatGptPromptEngine for m in openai_models},
141
+ **{m: ChatGptPromptEngine for m in azure_models},
135
142
  **{m: ClaudePromptEngine for m in claude_models},
136
143
  **{m: Llama2PromptEngine for m in llama2_models},
137
144
  **{m: Llama3PromptEngine for m in llama3_models},
@@ -141,11 +148,13 @@ MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
141
148
  }
142
149
 
143
150
  MODEL_ID_TO_LONG_ID = {
144
- **{m: mr for m, mr in openai_model_reroutes.items()},
151
+ # **{m: mr for m, mr in openai_model_reroutes.items()},
152
+ **{m: mr for m, mr in azure_model_reroutes.items()},
145
153
  "bedrock-claude-v2": "anthropic.claude-v2",
146
154
  "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
147
155
  "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
148
156
  "bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
157
+ "bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
149
158
  "bedrock-llama2-70b": "meta.llama2-70b-v1",
150
159
  "bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
151
160
  "bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
@@ -171,8 +180,9 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
171
180
 
172
181
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
173
182
 
174
- MODEL_TYPES: dict[str, str] = {
175
- **{m: "OpenAI" for m in openai_models},
183
+ MODEL_TYPES: dict[str, PromptEngine] = {
184
+ # **{m: "OpenAI" for m in openai_models},
185
+ **{m: "Azure" for m in azure_models},
176
186
  **{m: "BedrockChat" for m in bedrock_models},
177
187
  }
178
188
 
@@ -182,13 +192,17 @@ TOKEN_LIMITS: dict[str, int] = {
182
192
  "gpt-4-1106-preview": 128_000,
183
193
  "gpt-4-0125-preview": 128_000,
184
194
  "gpt-4o-2024-05-13": 128_000,
195
+ "gpt-4o-2024-08-06": 128_000,
196
+ "gpt-4o-mini": 128_000,
185
197
  "gpt-3.5-turbo-0125": 16_384,
198
+ "gpt35-turbo-16k": 16_384,
186
199
  "text-embedding-ada-002": 8191,
187
200
  "gpt4all": 16_384,
188
201
  "anthropic.claude-v2": 100_000,
189
202
  "anthropic.claude-instant-v1": 100_000,
190
203
  "anthropic.claude-3-haiku-20240307-v1:0": 248_000,
191
204
  "anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
205
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
192
206
  "meta.llama2-70b-v1": 4096,
193
207
  "meta.llama2-70b-chat-v1": 4096,
194
208
  "meta.llama2-13b-chat-v1": 4096,
@@ -270,11 +284,21 @@ def load_model(model_id) -> JanusModel:
270
284
  openai_api_key=str(os.getenv("OPENAI_API_KEY")),
271
285
  openai_organization=str(os.getenv("OPENAI_ORG_ID")),
272
286
  )
273
- log.warning("Do NOT use this model in sensitive environments!")
274
- log.warning("If you would like to cancel, please press Ctrl+C.")
275
- log.warning("Waiting 10 seconds...")
287
+ # log.warning("Do NOT use this model in sensitive environments!")
288
+ # log.warning("If you would like to cancel, please press Ctrl+C.")
289
+ # log.warning("Waiting 10 seconds...")
276
290
  # Give enough time for the user to read the warnings and cancel
277
- time.sleep(10)
291
+ # time.sleep(10)
292
+ raise DeprecationWarning("OpenAI models are no longer supported.")
293
+
294
+ elif model_type_name == "Azure":
295
+ model_args.update(
296
+ {
297
+ "api_key": os.getenv("AZURE_OPENAI_API_KEY"),
298
+ "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
299
+ "api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
300
+ }
301
+ )
278
302
 
279
303
  model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
280
304
  prompt_engine = MODEL_PROMPT_ENGINES[model_id]
@@ -0,0 +1,136 @@
1
+ import json
2
+ import random
3
+ import uuid
4
+
5
+ from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.language_models import BaseLanguageModel
8
+ from langchain_core.messages import BaseMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field
10
+
11
+ from janus.language.block import CodeBlock
12
+ from janus.parsers.parser import JanusParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+ RNG = random.Random()
17
+
18
+
19
+ class PartitionObject(BaseModel):
20
+ reasoning: str = Field(
21
+ description="An explanation for why the code should be split at this point"
22
+ )
23
+ location: str = Field(
24
+ description="The 8-character line label which should start a new chunk"
25
+ )
26
+
27
+
28
+ class PartitionList(BaseModel):
29
+ __root__: list[PartitionObject] = Field(
30
+ description=(
31
+ "A list of appropriate split points, each with a `reasoning` field "
32
+ "that explains a justification for splitting the code at that point, "
33
+ "and a `location` field which is simply the 8-character line ID. "
34
+ "The `reasoning` field should always be included first."
35
+ )
36
+ )
37
+
38
+
39
+ class PartitionParser(JanusParser, PydanticOutputParser):
40
+ token_limit: int
41
+ model: BaseLanguageModel
42
+ lines: list[str] = []
43
+ line_id_to_index: dict[str, int] = {}
44
+
45
+ def __init__(self, token_limit: int, model: BaseLanguageModel):
46
+ PydanticOutputParser.__init__(
47
+ self,
48
+ pydantic_object=PartitionList,
49
+ model=model,
50
+ token_limit=token_limit,
51
+ )
52
+
53
+ def parse_input(self, block: CodeBlock) -> str:
54
+ code = str(block.text)
55
+ RNG.seed(code)
56
+
57
+ self.lines = code.split("\n")
58
+
59
+ # Generate a unique ID for each line (ensure they are unique)
60
+ line_ids = set()
61
+ while len(line_ids) < len(self.lines):
62
+ line_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
63
+
64
+ # Prepend each line with the corresponding ID, save the mapping
65
+ self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
66
+ processed = "\n".join(
67
+ f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
68
+ )
69
+ return processed
70
+
71
+ def parse(self, text: str | BaseMessage) -> str:
72
+ if isinstance(text, BaseMessage):
73
+ text = str(text.content)
74
+
75
+ try:
76
+ out: PartitionList = super().parse(text)
77
+ except (OutputParserException, json.JSONDecodeError):
78
+ log.debug(f"Invalid JSON object. Output:\n{text}")
79
+ raise
80
+
81
+ # Locate any invalid line IDs, raise exception if any found
82
+ invalid_splits = [
83
+ partition.location
84
+ for partition in out.__root__
85
+ if partition.location not in self.line_id_to_index
86
+ ]
87
+ if invalid_splits:
88
+ err_msg = (
89
+ f"{len(invalid_splits)} line ID(s) not found in input: "
90
+ + ", ".join(invalid_splits)
91
+ )
92
+ log.warning(err_msg)
93
+ raise OutputParserException(err_msg)
94
+
95
+ # Map line IDs to indices (so they can be sorted and lines indexed)
96
+ index_to_line_id = {0: "START", None: "END"}
97
+ split_points = {0}
98
+ for partition in out.__root__:
99
+ index = self.line_id_to_index[partition.location]
100
+ index_to_line_id[index] = partition.location
101
+ split_points.add(index)
102
+
103
+ # Get partition start/ends, chunks, chunk lengths
104
+ split_points = sorted(split_points) + [None]
105
+ partition_indices = list(zip(split_points, split_points[1:]))
106
+ partition_points = [
107
+ (index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
108
+ ]
109
+ chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
110
+ chunk_tokens = list(map(self.model.get_num_tokens, chunks))
111
+
112
+ # Collect any chunks that exceed token limit
113
+ oversized_indices: list[int] = [
114
+ i for i, n in enumerate(chunk_tokens) if n > self.token_limit
115
+ ]
116
+ if oversized_indices:
117
+ data = list(zip(partition_points, chunks, chunk_tokens))
118
+ data = [data[i] for i in oversized_indices]
119
+
120
+ problem_points = "\n".join(
121
+ [
122
+ f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
123
+ for (i0, i1), _, t in data
124
+ ]
125
+ )
126
+ log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
127
+ log.debug(
128
+ "Oversized chunks:\n"
129
+ + "\n#############\n".join(chunk for _, chunk, _ in data)
130
+ )
131
+ raise OutputParserException(
132
+ f"The following segments are too long and must be "
133
+ f"further subdivided:\n{problem_points}"
134
+ )
135
+
136
+ return "\n<JANUS_PARTITION>\n".join(chunks)
janus/refiners/refiner.py CHANGED
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from typing import Any
2
3
 
3
4
  from langchain.output_parsers import RetryWithErrorOutputParser
@@ -27,7 +28,7 @@ class JanusRefiner(JanusParser):
27
28
 
28
29
  class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
29
30
  def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
30
- retry_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
31
+ retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
31
32
  source_language="text",
32
33
  prompt_template="refinement/fix_exceptions",
33
34
  ).prompt
@@ -46,6 +47,7 @@ class ReflectionRefiner(JanusRefiner):
46
47
  max_retries: int
47
48
  reflection_chain: RunnableSerializable
48
49
  revision_chain: RunnableSerializable
50
+ reflection_prompt_name: str
49
51
 
50
52
  def __init__(
51
53
  self,
@@ -54,11 +56,11 @@ class ReflectionRefiner(JanusRefiner):
54
56
  max_retries: int,
55
57
  prompt_template_name: str = "refinement/reflection",
56
58
  ):
57
- reflection_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
59
+ reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
58
60
  source_language="text",
59
61
  prompt_template=prompt_template_name,
60
62
  ).prompt
61
- revision_prompt = MODEL_PROMPT_ENGINES[llm.model_id](
63
+ revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
62
64
  source_language="text",
63
65
  prompt_template="refinement/revision",
64
66
  ).prompt
@@ -66,6 +68,7 @@ class ReflectionRefiner(JanusRefiner):
66
68
  reflection_chain = reflection_prompt | llm | StrOutputParser()
67
69
  revision_chain = revision_prompt | llm | StrOutputParser()
68
70
  super().__init__(
71
+ reflection_prompt_name=prompt_template_name,
69
72
  reflection_chain=reflection_chain,
70
73
  revision_chain=revision_chain,
71
74
  parser=parser,
@@ -75,6 +78,7 @@ class ReflectionRefiner(JanusRefiner):
75
78
  def parse_completion(
76
79
  self, completion: str, prompt_value: PromptValue, **kwargs
77
80
  ) -> Any:
81
+ log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
78
82
  for retry_number in range(self.max_retries):
79
83
  reflection = self.reflection_chain.invoke(
80
84
  dict(
@@ -82,7 +86,7 @@ class ReflectionRefiner(JanusRefiner):
82
86
  completion=completion,
83
87
  )
84
88
  )
85
- if reflection.strip() == "LGTM":
89
+ if re.search(r"\bLGTM\b", reflection) is not None:
86
90
  return self.parser.parse(completion)
87
91
  if not retry_number:
88
92
  log.info(f"Completion:\n{completion}")
@@ -105,11 +109,3 @@ class HallucinationRefiner(ReflectionRefiner):
105
109
  prompt_template_name="refinement/hallucination",
106
110
  **kwargs,
107
111
  )
108
-
109
-
110
- REFINERS = dict(
111
- none=JanusRefiner,
112
- parser=FixParserExceptions,
113
- reflection=ReflectionRefiner,
114
- hallucination=HallucinationRefiner,
115
- )
janus/refiners/uml.py ADDED
@@ -0,0 +1,33 @@
1
+ from janus.llm.models_info import JanusModel
2
+ from janus.parsers.parser import JanusParser
3
+ from janus.refiners.refiner import ReflectionRefiner
4
+
5
+
6
+ class ALCFixUMLVariablesRefiner(ReflectionRefiner):
7
+ def __init__(
8
+ self,
9
+ llm: JanusModel,
10
+ parser: JanusParser,
11
+ max_retries: int,
12
+ ):
13
+ super().__init__(
14
+ llm=llm,
15
+ parser=parser,
16
+ max_retries=max_retries,
17
+ prompt_template_name="refinement/uml/alc_fix_variables",
18
+ )
19
+
20
+
21
+ class FixUMLConnectionsRefiner(ReflectionRefiner):
22
+ def __init__(
23
+ self,
24
+ llm: JanusModel,
25
+ parser: JanusParser,
26
+ max_retries: int,
27
+ ):
28
+ super().__init__(
29
+ llm=llm,
30
+ parser=parser,
31
+ max_retries=max_retries,
32
+ prompt_template_name="refinement/uml/fix_connections",
33
+ )
@@ -1,7 +1,16 @@
1
+ from typing import List
2
+
3
+ from langchain_core.documents import Document
4
+ from langchain_core.output_parsers import StrOutputParser
1
5
  from langchain_core.retrievers import BaseRetriever
2
6
  from langchain_core.runnables import Runnable, RunnableConfig
3
7
 
4
8
  from janus.language.block import CodeBlock
9
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
10
+ from janus.utils.logger import create_logger
11
+ from janus.utils.pdf_docs_reader import PDFDocsReader
12
+
13
+ log = create_logger(__name__)
5
14
 
6
15
 
7
16
  class JanusRetriever(Runnable):
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
40
49
  docs = self.retriever.invoke(code_block.text)
41
50
  context = "\n\n".join(doc.page_content for doc in docs)
42
51
  return f"You may use the following additional context: {context}"
52
+
53
+
54
+ class LanguageDocsRetriever(JanusRetriever):
55
+ def __init__(
56
+ self,
57
+ llm: JanusModel,
58
+ language_name: str,
59
+ prompt_template_name: str = "retrieval/language_docs",
60
+ ):
61
+ super().__init__()
62
+ self.llm: JanusModel = llm
63
+ self.language: str = language_name
64
+
65
+ self.PDF_reader = PDFDocsReader(
66
+ language=self.language,
67
+ )
68
+
69
+ language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
70
+ source_language=self.language,
71
+ prompt_template=prompt_template_name,
72
+ ).prompt
73
+
74
+ parser: StrOutputParser = StrOutputParser()
75
+ self.chain = language_docs_prompt | self.llm | parser
76
+
77
+ def get_context(self, code_block: CodeBlock) -> str:
78
+ functionality_to_reference: str = self.chain.invoke(
79
+ dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
80
+ )
81
+ if functionality_to_reference == "NODOCS":
82
+ log.debug("No Opcodes requested from language docs retriever.")
83
+ return ""
84
+ else:
85
+ functionality_to_reference: List = functionality_to_reference.split(", ")
86
+ log.debug(
87
+ f"List of opcodes requested by language docs retriever"
88
+ f"to search the {self.language} "
89
+ f"docs for: {functionality_to_reference}"
90
+ )
91
+
92
+ docs: List[Document] = self.PDF_reader.search_language_reference(
93
+ functionality_to_reference
94
+ )
95
+ context = "\n\n".join(doc.page_content for doc in docs)
96
+ if context:
97
+ return (
98
+ f"You may reference the following excerpts from the {self.language} "
99
+ f"language documentation: {context}"
100
+ )
101
+ else:
102
+ return ""
@@ -0,0 +1,134 @@
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ import joblib
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_core.documents import Document
9
+ from langchain_unstructured import UnstructuredLoader
10
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+
17
+
18
+ class PDFDocsReader:
19
+ def __init__(
20
+ self,
21
+ language: str,
22
+ chunk_size: int = 1000,
23
+ chunk_overlap: int = 100,
24
+ start_page: Optional[int] = None,
25
+ end_page: Optional[int] = None,
26
+ vectorizer: CountVectorizer = TfidfVectorizer(),
27
+ ):
28
+ self.retrieval_docs_dir: Path = Path(
29
+ os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
30
+ )
31
+ self.language = language
32
+ self.chunk_size = chunk_size
33
+ self.chunk_overlap = chunk_overlap
34
+ self.start_page = start_page
35
+ self.end_page = end_page
36
+ self.vectorizer = vectorizer
37
+ self.documents = self.load_and_chunk_pdf()
38
+ self.doc_vectors = self.vectorize_documents()
39
+
40
+ def load_and_chunk_pdf(self) -> List[str]:
41
+ pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
42
+ pickled_documents_path = (
43
+ self.retrieval_docs_dir / f"{self.language}_documents.pkl"
44
+ )
45
+
46
+ if pickled_documents_path.exists():
47
+ log.debug(
48
+ f"Loading pre-chunked PDF from {pickled_documents_path}. "
49
+ f"If you want to regenerate retrieval docs for {self.language}, "
50
+ f"delete the file at {pickled_documents_path}, "
51
+ f"then add a new {self.language}.pdf."
52
+ )
53
+ documents = joblib.load(pickled_documents_path)
54
+ else:
55
+ if not pdf_path.exists():
56
+ raise FileNotFoundError(
57
+ f"Language docs retrieval is enabled, but no PDF for language "
58
+ f"'{self.language}' was found. Move a "
59
+ f"{self.language} reference manual to "
60
+ f"{pdf_path.absolute()} "
61
+ f"(the path to the directory of PDF docs can be "
62
+ f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
63
+ )
64
+ log.info(
65
+ f"Chunking reference PDF for {self.language} using unstructured - "
66
+ f"if your PDF has many pages, this could take a while..."
67
+ )
68
+ start_time = time.time()
69
+ loader = UnstructuredLoader(
70
+ pdf_path,
71
+ chunking_strategy="basic",
72
+ max_characters=1000000,
73
+ include_orig_elements=False,
74
+ start_page=self.start_page,
75
+ end_page=self.end_page,
76
+ )
77
+ docs = loader.load()
78
+ text = "\n\n".join([doc.page_content for doc in docs])
79
+ text_splitter = RecursiveCharacterTextSplitter(
80
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
81
+ )
82
+ documents = text_splitter.split_text(text)
83
+ log.info(f"Document store created for language: {self.language}")
84
+ end_time = time.time()
85
+ log.info(
86
+ f"Processing time for {self.language} PDF: "
87
+ f"{end_time - start_time} seconds"
88
+ )
89
+
90
+ joblib.dump(documents, pickled_documents_path)
91
+ log.debug(f"Documents saved to {pickled_documents_path}")
92
+
93
+ return documents
94
+
95
+ def vectorize_documents(self) -> (TfidfVectorizer, any):
96
+ doc_vectors = self.vectorizer.fit_transform(self.documents)
97
+ return doc_vectors
98
+
99
+ def search_language_reference(
100
+ self,
101
+ query: List[str],
102
+ top_k: int = 1,
103
+ min_similarity: float = 0.1,
104
+ ) -> List[Document]:
105
+ """Searches through the vectorized PDF for the query using
106
+ tf-idf and returns a list of langchain Documents."""
107
+
108
+ docs: List[Document] = []
109
+
110
+ for item in query:
111
+ # Transform the query using the TF-IDF vectorizer
112
+ query_vector = self.vectorizer.transform([item])
113
+
114
+ # Calculate cosine similarities between the query and document vectors
115
+ similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
116
+
117
+ # Get the indices of documents with similarity above the threshold
118
+ valid_indices = [
119
+ i for i, sim in enumerate(similarities) if sim >= min_similarity
120
+ ]
121
+
122
+ # Sort the valid indices by similarity score in descending order
123
+ sorted_indices = sorted(
124
+ valid_indices, key=lambda i: similarities[i], reverse=True
125
+ )
126
+
127
+ # Limit to top-k results
128
+ top_indices = sorted_indices[:top_k]
129
+
130
+ # Retrieve the top-k most relevant documents
131
+ docs += [Document(page_content=self.documents[i]) for i in top_indices]
132
+ log.debug(f"Langauge documentation search result: {docs}")
133
+
134
+ return docs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 4.0.0
3
+ Version: 4.2.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
23
23
  Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
24
24
  Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
25
25
  Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
26
+ Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
26
27
  Requires-Dist: nltk (>=3.8.1,<4.0.0)
27
28
  Requires-Dist: numpy (>=1.24.3,<2.0.0)
28
29
  Requires-Dist: openai (>=1.14.0,<2.0.0)
30
+ Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
29
31
  Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
30
32
  Requires-Dist: py-rouge (>=1.1,<2.0)
33
+ Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
31
34
  Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
32
35
  Requires-Dist: rich (>=13.7.1,<14.0.0)
33
36
  Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
37
+ Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
34
38
  Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
39
+ Requires-Dist: tesseract (>=0.1.3,<0.2.0)
35
40
  Requires-Dist: text-generation (>=0.6.0,<0.7.0)
36
41
  Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
37
42
  Requires-Dist: transformers (>=4.31.0,<5.0.0)
38
43
  Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
39
44
  Requires-Dist: typer (>=0.9.0,<0.10.0)
45
+ Requires-Dist: unstructured (>=0.15.9,<0.16.0)
46
+ Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
47
+ Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
40
48
  Project-URL: Documentation, https://janus-llm.github.io/janus-llm
41
49
  Project-URL: Repository, https://github.com/janus-llm/janus-llm
42
50
  Description-Content-Type: text/markdown