pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -7,16 +7,17 @@ from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Callable, Union, cast
10
+ from typing import Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
- from .. import _utils, result
14
+ from .. import _utils, usage
15
+ from .._utils import PeekableAsyncStream
15
16
  from ..messages import (
16
17
  ModelMessage,
17
18
  ModelRequest,
18
19
  ModelResponse,
19
- ModelResponsePart,
20
+ ModelResponseStreamEvent,
20
21
  RetryPromptPart,
21
22
  SystemPromptPart,
22
23
  TextPart,
@@ -26,7 +27,7 @@ from ..messages import (
26
27
  )
27
28
  from ..settings import ModelSettings
28
29
  from ..tools import ToolDefinition
29
- from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse
30
+ from . import AgentModel, Model, StreamedResponse
30
31
 
31
32
 
32
33
  @dataclass(init=False)
@@ -142,7 +143,7 @@ class FunctionAgentModel(AgentModel):
142
143
 
143
144
  async def request(
144
145
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
145
- ) -> tuple[ModelResponse, result.Usage]:
146
+ ) -> tuple[ModelResponse, usage.Usage]:
146
147
  agent_info = replace(self.agent_info, model_settings=model_settings)
147
148
 
148
149
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
@@ -158,90 +159,55 @@ class FunctionAgentModel(AgentModel):
158
159
  @asynccontextmanager
159
160
  async def request_stream(
160
161
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> AsyncIterator[EitherStreamedResponse]:
162
+ ) -> AsyncIterator[StreamedResponse]:
162
163
  assert (
163
164
  self.stream_function is not None
164
165
  ), 'FunctionModel must receive a `stream_function` to support streamed requests'
165
- response_stream = self.stream_function(messages, self.agent_info)
166
- try:
167
- first = await response_stream.__anext__()
168
- except StopAsyncIteration as e:
169
- raise ValueError('Stream function must return at least one item') from e
170
-
171
- if isinstance(first, str):
172
- text_stream = cast(AsyncIterator[str], response_stream)
173
- yield FunctionStreamTextResponse(first, text_stream)
174
- else:
175
- structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream)
176
- yield FunctionStreamStructuredResponse(first, structured_stream)
177
-
178
-
179
- @dataclass
180
- class FunctionStreamTextResponse(StreamTextResponse):
181
- """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
182
-
183
- _next: str | None
184
- _iter: AsyncIterator[str]
185
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
186
- _buffer: list[str] = field(default_factory=list, init=False)
187
-
188
- async def __anext__(self) -> None:
189
- if self._next is not None:
190
- self._buffer.append(self._next)
191
- self._next = None
192
- else:
193
- self._buffer.append(await self._iter.__anext__())
194
-
195
- def get(self, *, final: bool = False) -> Iterable[str]:
196
- yield from self._buffer
197
- self._buffer.clear()
166
+ response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
198
167
 
199
- def usage(self) -> result.Usage:
200
- return result.Usage()
168
+ first = await response_stream.peek()
169
+ if isinstance(first, _utils.Unset):
170
+ raise ValueError('Stream function must return at least one item')
201
171
 
202
- def timestamp(self) -> datetime:
203
- return self._timestamp
172
+ yield FunctionStreamedResponse(response_stream)
204
173
 
205
174
 
206
175
  @dataclass
207
- class FunctionStreamStructuredResponse(StreamStructuredResponse):
208
- """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
176
+ class FunctionStreamedResponse(StreamedResponse):
177
+ """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
209
178
 
210
- _next: DeltaToolCalls | None
211
- _iter: AsyncIterator[DeltaToolCalls]
212
- _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict)
179
+ _iter: AsyncIterator[str | DeltaToolCalls]
213
180
  _timestamp: datetime = field(default_factory=_utils.now_utc)
214
181
 
215
- async def __anext__(self) -> None:
216
- if self._next is not None:
217
- tool_call = self._next
218
- self._next = None
219
- else:
220
- tool_call = await self._iter.__anext__()
182
+ def __post_init__(self):
183
+ self._usage += _estimate_usage([])
221
184
 
222
- for key, new in tool_call.items():
223
- if current := self._delta_tool_calls.get(key):
224
- current.name = _utils.add_optional(current.name, new.name)
225
- current.json_args = _utils.add_optional(current.json_args, new.json_args)
185
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
186
+ async for item in self._iter:
187
+ if isinstance(item, str):
188
+ response_tokens = _estimate_string_tokens(item)
189
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
190
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
226
191
  else:
227
- self._delta_tool_calls[key] = new
228
-
229
- def get(self, *, final: bool = False) -> ModelResponse:
230
- calls: list[ModelResponsePart] = []
231
- for c in self._delta_tool_calls.values():
232
- if c.name is not None and c.json_args is not None:
233
- calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
234
-
235
- return ModelResponse(calls, timestamp=self._timestamp)
236
-
237
- def usage(self) -> result.Usage:
238
- return _estimate_usage([self.get()])
192
+ delta_tool_calls = item
193
+ for dtc_index, delta_tool_call in delta_tool_calls.items():
194
+ if delta_tool_call.json_args:
195
+ response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
196
+ self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
197
+ maybe_event = self._parts_manager.handle_tool_call_delta(
198
+ vendor_part_id=dtc_index,
199
+ tool_name=delta_tool_call.name,
200
+ args=delta_tool_call.json_args,
201
+ tool_call_id=None,
202
+ )
203
+ if maybe_event is not None:
204
+ yield maybe_event
239
205
 
240
206
  def timestamp(self) -> datetime:
241
207
  return self._timestamp
242
208
 
243
209
 
244
- def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
210
+ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
245
211
  """Very rough guesstimate of the token usage associated with a series of messages.
