wishful 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,333 @@
1
+ """Type registration system for wishful.
2
+
3
+ Allows users to register complex types (Pydantic models, dataclasses, TypedDict)
4
+ that the LLM can use when generating code.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import inspect
10
+ from dataclasses import fields as dataclass_fields
11
+ from dataclasses import is_dataclass
12
+ from typing import Any, Callable, TypeVar, get_type_hints
13
+
14
+ T = TypeVar("T")
15
+
16
+
17
+ class TypeRegistry:
18
+ """Global registry for user-defined types."""
19
+
20
+ def __init__(self):
21
+ # Map: type_name -> serialized type definition
22
+ self._types: dict[str, str] = {}
23
+ # Map: function_name -> type_name (for output_for mapping)
24
+ self._function_outputs: dict[str, str] = {}
25
+
26
+ def register(
27
+ self, type_class: type, *, output_for: str | list[str] | None = None
28
+ ) -> None:
29
+ """Register a type and optionally associate it with function(s)."""
30
+ schema = self._serialize_type(type_class)
31
+ self._types[type_class.__name__] = schema
32
+
33
+ if output_for:
34
+ functions = [output_for] if isinstance(output_for, str) else output_for
35
+ for func_name in functions:
36
+ self._function_outputs[func_name] = type_class.__name__
37
+
38
+ def get_schema(self, type_name: str) -> str | None:
39
+ """Get the serialized schema for a registered type."""
40
+ return self._types.get(type_name)
41
+
42
+ def get_all_schemas(self) -> dict[str, str]:
43
+ """Get all registered type schemas."""
44
+ return self._types.copy()
45
+
46
+ def get_output_type(self, function_name: str) -> str | None:
47
+ """Get the registered output type for a function."""
48
+ return self._function_outputs.get(function_name)
49
+
50
+ def clear(self) -> None:
51
+ """Clear all registered types."""
52
+ self._types.clear()
53
+ self._function_outputs.clear()
54
+
55
+ def _serialize_type(self, type_class: type) -> str:
56
+ """Serialize a type to a string representation for the LLM."""
57
+ # Check if it's a Pydantic model
58
+ if self._is_pydantic_model(type_class):
59
+ return self._serialize_pydantic(type_class)
60
+
61
+ # Check if it's a dataclass
62
+ if is_dataclass(type_class):
63
+ return self._serialize_dataclass(type_class)
64
+
65
+ # Check if it's a TypedDict
66
+ if self._is_typed_dict(type_class):
67
+ return self._serialize_typed_dict(type_class)
68
+
69
+ # Fallback: get source code if available
70
+ try:
71
+ return inspect.getsource(type_class)
72
+ except (OSError, TypeError):
73
+ # Last resort: just return the class definition line
74
+ return f"class {type_class.__name__}: ..."
75
+
76
+ def _is_pydantic_model(self, type_class: type) -> bool:
77
+ """Check if a class is a Pydantic BaseModel."""
78
+ try:
79
+ # Check if BaseModel is in the MRO or has model_fields
80
+ return hasattr(type_class, "model_fields") or any(
81
+ "BaseModel" in base.__name__ for base in type_class.__mro__
82
+ )
83
+ except (AttributeError, TypeError):
84
+ return False
85
+
86
+ def _serialize_pydantic(self, model_class: type) -> str:
87
+ """Serialize a Pydantic model to source code."""
88
+ lines = [f"class {model_class.__name__}(BaseModel):"]
89
+
90
+ # Add docstring if present
91
+ if model_class.__doc__:
92
+ lines.append(f' """{model_class.__doc__.strip()}"""')
93
+
94
+ # Get model fields
95
+ if hasattr(model_class, "model_fields"):
96
+ # Pydantic v2
97
+ for field_name, field_info in model_class.model_fields.items():
98
+ annotation = self._format_annotation(field_info.annotation)
99
+
100
+ # Check if field is required (has no default)
101
+ # Use callable check for robustness with mocks
102
+ is_required = field_info.is_required() if callable(getattr(field_info, 'is_required', None)) else (field_info.default is None and field_info.default_factory is None)
103
+
104
+ # Check if field has metadata (Field() usage)
105
+ has_field_metadata = hasattr(field_info, 'metadata') and field_info.metadata
106
+ has_constraints = any(
107
+ hasattr(field_info, attr) and getattr(field_info, attr) is not None
108
+ for attr in ['description', 'min_length', 'max_length', 'gt', 'ge', 'lt', 'le', 'pattern']
109
+ )
110
+
111
+ if is_required:
112
+ if has_field_metadata or has_constraints:
113
+ # Build Field() arguments
114
+ field_args = self._build_field_args(field_info)
115
+ lines.append(f" {field_name}: {annotation} = Field({field_args})")
116
+ else:
117
+ lines.append(f" {field_name}: {annotation}")
118
+ else:
119
+ # Field has a default value or default_factory
120
+ if field_info.default_factory is not None:
121
+ field_args = self._build_field_args(field_info)
122
+ if field_args:
123
+ lines.append(
124
+ f" {field_name}: {annotation} = Field(default_factory=..., {field_args})"
125
+ )
126
+ else:
127
+ lines.append(
128
+ f" {field_name}: {annotation} = Field(default_factory=...)"
129
+ )
130
+ else:
131
+ # Has a default value
132
+ if has_field_metadata or has_constraints:
133
+ field_args = self._build_field_args(field_info)
134
+ default_repr = repr(field_info.default)
135
+ lines.append(f" {field_name}: {annotation} = Field(default={default_repr}, {field_args})")
136
+ else:
137
+ default_repr = repr(field_info.default)
138
+ lines.append(f" {field_name}: {annotation} = {default_repr}")
139
+ elif hasattr(model_class, "__fields__"):
140
+ # Pydantic v1
141
+ for field_name, field in model_class.__fields__.items():
142
+ annotation = self._format_annotation(field.outer_type_)
143
+ if field.required:
144
+ lines.append(f" {field_name}: {annotation}")
145
+ else:
146
+ default_repr = repr(field.default)
147
+ lines.append(f" {field_name}: {annotation} = {default_repr}")
148
+
149
+ return "\n".join(lines)
150
+
151
+ def _build_field_args(self, field_info) -> str:
152
+ """Build Field() arguments from field_info metadata and constraints."""
153
+ args = []
154
+
155
+ # Add description
156
+ if hasattr(field_info, 'description') and field_info.description:
157
+ args.append(f"description={repr(field_info.description)}")
158
+
159
+ # Parse metadata for Pydantic v2 constraints
160
+ if hasattr(field_info, 'metadata') and field_info.metadata:
161
+ for meta_item in field_info.metadata:
162
+ meta_class = meta_item.__class__.__name__
163
+
164
+ # Handle common constraint types
165
+ if meta_class == 'MinLen' and hasattr(meta_item, 'min_length'):
166
+ args.append(f"min_length={meta_item.min_length}")
167
+ elif meta_class == 'MaxLen' and hasattr(meta_item, 'max_length'):
168
+ args.append(f"max_length={meta_item.max_length}")
169
+ elif meta_class == 'Gt' and hasattr(meta_item, 'gt'):
170
+ args.append(f"gt={meta_item.gt}")
171
+ elif meta_class == 'Ge' and hasattr(meta_item, 'ge'):
172
+ args.append(f"ge={meta_item.ge}")
173
+ elif meta_class == 'Lt' and hasattr(meta_item, 'lt'):
174
+ args.append(f"lt={meta_item.lt}")
175
+ elif meta_class == 'Le' and hasattr(meta_item, 'le'):
176
+ args.append(f"le={meta_item.le}")
177
+ elif meta_class == '_PydanticGeneralMetadata':
178
+ # Handle pattern and other general metadata
179
+ if hasattr(meta_item, 'pattern') and meta_item.pattern:
180
+ args.append(f"pattern={repr(meta_item.pattern)}")
181
+
182
+ # Fallback: check direct attributes (Pydantic v1 or custom)
183
+ constraints = [
184
+ ('min_length', 'min_length'),
185
+ ('max_length', 'max_length'),
186
+ ('gt', 'gt'),
187
+ ('ge', 'ge'),
188
+ ('lt', 'lt'),
189
+ ('le', 'le'),
190
+ ('pattern', 'pattern'),
191
+ ]
192
+
193
+ for attr_name, arg_name in constraints:
194
+ if hasattr(field_info, attr_name):
195
+ value = getattr(field_info, attr_name)
196
+ if value is not None and arg_name not in ' '.join(args): # Avoid duplicates
197
+ args.append(f"{arg_name}={repr(value)}")
198
+
199
+ return ", ".join(args)
200
+
201
+ def _serialize_dataclass(self, dc_class: type) -> str:
202
+ """Serialize a dataclass to source code."""
203
+ lines = ["@dataclass", f"class {dc_class.__name__}:"]
204
+
205
+ # Add docstring if present
206
+ if dc_class.__doc__:
207
+ lines.append(f' """{dc_class.__doc__.strip()}"""')
208
+
209
+ for field in dataclass_fields(dc_class):
210
+ annotation = self._format_annotation(field.type)
211
+ if field.default is not field.default_factory: # type: ignore
212
+ # Has a default value
213
+ default_repr = repr(field.default)
214
+ lines.append(f" {field.name}: {annotation} = {default_repr}")
215
+ elif field.default_factory is not field.default_factory: # type: ignore
216
+ lines.append(
217
+ f" {field.name}: {annotation} = field(default_factory=...)"
218
+ )
219
+ else:
220
+ lines.append(f" {field.name}: {annotation}")
221
+
222
+ return "\n".join(lines)
223
+
224
+ def _is_typed_dict(self, type_class: type) -> bool:
225
+ """Check if a class is a TypedDict."""
226
+ try:
227
+ return hasattr(type_class, "__annotations__") and hasattr(
228
+ type_class, "__total__"
229
+ )
230
+ except AttributeError:
231
+ return False
232
+
233
+ def _serialize_typed_dict(self, td_class: type) -> str:
234
+ """Serialize a TypedDict to source code."""
235
+ lines = [f"class {td_class.__name__}(TypedDict):"]
236
+
237
+ if td_class.__doc__:
238
+ lines.append(f' """{td_class.__doc__.strip()}"""')
239
+
240
+ for field_name, field_type in get_type_hints(td_class).items():
241
+ annotation = self._format_annotation(field_type)
242
+ lines.append(f" {field_name}: {annotation}")
243
+
244
+ return "\n".join(lines)
245
+
246
+ def _format_annotation(self, annotation: Any) -> str:
247
+ """Format a type annotation as a string."""
248
+ if hasattr(annotation, "__name__"):
249
+ return annotation.__name__
250
+
251
+ # Handle typing generics
252
+ if hasattr(annotation, "__origin__"):
253
+ origin = annotation.__origin__
254
+ args = getattr(annotation, "__args__", ())
255
+
256
+ if origin is list:
257
+ return (
258
+ f"list[{self._format_annotation(args[0])}]" if args else "list"
259
+ )
260
+ elif origin is dict:
261
+ key_type = self._format_annotation(args[0]) if args else "Any"
262
+ val_type = self._format_annotation(args[1]) if len(args) > 1 else "Any"
263
+ return f"dict[{key_type}, {val_type}]"
264
+ elif origin is tuple:
265
+ arg_strs = ", ".join(self._format_annotation(a) for a in args)
266
+ return f"tuple[{arg_strs}]"
267
+ # Handle Union/Optional
268
+ elif hasattr(origin, "__name__") and origin.__name__ == "UnionType":
269
+ arg_strs = " | ".join(self._format_annotation(a) for a in args)
270
+ return arg_strs
271
+
272
+ return str(annotation).replace("typing.", "")
273
+
274
+
275
+ # Global registry instance
276
+ _registry = TypeRegistry()
277
+
278
+
279
+ def type(
280
+ cls: type[T] | None = None, *, output_for: str | list[str] | None = None
281
+ ) -> type[T] | Callable[[type[T]], type[T]]:
282
+ """Decorator to register a type with wishful.
283
+
284
+ Usage:
285
+ @wishful.type
286
+ class UserProfile(BaseModel):
287
+ name: str
288
+ email: str
289
+
290
+ # Or with output type specification
291
+ @wishful.type(output_for='create_user')
292
+ class UserProfile(BaseModel):
293
+ name: str
294
+ email: str
295
+
296
+ # Multiple functions
297
+ @wishful.type(output_for=['create_user', 'update_user'])
298
+ class UserProfile(BaseModel):
299
+ name: str
300
+ email: str
301
+ """
302
+
303
+ def decorator(type_class: type[T]) -> type[T]:
304
+ _registry.register(type_class, output_for=output_for)
305
+ return type_class
306
+
307
+ # Handle both @wishful.type and @wishful.type(...) syntax
308
+ if cls is None:
309
+ # Called with arguments: @wishful.type(output_for='...')
310
+ return decorator
311
+ else:
312
+ # Called without arguments: @wishful.type
313
+ return decorator(cls)
314
+
315
+
316
+ def get_type_schema(type_name: str) -> str | None:
317
+ """Get the schema for a registered type."""
318
+ return _registry.get_schema(type_name)
319
+
320
+
321
+ def get_all_type_schemas() -> dict[str, str]:
322
+ """Get all registered type schemas."""
323
+ return _registry.get_all_schemas()
324
+
325
+
326
+ def get_output_type_for_function(function_name: str) -> str | None:
327
+ """Get the output type registered for a function."""
328
+ return _registry.get_output_type(function_name)
329
+
330
+
331
+ def clear_type_registry() -> None:
332
+ """Clear all registered types (useful for testing)."""
333
+ _registry.clear()
wishful/ui.py ADDED
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ from typing import Iterator
5
+
6
+ from rich.console import Console
7
+ from rich.progress import Progress, SpinnerColumn, TextColumn
8
+
9
+ from wishful.config import settings
10
+
11
+ _console = Console()
12
+
13
+
14
+ @contextmanager
15
+ def spinner(message: str) -> Iterator[None]:
16
+ if not settings.spinner:
17
+ yield
18
+ return
19
+
20
+ with Progress(SpinnerColumn(), TextColumn(message), console=_console, transient=True) as progress:
21
+ task_id = progress.add_task(message, total=None)
22
+ try:
23
+ yield
24
+ finally:
25
+ progress.update(task_id, completed=1)
26
+ progress.stop()