pygpt-net 2.5.13__py3-none-any.whl → 2.5.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pygpt_net/CHANGELOG.txt +12 -1
- pygpt_net/__init__.py +3 -3
- pygpt_net/controller/chat/input.py +9 -2
- pygpt_net/controller/lang/mapping.py +4 -2
- pygpt_net/controller/model/__init__.py +3 -1
- pygpt_net/controller/model/importer.py +337 -0
- pygpt_net/controller/settings/editor.py +3 -0
- pygpt_net/core/camera/__init__.py +1 -1
- pygpt_net/core/models/__init__.py +6 -3
- pygpt_net/core/models/ollama.py +7 -2
- pygpt_net/data/config/config.json +9 -4
- pygpt_net/data/config/models.json +22 -22
- pygpt_net/data/config/modes.json +3 -3
- pygpt_net/data/locale/locale.de.ini +18 -0
- pygpt_net/data/locale/locale.en.ini +19 -2
- pygpt_net/data/locale/locale.es.ini +18 -0
- pygpt_net/data/locale/locale.fr.ini +18 -0
- pygpt_net/data/locale/locale.it.ini +18 -0
- pygpt_net/data/locale/locale.pl.ini +19 -1
- pygpt_net/data/locale/locale.uk.ini +18 -0
- pygpt_net/data/locale/locale.zh.ini +17 -0
- pygpt_net/item/__init__.py +0 -0
- pygpt_net/item/assistant.py +1 -1
- pygpt_net/item/attachment.py +2 -2
- pygpt_net/item/calendar_note.py +1 -1
- pygpt_net/item/ctx.py +1 -1
- pygpt_net/item/index.py +0 -0
- pygpt_net/item/mode.py +0 -0
- pygpt_net/item/model.py +5 -1
- pygpt_net/item/notepad.py +2 -2
- pygpt_net/item/preset.py +0 -0
- pygpt_net/item/prompt.py +2 -2
- pygpt_net/provider/core/ctx/db_sqlite/patch.py +2 -1
- pygpt_net/provider/core/model/json_file.py +3 -0
- pygpt_net/provider/core/model/patch.py +24 -1
- pygpt_net/provider/core/notepad/db_sqlite/patch.py +1 -0
- pygpt_net/provider/llms/ollama.py +7 -2
- pygpt_net/provider/llms/ollama_custom.py +693 -0
- pygpt_net/ui/dialog/models_importer.py +82 -0
- pygpt_net/ui/dialogs.py +3 -1
- pygpt_net/ui/main.py +4 -1
- pygpt_net/ui/menu/config.py +18 -7
- pygpt_net/ui/widget/dialog/model_importer.py +55 -0
- pygpt_net/ui/widget/lists/model_importer.py +151 -0
- {pygpt_net-2.5.13.dist-info → pygpt_net-2.5.15.dist-info}/METADATA +74 -9
- {pygpt_net-2.5.13.dist-info → pygpt_net-2.5.15.dist-info}/RECORD +45 -40
- {pygpt_net-2.5.13.dist-info → pygpt_net-2.5.15.dist-info}/LICENSE +0 -0
- {pygpt_net-2.5.13.dist-info → pygpt_net-2.5.15.dist-info}/WHEEL +0 -0
- {pygpt_net-2.5.13.dist-info → pygpt_net-2.5.15.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,693 @@
|
|
1
|
+
from typing import (
|
2
|
+
TYPE_CHECKING,
|
3
|
+
Any,
|
4
|
+
AsyncGenerator,
|
5
|
+
Dict,
|
6
|
+
Generator,
|
7
|
+
List,
|
8
|
+
Optional,
|
9
|
+
Sequence,
|
10
|
+
Tuple,
|
11
|
+
Type,
|
12
|
+
Union,
|
13
|
+
)
|
14
|
+
|
15
|
+
from ollama import AsyncClient, Client
|
16
|
+
|
17
|
+
from llama_index.core.base.llms.generic_utils import (
|
18
|
+
achat_to_completion_decorator,
|
19
|
+
astream_chat_to_completion_decorator,
|
20
|
+
chat_to_completion_decorator,
|
21
|
+
stream_chat_to_completion_decorator,
|
22
|
+
)
|
23
|
+
from llama_index.core.base.llms.types import (
|
24
|
+
ChatMessage,
|
25
|
+
ChatResponse,
|
26
|
+
ChatResponseAsyncGen,
|
27
|
+
ChatResponseGen,
|
28
|
+
CompletionResponse,
|
29
|
+
CompletionResponseAsyncGen,
|
30
|
+
CompletionResponseGen,
|
31
|
+
ImageBlock,
|
32
|
+
LLMMetadata,
|
33
|
+
MessageRole,
|
34
|
+
TextBlock,
|
35
|
+
)
|
36
|
+
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
37
|
+
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
|
38
|
+
from llama_index.core.instrumentation import get_dispatcher
|
39
|
+
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
|
40
|
+
from llama_index.core.llms.function_calling import FunctionCallingLLM
|
41
|
+
from llama_index.core.llms.llm import ToolSelection, Model
|
42
|
+
from llama_index.core.program.utils import process_streaming_objects, FlexibleModel
|
43
|
+
from llama_index.core.prompts import PromptTemplate
|
44
|
+
from llama_index.core.types import PydanticProgramMode
|
45
|
+
|
46
|
+
if TYPE_CHECKING:
|
47
|
+
from llama_index.core.tools.types import BaseTool
|
48
|
+
|
49
|
+
DEFAULT_REQUEST_TIMEOUT = 30.0
|
50
|
+
dispatcher = get_dispatcher(__name__)
|
51
|
+
|
52
|
+
|
53
|
+
def get_additional_kwargs(
|
54
|
+
response: Dict[str, Any], exclude: Tuple[str, ...]
|
55
|
+
) -> Dict[str, Any]:
|
56
|
+
return {k: v for k, v in response.items() if k not in exclude}
|
57
|
+
|
58
|
+
|
59
|
+
def force_single_tool_call(response: ChatResponse) -> None:
|
60
|
+
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
|
61
|
+
if len(tool_calls) > 1:
|
62
|
+
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
|
63
|
+
|
64
|
+
|
65
|
+
class Ollama(FunctionCallingLLM):
|
66
|
+
"""
|
67
|
+
Ollama LLM.
|
68
|
+
|
69
|
+
Visit https://ollama.com/ to download and install Ollama.
|
70
|
+
|
71
|
+
Run `ollama serve` to start a server.
|
72
|
+
|
73
|
+
Run `ollama pull <name>` to download a model to run.
|
74
|
+
|
75
|
+
Examples:
|
76
|
+
`pip install llama-index-llms-ollama`
|
77
|
+
|
78
|
+
```python
|
79
|
+
from llama_index.llms.ollama import Ollama
|
80
|
+
|
81
|
+
llm = Ollama(model="llama2", request_timeout=60.0)
|
82
|
+
|
83
|
+
response = llm.complete("What is the capital of France?")
|
84
|
+
print(response)
|
85
|
+
```
|
86
|
+
|
87
|
+
"""
|
88
|
+
|
89
|
+
base_url: str = Field(
|
90
|
+
default="http://localhost:11434",
|
91
|
+
description="Base url the model is hosted under.",
|
92
|
+
)
|
93
|
+
model: str = Field(description="The Ollama model to use.")
|
94
|
+
temperature: Optional[float] = Field(
|
95
|
+
default=None,
|
96
|
+
description="The temperature to use for sampling.",
|
97
|
+
)
|
98
|
+
context_window: int = Field(
|
99
|
+
default=-1,
|
100
|
+
description="The maximum number of context tokens for the model.",
|
101
|
+
)
|
102
|
+
request_timeout: float = Field(
|
103
|
+
default=DEFAULT_REQUEST_TIMEOUT,
|
104
|
+
description="The timeout for making http request to Ollama API server",
|
105
|
+
)
|
106
|
+
prompt_key: str = Field(
|
107
|
+
default="prompt", description="The key to use for the prompt in API calls."
|
108
|
+
)
|
109
|
+
json_mode: bool = Field(
|
110
|
+
default=False,
|
111
|
+
description="Whether to use JSON mode for the Ollama API.",
|
112
|
+
)
|
113
|
+
additional_kwargs: Dict[str, Any] = Field(
|
114
|
+
default_factory=dict,
|
115
|
+
description="Additional model parameters for the Ollama API.",
|
116
|
+
)
|
117
|
+
is_function_calling_model: bool = Field(
|
118
|
+
default=True,
|
119
|
+
description="Whether the model is a function calling model.",
|
120
|
+
)
|
121
|
+
keep_alive: Optional[Union[float, str]] = Field(
|
122
|
+
default="5m",
|
123
|
+
description="controls how long the model will stay loaded into memory following the request(default: 5m)",
|
124
|
+
)
|
125
|
+
|
126
|
+
_client: Optional[Client] = PrivateAttr()
|
127
|
+
_async_client: Optional[AsyncClient] = PrivateAttr()
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
model: str,
|
132
|
+
base_url: str = "http://localhost:11434",
|
133
|
+
temperature: Optional[float] = None,
|
134
|
+
context_window: int = -1,
|
135
|
+
request_timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
|
136
|
+
prompt_key: str = "prompt",
|
137
|
+
json_mode: bool = False,
|
138
|
+
additional_kwargs: Optional[Dict[str, Any]] = None,
|
139
|
+
client: Optional[Client] = None,
|
140
|
+
async_client: Optional[AsyncClient] = None,
|
141
|
+
is_function_calling_model: bool = True,
|
142
|
+
keep_alive: Optional[Union[float, str]] = None,
|
143
|
+
**kwargs: Any,
|
144
|
+
) -> None:
|
145
|
+
super().__init__(
|
146
|
+
model=model,
|
147
|
+
base_url=base_url,
|
148
|
+
temperature=temperature,
|
149
|
+
context_window=context_window,
|
150
|
+
request_timeout=request_timeout,
|
151
|
+
prompt_key=prompt_key,
|
152
|
+
json_mode=json_mode,
|
153
|
+
additional_kwargs=additional_kwargs or {},
|
154
|
+
is_function_calling_model=is_function_calling_model,
|
155
|
+
keep_alive=keep_alive,
|
156
|
+
**kwargs,
|
157
|
+
)
|
158
|
+
|
159
|
+
self._client = client
|
160
|
+
self._async_client = async_client
|
161
|
+
|
162
|
+
@classmethod
|
163
|
+
def class_name(cls) -> str:
|
164
|
+
return "Ollama_llm"
|
165
|
+
|
166
|
+
@property
|
167
|
+
def metadata(self) -> LLMMetadata:
|
168
|
+
"""LLM metadata."""
|
169
|
+
return LLMMetadata(
|
170
|
+
context_window=self.get_context_window(),
|
171
|
+
num_output=DEFAULT_NUM_OUTPUTS,
|
172
|
+
model_name=self.model,
|
173
|
+
is_chat_model=True, # Ollama supports chat API for all models
|
174
|
+
# TODO: Detect if selected model is a function calling model?
|
175
|
+
is_function_calling_model=self.is_function_calling_model,
|
176
|
+
)
|
177
|
+
|
178
|
+
@property
|
179
|
+
def client(self) -> Client:
|
180
|
+
if self._client is None:
|
181
|
+
self._client = Client(host=self.base_url, timeout=self.request_timeout)
|
182
|
+
return self._client
|
183
|
+
|
184
|
+
@property
|
185
|
+
def async_client(self) -> AsyncClient:
|
186
|
+
if self._async_client is None:
|
187
|
+
self._async_client = AsyncClient(
|
188
|
+
host=self.base_url, timeout=self.request_timeout
|
189
|
+
)
|
190
|
+
return self._async_client
|
191
|
+
|
192
|
+
@property
|
193
|
+
def _model_kwargs(self) -> Dict[str, Any]:
|
194
|
+
base_kwargs = {
|
195
|
+
"temperature": self.temperature,
|
196
|
+
"num_ctx": self.get_context_window(),
|
197
|
+
}
|
198
|
+
return {
|
199
|
+
**base_kwargs,
|
200
|
+
**self.additional_kwargs,
|
201
|
+
}
|
202
|
+
|
203
|
+
def get_context_window(self) -> int:
|
204
|
+
if self.context_window == -1:
|
205
|
+
# Try to get the context window from the model info if not set
|
206
|
+
info = self.client.show(self.model).modelinfo
|
207
|
+
for key, value in info.items():
|
208
|
+
if "context_length" in key:
|
209
|
+
self.context_window = int(value)
|
210
|
+
break
|
211
|
+
|
212
|
+
# If the context window is still -1, use the default context window
|
213
|
+
return self.context_window if self.context_window != -1 else DEFAULT_CONTEXT_WINDOW
|
214
|
+
|
215
|
+
def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict:
|
216
|
+
ollama_messages = []
|
217
|
+
for message in messages:
|
218
|
+
cur_ollama_message = {
|
219
|
+
"role": message.role.value,
|
220
|
+
"content": "",
|
221
|
+
}
|
222
|
+
for block in message.blocks:
|
223
|
+
if isinstance(block, TextBlock):
|
224
|
+
cur_ollama_message["content"] += block.text
|
225
|
+
elif isinstance(block, ImageBlock):
|
226
|
+
if "images" not in cur_ollama_message:
|
227
|
+
cur_ollama_message["images"] = []
|
228
|
+
cur_ollama_message["images"].append(
|
229
|
+
block.resolve_image(as_base64=True).read().decode("utf-8")
|
230
|
+
)
|
231
|
+
else:
|
232
|
+
raise ValueError(f"Unsupported block type: {type(block)}")
|
233
|
+
|
234
|
+
if "tool_calls" in message.additional_kwargs:
|
235
|
+
cur_ollama_message["tool_calls"] = message.additional_kwargs[
|
236
|
+
"tool_calls"
|
237
|
+
]
|
238
|
+
|
239
|
+
ollama_messages.append(cur_ollama_message)
|
240
|
+
|
241
|
+
return ollama_messages
|
242
|
+
|
243
|
+
def _get_response_token_counts(self, raw_response: dict) -> dict:
|
244
|
+
"""Get the token usage reported by the response."""
|
245
|
+
try:
|
246
|
+
prompt_tokens = raw_response["prompt_eval_count"]
|
247
|
+
completion_tokens = raw_response["eval_count"]
|
248
|
+
total_tokens = prompt_tokens + completion_tokens
|
249
|
+
except KeyError:
|
250
|
+
return {}
|
251
|
+
except TypeError:
|
252
|
+
return {}
|
253
|
+
return {
|
254
|
+
"prompt_tokens": prompt_tokens,
|
255
|
+
"completion_tokens": completion_tokens,
|
256
|
+
"total_tokens": total_tokens,
|
257
|
+
}
|
258
|
+
|
259
|
+
def _prepare_chat_with_tools(
|
260
|
+
self,
|
261
|
+
tools: List["BaseTool"],
|
262
|
+
user_msg: Optional[Union[str, ChatMessage]] = None,
|
263
|
+
chat_history: Optional[List[ChatMessage]] = None,
|
264
|
+
verbose: bool = False,
|
265
|
+
allow_parallel_tool_calls: bool = False,
|
266
|
+
**kwargs: Any,
|
267
|
+
) -> Dict[str, Any]:
|
268
|
+
tool_specs = [
|
269
|
+
tool.metadata.to_openai_tool(skip_length_check=True) for tool in tools
|
270
|
+
]
|
271
|
+
|
272
|
+
if isinstance(user_msg, str):
|
273
|
+
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
|
274
|
+
|
275
|
+
messages = chat_history or []
|
276
|
+
if user_msg:
|
277
|
+
messages.append(user_msg)
|
278
|
+
|
279
|
+
return {
|
280
|
+
"messages": messages,
|
281
|
+
"tools": tool_specs or None,
|
282
|
+
}
|
283
|
+
|
284
|
+
def _validate_chat_with_tools_response(
|
285
|
+
self,
|
286
|
+
response: ChatResponse,
|
287
|
+
tools: List["BaseTool"],
|
288
|
+
allow_parallel_tool_calls: bool = False,
|
289
|
+
**kwargs: Any,
|
290
|
+
) -> ChatResponse:
|
291
|
+
"""Validate the response from chat_with_tools."""
|
292
|
+
if not allow_parallel_tool_calls:
|
293
|
+
force_single_tool_call(response)
|
294
|
+
return response
|
295
|
+
|
296
|
+
def get_tool_calls_from_response(
|
297
|
+
self,
|
298
|
+
response: "ChatResponse",
|
299
|
+
error_on_no_tool_call: bool = True,
|
300
|
+
) -> List[ToolSelection]:
|
301
|
+
"""Predict and call the tool."""
|
302
|
+
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
|
303
|
+
if len(tool_calls) < 1:
|
304
|
+
if error_on_no_tool_call:
|
305
|
+
raise ValueError(
|
306
|
+
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
|
307
|
+
)
|
308
|
+
else:
|
309
|
+
return []
|
310
|
+
|
311
|
+
tool_selections = []
|
312
|
+
for tool_call in tool_calls:
|
313
|
+
argument_dict = tool_call["function"]["arguments"]
|
314
|
+
|
315
|
+
tool_selections.append(
|
316
|
+
ToolSelection(
|
317
|
+
# tool ids not provided by Ollama
|
318
|
+
tool_id=tool_call["function"]["name"],
|
319
|
+
tool_name=tool_call["function"]["name"],
|
320
|
+
tool_kwargs=argument_dict,
|
321
|
+
)
|
322
|
+
)
|
323
|
+
|
324
|
+
return tool_selections
|
325
|
+
|
326
|
+
@llm_chat_callback()
|
327
|
+
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
328
|
+
ollama_messages = self._convert_to_ollama_messages(messages)
|
329
|
+
|
330
|
+
tools = kwargs.pop("tools", None)
|
331
|
+
format = kwargs.pop("format", "json" if self.json_mode else None)
|
332
|
+
|
333
|
+
response = self.client.chat(
|
334
|
+
model=self.model,
|
335
|
+
messages=ollama_messages,
|
336
|
+
stream=False,
|
337
|
+
format=format,
|
338
|
+
tools=tools,
|
339
|
+
options=self._model_kwargs,
|
340
|
+
keep_alive=self.keep_alive,
|
341
|
+
)
|
342
|
+
|
343
|
+
response = dict(response)
|
344
|
+
|
345
|
+
tool_calls = response["message"].get("tool_calls", [])
|
346
|
+
token_counts = self._get_response_token_counts(response)
|
347
|
+
if token_counts:
|
348
|
+
response["usage"] = token_counts
|
349
|
+
|
350
|
+
return ChatResponse(
|
351
|
+
message=ChatMessage(
|
352
|
+
content=response["message"]["content"],
|
353
|
+
role=response["message"]["role"],
|
354
|
+
additional_kwargs={"tool_calls": tool_calls},
|
355
|
+
),
|
356
|
+
raw=response,
|
357
|
+
)
|
358
|
+
|
359
|
+
@llm_chat_callback()
|
360
|
+
def stream_chat(
|
361
|
+
self, messages: Sequence[ChatMessage], **kwargs: Any
|
362
|
+
) -> ChatResponseGen:
|
363
|
+
ollama_messages = self._convert_to_ollama_messages(messages)
|
364
|
+
|
365
|
+
tools = kwargs.pop("tools", None)
|
366
|
+
format = kwargs.pop("format", "json" if self.json_mode else None)
|
367
|
+
|
368
|
+
def gen() -> ChatResponseGen:
|
369
|
+
response = self.client.chat(
|
370
|
+
model=self.model,
|
371
|
+
messages=ollama_messages,
|
372
|
+
stream=True,
|
373
|
+
format=format,
|
374
|
+
tools=tools,
|
375
|
+
options=self._model_kwargs,
|
376
|
+
keep_alive=self.keep_alive,
|
377
|
+
)
|
378
|
+
|
379
|
+
response_txt = ""
|
380
|
+
seen_tool_calls = set()
|
381
|
+
all_tool_calls = []
|
382
|
+
|
383
|
+
for r in response:
|
384
|
+
if r["message"]["content"] is None:
|
385
|
+
continue
|
386
|
+
|
387
|
+
r = dict(r)
|
388
|
+
|
389
|
+
response_txt += r["message"]["content"]
|
390
|
+
|
391
|
+
# FIX:
|
392
|
+
if r["message"].get("tool_calls", []) is None:
|
393
|
+
r["message"]["tool_calls"] = []
|
394
|
+
|
395
|
+
new_tool_calls = [dict(t) for t in r["message"].get("tool_calls", [])]
|
396
|
+
for tool_call in new_tool_calls:
|
397
|
+
if (
|
398
|
+
str(tool_call["function"]["name"]),
|
399
|
+
str(tool_call["function"]["arguments"]),
|
400
|
+
) in seen_tool_calls:
|
401
|
+
continue
|
402
|
+
seen_tool_calls.add(
|
403
|
+
(
|
404
|
+
str(tool_call["function"]["name"]),
|
405
|
+
str(tool_call["function"]["arguments"]),
|
406
|
+
)
|
407
|
+
)
|
408
|
+
all_tool_calls.append(tool_call)
|
409
|
+
token_counts = self._get_response_token_counts(r)
|
410
|
+
if token_counts:
|
411
|
+
r["usage"] = token_counts
|
412
|
+
|
413
|
+
yield ChatResponse(
|
414
|
+
message=ChatMessage(
|
415
|
+
content=response_txt,
|
416
|
+
role=r["message"]["role"],
|
417
|
+
additional_kwargs={"tool_calls": list(set(all_tool_calls))},
|
418
|
+
),
|
419
|
+
delta=r["message"]["content"],
|
420
|
+
raw=r,
|
421
|
+
)
|
422
|
+
|
423
|
+
return gen()
|
424
|
+
|
425
|
+
@llm_chat_callback()
|
426
|
+
async def astream_chat(
|
427
|
+
self, messages: Sequence[ChatMessage], **kwargs: Any
|
428
|
+
) -> ChatResponseAsyncGen:
|
429
|
+
ollama_messages = self._convert_to_ollama_messages(messages)
|
430
|
+
|
431
|
+
tools = kwargs.pop("tools", None)
|
432
|
+
format = kwargs.pop("format", "json" if self.json_mode else None)
|
433
|
+
|
434
|
+
async def gen() -> ChatResponseAsyncGen:
|
435
|
+
response = await self.async_client.chat(
|
436
|
+
model=self.model,
|
437
|
+
messages=ollama_messages,
|
438
|
+
stream=True,
|
439
|
+
format=format,
|
440
|
+
tools=tools,
|
441
|
+
options=self._model_kwargs,
|
442
|
+
keep_alive=self.keep_alive,
|
443
|
+
)
|
444
|
+
|
445
|
+
response_txt = ""
|
446
|
+
seen_tool_calls = set()
|
447
|
+
all_tool_calls = []
|
448
|
+
|
449
|
+
async for r in response:
|
450
|
+
if r["message"]["content"] is None:
|
451
|
+
continue
|
452
|
+
|
453
|
+
r = dict(r)
|
454
|
+
|
455
|
+
response_txt += r["message"]["content"]
|
456
|
+
|
457
|
+
new_tool_calls = [dict(t) for t in r["message"].get("tool_calls", [])]
|
458
|
+
for tool_call in new_tool_calls:
|
459
|
+
if (
|
460
|
+
str(tool_call["function"]["name"]),
|
461
|
+
str(tool_call["function"]["arguments"]),
|
462
|
+
) in seen_tool_calls:
|
463
|
+
continue
|
464
|
+
seen_tool_calls.add(
|
465
|
+
(
|
466
|
+
str(tool_call["function"]["name"]),
|
467
|
+
str(tool_call["function"]["arguments"]),
|
468
|
+
)
|
469
|
+
)
|
470
|
+
all_tool_calls.append(tool_call)
|
471
|
+
token_counts = self._get_response_token_counts(r)
|
472
|
+
if token_counts:
|
473
|
+
r["usage"] = token_counts
|
474
|
+
|
475
|
+
yield ChatResponse(
|
476
|
+
message=ChatMessage(
|
477
|
+
content=response_txt,
|
478
|
+
role=r["message"]["role"],
|
479
|
+
additional_kwargs={"tool_calls": all_tool_calls},
|
480
|
+
),
|
481
|
+
delta=r["message"]["content"],
|
482
|
+
raw=r,
|
483
|
+
)
|
484
|
+
|
485
|
+
return gen()
|
486
|
+
|
487
|
+
@llm_chat_callback()
|
488
|
+
async def achat(
|
489
|
+
self, messages: Sequence[ChatMessage], **kwargs: Any
|
490
|
+
) -> ChatResponse:
|
491
|
+
ollama_messages = self._convert_to_ollama_messages(messages)
|
492
|
+
|
493
|
+
tools = kwargs.pop("tools", None)
|
494
|
+
format = kwargs.pop("format", "json" if self.json_mode else None)
|
495
|
+
|
496
|
+
response = await self.async_client.chat(
|
497
|
+
model=self.model,
|
498
|
+
messages=ollama_messages,
|
499
|
+
stream=False,
|
500
|
+
format=format,
|
501
|
+
tools=tools,
|
502
|
+
options=self._model_kwargs,
|
503
|
+
keep_alive=self.keep_alive,
|
504
|
+
)
|
505
|
+
|
506
|
+
response = dict(response)
|
507
|
+
|
508
|
+
tool_calls = response["message"].get("tool_calls", [])
|
509
|
+
token_counts = self._get_response_token_counts(response)
|
510
|
+
if token_counts:
|
511
|
+
response["usage"] = token_counts
|
512
|
+
|
513
|
+
return ChatResponse(
|
514
|
+
message=ChatMessage(
|
515
|
+
content=response["message"]["content"],
|
516
|
+
role=response["message"]["role"],
|
517
|
+
additional_kwargs={"tool_calls": tool_calls},
|
518
|
+
),
|
519
|
+
raw=response,
|
520
|
+
)
|
521
|
+
|
522
|
+
@llm_completion_callback()
|
523
|
+
def complete(
|
524
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
525
|
+
) -> CompletionResponse:
|
526
|
+
return chat_to_completion_decorator(self.chat)(prompt, **kwargs)
|
527
|
+
|
528
|
+
@llm_completion_callback()
|
529
|
+
async def acomplete(
|
530
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
531
|
+
) -> CompletionResponse:
|
532
|
+
return await achat_to_completion_decorator(self.achat)(prompt, **kwargs)
|
533
|
+
|
534
|
+
@llm_completion_callback()
|
535
|
+
def stream_complete(
|
536
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
537
|
+
) -> CompletionResponseGen:
|
538
|
+
return stream_chat_to_completion_decorator(self.stream_chat)(prompt, **kwargs)
|
539
|
+
|
540
|
+
@llm_completion_callback()
|
541
|
+
async def astream_complete(
|
542
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
543
|
+
) -> CompletionResponseAsyncGen:
|
544
|
+
return await astream_chat_to_completion_decorator(self.astream_chat)(
|
545
|
+
prompt, **kwargs
|
546
|
+
)
|
547
|
+
|
548
|
+
@dispatcher.span
|
549
|
+
def structured_predict(
|
550
|
+
self,
|
551
|
+
output_cls: Type[Model],
|
552
|
+
prompt: PromptTemplate,
|
553
|
+
llm_kwargs: Optional[Dict[str, Any]] = None,
|
554
|
+
**prompt_args: Any,
|
555
|
+
) -> Model:
|
556
|
+
if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
|
557
|
+
llm_kwargs = llm_kwargs or {}
|
558
|
+
llm_kwargs["format"] = output_cls.model_json_schema()
|
559
|
+
|
560
|
+
messages = prompt.format_messages(**prompt_args)
|
561
|
+
response = self.chat(messages, **llm_kwargs)
|
562
|
+
|
563
|
+
return output_cls.model_validate_json(response.message.content or "")
|
564
|
+
else:
|
565
|
+
return super().structured_predict(
|
566
|
+
output_cls, prompt, llm_kwargs, **prompt_args
|
567
|
+
)
|
568
|
+
|
569
|
+
@dispatcher.span
|
570
|
+
async def astructured_predict(
|
571
|
+
self,
|
572
|
+
output_cls: Type[Model],
|
573
|
+
prompt: PromptTemplate,
|
574
|
+
llm_kwargs: Optional[Dict[str, Any]] = None,
|
575
|
+
**prompt_args: Any,
|
576
|
+
) -> Model:
|
577
|
+
if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
|
578
|
+
llm_kwargs = llm_kwargs or {}
|
579
|
+
llm_kwargs["format"] = output_cls.model_json_schema()
|
580
|
+
|
581
|
+
messages = prompt.format_messages(**prompt_args)
|
582
|
+
response = await self.achat(messages, **llm_kwargs)
|
583
|
+
|
584
|
+
return output_cls.model_validate_json(response.message.content or "")
|
585
|
+
else:
|
586
|
+
return await super().astructured_predict(
|
587
|
+
output_cls, prompt, llm_kwargs, **prompt_args
|
588
|
+
)
|
589
|
+
|
590
|
+
@dispatcher.span
|
591
|
+
def stream_structured_predict(
|
592
|
+
self,
|
593
|
+
output_cls: Type[Model],
|
594
|
+
prompt: PromptTemplate,
|
595
|
+
llm_kwargs: Optional[Dict[str, Any]] = None,
|
596
|
+
**prompt_args: Any,
|
597
|
+
) -> Generator[Union[Model, FlexibleModel], None, None]:
|
598
|
+
"""
|
599
|
+
Stream structured predictions as they are generated.
|
600
|
+
|
601
|
+
Args:
|
602
|
+
output_cls: The Pydantic class to parse responses into
|
603
|
+
prompt: The prompt template to use
|
604
|
+
llm_kwargs: Optional kwargs for the LLM
|
605
|
+
**prompt_args: Args to format the prompt with
|
606
|
+
|
607
|
+
Returns:
|
608
|
+
Generator yielding partial objects as they are generated
|
609
|
+
|
610
|
+
"""
|
611
|
+
if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
|
612
|
+
|
613
|
+
def gen(
|
614
|
+
output_cls: Type[Model],
|
615
|
+
prompt: PromptTemplate,
|
616
|
+
llm_kwargs: Dict[str, Any],
|
617
|
+
prompt_args: Dict[str, Any],
|
618
|
+
) -> Generator[Union[Model, FlexibleModel], None, None]:
|
619
|
+
llm_kwargs = llm_kwargs or {}
|
620
|
+
llm_kwargs["format"] = output_cls.model_json_schema()
|
621
|
+
|
622
|
+
messages = prompt.format_messages(**prompt_args)
|
623
|
+
response_gen = self.stream_chat(messages, **llm_kwargs)
|
624
|
+
|
625
|
+
cur_objects = None
|
626
|
+
for response in response_gen:
|
627
|
+
try:
|
628
|
+
objects = process_streaming_objects(
|
629
|
+
response,
|
630
|
+
output_cls,
|
631
|
+
cur_objects=cur_objects,
|
632
|
+
allow_parallel_tool_calls=False,
|
633
|
+
flexible_mode=True,
|
634
|
+
)
|
635
|
+
cur_objects = (
|
636
|
+
objects if isinstance(objects, list) else [objects]
|
637
|
+
)
|
638
|
+
yield objects
|
639
|
+
except Exception:
|
640
|
+
continue
|
641
|
+
|
642
|
+
return gen(output_cls, prompt, llm_kwargs, prompt_args)
|
643
|
+
else:
|
644
|
+
return super().stream_structured_predict(
|
645
|
+
output_cls, prompt, llm_kwargs, **prompt_args
|
646
|
+
)
|
647
|
+
|
648
|
+
@dispatcher.span
|
649
|
+
async def astream_structured_predict(
|
650
|
+
self,
|
651
|
+
output_cls: Type[Model],
|
652
|
+
prompt: PromptTemplate,
|
653
|
+
llm_kwargs: Optional[Dict[str, Any]] = None,
|
654
|
+
**prompt_args: Any,
|
655
|
+
) -> AsyncGenerator[Union[Model, FlexibleModel], None]:
|
656
|
+
"""Async version of stream_structured_predict."""
|
657
|
+
if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
|
658
|
+
|
659
|
+
async def gen(
|
660
|
+
output_cls: Type[Model],
|
661
|
+
prompt: PromptTemplate,
|
662
|
+
llm_kwargs: Dict[str, Any],
|
663
|
+
prompt_args: Dict[str, Any],
|
664
|
+
) -> AsyncGenerator[Union[Model, FlexibleModel], None]:
|
665
|
+
llm_kwargs = llm_kwargs or {}
|
666
|
+
llm_kwargs["format"] = output_cls.model_json_schema()
|
667
|
+
|
668
|
+
messages = prompt.format_messages(**prompt_args)
|
669
|
+
response_gen = await self.astream_chat(messages, **llm_kwargs)
|
670
|
+
|
671
|
+
cur_objects = None
|
672
|
+
async for response in response_gen:
|
673
|
+
try:
|
674
|
+
objects = process_streaming_objects(
|
675
|
+
response,
|
676
|
+
output_cls,
|
677
|
+
cur_objects=cur_objects,
|
678
|
+
allow_parallel_tool_calls=False,
|
679
|
+
flexible_mode=True,
|
680
|
+
)
|
681
|
+
cur_objects = (
|
682
|
+
objects if isinstance(objects, list) else [objects]
|
683
|
+
)
|
684
|
+
yield objects
|
685
|
+
except Exception:
|
686
|
+
continue
|
687
|
+
|
688
|
+
return gen(output_cls, prompt, llm_kwargs, prompt_args)
|
689
|
+
else:
|
690
|
+
# Fall back to non-streaming structured predict
|
691
|
+
return await super().astream_structured_predict(
|
692
|
+
output_cls, prompt, llm_kwargs, **prompt_args
|
693
|
+
)
|