qnty 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qnty/__init__.py +140 -58
- qnty/_backup/problem_original.py +1251 -0
- qnty/_backup/quantity.py +63 -0
- qnty/codegen/cli.py +125 -0
- qnty/codegen/generators/data/unit_data.json +8807 -0
- qnty/codegen/generators/data_processor.py +345 -0
- qnty/codegen/generators/dimensions_gen.py +434 -0
- qnty/codegen/generators/doc_generator.py +141 -0
- qnty/codegen/generators/out/dimension_mapping.json +974 -0
- qnty/codegen/generators/out/dimension_metadata.json +123 -0
- qnty/codegen/generators/out/units_metadata.json +223 -0
- qnty/codegen/generators/quantities_gen.py +159 -0
- qnty/codegen/generators/setters_gen.py +178 -0
- qnty/codegen/generators/stubs_gen.py +167 -0
- qnty/codegen/generators/units_gen.py +295 -0
- qnty/codegen/generators/utils/__init__.py +0 -0
- qnty/equations/__init__.py +4 -0
- qnty/equations/equation.py +257 -0
- qnty/equations/system.py +127 -0
- qnty/expressions/__init__.py +61 -0
- qnty/expressions/cache.py +94 -0
- qnty/expressions/functions.py +96 -0
- qnty/expressions/nodes.py +546 -0
- qnty/generated/__init__.py +0 -0
- qnty/generated/dimensions.py +514 -0
- qnty/generated/quantities.py +6003 -0
- qnty/generated/quantities.pyi +4192 -0
- qnty/generated/setters.py +12210 -0
- qnty/generated/units.py +9798 -0
- qnty/problem/__init__.py +91 -0
- qnty/problem/base.py +142 -0
- qnty/problem/composition.py +385 -0
- qnty/problem/composition_mixin.py +382 -0
- qnty/problem/equations.py +413 -0
- qnty/problem/metaclass.py +302 -0
- qnty/problem/reconstruction.py +1016 -0
- qnty/problem/solving.py +180 -0
- qnty/problem/validation.py +64 -0
- qnty/problem/variables.py +239 -0
- qnty/quantities/__init__.py +6 -0
- qnty/quantities/expression_quantity.py +314 -0
- qnty/quantities/quantity.py +428 -0
- qnty/quantities/typed_quantity.py +215 -0
- qnty/solving/__init__.py +0 -0
- qnty/solving/manager.py +90 -0
- qnty/solving/order.py +355 -0
- qnty/solving/solvers/__init__.py +20 -0
- qnty/solving/solvers/base.py +92 -0
- qnty/solving/solvers/iterative.py +185 -0
- qnty/solving/solvers/simultaneous.py +547 -0
- qnty/units/__init__.py +0 -0
- qnty/{prefixes.py → units/prefixes.py} +54 -33
- qnty/{unit.py → units/registry.py} +73 -32
- qnty/utils/__init__.py +0 -0
- qnty/utils/logging.py +40 -0
- qnty/validation/__init__.py +0 -0
- qnty/validation/registry.py +0 -0
- qnty/validation/rules.py +167 -0
- qnty-0.0.9.dist-info/METADATA +199 -0
- qnty-0.0.9.dist-info/RECORD +63 -0
- qnty/dimension.py +0 -186
- qnty/equation.py +0 -216
- qnty/expression.py +0 -492
- qnty/unit_types/base.py +0 -47
- qnty/units.py +0 -8113
- qnty/variable.py +0 -263
- qnty/variable_types/base.py +0 -58
- qnty/variable_types/expression_variable.py +0 -68
- qnty/variable_types/typed_variable.py +0 -87
- qnty/variables.py +0 -2298
- qnty/variables.pyi +0 -6148
- qnty-0.0.7.dist-info/METADATA +0 -355
- qnty-0.0.7.dist-info/RECORD +0 -19
- /qnty/{unit_types → codegen}/__init__.py +0 -0
- /qnty/{variable_types → codegen/generators}/__init__.py +0 -0
- {qnty-0.0.7.dist-info → qnty-0.0.9.dist-info}/WHEEL +0 -0
@@ -0,0 +1,546 @@
|
|
1
|
+
"""
|
2
|
+
Expression AST Nodes
|
3
|
+
===================
|
4
|
+
|
5
|
+
Core abstract syntax tree nodes for mathematical expressions.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import math
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
from typing import TYPE_CHECKING, Union
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from ..quantities.quantity import Quantity, TypeSafeVariable
|
14
|
+
|
15
|
+
from ..generated.units import DimensionlessUnits
|
16
|
+
from ..quantities.quantity import Quantity, TypeSafeVariable
|
17
|
+
from .cache import _EXPRESSION_RESULT_CACHE, _MAX_EXPRESSION_CACHE_SIZE, wrap_operand
|
18
|
+
|
19
|
+
|
20
|
+
class Expression(ABC):
|
21
|
+
"""Abstract base class for mathematical expressions."""
|
22
|
+
|
23
|
+
# Class-level optimization settings
|
24
|
+
_scope_cache = {}
|
25
|
+
_auto_eval_enabled = False # Disabled by default for performance
|
26
|
+
_max_scope_cache_size = 100 # Limit scope cache size
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def evaluate(self, variable_values: dict[str, 'TypeSafeVariable']) -> 'Quantity':
|
30
|
+
"""Evaluate the expression given variable values."""
|
31
|
+
pass
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def get_variables(self) -> set[str]:
|
35
|
+
"""Get all variable symbols used in this expression."""
|
36
|
+
pass
|
37
|
+
|
38
|
+
@abstractmethod
|
39
|
+
def simplify(self) -> 'Expression':
|
40
|
+
"""Simplify the expression."""
|
41
|
+
pass
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def __str__(self) -> str:
|
45
|
+
pass
|
46
|
+
|
47
|
+
def _discover_variables_from_scope(self) -> dict[str, 'TypeSafeVariable']:
|
48
|
+
"""Automatically discover variables from the calling scope (optimized)."""
|
49
|
+
# Skip if auto-evaluation is disabled
|
50
|
+
if not self._auto_eval_enabled:
|
51
|
+
return {}
|
52
|
+
|
53
|
+
# Check cache first with size limit
|
54
|
+
cache_key = id(self)
|
55
|
+
if cache_key in self._scope_cache:
|
56
|
+
return self._scope_cache[cache_key]
|
57
|
+
|
58
|
+
# Clean cache if it gets too large
|
59
|
+
if len(self._scope_cache) >= self._max_scope_cache_size:
|
60
|
+
self._scope_cache.clear()
|
61
|
+
|
62
|
+
import inspect
|
63
|
+
|
64
|
+
# Get the frame that called this method (skip through __str__ calls)
|
65
|
+
frame = inspect.currentframe()
|
66
|
+
try:
|
67
|
+
# Skip frames until we find one outside the expression system (with depth limit)
|
68
|
+
depth = 0
|
69
|
+
max_depth = 6 # Reduced from unlimited for performance
|
70
|
+
while frame and depth < max_depth and (
|
71
|
+
frame.f_code.co_filename.endswith('expression.py') or
|
72
|
+
frame.f_code.co_name in ['__str__', '__repr__']
|
73
|
+
):
|
74
|
+
frame = frame.f_back
|
75
|
+
depth += 1
|
76
|
+
|
77
|
+
if not frame:
|
78
|
+
return {}
|
79
|
+
|
80
|
+
# Get required variables first to optimize search
|
81
|
+
required_vars = self.get_variables()
|
82
|
+
if not required_vars:
|
83
|
+
return {}
|
84
|
+
|
85
|
+
discovered = {}
|
86
|
+
|
87
|
+
# Search locals first (most common case)
|
88
|
+
local_vars = frame.f_locals
|
89
|
+
for var_name in required_vars:
|
90
|
+
# Direct lookup first (fastest)
|
91
|
+
if var_name in local_vars:
|
92
|
+
obj = local_vars[var_name]
|
93
|
+
if isinstance(obj, TypeSafeVariable):
|
94
|
+
discovered[var_name] = obj
|
95
|
+
continue
|
96
|
+
|
97
|
+
# Search globals only for remaining variables
|
98
|
+
if len(discovered) < len(required_vars):
|
99
|
+
global_vars = frame.f_globals
|
100
|
+
remaining_vars = required_vars - discovered.keys()
|
101
|
+
for var_name in remaining_vars:
|
102
|
+
if var_name in global_vars:
|
103
|
+
obj = global_vars[var_name]
|
104
|
+
if isinstance(obj, TypeSafeVariable):
|
105
|
+
discovered[var_name] = obj
|
106
|
+
|
107
|
+
# Cache the result
|
108
|
+
self._scope_cache[cache_key] = discovered
|
109
|
+
return discovered
|
110
|
+
|
111
|
+
finally:
|
112
|
+
del frame
|
113
|
+
|
114
|
+
def _can_auto_evaluate(self) -> tuple[bool, dict[str, 'TypeSafeVariable']]:
|
115
|
+
"""Check if expression can be auto-evaluated from scope."""
|
116
|
+
try:
|
117
|
+
discovered = self._discover_variables_from_scope()
|
118
|
+
required_vars = self.get_variables()
|
119
|
+
|
120
|
+
# Check if all required variables are available and have values
|
121
|
+
for var_name in required_vars:
|
122
|
+
if var_name not in discovered:
|
123
|
+
return False, {}
|
124
|
+
var = discovered[var_name]
|
125
|
+
if not hasattr(var, 'quantity') or var.quantity is None:
|
126
|
+
return False, {}
|
127
|
+
|
128
|
+
return True, discovered
|
129
|
+
|
130
|
+
except Exception:
|
131
|
+
return False, {}
|
132
|
+
|
133
|
+
def __add__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
134
|
+
return BinaryOperation('+', self, wrap_operand(other))
|
135
|
+
|
136
|
+
def __radd__(self, other: Union['TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
137
|
+
return BinaryOperation('+', wrap_operand(other), self)
|
138
|
+
|
139
|
+
def __sub__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
140
|
+
return BinaryOperation('-', self, wrap_operand(other))
|
141
|
+
|
142
|
+
def __rsub__(self, other: Union['TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
143
|
+
return BinaryOperation('-', wrap_operand(other), self)
|
144
|
+
|
145
|
+
def __mul__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
146
|
+
return BinaryOperation('*', self, wrap_operand(other))
|
147
|
+
|
148
|
+
def __rmul__(self, other: Union['TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
149
|
+
return BinaryOperation('*', wrap_operand(other), self)
|
150
|
+
|
151
|
+
def __truediv__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
152
|
+
return BinaryOperation('/', self, wrap_operand(other))
|
153
|
+
|
154
|
+
def __rtruediv__(self, other: Union['TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
155
|
+
return BinaryOperation('/', wrap_operand(other), self)
|
156
|
+
|
157
|
+
def __pow__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
158
|
+
return BinaryOperation('**', self, wrap_operand(other))
|
159
|
+
|
160
|
+
def __rpow__(self, other: Union['TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
161
|
+
return BinaryOperation('**', wrap_operand(other), self)
|
162
|
+
|
163
|
+
def __abs__(self) -> 'Expression':
|
164
|
+
"""Absolute value of the expression."""
|
165
|
+
return UnaryFunction('abs', self)
|
166
|
+
|
167
|
+
# Comparison operators for conditional expressions (consolidated)
|
168
|
+
def _make_comparison(self, operator: str, other) -> 'BinaryOperation':
|
169
|
+
"""Helper method to create comparison operations."""
|
170
|
+
return BinaryOperation(operator, self, wrap_operand(other))
|
171
|
+
|
172
|
+
def __lt__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'BinaryOperation':
|
173
|
+
return self._make_comparison('<', other)
|
174
|
+
|
175
|
+
def __le__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'BinaryOperation':
|
176
|
+
return self._make_comparison('<=', other)
|
177
|
+
|
178
|
+
def __gt__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'BinaryOperation':
|
179
|
+
return self._make_comparison('>', other)
|
180
|
+
|
181
|
+
def __ge__(self, other: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'BinaryOperation':
|
182
|
+
return self._make_comparison('>=', other)
|
183
|
+
|
184
|
+
@staticmethod
|
185
|
+
def _wrap_operand(operand: Union['Expression', 'TypeSafeVariable', 'Quantity', int, float]) -> 'Expression':
|
186
|
+
"""Wrap non-Expression operands in appropriate Expression subclasses."""
|
187
|
+
return wrap_operand(operand)
|
188
|
+
|
189
|
+
|
190
|
+
class VariableReference(Expression):
|
191
|
+
"""Reference to a variable in an expression with performance optimizations."""
|
192
|
+
__slots__ = ('variable', '_cached_name', '_last_symbol')
|
193
|
+
|
194
|
+
def __init__(self, variable: 'TypeSafeVariable'):
|
195
|
+
self.variable = variable
|
196
|
+
# Cache the name resolution to avoid repeated lookups
|
197
|
+
self._cached_name = None
|
198
|
+
self._last_symbol = None
|
199
|
+
|
200
|
+
@property
|
201
|
+
def name(self) -> str:
|
202
|
+
"""Get variable name with caching for performance."""
|
203
|
+
current_symbol = self.variable.symbol
|
204
|
+
if self._cached_name is None or self._last_symbol != current_symbol:
|
205
|
+
# Use symbol for optinova compatibility, fall back to name if symbol not set
|
206
|
+
self._cached_name = current_symbol if current_symbol else self.variable.name
|
207
|
+
self._last_symbol = current_symbol
|
208
|
+
return self._cached_name
|
209
|
+
|
210
|
+
def evaluate(self, variable_values: dict[str, 'TypeSafeVariable']) -> 'Quantity':
|
211
|
+
try:
|
212
|
+
if self.name in variable_values:
|
213
|
+
var = variable_values[self.name]
|
214
|
+
if var.quantity is not None:
|
215
|
+
return var.quantity
|
216
|
+
elif self.variable.quantity is not None:
|
217
|
+
return self.variable.quantity
|
218
|
+
|
219
|
+
# If we reach here, no valid quantity was found
|
220
|
+
available_vars = list(variable_values.keys()) if variable_values else []
|
221
|
+
raise ValueError(
|
222
|
+
f"Cannot evaluate variable '{self.name}' without value. "
|
223
|
+
f"Available variables: {available_vars}"
|
224
|
+
)
|
225
|
+
except Exception as e:
|
226
|
+
if isinstance(e, ValueError):
|
227
|
+
raise
|
228
|
+
raise ValueError(f"Error evaluating variable '{self.name}': {e}") from e
|
229
|
+
|
230
|
+
def get_variables(self) -> set[str]:
|
231
|
+
return {self.name}
|
232
|
+
|
233
|
+
def simplify(self) -> 'Expression':
|
234
|
+
return self
|
235
|
+
|
236
|
+
def __str__(self) -> str:
|
237
|
+
return self.name
|
238
|
+
|
239
|
+
|
240
|
+
class Constant(Expression):
|
241
|
+
"""Constant value in an expression."""
|
242
|
+
__slots__ = ('value',)
|
243
|
+
|
244
|
+
def __init__(self, value: 'Quantity'):
|
245
|
+
self.value = value
|
246
|
+
|
247
|
+
def evaluate(self, variable_values: dict[str, 'TypeSafeVariable']) -> 'Quantity':
|
248
|
+
del variable_values # Suppress unused variable warning
|
249
|
+
return self.value
|
250
|
+
|
251
|
+
def get_variables(self) -> set[str]:
|
252
|
+
return set()
|
253
|
+
|
254
|
+
def simplify(self) -> 'Expression':
|
255
|
+
return self
|
256
|
+
|
257
|
+
def __str__(self) -> str:
|
258
|
+
return str(self.value.value)
|
259
|
+
|
260
|
+
|
261
|
+
class BinaryOperation(Expression):
|
262
|
+
"""Binary operation between two expressions."""
|
263
|
+
__slots__ = ('operator', 'left', 'right')
|
264
|
+
|
265
|
+
# Operator dispatch table for better performance
|
266
|
+
_ARITHMETIC_OPS = {'+', '-', '*', '/', '**'}
|
267
|
+
_COMPARISON_OPS = {'<', '<=', '>', '>=', '==', '!='}
|
268
|
+
|
269
|
+
def __init__(self, operator: str, left: Expression, right: Expression):
|
270
|
+
self.operator = operator
|
271
|
+
self.left = left
|
272
|
+
self.right = right
|
273
|
+
|
274
|
+
def evaluate(self, variable_values: dict[str, 'TypeSafeVariable']) -> 'Quantity':
|
275
|
+
try:
|
276
|
+
# Fast path for constant expressions (both sides are constants)
|
277
|
+
if isinstance(self.left, Constant) and isinstance(self.right, Constant):
|
278
|
+
cache_key = (id(self), self.operator, id(self.left.value), id(self.right.value))
|
279
|
+
if cache_key in _EXPRESSION_RESULT_CACHE:
|
280
|
+
return _EXPRESSION_RESULT_CACHE[cache_key]
|
281
|
+
|
282
|
+
# Clean cache if it gets too large
|
283
|
+
if len(_EXPRESSION_RESULT_CACHE) >= _MAX_EXPRESSION_CACHE_SIZE:
|
284
|
+
_EXPRESSION_RESULT_CACHE.clear()
|
285
|
+
else:
|
286
|
+
cache_key = None
|
287
|
+
|
288
|
+
left_val = self.left.evaluate(variable_values)
|
289
|
+
right_val = self.right.evaluate(variable_values)
|
290
|
+
|
291
|
+
# Fast dispatch for arithmetic operations
|
292
|
+
if self.operator in self._ARITHMETIC_OPS:
|
293
|
+
result = self._evaluate_arithmetic(left_val, right_val)
|
294
|
+
elif self.operator in self._COMPARISON_OPS:
|
295
|
+
result = self._evaluate_comparison(left_val, right_val)
|
296
|
+
else:
|
297
|
+
raise ValueError(f"Unknown operator: {self.operator}")
|
298
|
+
|
299
|
+
# Cache result for constant expressions
|
300
|
+
if cache_key is not None:
|
301
|
+
_EXPRESSION_RESULT_CACHE[cache_key] = result
|
302
|
+
|
303
|
+
return result
|
304
|
+
except Exception as e:
|
305
|
+
if isinstance(e, ValueError):
|
306
|
+
raise
|
307
|
+
raise ValueError(f"Error evaluating binary operation '{self}': {e}") from e
|
308
|
+
|
309
|
+
def _evaluate_arithmetic(self, left_val: 'Quantity', right_val: 'Quantity') -> 'Quantity':
|
310
|
+
"""Evaluate arithmetic operations with fast paths."""
|
311
|
+
# Fast path optimizations for common cases
|
312
|
+
if self.operator == '*':
|
313
|
+
# Fast path for multiplication by 1
|
314
|
+
if right_val.value == 1.0:
|
315
|
+
return left_val
|
316
|
+
elif left_val.value == 1.0:
|
317
|
+
return right_val
|
318
|
+
# Fast path for multiplication by 0
|
319
|
+
elif right_val.value == 0.0 or left_val.value == 0.0:
|
320
|
+
return Quantity(0.0, left_val.unit if right_val.value == 0.0 else right_val.unit)
|
321
|
+
return left_val * right_val
|
322
|
+
elif self.operator == '+':
|
323
|
+
# Fast path for addition with 0
|
324
|
+
if right_val.value == 0.0:
|
325
|
+
return left_val
|
326
|
+
elif left_val.value == 0.0:
|
327
|
+
return right_val
|
328
|
+
return left_val + right_val
|
329
|
+
elif self.operator == '-':
|
330
|
+
# Fast path for subtraction with 0
|
331
|
+
if right_val.value == 0.0:
|
332
|
+
return left_val
|
333
|
+
return left_val - right_val
|
334
|
+
elif self.operator == '/':
|
335
|
+
# Check for division by zero
|
336
|
+
if abs(right_val.value) < 1e-15:
|
337
|
+
raise ValueError(f"Division by zero in expression: {self}")
|
338
|
+
# Fast path for division by 1
|
339
|
+
if right_val.value == 1.0:
|
340
|
+
return left_val
|
341
|
+
return left_val / right_val
|
342
|
+
elif self.operator == '**':
|
343
|
+
# For power, right side should be dimensionless
|
344
|
+
if isinstance(right_val.value, int | float):
|
345
|
+
# Fast paths for common exponents
|
346
|
+
if right_val.value == 1.0:
|
347
|
+
return left_val
|
348
|
+
elif right_val.value == 0.0:
|
349
|
+
return Quantity(1.0, DimensionlessUnits.dimensionless)
|
350
|
+
elif right_val.value == 2.0:
|
351
|
+
return left_val * left_val # Use multiplication for squaring
|
352
|
+
|
353
|
+
if right_val.value < 0 and left_val.value < 0:
|
354
|
+
raise ValueError(f"Negative base with negative exponent: {left_val.value}^{right_val.value}")
|
355
|
+
result_value = left_val.value ** right_val.value
|
356
|
+
# For power operations, we need to handle units carefully
|
357
|
+
# This is a simplified implementation
|
358
|
+
return Quantity(result_value, left_val.unit)
|
359
|
+
else:
|
360
|
+
raise ValueError("Exponent must be dimensionless number")
|
361
|
+
else:
|
362
|
+
# Unknown operator - should not happen
|
363
|
+
raise ValueError(f"Unknown arithmetic operator: {self.operator}")
|
364
|
+
|
365
|
+
def _evaluate_comparison(self, left_val: 'Quantity', right_val: 'Quantity') -> 'Quantity':
|
366
|
+
"""Evaluate comparison operations."""
|
367
|
+
# Convert to same units for comparison if possible
|
368
|
+
try:
|
369
|
+
if left_val._dimension_sig == right_val._dimension_sig and left_val.unit != right_val.unit:
|
370
|
+
right_val = right_val.to(left_val.unit)
|
371
|
+
except (ValueError, TypeError, AttributeError):
|
372
|
+
pass
|
373
|
+
|
374
|
+
# Use dispatch dictionary for comparisons
|
375
|
+
ops = {
|
376
|
+
'<': lambda left, right: left < right,
|
377
|
+
'<=': lambda left, right: left <= right,
|
378
|
+
'>': lambda left, right: left > right,
|
379
|
+
'>=': lambda left, right: left >= right,
|
380
|
+
'==': lambda left, right: abs(left - right) < 1e-10,
|
381
|
+
'!=': lambda left, right: abs(left - right) >= 1e-10
|
382
|
+
}
|
383
|
+
|
384
|
+
result = ops[self.operator](left_val.value, right_val.value)
|
385
|
+
return Quantity(1.0 if result else 0.0, DimensionlessUnits.dimensionless)
|
386
|
+
|
387
|
+
def get_variables(self) -> set[str]:
|
388
|
+
return self.left.get_variables() | self.right.get_variables()
|
389
|
+
|
390
|
+
def simplify(self) -> Expression:
|
391
|
+
left_simplified = self.left.simplify()
|
392
|
+
right_simplified = self.right.simplify()
|
393
|
+
|
394
|
+
# Basic simplification rules
|
395
|
+
if isinstance(left_simplified, Constant) and isinstance(right_simplified, Constant):
|
396
|
+
# Evaluate constant expressions
|
397
|
+
dummy_vars = {}
|
398
|
+
try:
|
399
|
+
result = BinaryOperation(self.operator, left_simplified, right_simplified).evaluate(dummy_vars)
|
400
|
+
return Constant(result)
|
401
|
+
except (ValueError, TypeError, ArithmeticError):
|
402
|
+
pass
|
403
|
+
|
404
|
+
return BinaryOperation(self.operator, left_simplified, right_simplified)
|
405
|
+
|
406
|
+
def __str__(self) -> str:
|
407
|
+
# Try to auto-evaluate if all variables are available
|
408
|
+
can_eval, variables = self._can_auto_evaluate()
|
409
|
+
if can_eval:
|
410
|
+
try:
|
411
|
+
result = self.evaluate(variables)
|
412
|
+
return str(result)
|
413
|
+
except Exception:
|
414
|
+
pass # Fall back to symbolic representation
|
415
|
+
|
416
|
+
# Handle operator precedence for cleaner string representation
|
417
|
+
precedence = {'+': 1, '-': 1, '*': 2, '/': 2, '**': 3, '<': 0, '<=': 0, '>': 0, '>=': 0, '==': 0, '!=': 0}
|
418
|
+
left_str = str(self.left)
|
419
|
+
right_str = str(self.right)
|
420
|
+
|
421
|
+
# Add parentheses for left side when precedence is strictly lower
|
422
|
+
if isinstance(self.left, BinaryOperation) and precedence.get(self.left.operator, 0) < precedence.get(self.operator, 0):
|
423
|
+
left_str = f"({left_str})"
|
424
|
+
|
425
|
+
# CRITICAL FIX: For right side, add parentheses when:
|
426
|
+
# 1. Precedence is strictly lower, OR
|
427
|
+
# 2. Precedence is equal AND operation is left-associative (-, /)
|
428
|
+
if isinstance(self.right, BinaryOperation):
|
429
|
+
right_prec = precedence.get(self.right.operator, 0)
|
430
|
+
curr_prec = precedence.get(self.operator, 0)
|
431
|
+
|
432
|
+
# Need parentheses if:
|
433
|
+
# - Right has lower precedence, OR
|
434
|
+
# - Same precedence and current operator is left-associative (- or /)
|
435
|
+
if (right_prec < curr_prec or
|
436
|
+
(right_prec == curr_prec and self.operator in ['-', '/'])):
|
437
|
+
right_str = f"({right_str})"
|
438
|
+
|
439
|
+
return f"{left_str} {self.operator} {right_str}"
|
440
|
+
|
441
|
+
|
442
|
+
class UnaryFunction(Expression):
|
443
|
+
"""Unary mathematical function expression."""
|
444
|
+
__slots__ = ('function_name', 'operand')
|
445
|
+
|
446
|
+
def __init__(self, function_name: str, operand: Expression):
|
447
|
+
self.function_name = function_name
|
448
|
+
self.operand = operand
|
449
|
+
|
450
|
+
def evaluate(self, variable_values: dict[str, 'TypeSafeVariable']) -> 'Quantity':
|
451
|
+
|
452
|
+
operand_val = self.operand.evaluate(variable_values)
|
453
|
+
|
454
|
+
if self.function_name == 'sin':
|
455
|
+
# Assume input is in radians, result is dimensionless
|
456
|
+
result_value = math.sin(operand_val.value)
|
457
|
+
return Quantity(result_value, DimensionlessUnits.dimensionless)
|
458
|
+
elif self.function_name == 'cos':
|
459
|
+
result_value = math.cos(operand_val.value)
|
460
|
+
return Quantity(result_value, DimensionlessUnits.dimensionless)
|
461
|
+
elif self.function_name == 'tan':
|
462
|
+
result_value = math.tan(operand_val.value)
|
463
|
+
return Quantity(result_value, DimensionlessUnits.dimensionless)
|
464
|
+
elif self.function_name == 'sqrt':
|
465
|
+
# For sqrt, we need to handle units carefully
|
466
|
+
result_value = math.sqrt(operand_val.value)
|
467
|
+
# This is simplified - proper unit handling would need dimensional analysis
|
468
|
+
return Quantity(result_value, operand_val.unit)
|
469
|
+
elif self.function_name == 'abs':
|
470
|
+
return Quantity(abs(operand_val.value), operand_val.unit)
|
471
|
+
elif self.function_name == 'ln':
|
472
|
+
# Natural log - input should be dimensionless
|
473
|
+
result_value = math.log(operand_val.value)
|
474
|
+
return Quantity(result_value, DimensionlessUnits.dimensionless)
|
475
|
+
elif self.function_name == 'log10':
|
476
|
+
result_value = math.log10(operand_val.value)
|
477
|
+
return Quantity(result_value, DimensionlessUnits.dimensionless)
|
478
|
+
elif self.function_name == 'exp':
|
479
|
+
# Exponential - input should be dimensionless
|
480
|
+
result_value = math.exp(operand_val.value)
|
481
|
+
return Quantity(result_value, DimensionlessUnits.dimensionless)
|
482
|
+
else:
|
483
|
+
raise ValueError(f"Unknown function: {self.function_name}")
|
484
|
+
|
485
|
+
def get_variables(self) -> set[str]:
|
486
|
+
return self.operand.get_variables()
|
487
|
+
|
488
|
+
def simplify(self) -> Expression:
|
489
|
+
simplified_operand = self.operand.simplify()
|
490
|
+
if isinstance(simplified_operand, Constant):
|
491
|
+
# Evaluate constant functions at compile time
|
492
|
+
try:
|
493
|
+
dummy_vars = {}
|
494
|
+
result = UnaryFunction(self.function_name, simplified_operand).evaluate(dummy_vars)
|
495
|
+
return Constant(result)
|
496
|
+
except (ValueError, TypeError, ArithmeticError):
|
497
|
+
pass
|
498
|
+
return UnaryFunction(self.function_name, simplified_operand)
|
499
|
+
|
500
|
+
def __str__(self) -> str:
|
501
|
+
return f"{self.function_name}({self.operand})"
|
502
|
+
|
503
|
+
|
504
|
+
class ConditionalExpression(Expression):
|
505
|
+
"""Conditional expression: if condition then true_expr else false_expr."""
|
506
|
+
__slots__ = ('condition', 'true_expr', 'false_expr')
|
507
|
+
|
508
|
+
def __init__(self, condition: Expression, true_expr: Expression, false_expr: Expression):
|
509
|
+
self.condition = condition
|
510
|
+
self.true_expr = true_expr
|
511
|
+
self.false_expr = false_expr
|
512
|
+
|
513
|
+
def evaluate(self, variable_values: dict[str, 'TypeSafeVariable']) -> 'Quantity':
|
514
|
+
condition_val = self.condition.evaluate(variable_values)
|
515
|
+
# Consider non-zero as True
|
516
|
+
if abs(condition_val.value) > 1e-10:
|
517
|
+
return self.true_expr.evaluate(variable_values)
|
518
|
+
else:
|
519
|
+
return self.false_expr.evaluate(variable_values)
|
520
|
+
|
521
|
+
def get_variables(self) -> set[str]:
|
522
|
+
return (self.condition.get_variables() |
|
523
|
+
self.true_expr.get_variables() |
|
524
|
+
self.false_expr.get_variables())
|
525
|
+
|
526
|
+
def simplify(self) -> Expression:
|
527
|
+
simplified_condition = self.condition.simplify()
|
528
|
+
simplified_true = self.true_expr.simplify()
|
529
|
+
simplified_false = self.false_expr.simplify()
|
530
|
+
|
531
|
+
# If condition is constant, choose the appropriate branch
|
532
|
+
if isinstance(simplified_condition, Constant):
|
533
|
+
try:
|
534
|
+
dummy_vars = {}
|
535
|
+
condition_val = simplified_condition.evaluate(dummy_vars)
|
536
|
+
if abs(condition_val.value) > 1e-10:
|
537
|
+
return simplified_true
|
538
|
+
else:
|
539
|
+
return simplified_false
|
540
|
+
except (ValueError, TypeError, ArithmeticError):
|
541
|
+
pass
|
542
|
+
|
543
|
+
return ConditionalExpression(simplified_condition, simplified_true, simplified_false)
|
544
|
+
|
545
|
+
def __str__(self) -> str:
|
546
|
+
return f"if({self.condition}, {self.true_expr}, {self.false_expr})"
|
File without changes
|