chatlas 0.10.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
7
7
  import orjson
8
8
 
9
9
  from ._chat import Chat
10
- from ._provider_openai import OpenAIProvider
10
+ from ._provider_openai import ModelInfo, OpenAIProvider
11
11
  from ._utils import MISSING_TYPE, is_testing
12
12
 
13
13
  if TYPE_CHECKING:
@@ -90,18 +90,19 @@ def ChatOllama(
90
90
  raise RuntimeError("Can't find locally running ollama.")
91
91
 
92
92
  if model is None:
93
- models = ollama_models(base_url)
93
+ models = ollama_model_info(base_url)
94
+ model_ids = [m["id"] for m in models]
94
95
  raise ValueError(
95
- f"Must specify model. Locally installed models: {', '.join(models)}"
96
+ f"Must specify model. Locally installed models: {', '.join(model_ids)}"
96
97
  )
97
98
  if isinstance(seed, MISSING_TYPE):
98
99
  seed = 1014 if is_testing() else None
99
100
 
100
101
  return Chat(
101
- provider=OpenAIProvider(
102
+ provider=OllamaProvider(
102
103
  api_key="ollama", # ignored
103
104
  model=model,
104
- base_url=f"{base_url}/v1",
105
+ base_url=base_url,
105
106
  seed=seed,
106
107
  name="Ollama",
107
108
  kwargs=kwargs,
@@ -110,10 +111,40 @@ def ChatOllama(
110
111
  )
111
112
 
112
113
 
113
- def ollama_models(base_url: str) -> list[str]:
114
- res = urllib.request.urlopen(url=f"{base_url}/api/tags")
115
- data = orjson.loads(res.read())
116
- return [re.sub(":latest$", "", x["name"]) for x in data["models"]]
114
+ class OllamaProvider(OpenAIProvider):
115
+ def __init__(self, *, api_key, model, base_url, seed, name, kwargs):
116
+ super().__init__(
117
+ api_key=api_key,
118
+ model=model,
119
+ base_url=f"{base_url}/v1",
120
+ seed=seed,
121
+ name=name,
122
+ kwargs=kwargs,
123
+ )
124
+ self.base_url = base_url
125
+
126
+ def list_models(self):
127
+ return ollama_model_info(self.base_url)
128
+
129
+
130
+ def ollama_model_info(base_url: str) -> list[ModelInfo]:
131
+ response = urllib.request.urlopen(url=f"{base_url}/api/tags")
132
+ data = orjson.loads(response.read())
133
+ models = data.get("models", [])
134
+ if not models:
135
+ return []
136
+
137
+ res: list[ModelInfo] = []
138
+ for model in models:
139
+ # TODO: add capabilities
140
+ info: ModelInfo = {
141
+ "id": re.sub(":latest$", "", model["name"]),
142
+ "created_at": model["modified_at"],
143
+ "size": model["size"],
144
+ }
145
+ res.append(info)
146
+
147
+ return res
117
148
 
118
149
 
119
150
  def has_ollama(base_url):
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import base64
4
+ from datetime import datetime
4
5
  from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
5
6
 
6
7
  import orjson
@@ -23,8 +24,8 @@ from ._content import (
23
24
  )
24
25
  from ._logging import log_model_default
25
26
  from ._merge import merge_dicts
26
- from ._provider import Provider, StandardModelParamNames, StandardModelParams
27
- from ._tokens import tokens_log
27
+ from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams
28
+ from ._tokens import get_token_pricing, tokens_log
28
29
  from ._tools import Tool, basemodel_to_param_schema
29
30
  from ._turn import Turn, user_turn
30
31
  from ._utils import MISSING, MISSING_TYPE, is_testing, split_http_client_kwargs
@@ -200,6 +201,32 @@ class OpenAIProvider(
200
201
  self._client = OpenAI(**sync_kwargs) # type: ignore
201
202
  self._async_client = AsyncOpenAI(**async_kwargs)
202
203
 
204
+ def list_models(self):
205
+ models = self._client.models.list()
206
+
207
+ res: list[ModelInfo] = []
208
+ for m in models:
209
+ pricing = get_token_pricing(self.name, m.id) or {}
210
+ info: ModelInfo = {
211
+ "id": m.id,
212
+ "owned_by": m.owned_by,
213
+ "input": pricing.get("input"),
214
+ "output": pricing.get("output"),
215
+ "cached_input": pricing.get("cached_input"),
216
+ }
217
+ # DeepSeek compatibility
218
+ if m.created is not None:
219
+ info["created_at"] = datetime.fromtimestamp(m.created).date()
220
+ res.append(info)
221
+
222
+ # More recent models first
223
+ res.sort(
224
+ key=lambda x: x.get("created_at", 0),
225
+ reverse=True,
226
+ )
227
+
228
+ return res
229
+
203
230
  @overload
204
231
  def chat_perform(
205
232
  self,
@@ -126,7 +126,7 @@ def ChatPerplexity(
126
126
  seed = 1014 if is_testing() else None
127
127
 
128
128
  return Chat(
129
- provider=OpenAIProvider(
129
+ provider=PerplexityProvider(
130
130
  api_key=api_key,
131
131
  model=model,
132
132
  base_url=base_url,
@@ -136,3 +136,11 @@ def ChatPerplexity(
136
136
  ),
137
137
  system_prompt=system_prompt,
138
138
  )
139
+
140
+
141
+ class PerplexityProvider(OpenAIProvider):
142
+ def list_models(self):
143
+ raise NotImplementedError(
144
+ ".list_models() is not yet implemented for Perplexity."
145
+ " To view available models online, see https://docs.perplexity.ai/getting-started/models"
146
+ )
@@ -121,3 +121,11 @@ def add_default_headers(
121
121
  }
122
122
  )
123
123
  return {"default_headers": default_headers, **kwargs}
124
+
125
+
126
+ class PortkeyProvider(OpenAIProvider):
127
+ def list_models(self):
128
+ raise NotImplementedError(
129
+ ".list_models() is not yet implemented for Portkey. "
130
+ "To view model availability online, see https://portkey.ai/docs/product/model-catalog"
131
+ )
@@ -194,6 +194,12 @@ class SnowflakeProvider(
194
194
  session = Session.builder.configs(configs).create()
195
195
  self._cortex_service = Root(session).cortex_inference_service
196
196
 
197
+ def list_models(self):
198
+ raise NotImplementedError(
199
+ ".list_models() is not yet implemented for Snowflake. "
200
+ "To view model availability online, see https://docs.snowflake.com/user-guide/snowflake-cortex/aisql#availability"
201
+ )
202
+
197
203
  @overload
198
204
  def chat_perform(
199
205
  self,
chatlas/_tools.py CHANGED
@@ -2,7 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import warnings
5
- from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Optional
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ AsyncGenerator,
9
+ Awaitable,
10
+ Callable,
11
+ Optional,
12
+ cast,
13
+ )
6
14
 
7
15
  import openai
8
16
  from pydantic import BaseModel, Field, create_model
@@ -12,6 +20,7 @@ from ._content import (
12
20
  ContentToolResult,
13
21
  ContentToolResultImage,
14
22
  ContentToolResultResource,
23
+ ToolAnnotations,
15
24
  )
16
25
 
17
26
  __all__ = (
@@ -42,6 +51,8 @@ class Tool:
42
51
  A description of what the tool does.
43
52
  parameters
44
53
  A dictionary describing the input parameters and their types.
54
+ annotations
55
+ Additional properties that describe the tool and its behavior.
45
56
  """
46
57
 
47
58
  func: Callable[..., Any] | Callable[..., Awaitable[Any]]
@@ -53,9 +64,11 @@ class Tool:
53
64
  name: str,
54
65
  description: str,
55
66
  parameters: dict[str, Any],
67
+ annotations: "Optional[ToolAnnotations]" = None,
56
68
  ):
57
69
  self.name = name
58
70
  self.func = func
71
+ self.annotations = annotations
59
72
  self._is_async = _utils.is_async_callable(func)
60
73
  self.schema: "ChatCompletionToolParam" = {
61
74
  "type": "function",
@@ -71,7 +84,9 @@ class Tool:
71
84
  cls: type["Tool"],
72
85
  func: Callable[..., Any] | Callable[..., Awaitable[Any]],
73
86
  *,
87
+ name: Optional[str] = None,
74
88
  model: Optional[type[BaseModel]] = None,
89
+ annotations: "Optional[ToolAnnotations]" = None,
75
90
  ) -> "Tool":
76
91
  """
77
92
  Create a Tool from a Python function
@@ -80,12 +95,17 @@ class Tool:
80
95
  ----------
81
96
  func
82
97
  The function to wrap as a tool.
98
+ name
99
+ The name of the tool. If not provided, the name will be inferred from the
100
+ function's name.
83
101
  model
84
102
  A Pydantic model that describes the input parameters for the function.
85
103
  If not provided, the model will be inferred from the function's type hints.
86
104
  The primary reason why you might want to provide a model in
87
105
  Note that the name and docstring of the model takes precedence over the
88
106
  name and docstring of the function.
107
+ annotations
108
+ Additional properties that describe the tool and its behavior.
89
109
 
90
110
  Returns
91
111
  -------
@@ -104,7 +124,8 @@ class Tool:
104
124
  # Throw if there is a mismatch between the model and the function parameters
105
125
  params = inspect.signature(func).parameters
106
126
  fields = model.model_fields
107
- diff = set(params) ^ set(fields)
127
+ fields_alias = [val.alias if val.alias else key for key, val in fields.items()]
128
+ diff = set(params) ^ set(fields_alias)
108
129
  if diff:
109
130
  raise ValueError(
110
131
  f"`model` fields must match tool function parameters exactly. "
@@ -115,9 +136,10 @@ class Tool:
115
136
 
116
137
  return cls(
117
138
  func=func,
118
- name=model.__name__ or func.__name__,
139
+ name=name or model.__name__ or func.__name__,
119
140
  description=model.__doc__ or func.__doc__ or "",
120
141
  parameters=params,
142
+ annotations=annotations,
121
143
  )
122
144
 
123
145
  @classmethod
@@ -192,11 +214,17 @@ class Tool:
192
214
 
193
215
  params = mcp_tool_input_schema_to_param_schema(mcp_tool.inputSchema)
194
216
 
217
+ # Convert MCP ToolAnnotations to our TypedDict format
218
+ annotations = None
219
+ if mcp_tool.annotations:
220
+ annotations = cast(ToolAnnotations, mcp_tool.annotations.model_dump())
221
+
195
222
  return cls(
196
223
  func=_utils.wrap_async(_call),
197
224
  name=mcp_tool.name,
198
225
  description=mcp_tool.description or "",
199
226
  parameters=params,
227
+ annotations=annotations,
200
228
  )
201
229
 
202
230
 
@@ -13,7 +13,7 @@ else:
13
13
  # Even though TypedDict is available in Python 3.8, because it's used with NotRequired,
14
14
  # they should both come from the same typing module.
15
15
  # https://peps.python.org/pep-0655/#usage-in-python-3-11
16
- if sys.version_info >= (3, 11):
16
+ if sys.version_info >= (3, 12):
17
17
  from typing import NotRequired, Required, TypedDict
18
18
  else:
19
19
  from typing_extensions import NotRequired, Required, TypedDict
chatlas/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.10.0'
32
- __version_tuple__ = version_tuple = (0, 10, 0)
31
+ __version__ = version = '0.11.1'
32
+ __version_tuple__ = version_tuple = (0, 11, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None