not-again-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. not_again_ai/llm/chat_completion/__init__.py +4 -0
  2. not_again_ai/llm/chat_completion/interface.py +32 -0
  3. not_again_ai/llm/chat_completion/providers/ollama_api.py +227 -0
  4. not_again_ai/llm/chat_completion/providers/openai_api.py +290 -0
  5. not_again_ai/llm/chat_completion/types.py +145 -0
  6. not_again_ai/llm/embedding/__init__.py +4 -0
  7. not_again_ai/llm/embedding/interface.py +28 -0
  8. not_again_ai/llm/embedding/providers/ollama_api.py +87 -0
  9. not_again_ai/llm/embedding/providers/openai_api.py +126 -0
  10. not_again_ai/llm/embedding/types.py +23 -0
  11. not_again_ai/llm/prompting/__init__.py +3 -0
  12. not_again_ai/llm/prompting/compile_prompt.py +125 -0
  13. not_again_ai/llm/prompting/interface.py +46 -0
  14. not_again_ai/llm/prompting/providers/openai_tiktoken.py +122 -0
  15. not_again_ai/llm/prompting/types.py +43 -0
  16. {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/METADATA +24 -40
  17. not_again_ai-0.16.0.dist-info/RECORD +38 -0
  18. {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/WHEEL +1 -1
  19. not_again_ai/llm/gh_models/azure_ai_client.py +0 -20
  20. not_again_ai/llm/gh_models/chat_completion.py +0 -81
  21. not_again_ai/llm/openai_api/chat_completion.py +0 -339
  22. not_again_ai/llm/openai_api/context_management.py +0 -70
  23. not_again_ai/llm/openai_api/embeddings.py +0 -62
  24. not_again_ai/llm/openai_api/openai_client.py +0 -78
  25. not_again_ai/llm/openai_api/prompts.py +0 -191
  26. not_again_ai/llm/openai_api/tokens.py +0 -184
  27. not_again_ai/local_llm/__init__.py +0 -27
  28. not_again_ai/local_llm/chat_completion.py +0 -105
  29. not_again_ai/local_llm/huggingface/chat_completion.py +0 -59
  30. not_again_ai/local_llm/huggingface/helpers.py +0 -23
  31. not_again_ai/local_llm/ollama/__init__.py +0 -0
  32. not_again_ai/local_llm/ollama/chat_completion.py +0 -111
  33. not_again_ai/local_llm/ollama/model_mapping.py +0 -17
  34. not_again_ai/local_llm/ollama/ollama_client.py +0 -24
  35. not_again_ai/local_llm/ollama/service.py +0 -81
  36. not_again_ai/local_llm/ollama/tokens.py +0 -104
  37. not_again_ai/local_llm/prompts.py +0 -38
  38. not_again_ai/local_llm/tokens.py +0 -90
  39. not_again_ai-0.14.0.dist-info/RECORD +0 -44
  40. not_again_ai-0.14.0.dist-info/entry_points.txt +0 -3
  41. /not_again_ai/llm/{gh_models → chat_completion/providers}/__init__.py +0 -0
  42. /not_again_ai/llm/{openai_api → embedding/providers}/__init__.py +0 -0
  43. /not_again_ai/{local_llm/huggingface → llm/prompting/providers}/__init__.py +0 -0
  44. {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/LICENSE +0 -0
@@ -1,191 +0,0 @@
1
- import base64
2
- from copy import deepcopy
3
- import mimetypes
4
- from pathlib import Path
5
- from typing import Any
6
-
7
- from liquid import Template
8
- from openai.lib._pydantic import to_strict_json_schema
9
- from pydantic import BaseModel
10
-
11
-
12
- def _validate_message_vision(message: dict[str, list[dict[str, Path | str]] | str]) -> bool:
13
- """Validates that a message for a vision model is valid"""
14
- valid_fields = ["role", "content", "name", "tool_call_id", "tool_calls"]
15
- if not all(key in valid_fields for key in message):
16
- raise ValueError(f"Message contains invalid fields: {message.keys()}")
17
-
18
- valid_roles = ["system", "user", "assistant", "tool"]
19
- if message["role"] not in valid_roles:
20
- raise ValueError(f"Message contains invalid role: {message['role']}")
21
-
22
- if not isinstance(message["content"], list) and not isinstance(message["content"], str):
23
- raise ValueError(f"content must be a list of dictionaries or a string: {message['content']}")
24
-
25
- if isinstance(message["content"], list):
26
- for content_part in message["content"]:
27
- if isinstance(content_part, dict):
28
- if "image" not in content_part:
29
- raise ValueError(f"Dictionary content part must contain 'image' key: {content_part}")
30
- if "detail" in content_part and content_part["detail"] not in ["low", "high"]:
31
- raise ValueError(f"Optional 'detail' key must be 'low' or 'high': {content_part['detail']}")
32
- elif not isinstance(content_part, str):
33
- raise ValueError(f"content_part must be a dictionary or a string: {content_part}")
34
-
35
- return True
36
-
37
-
38
- def encode_image(image_path: Path) -> str:
39
- """Encodes an image file at the given Path to base64.
40
-
41
- Args:
42
- image_path: The path to the image file to encode.
43
-
44
- Returns:
45
- The base64 encoded image as a string.
46
- """
47
- with Path.open(image_path, "rb") as image_file:
48
- return base64.b64encode(image_file.read()).decode("utf-8")
49
-
50
-
51
- def create_image_url(image_path: Path) -> str:
52
- """Creates a data URL for an image file at the given Path.
53
-
54
- Args:
55
- image_path: The path to the image file to encode.
56
-
57
- Returns:
58
- The data URL for the image.
59
- """
60
- image_data = encode_image(image_path)
61
-
62
- valid_mime_types = ["image/jpeg", "image/png", "image/webp", "image/gif"]
63
-
64
- # Get the MIME type from the image file extension
65
- mime_type = mimetypes.guess_type(image_path)[0]
66
-
67
- # Check if the MIME type is valid
68
- # List of valid types is here: https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
69
- if mime_type not in valid_mime_types:
70
- raise ValueError(f"Invalid MIME type for image: {mime_type}")
71
-
72
- return f"data:{mime_type};base64,{image_data}"
73
-
74
-
75
- def chat_prompt(messages_unformatted: list[dict[str, Any]], variables: dict[str, str]) -> list[dict[str, Any]]:
76
- """Formats a list of messages for OpenAI's chat completion API,
77
- including special syntax for vision models, using Liquid templating.
78
-
79
- Args:
80
- messages_unformatted (list[dict[str, list[dict[str, Path | str]] | str]]):
81
- A list of dictionaries where each dictionary represents a message.
82
- `content` can be a Liquid template string or a list of dictionaries where each dictionary
83
- represents a content part. Each content part can be a string or a dictionary with 'image' and 'detail' keys.
84
- The 'image' key must be a Path or a string representing a URL. The 'detail' key is optional and must be 'low' or 'high'.
85
- variables: A dictionary where each key-value pair represents a variable
86
- name and its value for template rendering.
87
-
88
- Returns:
89
- A list which represents messages in the format that OpenAI expects for its chat completions API.
90
- See here for details: https://platform.openai.com/docs/api-reference/chat/create
91
-
92
- Example:
93
- >>> # Assume cat_image and dog_image are Path objects to image files
94
- >>> messages = [
95
- ... {"role": "system", "content": "You are a helpful assistant."},
96
- ... {
97
- ... "role": "user",
98
- ... "content": ["Describe the animal in the image in one word.", {"image": cat_image, "detail": "low"}],
99
- ... }
100
- ... {"role": "assistant", "content": "{{ answer }}"},
101
- ... {
102
- ... "role": "user",
103
- ... "content": ["What about this animal?", {"image": dog_image, "detail": "high"}],
104
- ... }
105
- ... ]
106
- >>> vars = {"answer": "Cat"}
107
- >>> chat_prompt(messages, vars)
108
- [
109
- {"role": "system", "content": "You are a helpful assistant."},
110
- {
111
- "role": "user",
112
- "content": [
113
- {"type": "text", "text": "Describe the animal in the image in one word."},
114
- {
115
- "type": "image_url",
116
- "image_url": {"url": f"data:image/jpeg;base64,<encoding>", "detail": "low"},
117
- },
118
- ],
119
- },
120
- {"role": "assistant", "content": "Cat"},
121
- {
122
- "role": "user",
123
- "content": [
124
- {"type": "text", "text": "What about this animal?"},
125
- {
126
- "type": "image_url",
127
- "image_url": {"url": f"data:image/jpeg;base64,<encoding>", "detail": "high"},
128
- },
129
- ],
130
- },
131
- ]
132
- """
133
- messages_formatted = deepcopy(messages_unformatted)
134
- for message in messages_formatted:
135
- if not _validate_message_vision(message):
136
- raise ValueError()
137
-
138
- if isinstance(message["content"], list):
139
- for i in range(len(message["content"])):
140
- content_part = message["content"][i]
141
- if isinstance(content_part, dict):
142
- image_path = content_part["image"]
143
- if isinstance(image_path, Path):
144
- temp_content_part: dict[str, Any] = {
145
- "type": "image_url",
146
- "image_url": {
147
- "url": create_image_url(image_path),
148
- },
149
- }
150
- if "detail" in content_part:
151
- temp_content_part["image_url"]["detail"] = content_part["detail"]
152
- elif isinstance(image_path, str):
153
- # Assume its a valid URL
154
- pass
155
- else:
156
- raise ValueError(f"Image path must be a Path or str: {image_path}")
157
- message["content"][i] = temp_content_part
158
- elif isinstance(content_part, str):
159
- message["content"][i] = {
160
- "type": "text",
161
- "text": Template(content_part).render(**variables),
162
- }
163
- elif isinstance(message["content"], str):
164
- message["content"] = Template(message["content"]).render(**variables)
165
-
166
- return messages_formatted
167
-
168
-
169
- def pydantic_to_json_schema(
170
- pydantic_model: type[BaseModel], schema_name: str, description: str | None = None
171
- ) -> dict[str, Any]:
172
- """Converts a Pydantic model to a JSON schema expected by Structured Outputs.
173
- Must adhere to the supported schemas: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
174
-
175
- Args:
176
- pydantic_model: The Pydantic model to convert.
177
- schema_name: The name of the schema.
178
- description: An optional description of the schema.
179
-
180
- Returns:
181
- A JSON schema dictionary representing the Pydantic model.
182
- """
183
- converted_pydantic = to_strict_json_schema(pydantic_model)
184
- schema = {
185
- "name": schema_name,
186
- "strict": True,
187
- "schema": converted_pydantic,
188
- }
189
- if description:
190
- schema["description"] = description
191
- return schema
@@ -1,184 +0,0 @@
1
- from collections.abc import Collection, Set
2
- from typing import Literal
3
-
4
- import tiktoken
5
-
6
-
7
- def load_tokenizer(model: str) -> tiktoken.Encoding:
8
- """Load the tokenizer for the given model
9
-
10
- Args:
11
- model (str): The name of the language model to load the tokenizer for
12
-
13
- Returns:
14
- A tiktoken encoding object
15
- """
16
- try:
17
- encoding = tiktoken.encoding_for_model(model)
18
- except KeyError:
19
- print("Warning: model not found. Using cl100k_base encoding.")
20
- encoding = tiktoken.get_encoding("cl100k_base")
21
- return encoding
22
-
23
-
24
- def truncate_str(
25
- text: str,
26
- max_len: int,
27
- tokenizer: tiktoken.Encoding,
28
- allowed_special: Literal["all"] | Set[str] = set(),
29
- disallowed_special: Literal["all"] | Collection[str] = (),
30
- ) -> str:
31
- """Truncates a string to a maximum token length.
32
-
33
- Special tokens are artificial tokens used to unlock capabilities from a model,
34
- such as fill-in-the-middle. So we want to be careful about accidentally encoding special
35
- tokens, since they can be used to trick a model into doing something we don't want it to do.
36
-
37
- Hence, by default, encode will raise an error if it encounters text that corresponds
38
- to a special token. This can be controlled on a per-token level using the `allowed_special`
39
- and `disallowed_special` parameters. In particular:
40
- - Setting `disallowed_special` to () will prevent this function from raising errors and
41
- cause all text corresponding to special tokens to be encoded as natural text.
42
- - Setting `allowed_special` to "all" will cause this function to treat all text
43
- corresponding to special tokens to be encoded as special tokens.
44
-
45
- Args:
46
- text (str): The string to truncate.
47
- max_len (int): The maximum number of tokens to keep.
48
- tokenizer (tiktoken.Encoding): A tiktoken encoding object
49
- allowed_special (str | set[str]):
50
- disallowed_special (str | set[str]):
51
-
52
- Returns:
53
- str: The truncated string.
54
- """
55
- tokens = tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
56
- if len(tokens) > max_len:
57
- tokens = tokens[:max_len]
58
- # Decode the tokens back to a string
59
- truncated_text = tokenizer.decode(tokens)
60
- return truncated_text
61
- else:
62
- return text
63
-
64
-
65
- def num_tokens_in_string(
66
- text: str,
67
- tokenizer: tiktoken.Encoding,
68
- allowed_special: Literal["all"] | Set[str] = set(),
69
- disallowed_special: Literal["all"] | Collection[str] = (),
70
- ) -> int:
71
- """Return the number of tokens in a string.
72
-
73
- Special tokens are artificial tokens used to unlock capabilities from a model,
74
- such as fill-in-the-middle. So we want to be careful about accidentally encoding special
75
- tokens, since they can be used to trick a model into doing something we don't want it to do.
76
-
77
- Hence, by default, encode will raise an error if it encounters text that corresponds
78
- to a special token. This can be controlled on a per-token level using the `allowed_special`
79
- and `disallowed_special` parameters. In particular:
80
- - Setting `disallowed_special` to () will prevent this function from raising errors and
81
- cause all text corresponding to special tokens to be encoded as natural text.
82
- - Setting `allowed_special` to "all" will cause this function to treat all text
83
- corresponding to special tokens to be encoded as special tokens.
84
-
85
- Args:
86
- text (str): The string to count the tokens.
87
- tokenizer (tiktoken.Encoding): A tiktoken encoding object
88
- allowed_special (str | set[str]):
89
- disallowed_special (str | set[str]):
90
-
91
- Returns:
92
- int: The number of tokens in the string.
93
- """
94
- return len(tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special))
95
-
96
-
97
- def num_tokens_from_messages(
98
- messages: list[dict[str, str]],
99
- tokenizer: tiktoken.Encoding,
100
- model: str = "gpt-3.5-turbo-0125",
101
- allowed_special: Literal["all"] | Set[str] = set(),
102
- disallowed_special: Literal["all"] | Collection[str] = (),
103
- ) -> int:
104
- """Return the number of tokens used by a list of messages.
105
- NOTE: Does not support counting tokens used by function calling or prompts with images.
106
- Reference: # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
107
- and https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
108
-
109
- Special tokens are artificial tokens used to unlock capabilities from a model,
110
- such as fill-in-the-middle. So we want to be careful about accidentally encoding special
111
- tokens, since they can be used to trick a model into doing something we don't want it to do.
112
-
113
- Hence, by default, encode will raise an error if it encounters text that corresponds
114
- to a special token. This can be controlled on a per-token level using the `allowed_special`
115
- and `disallowed_special` parameters. In particular:
116
- - Setting `disallowed_special` to () will prevent this function from raising errors and
117
- cause all text corresponding to special tokens to be encoded as natural text.
118
- - Setting `allowed_special` to "all" will cause this function to treat all text
119
- corresponding to special tokens to be encoded as special tokens.
120
-
121
- Args:
122
- messages (list[dict[str, str]]): A list of messages to count the tokens
123
- should ideally be the result after calling llm.prompts.chat_prompt.
124
- tokenizer (tiktoken.Encoding): A tiktoken encoding object
125
- model (str): The model to use for tokenization. Defaults to "gpt-3.5-turbo-0125".
126
- See https://platform.openai.com/docs/models for a list of OpenAI models.
127
- allowed_special (str | set[str]):
128
- disallowed_special (str | set[str]):
129
-
130
- Returns:
131
- int: The number of tokens used by the messages.
132
- """
133
- if model in {
134
- "gpt-3.5-turbo-0613",
135
- "gpt-3.5-turbo-16k-0613",
136
- "gpt-3.5-turbo-1106",
137
- "gpt-3.5-turbo-0125",
138
- "gpt-4-0314",
139
- "gpt-4-32k-0314",
140
- "gpt-4-0613",
141
- "gpt-4-32k-0613",
142
- "gpt-4-1106-preview",
143
- "gpt-4-turbo-preview",
144
- "gpt-4-0125-preview",
145
- "gpt-4-turbo",
146
- "gpt-4-turbo-2024-04-09",
147
- "gpt-4o",
148
- "gpt-4o-2024-05-13",
149
- "gpt-4o-2024-08-06",
150
- "gpt-4o-mini",
151
- "gpt-4o-mini-2024-07-18",
152
- }:
153
- tokens_per_message = 3 # every message follows <|start|>{role/name}\n{content}<|end|>\n
154
- tokens_per_name = 1 # if there's a name, the role is omitted
155
- elif model == "gpt-3.5-turbo-0301":
156
- tokens_per_message = 4
157
- tokens_per_name = -1
158
- # Approximate catch-all. Assumes future versions of 3.5 and 4 will have the same token counts as the 0613 versions.
159
- elif "gpt-3.5-turbo" in model:
160
- return num_tokens_from_messages(messages, tokenizer=tokenizer, model="gpt-3.5-turbo-0613")
161
- elif "gpt-4o" in model:
162
- return num_tokens_from_messages(messages, tokenizer=tokenizer, model="gpt-4o-2024-05-13")
163
- elif "gpt-4" in model:
164
- return num_tokens_from_messages(messages, tokenizer=tokenizer, model="gpt-4-0613")
165
- else:
166
- raise NotImplementedError(
167
- f"""num_tokens_from_messages() is not implemented for model {model}.
168
- See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
169
- )
170
- num_tokens = 0
171
- for message in messages:
172
- num_tokens += tokens_per_message
173
- for key, value in message.items():
174
- num_tokens += len(
175
- tokenizer.encode(
176
- value,
177
- allowed_special=allowed_special,
178
- disallowed_special=disallowed_special,
179
- )
180
- )
181
- if key == "name":
182
- num_tokens += tokens_per_name
183
- num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
184
- return num_tokens
@@ -1,27 +0,0 @@
1
- import importlib.util
2
- import os
3
-
4
- os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
5
- os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
6
-
7
- if (
8
- importlib.util.find_spec("liquid") is None
9
- or importlib.util.find_spec("ollama") is None
10
- or importlib.util.find_spec("openai") is None
11
- or importlib.util.find_spec("tiktoken") is None
12
- or importlib.util.find_spec("transformers") is None
13
- ):
14
- raise ImportError(
15
- "not_again_ai.local_llm requires the 'llm' and 'local_llm' extra to be installed. "
16
- "You can install it using 'pip install not_again_ai[llm,local_llm]'."
17
- )
18
- else:
19
- import liquid # noqa: F401
20
- import ollama # noqa: F401
21
- import openai # noqa: F401
22
- import tiktoken # noqa: F401
23
- import transformers # noqa: F401
24
- from transformers.utils import logging
25
-
26
- logging.disable_progress_bar()
27
- logging.set_verbosity_error()
@@ -1,105 +0,0 @@
1
- from typing import Any
2
-
3
- from azure.ai.inference import ChatCompletionsClient
4
- from ollama import Client
5
- from openai import AzureOpenAI, OpenAI
6
-
7
- from not_again_ai.llm.gh_models import chat_completion as chat_completion_gh_models
8
- from not_again_ai.llm.openai_api import chat_completion as chat_completion_openai
9
- from not_again_ai.local_llm.ollama import chat_completion as chat_completion_ollama
10
-
11
-
12
- def chat_completion(
13
- messages: list[dict[str, Any]],
14
- model: str,
15
- client: OpenAI | AzureOpenAI | Client | ChatCompletionsClient,
16
- tools: list[dict[str, Any]] | None = None,
17
- max_tokens: int | None = None,
18
- temperature: float = 0.7,
19
- json_mode: bool = False,
20
- seed: int | None = None,
21
- **kwargs: Any,
22
- ) -> dict[str, Any]:
23
- """Creates a common wrapper around chat completion models from different providers.
24
- Currently supports the OpenAI API and Ollama local models.
25
- All input parameters are supported by all providers in similar ways and the output is standardized.
26
-
27
- Args:
28
- messages (list[dict[str, Any]]): A list of messages to send to the model.
29
- model (str): The model name to use.
30
- client (OpenAI | AzureOpenAI | Client | ChatCompletionsClient): The client object to use for chat completion.
31
- tools (list[dict[str, Any]], optional):A list of tools the model may call.
32
- Use this to provide a list of functions the model may generate JSON inputs for. Defaults to None.
33
- max_tokens (int, optional): The maximum number of tokens to generate.
34
- temperature (float, optional): The temperature of the model. Increasing the temperature will make the model answer more creatively.
35
- json_mode (bool, optional): This will structure the response as a valid JSON object.
36
- seed (int, optional): The seed to use for the model for reproducible outputs.
37
-
38
- Returns:
39
- dict[str, Any]: A dictionary with the following keys
40
- message (str | dict): The content of the generated assistant message.
41
- If json_mode is True, this will be a dictionary.
42
- tool_names (list[str], optional): The names of the tools called by the model.
43
- If the model does not support tools, a ResponseError is raised.
44
- tool_args_list (list[dict], optional): The arguments of the tools called by the model.
45
- prompt_tokens (int): The number of tokens in the messages sent to the model.
46
- completion_tokens (int): The number of tokens used by the model to generate the completion.
47
- response_duration (float): The time, in seconds, taken to generate the response by using the model.
48
- extras (dict): This will contain any additional fields returned by corresponding provider.
49
- """
50
- # Determine which chat_completion function to call based on the client type
51
- if isinstance(client, OpenAI | AzureOpenAI):
52
- response = chat_completion_openai.chat_completion(
53
- messages=messages,
54
- model=model,
55
- client=client,
56
- tools=tools,
57
- max_tokens=max_tokens,
58
- temperature=temperature,
59
- json_mode=json_mode,
60
- seed=seed,
61
- **kwargs,
62
- )
63
- elif isinstance(client, Client):
64
- response = chat_completion_ollama.chat_completion(
65
- messages=messages,
66
- model=model,
67
- client=client,
68
- tools=tools,
69
- max_tokens=max_tokens,
70
- temperature=temperature,
71
- json_mode=json_mode,
72
- seed=seed,
73
- **kwargs,
74
- )
75
- elif isinstance(client, ChatCompletionsClient):
76
- response = chat_completion_gh_models.chat_completion(
77
- messages=messages, # type: ignore
78
- model=model,
79
- client=client,
80
- tools=tools, # type: ignore
81
- max_tokens=max_tokens,
82
- temperature=temperature,
83
- json_mode=json_mode,
84
- seed=seed,
85
- **kwargs,
86
- )
87
- else:
88
- raise ValueError("Invalid client type")
89
-
90
- # Parse the responses to be consistent
91
- response_data = {}
92
- response_data["message"] = response.get("message")
93
- if response.get("tool_names") and response.get("tool_args_list"):
94
- response_data["tool_names"] = response.get("tool_names")
95
- response_data["tool_args_list"] = response.get("tool_args_list")
96
- response_data["completion_tokens"] = response.get("completion_tokens")
97
- response_data["prompt_tokens"] = response.get("prompt_tokens")
98
- response_data["response_duration"] = response.get("response_duration")
99
-
100
- # Return any additional fields from the response in an "extras" dictionary
101
- extras = {k: v for k, v in response.items() if k not in response_data}
102
- if extras:
103
- response_data["extras"] = extras
104
-
105
- return response_data
@@ -1,59 +0,0 @@
1
- from pathlib import Path
2
- from typing import Any
3
-
4
- from PIL import Image
5
-
6
-
7
- def chat_completion_image(
8
- messages: list[dict[str, str]],
9
- images: list[Path] | None,
10
- model_processor: tuple[Any, Any],
11
- max_tokens: int | None = None,
12
- temperature: float = 0.7,
13
- ) -> dict[str, Any]:
14
- """A wrapper around ision language model inference for multimodal language models from huggingface.
15
-
16
- Args:
17
- messages (list[dict[str, str]]): A list of messages to send to the model.
18
- images (list[Path] | None): A list of image paths to send to the model.
19
- model_processor (tuple[Any, Any]): A tuple containing the model and processor objects.
20
- max_tokens (int, optional): The maximum number of tokens to generate. Defaults to None.
21
- temperature (float, optional): The temperature of the model. Increasing the temperature will make the model answer more creatively. Defaults to 0.7.
22
-
23
- Returns:
24
- dict[str, Any]: A dictionary with the following keys
25
- message (str): The content of the generated assistant message.
26
- completion_tokens (int): The number of tokens used by the model to generate the completion.
27
- """
28
-
29
- model, processor = model_processor
30
-
31
- prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
-
33
- if images:
34
- image_objects = [Image.open(image) for image in images]
35
- inputs = processor(prompt, image_objects, return_tensors="pt").to("cuda:0")
36
- else:
37
- inputs = processor(prompt, return_tensors="pt").to("cuda:0")
38
-
39
- generation_args = {
40
- "max_new_tokens": max_tokens,
41
- "temperature": temperature,
42
- "num_beams": 1,
43
- "do_sample": True,
44
- }
45
-
46
- generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
47
-
48
- # Remove input tokens
49
- generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
50
-
51
- # Get the number of generated tokens
52
- completion_tokens = generate_ids.shape[1]
53
-
54
- response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
55
-
56
- response_data: dict[str, Any] = {}
57
- response_data["message"] = response[0]
58
- response_data["completion_tokens"] = completion_tokens
59
- return response_data
@@ -1,23 +0,0 @@
1
- from typing import Any
2
-
3
- from transformers import AutoModelForCausalLM, AutoProcessor
4
-
5
-
6
- def load_model(model_id: str, device_map: str = "cuda", trust_remote_code: bool = True) -> Any:
7
- """Load a model from Hugging Face."""
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_id,
10
- device_map=device_map,
11
- trust_remote_code=trust_remote_code,
12
- torch_dtype="auto",
13
- )
14
- return model
15
-
16
-
17
- def load_processor(model_id: str, trust_remote_code: bool = True) -> Any:
18
- """Load a processor from Hugging Face. This is typically used for multimodal language models."""
19
- processor = AutoProcessor.from_pretrained(
20
- model_id,
21
- trust_remote_code=trust_remote_code,
22
- )
23
- return processor
File without changes