nighthawk-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,279 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import re
5
+ import textwrap
6
+ from dataclasses import dataclass
7
+ from typing import Any, Literal
8
+
9
+ import yaml
10
+
11
+ from ..errors import NaturalParseError
12
+
13
+ _IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
14
+ _BINDING_PATTERN = re.compile(r"<(:?)([A-Za-z_][A-Za-z0-9_]*)>")
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class NaturalBlock:
19
+ kind: Literal["docstring", "inline"]
20
+ text: str
21
+ input_bindings: tuple[str, ...]
22
+ output_bindings: tuple[str, ...]
23
+ lineno: int
24
+
25
+
26
+ def is_natural_sentinel(text: str) -> bool:
27
+ return text.startswith("natural\n")
28
+
29
+
30
+ def extract_program(text: str) -> str:
31
+ if not is_natural_sentinel(text):
32
+ raise NaturalParseError("Missing natural sentinel")
33
+ program = text.removeprefix("natural\n")
34
+ return textwrap.dedent(program)
35
+
36
+
37
+ def extract_bindings(program: str) -> tuple[tuple[str, ...], tuple[str, ...]]:
38
+ inputs: list[str] = []
39
+ outputs: list[str] = []
40
+ for match in _BINDING_PATTERN.finditer(program):
41
+ is_output = match.group(1) == ":"
42
+ name = match.group(2)
43
+ if not _IDENTIFIER_PATTERN.match(name):
44
+ raise NaturalParseError(f"Invalid binding name: {name!r}")
45
+ if is_output:
46
+ outputs.append(name)
47
+ else:
48
+ inputs.append(name)
49
+
50
+ def deduplicate(names: list[str]) -> tuple[str, ...]:
51
+ seen: set[str] = set()
52
+ ordered: list[str] = []
53
+ for name in names:
54
+ if name in seen:
55
+ continue
56
+ seen.add(name)
57
+ ordered.append(name)
58
+ return tuple(ordered)
59
+
60
+ return deduplicate(inputs), deduplicate(outputs)
61
+
62
+
63
+ _JOINED_STRING_FORMATTED_VALUE_PLACEHOLDER = "\x00"
64
+
65
+
66
+ def _joined_string_first_literal_or_none(joined_string: ast.JoinedStr) -> str | None:
67
+ if not joined_string.values:
68
+ return None
69
+ first = joined_string.values[0]
70
+ if not isinstance(first, ast.Constant) or not isinstance(first.value, str):
71
+ return None
72
+ return first.value
73
+
74
+
75
+ def _joined_string_is_natural_sentinel(joined_string: ast.JoinedStr) -> bool:
76
+ first_literal = _joined_string_first_literal_or_none(joined_string)
77
+ if first_literal is None:
78
+ return False
79
+ return is_natural_sentinel(first_literal)
80
+
81
+
82
+ def _joined_string_scan_text(joined_string: ast.JoinedStr, *, formatted_value_placeholder: str) -> str:
83
+ parts: list[str] = []
84
+ for part in joined_string.values:
85
+ if isinstance(part, ast.Constant) and isinstance(part.value, str):
86
+ parts.append(part.value)
87
+ else:
88
+ parts.append(formatted_value_placeholder)
89
+ return "".join(parts)
90
+
91
+
92
+ def _validate_joined_string_bindings_do_not_span_formatted_values(joined_string: ast.JoinedStr) -> None:
93
+ """Validate that no binding marker spans a formatted value boundary."""
94
+ boundary_marked_text = _joined_string_scan_text(
95
+ joined_string,
96
+ formatted_value_placeholder=_JOINED_STRING_FORMATTED_VALUE_PLACEHOLDER,
97
+ )
98
+
99
+ if re.search(r"<[^>]*" + _JOINED_STRING_FORMATTED_VALUE_PLACEHOLDER + r"[^>]*>", boundary_marked_text):
100
+ raise NaturalParseError("Binding marker must not span formatted value boundary in inline f-string Natural block")
101
+
102
+
103
+ def _extract_program_and_bindings_from_joined_string(joined_string: ast.JoinedStr) -> tuple[str, tuple[str, ...], tuple[str, ...]]:
104
+ _validate_joined_string_bindings_do_not_span_formatted_values(joined_string)
105
+
106
+ scan_text = _joined_string_scan_text(joined_string, formatted_value_placeholder="")
107
+ program = extract_program(scan_text)
108
+ input_bindings, output_bindings = extract_bindings(program)
109
+ return program, input_bindings, output_bindings
110
+
111
+
112
+ def find_natural_blocks(func_source: str) -> tuple[NaturalBlock, ...]:
113
+ """Parse function source text and return Natural blocks (docstring + inline)."""
114
+
115
+ try:
116
+ module = ast.parse(func_source)
117
+ except SyntaxError as e:
118
+ raise NaturalParseError(str(e)) from e
119
+
120
+ blocks: list[NaturalBlock] = []
121
+
122
+ func_def: ast.FunctionDef | ast.AsyncFunctionDef | None = None
123
+ for node in module.body:
124
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
125
+ func_def = node
126
+ break
127
+ if func_def is None:
128
+ raise NaturalParseError("No function definition found")
129
+
130
+ docstring_text = ast.get_docstring(func_def, clean=False)
131
+ if docstring_text and is_natural_sentinel(docstring_text):
132
+ program = extract_program(docstring_text)
133
+ input_bindings, output_bindings = extract_bindings(program)
134
+ blocks.append(
135
+ NaturalBlock(
136
+ kind="docstring",
137
+ text=program,
138
+ input_bindings=input_bindings,
139
+ output_bindings=output_bindings,
140
+ lineno=getattr(func_def, "lineno", 1),
141
+ )
142
+ )
143
+
144
+ start_index = 0
145
+ if func_def.body:
146
+ first_statement = func_def.body[0]
147
+ if isinstance(first_statement, ast.Expr) and isinstance(first_statement.value, ast.Constant) and isinstance(first_statement.value.value, str):
148
+ start_index = 1
149
+
150
+ for statement in func_def.body[start_index:]:
151
+ if not isinstance(statement, ast.Expr):
152
+ continue
153
+
154
+ value = statement.value
155
+
156
+ if isinstance(value, ast.Constant) and isinstance(value.value, str):
157
+ text = value.value
158
+ if is_natural_sentinel(text):
159
+ program = extract_program(text)
160
+ input_bindings, output_bindings = extract_bindings(program)
161
+ blocks.append(
162
+ NaturalBlock(
163
+ kind="inline",
164
+ text=program,
165
+ input_bindings=input_bindings,
166
+ output_bindings=output_bindings,
167
+ lineno=getattr(statement, "lineno", 1),
168
+ )
169
+ )
170
+
171
+ if isinstance(value, ast.JoinedStr) and _joined_string_is_natural_sentinel(value):
172
+ program, input_bindings, output_bindings = _extract_program_and_bindings_from_joined_string(value)
173
+ blocks.append(
174
+ NaturalBlock(
175
+ kind="inline",
176
+ text=program,
177
+ input_bindings=input_bindings,
178
+ output_bindings=output_bindings,
179
+ lineno=getattr(statement, "lineno", 1),
180
+ )
181
+ )
182
+
183
+ return tuple(blocks)
184
+
185
+
186
+ _FRONTMATTER_STEP_KINDS = ("pass", "return", "break", "continue", "raise")
187
+
188
+
189
+ def validate_frontmatter_deny(frontmatter: dict[str, object]) -> tuple[str, ...]:
190
+ """Validate a parsed frontmatter mapping and return the denied step kinds.
191
+
192
+ Raises:
193
+ NaturalParseError: If the frontmatter contains unknown keys, unknown
194
+ step kind names, or has an invalid ``deny`` structure.
195
+ """
196
+ if not frontmatter:
197
+ return ()
198
+
199
+ allowed_keys = {"deny"}
200
+ unknown_keys = set(frontmatter.keys()) - allowed_keys
201
+ if unknown_keys:
202
+ unknown_key_list = ", ".join(sorted(str(k) for k in unknown_keys))
203
+ raise NaturalParseError(f"Unknown frontmatter keys: {unknown_key_list}")
204
+
205
+ if "deny" not in frontmatter:
206
+ raise NaturalParseError("Frontmatter must include 'deny'")
207
+
208
+ deny_value = frontmatter["deny"]
209
+ if not isinstance(deny_value, list) or not all(isinstance(item, str) for item in deny_value):
210
+ raise NaturalParseError("Frontmatter 'deny' must be a YAML sequence of strings")
211
+
212
+ if len(deny_value) == 0:
213
+ raise NaturalParseError("Frontmatter 'deny' must not be empty")
214
+
215
+ denied: list[str] = []
216
+ for item in deny_value:
217
+ if item not in _FRONTMATTER_STEP_KINDS:
218
+ raise NaturalParseError(f"Unknown denied step kind: {item}")
219
+ if item not in denied:
220
+ denied.append(item)
221
+
222
+ return tuple(denied)
223
+
224
+
225
+ def parse_frontmatter(processed_natural_program: str) -> tuple[str, dict[str, Any]]:
226
+ """Parse and strip YAML frontmatter from a Natural program.
227
+
228
+ Frontmatter is recognized when the first non-blank line is ``---`` and a
229
+ matching closing ``---`` line follows. The YAML content between the
230
+ delimiters must be a mapping.
231
+
232
+ Returns:
233
+ A tuple of (program_text_without_frontmatter, parsed_mapping).
234
+ When no frontmatter is present the mapping is empty.
235
+
236
+ Raises:
237
+ NaturalParseError: If the frontmatter is syntactically invalid.
238
+ """
239
+ lines = processed_natural_program.splitlines(keepends=True)
240
+ if not lines:
241
+ return processed_natural_program, {}
242
+
243
+ start_index: int | None = None
244
+ for i, line in enumerate(lines):
245
+ if line.strip(" \t\r\n") == "":
246
+ continue
247
+ start_index = i
248
+ break
249
+
250
+ if start_index is None:
251
+ return processed_natural_program, {}
252
+
253
+ first_line = lines[start_index]
254
+ if first_line not in ("---\n", "---"):
255
+ return processed_natural_program, {}
256
+
257
+ closing_index: int | None = None
258
+ for i, line in enumerate(lines[start_index + 1 :], start=start_index + 1):
259
+ if line in ("---\n", "---"):
260
+ closing_index = i
261
+ break
262
+
263
+ if closing_index is None:
264
+ return processed_natural_program, {}
265
+
266
+ yaml_text = "".join(lines[start_index + 1 : closing_index])
267
+ if yaml_text.strip() == "":
268
+ return processed_natural_program, {}
269
+
270
+ try:
271
+ loaded = yaml.safe_load(yaml_text)
272
+ except yaml.YAMLError as e:
273
+ raise NaturalParseError(f"Frontmatter YAML parsing failed: {e}") from e
274
+
275
+ if not isinstance(loaded, dict):
276
+ raise NaturalParseError("Frontmatter YAML must be a mapping")
277
+
278
+ instructions_without_frontmatter = "".join(lines[closing_index + 1 :])
279
+ return instructions_without_frontmatter, loaded
@@ -0,0 +1,302 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import inspect
5
+ import logging
6
+ import sys
7
+ import textwrap
8
+ from collections.abc import Awaitable, Callable
9
+ from functools import wraps
10
+ from typing import Any, cast
11
+
12
+ from ..runtime.runner import Runner, StepEnvelope
13
+ from ..runtime.scoping import get_step_executor
14
+ from ..runtime.step_context import python_cell_scope, python_name_scope
15
+ from ..tools.registry import call_scope
16
+ from .blocks import find_natural_blocks
17
+ from .transform import transform_module_ast
18
+
19
+ type NaturalFunctionCallable = Callable[..., Any]
20
+
21
+
22
+ class _RunnerProxy:
23
+ @staticmethod
24
+ def run_step(
25
+ natural_program: str,
26
+ input_binding_names: list[str],
27
+ output_binding_names: list[str],
28
+ binding_name_to_type: dict[str, object],
29
+ return_annotation: object,
30
+ is_in_loop: bool,
31
+ ) -> StepEnvelope:
32
+ caller_frame = sys._getframe(1)
33
+ current_step_executor = get_step_executor()
34
+ runner = Runner(current_step_executor)
35
+ return runner.run_step(
36
+ natural_program,
37
+ input_binding_names,
38
+ output_binding_names,
39
+ binding_name_to_type,
40
+ return_annotation,
41
+ is_in_loop,
42
+ caller_frame=caller_frame,
43
+ )
44
+
45
+ @staticmethod
46
+ async def run_step_async(
47
+ natural_program: str,
48
+ input_binding_names: list[str],
49
+ output_binding_names: list[str],
50
+ binding_name_to_type: dict[str, object],
51
+ return_annotation: object,
52
+ is_in_loop: bool,
53
+ ) -> StepEnvelope:
54
+ caller_frame = sys._getframe(1)
55
+ current_step_executor = get_step_executor()
56
+ runner = Runner(current_step_executor)
57
+ return await runner.run_step_async(
58
+ natural_program,
59
+ input_binding_names,
60
+ output_binding_names,
61
+ binding_name_to_type,
62
+ return_annotation,
63
+ is_in_loop,
64
+ caller_frame=caller_frame,
65
+ )
66
+
67
+
68
+ def _extract_inline_fstring_name_set(function_source: str, *, function_name: str) -> set[str]:
69
+ """Extract names referenced in f-string expressions of inline Natural blocks."""
70
+ try:
71
+ module = ast.parse(function_source)
72
+ except SyntaxError:
73
+ return set()
74
+
75
+ function_def: ast.FunctionDef | ast.AsyncFunctionDef | None = None
76
+ for node in module.body:
77
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
78
+ function_def = node
79
+ break
80
+ if function_def is None:
81
+ return set()
82
+
83
+ names: set[str] = set()
84
+
85
+ class Visitor(ast.NodeVisitor):
86
+ def visit_Name(self, node: ast.Name) -> None: # noqa: N802
87
+ names.add(node.id)
88
+
89
+ visitor = Visitor()
90
+
91
+ for statement in function_def.body:
92
+ if not isinstance(statement, ast.Expr):
93
+ continue
94
+ value = statement.value
95
+ if not isinstance(value, ast.JoinedStr):
96
+ continue
97
+
98
+ first_part: ast.expr | None = value.values[0] if value.values else None
99
+ if not isinstance(first_part, ast.Constant) or not isinstance(first_part.value, str):
100
+ continue
101
+ if not first_part.value.startswith("natural\n"):
102
+ continue
103
+
104
+ for part in value.values:
105
+ if isinstance(part, ast.FormattedValue):
106
+ visitor.visit(part.value)
107
+
108
+ return names
109
+
110
+
111
+ def _build_capture_name_set(source: str, function_name: str) -> set[str]:
112
+ """Build the set of names that need to be captured from the enclosing scope."""
113
+ capture_name_set: set[str] = set()
114
+ try:
115
+ for block in find_natural_blocks(source):
116
+ capture_name_set.update(block.input_bindings)
117
+ capture_name_set.update(block.output_bindings)
118
+ capture_name_set.update(_extract_inline_fstring_name_set(source, function_name=function_name))
119
+ except Exception as exception:
120
+ logging.getLogger("nighthawk").warning("Failed to extract capture names for %s: %s", function_name, exception)
121
+ capture_name_set = set()
122
+ return capture_name_set
123
+
124
+
125
+ def _build_transformed_factory_module(
126
+ *,
127
+ transformed_module: ast.Module,
128
+ function_name: str,
129
+ name_to_value: dict[str, object],
130
+ ) -> ast.Module:
131
+ """Build a factory-function module that captures enclosing-scope values via closure."""
132
+ transformed_function_def: ast.FunctionDef | ast.AsyncFunctionDef | None = None
133
+ for node in transformed_module.body:
134
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
135
+ transformed_function_def = node
136
+ break
137
+
138
+ if transformed_function_def is None:
139
+ raise RuntimeError("Transformed function not found in transformed module")
140
+
141
+ captured_value_name = "__nh_captured_values__"
142
+ factory_name = "__nh_factory__"
143
+
144
+ factory_body: list[ast.stmt] = []
145
+ for name in sorted(name_to_value.keys()):
146
+ factory_body.append(
147
+ ast.Assign(
148
+ targets=[ast.Name(id=name, ctx=ast.Store())],
149
+ value=ast.Subscript(
150
+ value=ast.Name(id=captured_value_name, ctx=ast.Load()),
151
+ slice=ast.Constant(name),
152
+ ctx=ast.Load(),
153
+ ),
154
+ )
155
+ )
156
+
157
+ factory_body.append(transformed_function_def)
158
+ factory_body.append(ast.Return(value=ast.Name(id=function_name, ctx=ast.Load())))
159
+
160
+ factory_function_def = ast.FunctionDef(
161
+ name=factory_name,
162
+ args=ast.arguments(
163
+ posonlyargs=[],
164
+ args=[ast.arg(arg=captured_value_name)],
165
+ kwonlyargs=[],
166
+ kw_defaults=[],
167
+ defaults=[],
168
+ ),
169
+ body=factory_body,
170
+ decorator_list=[],
171
+ returns=None,
172
+ type_comment=None,
173
+ )
174
+
175
+ factory_module = ast.Module(body=[factory_function_def], type_ignores=[])
176
+ ast.fix_missing_locations(factory_module)
177
+ return factory_module
178
+
179
+
180
+ def natural_function(func: NaturalFunctionCallable | None = None) -> NaturalFunctionCallable:
181
+ """Transform a function containing Natural blocks into an executable Natural function.
182
+
183
+ Parses the function source to find Natural blocks, rewrites the AST to
184
+ delegate block execution to the active step executor at runtime.
185
+
186
+ Args:
187
+ func: The function to transform. Can be omitted for use as a bare
188
+ decorator.
189
+
190
+ Example:
191
+ ```python
192
+ @nighthawk.natural_function
193
+ def summarize(text: str) -> str:
194
+ '''natural
195
+ Summarize <text> in one sentence.
196
+ -> <:result>
197
+ '''
198
+ return result
199
+ ```
200
+ """
201
+ if func is None:
202
+ return lambda f: natural_function(f) # type: ignore[return-value]
203
+
204
+ if isinstance(func, staticmethod):
205
+ decorated_static_function = natural_function(func.__func__)
206
+ return cast(NaturalFunctionCallable, staticmethod(decorated_static_function))
207
+
208
+ if isinstance(func, classmethod):
209
+ decorated_class_function = natural_function(func.__func__)
210
+ return cast(NaturalFunctionCallable, classmethod(decorated_class_function))
211
+
212
+ lines, starting_line_number = inspect.getsourcelines(func)
213
+ source = textwrap.dedent("".join(lines))
214
+
215
+ try:
216
+ original_module = ast.parse(source)
217
+ for node in original_module.body:
218
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func.__name__:
219
+ node.decorator_list = []
220
+ break
221
+ ast.increment_lineno(original_module, starting_line_number - 1)
222
+ except Exception as exception:
223
+ logging.getLogger("nighthawk").warning("Failed to parse original module AST for %s: %s", func.__name__, exception)
224
+ original_module = ast.Module(body=[], type_ignores=[])
225
+
226
+ capture_name_set = _build_capture_name_set(source, func.__name__)
227
+
228
+ definition_frame = inspect.currentframe()
229
+ name_to_value: dict[str, object] = {}
230
+ if definition_frame is not None and definition_frame.f_back is not None:
231
+ caller_frame = definition_frame.f_back
232
+ if caller_frame.f_code.co_name != "<module>":
233
+ for name in capture_name_set:
234
+ if name in caller_frame.f_locals:
235
+ name_to_value[name] = caller_frame.f_locals[name]
236
+
237
+ captured_name_tuple = tuple(sorted(capture_name_set))
238
+
239
+ transformed_module = transform_module_ast(original_module, captured_name_tuple=captured_name_tuple)
240
+
241
+ filename = inspect.getsourcefile(func) or "<nighthawk>"
242
+
243
+ factory_module = _build_transformed_factory_module(
244
+ transformed_module=transformed_module,
245
+ function_name=func.__name__,
246
+ name_to_value=name_to_value,
247
+ )
248
+ code = compile(factory_module, filename, "exec")
249
+
250
+ globals_namespace: dict[str, object] = dict(func.__globals__)
251
+ globals_namespace["__nighthawk_runner__"] = _RunnerProxy()
252
+ from .blocks import extract_program as _nh_extract_program
253
+
254
+ globals_namespace["__nh_extract_program__"] = _nh_extract_program
255
+ globals_namespace["__nh_python_cell_scope__"] = python_cell_scope
256
+
257
+ module_namespace: dict[str, object] = {}
258
+ exec(code, globals_namespace, module_namespace)
259
+
260
+ factory = module_namespace.get("__nh_factory__")
261
+ if not callable(factory):
262
+ raise RuntimeError("Transformed factory not found after compilation")
263
+
264
+ transformed = factory(name_to_value)
265
+ if not callable(transformed):
266
+ raise RuntimeError("Transformed function not found after factory execution")
267
+
268
+ transformed_freevar_name_set = set(transformed.__code__.co_freevars)
269
+ captured_name_set = set(name_to_value.keys())
270
+
271
+ unexpected_freevar_name_set = transformed_freevar_name_set - captured_name_set
272
+ allowed_unexpected_freevar_name_set = {func.__name__}
273
+ if not unexpected_freevar_name_set.issubset(allowed_unexpected_freevar_name_set):
274
+ raise RuntimeError(
275
+ f"Transformed function freevars do not match captured names. freevars={transformed.__code__.co_freevars!r} captured={tuple(sorted(name_to_value.keys()))!r}"
276
+ )
277
+
278
+ if transformed.__closure__ is None and name_to_value:
279
+ raise RuntimeError("Transformed function closure is missing for captured names")
280
+
281
+ if inspect.iscoroutinefunction(func):
282
+ transformed_async = cast(Callable[..., Awaitable[Any]], transformed)
283
+
284
+ @wraps(func)
285
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
286
+ with call_scope():
287
+ if name_to_value:
288
+ with python_name_scope(name_to_value):
289
+ return await transformed_async(*args, **kwargs)
290
+ return await transformed_async(*args, **kwargs)
291
+
292
+ return cast(NaturalFunctionCallable, async_wrapper) # type: ignore[return-value]
293
+
294
+ @wraps(func)
295
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
296
+ with call_scope():
297
+ if name_to_value:
298
+ with python_name_scope(name_to_value):
299
+ return transformed(*args, **kwargs)
300
+ return transformed(*args, **kwargs)
301
+
302
+ return cast(NaturalFunctionCallable, wrapper) # type: ignore[return-value]