apple-foundation-models 0.1.9__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,144 @@
1
+ """
2
+ Pydantic compatibility utilities for applefoundationmodels.
3
+
4
+ Provides optional Pydantic integration for structured output generation.
5
+ """
6
+
7
+ from typing import Any, Dict, Union, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from pydantic import BaseModel
11
+
12
+ # Try to import Pydantic, but don't fail if it's not installed
13
+ try:
14
+ from pydantic import BaseModel
15
+
16
+ PYDANTIC_AVAILABLE = True
17
+ except ImportError:
18
+ PYDANTIC_AVAILABLE = False
19
+ BaseModel = None # type: ignore
20
+
21
+
22
+ def require_pydantic() -> None:
23
+ """
24
+ Raise ImportError if Pydantic is not available.
25
+
26
+ Raises:
27
+ ImportError: If Pydantic is not installed
28
+ """
29
+ if not PYDANTIC_AVAILABLE:
30
+ raise ImportError(
31
+ "Pydantic is not installed. Install it with: pip install pydantic>=2.0"
32
+ )
33
+
34
+
35
+ def model_to_schema(model: "BaseModel") -> Dict[str, Any]:
36
+ """
37
+ Convert a Pydantic model to JSON Schema.
38
+
39
+ Args:
40
+ model: Pydantic BaseModel class or instance
41
+
42
+ Returns:
43
+ JSON Schema dictionary
44
+
45
+ Raises:
46
+ ImportError: If Pydantic is not installed
47
+ ValueError: If model is not a valid Pydantic model
48
+
49
+ Example:
50
+ >>> from pydantic import BaseModel
51
+ >>> class Person(BaseModel):
52
+ ... name: str
53
+ ... age: int
54
+ >>> schema = model_to_schema(Person)
55
+ >>> print(schema)
56
+ {'type': 'object', 'properties': {...}, 'required': [...]}
57
+ """
58
+ require_pydantic()
59
+
60
+ if not hasattr(model, "model_json_schema"):
61
+ raise ValueError(
62
+ f"Expected Pydantic BaseModel, got {type(model).__name__}. "
63
+ "Make sure your model inherits from pydantic.BaseModel"
64
+ )
65
+
66
+ # Get JSON Schema from Pydantic model
67
+ schema = model.model_json_schema()
68
+
69
+ # Clean up schema (remove title, $defs if empty, etc.)
70
+ schema = _clean_schema(schema)
71
+
72
+ return schema
73
+
74
+
75
+ def _clean_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
76
+ """
77
+ Clean up Pydantic-generated schema for better compatibility.
78
+
79
+ Removes unnecessary fields like 'title' and simplifies nested definitions.
80
+ """
81
+ cleaned = schema.copy()
82
+
83
+ # Remove title if present
84
+ cleaned.pop("title", None)
85
+
86
+ # Remove $defs if empty or inline them if simple
87
+ if "$defs" in cleaned:
88
+ defs = cleaned["$defs"]
89
+ if not defs:
90
+ del cleaned["$defs"]
91
+
92
+ return cleaned
93
+
94
+
95
+ def is_pydantic_model(obj: Any) -> bool:
96
+ """
97
+ Check if an object is a Pydantic model class or instance.
98
+
99
+ Args:
100
+ obj: Object to check
101
+
102
+ Returns:
103
+ True if obj is a Pydantic BaseModel class or instance
104
+ """
105
+ if not PYDANTIC_AVAILABLE:
106
+ return False
107
+
108
+ # Check if it has the Pydantic model_json_schema method
109
+ # This works for both classes and instances
110
+ if hasattr(obj, "model_json_schema"):
111
+ return True
112
+
113
+ return False
114
+
115
+
116
+ def normalize_schema(schema: Union[Dict[str, Any], "BaseModel"]) -> Dict[str, Any]:
117
+ """
118
+ Normalize schema input to JSON Schema dict.
119
+
120
+ Accepts either a JSON Schema dictionary or a Pydantic model,
121
+ and returns a JSON Schema dictionary.
122
+
123
+ Args:
124
+ schema: JSON Schema dict or Pydantic BaseModel
125
+
126
+ Returns:
127
+ JSON Schema dictionary
128
+
129
+ Raises:
130
+ TypeError: If schema is neither dict nor Pydantic model
131
+ ImportError: If Pydantic is needed but not installed
132
+ """
133
+ # If it's already a dict, return it
134
+ if isinstance(schema, dict):
135
+ return schema
136
+
137
+ # If it's a Pydantic model, convert it
138
+ if is_pydantic_model(schema):
139
+ return model_to_schema(schema)
140
+
141
+ # Otherwise, raise an error
142
+ raise TypeError(
143
+ f"Expected JSON Schema dict or Pydantic BaseModel, got {type(schema).__name__}"
144
+ )
@@ -0,0 +1,448 @@
1
+ """
2
+ Session API for applefoundationmodels Python bindings.
3
+
4
+ Provides session management, text generation, and async streaming support.
5
+ """
6
+
7
+ import asyncio
8
+ import json
9
+ from typing import (
10
+ Optional,
11
+ Dict,
12
+ Any,
13
+ AsyncIterator,
14
+ Callable,
15
+ Union,
16
+ TYPE_CHECKING,
17
+ List,
18
+ cast,
19
+ )
20
+ from queue import Queue, Empty
21
+ import threading
22
+
23
+ from . import _foundationmodels
24
+ from .base import ContextManagedResource
25
+ from .types import GenerationParams, NormalizedGenerationParams
26
+ from .pydantic_compat import normalize_schema
27
+ from .tools import extract_function_schema, attach_tool_metadata
28
+
29
+ if TYPE_CHECKING:
30
+ from pydantic import BaseModel
31
+
32
+
33
+ class Session(ContextManagedResource):
34
+ """
35
+ AI session for maintaining conversation state.
36
+
37
+ Sessions maintain conversation history and can be configured with tools
38
+ and instructions. Use as a context manager for automatic cleanup.
39
+
40
+ Usage:
41
+ with client.create_session() as session:
42
+ response = session.generate("Hello!")
43
+ print(response)
44
+ """
45
+
46
+ def __init__(self, session_id: int, config: Optional[Dict[str, Any]] = None):
47
+ """
48
+ Create a Session instance.
49
+
50
+ Note: Users should create sessions via Client.create_session()
51
+ rather than calling this constructor directly.
52
+
53
+ Args:
54
+ session_id: The session ID (always 0 in simplified API)
55
+ config: Optional session configuration
56
+ """
57
+ self._session_id = session_id
58
+ self._closed = False
59
+ self._tools: Dict[str, Callable] = {}
60
+ self._tools_registered = False
61
+ self._config = config
62
+ self._last_transcript_length = 0
63
+
64
+ def close(self) -> None:
65
+ """
66
+ Close the session and cleanup resources.
67
+
68
+ This is a no-op in the simplified API.
69
+ """
70
+ self._closed = False
71
+
72
+ def _check_closed(self) -> None:
73
+ """Raise error if session is closed."""
74
+ if self._closed:
75
+ raise RuntimeError("Session is closed")
76
+
77
+ def _normalize_generation_params(
78
+ self, temperature: Optional[float], max_tokens: Optional[int]
79
+ ) -> NormalizedGenerationParams:
80
+ """
81
+ Normalize generation parameters with defaults.
82
+
83
+ Args:
84
+ temperature: Optional temperature value
85
+ max_tokens: Optional max tokens value
86
+
87
+ Returns:
88
+ NormalizedGenerationParams with defaults applied
89
+ """
90
+ return NormalizedGenerationParams.from_optional(temperature, max_tokens)
91
+
92
+ def _begin_generation(self) -> int:
93
+ """
94
+ Mark the beginning of a generation call.
95
+
96
+ Returns:
97
+ The current transcript length (boundary marker for this generation)
98
+ """
99
+ return len(self.transcript)
100
+
101
+ def _end_generation(self, start_length: int) -> None:
102
+ """
103
+ Mark the end of a generation call.
104
+
105
+ Args:
106
+ start_length: The transcript length captured at generation start
107
+ """
108
+ self._last_transcript_length = start_length
109
+
110
+ def generate(
111
+ self,
112
+ prompt: str,
113
+ temperature: Optional[float] = None,
114
+ max_tokens: Optional[int] = None,
115
+ include_reasoning: Optional[bool] = None,
116
+ seed: Optional[int] = None,
117
+ ) -> str:
118
+ """
119
+ Generate text response for a prompt.
120
+
121
+ Args:
122
+ prompt: Input text prompt
123
+ temperature: Sampling temperature (0.0-2.0, default: DEFAULT_TEMPERATURE)
124
+ max_tokens: Maximum tokens to generate (default: DEFAULT_MAX_TOKENS)
125
+ include_reasoning: Include reasoning steps (not supported)
126
+ seed: Random seed for reproducibility (not supported)
127
+
128
+ Returns:
129
+ Generated text response
130
+
131
+ Raises:
132
+ RuntimeError: If session is closed
133
+ GenerationError: If generation fails
134
+
135
+ Example:
136
+ >>> response = session.generate("What is Python?")
137
+ >>> print(response)
138
+ """
139
+ self._check_closed()
140
+ params = self._normalize_generation_params(temperature, max_tokens)
141
+ start_length = self._begin_generation()
142
+
143
+ try:
144
+ return _foundationmodels.generate(
145
+ prompt, params.temperature, params.max_tokens
146
+ )
147
+ finally:
148
+ self._end_generation(start_length)
149
+
150
+ def generate_structured(
151
+ self,
152
+ prompt: str,
153
+ schema: Union[Dict[str, Any], "BaseModel"],
154
+ temperature: Optional[float] = None,
155
+ max_tokens: Optional[int] = None,
156
+ ) -> Dict[str, Any]:
157
+ """
158
+ Generate structured JSON output matching a schema.
159
+
160
+ Args:
161
+ prompt: Input text prompt
162
+ schema: JSON schema dict or Pydantic BaseModel class
163
+ temperature: Sampling temperature (0.0-2.0, default: DEFAULT_TEMPERATURE)
164
+ max_tokens: Maximum tokens to generate (default: DEFAULT_MAX_TOKENS)
165
+
166
+ Returns:
167
+ Dictionary containing the parsed JSON matching the schema
168
+
169
+ Raises:
170
+ RuntimeError: If session is closed
171
+ GenerationError: If generation fails
172
+ JSONParseError: If schema or response is invalid JSON
173
+ TypeError: If schema is neither dict nor Pydantic model
174
+ ImportError: If Pydantic model provided but Pydantic not installed
175
+
176
+ Example (JSON Schema):
177
+ >>> schema = {
178
+ ... "type": "object",
179
+ ... "properties": {
180
+ ... "name": {"type": "string"},
181
+ ... "age": {"type": "integer"}
182
+ ... },
183
+ ... "required": ["name", "age"]
184
+ ... }
185
+ >>> result = session.generate_structured(
186
+ ... "Extract: Alice is 28",
187
+ ... schema=schema
188
+ ... )
189
+ >>> print(result)
190
+ {'name': 'Alice', 'age': 28}
191
+
192
+ Example (Pydantic):
193
+ >>> from pydantic import BaseModel
194
+ >>> class Person(BaseModel):
195
+ ... name: str
196
+ ... age: int
197
+ >>> result = session.generate_structured(
198
+ ... "Extract: Alice is 28",
199
+ ... schema=Person
200
+ ... )
201
+ >>> person = Person(**result) # Parse directly into Pydantic model
202
+ >>> print(person.name, person.age)
203
+ Alice 28
204
+ """
205
+ self._check_closed()
206
+ params = self._normalize_generation_params(temperature, max_tokens)
207
+ json_schema = normalize_schema(schema)
208
+ start_length = self._begin_generation()
209
+
210
+ try:
211
+ return _foundationmodels.generate_structured(
212
+ prompt, json_schema, params.temperature, params.max_tokens
213
+ )
214
+ finally:
215
+ self._end_generation(start_length)
216
+
217
+ async def generate_stream(
218
+ self,
219
+ prompt: str,
220
+ temperature: Optional[float] = None,
221
+ max_tokens: Optional[int] = None,
222
+ include_reasoning: Optional[bool] = None,
223
+ seed: Optional[int] = None,
224
+ ) -> AsyncIterator[str]:
225
+ """
226
+ Generate text response with async streaming.
227
+
228
+ Args:
229
+ prompt: Input text prompt
230
+ temperature: Sampling temperature (0.0-2.0, default: DEFAULT_TEMPERATURE)
231
+ max_tokens: Maximum tokens to generate (default: DEFAULT_MAX_TOKENS)
232
+ include_reasoning: Include reasoning steps (not supported)
233
+ seed: Random seed (not supported)
234
+
235
+ Yields:
236
+ Text chunks as they are generated
237
+
238
+ Example:
239
+ >>> async for chunk in session.generate_stream("Tell me a story"):
240
+ ... print(chunk, end='', flush=True)
241
+ """
242
+ self._check_closed()
243
+ params = self._normalize_generation_params(temperature, max_tokens)
244
+ start_length = self._begin_generation()
245
+
246
+ try:
247
+ # Use a queue to bridge the sync callback and async iterator
248
+ queue: Queue = Queue()
249
+
250
+ def callback(chunk: Optional[str]) -> None:
251
+ queue.put(chunk)
252
+
253
+ # Run streaming in a background thread
254
+ def run_stream():
255
+ try:
256
+ _foundationmodels.generate_stream(
257
+ prompt, callback, params.temperature, params.max_tokens
258
+ )
259
+ except Exception as e:
260
+ queue.put(e)
261
+
262
+ thread = threading.Thread(target=run_stream, daemon=True)
263
+ thread.start()
264
+
265
+ # Yield chunks from queue
266
+ while True:
267
+ # Use asyncio.sleep to yield control
268
+ await asyncio.sleep(0)
269
+
270
+ try:
271
+ chunk = queue.get(timeout=0.1)
272
+ except Empty:
273
+ continue
274
+
275
+ if isinstance(chunk, Exception):
276
+ raise chunk
277
+
278
+ if chunk is None: # End of stream
279
+ break
280
+
281
+ yield chunk
282
+
283
+ thread.join(timeout=1.0)
284
+ finally:
285
+ self._end_generation(start_length)
286
+
287
+ def get_history(self) -> list:
288
+ """
289
+ Get conversation history.
290
+
291
+ Returns:
292
+ List of message dictionaries with 'role' and 'content' keys
293
+
294
+ Example:
295
+ >>> history = session.get_history()
296
+ >>> for msg in history:
297
+ ... print(f"{msg['role']}: {msg['content']}")
298
+ """
299
+ self._check_closed()
300
+ return _foundationmodels.get_history()
301
+
302
+ def clear_history(self) -> None:
303
+ """
304
+ Clear conversation history.
305
+
306
+ Removes all messages from the session while keeping the session active.
307
+ """
308
+ self._check_closed()
309
+ _foundationmodels.clear_history()
310
+ self._last_transcript_length = 0
311
+
312
+ def add_message(self, role: str, content: str) -> None:
313
+ """
314
+ Add a message to conversation history.
315
+
316
+ Note: This is a stub in the simplified API.
317
+
318
+ Args:
319
+ role: Message role ('user', 'assistant', 'system')
320
+ content: Message content
321
+ """
322
+ self._check_closed()
323
+ _foundationmodels.add_message(role, content)
324
+
325
+ def tool(
326
+ self,
327
+ description: Optional[str] = None,
328
+ name: Optional[str] = None,
329
+ ) -> Callable[[Callable], Callable]:
330
+ """
331
+ Decorator to register a function as a tool for this session.
332
+
333
+ The function's signature and docstring are used to automatically
334
+ generate a JSON schema for the tool's parameters.
335
+
336
+ Args:
337
+ description: Optional tool description (uses docstring if not provided)
338
+ name: Optional tool name (uses function name if not provided)
339
+
340
+ Returns:
341
+ Decorator function
342
+
343
+ Note:
344
+ Tool output size limits:
345
+ - Initial buffer: 16KB
346
+ - Maximum size: 1MB (automatically retried with larger buffers)
347
+ - Tools returning outputs larger than 1MB will raise an error
348
+ - For large outputs, consider returning references or summaries
349
+
350
+ Example:
351
+ @session.tool(description="Get current weather")
352
+ def get_weather(location: str, units: str = "celsius") -> str:
353
+ '''Get weather for a location.'''
354
+ return f"Weather in {location}: 20°{units[0].upper()}"
355
+
356
+ response = session.generate("What's the weather in Paris?")
357
+ """
358
+
359
+ def decorator(func: Callable) -> Callable:
360
+ # Extract schema and attach metadata using shared helper
361
+ schema = extract_function_schema(func)
362
+ final_schema = attach_tool_metadata(func, schema, description, name)
363
+
364
+ # Session-specific logic: store and register tool
365
+ tool_name = final_schema["name"]
366
+ self._tools[tool_name] = func
367
+ self._register_tools()
368
+
369
+ return func
370
+
371
+ return decorator
372
+
373
+ def _register_tools(self) -> None:
374
+ """
375
+ Register all tools with the FFI layer.
376
+
377
+ Called automatically when tools are added via decorator.
378
+ Recreates the session with tools enabled.
379
+ """
380
+ if not self._tools:
381
+ return
382
+
383
+ # Register tools with C FFI
384
+ _foundationmodels.register_tools(self._tools)
385
+ self._tools_registered = True
386
+
387
+ # Recreate session with tools enabled
388
+ # This is necessary because the session needs to be created with tools
389
+ # for FoundationModels to know about them
390
+ config = self._config or {}
391
+ _foundationmodels.create_session(config)
392
+
393
+ @property
394
+ def transcript(self) -> List[Dict[str, Any]]:
395
+ """
396
+ Get the session transcript including tool calls.
397
+
398
+ Returns a list of transcript entries showing the full conversation
399
+ history including instructions, prompts, tool calls, tool outputs,
400
+ and responses.
401
+
402
+ Returns:
403
+ List of transcript entry dictionaries with keys:
404
+ - type: Entry type ('instructions', 'prompt', 'response', 'tool_call', 'tool_output')
405
+ - content: Entry content (for text entries)
406
+ - tool_name: Tool name (for tool_call entries)
407
+ - tool_id: Tool call ID (for tool_call and tool_output entries)
408
+ - arguments: Tool arguments as JSON string (for tool_call entries)
409
+
410
+ Example:
411
+ >>> transcript = session.transcript
412
+ >>> for entry in transcript:
413
+ ... print(f"{entry['type']}: {entry.get('content', '')}")
414
+ """
415
+ self._check_closed()
416
+ # Explicit cast to ensure type checkers see the correct return type
417
+ return cast(List[Dict[str, Any]], _foundationmodels.get_transcript())
418
+
419
+ @property
420
+ def last_generation_transcript(self) -> List[Dict[str, Any]]:
421
+ """
422
+ Get transcript entries from the most recent generate() call only.
423
+
424
+ Unlike the `transcript` property which returns the full accumulated history,
425
+ this returns only the entries added during the last generation call
426
+ (generate(), generate_structured(), or generate_stream()).
427
+
428
+ This is useful when you need to inspect what happened during a specific
429
+ generation without worrying about accumulated history from previous calls.
430
+
431
+ Returns:
432
+ List of transcript entries from the last generate() call.
433
+ Returns empty list if no generation has been performed yet.
434
+
435
+ Example:
436
+ >>> # First generation
437
+ >>> response1 = session.generate("What is 2 + 2?")
438
+ >>> entries1 = session.last_generation_transcript
439
+ >>> print(f"First call: {len(entries1)} entries")
440
+
441
+ >>> # Second generation on same session
442
+ >>> response2 = session.generate("What is 5 + 7?")
443
+ >>> entries2 = session.last_generation_transcript
444
+ >>> print(f"Second call: {len(entries2)} entries (only from second call)")
445
+ """
446
+ self._check_closed()
447
+ full_transcript = self.transcript
448
+ return full_transcript[self._last_transcript_length :]