apple-foundation-models 0.2.2__cp312-cp312-macosx_26_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,659 @@
1
+ """
2
+ Base Session implementation for applefoundationmodels Python bindings.
3
+
4
+ Provides shared logic for both sync and async sessions.
5
+ """
6
+
7
+ import asyncio
8
+ import platform
9
+ import threading
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from contextlib import contextmanager, asynccontextmanager
13
+ from dataclasses import dataclass
14
+ from functools import lru_cache
15
+ from typing import (
16
+ Optional,
17
+ Dict,
18
+ Any,
19
+ Callable,
20
+ List,
21
+ Union,
22
+ cast,
23
+ Generator,
24
+ AsyncGenerator,
25
+ Awaitable,
26
+ ClassVar,
27
+ Literal,
28
+ )
29
+
30
+ from .base import ContextManagedResource
31
+ from .types import (
32
+ GenerationResponse,
33
+ StreamChunk,
34
+ ToolCall,
35
+ Function,
36
+ Availability,
37
+ )
38
+ from .constants import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
39
+ from .exceptions import NotAvailableError
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ StreamQueueItem = Union[str, None, Exception]
44
+
45
+
46
+ @lru_cache(maxsize=1)
47
+ def get_foundationmodels():
48
+ """Return the cached _foundationmodels module."""
49
+ from . import _foundationmodels
50
+
51
+ return _foundationmodels
52
+
53
+
54
+ class BaseSession(ContextManagedResource, ABC):
55
+ """
56
+ Base class for Session and AsyncSession with shared logic.
57
+
58
+ This class contains all the common functionality between the sync
59
+ and async session implementations to avoid duplication.
60
+ """
61
+
62
+ # Class-level flag to track if library has been initialized
63
+ _initialized: ClassVar[bool] = False
64
+
65
+ @dataclass
66
+ class _GenerationPlan:
67
+ mode: Literal["stream", "structured", "text"]
68
+ temperature: float
69
+ max_tokens: int
70
+
71
+ @dataclass
72
+ class _StreamQueueAdapter:
73
+ push: Callable[[StreamQueueItem], None]
74
+ get_sync: Callable[[], StreamQueueItem]
75
+ get_async: Optional[Callable[[], Awaitable[StreamQueueItem]]] = None
76
+
77
+ def __init__(
78
+ self,
79
+ instructions: Optional[str] = None,
80
+ tools: Optional[List[Callable]] = None,
81
+ ):
82
+ """
83
+ Create a base session instance.
84
+
85
+ Args:
86
+ instructions: Optional system instructions to guide AI behavior
87
+ tools: Optional list of tool functions to make available to the model
88
+
89
+ Raises:
90
+ InitializationError: If library initialization fails
91
+ NotAvailableError: If Apple Intelligence is not available
92
+ RuntimeError: If platform is not supported
93
+ """
94
+ # Validate platform and initialize library on first session creation
95
+ self._validate_platform()
96
+ self._initialize_library()
97
+
98
+ self._ffi = get_foundationmodels()
99
+ config = self._build_session_config(instructions, tools)
100
+ self._session_id = self._ffi.create_session(config)
101
+ self._closed = False
102
+ self._config = config
103
+ # Initialize to current transcript length to exclude any initial instructions
104
+ self._last_transcript_length = len(self.transcript)
105
+
106
+ @abstractmethod
107
+ def _call_ffi(self, func: Callable, *args, **kwargs) -> Any:
108
+ """
109
+ Execute an FFI call (sync or async depending on implementation).
110
+
111
+ This is the adapter method that subclasses must implement to handle
112
+ sync vs async execution of FFI calls.
113
+
114
+ Args:
115
+ func: The FFI function to call
116
+ *args: Positional arguments to pass to the function
117
+ **kwargs: Keyword arguments to pass to the function
118
+
119
+ Returns:
120
+ The result from the FFI call
121
+ """
122
+ pass
123
+
124
+ def _check_closed(self) -> None:
125
+ """
126
+ Raise error if session is closed.
127
+
128
+ Raises:
129
+ RuntimeError: If session has been closed
130
+ """
131
+ if self._closed:
132
+ raise RuntimeError("Session is closed")
133
+
134
+ def _apply_defaults(
135
+ self, temperature: Optional[float], max_tokens: Optional[int]
136
+ ) -> tuple[float, int]:
137
+ """
138
+ Apply default values to generation parameters.
139
+
140
+ Args:
141
+ temperature: Temperature value or None to use default
142
+ max_tokens: Max tokens value or None to use default
143
+
144
+ Returns:
145
+ Tuple of (temperature, max_tokens) with defaults applied
146
+ """
147
+ return (
148
+ temperature if temperature is not None else DEFAULT_TEMPERATURE,
149
+ max_tokens if max_tokens is not None else DEFAULT_MAX_TOKENS,
150
+ )
151
+
152
+ def _plan_generate_call(
153
+ self,
154
+ stream: bool,
155
+ schema: Optional[Union[Dict[str, Any], type]],
156
+ temperature: Optional[float],
157
+ max_tokens: Optional[int],
158
+ ) -> "BaseSession._GenerationPlan":
159
+ """Return normalized plan shared by Session and AsyncSession generate()."""
160
+ self._check_closed()
161
+ self._validate_generate_params(stream, schema)
162
+ temp, max_tok = self._apply_defaults(temperature, max_tokens)
163
+
164
+ if stream:
165
+ mode: Literal["stream", "structured", "text"] = "stream"
166
+ elif schema is not None:
167
+ mode = "structured"
168
+ else:
169
+ mode = "text"
170
+
171
+ return BaseSession._GenerationPlan(
172
+ mode=mode,
173
+ temperature=temp,
174
+ max_tokens=max_tok,
175
+ )
176
+
177
+ def _begin_generation(self) -> int:
178
+ """
179
+ Mark the beginning of a generation call.
180
+
181
+ Returns:
182
+ The current transcript length (boundary marker for this generation)
183
+ """
184
+ return len(self.transcript)
185
+
186
+ def _end_generation(self, start_length: int) -> None:
187
+ """
188
+ Mark the end of a generation call.
189
+
190
+ Args:
191
+ start_length: The transcript length captured at generation start
192
+ """
193
+ self._last_transcript_length = start_length
194
+
195
+ @contextmanager
196
+ def _generation_context(self) -> Generator[int, None, None]:
197
+ """
198
+ Context manager for synchronous generation calls.
199
+
200
+ Handles:
201
+ - Marking generation start/end boundaries
202
+ - Automatic cleanup on exception
203
+ - Transcript length tracking
204
+
205
+ Yields:
206
+ start_length: Transcript length at generation start
207
+
208
+ Example:
209
+ >>> with self._generation_context() as start_length:
210
+ ... text = _foundationmodels.generate(prompt, temp, max_tok)
211
+ ... return self._build_generation_response(text, False, start_length)
212
+ """
213
+ start_length = self._begin_generation()
214
+ try:
215
+ yield start_length
216
+ except Exception:
217
+ self._end_generation(start_length)
218
+ raise
219
+
220
+ @asynccontextmanager
221
+ async def _async_generation_context(self) -> AsyncGenerator[int, None]:
222
+ """
223
+ Context manager for asynchronous generation calls.
224
+
225
+ Handles:
226
+ - Marking generation start/end boundaries
227
+ - Automatic cleanup on exception
228
+ - Transcript length tracking
229
+
230
+ Yields:
231
+ start_length: Transcript length at generation start
232
+
233
+ Example:
234
+ >>> async with self._async_generation_context() as start_length:
235
+ ... text = await asyncio.to_thread(fm.generate, prompt, temp, max_tok)
236
+ ... return self._build_generation_response(text, False, start_length)
237
+ """
238
+ start_length = self._begin_generation()
239
+ try:
240
+ yield start_length
241
+ except Exception:
242
+ self._end_generation(start_length)
243
+ raise
244
+
245
+ @abstractmethod
246
+ def _create_stream_queue_adapter(self) -> "BaseSession._StreamQueueAdapter":
247
+ """Create a stream queue adapter for the current session implementation."""
248
+ pass
249
+
250
+ def _stream_chunks_impl(
251
+ self,
252
+ prompt: str,
253
+ temperature: float,
254
+ max_tokens: int,
255
+ adapter: "BaseSession._StreamQueueAdapter",
256
+ ) -> Generator[StreamChunk, None, None]:
257
+ """Shared synchronous streaming implementation."""
258
+ thread = self._start_stream_thread(prompt, temperature, max_tokens, adapter)
259
+ try:
260
+ yield from self._drain_stream_queue_sync(adapter)
261
+ finally:
262
+ self._wait_for_stream_thread(thread)
263
+
264
+ async def _stream_chunks_async_impl(
265
+ self,
266
+ prompt: str,
267
+ temperature: float,
268
+ max_tokens: int,
269
+ adapter: "BaseSession._StreamQueueAdapter",
270
+ ) -> AsyncGenerator[StreamChunk, None]:
271
+ """Shared asynchronous streaming implementation."""
272
+ if adapter.get_async is None:
273
+ raise RuntimeError("Async streaming requires an adapter with async support")
274
+
275
+ thread = self._start_stream_thread(prompt, temperature, max_tokens, adapter)
276
+ try:
277
+ async for chunk in self._drain_stream_queue_async(adapter.get_async):
278
+ yield chunk
279
+ finally:
280
+ await self._await_stream_thread(thread)
281
+
282
+ def _start_stream_thread(
283
+ self,
284
+ prompt: str,
285
+ temperature: float,
286
+ max_tokens: int,
287
+ adapter: "BaseSession._StreamQueueAdapter",
288
+ ) -> threading.Thread:
289
+ """Start background thread that drives the FFI stream."""
290
+
291
+ def run_stream() -> None:
292
+ try:
293
+ self._ffi.generate_stream(prompt, adapter.push, temperature, max_tokens)
294
+ except Exception as exc: # pragma: no cover - defensive
295
+ try:
296
+ adapter.push(exc)
297
+ except Exception: # pragma: no cover - defensive
298
+ logger.error(
299
+ "Failed to propagate streaming exception", exc_info=True
300
+ )
301
+
302
+ thread = threading.Thread(target=run_stream, daemon=True)
303
+ thread.start()
304
+ return thread
305
+
306
+ def _drain_stream_queue_sync(
307
+ self, adapter: "BaseSession._StreamQueueAdapter"
308
+ ) -> Generator[StreamChunk, None, None]:
309
+ """Yield chunks synchronously from the adapter."""
310
+ chunk_index = 0
311
+ while True:
312
+ item = adapter.get_sync()
313
+ chunk, done = self._convert_stream_item(item, chunk_index)
314
+ yield chunk
315
+ if done:
316
+ break
317
+ chunk_index += 1
318
+
319
+ async def _drain_stream_queue_async(
320
+ self, get_next: Callable[[], Awaitable[StreamQueueItem]]
321
+ ) -> AsyncGenerator[StreamChunk, None]:
322
+ """Yield chunks asynchronously using the provided getter."""
323
+ chunk_index = 0
324
+ while True:
325
+ item = await get_next()
326
+ chunk, done = self._convert_stream_item(item, chunk_index)
327
+ yield chunk
328
+ if done:
329
+ break
330
+ chunk_index += 1
331
+
332
+ def _convert_stream_item(
333
+ self, item: StreamQueueItem, chunk_index: int
334
+ ) -> tuple[StreamChunk, bool]:
335
+ """Convert a raw queue item into a StreamChunk plus completion flag."""
336
+ if isinstance(item, Exception):
337
+ raise item
338
+
339
+ if item is None:
340
+ return (
341
+ StreamChunk(content="", finish_reason="stop", index=chunk_index),
342
+ True,
343
+ )
344
+
345
+ return (
346
+ StreamChunk(content=item, finish_reason=None, index=chunk_index),
347
+ False,
348
+ )
349
+
350
+ def _wait_for_stream_thread(
351
+ self, thread: threading.Thread, timeout: float = 5.0
352
+ ) -> None:
353
+ """Wait for the background streaming thread to finish."""
354
+ thread.join(timeout=timeout)
355
+ if thread.is_alive():
356
+ logger.warning(
357
+ "Streaming thread did not complete within %.1f seconds after stream end."
358
+ " Thread will be cleaned up as a daemon.",
359
+ timeout,
360
+ )
361
+
362
+ async def _await_stream_thread(
363
+ self, thread: threading.Thread, timeout: float = 5.0
364
+ ) -> None:
365
+ """Async variant of _wait_for_stream_thread."""
366
+ await asyncio.to_thread(thread.join, timeout)
367
+ if thread.is_alive():
368
+ logger.warning(
369
+ "Streaming thread did not complete within %.1f seconds after stream end."
370
+ " Thread will be cleaned up as a daemon.",
371
+ timeout,
372
+ )
373
+
374
+ def _extract_tool_calls_from_transcript(
375
+ self, transcript_entries: List[Dict[str, Any]]
376
+ ) -> Optional[List[ToolCall]]:
377
+ """
378
+ Extract tool calls from transcript entries.
379
+
380
+ Args:
381
+ transcript_entries: List of transcript entries to search
382
+
383
+ Returns:
384
+ List of ToolCall objects if any tool calls found, None otherwise
385
+ """
386
+ tool_calls = []
387
+ for entry in transcript_entries:
388
+ if entry.get("type") == "tool_call":
389
+ tool_call = ToolCall(
390
+ id=entry.get("tool_id", ""),
391
+ type="function",
392
+ function=Function(
393
+ name=entry.get("tool_name", ""),
394
+ arguments=entry.get("arguments", "{}"),
395
+ ),
396
+ )
397
+ tool_calls.append(tool_call)
398
+
399
+ return tool_calls if tool_calls else None
400
+
401
+ def _build_generation_response(
402
+ self,
403
+ content: Union[str, Dict[str, Any]],
404
+ is_structured: bool,
405
+ start_length: int,
406
+ ) -> GenerationResponse:
407
+ """
408
+ Build a GenerationResponse with tool call extraction.
409
+
410
+ This method centralizes the response building logic to avoid duplication
411
+ between text and structured generation.
412
+
413
+ Args:
414
+ content: The generated content (str for text, dict for structured)
415
+ is_structured: Whether this is structured output
416
+ start_length: The transcript length at generation start
417
+
418
+ Returns:
419
+ GenerationResponse with tool_calls and finish_reason populated
420
+ """
421
+ # Update the generation boundary marker
422
+ self._end_generation(start_length)
423
+
424
+ # Structured generation does not support tool calls
425
+ if is_structured:
426
+ return GenerationResponse(
427
+ content=content,
428
+ is_structured=True,
429
+ tool_calls=None,
430
+ finish_reason="stop",
431
+ )
432
+
433
+ # Extract tool calls from the generation transcript
434
+ tool_calls = self._extract_tool_calls_from_transcript(
435
+ self.last_generation_transcript
436
+ )
437
+
438
+ # Set finish reason based on whether tools were called
439
+ finish_reason = "tool_calls" if tool_calls else "stop"
440
+
441
+ return GenerationResponse(
442
+ content=content,
443
+ is_structured=False,
444
+ tool_calls=tool_calls,
445
+ finish_reason=finish_reason,
446
+ )
447
+
448
+ @property
449
+ def transcript(self) -> List[Dict[str, Any]]:
450
+ """
451
+ Get the session transcript including tool calls.
452
+
453
+ Returns a list of transcript entries showing the full conversation
454
+ history including instructions, prompts, tool calls, tool outputs,
455
+ and responses.
456
+
457
+ Returns:
458
+ List of transcript entry dictionaries with keys:
459
+ - type: Entry type ('instructions', 'prompt', 'response', 'tool_call', 'tool_output')
460
+ - content: Entry content (for text entries)
461
+ - tool_name: Tool name (for tool_call entries)
462
+ - tool_id: Tool call ID (for tool_call and tool_output entries)
463
+ - arguments: Tool arguments as JSON string (for tool_call entries)
464
+
465
+ Example:
466
+ >>> transcript = session.transcript
467
+ >>> for entry in transcript:
468
+ ... print(f"{entry['type']}: {entry.get('content', '')}")
469
+ """
470
+ self._check_closed()
471
+ # Explicit cast to ensure type checkers see the correct return type
472
+ return cast(List[Dict[str, Any]], self._ffi.get_transcript())
473
+
474
+ @property
475
+ def last_generation_transcript(self) -> List[Dict[str, Any]]:
476
+ """
477
+ Get transcript entries from the most recent generate() call only.
478
+
479
+ Unlike the `transcript` property which returns the full accumulated history,
480
+ this returns only the entries added during the last generation call
481
+ (generate(), generate_structured(), or generate_stream()).
482
+
483
+ This is useful when you need to inspect what happened during a specific
484
+ generation without worrying about accumulated history from previous calls.
485
+
486
+ Returns:
487
+ List of transcript entries from the last generate() call.
488
+ Returns empty list if no generation has been performed yet.
489
+
490
+ Example:
491
+ >>> # First generation
492
+ >>> response1 = session.generate("What is 2 + 2?")
493
+ >>> entries1 = session.last_generation_transcript
494
+ >>> print(f"First call: {len(entries1)} entries")
495
+
496
+ >>> # Second generation on same session
497
+ >>> response2 = session.generate("What is 5 + 7?")
498
+ >>> entries2 = session.last_generation_transcript
499
+ >>> print(f"Second call: {len(entries2)} entries (only from second call)")
500
+ """
501
+ self._check_closed()
502
+ full_transcript = self.transcript
503
+ return full_transcript[self._last_transcript_length :]
504
+
505
+ def _validate_generate_params(
506
+ self,
507
+ stream: bool,
508
+ schema: Optional[Union[Dict[str, Any], type]],
509
+ ) -> None:
510
+ """
511
+ Validate generation parameters.
512
+
513
+ Args:
514
+ stream: Whether streaming is requested
515
+ schema: Schema if structured output is requested
516
+
517
+ Raises:
518
+ ValueError: If invalid parameter combination
519
+ """
520
+ if stream and schema is not None:
521
+ raise ValueError(
522
+ "Streaming is not supported with structured output (schema parameter)"
523
+ )
524
+
525
+ def _mark_closed(self) -> None:
526
+ """
527
+ Mark the session as closed.
528
+
529
+ This is used by both Session.close() and AsyncSession.close() to
530
+ set the closed flag.
531
+ """
532
+ self._closed = True
533
+
534
+ @staticmethod
535
+ def _validate_platform() -> None:
536
+ """
537
+ Validate platform requirements for Apple Intelligence.
538
+
539
+ Raises:
540
+ NotAvailableError: If platform is not supported or version is insufficient
541
+ """
542
+ # Check platform requirements
543
+ if platform.system() != "Darwin":
544
+ raise NotAvailableError(
545
+ "Apple Intelligence is only available on macOS. "
546
+ f"Current platform: {platform.system()}"
547
+ )
548
+
549
+ # Check macOS version
550
+ mac_ver = platform.mac_ver()[0]
551
+ if mac_ver:
552
+ try:
553
+ major_version = int(mac_ver.split(".")[0])
554
+ if major_version < 26:
555
+ raise NotAvailableError(
556
+ f"Apple Intelligence requires macOS 26.0 or later. "
557
+ f"Current version: {mac_ver}"
558
+ )
559
+ except (ValueError, IndexError):
560
+ # If we can't parse the version, let it try anyway
561
+ pass
562
+
563
+ @staticmethod
564
+ def _initialize_library() -> None:
565
+ """
566
+ Initialize the FoundationModels library if not already initialized.
567
+
568
+ This is called automatically on first session creation.
569
+ """
570
+ if not BaseSession._initialized:
571
+ get_foundationmodels().init()
572
+ BaseSession._initialized = True
573
+
574
+ @staticmethod
575
+ def _build_session_config(
576
+ instructions: Optional[str],
577
+ tools: Optional[List[Callable]],
578
+ ) -> Optional[Dict[str, Any]]:
579
+ """
580
+ Build session configuration dictionary and register tools.
581
+
582
+ Args:
583
+ instructions: Optional system instructions
584
+ tools: Optional list of tool functions to register
585
+
586
+ Returns:
587
+ Configuration dictionary or None if empty
588
+ """
589
+ # Register tools if provided
590
+ if tools:
591
+ from .tools import register_tool_for_function
592
+
593
+ # Build tool dictionary with function objects
594
+ tool_dict = {}
595
+ for func in tools:
596
+ schema = register_tool_for_function(func)
597
+ tool_name = schema["name"]
598
+ tool_dict[tool_name] = func
599
+
600
+ # Register with FFI
601
+ get_foundationmodels().register_tools(tool_dict)
602
+
603
+ config = {}
604
+ if instructions is not None:
605
+ config["instructions"] = instructions
606
+ return config if config else None
607
+
608
+ @staticmethod
609
+ def check_availability() -> Availability:
610
+ """
611
+ Check Apple Intelligence availability on this device.
612
+
613
+ This is a static method that can be called without creating a session.
614
+
615
+ Returns:
616
+ Availability status enum value
617
+
618
+ Example:
619
+ >>> from applefoundationmodels import Session, Availability
620
+ >>> status = Session.check_availability()
621
+ >>> if status == Availability.AVAILABLE:
622
+ ... print("Apple Intelligence is available!")
623
+ """
624
+ status = cast(int, get_foundationmodels().check_availability())
625
+ return Availability(status)
626
+
627
+ @staticmethod
628
+ def get_availability_reason() -> Optional[str]:
629
+ """
630
+ Get detailed availability status message.
631
+
632
+ Returns:
633
+ Detailed status description with actionable guidance,
634
+ or None if library not initialized
635
+ """
636
+ reason = cast(Optional[str], get_foundationmodels().get_availability_reason())
637
+ return reason
638
+
639
+ @staticmethod
640
+ def is_ready() -> bool:
641
+ """
642
+ Check if Apple Intelligence is ready for immediate use.
643
+
644
+ Returns:
645
+ True if ready for use, False otherwise
646
+ """
647
+ ready = cast(bool, get_foundationmodels().is_ready())
648
+ return ready
649
+
650
+ @staticmethod
651
+ def get_version() -> str:
652
+ """
653
+ Get library version string.
654
+
655
+ Returns:
656
+ Version string in format "major.minor.patch"
657
+ """
658
+ version = cast(str, get_foundationmodels().get_version())
659
+ return version
@@ -0,0 +1,30 @@
1
+ """
2
+ Constants for applefoundationmodels.
3
+
4
+ Provides default values and presets for generation parameters.
5
+ """
6
+
7
+ # Generation defaults
8
+ DEFAULT_TEMPERATURE = 1.0
9
+ DEFAULT_MAX_TOKENS = 1024
10
+
11
+ # Temperature bounds
12
+ MIN_TEMPERATURE = 0.0
13
+ MAX_TEMPERATURE = 2.0
14
+
15
+
16
+ class TemperaturePreset:
17
+ """
18
+ Common temperature presets for different use cases.
19
+
20
+ Temperature controls randomness in generation:
21
+ - Lower values (0.1-0.3): More deterministic, good for facts and precision
22
+ - Medium values (0.5-0.7): Balanced creativity and consistency
23
+ - Higher values (1.0-1.5): More creative and varied outputs
24
+ """
25
+
26
+ DETERMINISTIC = 0.1 # Very low randomness, highly consistent
27
+ FACTUAL = 0.3 # Low randomness, good for factual responses
28
+ BALANCED = 0.7 # Balanced creativity and consistency
29
+ CREATIVE = 1.0 # More creative responses
30
+ VERY_CREATIVE = 1.5 # High creativity and variety