246
212
 
247
213
  This is designed to be used solely to give plausible numbers for testing!
@@ -253,28 +219,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
253
219
  if isinstance(message, ModelRequest):
254
220
  for part in message.parts:
255
221
  if isinstance(part, (SystemPromptPart, UserPromptPart)):
256
- request_tokens += _estimate_string_usage(part.content)
222
+ request_tokens += _estimate_string_tokens(part.content)
257
223
  elif isinstance(part, ToolReturnPart):
258
- request_tokens += _estimate_string_usage(part.model_response_str())
224
+ request_tokens += _estimate_string_tokens(part.model_response_str())
259
225
  elif isinstance(part, RetryPromptPart):
260
- request_tokens += _estimate_string_usage(part.model_response())
226
+ request_tokens += _estimate_string_tokens(part.model_response())
261
227
  else:
262
228
  assert_never(part)
263
229
  elif isinstance(message, ModelResponse):
264
230
  for part in message.parts:
265
231
  if isinstance(part, TextPart):
266
- response_tokens += _estimate_string_usage(part.content)
232
+ response_tokens += _estimate_string_tokens(part.content)
267
233
  elif isinstance(part, ToolCallPart):
268
234
  call = part
269
- response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
235
+ response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
270
236
  else:
271
237
  assert_never(part)
272
238
  else:
273
239
  assert_never(message)
