pineforge-codegen 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pineforge_codegen/__init__.py +53 -0
- pineforge_codegen/analyzer/__init__.py +60 -0
- pineforge_codegen/analyzer/base.py +1563 -0
- pineforge_codegen/analyzer/call_handlers.py +895 -0
- pineforge_codegen/analyzer/contracts.py +163 -0
- pineforge_codegen/analyzer/diagnostics.py +118 -0
- pineforge_codegen/analyzer/tables.py +204 -0
- pineforge_codegen/analyzer/types.py +250 -0
- pineforge_codegen/ast_nodes.py +293 -0
- pineforge_codegen/codegen/__init__.py +78 -0
- pineforge_codegen/codegen/base.py +1381 -0
- pineforge_codegen/codegen/emit_top.py +875 -0
- pineforge_codegen/codegen/helpers.py +163 -0
- pineforge_codegen/codegen/helpers_syminfo.py +134 -0
- pineforge_codegen/codegen/input.py +189 -0
- pineforge_codegen/codegen/security.py +1564 -0
- pineforge_codegen/codegen/ta.py +298 -0
- pineforge_codegen/codegen/tables.py +613 -0
- pineforge_codegen/codegen/types.py +573 -0
- pineforge_codegen/codegen/visit_call.py +1305 -0
- pineforge_codegen/codegen/visit_expr.py +701 -0
- pineforge_codegen/codegen/visit_stmt.py +729 -0
- pineforge_codegen/errors.py +98 -0
- pineforge_codegen/lexer.py +531 -0
- pineforge_codegen/parser.py +1198 -0
- pineforge_codegen/pragmas.py +117 -0
- pineforge_codegen/signatures.py +808 -0
- pineforge_codegen/support_checker.py +1111 -0
- pineforge_codegen/symbols.py +118 -0
- pineforge_codegen/tokens.py +406 -0
- pineforge_codegen/tv_input_choices.py +86 -0
- pineforge_codegen-0.6.5.dist-info/METADATA +462 -0
- pineforge_codegen-0.6.5.dist-info/RECORD +35 -0
- pineforge_codegen-0.6.5.dist-info/WHEEL +4 -0
- pineforge_codegen-0.6.5.dist-info/licenses/LICENSE +197 -0
|
@@ -0,0 +1,895 @@
|
|
|
1
|
+
"""Per-callee dispatch + bookkeeping for the analyzer.
|
|
2
|
+
|
|
3
|
+
This is the largest analyzer mixin (~500 lines). It owns the
|
|
4
|
+
``_handle_*_call`` family that routes Pine ``ta.*`` / ``request.*``
|
|
5
|
+
/ ``strategy.*`` / ``input.*`` / ``fixnan(...)`` / user-defined
|
|
6
|
+
function calls into TA-call-site allocation, ``input.*`` defval
|
|
7
|
+
inference, ``request.security()`` recording, ``fixnan`` site
|
|
8
|
+
allocation, and per-call-site cloning of TA + series state for
|
|
9
|
+
user functions called more than once per bar.
|
|
10
|
+
|
|
11
|
+
Mixin contract -- host class must provide the following attributes
|
|
12
|
+
(all set by ``Analyzer.__init__`` unless noted):
|
|
13
|
+
|
|
14
|
+
- ``self._ta_call_sites`` (``list[TACallSite]``): TA call-site
|
|
15
|
+
registry. Appended to by ``_handle_ta_call`` and (per-call-site
|
|
16
|
+
clones only) by ``_handle_user_func_call``.
|
|
17
|
+
- ``self._ta_counter`` (``int``): monotonically increasing TA member
|
|
18
|
+
index used to mint unique ``_ta_<func>_<n>`` member names.
|
|
19
|
+
- ``self._series_bar_fields`` (``set[str]``): bar-field identifiers
|
|
20
|
+
used as TA inputs anywhere in the program.
|
|
21
|
+
- ``self._security_calls`` (``list[SecurityCallInfo]``): created on
|
|
22
|
+
first ``request.security(...)`` via ``getattr(..., "_security_calls",
|
|
23
|
+
[])`` and stored back on ``self`` -- the analyzer also reads this
|
|
24
|
+
attribute via ``getattr`` in ``analyze()``.
|
|
25
|
+
- ``self._fixnan_counter`` / ``self._fixnan_sites``: ``fixnan(...)``
|
|
26
|
+
member-name counter + site list.
|
|
27
|
+
- ``self._symbols`` (``SymbolTable``): consulted by
|
|
28
|
+
``_handle_input_call`` to type a series-defval input.
|
|
29
|
+
- ``self._enum_defs`` (``dict[str, list[str]]``): enum schema, used
|
|
30
|
+
by ``_validate_input_member_tv`` for input.enum() checks.
|
|
31
|
+
- ``self._func_defs`` / ``self._func_return_types`` /
|
|
32
|
+
``self._func_returns_tuple`` / ``self._func_tuple_element_count``:
|
|
33
|
+
user-function metadata captured during initial pass; consumed by
|
|
34
|
+
``_handle_user_func_call``.
|
|
35
|
+
- ``self._func_series_vars`` / ``self._func_var_members`` /
|
|
36
|
+
``self._func_ta_ranges``: per-function state needed for
|
|
37
|
+
call-site cloning.
|
|
38
|
+
- ``self._func_call_site_count`` / ``self._func_call_cs_map``:
|
|
39
|
+
per-call-site indices populated by ``_handle_user_func_call``.
|
|
40
|
+
- ``self._func_infos`` (``list[FuncInfo]``): the function-info list
|
|
41
|
+
surfaced through ``AnalyzerContext``.
|
|
42
|
+
|
|
43
|
+
Sibling-mixin methods consumed via ``self``:
|
|
44
|
+
|
|
45
|
+
- ``self._visit`` -- visitor entry (``Analyzer.base``).
|
|
46
|
+
- ``self._expr_to_str`` -- expression stringifier
|
|
47
|
+
(``DiagnosticsHelper``); used by ``_handle_ta_call`` to render
|
|
48
|
+
ctor args and by ``_handle_user_func_call`` for the param
|
|
49
|
+
substitution map.
|
|
50
|
+
- ``self._warn`` / ``self._error`` (``DiagnosticsHelper``).
|
|
51
|
+
- ``self._warn_if_unknown_source_id`` (``DiagnosticsHelper``).
|
|
52
|
+
- ``self._input_diag_loc`` (``DiagnosticsHelper``).
|
|
53
|
+
- ``self._extract_literal_value`` (``TypeHelper``).
|
|
54
|
+
- ``self._collect_security_mutable_globals`` (``Analyzer.base``):
|
|
55
|
+
stays on the host class because it walks the AST collecting
|
|
56
|
+
mutable globals -- not an isolated call-handling concern.
|
|
57
|
+
|
|
58
|
+
Output dataclasses (``TACallSite`` / ``FuncInfo`` / ``FixnanCallSite``
|
|
59
|
+
/ ``SecurityCallInfo``) are imported from sibling ``contracts.py`` so
|
|
60
|
+
the analyzer package's import graph stays a strict DAG with no cycle
|
|
61
|
+
back through ``base.py``.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
from __future__ import annotations
|
|
65
|
+
|
|
66
|
+
from typing import Any
|
|
67
|
+
|
|
68
|
+
from ..ast_nodes import (
|
|
69
|
+
ASTNode, BoolLiteral, FuncCall, Identifier, MemberAccess,
|
|
70
|
+
NumberLiteral, StringLiteral, TupleLiteral,
|
|
71
|
+
)
|
|
72
|
+
from ..symbols import PineType
|
|
73
|
+
from .. import signatures as sigs
|
|
74
|
+
from .. import tv_input_choices as tv_in
|
|
75
|
+
from .contracts import FixnanCallSite, FuncInfo, SecurityCallInfo, TACallSite
|
|
76
|
+
from .tables import (
|
|
77
|
+
BAR_FIELDS, TA_CLASS_MAP, TA_MULTI_CTOR, TA_NO_CTOR, TA_PERIOD_ARG,
|
|
78
|
+
TA_TUPLE_RETURNS,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class CallHandlers:
|
|
83
|
+
"""``_handle_*_call`` dispatch + bookkeeping for analyzer call-sites.
|
|
84
|
+
|
|
85
|
+
Mixed into ``Analyzer``; not meant to be instantiated standalone.
|
|
86
|
+
See the module docstring for the host-class state contract."""
|
|
87
|
+
|
|
88
|
+
# ------------------------------------------------------------------
|
|
89
|
+
# TA call handling
|
|
90
|
+
# ------------------------------------------------------------------
|
|
91
|
+
|
|
92
|
+
def _merge_ta_args(self, func_name: str, node: FuncCall) -> list:
|
|
93
|
+
"""Merge positional args and kwargs into a unified positional list."""
|
|
94
|
+
param_names = sigs.get_param_names("ta", func_name)
|
|
95
|
+
if param_names is None and func_name == "sum":
|
|
96
|
+
param_names = sigs.get_param_names("math", "sum")
|
|
97
|
+
if param_names is None or not node.kwargs:
|
|
98
|
+
return list(node.args)
|
|
99
|
+
|
|
100
|
+
# Start with positional args
|
|
101
|
+
merged = list(node.args)
|
|
102
|
+
# Fill in kwargs at their expected positions
|
|
103
|
+
for i, pname in enumerate(param_names):
|
|
104
|
+
if pname in node.kwargs:
|
|
105
|
+
# Extend list if needed
|
|
106
|
+
while len(merged) <= i:
|
|
107
|
+
merged.append(None)
|
|
108
|
+
if merged[i] is None:
|
|
109
|
+
merged[i] = node.kwargs[pname]
|
|
110
|
+
# Remove trailing Nones
|
|
111
|
+
while merged and merged[-1] is None:
|
|
112
|
+
merged.pop()
|
|
113
|
+
return merged
|
|
114
|
+
|
|
115
|
+
def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType:
|
|
116
|
+
"""Handle ta.* function calls."""
|
|
117
|
+
# Visit all args for side effects (series detection, etc.)
|
|
118
|
+
for arg in node.args:
|
|
119
|
+
self._visit(arg)
|
|
120
|
+
for val in node.kwargs.values():
|
|
121
|
+
self._visit(val)
|
|
122
|
+
|
|
123
|
+
# ta.pivot_point_levels is a free runtime function (not a stateful
|
|
124
|
+
# indicator), but its codegen lowers to use `_s_high[1]`, `_s_low[1]`,
|
|
125
|
+
# `_s_close[1]` so the pivot is calculated from the PREVIOUS bar's
|
|
126
|
+
# HLC (matching Pine v6 semantics where `developing` defaults to
|
|
127
|
+
# false). Register the bar-field history series here so that the
|
|
128
|
+
# codegen emits the corresponding `Series<double> _s_high/...` members
|
|
129
|
+
# and pushes them at the top of every on_bar tick.
|
|
130
|
+
if func_name == "pivot_point_levels":
|
|
131
|
+
self._series_bar_fields.add("high")
|
|
132
|
+
self._series_bar_fields.add("low")
|
|
133
|
+
self._series_bar_fields.add("close")
|
|
134
|
+
return PineType.FLOAT # actual array<float> handled by type inference
|
|
135
|
+
|
|
136
|
+
# ta.vwap(source, anchor, stdev_mult) → 3-arg bands form.
|
|
137
|
+
# When called with 3 args (or anchor/stdev_mult kwargs), remap to the
|
|
138
|
+
# internal "vwap_bands" key which maps to ta::VWAPBands (returns tuple).
|
|
139
|
+
if func_name == "vwap":
|
|
140
|
+
param_names_v = ["source", "anchor", "stdev_mult"]
|
|
141
|
+
merged_v = list(node.args)
|
|
142
|
+
for i, pname in enumerate(param_names_v):
|
|
143
|
+
if pname in node.kwargs:
|
|
144
|
+
while len(merged_v) <= i:
|
|
145
|
+
merged_v.append(None)
|
|
146
|
+
if merged_v[i] is None:
|
|
147
|
+
merged_v[i] = node.kwargs[pname]
|
|
148
|
+
if len(merged_v) >= 3:
|
|
149
|
+
func_name = "vwap_bands"
|
|
150
|
+
|
|
151
|
+
if func_name not in TA_CLASS_MAP:
|
|
152
|
+
return PineType.FLOAT
|
|
153
|
+
|
|
154
|
+
# Merge positional + kwargs into a unified arg list
|
|
155
|
+
all_args = self._merge_ta_args(func_name, node)
|
|
156
|
+
|
|
157
|
+
# ta.tr(handle_na) — TV v6 default for handle_na is false. When the
|
|
158
|
+
# caller omits the arg, inject the explicit ``false`` so the C++
|
|
159
|
+
# TR ctor receives an unambiguous compile-time literal at the
|
|
160
|
+
# initializer-list site (`_ta_tr_1(false)`).
|
|
161
|
+
if func_name == "tr" and not all_args:
|
|
162
|
+
default_arg = BoolLiteral(value=False)
|
|
163
|
+
self._visit(default_arg)
|
|
164
|
+
all_args = [default_arg]
|
|
165
|
+
|
|
166
|
+
if func_name == "vwap" and not all_args:
|
|
167
|
+
default_src = Identifier(name="close")
|
|
168
|
+
self._visit(default_src)
|
|
169
|
+
self._series_bar_fields.add("close")
|
|
170
|
+
all_args = [default_src]
|
|
171
|
+
|
|
172
|
+
# Handle ta.highest(length) / ta.lowest(length) with 1 arg:
|
|
173
|
+
# single arg is the length, source defaults to high/low respectively.
|
|
174
|
+
# Remap so all_args = [default_source, length_arg].
|
|
175
|
+
_DEFAULT_SOURCE = {"highest": "high", "lowest": "low"}
|
|
176
|
+
if func_name in _DEFAULT_SOURCE and len(all_args) == 1:
|
|
177
|
+
default_src = Identifier(name=_DEFAULT_SOURCE[func_name])
|
|
178
|
+
self._visit(default_src)
|
|
179
|
+
self._series_bar_fields.add(_DEFAULT_SOURCE[func_name])
|
|
180
|
+
all_args = [default_src, all_args[0]]
|
|
181
|
+
|
|
182
|
+
self._ta_counter += 1
|
|
183
|
+
class_name = TA_CLASS_MAP[func_name]
|
|
184
|
+
member_name = f"_ta_{func_name}_{self._ta_counter}"
|
|
185
|
+
returns_tuple = func_name in TA_TUPLE_RETURNS
|
|
186
|
+
|
|
187
|
+
# vwap_bands special dispatch: ta.vwap(source, anchor, stdev_mult)
|
|
188
|
+
# ctor receives stdev_mult only; compute receives source only.
|
|
189
|
+
# The anchor arg (index 1) is the Pine-level "when to reset" series;
|
|
190
|
+
# our VWAPBands wrapper uses UTC-day boundaries matching the daily
|
|
191
|
+
# anchor default, so anchor is intentionally ignored in codegen.
|
|
192
|
+
if func_name == "vwap_bands":
|
|
193
|
+
ctor_args: list[str] = []
|
|
194
|
+
if len(all_args) >= 3 and all_args[2] is not None:
|
|
195
|
+
ctor_args = [self._expr_to_str(all_args[2])]
|
|
196
|
+
compute_args: list = []
|
|
197
|
+
if all_args and all_args[0] is not None:
|
|
198
|
+
compute_args = [all_args[0]]
|
|
199
|
+
is_static = self._global_scope and all(self._is_static_expression(arg) for arg in compute_args)
|
|
200
|
+
site = TACallSite(
|
|
201
|
+
member_name=member_name,
|
|
202
|
+
class_name=class_name,
|
|
203
|
+
ctor_args=ctor_args,
|
|
204
|
+
compute_args=compute_args,
|
|
205
|
+
returns_tuple=returns_tuple,
|
|
206
|
+
node=node,
|
|
207
|
+
is_static=is_static,
|
|
208
|
+
)
|
|
209
|
+
self._ta_call_sites.append(site)
|
|
210
|
+
return PineType.FLOAT
|
|
211
|
+
|
|
212
|
+
# Determine constructor args
|
|
213
|
+
ctor_args: list[str] = []
|
|
214
|
+
effective_multi_ctor = TA_MULTI_CTOR.copy()
|
|
215
|
+
if func_name in ("pivothigh", "pivotlow") and len(all_args) == 3:
|
|
216
|
+
effective_multi_ctor[func_name] = [1, 2]
|
|
217
|
+
|
|
218
|
+
if func_name in TA_NO_CTOR:
|
|
219
|
+
pass
|
|
220
|
+
elif func_name in effective_multi_ctor:
|
|
221
|
+
for idx in effective_multi_ctor[func_name]:
|
|
222
|
+
if idx < len(all_args) and all_args[idx] is not None:
|
|
223
|
+
ctor_args.append(self._expr_to_str(all_args[idx]))
|
|
224
|
+
elif func_name in TA_PERIOD_ARG:
|
|
225
|
+
idx = TA_PERIOD_ARG[func_name]
|
|
226
|
+
if idx < len(all_args) and all_args[idx] is not None:
|
|
227
|
+
ctor_args.append(self._expr_to_str(all_args[idx]))
|
|
228
|
+
|
|
229
|
+
# Determine compute args (all args that aren't ctor args)
|
|
230
|
+
compute_args: list = []
|
|
231
|
+
ctor_indices = set()
|
|
232
|
+
if func_name in effective_multi_ctor:
|
|
233
|
+
ctor_indices = set(effective_multi_ctor[func_name])
|
|
234
|
+
elif func_name in TA_PERIOD_ARG:
|
|
235
|
+
ctor_indices = {TA_PERIOD_ARG[func_name]}
|
|
236
|
+
|
|
237
|
+
for i, arg in enumerate(all_args):
|
|
238
|
+
if i not in ctor_indices and arg is not None:
|
|
239
|
+
compute_args.append(arg)
|
|
240
|
+
|
|
241
|
+
is_static = self._global_scope and all(self._is_static_expression(arg) for arg in compute_args)
|
|
242
|
+
site = TACallSite(
|
|
243
|
+
member_name=member_name,
|
|
244
|
+
class_name=class_name,
|
|
245
|
+
ctor_args=ctor_args,
|
|
246
|
+
compute_args=compute_args,
|
|
247
|
+
returns_tuple=returns_tuple,
|
|
248
|
+
node=node,
|
|
249
|
+
is_static=is_static,
|
|
250
|
+
)
|
|
251
|
+
self._ta_call_sites.append(site)
|
|
252
|
+
|
|
253
|
+
return PineType.FLOAT
|
|
254
|
+
|
|
255
|
+
def _handle_request_call(self, func_name: str, node: FuncCall) -> PineType:
|
|
256
|
+
"""Handle request.* function calls."""
|
|
257
|
+
if func_name == "security":
|
|
258
|
+
param_names = ["symbol", "timeframe", "expression", "gaps", "lookahead",
|
|
259
|
+
"ignore_invalid_symbol", "currency"]
|
|
260
|
+
all_args = list(node.args)
|
|
261
|
+
for i, pname in enumerate(param_names):
|
|
262
|
+
if pname in node.kwargs:
|
|
263
|
+
while len(all_args) <= i:
|
|
264
|
+
all_args.append(None)
|
|
265
|
+
all_args[i] = node.kwargs[pname]
|
|
266
|
+
|
|
267
|
+
tf_node = all_args[1] if len(all_args) > 1 else None
|
|
268
|
+
expr_node = all_args[2] if len(all_args) > 2 else None
|
|
269
|
+
|
|
270
|
+
# Visit non-expression args first (symbol, tf, gaps, lookahead)
|
|
271
|
+
for arg in node.args:
|
|
272
|
+
if arg is not None and arg is not expr_node:
|
|
273
|
+
self._visit(arg)
|
|
274
|
+
for k, val in node.kwargs.items():
|
|
275
|
+
if val is not None and val is not expr_node:
|
|
276
|
+
self._visit(val)
|
|
277
|
+
|
|
278
|
+
# Track TA sites created by the expression
|
|
279
|
+
ta_start = len(self._ta_call_sites)
|
|
280
|
+
if expr_node is not None:
|
|
281
|
+
self._visit(expr_node)
|
|
282
|
+
ta_end = len(self._ta_call_sites)
|
|
283
|
+
security_ta_range = (ta_start, ta_end) if ta_end > ta_start else None
|
|
284
|
+
|
|
285
|
+
# Assign ID and record the call
|
|
286
|
+
self._security_calls = getattr(self, "_security_calls", [])
|
|
287
|
+
sec_id = len(self._security_calls)
|
|
288
|
+
|
|
289
|
+
returns_tuple = isinstance(expr_node, TupleLiteral)
|
|
290
|
+
tuple_size = len(expr_node.elements) if returns_tuple else 0
|
|
291
|
+
|
|
292
|
+
gaps_node = all_args[3] if len(all_args) > 3 else None
|
|
293
|
+
lookahead_node = all_args[4] if len(all_args) > 4 else None
|
|
294
|
+
|
|
295
|
+
mutable_globals = tuple(sorted(self._collect_security_mutable_globals(expr_node)))
|
|
296
|
+
self._security_calls.append(SecurityCallInfo(
|
|
297
|
+
sec_id=sec_id,
|
|
298
|
+
timeframe=tf_node,
|
|
299
|
+
expression=expr_node,
|
|
300
|
+
returns_tuple=returns_tuple,
|
|
301
|
+
tuple_size=tuple_size,
|
|
302
|
+
gaps=gaps_node,
|
|
303
|
+
lookahead=lookahead_node,
|
|
304
|
+
ta_range=security_ta_range,
|
|
305
|
+
depends_on_mutable_globals=bool(mutable_globals),
|
|
306
|
+
mutable_globals=mutable_globals,
|
|
307
|
+
))
|
|
308
|
+
|
|
309
|
+
return PineType.FLOAT
|
|
310
|
+
|
|
311
|
+
if func_name == "security_lower_tf":
|
|
312
|
+
return self._handle_request_security_lower_tf(node)
|
|
313
|
+
|
|
314
|
+
# Fallback for other request.*
|
|
315
|
+
for arg in node.args:
|
|
316
|
+
self._visit(arg)
|
|
317
|
+
for val in node.kwargs.values():
|
|
318
|
+
self._visit(val)
|
|
319
|
+
return PineType.FLOAT
|
|
320
|
+
|
|
321
|
+
def _handle_request_security_lower_tf(self, node: FuncCall) -> PineType:
|
|
322
|
+
"""Lower ``request.security_lower_tf(symbol, timeframe, expression, ...)``.
|
|
323
|
+
|
|
324
|
+
TV signature differs from ``request.security``: there is no ``gaps``
|
|
325
|
+
or ``lookahead`` keyword (lower-TF emulation pins both off), and
|
|
326
|
+
the result is an ``array<T>`` with one element per synthesised
|
|
327
|
+
sub-bar of the current chart bar instead of a scalar T.
|
|
328
|
+
|
|
329
|
+
We piggy-back on the existing ``SecurityCallInfo`` plumbing — same
|
|
330
|
+
``sec_id`` allocation, same TA-binding-stack collection, same
|
|
331
|
+
mutable-global discovery — but flip ``is_lower_tf_array=True`` so
|
|
332
|
+
the codegen knows to emit a ``std::vector<T>`` accumulator that
|
|
333
|
+
gets cleared on sub-bar 0 and pushed-to per sub-bar.
|
|
334
|
+
|
|
335
|
+
UDT / color / string element types are deliberately rejected here
|
|
336
|
+
with a precise diagnostic; the runtime path only knows how to
|
|
337
|
+
accumulate ``double`` / ``int`` / ``bool``."""
|
|
338
|
+
param_names = ["symbol", "timeframe", "expression",
|
|
339
|
+
"ignore_invalid_symbol", "currency",
|
|
340
|
+
"ignore_invalid_timeframe", "calc_bars_count"]
|
|
341
|
+
|
|
342
|
+
unknown = set(node.kwargs) - set(param_names)
|
|
343
|
+
if unknown:
|
|
344
|
+
self._error(
|
|
345
|
+
"request.security_lower_tf has unknown parameter(s): "
|
|
346
|
+
+ ", ".join(sorted(unknown))
|
|
347
|
+
+ ". Supported parameters: " + ", ".join(param_names),
|
|
348
|
+
node.loc,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
all_args = list(node.args)
|
|
352
|
+
for i, pname in enumerate(param_names):
|
|
353
|
+
if pname in node.kwargs:
|
|
354
|
+
while len(all_args) <= i:
|
|
355
|
+
all_args.append(None)
|
|
356
|
+
all_args[i] = node.kwargs[pname]
|
|
357
|
+
|
|
358
|
+
tf_node = all_args[1] if len(all_args) > 1 else None
|
|
359
|
+
expr_node = all_args[2] if len(all_args) > 2 else None
|
|
360
|
+
|
|
361
|
+
if expr_node is None:
|
|
362
|
+
self._error(
|
|
363
|
+
"request.security_lower_tf requires an expression argument",
|
|
364
|
+
node.loc,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if isinstance(expr_node, TupleLiteral):
|
|
368
|
+
self._error(
|
|
369
|
+
"request.security_lower_tf does not support tuple expressions yet. "
|
|
370
|
+
"Issue separate request.security_lower_tf calls for each series.",
|
|
371
|
+
node.loc,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
for arg in node.args:
|
|
375
|
+
if arg is not None and arg is not expr_node:
|
|
376
|
+
self._visit(arg)
|
|
377
|
+
for k, val in node.kwargs.items():
|
|
378
|
+
if val is not None and val is not expr_node:
|
|
379
|
+
self._visit(val)
|
|
380
|
+
|
|
381
|
+
ta_start = len(self._ta_call_sites)
|
|
382
|
+
expr_pine_type = PineType.FLOAT
|
|
383
|
+
if expr_node is not None:
|
|
384
|
+
expr_pine_type = self._visit(expr_node)
|
|
385
|
+
ta_end = len(self._ta_call_sites)
|
|
386
|
+
security_ta_range = (ta_start, ta_end) if ta_end > ta_start else None
|
|
387
|
+
|
|
388
|
+
# Cache the resolved element type on the call node so the
|
|
389
|
+
# ``_type_spec_from_expr`` pass can map ``request.security_lower_tf``
|
|
390
|
+
# to ``array<T>`` without re-visiting the expression (which would
|
|
391
|
+
# double-allocate TA call sites for expressions like ``ta.ema``).
|
|
392
|
+
cached_anns = getattr(node, "annotations", None) or {}
|
|
393
|
+
cached_anns["lower_tf_element_pine_type"] = expr_pine_type
|
|
394
|
+
node.annotations = cached_anns
|
|
395
|
+
|
|
396
|
+
if expr_pine_type not in (PineType.FLOAT, PineType.INT, PineType.BOOL,
|
|
397
|
+
PineType.NA, PineType.UNKNOWN):
|
|
398
|
+
element_label = {
|
|
399
|
+
PineType.STRING: "string",
|
|
400
|
+
PineType.COLOR: "color",
|
|
401
|
+
}.get(expr_pine_type, str(expr_pine_type))
|
|
402
|
+
self._error(
|
|
403
|
+
"request.security_lower_tf element type '" + element_label
|
|
404
|
+
+ "' is not yet supported. Supported element types: float, int, bool. "
|
|
405
|
+
"UDT / color / string element types are out of scope.",
|
|
406
|
+
node.loc,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
self._security_calls = getattr(self, "_security_calls", [])
|
|
410
|
+
sec_id = len(self._security_calls)
|
|
411
|
+
|
|
412
|
+
mutable_globals = tuple(sorted(self._collect_security_mutable_globals(expr_node)))
|
|
413
|
+
self._security_calls.append(SecurityCallInfo(
|
|
414
|
+
sec_id=sec_id,
|
|
415
|
+
timeframe=tf_node,
|
|
416
|
+
expression=expr_node,
|
|
417
|
+
returns_tuple=False,
|
|
418
|
+
tuple_size=0,
|
|
419
|
+
gaps=None,
|
|
420
|
+
lookahead=None,
|
|
421
|
+
ta_range=security_ta_range,
|
|
422
|
+
depends_on_mutable_globals=bool(mutable_globals),
|
|
423
|
+
mutable_globals=mutable_globals,
|
|
424
|
+
is_lower_tf_array=True,
|
|
425
|
+
))
|
|
426
|
+
|
|
427
|
+
# ``request.security_lower_tf`` returns an array; the value-level
|
|
428
|
+
# PineType remains UNKNOWN here so callers fall through to
|
|
429
|
+
# ``_type_spec_from_expr`` for the structured ``array<T>`` spec.
|
|
430
|
+
return PineType.UNKNOWN
|
|
431
|
+
|
|
432
|
+
def _handle_strategy_call(self, func_name: str, node: FuncCall) -> PineType:
|
|
433
|
+
"""Handle strategy.* function calls."""
|
|
434
|
+
for arg in node.args:
|
|
435
|
+
self._visit(arg)
|
|
436
|
+
for val in node.kwargs.values():
|
|
437
|
+
self._visit(val)
|
|
438
|
+
if func_name in ("convert_to_account", "convert_to_symbol", "default_entry_qty"):
|
|
439
|
+
return PineType.FLOAT
|
|
440
|
+
return PineType.VOID
|
|
441
|
+
|
|
442
|
+
# ------------------------------------------------------------------
|
|
443
|
+
# Input call handling
|
|
444
|
+
# ------------------------------------------------------------------
|
|
445
|
+
|
|
446
|
+
def _handle_input_call(self, node: FuncCall) -> PineType:
|
|
447
|
+
"""Handle input() calls without qualifier."""
|
|
448
|
+
# First arg is defval
|
|
449
|
+
if node.args:
|
|
450
|
+
defval = node.args[0]
|
|
451
|
+
self._visit(defval)
|
|
452
|
+
# Infer type from defval
|
|
453
|
+
if isinstance(defval, NumberLiteral):
|
|
454
|
+
if isinstance(defval.value, float):
|
|
455
|
+
return PineType.FLOAT
|
|
456
|
+
return PineType.INT
|
|
457
|
+
if isinstance(defval, StringLiteral):
|
|
458
|
+
return PineType.STRING
|
|
459
|
+
if isinstance(defval, BoolLiteral):
|
|
460
|
+
return PineType.BOOL
|
|
461
|
+
if isinstance(defval, Identifier):
|
|
462
|
+
# input(close) => source input
|
|
463
|
+
self._validate_plain_input_source(defval, node)
|
|
464
|
+
sym = self._symbols.resolve(defval.name)
|
|
465
|
+
if sym:
|
|
466
|
+
return sym.pine_type
|
|
467
|
+
return PineType.FLOAT
|
|
468
|
+
# Check kwargs for defval
|
|
469
|
+
if "defval" in node.kwargs:
|
|
470
|
+
defval = node.kwargs["defval"]
|
|
471
|
+
self._visit(defval)
|
|
472
|
+
if isinstance(defval, Identifier):
|
|
473
|
+
self._validate_plain_input_source(defval, node)
|
|
474
|
+
if isinstance(defval, NumberLiteral):
|
|
475
|
+
if isinstance(defval.value, float):
|
|
476
|
+
return PineType.FLOAT
|
|
477
|
+
return PineType.INT
|
|
478
|
+
if isinstance(defval, StringLiteral):
|
|
479
|
+
return PineType.STRING
|
|
480
|
+
if isinstance(defval, BoolLiteral):
|
|
481
|
+
return PineType.BOOL
|
|
482
|
+
|
|
483
|
+
# Visit remaining args
|
|
484
|
+
for arg in node.args[1:]:
|
|
485
|
+
self._visit(arg)
|
|
486
|
+
for val in node.kwargs.values():
|
|
487
|
+
self._visit(val)
|
|
488
|
+
|
|
489
|
+
return PineType.FLOAT # default
|
|
490
|
+
|
|
491
|
+
def _merge_input_params(self, member: str | None, node: FuncCall) -> dict[str, Any]:
|
|
492
|
+
"""Positional + kwargs merged like codegen (for input.* validation)."""
|
|
493
|
+
if member is None:
|
|
494
|
+
param_names = sigs.get_param_names(None, "input")
|
|
495
|
+
else:
|
|
496
|
+
param_names = sigs.get_param_names("input", member)
|
|
497
|
+
if not param_names:
|
|
498
|
+
return {}
|
|
499
|
+
merged: list[Any] = list(node.args)
|
|
500
|
+
for i, pname in enumerate(param_names):
|
|
501
|
+
if pname in node.kwargs:
|
|
502
|
+
while len(merged) <= i:
|
|
503
|
+
merged.append(None)
|
|
504
|
+
if i >= len(merged) or merged[i] is None:
|
|
505
|
+
merged[i] = node.kwargs[pname]
|
|
506
|
+
out: dict[str, Any] = {}
|
|
507
|
+
for i, pname in enumerate(param_names):
|
|
508
|
+
if i < len(merged) and merged[i] is not None:
|
|
509
|
+
out[pname] = merged[i]
|
|
510
|
+
for k, v in node.kwargs.items():
|
|
511
|
+
if k not in out:
|
|
512
|
+
out[k] = v
|
|
513
|
+
return out
|
|
514
|
+
|
|
515
|
+
def _input_enum_type_name(self, node: FuncCall) -> str | None:
|
|
516
|
+
"""If this is input.enum(...) with Enum.member defval, return the enum type name."""
|
|
517
|
+
callee = node.callee
|
|
518
|
+
if not isinstance(callee, MemberAccess):
|
|
519
|
+
return None
|
|
520
|
+
if not isinstance(callee.object, Identifier) or callee.object.name != "input":
|
|
521
|
+
return None
|
|
522
|
+
if callee.member != "enum":
|
|
523
|
+
return None
|
|
524
|
+
merged = self._merge_input_params("enum", node)
|
|
525
|
+
dv = merged.get("defval")
|
|
526
|
+
if dv is None and node.args:
|
|
527
|
+
dv = node.args[0]
|
|
528
|
+
if isinstance(dv, MemberAccess) and isinstance(dv.object, Identifier):
|
|
529
|
+
return dv.object.name
|
|
530
|
+
return None
|
|
531
|
+
|
|
532
|
+
def _validate_plain_input_source(self, defval: ASTNode, node: FuncCall) -> None:
|
|
533
|
+
"""Warn when plain input() uses a series defval unlike TV built-ins."""
|
|
534
|
+
if isinstance(defval, Identifier):
|
|
535
|
+
self._warn_if_unknown_source_id(defval.name, defval, node)
|
|
536
|
+
|
|
537
|
+
def _validate_input_member_tv(self, member: str, node: FuncCall) -> None:
|
|
538
|
+
"""TradingView-style const checks for input.* (warnings only)."""
|
|
539
|
+
merged = self._merge_input_params(member, node)
|
|
540
|
+
defval = merged.get("defval")
|
|
541
|
+
if defval is None and node.args:
|
|
542
|
+
defval = node.args[0]
|
|
543
|
+
|
|
544
|
+
if member == "source" and defval is not None:
|
|
545
|
+
if isinstance(defval, Identifier):
|
|
546
|
+
self._warn_if_unknown_source_id(defval.name, defval, node)
|
|
547
|
+
else:
|
|
548
|
+
self._warn(
|
|
549
|
+
"input.source defval is not a native chart series (open, high, low, close, …); "
|
|
550
|
+
"complex indicators or expressions are not supported in PineForge.",
|
|
551
|
+
self._input_diag_loc(node, defval),
|
|
552
|
+
)
|
|
553
|
+
return
|
|
554
|
+
|
|
555
|
+
if member == "timeframe" and isinstance(defval, StringLiteral):
|
|
556
|
+
if not tv_in.is_valid_timeframe_string(defval.value):
|
|
557
|
+
self._warn(
|
|
558
|
+
f"input.timeframe defval {defval.value!r} is not a typical Pine timeframe string.",
|
|
559
|
+
self._input_diag_loc(node, defval),
|
|
560
|
+
)
|
|
561
|
+
return
|
|
562
|
+
|
|
563
|
+
if member == "session" and isinstance(defval, StringLiteral):
|
|
564
|
+
if not tv_in.is_plausible_session_string(defval.value):
|
|
565
|
+
self._warn(
|
|
566
|
+
f"input.session defval {defval.value!r} may be invalid (expected e.g. "
|
|
567
|
+
"'24x7', '0930-1600', or weekday flags).",
|
|
568
|
+
self._input_diag_loc(node, defval),
|
|
569
|
+
)
|
|
570
|
+
return
|
|
571
|
+
|
|
572
|
+
if member == "string":
|
|
573
|
+
opts = merged.get("options")
|
|
574
|
+
if isinstance(opts, TupleLiteral):
|
|
575
|
+
literals: list[str] = []
|
|
576
|
+
non_const = False
|
|
577
|
+
for el in opts.elements:
|
|
578
|
+
if isinstance(el, StringLiteral):
|
|
579
|
+
literals.append(el.value)
|
|
580
|
+
else:
|
|
581
|
+
non_const = True
|
|
582
|
+
break
|
|
583
|
+
if not non_const and literals and isinstance(defval, StringLiteral):
|
|
584
|
+
if defval.value not in literals:
|
|
585
|
+
self._warn(
|
|
586
|
+
"input.string defval is not among the options list values.",
|
|
587
|
+
self._input_diag_loc(node, defval),
|
|
588
|
+
)
|
|
589
|
+
return
|
|
590
|
+
|
|
591
|
+
if member == "enum" and defval is not None:
|
|
592
|
+
if isinstance(defval, MemberAccess) and isinstance(defval.object, Identifier):
|
|
593
|
+
ename = defval.object.name
|
|
594
|
+
emem = defval.member
|
|
595
|
+
members = self._enum_defs.get(ename)
|
|
596
|
+
if members is None:
|
|
597
|
+
self._error(
|
|
598
|
+
f"Enum '{ename}' must be declared above this input.enum() call "
|
|
599
|
+
"(or the name is misspelled).",
|
|
600
|
+
self._input_diag_loc(node, defval),
|
|
601
|
+
)
|
|
602
|
+
if emem not in members:
|
|
603
|
+
self._warn(
|
|
604
|
+
f"input.enum defval {ename}.{emem} is not a member of enum {ename}.",
|
|
605
|
+
self._input_diag_loc(node, defval),
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
def _handle_input_member_call(self, member: str, node: FuncCall) -> PineType:
|
|
609
|
+
"""Handle input.int(), input.float(), etc."""
|
|
610
|
+
for arg in node.args:
|
|
611
|
+
self._visit(arg)
|
|
612
|
+
for val in node.kwargs.values():
|
|
613
|
+
self._visit(val)
|
|
614
|
+
|
|
615
|
+
self._validate_input_member_tv(member, node)
|
|
616
|
+
|
|
617
|
+
type_map = {
|
|
618
|
+
"int": PineType.INT,
|
|
619
|
+
"float": PineType.FLOAT,
|
|
620
|
+
"bool": PineType.BOOL,
|
|
621
|
+
"string": PineType.STRING,
|
|
622
|
+
"source": PineType.FLOAT,
|
|
623
|
+
"color": PineType.COLOR,
|
|
624
|
+
"enum": PineType.INT,
|
|
625
|
+
"session": PineType.STRING,
|
|
626
|
+
"timeframe": PineType.STRING,
|
|
627
|
+
"time": PineType.INT,
|
|
628
|
+
"symbol": PineType.STRING,
|
|
629
|
+
"price": PineType.FLOAT,
|
|
630
|
+
"text_area": PineType.STRING,
|
|
631
|
+
}
|
|
632
|
+
return type_map.get(member, PineType.FLOAT)
|
|
633
|
+
|
|
634
|
+
def _check_input_call(self, node: FuncCall) -> tuple[PineType, bool, Any] | None:
|
|
635
|
+
"""Check if a FuncCall is an input call and extract default value.
|
|
636
|
+
Returns (type, is_const, const_value) or None.
|
|
637
|
+
"""
|
|
638
|
+
callee = node.callee
|
|
639
|
+
|
|
640
|
+
if isinstance(callee, Identifier) and callee.name == "input":
|
|
641
|
+
# input(defval, ...)
|
|
642
|
+
defval = self._extract_defval(node)
|
|
643
|
+
ptype = self._handle_input_call(node)
|
|
644
|
+
return (ptype, True, defval)
|
|
645
|
+
|
|
646
|
+
if isinstance(callee, MemberAccess):
|
|
647
|
+
if isinstance(callee.object, Identifier) and callee.object.name == "input":
|
|
648
|
+
member = callee.member
|
|
649
|
+
defval = self._extract_defval(node)
|
|
650
|
+
ptype = self._handle_input_member_call(member, node)
|
|
651
|
+
return (ptype, True, defval)
|
|
652
|
+
|
|
653
|
+
return None
|
|
654
|
+
|
|
655
|
+
def _extract_defval(self, node: FuncCall) -> Any:
|
|
656
|
+
"""Extract the default value from an input call."""
|
|
657
|
+
# First positional arg is typically defval
|
|
658
|
+
if node.args:
|
|
659
|
+
first = node.args[0]
|
|
660
|
+
return self._extract_literal_value(first)
|
|
661
|
+
if "defval" in node.kwargs:
|
|
662
|
+
return self._extract_literal_value(node.kwargs["defval"])
|
|
663
|
+
return None
|
|
664
|
+
|
|
665
|
+
# ------------------------------------------------------------------
|
|
666
|
+
# matrix method dispatch
|
|
667
|
+
# ------------------------------------------------------------------
|
|
668
|
+
|
|
669
|
+
def _handle_matrix_method(self, member: str, recv_spec) -> PineType:
|
|
670
|
+
"""Map a matrix.<member>(receiver, ...) call to its PineType.
|
|
671
|
+
|
|
672
|
+
``recv_spec`` is the receiver's :class:`TypeSpec` (kind ``"matrix"``).
|
|
673
|
+
``_type_spec_from_expr`` (in ``analyzer/types.py``) already returns
|
|
674
|
+
the correct structured ``TypeSpec`` for matrix-method calls so codegen
|
|
675
|
+
downstream is unaffected; this helper exists so the smaller
|
|
676
|
+
:class:`PineType` enum surface used by ``_visit_FuncCall`` and
|
|
677
|
+
``_visit_VarDecl`` no longer collapses element types to ``VOID``.
|
|
678
|
+
|
|
679
|
+
Phase D Task 2: previously the general MemberAccess arm in
|
|
680
|
+
:meth:`_visit_FuncCall` returned ``PineType.VOID`` for matrix-method
|
|
681
|
+
calls, so ``v = m.get(0, 0)`` typed ``v`` as ``VOID`` even on
|
|
682
|
+
``matrix<int>``.
|
|
683
|
+
"""
|
|
684
|
+
from ..symbols import TypeSpec
|
|
685
|
+
|
|
686
|
+
if recv_spec is None or recv_spec.kind != "matrix":
|
|
687
|
+
return PineType.VOID
|
|
688
|
+
|
|
689
|
+
elem: TypeSpec | None = recv_spec.element
|
|
690
|
+
|
|
691
|
+
# Element-typed return paths
|
|
692
|
+
if member == "get":
|
|
693
|
+
return self._element_pine_type(elem)
|
|
694
|
+
if member in ("row", "col"):
|
|
695
|
+
# Element type is preserved via TypeSpec.array(elem); the legacy
|
|
696
|
+
# PineType slot can't represent array<T>, so fall back to a
|
|
697
|
+
# reasonable scalar PineType (UNKNOWN is a poor fit because
|
|
698
|
+
# downstream defaults to FLOAT for UNKNOWN).
|
|
699
|
+
return PineType.VOID
|
|
700
|
+
|
|
701
|
+
# Scalar-return methods (numeric matrix only — codegen rejects
|
|
702
|
+
# these on non-float matrices via MATRIX_NUMERIC_ONLY).
|
|
703
|
+
if member in ("det", "trace", "rank", "sum", "avg", "min", "max", "mode"):
|
|
704
|
+
return PineType.FLOAT
|
|
705
|
+
if member == "elements_count":
|
|
706
|
+
return PineType.INT
|
|
707
|
+
if member in (
|
|
708
|
+
"is_square", "is_identity", "is_diagonal", "is_antidiagonal",
|
|
709
|
+
"is_symmetric", "is_antisymmetric", "is_triangular",
|
|
710
|
+
"is_stochastic", "is_binary", "is_zero",
|
|
711
|
+
):
|
|
712
|
+
return PineType.BOOL
|
|
713
|
+
if member in ("rows", "columns"):
|
|
714
|
+
return PineType.INT
|
|
715
|
+
|
|
716
|
+
# Mutators / matrix-returning methods don't carry a usable scalar
|
|
717
|
+
# PineType; type_spec on the LHS Symbol is what codegen reads for
|
|
718
|
+
# those cases.
|
|
719
|
+
return PineType.VOID
|
|
720
|
+
|
|
721
|
+
@staticmethod
|
|
722
|
+
def _element_pine_type(elem) -> PineType:
|
|
723
|
+
"""Element ``TypeSpec`` -> ``PineType`` for matrix.get() / array.get().
|
|
724
|
+
|
|
725
|
+
Mirrors :meth:`TypeHelper._pine_type_to_spec` (inverse direction).
|
|
726
|
+
Returns ``PineType.VOID`` when the element has no clean PineType
|
|
727
|
+
slot (UDT / nested collection) -- callers should consult ``type_spec``
|
|
728
|
+
/ ``udt_type_name`` on the resulting Symbol instead.
|
|
729
|
+
"""
|
|
730
|
+
if elem is None:
|
|
731
|
+
return PineType.VOID
|
|
732
|
+
if elem.kind == "primitive":
|
|
733
|
+
mapping = {
|
|
734
|
+
"int": PineType.INT,
|
|
735
|
+
"float": PineType.FLOAT,
|
|
736
|
+
"bool": PineType.BOOL,
|
|
737
|
+
"string": PineType.STRING,
|
|
738
|
+
"color": PineType.COLOR,
|
|
739
|
+
}
|
|
740
|
+
return mapping.get(elem.name or "", PineType.VOID)
|
|
741
|
+
# UDT, array, map, matrix: PineType enum can't represent these.
|
|
742
|
+
# _visit_VarDecl's type_spec / udt_type_name path covers the gap.
|
|
743
|
+
return PineType.VOID
|
|
744
|
+
|
|
745
|
+
# ------------------------------------------------------------------
|
|
746
|
+
# fixnan handling
|
|
747
|
+
# ------------------------------------------------------------------
|
|
748
|
+
|
|
749
|
+
def _handle_fixnan_call(self, node: FuncCall) -> PineType:
|
|
750
|
+
"""Handle fixnan() calls."""
|
|
751
|
+
arg_type = PineType.FLOAT
|
|
752
|
+
for arg in node.args:
|
|
753
|
+
arg_type = self._visit(arg)
|
|
754
|
+
|
|
755
|
+
self._fixnan_counter += 1
|
|
756
|
+
site = FixnanCallSite(
|
|
757
|
+
member_name=f"_prev_fixnan_{self._fixnan_counter}",
|
|
758
|
+
pine_type=arg_type,
|
|
759
|
+
)
|
|
760
|
+
self._fixnan_sites.append(site)
|
|
761
|
+
|
|
762
|
+
return arg_type
|
|
763
|
+
|
|
764
|
+
# ------------------------------------------------------------------
|
|
765
|
+
# User-defined function calls
|
|
766
|
+
# ------------------------------------------------------------------
|
|
767
|
+
|
|
768
|
+
def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType:
|
|
769
|
+
"""Handle calls to user-defined functions."""
|
|
770
|
+
func_def = self._func_defs[func_name]
|
|
771
|
+
|
|
772
|
+
# Visit the call args
|
|
773
|
+
arg_types = []
|
|
774
|
+
for arg in node.args:
|
|
775
|
+
arg_types.append(self._visit(arg))
|
|
776
|
+
|
|
777
|
+
# Determine param types from call-site args
|
|
778
|
+
param_types = arg_types[:len(func_def.params)]
|
|
779
|
+
while len(param_types) < len(func_def.params):
|
|
780
|
+
param_types.append(PineType.UNKNOWN)
|
|
781
|
+
|
|
782
|
+
# Determine return type: re-analyze the function body with known param types
|
|
783
|
+
# For now, use the cached return type from initial analysis
|
|
784
|
+
return_type = self._func_return_types.get(func_name, PineType.FLOAT)
|
|
785
|
+
|
|
786
|
+
# If the return type was UNKNOWN or VOID, infer from param types
|
|
787
|
+
if return_type in (PineType.UNKNOWN, PineType.VOID):
|
|
788
|
+
if any(t == PineType.STRING for t in param_types):
|
|
789
|
+
return_type = PineType.STRING
|
|
790
|
+
elif any(t == PineType.FLOAT for t in param_types):
|
|
791
|
+
return_type = PineType.FLOAT
|
|
792
|
+
elif any(t == PineType.INT for t in param_types):
|
|
793
|
+
return_type = PineType.INT
|
|
794
|
+
|
|
795
|
+
# If this function has series params, ensure bar-field arguments
|
|
796
|
+
# passed at the call site are registered as series_bar_fields so that
|
|
797
|
+
# the codegen can create Series<double> members for them.
|
|
798
|
+
func_sv = self._func_series_vars.get(func_name, set())
|
|
799
|
+
if func_sv:
|
|
800
|
+
for p_idx, param_name in enumerate(func_def.params):
|
|
801
|
+
if param_name in func_sv and p_idx < len(node.args):
|
|
802
|
+
arg = node.args[p_idx]
|
|
803
|
+
if isinstance(arg, Identifier) and arg.name in BAR_FIELDS:
|
|
804
|
+
self._series_bar_fields.add(arg.name)
|
|
805
|
+
|
|
806
|
+
# Per-call-site cloning: if this function has TA calls or series vars,
|
|
807
|
+
# track call sites so codegen can create per-call-site variants.
|
|
808
|
+
# This prevents shared state corruption when the function is called
|
|
809
|
+
# multiple times per bar.
|
|
810
|
+
has_ta = func_name in self._func_ta_ranges
|
|
811
|
+
has_series = func_name in self._func_series_vars or func_name in self._func_var_members
|
|
812
|
+
if has_ta or has_series:
|
|
813
|
+
cs_idx = self._func_call_site_count.get(func_name, 0)
|
|
814
|
+
self._func_call_site_count[func_name] = cs_idx + 1
|
|
815
|
+
self._func_call_cs_map[id(node)] = (func_name, cs_idx)
|
|
816
|
+
|
|
817
|
+
# Build parameter -> call-site argument string mapping
|
|
818
|
+
param_arg_map: dict[str, str] = {}
|
|
819
|
+
for p_idx, param_name in enumerate(func_def.params):
|
|
820
|
+
if p_idx < len(node.args):
|
|
821
|
+
param_arg_map[param_name] = self._expr_to_str(node.args[p_idx])
|
|
822
|
+
|
|
823
|
+
# Clone TA call sites (only if function has TA ranges)
|
|
824
|
+
if has_ta:
|
|
825
|
+
start, end = self._func_ta_ranges[func_name]
|
|
826
|
+
|
|
827
|
+
def _subst_params(arg: str, pmap: dict[str, str]) -> str:
|
|
828
|
+
"""Substitute parameter names in an expression string.
|
|
829
|
+
|
|
830
|
+
Handles both exact matches ('len' -> 'len3') and parameter
|
|
831
|
+
names within expressions ('len / 2' -> 'len3 / 2').
|
|
832
|
+
"""
|
|
833
|
+
import re
|
|
834
|
+
result = arg
|
|
835
|
+
# Sort by length descending to avoid partial replacements
|
|
836
|
+
for param, value in sorted(pmap.items(), key=lambda x: len(x[0]), reverse=True):
|
|
837
|
+
result = re.sub(rf'\b{re.escape(param)}\b', value, result)
|
|
838
|
+
return result
|
|
839
|
+
|
|
840
|
+
if cs_idx == 0:
|
|
841
|
+
# First call site: save original param-based ctor_args for future cloning,
|
|
842
|
+
# then resolve to actual call-site values
|
|
843
|
+
for i in range(start, end):
|
|
844
|
+
site = self._ta_call_sites[i]
|
|
845
|
+
if not hasattr(site, '_orig_ctor_args'):
|
|
846
|
+
site._orig_ctor_args = site.ctor_args[:]
|
|
847
|
+
site.ctor_args = [_subst_params(a, param_arg_map) for a in site._orig_ctor_args]
|
|
848
|
+
else:
|
|
849
|
+
# Subsequent call sites: clone using saved original param names,
|
|
850
|
+
# substituted with this call site's arguments
|
|
851
|
+
for i in range(start, end):
|
|
852
|
+
orig = self._ta_call_sites[i]
|
|
853
|
+
orig_args = getattr(orig, '_orig_ctor_args', orig.ctor_args)
|
|
854
|
+
resolved_ctor = [_subst_params(a, param_arg_map) for a in orig_args]
|
|
855
|
+
cloned = TACallSite(
|
|
856
|
+
member_name=f"{orig.member_name}_cs{cs_idx}",
|
|
857
|
+
class_name=orig.class_name,
|
|
858
|
+
ctor_args=resolved_ctor,
|
|
859
|
+
compute_args=orig.compute_args[:],
|
|
860
|
+
returns_tuple=orig.returns_tuple,
|
|
861
|
+
node=orig.node,
|
|
862
|
+
is_static=orig.is_static,
|
|
863
|
+
)
|
|
864
|
+
self._ta_call_sites.append(cloned)
|
|
865
|
+
|
|
866
|
+
# Create or update FuncInfo
|
|
867
|
+
is_tuple = self._func_returns_tuple.get(func_name, False)
|
|
868
|
+
tuple_count = self._func_tuple_element_count.get(func_name, 0)
|
|
869
|
+
# Forward UDT-return inference (set in _visit_FuncDef) so codegen can
|
|
870
|
+
# emit the struct return type. Probe: udt-method-probe-20.
|
|
871
|
+
udt_ret = self._func_udt_return_types.get(func_name)
|
|
872
|
+
existing = [fi for fi in self._func_infos if fi.name == func_name]
|
|
873
|
+
if not existing:
|
|
874
|
+
fi = FuncInfo(
|
|
875
|
+
name=func_name,
|
|
876
|
+
param_types=param_types,
|
|
877
|
+
return_type=return_type,
|
|
878
|
+
node=func_def,
|
|
879
|
+
returns_tuple=is_tuple,
|
|
880
|
+
tuple_element_count=tuple_count,
|
|
881
|
+
udt_return_type=udt_ret,
|
|
882
|
+
)
|
|
883
|
+
self._func_infos.append(fi)
|
|
884
|
+
else:
|
|
885
|
+
# Update with better type info if available
|
|
886
|
+
fi = existing[0]
|
|
887
|
+
if fi.return_type in (PineType.UNKNOWN, PineType.VOID) and return_type not in (PineType.UNKNOWN, PineType.VOID):
|
|
888
|
+
fi.return_type = return_type
|
|
889
|
+
for i, pt in enumerate(param_types):
|
|
890
|
+
if i < len(fi.param_types) and fi.param_types[i] == PineType.UNKNOWN:
|
|
891
|
+
fi.param_types[i] = pt
|
|
892
|
+
if fi.udt_return_type is None and udt_ret is not None:
|
|
893
|
+
fi.udt_return_type = udt_ret
|
|
894
|
+
|
|
895
|
+
return return_type
|