not-again-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. not_again_ai/llm/chat_completion/__init__.py +4 -0
  2. not_again_ai/llm/chat_completion/interface.py +32 -0
  3. not_again_ai/llm/chat_completion/providers/ollama_api.py +227 -0
  4. not_again_ai/llm/chat_completion/providers/openai_api.py +290 -0
  5. not_again_ai/llm/chat_completion/types.py +145 -0
  6. not_again_ai/llm/embedding/__init__.py +4 -0
  7. not_again_ai/llm/embedding/interface.py +28 -0
  8. not_again_ai/llm/embedding/providers/ollama_api.py +87 -0
  9. not_again_ai/llm/embedding/providers/openai_api.py +126 -0
  10. not_again_ai/llm/embedding/types.py +23 -0
  11. not_again_ai/llm/prompting/__init__.py +3 -0
  12. not_again_ai/llm/prompting/compile_prompt.py +125 -0
  13. not_again_ai/llm/prompting/interface.py +46 -0
  14. not_again_ai/llm/prompting/providers/openai_tiktoken.py +122 -0
  15. not_again_ai/llm/prompting/types.py +43 -0
  16. {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/METADATA +24 -40
  17. not_again_ai-0.16.0.dist-info/RECORD +38 -0
  18. {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/WHEEL +1 -1
  19. not_again_ai/llm/gh_models/azure_ai_client.py +0 -20
  20. not_again_ai/llm/gh_models/chat_completion.py +0 -81
  21. not_again_ai/llm/openai_api/chat_completion.py +0 -339
  22. not_again_ai/llm/openai_api/context_management.py +0 -70
  23. not_again_ai/llm/openai_api/embeddings.py +0 -62
  24. not_again_ai/llm/openai_api/openai_client.py +0 -78
  25. not_again_ai/llm/openai_api/prompts.py +0 -191
  26. not_again_ai/llm/openai_api/tokens.py +0 -184
  27. not_again_ai/local_llm/__init__.py +0 -27
  28. not_again_ai/local_llm/chat_completion.py +0 -105
  29. not_again_ai/local_llm/huggingface/chat_completion.py +0 -59
  30. not_again_ai/local_llm/huggingface/helpers.py +0 -23
  31. not_again_ai/local_llm/ollama/__init__.py +0 -0
  32. not_again_ai/local_llm/ollama/chat_completion.py +0 -111
  33. not_again_ai/local_llm/ollama/model_mapping.py +0 -17
  34. not_again_ai/local_llm/ollama/ollama_client.py +0 -24
  35. not_again_ai/local_llm/ollama/service.py +0 -81
  36. not_again_ai/local_llm/ollama/tokens.py +0 -104
  37. not_again_ai/local_llm/prompts.py +0 -38
  38. not_again_ai/local_llm/tokens.py +0 -90
  39. not_again_ai-0.14.0.dist-info/RECORD +0 -44
  40. not_again_ai-0.14.0.dist-info/entry_points.txt +0 -3
  41. /not_again_ai/llm/{gh_models → chat_completion/providers}/__init__.py +0 -0
  42. /not_again_ai/llm/{openai_api → embedding/providers}/__init__.py +0 -0
  43. /not_again_ai/{local_llm/huggingface → llm/prompting/providers}/__init__.py +0 -0
  44. {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,145 @@
1
+ from enum import Enum
2
+ from typing import Any, Generic, Literal, TypeVar
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class Role(str, Enum):
8
+ ASSISTANT = "assistant"
9
+ DEVELOPER = "developer"
10
+ SYSTEM = "system"
11
+ TOOL = "tool"
12
+ USER = "user"
13
+
14
+
15
+ class ContentPartType(str, Enum):
16
+ TEXT = "text"
17
+ IMAGE = "image_url"
18
+
19
+
20
+ class TextContent(BaseModel):
21
+ type: Literal[ContentPartType.TEXT] = ContentPartType.TEXT
22
+ text: str
23
+
24
+
25
+ class ImageDetail(str, Enum):
26
+ AUTO = "auto"
27
+ LOW = "low"
28
+ HIGH = "high"
29
+
30
+
31
+ class ImageUrl(BaseModel):
32
+ url: str
33
+ detail: ImageDetail = ImageDetail.AUTO
34
+
35
+
36
+ class ImageContent(BaseModel):
37
+ type: Literal[ContentPartType.IMAGE] = ContentPartType.IMAGE
38
+ image_url: ImageUrl
39
+
40
+
41
+ ContentT = TypeVar("ContentT", bound=str | list[TextContent | ImageContent])
42
+
43
+
44
+ class BaseMessage(BaseModel, Generic[ContentT]):
45
+ content: ContentT
46
+ role: Role
47
+ name: str | None = None
48
+
49
+
50
+ class Function(BaseModel):
51
+ name: str
52
+ arguments: dict[str, Any]
53
+
54
+
55
+ class ToolCall(BaseModel):
56
+ id: str
57
+ function: Function
58
+ type: Literal["function"] = "function"
59
+
60
+
61
+ class DeveloperMessage(BaseMessage[str]):
62
+ role: Literal[Role.DEVELOPER] = Role.DEVELOPER
63
+
64
+
65
+ class SystemMessage(BaseMessage[str]):
66
+ role: Literal[Role.SYSTEM] = Role.SYSTEM
67
+
68
+
69
+ class UserMessage(BaseMessage[str | list[TextContent | ImageContent]]):
70
+ role: Literal[Role.USER] = Role.USER
71
+
72
+
73
+ class AssistantMessage(BaseMessage[str]):
74
+ role: Literal[Role.ASSISTANT] = Role.ASSISTANT
75
+ refusal: str | None = None
76
+ tool_calls: list[ToolCall] | None = None
77
+
78
+
79
+ class ToolMessage(BaseMessage[str]):
80
+ # A tool message's name field will be interpreted as "tool_call_id"
81
+ role: Literal[Role.TOOL] = Role.TOOL
82
+
83
+
84
+ MessageT = AssistantMessage | DeveloperMessage | SystemMessage | ToolMessage | UserMessage
85
+
86
+
87
+ class ChatCompletionRequest(BaseModel):
88
+ messages: list[MessageT]
89
+ model: str
90
+
91
+ max_completion_tokens: int | None = Field(default=None)
92
+ context_window: int | None = Field(default=None)
93
+ logprobs: bool | None = Field(default=None)
94
+ n: int | None = Field(default=None)
95
+
96
+ tools: list[dict[str, Any]] | None = Field(default=None)
97
+ tool_choice: str | None = Field(default=None)
98
+ parallel_tool_calls: bool | None = Field(default=None)
99
+ json_mode: bool | None = Field(default=None)
100
+ structured_outputs: dict[str, Any] | None = Field(default=None)
101
+
102
+ temperature: float | None = Field(default=None)
103
+ reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None)
104
+ top_p: float | None = Field(default=None)
105
+ logit_bias: dict[str, float] | None = Field(default=None)
106
+ top_logprobs: int | None = Field(default=None)
107
+ frequency_penalty: float | None = Field(default=None)
108
+ presence_penalty: float | None = Field(default=None)
109
+ stop: str | list[str] | None = Field(default=None)
110
+
111
+ seed: int | None = Field(default=None)
112
+
113
+ mirostat: int | None = Field(default=None)
114
+ mirostat_eta: float | None = Field(default=None)
115
+ mirostat_tau: float | None = Field(default=None)
116
+ repeat_last_n: int | None = Field(default=None)
117
+ tfs_z: float | None = Field(default=None)
118
+ top_k: int | None = Field(default=None)
119
+ min_p: float | None = Field(default=None)
120
+
121
+
122
+ class ChatCompletionChoice(BaseModel):
123
+ message: AssistantMessage
124
+ finish_reason: Literal["stop", "length", "tool_calls", "content_filter"]
125
+
126
+ json_message: dict[str, Any] | None = Field(default=None)
127
+ logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None)
128
+
129
+ extras: Any | None = Field(default=None)
130
+
131
+
132
+ class ChatCompletionResponse(BaseModel):
133
+ choices: list[ChatCompletionChoice]
134
+
135
+ errors: str = Field(default="")
136
+
137
+ completion_tokens: int
138
+ prompt_tokens: int
139
+ completion_detailed_tokens: dict[str, int] | None = Field(default=None)
140
+ prompt_detailed_tokens: dict[str, int] | None = Field(default=None)
141
+ response_duration: float
142
+
143
+ system_fingerprint: str | None = Field(default=None)
144
+
145
+ extras: Any | None = Field(default=None)
@@ -0,0 +1,4 @@
1
+ from not_again_ai.llm.embedding.interface import create_embeddings
2
+ from not_again_ai.llm.embedding.types import EmbeddingRequest
3
+
4
+ __all__ = ["EmbeddingRequest", "create_embeddings"]
@@ -0,0 +1,28 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ from not_again_ai.llm.embedding.providers.ollama_api import ollama_create_embeddings
5
+ from not_again_ai.llm.embedding.providers.openai_api import openai_create_embeddings
6
+ from not_again_ai.llm.embedding.types import EmbeddingRequest, EmbeddingResponse
7
+
8
+
9
+ def create_embeddings(request: EmbeddingRequest, provider: str, client: Callable[..., Any]) -> EmbeddingResponse:
10
+ """Get a embedding response from the given provider. Currently supported providers:
11
+ - `openai` - OpenAI
12
+ - `azure_openai` - Azure OpenAI
13
+ - `ollama` - Ollama
14
+
15
+ Args:
16
+ request: Request parameter object
17
+ provider: The supported provider name
18
+ client: Client information, see the provider's implementation for what can be provided
19
+
20
+ Returns:
21
+ EmbeddingResponse: The embedding response.
22
+ """
23
+ if provider == "openai" or provider == "azure_openai":
24
+ return openai_create_embeddings(request, client)
25
+ elif provider == "ollama":
26
+ return ollama_create_embeddings(request, client)
27
+ else:
28
+ raise ValueError(f"Provider {provider} not supported")
@@ -0,0 +1,87 @@
1
+ from collections.abc import Callable
2
+ import os
3
+ import re
4
+ import time
5
+ from typing import Any
6
+
7
+ from loguru import logger
8
+ from ollama import Client, EmbedResponse, ResponseError
9
+
10
+ from not_again_ai.llm.embedding.types import EmbeddingObject, EmbeddingRequest, EmbeddingResponse
11
+
12
+ OLLAMA_PARAMETER_MAP = {
13
+ "dimensions": None,
14
+ }
15
+
16
+
17
+ def validate(request: EmbeddingRequest) -> None:
18
+ # Check if any of the parameters set to OLLAMA_PARAMETER_MAP are not None
19
+ for key, value in OLLAMA_PARAMETER_MAP.items():
20
+ if value is None and getattr(request, key) is not None:
21
+ logger.warning(f"Parameter {key} is not supported by Ollama and will be ignored.")
22
+
23
+
24
+ def ollama_create_embeddings(request: EmbeddingRequest, client: Callable[..., Any]) -> EmbeddingResponse:
25
+ validate(request)
26
+ kwargs = request.model_dump(mode="json", exclude_none=True)
27
+
28
+ # For each key in OLLAMA_PARAMETER_MAP
29
+ # If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP
30
+ # If it is None, remove that key from kwargs
31
+ for key, value in OLLAMA_PARAMETER_MAP.items():
32
+ if value is not None and key in kwargs:
33
+ kwargs[value] = kwargs.pop(key)
34
+ elif value is None and key in kwargs:
35
+ del kwargs[key]
36
+
37
+ # Explicitly set truncate to True (it is the default)
38
+ kwargs["truncate"] = True
39
+
40
+ try:
41
+ start_time = time.time()
42
+ response: EmbedResponse = client(**kwargs)
43
+ end_time = time.time()
44
+ response_duration = round(end_time - start_time, 4)
45
+ except ResponseError as e:
46
+ # If the error says "model 'model' not found" use regex then raise a more specific error
47
+ expected_pattern = f"model '{request.model}' not found"
48
+ if re.search(expected_pattern, e.error):
49
+ raise ResponseError(f"Model '{request.model}' not found.") from e
50
+ else:
51
+ raise ResponseError(e.error) from e
52
+
53
+ embeddings: list[EmbeddingObject] = []
54
+ for index, embedding in enumerate(response.embeddings):
55
+ embeddings.append(EmbeddingObject(embedding=list(embedding), index=index))
56
+
57
+ return EmbeddingResponse(
58
+ embeddings=embeddings,
59
+ response_duration=response_duration,
60
+ total_tokens=response.prompt_eval_count,
61
+ )
62
+
63
+
64
+ def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]:
65
+ """Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable.
66
+
67
+ Args:
68
+ host (str, optional): The host URL of the Ollama server.
69
+ timeout (float, optional): The timeout for requests
70
+
71
+ Returns:
72
+ Client: An instance of the Ollama client.
73
+
74
+ Examples:
75
+ >>> client = client(host="http://localhost:11434")
76
+ """
77
+ if host is None:
78
+ host = os.getenv("OLLAMA_HOST")
79
+ if host is None:
80
+ logger.warning("OLLAMA_HOST environment variable not set, using default host: http://localhost:11434")
81
+ host = "http://localhost:11434"
82
+
83
+ def client_callable(**kwargs: Any) -> Any:
84
+ client = Client(host=host, timeout=timeout)
85
+ return client.embed(**kwargs)
86
+
87
+ return client_callable
@@ -0,0 +1,126 @@
1
+ from collections.abc import Callable
2
+ import time
3
+ from typing import Any, Literal
4
+
5
+ from azure.identity import DefaultAzureCredential, get_bearer_token_provider
6
+ from openai import AzureOpenAI, OpenAI
7
+
8
+ from not_again_ai.llm.embedding.types import EmbeddingObject, EmbeddingRequest, EmbeddingResponse
9
+
10
+
11
+ def openai_create_embeddings(request: EmbeddingRequest, client: Callable[..., Any]) -> EmbeddingResponse:
12
+ kwargs = request.model_dump(mode="json", exclude_none=True)
13
+
14
+ start_time = time.time()
15
+ response = client(**kwargs)
16
+ end_time = time.time()
17
+ response_duration = round(end_time - start_time, 4)
18
+
19
+ embeddings: list[EmbeddingObject] = []
20
+ for data in response["data"]:
21
+ embeddings.append(EmbeddingObject(embedding=data["embedding"], index=data["index"]))
22
+
23
+ return EmbeddingResponse(
24
+ embeddings=embeddings,
25
+ response_duration=response_duration,
26
+ total_tokens=response["usage"]["total_tokens"],
27
+ )
28
+
29
+
30
+ def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]:
31
+ """Creates a callable that instantiates and uses an OpenAI client.
32
+
33
+ Args:
34
+ client_class: The OpenAI client class to instantiate (OpenAI or AzureOpenAI)
35
+ **client_args: Arguments to pass to the client constructor
36
+
37
+ Returns:
38
+ A callable that creates a client and returns completion results
39
+ """
40
+ filtered_args = {k: v for k, v in client_args.items() if v is not None}
41
+
42
+ def client_callable(**kwargs: Any) -> Any:
43
+ client = client_class(**filtered_args)
44
+ completion = client.embeddings.create(**kwargs)
45
+ return completion.to_dict()
46
+
47
+ return client_callable
48
+
49
+
50
+ class InvalidOAIAPITypeError(Exception):
51
+ """Raised when an invalid OAIAPIType string is provided."""
52
+
53
+
54
+ def openai_client(
55
+ api_type: Literal["openai", "azure_openai"] = "openai",
56
+ api_key: str | None = None,
57
+ organization: str | None = None,
58
+ aoai_api_version: str = "2024-06-01",
59
+ azure_endpoint: str | None = None,
60
+ timeout: float | None = None,
61
+ max_retries: int | None = None,
62
+ ) -> Callable[..., Any]:
63
+ """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters.
64
+
65
+ It is preferred to use RBAC authentication for Azure OpenAI. You must be signed in with the Azure CLI and have correct role assigned.
66
+ See https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
67
+
68
+ Args:
69
+ api_type (str, optional): Type of the API to be used. Accepted values are 'openai' or 'azure_openai'.
70
+ Defaults to 'openai'.
71
+ api_key (str, optional): The API key to authenticate the client. If not provided,
72
+ OpenAI automatically uses `OPENAI_API_KEY` from the environment.
73
+ If provided for Azure OpenAI, it will be used for authentication instead of the Azure AD token provider.
74
+ organization (str, optional): The ID of the organization. If not provided,
75
+ OpenAI automotically uses `OPENAI_ORG_ID` from the environment.
76
+ aoai_api_version (str, optional): Only applicable if using Azure OpenAI https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning
77
+ azure_endpoint (str, optional): The endpoint to use for Azure OpenAI.
78
+ timeout (float, optional): By default requests time out after 10 minutes.
79
+ max_retries (int, optional): Certain errors are automatically retried 2 times by default,
80
+ with a short exponential backoff. Connection errors (for example, due to a network connectivity problem),
81
+ 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default.
82
+
83
+ Returns:
84
+ Callable[..., Any]: A callable that creates a client and returns completion results
85
+
86
+
87
+ Raises:
88
+ InvalidOAIAPITypeError: If an invalid API type string is provided.
89
+ NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai').
90
+ """
91
+ if api_type not in ["openai", "azure_openai"]:
92
+ raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.")
93
+
94
+ if api_type == "openai":
95
+ return create_client_callable(
96
+ OpenAI,
97
+ api_key=api_key,
98
+ organization=organization,
99
+ timeout=timeout,
100
+ max_retries=max_retries,
101
+ )
102
+ elif api_type == "azure_openai":
103
+ if api_key:
104
+ return create_client_callable(
105
+ AzureOpenAI,
106
+ api_version=aoai_api_version,
107
+ azure_endpoint=azure_endpoint,
108
+ api_key=api_key,
109
+ timeout=timeout,
110
+ max_retries=max_retries,
111
+ )
112
+ else:
113
+ azure_credential = DefaultAzureCredential()
114
+ ad_token_provider = get_bearer_token_provider(
115
+ azure_credential, "https://cognitiveservices.azure.com/.default"
116
+ )
117
+ return create_client_callable(
118
+ AzureOpenAI,
119
+ api_version=aoai_api_version,
120
+ azure_endpoint=azure_endpoint,
121
+ azure_ad_token_provider=ad_token_provider,
122
+ timeout=timeout,
123
+ max_retries=max_retries,
124
+ )
125
+ else:
126
+ raise NotImplementedError(f"API type '{api_type}' is invalid.")
@@ -0,0 +1,23 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class EmbeddingRequest(BaseModel):
7
+ input: str | list[str]
8
+ model: str
9
+ dimensions: int | None = Field(default=None)
10
+
11
+
12
+ class EmbeddingObject(BaseModel):
13
+ embedding: list[float]
14
+ index: int
15
+
16
+
17
+ class EmbeddingResponse(BaseModel):
18
+ embeddings: list[EmbeddingObject]
19
+ total_tokens: int | None = Field(default=None)
20
+ response_duration: float
21
+
22
+ errors: str = Field(default="")
23
+ extras: Any | None = Field(default=None)
@@ -0,0 +1,3 @@
1
+ from not_again_ai.llm.prompting.interface import Tokenizer
2
+
3
+ __all__ = ["Tokenizer"]
@@ -0,0 +1,125 @@
1
+ import base64
2
+ from collections.abc import Sequence
3
+ from copy import deepcopy
4
+ import mimetypes
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from liquid import Template
9
+ from openai.lib._pydantic import to_strict_json_schema
10
+ from pydantic import BaseModel
11
+
12
+ from not_again_ai.llm.chat_completion.types import MessageT
13
+
14
+
15
+ def _apply_templates(value: Any, variables: dict[str, str]) -> Any:
16
+ """Recursively applies Liquid templating to all string fields within the given value."""
17
+ if isinstance(value, str):
18
+ return Template(value).render(**variables)
19
+ elif isinstance(value, list):
20
+ return [_apply_templates(item, variables) for item in value]
21
+ elif isinstance(value, dict):
22
+ return {key: _apply_templates(val, variables) for key, val in value.items()}
23
+ elif isinstance(value, BaseModel):
24
+ # Process each field in the BaseModel by converting it to a dict,
25
+ # applying templating to its values, and then re-instantiating the model.
26
+ processed_data = {key: _apply_templates(val, variables) for key, val in value.model_dump().items()}
27
+ return value.__class__(**processed_data)
28
+ else:
29
+ return value
30
+
31
+
32
+ def compile_messages(messages: Sequence[MessageT], variables: dict[str, str]) -> Sequence[MessageT]:
33
+ """Compiles messages using Liquid templating and the provided variables.
34
+ Calls Template(content_part).render(**variables) on each text content part.
35
+
36
+ Args:
37
+ messages: List of MessageT where content can contain Liquid templates.
38
+ variables: The variables to inject into the templates.
39
+
40
+ Returns:
41
+ The same list of messages with the content parts injected with the variables.
42
+ """
43
+ messages_formatted = deepcopy(messages)
44
+ messages_formatted = [_apply_templates(message, variables) for message in messages_formatted]
45
+ return messages_formatted
46
+
47
+
48
+ def compile_tools(tools: Sequence[dict[str, Any]], variables: dict[str, str]) -> Sequence[dict[str, Any]]:
49
+ """Compiles a list of tool argument dictionaries using Liquid templating and provided variables.
50
+
51
+ Each dictionary in the list is deep copied and processed recursively to substitute any Liquid
52
+ templates present in its data structure.
53
+
54
+ Args:
55
+ tools: A list of dictionaries representing tool arguments, where values can include Liquid templates.
56
+ variables: A dictionary of variables to substitute into the Liquid templates.
57
+
58
+ Returns:
59
+ A new list of dictionaries with the Liquid templates replaced by their corresponding variable values.
60
+ """
61
+ tools_formatted = deepcopy(tools)
62
+ tools_formatted = [_apply_templates(tool, variables) for tool in tools_formatted]
63
+ return tools_formatted
64
+
65
+
66
+ def encode_image(image_path: Path) -> str:
67
+ """Encodes an image file at the given Path to base64.
68
+
69
+ Args:
70
+ image_path: The path to the image file to encode.
71
+
72
+ Returns:
73
+ The base64 encoded image as a string.
74
+ """
75
+ with Path.open(image_path, "rb") as image_file:
76
+ return base64.b64encode(image_file.read()).decode("utf-8")
77
+
78
+
79
+ def create_image_url(image_path: Path) -> str:
80
+ """Creates a data URL for an image file at the given Path.
81
+
82
+ Args:
83
+ image_path: The path to the image file to encode.
84
+
85
+ Returns:
86
+ The data URL for the image.
87
+ """
88
+ image_data = encode_image(image_path)
89
+
90
+ valid_mime_types = ["image/jpeg", "image/png", "image/webp", "image/gif"]
91
+
92
+ # Get the MIME type from the image file extension
93
+ mime_type = mimetypes.guess_type(image_path)[0]
94
+
95
+ # Check if the MIME type is valid
96
+ # List of valid types is here: https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
97
+ if mime_type not in valid_mime_types:
98
+ raise ValueError(f"Invalid MIME type for image: {mime_type}")
99
+
100
+ return f"data:{mime_type};base64,{image_data}"
101
+
102
+
103
+ def pydantic_to_json_schema(
104
+ pydantic_model: type[BaseModel], schema_name: str, description: str | None = None
105
+ ) -> dict[str, Any]:
106
+ """Converts a Pydantic model to a JSON schema expected by Structured Outputs.
107
+ Must adhere to the supported schemas: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
108
+
109
+ Args:
110
+ pydantic_model: The Pydantic model to convert.
111
+ schema_name: The name of the schema.
112
+ description: An optional description of the schema.
113
+
114
+ Returns:
115
+ A JSON schema dictionary representing the Pydantic model.
116
+ """
117
+ converted_pydantic = to_strict_json_schema(pydantic_model)
118
+ schema = {
119
+ "name": schema_name,
120
+ "strict": True,
121
+ "schema": converted_pydantic,
122
+ }
123
+ if description:
124
+ schema["description"] = description
125
+ return schema
@@ -0,0 +1,46 @@
1
+ from collections.abc import Collection, Set
2
+ from typing import Literal
3
+
4
+ from loguru import logger
5
+
6
+ from not_again_ai.llm.chat_completion.types import MessageT
7
+ from not_again_ai.llm.prompting.providers.openai_tiktoken import TokenizerOpenAI
8
+ from not_again_ai.llm.prompting.types import BaseTokenizer
9
+
10
+
11
+ class Tokenizer(BaseTokenizer):
12
+ def __init__(
13
+ self,
14
+ model: str,
15
+ provider: str,
16
+ allowed_special: Literal["all"] | Set[str] | None = None,
17
+ disallowed_special: Literal["all"] | Collection[str] | None = None,
18
+ ):
19
+ self.model = model
20
+ self.provider = provider
21
+ self.allowed_special = allowed_special
22
+ self.disallowed_special = disallowed_special
23
+
24
+ self.init_tokenizer(model, provider, allowed_special, disallowed_special)
25
+
26
+ def init_tokenizer(
27
+ self,
28
+ model: str,
29
+ provider: str,
30
+ allowed_special: Literal["all"] | Set[str] | None = None,
31
+ disallowed_special: Literal["all"] | Collection[str] | None = None,
32
+ ) -> None:
33
+ if provider == "openai" or provider == "azure_openai":
34
+ self.tokenizer = TokenizerOpenAI(model, provider, allowed_special, disallowed_special)
35
+ else:
36
+ logger.warning(f"Provider {provider} not supported. Initializing using tiktoken and gpt-4o.")
37
+ self.tokenizer = TokenizerOpenAI("gpt-4o", "openai", allowed_special, disallowed_special)
38
+
39
+ def truncate_str(self, text: str, max_len: int) -> str:
40
+ return self.tokenizer.truncate_str(text, max_len)
41
+
42
+ def num_tokens_in_str(self, text: str) -> int:
43
+ return self.tokenizer.num_tokens_in_str(text)
44
+
45
+ def num_tokens_in_messages(self, messages: list[MessageT]) -> int:
46
+ return self.tokenizer.num_tokens_in_messages(messages)