274
- return result.Usage(
240
+ return usage.Usage(
275
241
  request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
276
242
  )
277
243
 
278
244
 
279
- def _estimate_string_usage(content: str) -> int:
280
- return len(re.split(r'[\s",.:]+', content))
245
+ def _estimate_string_tokens(content: str) -> int:
246
+ if not content:
247
+ return 0
248
+ return len(re.split(r'[\s",.:]+', content.strip()))
@@ -2,24 +2,25 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
  import re
5
- from collections.abc import AsyncIterator, Iterable, Sequence
5
+ from collections.abc import AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
10
  from typing import Annotated, Any, Literal, Protocol, Union
11
+ from uuid import uuid4
11
12
 
12
13
  import pydantic
13
- import pydantic_core
14
14
  from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
15
- from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
15
+ from typing_extensions import NotRequired, TypedDict, assert_never
16
16
 
17
- from .. import UnexpectedModelBehavior, _utils, exceptions, result
17
+ from .. import UnexpectedModelBehavior, _utils, exceptions, usage
18
18
  from ..messages import (
19
19
  ModelMessage,
20
20
  ModelRequest,
21
21
  ModelResponse,
22
22
  ModelResponsePart,
23
+ ModelResponseStreamEvent,
23
24
  RetryPromptPart,
24
25
  SystemPromptPart,
25
26
  TextPart,
@@ -31,10 +32,8 @@ from ..settings import ModelSettings
31
32
  from ..tools import ToolDefinition
32
33
  from . import (
33
34
  AgentModel,
34
- EitherStreamedResponse,
35
35
  Model,
36
- StreamStructuredResponse,
37
- StreamTextResponse,
36
+ StreamedResponse,
38
37
  cached_async_http_client,
39
38
  check_allow_model_requests,
40
39
  get_user_agent,
@@ -111,7 +110,7 @@ class GeminiModel(Model):
111
110
  )
112
111
 
113
112
  def name(self) -> str:
114
- return self.model_name
113
+ return f'google-gla:{self.model_name}'
115
114
 
116
115
 
117
116
  class AuthProtocol(Protocol):
@@ -171,7 +170,7 @@ class GeminiAgentModel(AgentModel):
171
170
 
172
171
  async def request(
173
172
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
174
- ) -> tuple[ModelResponse, result.Usage]:
173
+ ) -> tuple[ModelResponse, usage.Usage]:
175
174
  async with self._make_request(messages, False, model_settings) as http_response:
176
175
  response = _gemini_response_ta.validate_json(await http_response.aread())
177
176
  return self._process_response(response), _metadata_as_usage(response)
@@ -179,7 +178,7 @@ class GeminiAgentModel(AgentModel):
179
178
  @asynccontextmanager
180
179
  async def request_stream(
181
180
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
182
- ) -> AsyncIterator[EitherStreamedResponse]:
181
+ ) -> AsyncIterator[StreamedResponse]:
183
182
  async with self._make_request(messages, True, model_settings) as http_response:
184
183
  yield await self._process_streamed_response(http_response)
185
184
 
@@ -238,7 +237,7 @@ class GeminiAgentModel(AgentModel):
238
237
  return _process_response_from_parts(parts)
239
238
 
240
239
  @staticmethod
241
- async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse:
240
+ async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse:
242
241
  """Process a streamed response, and prepare a streaming response to return."""
243
242
  aiter_bytes = http_response.aiter_bytes()
244
243
  start_response: _GeminiResponse | None = None
@@ -259,11 +258,7 @@ class GeminiAgentModel(AgentModel):
259
258
  if start_response is None:
260
259
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
261
260
 
262
- # TODO: Update this once we rework stream responses to be more flexible
263
- if _extract_response_parts(start_response).is_left():
264
- return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes)
265
- else:
266
- return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
261
+ return GeminiStreamedResponse(_content=content, _stream=aiter_bytes)
267
262
 
268
263
  @classmethod
