pygpt-net 2.5.14__py3-none-any.whl → 2.5.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pygpt_net/CHANGELOG.txt +6 -0
  2. pygpt_net/__init__.py +3 -3
  3. pygpt_net/controller/chat/input.py +9 -2
  4. pygpt_net/controller/lang/mapping.py +4 -2
  5. pygpt_net/controller/model/__init__.py +3 -1
  6. pygpt_net/controller/model/importer.py +337 -0
  7. pygpt_net/controller/settings/editor.py +3 -0
  8. pygpt_net/core/models/__init__.py +6 -3
  9. pygpt_net/core/models/ollama.py +7 -2
  10. pygpt_net/data/config/config.json +9 -4
  11. pygpt_net/data/config/models.json +22 -22
  12. pygpt_net/data/locale/locale.de.ini +18 -0
  13. pygpt_net/data/locale/locale.en.ini +19 -2
  14. pygpt_net/data/locale/locale.es.ini +18 -0
  15. pygpt_net/data/locale/locale.fr.ini +18 -0
  16. pygpt_net/data/locale/locale.it.ini +18 -0
  17. pygpt_net/data/locale/locale.pl.ini +19 -1
  18. pygpt_net/data/locale/locale.uk.ini +18 -0
  19. pygpt_net/data/locale/locale.zh.ini +17 -0
  20. pygpt_net/item/model.py +5 -1
  21. pygpt_net/provider/core/model/json_file.py +3 -0
  22. pygpt_net/provider/core/model/patch.py +24 -1
  23. pygpt_net/provider/llms/ollama.py +7 -2
  24. pygpt_net/provider/llms/ollama_custom.py +693 -0
  25. pygpt_net/ui/dialog/models_importer.py +82 -0
  26. pygpt_net/ui/dialogs.py +3 -1
  27. pygpt_net/ui/menu/config.py +18 -7
  28. pygpt_net/ui/widget/dialog/model_importer.py +55 -0
  29. pygpt_net/ui/widget/lists/model_importer.py +151 -0
  30. {pygpt_net-2.5.14.dist-info → pygpt_net-2.5.15.dist-info}/METADATA +68 -8
  31. {pygpt_net-2.5.14.dist-info → pygpt_net-2.5.15.dist-info}/RECORD +34 -29
  32. {pygpt_net-2.5.14.dist-info → pygpt_net-2.5.15.dist-info}/LICENSE +0 -0
  33. {pygpt_net-2.5.14.dist-info → pygpt_net-2.5.15.dist-info}/WHEEL +0 -0
  34. {pygpt_net-2.5.14.dist-info → pygpt_net-2.5.15.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,693 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ AsyncGenerator,
5
+ Dict,
6
+ Generator,
7
+ List,
8
+ Optional,
9
+ Sequence,
10
+ Tuple,
11
+ Type,
12
+ Union,
13
+ )
14
+
15
+ from ollama import AsyncClient, Client
16
+
17
+ from llama_index.core.base.llms.generic_utils import (
18
+ achat_to_completion_decorator,
19
+ astream_chat_to_completion_decorator,
20
+ chat_to_completion_decorator,
21
+ stream_chat_to_completion_decorator,
22
+ )
23
+ from llama_index.core.base.llms.types import (
24
+ ChatMessage,
25
+ ChatResponse,
26
+ ChatResponseAsyncGen,
27
+ ChatResponseGen,
28
+ CompletionResponse,
29
+ CompletionResponseAsyncGen,
30
+ CompletionResponseGen,
31
+ ImageBlock,
32
+ LLMMetadata,
33
+ MessageRole,
34
+ TextBlock,
35
+ )
36
+ from llama_index.core.bridge.pydantic import Field, PrivateAttr
37
+ from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
38
+ from llama_index.core.instrumentation import get_dispatcher
39
+ from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
40
+ from llama_index.core.llms.function_calling import FunctionCallingLLM
41
+ from llama_index.core.llms.llm import ToolSelection, Model
42
+ from llama_index.core.program.utils import process_streaming_objects, FlexibleModel
43
+ from llama_index.core.prompts import PromptTemplate
44
+ from llama_index.core.types import PydanticProgramMode
45
+
46
+ if TYPE_CHECKING:
47
+ from llama_index.core.tools.types import BaseTool
48
+
49
+ DEFAULT_REQUEST_TIMEOUT = 30.0
50
+ dispatcher = get_dispatcher(__name__)
51
+
52
+
53
+ def get_additional_kwargs(
54
+ response: Dict[str, Any], exclude: Tuple[str, ...]
55
+ ) -> Dict[str, Any]:
56
+ return {k: v for k, v in response.items() if k not in exclude}
57
+
58
+
59
+ def force_single_tool_call(response: ChatResponse) -> None:
60
+ tool_calls = response.message.additional_kwargs.get("tool_calls", [])
61
+ if len(tool_calls) > 1:
62
+ response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
63
+
64
+
65
+ class Ollama(FunctionCallingLLM):
66
+ """
67
+ Ollama LLM.
68
+
69
+ Visit https://ollama.com/ to download and install Ollama.
70
+
71
+ Run `ollama serve` to start a server.
72
+
73
+ Run `ollama pull <name>` to download a model to run.
74
+
75
+ Examples:
76
+ `pip install llama-index-llms-ollama`
77
+
78
+ ```python
79
+ from llama_index.llms.ollama import Ollama
80
+
81
+ llm = Ollama(model="llama2", request_timeout=60.0)
82
+
83
+ response = llm.complete("What is the capital of France?")
84
+ print(response)
85
+ ```
86
+
87
+ """
88
+
89
+ base_url: str = Field(
90
+ default="http://localhost:11434",
91
+ description="Base url the model is hosted under.",
92
+ )
93
+ model: str = Field(description="The Ollama model to use.")
94
+ temperature: Optional[float] = Field(
95
+ default=None,
96
+ description="The temperature to use for sampling.",
97
+ )
98
+ context_window: int = Field(
99
+ default=-1,
100
+ description="The maximum number of context tokens for the model.",
101
+ )
102
+ request_timeout: float = Field(
103
+ default=DEFAULT_REQUEST_TIMEOUT,
104
+ description="The timeout for making http request to Ollama API server",
105
+ )
106
+ prompt_key: str = Field(
107
+ default="prompt", description="The key to use for the prompt in API calls."
108
+ )
109
+ json_mode: bool = Field(
110
+ default=False,
111
+ description="Whether to use JSON mode for the Ollama API.",
112
+ )
113
+ additional_kwargs: Dict[str, Any] = Field(
114
+ default_factory=dict,
115
+ description="Additional model parameters for the Ollama API.",
116
+ )
117
+ is_function_calling_model: bool = Field(
118
+ default=True,
119
+ description="Whether the model is a function calling model.",
120
+ )
121
+ keep_alive: Optional[Union[float, str]] = Field(
122
+ default="5m",
123
+ description="controls how long the model will stay loaded into memory following the request(default: 5m)",
124
+ )
125
+
126
+ _client: Optional[Client] = PrivateAttr()
127
+ _async_client: Optional[AsyncClient] = PrivateAttr()
128
+
129
+ def __init__(
130
+ self,
131
+ model: str,
132
+ base_url: str = "http://localhost:11434",
133
+ temperature: Optional[float] = None,
134
+ context_window: int = -1,
135
+ request_timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
136
+ prompt_key: str = "prompt",
137
+ json_mode: bool = False,
138
+ additional_kwargs: Optional[Dict[str, Any]] = None,
139
+ client: Optional[Client] = None,
140
+ async_client: Optional[AsyncClient] = None,
141
+ is_function_calling_model: bool = True,
142
+ keep_alive: Optional[Union[float, str]] = None,
143
+ **kwargs: Any,
144
+ ) -> None:
145
+ super().__init__(
146
+ model=model,
147
+ base_url=base_url,
148
+ temperature=temperature,
149
+ context_window=context_window,
150
+ request_timeout=request_timeout,
151
+ prompt_key=prompt_key,
152
+ json_mode=json_mode,
153
+ additional_kwargs=additional_kwargs or {},
154
+ is_function_calling_model=is_function_calling_model,
155
+ keep_alive=keep_alive,
156
+ **kwargs,
157
+ )
158
+
159
+ self._client = client
160
+ self._async_client = async_client
161
+
162
+ @classmethod
163
+ def class_name(cls) -> str:
164
+ return "Ollama_llm"
165
+
166
+ @property
167
+ def metadata(self) -> LLMMetadata:
168
+ """LLM metadata."""
169
+ return LLMMetadata(
170
+ context_window=self.get_context_window(),
171
+ num_output=DEFAULT_NUM_OUTPUTS,
172
+ model_name=self.model,
173
+ is_chat_model=True, # Ollama supports chat API for all models
174
+ # TODO: Detect if selected model is a function calling model?
175
+ is_function_calling_model=self.is_function_calling_model,
176
+ )
177
+
178
+ @property
179
+ def client(self) -> Client:
180
+ if self._client is None:
181
+ self._client = Client(host=self.base_url, timeout=self.request_timeout)
182
+ return self._client
183
+
184
+ @property
185
+ def async_client(self) -> AsyncClient:
186
+ if self._async_client is None:
187
+ self._async_client = AsyncClient(
188
+ host=self.base_url, timeout=self.request_timeout
189
+ )
190
+ return self._async_client
191
+
192
+ @property
193
+ def _model_kwargs(self) -> Dict[str, Any]:
194
+ base_kwargs = {
195
+ "temperature": self.temperature,
196
+ "num_ctx": self.get_context_window(),
197
+ }
198
+ return {
199
+ **base_kwargs,
200
+ **self.additional_kwargs,
201
+ }
202
+
203
+ def get_context_window(self) -> int:
204
+ if self.context_window == -1:
205
+ # Try to get the context window from the model info if not set
206
+ info = self.client.show(self.model).modelinfo
207
+ for key, value in info.items():
208
+ if "context_length" in key:
209
+ self.context_window = int(value)
210
+ break
211
+
212
+ # If the context window is still -1, use the default context window
213
+ return self.context_window if self.context_window != -1 else DEFAULT_CONTEXT_WINDOW
214
+
215
+ def _convert_to_ollama_messages(self, messages: Sequence[ChatMessage]) -> Dict:
216
+ ollama_messages = []
217
+ for message in messages:
218
+ cur_ollama_message = {
219
+ "role": message.role.value,
220
+ "content": "",
221
+ }
222
+ for block in message.blocks:
223
+ if isinstance(block, TextBlock):
224
+ cur_ollama_message["content"] += block.text
225
+ elif isinstance(block, ImageBlock):
226
+ if "images" not in cur_ollama_message:
227
+ cur_ollama_message["images"] = []
228
+ cur_ollama_message["images"].append(
229
+ block.resolve_image(as_base64=True).read().decode("utf-8")
230
+ )
231
+ else:
232
+ raise ValueError(f"Unsupported block type: {type(block)}")
233
+
234
+ if "tool_calls" in message.additional_kwargs:
235
+ cur_ollama_message["tool_calls"] = message.additional_kwargs[
236
+ "tool_calls"
237
+ ]
238
+
239
+ ollama_messages.append(cur_ollama_message)
240
+
241
+ return ollama_messages
242
+
243
+ def _get_response_token_counts(self, raw_response: dict) -> dict:
244
+ """Get the token usage reported by the response."""
245
+ try:
246
+ prompt_tokens = raw_response["prompt_eval_count"]
247
+ completion_tokens = raw_response["eval_count"]
248
+ total_tokens = prompt_tokens + completion_tokens
249
+ except KeyError:
250
+ return {}
251
+ except TypeError:
252
+ return {}
253
+ return {
254
+ "prompt_tokens": prompt_tokens,
255
+ "completion_tokens": completion_tokens,
256
+ "total_tokens": total_tokens,
257
+ }
258
+
259
+ def _prepare_chat_with_tools(
260
+ self,
261
+ tools: List["BaseTool"],
262
+ user_msg: Optional[Union[str, ChatMessage]] = None,
263
+ chat_history: Optional[List[ChatMessage]] = None,
264
+ verbose: bool = False,
265
+ allow_parallel_tool_calls: bool = False,
266
+ **kwargs: Any,
267
+ ) -> Dict[str, Any]:
268
+ tool_specs = [
269
+ tool.metadata.to_openai_tool(skip_length_check=True) for tool in tools
270
+ ]
271
+
272
+ if isinstance(user_msg, str):
273
+ user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
274
+
275
+ messages = chat_history or []
276
+ if user_msg:
277
+ messages.append(user_msg)
278
+
279
+ return {
280
+ "messages": messages,
281
+ "tools": tool_specs or None,
282
+ }
283
+
284
+ def _validate_chat_with_tools_response(
285
+ self,
286
+ response: ChatResponse,
287
+ tools: List["BaseTool"],
288
+ allow_parallel_tool_calls: bool = False,
289
+ **kwargs: Any,
290
+ ) -> ChatResponse:
291
+ """Validate the response from chat_with_tools."""
292
+ if not allow_parallel_tool_calls:
293
+ force_single_tool_call(response)
294
+ return response
295
+
296
+ def get_tool_calls_from_response(
297
+ self,
298
+ response: "ChatResponse",
299
+ error_on_no_tool_call: bool = True,
300
+ ) -> List[ToolSelection]:
301
+ """Predict and call the tool."""
302
+ tool_calls = response.message.additional_kwargs.get("tool_calls", [])
303
+ if len(tool_calls) < 1:
304
+ if error_on_no_tool_call:
305
+ raise ValueError(
306
+ f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
307
+ )
308
+ else:
309
+ return []
310
+
311
+ tool_selections = []
312
+ for tool_call in tool_calls:
313
+ argument_dict = tool_call["function"]["arguments"]
314
+
315
+ tool_selections.append(
316
+ ToolSelection(
317
+ # tool ids not provided by Ollama
318
+ tool_id=tool_call["function"]["name"],
319
+ tool_name=tool_call["function"]["name"],
320
+ tool_kwargs=argument_dict,
321
+ )
322
+ )
323
+
324
+ return tool_selections
325
+
326
+ @llm_chat_callback()
327
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
328
+ ollama_messages = self._convert_to_ollama_messages(messages)
329
+
330
+ tools = kwargs.pop("tools", None)
331
+ format = kwargs.pop("format", "json" if self.json_mode else None)
332
+
333
+ response = self.client.chat(
334
+ model=self.model,
335
+ messages=ollama_messages,
336
+ stream=False,
337
+ format=format,
338
+ tools=tools,
339
+ options=self._model_kwargs,
340
+ keep_alive=self.keep_alive,
341
+ )
342
+
343
+ response = dict(response)
344
+
345
+ tool_calls = response["message"].get("tool_calls", [])
346
+ token_counts = self._get_response_token_counts(response)
347
+ if token_counts:
348
+ response["usage"] = token_counts
349
+
350
+ return ChatResponse(
351
+ message=ChatMessage(
352
+ content=response["message"]["content"],
353
+ role=response["message"]["role"],
354
+ additional_kwargs={"tool_calls": tool_calls},
355
+ ),
356
+ raw=response,
357
+ )
358
+
359
+ @llm_chat_callback()
360
+ def stream_chat(
361
+ self, messages: Sequence[ChatMessage], **kwargs: Any
362
+ ) -> ChatResponseGen:
363
+ ollama_messages = self._convert_to_ollama_messages(messages)
364
+
365
+ tools = kwargs.pop("tools", None)
366
+ format = kwargs.pop("format", "json" if self.json_mode else None)
367
+
368
+ def gen() -> ChatResponseGen:
369
+ response = self.client.chat(
370
+ model=self.model,
371
+ messages=ollama_messages,
372
+ stream=True,
373
+ format=format,
374
+ tools=tools,
375
+ options=self._model_kwargs,
376
+ keep_alive=self.keep_alive,
377
+ )
378
+
379
+ response_txt = ""
380
+ seen_tool_calls = set()
381
+ all_tool_calls = []
382
+
383
+ for r in response:
384
+ if r["message"]["content"] is None:
385
+ continue
386
+
387
+ r = dict(r)
388
+
389
+ response_txt += r["message"]["content"]
390
+
391
+ # FIX:
392
+ if r["message"].get("tool_calls", []) is None:
393
+ r["message"]["tool_calls"] = []
394
+
395
+ new_tool_calls = [dict(t) for t in r["message"].get("tool_calls", [])]
396
+ for tool_call in new_tool_calls:
397
+ if (
398
+ str(tool_call["function"]["name"]),
399
+ str(tool_call["function"]["arguments"]),
400
+ ) in seen_tool_calls:
401
+ continue
402
+ seen_tool_calls.add(
403
+ (
404
+ str(tool_call["function"]["name"]),
405
+ str(tool_call["function"]["arguments"]),
406
+ )
407
+ )
408
+ all_tool_calls.append(tool_call)
409
+ token_counts = self._get_response_token_counts(r)
410
+ if token_counts:
411
+ r["usage"] = token_counts
412
+
413
+ yield ChatResponse(
414
+ message=ChatMessage(
415
+ content=response_txt,
416
+ role=r["message"]["role"],
417
+ additional_kwargs={"tool_calls": list(set(all_tool_calls))},
418
+ ),
419
+ delta=r["message"]["content"],
420
+ raw=r,
421
+ )
422
+
423
+ return gen()
424
+
425
+ @llm_chat_callback()
426
+ async def astream_chat(
427
+ self, messages: Sequence[ChatMessage], **kwargs: Any
428
+ ) -> ChatResponseAsyncGen:
429
+ ollama_messages = self._convert_to_ollama_messages(messages)
430
+
431
+ tools = kwargs.pop("tools", None)
432
+ format = kwargs.pop("format", "json" if self.json_mode else None)
433
+
434
+ async def gen() -> ChatResponseAsyncGen:
435
+ response = await self.async_client.chat(
436
+ model=self.model,
437
+ messages=ollama_messages,
438
+ stream=True,
439
+ format=format,
440
+ tools=tools,
441
+ options=self._model_kwargs,
442
+ keep_alive=self.keep_alive,
443
+ )
444
+
445
+ response_txt = ""
446
+ seen_tool_calls = set()
447
+ all_tool_calls = []
448
+
449
+ async for r in response:
450
+ if r["message"]["content"] is None:
451
+ continue
452
+
453
+ r = dict(r)
454
+
455
+ response_txt += r["message"]["content"]
456
+
457
+ new_tool_calls = [dict(t) for t in r["message"].get("tool_calls", [])]
458
+ for tool_call in new_tool_calls:
459
+ if (
460
+ str(tool_call["function"]["name"]),
461
+ str(tool_call["function"]["arguments"]),
462
+ ) in seen_tool_calls:
463
+ continue
464
+ seen_tool_calls.add(
465
+ (
466
+ str(tool_call["function"]["name"]),
467
+ str(tool_call["function"]["arguments"]),
468
+ )
469
+ )
470
+ all_tool_calls.append(tool_call)
471
+ token_counts = self._get_response_token_counts(r)
472
+ if token_counts:
473
+ r["usage"] = token_counts
474
+
475
+ yield ChatResponse(
476
+ message=ChatMessage(
477
+ content=response_txt,
478
+ role=r["message"]["role"],
479
+ additional_kwargs={"tool_calls": all_tool_calls},
480
+ ),
481
+ delta=r["message"]["content"],
482
+ raw=r,
483
+ )
484
+
485
+ return gen()
486
+
487
+ @llm_chat_callback()
488
+ async def achat(
489
+ self, messages: Sequence[ChatMessage], **kwargs: Any
490
+ ) -> ChatResponse:
491
+ ollama_messages = self._convert_to_ollama_messages(messages)
492
+
493
+ tools = kwargs.pop("tools", None)
494
+ format = kwargs.pop("format", "json" if self.json_mode else None)
495
+
496
+ response = await self.async_client.chat(
497
+ model=self.model,
498
+ messages=ollama_messages,
499
+ stream=False,
500
+ format=format,
501
+ tools=tools,
502
+ options=self._model_kwargs,
503
+ keep_alive=self.keep_alive,
504
+ )
505
+
506
+ response = dict(response)
507
+
508
+ tool_calls = response["message"].get("tool_calls", [])
509
+ token_counts = self._get_response_token_counts(response)
510
+ if token_counts:
511
+ response["usage"] = token_counts
512
+
513
+ return ChatResponse(
514
+ message=ChatMessage(
515
+ content=response["message"]["content"],
516
+ role=response["message"]["role"],
517
+ additional_kwargs={"tool_calls": tool_calls},
518
+ ),
519
+ raw=response,
520
+ )
521
+
522
+ @llm_completion_callback()
523
+ def complete(
524
+ self, prompt: str, formatted: bool = False, **kwargs: Any
525
+ ) -> CompletionResponse:
526
+ return chat_to_completion_decorator(self.chat)(prompt, **kwargs)
527
+
528
+ @llm_completion_callback()
529
+ async def acomplete(
530
+ self, prompt: str, formatted: bool = False, **kwargs: Any
531
+ ) -> CompletionResponse:
532
+ return await achat_to_completion_decorator(self.achat)(prompt, **kwargs)
533
+
534
+ @llm_completion_callback()
535
+ def stream_complete(
536
+ self, prompt: str, formatted: bool = False, **kwargs: Any
537
+ ) -> CompletionResponseGen:
538
+ return stream_chat_to_completion_decorator(self.stream_chat)(prompt, **kwargs)
539
+
540
+ @llm_completion_callback()
541
+ async def astream_complete(
542
+ self, prompt: str, formatted: bool = False, **kwargs: Any
543
+ ) -> CompletionResponseAsyncGen:
544
+ return await astream_chat_to_completion_decorator(self.astream_chat)(
545
+ prompt, **kwargs
546
+ )
547
+
548
+ @dispatcher.span
549
+ def structured_predict(
550
+ self,
551
+ output_cls: Type[Model],
552
+ prompt: PromptTemplate,
553
+ llm_kwargs: Optional[Dict[str, Any]] = None,
554
+ **prompt_args: Any,
555
+ ) -> Model:
556
+ if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
557
+ llm_kwargs = llm_kwargs or {}
558
+ llm_kwargs["format"] = output_cls.model_json_schema()
559
+
560
+ messages = prompt.format_messages(**prompt_args)
561
+ response = self.chat(messages, **llm_kwargs)
562
+
563
+ return output_cls.model_validate_json(response.message.content or "")
564
+ else:
565
+ return super().structured_predict(
566
+ output_cls, prompt, llm_kwargs, **prompt_args
567
+ )
568
+
569
+ @dispatcher.span
570
+ async def astructured_predict(
571
+ self,
572
+ output_cls: Type[Model],
573
+ prompt: PromptTemplate,
574
+ llm_kwargs: Optional[Dict[str, Any]] = None,
575
+ **prompt_args: Any,
576
+ ) -> Model:
577
+ if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
578
+ llm_kwargs = llm_kwargs or {}
579
+ llm_kwargs["format"] = output_cls.model_json_schema()
580
+
581
+ messages = prompt.format_messages(**prompt_args)
582
+ response = await self.achat(messages, **llm_kwargs)
583
+
584
+ return output_cls.model_validate_json(response.message.content or "")
585
+ else:
586
+ return await super().astructured_predict(
587
+ output_cls, prompt, llm_kwargs, **prompt_args
588
+ )
589
+
590
+ @dispatcher.span
591
+ def stream_structured_predict(
592
+ self,
593
+ output_cls: Type[Model],
594
+ prompt: PromptTemplate,
595
+ llm_kwargs: Optional[Dict[str, Any]] = None,
596
+ **prompt_args: Any,
597
+ ) -> Generator[Union[Model, FlexibleModel], None, None]:
598
+ """
599
+ Stream structured predictions as they are generated.
600
+
601
+ Args:
602
+ output_cls: The Pydantic class to parse responses into
603
+ prompt: The prompt template to use
604
+ llm_kwargs: Optional kwargs for the LLM
605
+ **prompt_args: Args to format the prompt with
606
+
607
+ Returns:
608
+ Generator yielding partial objects as they are generated
609
+
610
+ """
611
+ if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
612
+
613
+ def gen(
614
+ output_cls: Type[Model],
615
+ prompt: PromptTemplate,
616
+ llm_kwargs: Dict[str, Any],
617
+ prompt_args: Dict[str, Any],
618
+ ) -> Generator[Union[Model, FlexibleModel], None, None]:
619
+ llm_kwargs = llm_kwargs or {}
620
+ llm_kwargs["format"] = output_cls.model_json_schema()
621
+
622
+ messages = prompt.format_messages(**prompt_args)
623
+ response_gen = self.stream_chat(messages, **llm_kwargs)
624
+
625
+ cur_objects = None
626
+ for response in response_gen:
627
+ try:
628
+ objects = process_streaming_objects(
629
+ response,
630
+ output_cls,
631
+ cur_objects=cur_objects,
632
+ allow_parallel_tool_calls=False,
633
+ flexible_mode=True,
634
+ )
635
+ cur_objects = (
636
+ objects if isinstance(objects, list) else [objects]
637
+ )
638
+ yield objects
639
+ except Exception:
640
+ continue
641
+
642
+ return gen(output_cls, prompt, llm_kwargs, prompt_args)
643
+ else:
644
+ return super().stream_structured_predict(
645
+ output_cls, prompt, llm_kwargs, **prompt_args
646
+ )
647
+
648
+ @dispatcher.span
649
+ async def astream_structured_predict(
650
+ self,
651
+ output_cls: Type[Model],
652
+ prompt: PromptTemplate,
653
+ llm_kwargs: Optional[Dict[str, Any]] = None,
654
+ **prompt_args: Any,
655
+ ) -> AsyncGenerator[Union[Model, FlexibleModel], None]:
656
+ """Async version of stream_structured_predict."""
657
+ if self.pydantic_program_mode == PydanticProgramMode.DEFAULT:
658
+
659
+ async def gen(
660
+ output_cls: Type[Model],
661
+ prompt: PromptTemplate,
662
+ llm_kwargs: Dict[str, Any],
663
+ prompt_args: Dict[str, Any],
664
+ ) -> AsyncGenerator[Union[Model, FlexibleModel], None]:
665
+ llm_kwargs = llm_kwargs or {}
666
+ llm_kwargs["format"] = output_cls.model_json_schema()
667
+
668
+ messages = prompt.format_messages(**prompt_args)
669
+ response_gen = await self.astream_chat(messages, **llm_kwargs)
670
+
671
+ cur_objects = None
672
+ async for response in response_gen:
673
+ try:
674
+ objects = process_streaming_objects(
675
+ response,
676
+ output_cls,
677
+ cur_objects=cur_objects,
678
+ allow_parallel_tool_calls=False,
679
+ flexible_mode=True,
680
+ )
681
+ cur_objects = (
682
+ objects if isinstance(objects, list) else [objects]
683
+ )
684
+ yield objects
685
+ except Exception:
686
+ continue
687
+
688
+ return gen(output_cls, prompt, llm_kwargs, prompt_args)
689
+ else:
690
+ # Fall back to non-streaming structured predict
691
+ return await super().astream_structured_predict(
692
+ output_cls, prompt, llm_kwargs, **prompt_args
693
+ )