janus-llm 4.0.0__py3-none-any.whl → 4.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +161 -26
- janus/converter/__init__.py +1 -0
- janus/converter/_tests/test_translate.py +2 -2
- janus/converter/converter.py +45 -47
- janus/converter/partition.py +27 -0
- janus/language/combine.py +22 -0
- janus/llm/model_callbacks.py +9 -0
- janus/llm/models_info.py +41 -17
- janus/parsers/partition_parser.py +136 -0
- janus/refiners/refiner.py +8 -12
- janus/refiners/uml.py +33 -0
- janus/retrievers/retriever.py +60 -0
- janus/utils/pdf_docs_reader.py +134 -0
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/METADATA +9 -1
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/RECORD +19 -15
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/WHEEL +1 -1
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/LICENSE +0 -0
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/entry_points.txt +0 -0
janus/llm/models_info.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
-
import time
|
4
3
|
from pathlib import Path
|
5
|
-
from typing import Protocol, TypeVar
|
4
|
+
from typing import Callable, Protocol, TypeVar
|
6
5
|
|
7
6
|
from dotenv import load_dotenv
|
8
7
|
from langchain_community.llms import HuggingFaceTextGenInference
|
9
8
|
from langchain_core.runnables import Runnable
|
10
|
-
from langchain_openai import
|
9
|
+
from langchain_openai import AzureChatOpenAI
|
11
10
|
|
12
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS,
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
|
13
12
|
from janus.prompts.prompt import (
|
14
13
|
ChatGptPromptEngine,
|
15
14
|
ClaudePromptEngine,
|
@@ -46,7 +45,7 @@ except ImportError:
|
|
46
45
|
|
47
46
|
ModelType = TypeVar(
|
48
47
|
"ModelType",
|
49
|
-
|
48
|
+
AzureChatOpenAI,
|
50
49
|
HuggingFaceTextGenInference,
|
51
50
|
Bedrock,
|
52
51
|
BedrockChat,
|
@@ -72,7 +71,6 @@ class JanusModel(Runnable, JanusModelProtocol):
|
|
72
71
|
|
73
72
|
load_dotenv()
|
74
73
|
|
75
|
-
|
76
74
|
openai_models = [
|
77
75
|
"gpt-4o",
|
78
76
|
"gpt-4o-mini",
|
@@ -82,11 +80,17 @@ openai_models = [
|
|
82
80
|
"gpt-3.5-turbo",
|
83
81
|
"gpt-3.5-turbo-16k",
|
84
82
|
]
|
83
|
+
azure_models = [
|
84
|
+
"gpt-4o",
|
85
|
+
"gpt-4o-mini",
|
86
|
+
"gpt-3.5-turbo-16k",
|
87
|
+
]
|
85
88
|
claude_models = [
|
86
89
|
"bedrock-claude-v2",
|
87
90
|
"bedrock-claude-instant-v1",
|
88
91
|
"bedrock-claude-haiku",
|
89
92
|
"bedrock-claude-sonnet",
|
93
|
+
"bedrock-claude-sonnet-3.5",
|
90
94
|
]
|
91
95
|
llama2_models = [
|
92
96
|
"bedrock-llama2-70b",
|
@@ -120,18 +124,21 @@ bedrock_models = [
|
|
120
124
|
*cohere_models,
|
121
125
|
*mistral_models,
|
122
126
|
]
|
123
|
-
all_models = [*
|
127
|
+
all_models = [*azure_models, *bedrock_models]
|
124
128
|
|
125
129
|
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
126
|
-
"OpenAI": ChatOpenAI,
|
130
|
+
# "OpenAI": ChatOpenAI,
|
127
131
|
"HuggingFace": HuggingFaceTextGenInference,
|
132
|
+
"Azure": AzureChatOpenAI,
|
128
133
|
"Bedrock": Bedrock,
|
129
134
|
"BedrockChat": BedrockChat,
|
130
135
|
"HuggingFaceLocal": HuggingFacePipeline,
|
131
136
|
}
|
132
137
|
|
133
|
-
|
134
|
-
|
138
|
+
|
139
|
+
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
140
|
+
# **{m: ChatGptPromptEngine for m in openai_models},
|
141
|
+
**{m: ChatGptPromptEngine for m in azure_models},
|
135
142
|
**{m: ClaudePromptEngine for m in claude_models},
|
136
143
|
**{m: Llama2PromptEngine for m in llama2_models},
|
137
144
|
**{m: Llama3PromptEngine for m in llama3_models},
|
@@ -141,11 +148,13 @@ MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
|
|
141
148
|
}
|
142
149
|
|
143
150
|
MODEL_ID_TO_LONG_ID = {
|
144
|
-
**{m: mr for m, mr in openai_model_reroutes.items()},
|
151
|
+
# **{m: mr for m, mr in openai_model_reroutes.items()},
|
152
|
+
**{m: mr for m, mr in azure_model_reroutes.items()},
|
145
153
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
146
154
|
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
147
155
|
"bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
148
156
|
"bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
157
|
+
"bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
149
158
|
"bedrock-llama2-70b": "meta.llama2-70b-v1",
|
150
159
|
"bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
|
151
160
|
"bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
|
@@ -171,8 +180,9 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
171
180
|
|
172
181
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
173
182
|
|
174
|
-
MODEL_TYPES: dict[str,
|
175
|
-
**{m: "OpenAI" for m in openai_models},
|
183
|
+
MODEL_TYPES: dict[str, PromptEngine] = {
|
184
|
+
# **{m: "OpenAI" for m in openai_models},
|
185
|
+
**{m: "Azure" for m in azure_models},
|
176
186
|
**{m: "BedrockChat" for m in bedrock_models},
|
177
187
|
}
|
178
188
|
|
@@ -182,13 +192,17 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
182
192
|
"gpt-4-1106-preview": 128_000,
|
183
193
|
"gpt-4-0125-preview": 128_000,
|
184
194
|
"gpt-4o-2024-05-13": 128_000,
|
195
|
+
"gpt-4o-2024-08-06": 128_000,
|
196
|
+
"gpt-4o-mini": 128_000,
|
185
197
|
"gpt-3.5-turbo-0125": 16_384,
|
198
|
+
"gpt35-turbo-16k": 16_384,
|
186
199
|
"text-embedding-ada-002": 8191,
|
187
200
|
"gpt4all": 16_384,
|
188
201
|
"anthropic.claude-v2": 100_000,
|
189
202
|
"anthropic.claude-instant-v1": 100_000,
|
190
203
|
"anthropic.claude-3-haiku-20240307-v1:0": 248_000,
|
191
204
|
"anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
|
205
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
|
192
206
|
"meta.llama2-70b-v1": 4096,
|
193
207
|
"meta.llama2-70b-chat-v1": 4096,
|
194
208
|
"meta.llama2-13b-chat-v1": 4096,
|
@@ -270,11 +284,21 @@ def load_model(model_id) -> JanusModel:
|
|
270
284
|
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
271
285
|
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
272
286
|
)
|
273
|
-
log.warning("Do NOT use this model in sensitive environments!")
|
274
|
-
log.warning("If you would like to cancel, please press Ctrl+C.")
|
275
|
-
log.warning("Waiting 10 seconds...")
|
287
|
+
# log.warning("Do NOT use this model in sensitive environments!")
|
288
|
+
# log.warning("If you would like to cancel, please press Ctrl+C.")
|
289
|
+
# log.warning("Waiting 10 seconds...")
|
276
290
|
# Give enough time for the user to read the warnings and cancel
|
277
|
-
time.sleep(10)
|
291
|
+
# time.sleep(10)
|
292
|
+
raise DeprecationWarning("OpenAI models are no longer supported.")
|
293
|
+
|
294
|
+
elif model_type_name == "Azure":
|
295
|
+
model_args.update(
|
296
|
+
{
|
297
|
+
"api_key": os.getenv("AZURE_OPENAI_API_KEY"),
|
298
|
+
"azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
|
299
|
+
"api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
|
300
|
+
}
|
301
|
+
)
|
278
302
|
|
279
303
|
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
280
304
|
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
@@ -0,0 +1,136 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.language_models import BaseLanguageModel
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class PartitionObject(BaseModel):
|
20
|
+
reasoning: str = Field(
|
21
|
+
description="An explanation for why the code should be split at this point"
|
22
|
+
)
|
23
|
+
location: str = Field(
|
24
|
+
description="The 8-character line label which should start a new chunk"
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class PartitionList(BaseModel):
|
29
|
+
__root__: list[PartitionObject] = Field(
|
30
|
+
description=(
|
31
|
+
"A list of appropriate split points, each with a `reasoning` field "
|
32
|
+
"that explains a justification for splitting the code at that point, "
|
33
|
+
"and a `location` field which is simply the 8-character line ID. "
|
34
|
+
"The `reasoning` field should always be included first."
|
35
|
+
)
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
class PartitionParser(JanusParser, PydanticOutputParser):
|
40
|
+
token_limit: int
|
41
|
+
model: BaseLanguageModel
|
42
|
+
lines: list[str] = []
|
43
|
+
line_id_to_index: dict[str, int] = {}
|
44
|
+
|
45
|
+
def __init__(self, token_limit: int, model: BaseLanguageModel):
|
46
|
+
PydanticOutputParser.__init__(
|
47
|
+
self,
|
48
|
+
pydantic_object=PartitionList,
|
49
|
+
model=model,
|
50
|
+
token_limit=token_limit,
|
51
|
+
)
|
52
|
+
|
53
|
+
def parse_input(self, block: CodeBlock) -> str:
|
54
|
+
code = str(block.text)
|
55
|
+
RNG.seed(code)
|
56
|
+
|
57
|
+
self.lines = code.split("\n")
|
58
|
+
|
59
|
+
# Generate a unique ID for each line (ensure they are unique)
|
60
|
+
line_ids = set()
|
61
|
+
while len(line_ids) < len(self.lines):
|
62
|
+
line_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
|
63
|
+
|
64
|
+
# Prepend each line with the corresponding ID, save the mapping
|
65
|
+
self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
|
66
|
+
processed = "\n".join(
|
67
|
+
f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
|
68
|
+
)
|
69
|
+
return processed
|
70
|
+
|
71
|
+
def parse(self, text: str | BaseMessage) -> str:
|
72
|
+
if isinstance(text, BaseMessage):
|
73
|
+
text = str(text.content)
|
74
|
+
|
75
|
+
try:
|
76
|
+
out: PartitionList = super().parse(text)
|
77
|
+
except (OutputParserException, json.JSONDecodeError):
|
78
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
79
|
+
raise
|
80
|
+
|
81
|
+
# Locate any invalid line IDs, raise exception if any found
|
82
|
+
invalid_splits = [
|
83
|
+
partition.location
|
84
|
+
for partition in out.__root__
|
85
|
+
if partition.location not in self.line_id_to_index
|
86
|
+
]
|
87
|
+
if invalid_splits:
|
88
|
+
err_msg = (
|
89
|
+
f"{len(invalid_splits)} line ID(s) not found in input: "
|
90
|
+
+ ", ".join(invalid_splits)
|
91
|
+
)
|
92
|
+
log.warning(err_msg)
|
93
|
+
raise OutputParserException(err_msg)
|
94
|
+
|
95
|
+
# Map line IDs to indices (so they can be sorted and lines indexed)
|
96
|
+
index_to_line_id = {0: "START", None: "END"}
|
97
|
+
split_points = {0}
|
98
|
+
for partition in out.__root__:
|
99
|
+
index = self.line_id_to_index[partition.location]
|
100
|
+
index_to_line_id[index] = partition.location
|
101
|
+
split_points.add(index)
|
102
|
+
|
103
|
+
# Get partition start/ends, chunks, chunk lengths
|
104
|
+
split_points = sorted(split_points) + [None]
|
105
|
+
partition_indices = list(zip(split_points, split_points[1:]))
|
106
|
+
partition_points = [
|
107
|
+
(index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
|
108
|
+
]
|
109
|
+
chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
|
110
|
+
chunk_tokens = list(map(self.model.get_num_tokens, chunks))
|
111
|
+
|
112
|
+
# Collect any chunks that exceed token limit
|
113
|
+
oversized_indices: list[int] = [
|
114
|
+
i for i, n in enumerate(chunk_tokens) if n > self.token_limit
|
115
|
+
]
|
116
|
+
if oversized_indices:
|
117
|
+
data = list(zip(partition_points, chunks, chunk_tokens))
|
118
|
+
data = [data[i] for i in oversized_indices]
|
119
|
+
|
120
|
+
problem_points = "\n".join(
|
121
|
+
[
|
122
|
+
f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
|
123
|
+
for (i0, i1), _, t in data
|
124
|
+
]
|
125
|
+
)
|
126
|
+
log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
|
127
|
+
log.debug(
|
128
|
+
"Oversized chunks:\n"
|
129
|
+
+ "\n#############\n".join(chunk for _, chunk, _ in data)
|
130
|
+
)
|
131
|
+
raise OutputParserException(
|
132
|
+
f"The following segments are too long and must be "
|
133
|
+
f"further subdivided:\n{problem_points}"
|
134
|
+
)
|
135
|
+
|
136
|
+
return "\n<JANUS_PARTITION>\n".join(chunks)
|
janus/refiners/refiner.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Any
|
2
3
|
|
3
4
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
@@ -27,7 +28,7 @@ class JanusRefiner(JanusParser):
|
|
27
28
|
|
28
29
|
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
30
|
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
-
retry_prompt = MODEL_PROMPT_ENGINES[llm.
|
31
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
31
32
|
source_language="text",
|
32
33
|
prompt_template="refinement/fix_exceptions",
|
33
34
|
).prompt
|
@@ -46,6 +47,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
46
47
|
max_retries: int
|
47
48
|
reflection_chain: RunnableSerializable
|
48
49
|
revision_chain: RunnableSerializable
|
50
|
+
reflection_prompt_name: str
|
49
51
|
|
50
52
|
def __init__(
|
51
53
|
self,
|
@@ -54,11 +56,11 @@ class ReflectionRefiner(JanusRefiner):
|
|
54
56
|
max_retries: int,
|
55
57
|
prompt_template_name: str = "refinement/reflection",
|
56
58
|
):
|
57
|
-
reflection_prompt = MODEL_PROMPT_ENGINES[llm.
|
59
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
58
60
|
source_language="text",
|
59
61
|
prompt_template=prompt_template_name,
|
60
62
|
).prompt
|
61
|
-
revision_prompt = MODEL_PROMPT_ENGINES[llm.
|
63
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
62
64
|
source_language="text",
|
63
65
|
prompt_template="refinement/revision",
|
64
66
|
).prompt
|
@@ -66,6 +68,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
66
68
|
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
69
|
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
70
|
super().__init__(
|
71
|
+
reflection_prompt_name=prompt_template_name,
|
69
72
|
reflection_chain=reflection_chain,
|
70
73
|
revision_chain=revision_chain,
|
71
74
|
parser=parser,
|
@@ -75,6 +78,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
75
78
|
def parse_completion(
|
76
79
|
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
80
|
) -> Any:
|
81
|
+
log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
|
78
82
|
for retry_number in range(self.max_retries):
|
79
83
|
reflection = self.reflection_chain.invoke(
|
80
84
|
dict(
|
@@ -82,7 +86,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
82
86
|
completion=completion,
|
83
87
|
)
|
84
88
|
)
|
85
|
-
if
|
89
|
+
if re.search(r"\bLGTM\b", reflection) is not None:
|
86
90
|
return self.parser.parse(completion)
|
87
91
|
if not retry_number:
|
88
92
|
log.info(f"Completion:\n{completion}")
|
@@ -105,11 +109,3 @@ class HallucinationRefiner(ReflectionRefiner):
|
|
105
109
|
prompt_template_name="refinement/hallucination",
|
106
110
|
**kwargs,
|
107
111
|
)
|
108
|
-
|
109
|
-
|
110
|
-
REFINERS = dict(
|
111
|
-
none=JanusRefiner,
|
112
|
-
parser=FixParserExceptions,
|
113
|
-
reflection=ReflectionRefiner,
|
114
|
-
hallucination=HallucinationRefiner,
|
115
|
-
)
|
janus/refiners/uml.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from janus.llm.models_info import JanusModel
|
2
|
+
from janus.parsers.parser import JanusParser
|
3
|
+
from janus.refiners.refiner import ReflectionRefiner
|
4
|
+
|
5
|
+
|
6
|
+
class ALCFixUMLVariablesRefiner(ReflectionRefiner):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
llm: JanusModel,
|
10
|
+
parser: JanusParser,
|
11
|
+
max_retries: int,
|
12
|
+
):
|
13
|
+
super().__init__(
|
14
|
+
llm=llm,
|
15
|
+
parser=parser,
|
16
|
+
max_retries=max_retries,
|
17
|
+
prompt_template_name="refinement/uml/alc_fix_variables",
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class FixUMLConnectionsRefiner(ReflectionRefiner):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
llm: JanusModel,
|
25
|
+
parser: JanusParser,
|
26
|
+
max_retries: int,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
llm=llm,
|
30
|
+
parser=parser,
|
31
|
+
max_retries=max_retries,
|
32
|
+
prompt_template_name="refinement/uml/fix_connections",
|
33
|
+
)
|
janus/retrievers/retriever.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from langchain_core.documents import Document
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
1
5
|
from langchain_core.retrievers import BaseRetriever
|
2
6
|
from langchain_core.runnables import Runnable, RunnableConfig
|
3
7
|
|
4
8
|
from janus.language.block import CodeBlock
|
9
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
10
|
+
from janus.utils.logger import create_logger
|
11
|
+
from janus.utils.pdf_docs_reader import PDFDocsReader
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
5
14
|
|
6
15
|
|
7
16
|
class JanusRetriever(Runnable):
|
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
|
|
40
49
|
docs = self.retriever.invoke(code_block.text)
|
41
50
|
context = "\n\n".join(doc.page_content for doc in docs)
|
42
51
|
return f"You may use the following additional context: {context}"
|
52
|
+
|
53
|
+
|
54
|
+
class LanguageDocsRetriever(JanusRetriever):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
llm: JanusModel,
|
58
|
+
language_name: str,
|
59
|
+
prompt_template_name: str = "retrieval/language_docs",
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.llm: JanusModel = llm
|
63
|
+
self.language: str = language_name
|
64
|
+
|
65
|
+
self.PDF_reader = PDFDocsReader(
|
66
|
+
language=self.language,
|
67
|
+
)
|
68
|
+
|
69
|
+
language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
|
70
|
+
source_language=self.language,
|
71
|
+
prompt_template=prompt_template_name,
|
72
|
+
).prompt
|
73
|
+
|
74
|
+
parser: StrOutputParser = StrOutputParser()
|
75
|
+
self.chain = language_docs_prompt | self.llm | parser
|
76
|
+
|
77
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
78
|
+
functionality_to_reference: str = self.chain.invoke(
|
79
|
+
dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
|
80
|
+
)
|
81
|
+
if functionality_to_reference == "NODOCS":
|
82
|
+
log.debug("No Opcodes requested from language docs retriever.")
|
83
|
+
return ""
|
84
|
+
else:
|
85
|
+
functionality_to_reference: List = functionality_to_reference.split(", ")
|
86
|
+
log.debug(
|
87
|
+
f"List of opcodes requested by language docs retriever"
|
88
|
+
f"to search the {self.language} "
|
89
|
+
f"docs for: {functionality_to_reference}"
|
90
|
+
)
|
91
|
+
|
92
|
+
docs: List[Document] = self.PDF_reader.search_language_reference(
|
93
|
+
functionality_to_reference
|
94
|
+
)
|
95
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
96
|
+
if context:
|
97
|
+
return (
|
98
|
+
f"You may reference the following excerpts from the {self.language} "
|
99
|
+
f"language documentation: {context}"
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
return ""
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import joblib
|
7
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8
|
+
from langchain_core.documents import Document
|
9
|
+
from langchain_unstructured import UnstructuredLoader
|
10
|
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
11
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
12
|
+
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PDFDocsReader:
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
language: str,
|
22
|
+
chunk_size: int = 1000,
|
23
|
+
chunk_overlap: int = 100,
|
24
|
+
start_page: Optional[int] = None,
|
25
|
+
end_page: Optional[int] = None,
|
26
|
+
vectorizer: CountVectorizer = TfidfVectorizer(),
|
27
|
+
):
|
28
|
+
self.retrieval_docs_dir: Path = Path(
|
29
|
+
os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
|
30
|
+
)
|
31
|
+
self.language = language
|
32
|
+
self.chunk_size = chunk_size
|
33
|
+
self.chunk_overlap = chunk_overlap
|
34
|
+
self.start_page = start_page
|
35
|
+
self.end_page = end_page
|
36
|
+
self.vectorizer = vectorizer
|
37
|
+
self.documents = self.load_and_chunk_pdf()
|
38
|
+
self.doc_vectors = self.vectorize_documents()
|
39
|
+
|
40
|
+
def load_and_chunk_pdf(self) -> List[str]:
|
41
|
+
pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
|
42
|
+
pickled_documents_path = (
|
43
|
+
self.retrieval_docs_dir / f"{self.language}_documents.pkl"
|
44
|
+
)
|
45
|
+
|
46
|
+
if pickled_documents_path.exists():
|
47
|
+
log.debug(
|
48
|
+
f"Loading pre-chunked PDF from {pickled_documents_path}. "
|
49
|
+
f"If you want to regenerate retrieval docs for {self.language}, "
|
50
|
+
f"delete the file at {pickled_documents_path}, "
|
51
|
+
f"then add a new {self.language}.pdf."
|
52
|
+
)
|
53
|
+
documents = joblib.load(pickled_documents_path)
|
54
|
+
else:
|
55
|
+
if not pdf_path.exists():
|
56
|
+
raise FileNotFoundError(
|
57
|
+
f"Language docs retrieval is enabled, but no PDF for language "
|
58
|
+
f"'{self.language}' was found. Move a "
|
59
|
+
f"{self.language} reference manual to "
|
60
|
+
f"{pdf_path.absolute()} "
|
61
|
+
f"(the path to the directory of PDF docs can be "
|
62
|
+
f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
|
63
|
+
)
|
64
|
+
log.info(
|
65
|
+
f"Chunking reference PDF for {self.language} using unstructured - "
|
66
|
+
f"if your PDF has many pages, this could take a while..."
|
67
|
+
)
|
68
|
+
start_time = time.time()
|
69
|
+
loader = UnstructuredLoader(
|
70
|
+
pdf_path,
|
71
|
+
chunking_strategy="basic",
|
72
|
+
max_characters=1000000,
|
73
|
+
include_orig_elements=False,
|
74
|
+
start_page=self.start_page,
|
75
|
+
end_page=self.end_page,
|
76
|
+
)
|
77
|
+
docs = loader.load()
|
78
|
+
text = "\n\n".join([doc.page_content for doc in docs])
|
79
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
80
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
81
|
+
)
|
82
|
+
documents = text_splitter.split_text(text)
|
83
|
+
log.info(f"Document store created for language: {self.language}")
|
84
|
+
end_time = time.time()
|
85
|
+
log.info(
|
86
|
+
f"Processing time for {self.language} PDF: "
|
87
|
+
f"{end_time - start_time} seconds"
|
88
|
+
)
|
89
|
+
|
90
|
+
joblib.dump(documents, pickled_documents_path)
|
91
|
+
log.debug(f"Documents saved to {pickled_documents_path}")
|
92
|
+
|
93
|
+
return documents
|
94
|
+
|
95
|
+
def vectorize_documents(self) -> (TfidfVectorizer, any):
|
96
|
+
doc_vectors = self.vectorizer.fit_transform(self.documents)
|
97
|
+
return doc_vectors
|
98
|
+
|
99
|
+
def search_language_reference(
|
100
|
+
self,
|
101
|
+
query: List[str],
|
102
|
+
top_k: int = 1,
|
103
|
+
min_similarity: float = 0.1,
|
104
|
+
) -> List[Document]:
|
105
|
+
"""Searches through the vectorized PDF for the query using
|
106
|
+
tf-idf and returns a list of langchain Documents."""
|
107
|
+
|
108
|
+
docs: List[Document] = []
|
109
|
+
|
110
|
+
for item in query:
|
111
|
+
# Transform the query using the TF-IDF vectorizer
|
112
|
+
query_vector = self.vectorizer.transform([item])
|
113
|
+
|
114
|
+
# Calculate cosine similarities between the query and document vectors
|
115
|
+
similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
|
116
|
+
|
117
|
+
# Get the indices of documents with similarity above the threshold
|
118
|
+
valid_indices = [
|
119
|
+
i for i, sim in enumerate(similarities) if sim >= min_similarity
|
120
|
+
]
|
121
|
+
|
122
|
+
# Sort the valid indices by similarity score in descending order
|
123
|
+
sorted_indices = sorted(
|
124
|
+
valid_indices, key=lambda i: similarities[i], reverse=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Limit to top-k results
|
128
|
+
top_indices = sorted_indices[:top_k]
|
129
|
+
|
130
|
+
# Retrieve the top-k most relevant documents
|
131
|
+
docs += [Document(page_content=self.documents[i]) for i in top_indices]
|
132
|
+
log.debug(f"Langauge documentation search result: {docs}")
|
133
|
+
|
134
|
+
return docs
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: janus-llm
|
3
|
-
Version: 4.
|
3
|
+
Version: 4.2.0
|
4
4
|
Summary: A transcoding library using LLMs.
|
5
5
|
Home-page: https://github.com/janus-llm/janus-llm
|
6
6
|
License: Apache 2.0
|
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
|
|
23
23
|
Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
|
24
24
|
Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
|
25
25
|
Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
|
26
|
+
Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
|
26
27
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
27
28
|
Requires-Dist: numpy (>=1.24.3,<2.0.0)
|
28
29
|
Requires-Dist: openai (>=1.14.0,<2.0.0)
|
30
|
+
Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
|
29
31
|
Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
|
30
32
|
Requires-Dist: py-rouge (>=1.1,<2.0)
|
33
|
+
Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
|
31
34
|
Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
|
32
35
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
33
36
|
Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
|
37
|
+
Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
|
34
38
|
Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
|
39
|
+
Requires-Dist: tesseract (>=0.1.3,<0.2.0)
|
35
40
|
Requires-Dist: text-generation (>=0.6.0,<0.7.0)
|
36
41
|
Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
|
37
42
|
Requires-Dist: transformers (>=4.31.0,<5.0.0)
|
38
43
|
Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
|
39
44
|
Requires-Dist: typer (>=0.9.0,<0.10.0)
|
45
|
+
Requires-Dist: unstructured (>=0.15.9,<0.16.0)
|
46
|
+
Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
|
47
|
+
Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
|
40
48
|
Project-URL: Documentation, https://janus-llm.github.io/janus-llm
|
41
49
|
Project-URL: Repository, https://github.com/janus-llm/janus-llm
|
42
50
|
Description-Content-Type: text/markdown
|