269
264
  def _message_to_gemini_content(
@@ -302,86 +297,69 @@ class GeminiAgentModel(AgentModel):
302
297
 
303
298
 
304
299
  @dataclass
305
- class GeminiStreamTextResponse(StreamTextResponse):
306
- """Implementation of `StreamTextResponse` for the Gemini model."""
307
-
308
- _json_content: bytearray
309
- _stream: AsyncIterator[bytes]
310
- _position: int = 0
311
- _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
312
- _usage: result.Usage = field(default_factory=result.Usage, init=False)
313
-
314
- async def __anext__(self) -> None:
315
- chunk = await self._stream.__anext__()
316
- self._json_content.extend(chunk)
317
-
318
- def get(self, *, final: bool = False) -> Iterable[str]:
319
- if final:
320
- all_items = pydantic_core.from_json(self._json_content)
321
- new_items = all_items[self._position :]
322
- self._position = len(all_items)
323
- new_responses = _gemini_streamed_response_ta.validate_python(new_items)
324
- else:
325
- all_items = pydantic_core.from_json(self._json_content, allow_partial=True)
326
- new_items = all_items[self._position : -1]
327
- self._position = len(all_items) - 1
328
- new_responses = _gemini_streamed_response_ta.validate_python(
329
- new_items, experimental_allow_partial='trailing-strings'
330
- )
331
- for r in new_responses:
332
- self._usage += _metadata_as_usage(r)
333
- parts = r['candidates'][0]['content']['parts']
334
- if _all_text_parts(parts):
335
- for part in parts:
336
- yield part['text']
337
- else:
338
- raise UnexpectedModelBehavior(
339
- 'Streamed response with unexpected content, expected all parts to be text'
340
- )
341
-
342
- def usage(self) -> result.Usage:
343
- return self._usage
344
-
345
- def timestamp(self) -> datetime:
346
- return self._timestamp
347
-
348
-
349
- @dataclass
350
- class GeminiStreamStructuredResponse(StreamStructuredResponse):
351
- """Implementation of `StreamStructuredResponse` for the Gemini model."""
300
+ class GeminiStreamedResponse(StreamedResponse):
301
+ """Implementation of `StreamedResponse` for the Gemini model."""
352
302
 
353
303
  _content: bytearray
354
304
  _stream: AsyncIterator[bytes]
355
305
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
356
- _usage: result.Usage = field(default_factory=result.Usage, init=False)
357
-
358
- async def __anext__(self) -> None:
359
- chunk = await self._stream.__anext__()
360
- self._content.extend(chunk)
361
-
362
- def get(self, *, final: bool = False) -> ModelResponse:
363
- """Get the `ModelResponse` at this point.
364
306
 
365
- NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always
366
- reply with a single response, when returning a structured data.
307
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
308
+ async for gemini_response in self._get_gemini_responses():
309
+ candidate = gemini_response['candidates'][0]
310
+ gemini_part: _GeminiPartUnion
311
+ for gemini_part in candidate['content']['parts']:
312
+ if 'text' in gemini_part:
313
+ # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
314
+ # amongst the tool call deltas
315
+ yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text'])
316
+
317
+ elif 'function_call' in gemini_part:
318
+ # Here, we assume all function_call parts are complete and don't have deltas.
319
+ # We do this by assigning a unique randomly generated "vendor_part_id".
320
+ # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly
321
+ # it would just be a bit more complicated. And we'd need to confirm the intended semantics.
322
+ maybe_event = self._parts_manager.handle_tool_call_delta(
323
+ vendor_part_id=uuid4(),
324
+ tool_name=gemini_part['function_call']['name'],
325
+ args=gemini_part['function_call']['args'],
326
+ tool_call_id=None,
327
+ )
328
+ if maybe_event is not None:
329
+ yield maybe_event
330
+ else:
331
+ assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}'
332
+
333
+ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
334
+ # This method exists to ensure we only yield completed items, so we don't need to worry about
335
+ # partial gemini responses, which would make everything more complicated
336
+
337
+ gemini_responses: list[_GeminiResponse] = []
338
+ current_gemini_response_index = 0
339
+ # Right now, there are some circumstances where we will have information that could be yielded sooner than it is
340
+ # But changing that would make things a lot more complicated.
341
+ async for chunk in self._stream:
342
+ self._content.extend(chunk)
343
+
344
+ gemini_responses = _gemini_streamed_response_ta.validate_json(
345
+ self._content,
346
+ experimental_allow_partial='trailing-strings',
347
+ )
367
348
 
368
- I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from
369
- separate parts.
370
- """
371
- responses = _gemini_streamed_response_ta.validate_json(
372
- self._content,
373
- experimental_allow_partial='off' if final else 'trailing-strings',
374
- )
375
- combined_parts: list[_GeminiPartUnion] = []
376
- self._usage = result.Usage()
377
- for r in responses:
349
+ # The idea: yield only up to the latest response, which might still be partial.
350
+ # Note that if the latest response is complete, we could yield it immediately, but there's not a good
351
+ # allow_partial API to determine if the last item in the list is complete.
352
+ responses_to_yield = gemini_responses[:-1]
353
+ for r in responses_to_yield[current_gemini_response_index:]:
354
+ current_gemini_response_index += 1
355
+ self._usage += _metadata_as_usage(r)
356
+ yield r
357
+
358
+ # Now yield the final response, which should be complete
359
+ if gemini_responses:
360
+ r = gemini_responses[-1]
378
361
  self._usage += _metadata_as_usage(r)
379
- candidate = r['candidates'][0]
380
- combined_parts.extend(candidate['content']['parts'])
381
- return _process_response_from_parts(combined_parts, timestamp=self._timestamp)
382
-
383
- def usage(self) -> result.Usage:
384
- return self._usage
362
+ yield r
385
363
 
386
364
  def timestamp(self) -> datetime:
387
365
  return self._timestamp
@@ -458,9 +436,14 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
458
436
  items: list[ModelResponsePart] = []
459
437
  for part in parts:
460
438
  if 'text' in part:
461
- items.append(TextPart(part['text']))
439
+ items.append(TextPart(content=part['text']))
462
440
  elif 'function_call' in part:
463
- items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
441
+ items.append(
442
+ ToolCallPart.from_raw_args(
443
+ tool_name=part['function_call']['name'],
444
+ args=part['function_call']['args'],
445
+ )
446
+ )
464
447
  elif 'function_response' in part:
465
448
  raise exceptions.UnexpectedModelBehavior(
466
449
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
@@ -575,35 +558,6 @@ class _GeminiResponse(TypedDict):
575
558
  prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
576
559
 
577
560
 
578
- # TODO: Delete the next three functions once we've reworked streams to be more flexible
579
- def _extract_response_parts(
580
- response: _GeminiResponse,
581
- ) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]:
582
- """Extract the parts of the response from the Gemini API.
583
-
584
- Returns Either a list of function calls (Either.left) or a list of text parts (Either.right).
585
- """
586
- if len(response['candidates']) != 1:
587
- raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
588
- parts = response['candidates'][0]['content']['parts']
589
- if _all_function_call_parts(parts):
590
- return _utils.Either(left=parts)
591
- elif _all_text_parts(parts):
592
- return _utils.Either(right=parts)
593
- else:
594
- raise exceptions.UnexpectedModelBehavior(
595
- f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}'
596
- )
597
-
598
-
599
- def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]:
600
- return all('function_call' in part for part in parts)
601
-
602
-
603
- def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]:
604
- return all('text' in part for part in parts)
605
-
606
-
607
561
  class _GeminiCandidates(TypedDict):
