janus-llm 3.5.3__py3-none-any.whl → 4.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,14 +1,6 @@
1
- import math
1
+ from langchain_core.runnables import Runnable, RunnableParallel
2
2
 
3
- from langchain.output_parsers import RetryWithErrorOutputParser
4
- from langchain_core.exceptions import OutputParserException
5
- from langchain_core.runnables import RunnableLambda, RunnableParallel
6
-
7
- from janus.converter.converter import run_if_changed
8
3
  from janus.converter.document import Documenter
9
- from janus.language.block import TranslatedCodeBlock
10
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
11
- from janus.parsers.refiner_parser import RefinerParser
12
4
  from janus.parsers.uml import UMLSyntaxParser
13
5
  from janus.utils.logger import create_logger
14
6
 
@@ -16,10 +8,7 @@ log = create_logger(__name__)
16
8
 
17
9
 
18
10
  class DiagramGenerator(Documenter):
19
- """DiagramGenerator
20
-
21
- A class that translates code from one programming language to a set of diagrams.
22
- """
11
+ """A Converter that translates code into a set of PLANTUML diagrams."""
23
12
 
24
13
  def __init__(
25
14
  self,
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
30
19
  """Initialize the DiagramGenerator class
31
20
 
32
21
  Arguments:
33
- model: The LLM to use for translation. If an OpenAI model, the
34
- `OPENAI_API_KEY` environment variable must be set and the
35
- `OPENAI_ORG_ID` environment variable should be set if needed.
36
- model_arguments: Additional arguments to pass to the LLM constructor.
37
- source_language: The source programming language.
38
- max_prompts: The maximum number of prompts to try before giving up.
39
- db_path: path to chroma database
40
- db_config: database configuraiton
41
22
  diagram_type: type of PLANTUML diagram to generate
23
+ add_documentation: Whether to add a documentation step prior to
24
+ diagram generation.
42
25
  """
43
- super().__init__(**kwargs)
44
26
  self._diagram_type = diagram_type
45
27
  self._add_documentation = add_documentation
46
- self._documenter = None
47
- self._diagram_parser = UMLSyntaxParser(language="plantuml")
48
- if add_documentation:
49
- self._diagram_prompt_template_name = "diagram_with_documentation"
50
- else:
51
- self._diagram_prompt_template_name = "diagram"
52
- self._load_diagram_prompt_engine()
28
+ self._documenter = Documenter(**kwargs)
53
29
 
54
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
55
- input = self._parser.parse_input(block.original)
56
- n1 = round(self.max_prompts ** (1 / 3))
30
+ super().__init__(**kwargs)
31
+
32
+ self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
33
+ self._parser = UMLSyntaxParser(language="plantuml")
57
34
 
58
- # Retries with the input, output, and error
59
- n2 = round((self.max_prompts // n1) ** (1 / 2))
35
+ self._load_parameters()
60
36
 
61
- # Retries with just the input
62
- n3 = math.ceil(self.max_prompts / (n1 * n2))
37
+ def _load_prompt(self):
38
+ super()._load_prompt()
39
+ self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
63
40
 
41
+ def _input_runnable(self) -> Runnable:
64
42
  if self._add_documentation:
65
- documentation_text = super()._run_chain(block)
66
- refine_output = RefinerParser(
67
- parser=self._diagram_parser,
68
- initial_prompt=self._diagram_prompt.format(
69
- **{
70
- "SOURCE_CODE": input,
71
- "DOCUMENTATION": documentation_text,
72
- "DIAGRAM_TYPE": self._diagram_type,
73
- }
74
- ),
75
- refiner=self._refiner,
76
- max_retries=n1,
77
- llm=self._llm,
78
- )
79
- else:
80
- refine_output = RefinerParser(
81
- parser=self._diagram_parser,
82
- initial_prompt=self._diagram_prompt.format(
83
- **{
84
- "SOURCE_CODE": input,
85
- "DIAGRAM_TYPE": self._diagram_type,
86
- }
87
- ),
88
- refiner=self._refiner,
89
- max_retries=n1,
90
- llm=self._llm,
43
+ return RunnableParallel(
44
+ SOURCE_CODE=self._parser.parse_input,
45
+ DOCUMENTATION=self._documenter.chain,
46
+ context=self._retriever,
91
47
  )
92
- retry = RetryWithErrorOutputParser.from_llm(
93
- llm=self._llm,
94
- parser=refine_output,
95
- max_retries=n2,
96
- )
97
- completion_chain = self._prompt | self._llm
98
- chain = RunnableParallel(
99
- completion=completion_chain, prompt_value=self._diagram_prompt
100
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
101
- for _ in range(n3):
102
- try:
103
- if self._add_documentation:
104
- return chain.invoke(
105
- {
106
- "SOURCE_CODE": input,
107
- "DOCUMENTATION": documentation_text,
108
- "DIAGRAM_TYPE": self._diagram_type,
109
- }
110
- )
111
- else:
112
- return chain.invoke(
113
- {
114
- "SOURCE_CODE": input,
115
- "DIAGRAM_TYPE": self._diagram_type,
116
- }
117
- )
118
- except OutputParserException:
119
- pass
120
-
121
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
122
-
123
- @run_if_changed(
124
- "_diagram_prompt_template_name",
125
- "_source_language",
126
- )
127
- def _load_diagram_prompt_engine(self) -> None:
128
- """Load the prompt engine according to this instance's attributes.
129
-
130
- If the relevant fields have not been changed since the last time this method was
131
- called, nothing happens.
132
- """
133
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
134
- source_language=self._source_language,
135
- target_language="text",
136
- target_version=None,
137
- prompt_template=self._diagram_prompt_template_name,
48
+ return RunnableParallel(
49
+ SOURCE_CODE=self._parser.parse_input,
50
+ context=self._retriever,
138
51
  )
139
- self._diagram_prompt = self._diagram_prompt_engine.prompt
@@ -90,7 +90,7 @@ class Translator(Converter):
90
90
  f"({self._source_language} != {self._target_language})"
91
91
  )
92
92
 
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
93
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
94
94
  source_language=self._source_language,
95
95
  target_language=self._target_language,
96
96
  target_version=self._target_version,
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = AlcSplitter(model=llm)
17
17
  self.combiner = Combiner(language="ibmhlasm")
18
18
  self.test_file = Path("janus/language/alc/_tests/alc.asm")
janus/language/alc/alc.py CHANGED
@@ -1,12 +1,11 @@
1
1
  import re
2
2
  from typing import Optional
3
3
 
4
- from langchain.schema.language_model import BaseLanguageModel
5
-
6
4
  from janus.language.block import CodeBlock
7
5
  from janus.language.combine import Combiner
8
6
  from janus.language.node import NodeType
9
7
  from janus.language.treesitter import TreeSitterSplitter
8
+ from janus.llm.models_info import JanusModel
10
9
  from janus.utils.logger import create_logger
11
10
 
12
11
  log = create_logger(__name__)
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
27
26
 
28
27
  def __init__(
29
28
  self,
30
- model: None | BaseLanguageModel = None,
29
+ model: JanusModel | None = None,
31
30
  max_tokens: int = 4096,
32
31
  protected_node_types: tuple[str, ...] = (),
33
32
  prune_node_types: tuple[str, ...] = (),
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
101
100
 
102
101
  def __init__(
103
102
  self,
104
- model: None | BaseLanguageModel = None,
103
+ model: JanusModel | None = None,
105
104
  max_tokens: int = 4096,
106
105
  protected_node_types: tuple[str, ...] = (),
107
106
  prune_node_types: tuple[str, ...] = (),
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
129
128
  prune_unprotected=prune_unprotected,
130
129
  )
131
130
 
132
- def _get_ast(self, code: str) -> CodeBlock:
131
+ def split_string(self, code: str, name: str) -> CodeBlock:
132
+ # Override split_string to use processed code and track active usings
133
133
  active_usings = self.get_active_usings(code)
134
- code = self.preproccess_assembly(code)
135
- ast: CodeBlock = super()._get_ast(code)
136
- ast.context_tags["active_usings"] = active_usings
137
- return ast
134
+ processed_code = self.preproccess_assembly(code)
135
+ root = super().split_string(processed_code, name)
136
+ if active_usings is not None:
137
+ stack = [root]
138
+ while stack:
139
+ block = stack.pop()
140
+ block.context_tags["active_usings"] = active_usings
141
+ stack.extend(block.children)
142
+ return root
138
143
 
139
144
  def preproccess_assembly(self, code: str) -> str:
140
145
  """Remove non-essential lines from an assembly snippet"""
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
142
147
  lines = code.splitlines()
143
148
  lines = self.strip_header_and_left(lines)
144
149
  lines = self.strip_addresses(lines)
145
- return "".join(str(line) for line in lines)
150
+ return "\n".join(str(line) for line in lines)
146
151
 
147
152
  def get_active_usings(self, code: str) -> Optional[str]:
148
153
  """Look for 'active usings' in the ALC listing header"""
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
15
15
  def setUp(self):
16
16
  model_name = "gpt-4o"
17
17
  self.binary_file = Path("janus/language/binary/_tests/hello")
18
- self.llm, _, _, _ = load_model(model_name)
18
+ self.llm = load_model(model_name)
19
19
  self.splitter = BinarySplitter(model=self.llm)
20
20
  os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
21
21
 
@@ -5,11 +5,11 @@ import tempfile
5
5
  from pathlib import Path
6
6
 
7
7
  import tree_sitter
8
- from langchain.schema.language_model import BaseLanguageModel
9
8
 
10
9
  from janus.language.block import CodeBlock
11
10
  from janus.language.combine import Combiner
12
11
  from janus.language.treesitter import TreeSitterSplitter
12
+ from janus.llm.models_info import JanusModel
13
13
  from janus.utils.enums import LANGUAGES
14
14
  from janus.utils.logger import create_logger
15
15
 
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
31
31
 
32
32
  def __init__(
33
33
  self,
34
- model: None | BaseLanguageModel = None,
34
+ model: JanusModel | None = None,
35
35
  max_tokens: int = 4096,
36
36
  protected_node_types: tuple[str] = (),
37
37
  prune_node_types: tuple[str] = (),
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = MumpsSplitter(model=llm)
17
17
  self.combiner = Combiner(language="mumps")
18
18
  self.test_file = Path("janus/language/mumps/_tests/mumps.m")
@@ -1,11 +1,10 @@
1
1
  import re
2
2
 
3
- from langchain.schema.language_model import BaseLanguageModel
4
-
5
3
  from janus.language.block import CodeBlock
6
4
  from janus.language.combine import Combiner
7
5
  from janus.language.node import NodeType
8
6
  from janus.language.splitter import Splitter
7
+ from janus.llm.models_info import JanusModel
9
8
  from janus.utils.logger import create_logger
10
9
 
11
10
  log = create_logger(__name__)
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
44
43
 
45
44
  def __init__(
46
45
  self,
47
- model: None | BaseLanguageModel = None,
46
+ model: JanusModel | None = None,
48
47
  max_tokens: int = 4096,
49
48
  protected_node_types: tuple[str] = ("routine_definition",),
50
49
  prune_node_types: tuple[str] = (),
@@ -4,11 +4,11 @@ from pathlib import Path
4
4
  from typing import List
5
5
 
6
6
  import tiktoken
7
- from langchain.schema.language_model import BaseLanguageModel
8
7
 
9
8
  from janus.language.block import CodeBlock
10
9
  from janus.language.file import FileManager
11
10
  from janus.language.node import NodeType
11
+ from janus.llm.models_info import JanusModel
12
12
  from janus.utils.logger import create_logger
13
13
 
14
14
  log = create_logger(__name__)
@@ -44,7 +44,7 @@ class Splitter(FileManager):
44
44
  def __init__(
45
45
  self,
46
46
  language: str,
47
- model: None | BaseLanguageModel = None,
47
+ model: JanusModel | None = None,
48
48
  max_tokens: int = 4096,
49
49
  skip_merge: bool = False,
50
50
  protected_node_types: tuple[str, ...] = (),
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
15
  self.maxDiff = None
16
- self.llm, _, _, _ = load_model(model_name)
16
+ self.llm = load_model(model_name)
17
17
 
18
18
  def _split(self):
19
19
  """Split the test file."""
@@ -7,10 +7,10 @@ from typing import Optional
7
7
 
8
8
  import tree_sitter
9
9
  from git import Repo
10
- from langchain.schema.language_model import BaseLanguageModel
11
10
 
12
11
  from janus.language.block import CodeBlock, NodeType
13
12
  from janus.language.splitter import Splitter
13
+ from janus.llm.models_info import JanusModel
14
14
  from janus.utils.enums import LANGUAGES
15
15
  from janus.utils.logger import create_logger
16
16
 
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
25
25
  def __init__(
26
26
  self,
27
27
  language: str,
28
- model: None | BaseLanguageModel = None,
28
+ model: JanusModel | None = None,
29
29
  max_tokens: int = 4096,
30
30
  protected_node_types: tuple[str, ...] = (),
31
31
  prune_node_types: tuple[str, ...] = (),
@@ -12,6 +12,17 @@ from janus.utils.logger import create_logger
12
12
 
13
13
  log = create_logger(__name__)
14
14
 
15
+ openai_model_reroutes = {
16
+ "gpt-4o": "gpt-4o-2024-05-13",
17
+ "gpt-4o-mini": "gpt-4o-mini",
18
+ "gpt-4": "gpt-4-0613",
19
+ "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
20
+ "gpt-4-turbo-preview": "gpt-4-0125-preview",
21
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
22
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
23
+ "gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
24
+ }
25
+
15
26
 
16
27
  # Updated 2024-06-21
17
28
  COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
@@ -45,6 +56,8 @@ def _get_token_cost(
45
56
  prompt_tokens: int, completion_tokens: int, model_id: str | None
46
57
  ) -> float:
47
58
  """Get the cost of tokens according to model ID"""
59
+ if model_id in openai_model_reroutes:
60
+ model_id = openai_model_reroutes[model_id]
48
61
  if model_id not in COST_PER_1K_TOKENS:
49
62
  raise ValueError(
50
63
  f"Unknown model: {model_id}. Please provide a valid model name."
janus/llm/models_info.py CHANGED
@@ -2,14 +2,14 @@ import json
2
2
  import os
3
3
  import time
4
4
  from pathlib import Path
5
- from typing import Any, Callable
5
+ from typing import Protocol, TypeVar
6
6
 
7
7
  from dotenv import load_dotenv
8
8
  from langchain_community.llms import HuggingFaceTextGenInference
9
- from langchain_core.language_models import BaseLanguageModel
9
+ from langchain_core.runnables import Runnable
10
10
  from langchain_openai import ChatOpenAI
11
11
 
12
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
13
13
  from janus.prompts.prompt import (
14
14
  ChatGptPromptEngine,
15
15
  ClaudePromptEngine,
@@ -44,17 +44,34 @@ except ImportError:
44
44
  )
45
45
 
46
46
 
47
+ ModelType = TypeVar(
48
+ "ModelType",
49
+ ChatOpenAI,
50
+ HuggingFaceTextGenInference,
51
+ Bedrock,
52
+ BedrockChat,
53
+ HuggingFacePipeline,
54
+ )
55
+
56
+
57
+ class JanusModelProtocol(Protocol):
58
+ model_id: str
59
+ model_type_name: str
60
+ token_limit: int
61
+ input_token_cost: float
62
+ output_token_cost: float
63
+ prompt_engine: type[PromptEngine]
64
+
65
+ def get_num_tokens(self, text: str) -> int:
66
+ ...
67
+
68
+
69
+ class JanusModel(Runnable, JanusModelProtocol):
70
+ ...
71
+
72
+
47
73
  load_dotenv()
48
74
 
49
- openai_model_reroutes = {
50
- "gpt-4o": "gpt-4o-2024-05-13",
51
- "gpt-4o-mini": "gpt-4o-mini",
52
- "gpt-4": "gpt-4-0613",
53
- "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
54
- "gpt-4-turbo-preview": "gpt-4-0125-preview",
55
- "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
56
- "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
57
- }
58
75
 
59
76
  openai_models = [
60
77
  "gpt-4o",
@@ -105,24 +122,15 @@ bedrock_models = [
105
122
  ]
106
123
  all_models = [*openai_models, *bedrock_models]
107
124
 
108
- MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
125
+ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
109
126
  "OpenAI": ChatOpenAI,
110
127
  "HuggingFace": HuggingFaceTextGenInference,
128
+ "Bedrock": Bedrock,
129
+ "BedrockChat": BedrockChat,
130
+ "HuggingFaceLocal": HuggingFacePipeline,
111
131
  }
112
132
 
113
- try:
114
- MODEL_TYPE_CONSTRUCTORS.update(
115
- {
116
- "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
117
- "Bedrock": Bedrock,
118
- "BedrockChat": BedrockChat,
119
- }
120
- )
121
- except NameError:
122
- pass
123
-
124
-
125
- MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
133
+ MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
126
134
  **{m: ChatGptPromptEngine for m in openai_models},
127
135
  **{m: ClaudePromptEngine for m in claude_models},
128
136
  **{m: Llama2PromptEngine for m in llama2_models},
@@ -132,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
132
140
  **{m: MistralPromptEngine for m in mistral_models},
133
141
  }
134
142
 
135
- _open_ai_defaults: dict[str, str] = {
136
- "openai_api_key": os.getenv("OPENAI_API_KEY"),
137
- "openai_organization": os.getenv("OPENAI_ORG_ID"),
138
- }
139
-
140
143
  MODEL_ID_TO_LONG_ID = {
141
144
  **{m: mr for m, mr in openai_model_reroutes.items()},
142
145
  "bedrock-claude-v2": "anthropic.claude-v2",
@@ -168,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
168
171
 
169
172
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
170
173
 
171
- MODEL_TYPES: dict[str, PromptEngine] = {
174
+ MODEL_TYPES: dict[str, str] = {
172
175
  **{m: "OpenAI" for m in openai_models},
173
176
  **{m: "BedrockChat" for m in bedrock_models},
174
177
  }
@@ -211,53 +214,90 @@ def get_available_model_names() -> list[str]:
211
214
  return avaialable_models
212
215
 
213
216
 
214
- def load_model(
215
- user_model_name: str,
216
- ) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
217
+ def load_model(model_id) -> JanusModel:
217
218
  if not MODEL_CONFIG_DIR.exists():
218
219
  MODEL_CONFIG_DIR.mkdir(parents=True)
219
- model_config_file = MODEL_CONFIG_DIR / f"{user_model_name}.json"
220
- if not model_config_file.exists():
221
- log.warning(
222
- f"Model {user_model_name} not found in user-defined models, searching "
223
- f"default models for {user_model_name}."
224
- )
225
- model_id = user_model_name
226
- if user_model_name not in DEFAULT_MODELS:
227
- message = (
228
- f"Model {user_model_name} not found in default models. Make sure to run "
229
- "`janus llm add` first."
230
- )
231
- log.error(message)
232
- raise ValueError(message)
233
- model_config = {
234
- "model_type": MODEL_TYPES[model_id],
235
- "model_id": model_id,
236
- "model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
237
- "token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
238
- "model_cost": COST_PER_1K_TOKENS.get(
239
- MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
240
- ),
241
- }
242
- with open(model_config_file, "w") as f:
243
- json.dump(model_config, f)
244
- else:
220
+ model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
221
+
222
+ if model_config_file.exists():
223
+ log.info(f"Loading {model_id} from {model_config_file}.")
245
224
  with open(model_config_file, "r") as f:
246
225
  model_config = json.load(f)
247
- model_constructor = MODEL_TYPE_CONSTRUCTORS[model_config["model_type"]]
248
- model_args = model_config["model_args"]
249
- if model_config["model_type"] == "OpenAI":
250
- model_args.update(_open_ai_defaults)
226
+ model_type_name = model_config["model_type"]
227
+ model_id = model_config["model_id"]
228
+ model_args = model_config["model_args"]
229
+ token_limit = model_config["token_limit"]
230
+ input_token_cost = model_config["model_cost"]["input"]
231
+ output_token_cost = model_config["model_cost"]["output"]
232
+
233
+ elif model_id in DEFAULT_MODELS:
234
+ model_id = model_id
235
+ model_long_id = MODEL_ID_TO_LONG_ID[model_id]
236
+ model_type_name = MODEL_TYPES[model_id]
237
+ model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
238
+
239
+ token_limit = 0
240
+ input_token_cost = 0.0
241
+ output_token_cost = 0.0
242
+ if model_long_id in TOKEN_LIMITS:
243
+ token_limit = TOKEN_LIMITS[model_long_id]
244
+ if model_long_id in COST_PER_1K_TOKENS:
245
+ token_limits = COST_PER_1K_TOKENS[model_long_id]
246
+ input_token_cost = token_limits["input"]
247
+ output_token_cost = token_limits["output"]
248
+
249
+ else:
250
+ model_list = "\n\t".join(DEFAULT_MODELS)
251
+ message = (
252
+ f"Model {model_id} not found in user-defined model directory "
253
+ f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
254
+ f"models:\n\t{model_list}\n"
255
+ f"To use a custom model, first run `janus llm add`."
256
+ )
257
+ log.error(message)
258
+ raise ValueError(message)
259
+
260
+ if model_type_name == "HuggingFaceLocal":
261
+ model = HuggingFacePipeline.from_model_id(
262
+ model_id=model_id,
263
+ task="text-generation",
264
+ model_kwargs=model_args,
265
+ )
266
+ model_args.update(pipeline=model.pipeline)
267
+
268
+ elif model_type_name == "OpenAI":
269
+ model_args.update(
270
+ openai_api_key=str(os.getenv("OPENAI_API_KEY")),
271
+ openai_organization=str(os.getenv("OPENAI_ORG_ID")),
272
+ )
251
273
  log.warning("Do NOT use this model in sensitive environments!")
252
274
  log.warning("If you would like to cancel, please press Ctrl+C.")
253
275
  log.warning("Waiting 10 seconds...")
254
276
  # Give enough time for the user to read the warnings and cancel
255
277
  time.sleep(10)
256
278
 
257
- model = model_constructor(**model_args)
258
- return (
259
- model,
260
- model_config["model_id"],
261
- model_config["token_limit"],
262
- model_config["model_cost"],
279
+ model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
280
+ prompt_engine = MODEL_PROMPT_ENGINES[model_id]
281
+
282
+ class JanusModel(model_type):
283
+ model_id: str
284
+ short_model_id: str
285
+ model_type_name: str
286
+ token_limit: int
287
+ input_token_cost: float
288
+ output_token_cost: float
289
+ prompt_engine: type[PromptEngine]
290
+
291
+ model_args.update(
292
+ model_id=MODEL_ID_TO_LONG_ID[model_id],
293
+ short_model_id=model_id,
294
+ )
295
+
296
+ return JanusModel(
297
+ model_type_name=model_type_name,
298
+ token_limit=token_limit,
299
+ input_token_cost=input_token_cost,
300
+ output_token_cost=output_token_cost,
301
+ prompt_engine=prompt_engine,
302
+ **model_args,
263
303
  )