modelwright 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,153 @@
1
+ """`formulas`-backed workbook oracle."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from modelwright.oracles import (
9
+ OracleDiagnostic,
10
+ OracleRequest,
11
+ OracleResult,
12
+ WorkbookOracle,
13
+ missing_optional_dependency_diagnostic,
14
+ )
15
+ from modelwright.validation import JsonValue
16
+
17
+
18
+ class FormulasWorkbookOracle(WorkbookOracle):
19
+ """Evaluate workbook outputs with the optional `formulas` package."""
20
+
21
+ backend_name = "formulas"
22
+
23
+ def evaluate(self, request: OracleRequest) -> OracleResult:
24
+ if request.inputs:
25
+ return OracleResult(
26
+ backend=self.backend_name,
27
+ source_workbook=request.source_workbook,
28
+ diagnostics=(
29
+ OracleDiagnostic(
30
+ diagnostic_code="unsupported_oracle_inputs",
31
+ message="formulas oracle input overrides are not supported yet",
32
+ severity="error",
33
+ ),
34
+ ),
35
+ )
36
+
37
+ try:
38
+ import formulas
39
+ except ImportError:
40
+ return OracleResult(
41
+ backend=self.backend_name,
42
+ source_workbook=request.source_workbook,
43
+ diagnostics=(
44
+ missing_optional_dependency_diagnostic(
45
+ dependency="formulas",
46
+ extra="oracle",
47
+ backend=self.backend_name,
48
+ ),
49
+ ),
50
+ )
51
+
52
+ workbook_path = Path(request.source_workbook)
53
+ if not workbook_path.exists():
54
+ return OracleResult(
55
+ backend=self.backend_name,
56
+ source_workbook=request.source_workbook,
57
+ diagnostics=(
58
+ OracleDiagnostic(
59
+ diagnostic_code="missing_source_workbook",
60
+ message="source workbook does not exist",
61
+ severity="error",
62
+ location=str(workbook_path),
63
+ ),
64
+ ),
65
+ )
66
+
67
+ try:
68
+ model = formulas.ExcelModel().loads(str(workbook_path)).finish()
69
+ calculated = model.calculate()
70
+ except Exception as exc:
71
+ return OracleResult(
72
+ backend=self.backend_name,
73
+ source_workbook=request.source_workbook,
74
+ diagnostics=(
75
+ OracleDiagnostic(
76
+ diagnostic_code="oracle_calculation_failed",
77
+ message=str(exc),
78
+ severity="error",
79
+ location=str(workbook_path),
80
+ raw_value=exc.__class__.__name__,
81
+ ),
82
+ ),
83
+ )
84
+
85
+ outputs: dict[str, JsonValue] = {}
86
+ diagnostics: list[OracleDiagnostic] = []
87
+ for output in request.outputs:
88
+ matched_key, value = _find_output_value(calculated, output.cell_ref)
89
+ if matched_key is None:
90
+ diagnostics.append(
91
+ OracleDiagnostic(
92
+ diagnostic_code="missing_oracle_output",
93
+ message="formulas did not return the requested workbook output",
94
+ severity="error",
95
+ location=output.cell_ref,
96
+ )
97
+ )
98
+ continue
99
+
100
+ try:
101
+ outputs[output.cell_ref] = _to_json_scalar(value)
102
+ except ValueError as exc:
103
+ diagnostics.append(
104
+ OracleDiagnostic(
105
+ diagnostic_code="unsupported_oracle_value",
106
+ message=str(exc),
107
+ severity="error",
108
+ location=output.cell_ref,
109
+ raw_value=str(matched_key),
110
+ )
111
+ )
112
+
113
+ return OracleResult(
114
+ backend=self.backend_name,
115
+ source_workbook=request.source_workbook,
116
+ outputs=outputs,
117
+ diagnostics=tuple(diagnostics),
118
+ )
119
+
120
+
121
+ def _find_output_value(calculated: Any, cell_ref: str) -> tuple[Any | None, Any | None]:
122
+ for key, value in calculated.items():
123
+ if _matches_cell_ref(str(key), cell_ref):
124
+ return key, value
125
+ return None, None
126
+
127
+
128
+ def _matches_cell_ref(formulas_key: str, cell_ref: str) -> bool:
129
+ if "!" not in cell_ref:
130
+ return False
131
+
132
+ sheet_name, coordinate = cell_ref.split("!", 1)
133
+ expected_suffix = f"]{sheet_name.upper()}'!{coordinate.upper().replace('$', '')}"
134
+ return formulas_key.upper().replace("$", "").endswith(expected_suffix)
135
+
136
+
137
+ def _to_json_scalar(value: Any) -> JsonValue:
138
+ if hasattr(value, "value"):
139
+ value = value.value
140
+
141
+ if hasattr(value, "tolist"):
142
+ value = value.tolist()
143
+
144
+ while isinstance(value, list) and len(value) == 1:
145
+ value = value[0]
146
+
147
+ if hasattr(value, "item"):
148
+ value = value.item()
149
+
150
+ if value is None or isinstance(value, str | int | float | bool):
151
+ return value
152
+
153
+ raise ValueError(f"unsupported oracle value type: {type(value).__name__}")