608
562
  """See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
609
563
 
@@ -630,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False):
630
584
  cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
631
585
 
632
586
 
633
- def _metadata_as_usage(response: _GeminiResponse) -> result.Usage:
587
+ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
634
588
  metadata = response.get('usage_metadata')
635
589
  if metadata is None:
636
- return result.Usage()
590
+ return usage.Usage()
637
591
  details: dict[str, int] = {}
638
592
  if cached_content_token_count := metadata.get('cached_content_token_count'):
639
593
  details['cached_content_token_count'] = cached_content_token_count
640
- return result.Usage(
594
+ return usage.Usage(
641
595
  request_tokens=metadata.get('prompt_token_count', 0),
642
596
  response_tokens=metadata.get('candidates_token_count', 0),
643
597
  total_tokens=metadata.get('total_token_count', 0),
@@ -693,7 +647,7 @@ class _GeminiJsonSchema:
693
647
 
694
648
  def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
695
649
  schema.pop('title', None)
696
- default = schema.pop('default', _utils.UNSET)
650
+ schema.pop('default', None)
697
651
  if ref := schema.pop('$ref', None):
698
652
  # noinspection PyTypeChecker
699
653
  key = re.sub(r'^#/\$defs/', '', ref)
@@ -708,11 +662,12 @@ class _GeminiJsonSchema:
708
662
  if any_of := schema.get('anyOf'):
709
663
  for item_schema in any_of:
710
664
  self._simplify(item_schema, refs_stack)
711
- if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
665
+ if len(any_of) == 2 and {'type': 'null'} in any_of:
712
666
  for item_schema in any_of:
713
667
  if item_schema != {'type': 'null'}:
714
668
  schema.clear()
715
669
  schema.update(item_schema)
670
+ schema['nullable'] = True
716
671
  return
717
672
 
718
673
  type_ = schema.get('type')
@@ -721,6 +676,12 @@ class _GeminiJsonSchema:
721
676
  self._object(schema, refs_stack)
722
677
  elif type_ == 'array':
723
678
  return self._array(schema, refs_stack)
679
+ elif type_ == 'string' and (fmt := schema.pop('format', None)):
680
+ description = schema.get('description')
681
+ if description:
682
+ schema['description'] = f'{description} (format: {fmt})'
683
+ else:
684
+ schema['description'] = f'Format: {fmt}'
724
685
 
725
686
  def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
726
687
  ad_props = schema.pop('additionalProperties', None)