wishful 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wishful/__init__.py +99 -0
- wishful/__main__.py +72 -0
- wishful/cache/__init__.py +23 -0
- wishful/cache/manager.py +77 -0
- wishful/config.py +175 -0
- wishful/core/__init__.py +6 -0
- wishful/core/discovery.py +264 -0
- wishful/core/finder.py +77 -0
- wishful/core/loader.py +285 -0
- wishful/dynamic/__init__.py +8 -0
- wishful/dynamic/__init__.pyi +7 -0
- wishful/llm/__init__.py +5 -0
- wishful/llm/client.py +98 -0
- wishful/llm/prompts.py +74 -0
- wishful/logging.py +88 -0
- wishful/py.typed +0 -0
- wishful/safety/__init__.py +5 -0
- wishful/safety/validator.py +132 -0
- wishful/static/__init__.py +8 -0
- wishful/static/__init__.pyi +7 -0
- wishful/types/__init__.py +19 -0
- wishful/types/registry.py +333 -0
- wishful/ui.py +26 -0
- wishful-0.2.1.dist-info/METADATA +401 -0
- wishful-0.2.1.dist-info/RECORD +26 -0
- wishful-0.2.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""Type registration system for wishful.
|
|
2
|
+
|
|
3
|
+
Allows users to register complex types (Pydantic models, dataclasses, TypedDict)
|
|
4
|
+
that the LLM can use when generating code.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import inspect
|
|
10
|
+
from dataclasses import fields as dataclass_fields
|
|
11
|
+
from dataclasses import is_dataclass
|
|
12
|
+
from typing import Any, Callable, TypeVar, get_type_hints
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TypeRegistry:
|
|
18
|
+
"""Global registry for user-defined types."""
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
# Map: type_name -> serialized type definition
|
|
22
|
+
self._types: dict[str, str] = {}
|
|
23
|
+
# Map: function_name -> type_name (for output_for mapping)
|
|
24
|
+
self._function_outputs: dict[str, str] = {}
|
|
25
|
+
|
|
26
|
+
def register(
|
|
27
|
+
self, type_class: type, *, output_for: str | list[str] | None = None
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Register a type and optionally associate it with function(s)."""
|
|
30
|
+
schema = self._serialize_type(type_class)
|
|
31
|
+
self._types[type_class.__name__] = schema
|
|
32
|
+
|
|
33
|
+
if output_for:
|
|
34
|
+
functions = [output_for] if isinstance(output_for, str) else output_for
|
|
35
|
+
for func_name in functions:
|
|
36
|
+
self._function_outputs[func_name] = type_class.__name__
|
|
37
|
+
|
|
38
|
+
def get_schema(self, type_name: str) -> str | None:
|
|
39
|
+
"""Get the serialized schema for a registered type."""
|
|
40
|
+
return self._types.get(type_name)
|
|
41
|
+
|
|
42
|
+
def get_all_schemas(self) -> dict[str, str]:
|
|
43
|
+
"""Get all registered type schemas."""
|
|
44
|
+
return self._types.copy()
|
|
45
|
+
|
|
46
|
+
def get_output_type(self, function_name: str) -> str | None:
|
|
47
|
+
"""Get the registered output type for a function."""
|
|
48
|
+
return self._function_outputs.get(function_name)
|
|
49
|
+
|
|
50
|
+
def clear(self) -> None:
|
|
51
|
+
"""Clear all registered types."""
|
|
52
|
+
self._types.clear()
|
|
53
|
+
self._function_outputs.clear()
|
|
54
|
+
|
|
55
|
+
def _serialize_type(self, type_class: type) -> str:
|
|
56
|
+
"""Serialize a type to a string representation for the LLM."""
|
|
57
|
+
# Check if it's a Pydantic model
|
|
58
|
+
if self._is_pydantic_model(type_class):
|
|
59
|
+
return self._serialize_pydantic(type_class)
|
|
60
|
+
|
|
61
|
+
# Check if it's a dataclass
|
|
62
|
+
if is_dataclass(type_class):
|
|
63
|
+
return self._serialize_dataclass(type_class)
|
|
64
|
+
|
|
65
|
+
# Check if it's a TypedDict
|
|
66
|
+
if self._is_typed_dict(type_class):
|
|
67
|
+
return self._serialize_typed_dict(type_class)
|
|
68
|
+
|
|
69
|
+
# Fallback: get source code if available
|
|
70
|
+
try:
|
|
71
|
+
return inspect.getsource(type_class)
|
|
72
|
+
except (OSError, TypeError):
|
|
73
|
+
# Last resort: just return the class definition line
|
|
74
|
+
return f"class {type_class.__name__}: ..."
|
|
75
|
+
|
|
76
|
+
def _is_pydantic_model(self, type_class: type) -> bool:
|
|
77
|
+
"""Check if a class is a Pydantic BaseModel."""
|
|
78
|
+
try:
|
|
79
|
+
# Check if BaseModel is in the MRO or has model_fields
|
|
80
|
+
return hasattr(type_class, "model_fields") or any(
|
|
81
|
+
"BaseModel" in base.__name__ for base in type_class.__mro__
|
|
82
|
+
)
|
|
83
|
+
except (AttributeError, TypeError):
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
def _serialize_pydantic(self, model_class: type) -> str:
|
|
87
|
+
"""Serialize a Pydantic model to source code."""
|
|
88
|
+
lines = [f"class {model_class.__name__}(BaseModel):"]
|
|
89
|
+
|
|
90
|
+
# Add docstring if present
|
|
91
|
+
if model_class.__doc__:
|
|
92
|
+
lines.append(f' """{model_class.__doc__.strip()}"""')
|
|
93
|
+
|
|
94
|
+
# Get model fields
|
|
95
|
+
if hasattr(model_class, "model_fields"):
|
|
96
|
+
# Pydantic v2
|
|
97
|
+
for field_name, field_info in model_class.model_fields.items():
|
|
98
|
+
annotation = self._format_annotation(field_info.annotation)
|
|
99
|
+
|
|
100
|
+
# Check if field is required (has no default)
|
|
101
|
+
# Use callable check for robustness with mocks
|
|
102
|
+
is_required = field_info.is_required() if callable(getattr(field_info, 'is_required', None)) else (field_info.default is None and field_info.default_factory is None)
|
|
103
|
+
|
|
104
|
+
# Check if field has metadata (Field() usage)
|
|
105
|
+
has_field_metadata = hasattr(field_info, 'metadata') and field_info.metadata
|
|
106
|
+
has_constraints = any(
|
|
107
|
+
hasattr(field_info, attr) and getattr(field_info, attr) is not None
|
|
108
|
+
for attr in ['description', 'min_length', 'max_length', 'gt', 'ge', 'lt', 'le', 'pattern']
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if is_required:
|
|
112
|
+
if has_field_metadata or has_constraints:
|
|
113
|
+
# Build Field() arguments
|
|
114
|
+
field_args = self._build_field_args(field_info)
|
|
115
|
+
lines.append(f" {field_name}: {annotation} = Field({field_args})")
|
|
116
|
+
else:
|
|
117
|
+
lines.append(f" {field_name}: {annotation}")
|
|
118
|
+
else:
|
|
119
|
+
# Field has a default value or default_factory
|
|
120
|
+
if field_info.default_factory is not None:
|
|
121
|
+
field_args = self._build_field_args(field_info)
|
|
122
|
+
if field_args:
|
|
123
|
+
lines.append(
|
|
124
|
+
f" {field_name}: {annotation} = Field(default_factory=..., {field_args})"
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
lines.append(
|
|
128
|
+
f" {field_name}: {annotation} = Field(default_factory=...)"
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
# Has a default value
|
|
132
|
+
if has_field_metadata or has_constraints:
|
|
133
|
+
field_args = self._build_field_args(field_info)
|
|
134
|
+
default_repr = repr(field_info.default)
|
|
135
|
+
lines.append(f" {field_name}: {annotation} = Field(default={default_repr}, {field_args})")
|
|
136
|
+
else:
|
|
137
|
+
default_repr = repr(field_info.default)
|
|
138
|
+
lines.append(f" {field_name}: {annotation} = {default_repr}")
|
|
139
|
+
elif hasattr(model_class, "__fields__"):
|
|
140
|
+
# Pydantic v1
|
|
141
|
+
for field_name, field in model_class.__fields__.items():
|
|
142
|
+
annotation = self._format_annotation(field.outer_type_)
|
|
143
|
+
if field.required:
|
|
144
|
+
lines.append(f" {field_name}: {annotation}")
|
|
145
|
+
else:
|
|
146
|
+
default_repr = repr(field.default)
|
|
147
|
+
lines.append(f" {field_name}: {annotation} = {default_repr}")
|
|
148
|
+
|
|
149
|
+
return "\n".join(lines)
|
|
150
|
+
|
|
151
|
+
def _build_field_args(self, field_info) -> str:
|
|
152
|
+
"""Build Field() arguments from field_info metadata and constraints."""
|
|
153
|
+
args = []
|
|
154
|
+
|
|
155
|
+
# Add description
|
|
156
|
+
if hasattr(field_info, 'description') and field_info.description:
|
|
157
|
+
args.append(f"description={repr(field_info.description)}")
|
|
158
|
+
|
|
159
|
+
# Parse metadata for Pydantic v2 constraints
|
|
160
|
+
if hasattr(field_info, 'metadata') and field_info.metadata:
|
|
161
|
+
for meta_item in field_info.metadata:
|
|
162
|
+
meta_class = meta_item.__class__.__name__
|
|
163
|
+
|
|
164
|
+
# Handle common constraint types
|
|
165
|
+
if meta_class == 'MinLen' and hasattr(meta_item, 'min_length'):
|
|
166
|
+
args.append(f"min_length={meta_item.min_length}")
|
|
167
|
+
elif meta_class == 'MaxLen' and hasattr(meta_item, 'max_length'):
|
|
168
|
+
args.append(f"max_length={meta_item.max_length}")
|
|
169
|
+
elif meta_class == 'Gt' and hasattr(meta_item, 'gt'):
|
|
170
|
+
args.append(f"gt={meta_item.gt}")
|
|
171
|
+
elif meta_class == 'Ge' and hasattr(meta_item, 'ge'):
|
|
172
|
+
args.append(f"ge={meta_item.ge}")
|
|
173
|
+
elif meta_class == 'Lt' and hasattr(meta_item, 'lt'):
|
|
174
|
+
args.append(f"lt={meta_item.lt}")
|
|
175
|
+
elif meta_class == 'Le' and hasattr(meta_item, 'le'):
|
|
176
|
+
args.append(f"le={meta_item.le}")
|
|
177
|
+
elif meta_class == '_PydanticGeneralMetadata':
|
|
178
|
+
# Handle pattern and other general metadata
|
|
179
|
+
if hasattr(meta_item, 'pattern') and meta_item.pattern:
|
|
180
|
+
args.append(f"pattern={repr(meta_item.pattern)}")
|
|
181
|
+
|
|
182
|
+
# Fallback: check direct attributes (Pydantic v1 or custom)
|
|
183
|
+
constraints = [
|
|
184
|
+
('min_length', 'min_length'),
|
|
185
|
+
('max_length', 'max_length'),
|
|
186
|
+
('gt', 'gt'),
|
|
187
|
+
('ge', 'ge'),
|
|
188
|
+
('lt', 'lt'),
|
|
189
|
+
('le', 'le'),
|
|
190
|
+
('pattern', 'pattern'),
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
for attr_name, arg_name in constraints:
|
|
194
|
+
if hasattr(field_info, attr_name):
|
|
195
|
+
value = getattr(field_info, attr_name)
|
|
196
|
+
if value is not None and arg_name not in ' '.join(args): # Avoid duplicates
|
|
197
|
+
args.append(f"{arg_name}={repr(value)}")
|
|
198
|
+
|
|
199
|
+
return ", ".join(args)
|
|
200
|
+
|
|
201
|
+
def _serialize_dataclass(self, dc_class: type) -> str:
|
|
202
|
+
"""Serialize a dataclass to source code."""
|
|
203
|
+
lines = ["@dataclass", f"class {dc_class.__name__}:"]
|
|
204
|
+
|
|
205
|
+
# Add docstring if present
|
|
206
|
+
if dc_class.__doc__:
|
|
207
|
+
lines.append(f' """{dc_class.__doc__.strip()}"""')
|
|
208
|
+
|
|
209
|
+
for field in dataclass_fields(dc_class):
|
|
210
|
+
annotation = self._format_annotation(field.type)
|
|
211
|
+
if field.default is not field.default_factory: # type: ignore
|
|
212
|
+
# Has a default value
|
|
213
|
+
default_repr = repr(field.default)
|
|
214
|
+
lines.append(f" {field.name}: {annotation} = {default_repr}")
|
|
215
|
+
elif field.default_factory is not field.default_factory: # type: ignore
|
|
216
|
+
lines.append(
|
|
217
|
+
f" {field.name}: {annotation} = field(default_factory=...)"
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
lines.append(f" {field.name}: {annotation}")
|
|
221
|
+
|
|
222
|
+
return "\n".join(lines)
|
|
223
|
+
|
|
224
|
+
def _is_typed_dict(self, type_class: type) -> bool:
|
|
225
|
+
"""Check if a class is a TypedDict."""
|
|
226
|
+
try:
|
|
227
|
+
return hasattr(type_class, "__annotations__") and hasattr(
|
|
228
|
+
type_class, "__total__"
|
|
229
|
+
)
|
|
230
|
+
except AttributeError:
|
|
231
|
+
return False
|
|
232
|
+
|
|
233
|
+
def _serialize_typed_dict(self, td_class: type) -> str:
|
|
234
|
+
"""Serialize a TypedDict to source code."""
|
|
235
|
+
lines = [f"class {td_class.__name__}(TypedDict):"]
|
|
236
|
+
|
|
237
|
+
if td_class.__doc__:
|
|
238
|
+
lines.append(f' """{td_class.__doc__.strip()}"""')
|
|
239
|
+
|
|
240
|
+
for field_name, field_type in get_type_hints(td_class).items():
|
|
241
|
+
annotation = self._format_annotation(field_type)
|
|
242
|
+
lines.append(f" {field_name}: {annotation}")
|
|
243
|
+
|
|
244
|
+
return "\n".join(lines)
|
|
245
|
+
|
|
246
|
+
def _format_annotation(self, annotation: Any) -> str:
|
|
247
|
+
"""Format a type annotation as a string."""
|
|
248
|
+
if hasattr(annotation, "__name__"):
|
|
249
|
+
return annotation.__name__
|
|
250
|
+
|
|
251
|
+
# Handle typing generics
|
|
252
|
+
if hasattr(annotation, "__origin__"):
|
|
253
|
+
origin = annotation.__origin__
|
|
254
|
+
args = getattr(annotation, "__args__", ())
|
|
255
|
+
|
|
256
|
+
if origin is list:
|
|
257
|
+
return (
|
|
258
|
+
f"list[{self._format_annotation(args[0])}]" if args else "list"
|
|
259
|
+
)
|
|
260
|
+
elif origin is dict:
|
|
261
|
+
key_type = self._format_annotation(args[0]) if args else "Any"
|
|
262
|
+
val_type = self._format_annotation(args[1]) if len(args) > 1 else "Any"
|
|
263
|
+
return f"dict[{key_type}, {val_type}]"
|
|
264
|
+
elif origin is tuple:
|
|
265
|
+
arg_strs = ", ".join(self._format_annotation(a) for a in args)
|
|
266
|
+
return f"tuple[{arg_strs}]"
|
|
267
|
+
# Handle Union/Optional
|
|
268
|
+
elif hasattr(origin, "__name__") and origin.__name__ == "UnionType":
|
|
269
|
+
arg_strs = " | ".join(self._format_annotation(a) for a in args)
|
|
270
|
+
return arg_strs
|
|
271
|
+
|
|
272
|
+
return str(annotation).replace("typing.", "")
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# Global registry instance
|
|
276
|
+
_registry = TypeRegistry()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def type(
|
|
280
|
+
cls: type[T] | None = None, *, output_for: str | list[str] | None = None
|
|
281
|
+
) -> type[T] | Callable[[type[T]], type[T]]:
|
|
282
|
+
"""Decorator to register a type with wishful.
|
|
283
|
+
|
|
284
|
+
Usage:
|
|
285
|
+
@wishful.type
|
|
286
|
+
class UserProfile(BaseModel):
|
|
287
|
+
name: str
|
|
288
|
+
email: str
|
|
289
|
+
|
|
290
|
+
# Or with output type specification
|
|
291
|
+
@wishful.type(output_for='create_user')
|
|
292
|
+
class UserProfile(BaseModel):
|
|
293
|
+
name: str
|
|
294
|
+
email: str
|
|
295
|
+
|
|
296
|
+
# Multiple functions
|
|
297
|
+
@wishful.type(output_for=['create_user', 'update_user'])
|
|
298
|
+
class UserProfile(BaseModel):
|
|
299
|
+
name: str
|
|
300
|
+
email: str
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def decorator(type_class: type[T]) -> type[T]:
|
|
304
|
+
_registry.register(type_class, output_for=output_for)
|
|
305
|
+
return type_class
|
|
306
|
+
|
|
307
|
+
# Handle both @wishful.type and @wishful.type(...) syntax
|
|
308
|
+
if cls is None:
|
|
309
|
+
# Called with arguments: @wishful.type(output_for='...')
|
|
310
|
+
return decorator
|
|
311
|
+
else:
|
|
312
|
+
# Called without arguments: @wishful.type
|
|
313
|
+
return decorator(cls)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def get_type_schema(type_name: str) -> str | None:
|
|
317
|
+
"""Get the schema for a registered type."""
|
|
318
|
+
return _registry.get_schema(type_name)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def get_all_type_schemas() -> dict[str, str]:
|
|
322
|
+
"""Get all registered type schemas."""
|
|
323
|
+
return _registry.get_all_schemas()
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def get_output_type_for_function(function_name: str) -> str | None:
|
|
327
|
+
"""Get the output type registered for a function."""
|
|
328
|
+
return _registry.get_output_type(function_name)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def clear_type_registry() -> None:
|
|
332
|
+
"""Clear all registered types (useful for testing)."""
|
|
333
|
+
_registry.clear()
|
wishful/ui.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Iterator
|
|
5
|
+
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
8
|
+
|
|
9
|
+
from wishful.config import settings
|
|
10
|
+
|
|
11
|
+
_console = Console()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@contextmanager
|
|
15
|
+
def spinner(message: str) -> Iterator[None]:
|
|
16
|
+
if not settings.spinner:
|
|
17
|
+
yield
|
|
18
|
+
return
|
|
19
|
+
|
|
20
|
+
with Progress(SpinnerColumn(), TextColumn(message), console=_console, transient=True) as progress:
|
|
21
|
+
task_id = progress.add_task(message, total=None)
|
|
22
|
+
try:
|
|
23
|
+
yield
|
|
24
|
+
finally:
|
|
25
|
+
progress.update(task_id, completed=1)
|
|
26
|
+
progress.stop()
|