signalflow-trading 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalflow/__init__.py +21 -0
- signalflow/analytics/__init__.py +0 -0
- signalflow/core/__init__.py +46 -0
- signalflow/core/base_mixin.py +232 -0
- signalflow/core/containers/__init__.py +21 -0
- signalflow/core/containers/order.py +216 -0
- signalflow/core/containers/portfolio.py +211 -0
- signalflow/core/containers/position.py +296 -0
- signalflow/core/containers/raw_data.py +167 -0
- signalflow/core/containers/raw_data_view.py +169 -0
- signalflow/core/containers/signals.py +198 -0
- signalflow/core/containers/strategy_state.py +147 -0
- signalflow/core/containers/trade.py +112 -0
- signalflow/core/decorators.py +103 -0
- signalflow/core/enums.py +270 -0
- signalflow/core/registry.py +322 -0
- signalflow/core/rolling_aggregator.py +362 -0
- signalflow/core/signal_transforms/__init__.py +5 -0
- signalflow/core/signal_transforms/base_signal_transform.py +186 -0
- signalflow/data/__init__.py +11 -0
- signalflow/data/raw_data_factory.py +225 -0
- signalflow/data/raw_store/__init__.py +7 -0
- signalflow/data/raw_store/base.py +271 -0
- signalflow/data/raw_store/duckdb_stores.py +696 -0
- signalflow/data/source/__init__.py +10 -0
- signalflow/data/source/base.py +300 -0
- signalflow/data/source/binance.py +442 -0
- signalflow/data/strategy_store/__init__.py +8 -0
- signalflow/data/strategy_store/base.py +278 -0
- signalflow/data/strategy_store/duckdb.py +409 -0
- signalflow/data/strategy_store/schema.py +36 -0
- signalflow/detector/__init__.py +7 -0
- signalflow/detector/adapter/__init__.py +5 -0
- signalflow/detector/adapter/pandas_detector.py +46 -0
- signalflow/detector/base.py +390 -0
- signalflow/detector/sma_cross.py +105 -0
- signalflow/feature/__init__.py +16 -0
- signalflow/feature/adapter/__init__.py +5 -0
- signalflow/feature/adapter/pandas_feature_extractor.py +54 -0
- signalflow/feature/base.py +330 -0
- signalflow/feature/feature_set.py +286 -0
- signalflow/feature/oscillator/__init__.py +5 -0
- signalflow/feature/oscillator/rsi_extractor.py +42 -0
- signalflow/feature/pandasta/__init__.py +10 -0
- signalflow/feature/pandasta/pandas_ta_extractor.py +141 -0
- signalflow/feature/pandasta/top_pandasta_extractors.py +64 -0
- signalflow/feature/smoother/__init__.py +5 -0
- signalflow/feature/smoother/sma_extractor.py +46 -0
- signalflow/strategy/__init__.py +9 -0
- signalflow/strategy/broker/__init__.py +15 -0
- signalflow/strategy/broker/backtest.py +172 -0
- signalflow/strategy/broker/base.py +186 -0
- signalflow/strategy/broker/executor/__init__.py +9 -0
- signalflow/strategy/broker/executor/base.py +35 -0
- signalflow/strategy/broker/executor/binance_spot.py +12 -0
- signalflow/strategy/broker/executor/virtual_spot.py +81 -0
- signalflow/strategy/broker/realtime_spot.py +12 -0
- signalflow/strategy/component/__init__.py +9 -0
- signalflow/strategy/component/base.py +65 -0
- signalflow/strategy/component/entry/__init__.py +7 -0
- signalflow/strategy/component/entry/fixed_size.py +57 -0
- signalflow/strategy/component/entry/signal.py +127 -0
- signalflow/strategy/component/exit/__init__.py +5 -0
- signalflow/strategy/component/exit/time_based.py +47 -0
- signalflow/strategy/component/exit/tp_sl.py +80 -0
- signalflow/strategy/component/metric/__init__.py +8 -0
- signalflow/strategy/component/metric/main_metrics.py +181 -0
- signalflow/strategy/runner/__init__.py +8 -0
- signalflow/strategy/runner/backtest_runner.py +208 -0
- signalflow/strategy/runner/base.py +19 -0
- signalflow/strategy/runner/optimized_backtest_runner.py +178 -0
- signalflow/strategy/runner/realtime_runner.py +0 -0
- signalflow/target/__init__.py +14 -0
- signalflow/target/adapter/__init__.py +5 -0
- signalflow/target/adapter/pandas_labeler.py +45 -0
- signalflow/target/base.py +409 -0
- signalflow/target/fixed_horizon_labeler.py +93 -0
- signalflow/target/static_triple_barrier.py +162 -0
- signalflow/target/triple_barrier.py +188 -0
- signalflow/utils/__init__.py +7 -0
- signalflow/utils/import_utils.py +11 -0
- signalflow/utils/tune_utils.py +19 -0
- signalflow/validator/__init__.py +6 -0
- signalflow/validator/base.py +139 -0
- signalflow/validator/sklearn_validator.py +527 -0
- signalflow_trading-0.2.1.dist-info/METADATA +149 -0
- signalflow_trading-0.2.1.dist-info/RECORD +90 -0
- signalflow_trading-0.2.1.dist-info/WHEEL +5 -0
- signalflow_trading-0.2.1.dist-info/licenses/LICENSE +21 -0
- signalflow_trading-0.2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, ClassVar
|
|
6
|
+
|
|
7
|
+
import polars as pl
|
|
8
|
+
|
|
9
|
+
from signalflow.core import RawDataType, SfComponentType, SignalType, Signals
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Labeler(ABC):
|
|
14
|
+
"""Base class for Polars-only signal labeling.
|
|
15
|
+
|
|
16
|
+
Assigns forward-looking labels to historical data based on future price
|
|
17
|
+
movement. Labels are computed per-pair with length-preserving operations.
|
|
18
|
+
|
|
19
|
+
Key concepts:
|
|
20
|
+
- Forward-looking: Labels depend on future data (not available in live trading)
|
|
21
|
+
- Per-pair processing: Each pair labeled independently
|
|
22
|
+
- Length-preserving: Output has same row count as input
|
|
23
|
+
- Signal masking: Optionally label only at signal timestamps
|
|
24
|
+
|
|
25
|
+
Public API:
|
|
26
|
+
- compute(): Main entry point (handles grouping, filtering, projection)
|
|
27
|
+
- compute_group(): Per-pair labeling logic (must implement)
|
|
28
|
+
|
|
29
|
+
Common labeling strategies:
|
|
30
|
+
- Fixed horizon: Label based on return over N bars
|
|
31
|
+
- Triple barrier: Label based on first hit of profit/loss/time barrier
|
|
32
|
+
- Quantile-based: Label based on return quantiles
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
component_type (ClassVar[SfComponentType]): Always LABELER for registry.
|
|
36
|
+
raw_data_type (RawDataType): Type of raw data. Default: SPOT.
|
|
37
|
+
pair_col (str): Trading pair column. Default: "pair".
|
|
38
|
+
ts_col (str): Timestamp column. Default: "timestamp".
|
|
39
|
+
keep_input_columns (bool): Keep all input columns. Default: False.
|
|
40
|
+
output_columns (list[str] | None): Specific columns to output. Default: None.
|
|
41
|
+
filter_signal_type (SignalType | None): Filter to specific signal type. Default: None.
|
|
42
|
+
mask_to_signals (bool): Mask labels to signal timestamps only. Default: True.
|
|
43
|
+
out_col (str): Output label column name. Default: "label".
|
|
44
|
+
include_meta (bool): Include metadata columns. Default: False.
|
|
45
|
+
meta_columns (tuple[str, ...]): Metadata column names. Default: ("t_hit", "ret").
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
```python
|
|
49
|
+
from signalflow.target import Labeler
|
|
50
|
+
from signalflow.core import SignalType
|
|
51
|
+
import polars as pl
|
|
52
|
+
|
|
53
|
+
class FixedHorizonLabeler(Labeler):
|
|
54
|
+
'''Label based on fixed-horizon return'''
|
|
55
|
+
|
|
56
|
+
def __init__(self, horizon: int = 10, threshold: float = 0.01):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.horizon = horizon
|
|
59
|
+
self.threshold = threshold
|
|
60
|
+
|
|
61
|
+
def compute_group(self, group_df, data_context=None):
|
|
62
|
+
# Compute forward return
|
|
63
|
+
labels = group_df.with_columns([
|
|
64
|
+
pl.col("close").shift(-self.horizon).alias("future_close")
|
|
65
|
+
]).with_columns([
|
|
66
|
+
((pl.col("future_close") / pl.col("close")) - 1).alias("return")
|
|
67
|
+
]).with_columns([
|
|
68
|
+
pl.when(pl.col("return") > self.threshold)
|
|
69
|
+
.then(pl.lit(SignalType.RISE.value))
|
|
70
|
+
.when(pl.col("return") < -self.threshold)
|
|
71
|
+
.then(pl.lit(SignalType.FALL.value))
|
|
72
|
+
.otherwise(pl.lit(SignalType.NONE.value))
|
|
73
|
+
.alias("label")
|
|
74
|
+
])
|
|
75
|
+
|
|
76
|
+
return labels
|
|
77
|
+
|
|
78
|
+
# Usage
|
|
79
|
+
labeler = FixedHorizonLabeler(horizon=10, threshold=0.01)
|
|
80
|
+
labeled = labeler.compute(ohlcv_df, signals=signals)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Note:
|
|
84
|
+
compute_group() must preserve row count (no filtering).
|
|
85
|
+
All timestamps must be timezone-naive.
|
|
86
|
+
Signal masking requires mask_to_signals=True and signal_keys in context.
|
|
87
|
+
|
|
88
|
+
See Also:
|
|
89
|
+
FixedHorizonLabeler: Simple fixed-horizon implementation.
|
|
90
|
+
TripleBarrierLabeler: Three-barrier labeling strategy.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
component_type: ClassVar[SfComponentType] = SfComponentType.LABELER
|
|
94
|
+
raw_data_type: RawDataType = RawDataType.SPOT
|
|
95
|
+
|
|
96
|
+
pair_col: str = "pair"
|
|
97
|
+
ts_col: str = "timestamp"
|
|
98
|
+
|
|
99
|
+
keep_input_columns: bool = False
|
|
100
|
+
output_columns: list[str] | None = None
|
|
101
|
+
filter_signal_type: SignalType | None = None
|
|
102
|
+
|
|
103
|
+
mask_to_signals: bool = True
|
|
104
|
+
out_col: str = "label"
|
|
105
|
+
include_meta: bool = False
|
|
106
|
+
meta_columns: tuple[str, ...] = ("t_hit", "ret")
|
|
107
|
+
|
|
108
|
+
def compute(
|
|
109
|
+
self,
|
|
110
|
+
df: pl.DataFrame,
|
|
111
|
+
signals: Signals | None = None,
|
|
112
|
+
data_context: dict[str, Any] | None = None,
|
|
113
|
+
) -> pl.DataFrame:
|
|
114
|
+
"""Compute labels for input DataFrame.
|
|
115
|
+
|
|
116
|
+
Main entry point - handles validation, filtering, grouping, and projection.
|
|
117
|
+
|
|
118
|
+
Processing steps:
|
|
119
|
+
1. Validate input schema
|
|
120
|
+
2. Sort by (pair, timestamp)
|
|
121
|
+
3. (optional) Filter to specific signal type
|
|
122
|
+
4. Group by pair and apply compute_group()
|
|
123
|
+
5. Validate output (length-preserving)
|
|
124
|
+
6. Project to output columns
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
df (pl.DataFrame): Input data with OHLCV and required columns.
|
|
128
|
+
signals (Signals | None): Signals for filtering/masking.
|
|
129
|
+
data_context (dict[str, Any] | None): Additional context.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
pl.DataFrame: Labeled data with columns:
|
|
133
|
+
- pair, timestamp (always included)
|
|
134
|
+
- label column(s) (as specified by out_col)
|
|
135
|
+
- (optional) metadata columns
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
TypeError: If df not pl.DataFrame or compute_group returns wrong type.
|
|
139
|
+
ValueError: If compute_group changes row count or columns missing.
|
|
140
|
+
|
|
141
|
+
Example:
|
|
142
|
+
```python
|
|
143
|
+
# Basic labeling
|
|
144
|
+
labeled = labeler.compute(ohlcv_df)
|
|
145
|
+
|
|
146
|
+
# With signal filtering
|
|
147
|
+
labeled = labeler.compute(
|
|
148
|
+
ohlcv_df,
|
|
149
|
+
signals=signals,
|
|
150
|
+
filter_signal_type=SignalType.RISE
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# With masking context
|
|
154
|
+
labeled = labeler.compute(
|
|
155
|
+
ohlcv_df,
|
|
156
|
+
signals=signals,
|
|
157
|
+
data_context={"signal_keys": signal_timestamps_df}
|
|
158
|
+
)
|
|
159
|
+
```
|
|
160
|
+
"""
|
|
161
|
+
if not isinstance(df, pl.DataFrame):
|
|
162
|
+
raise TypeError(f"{self.__class__.__name__}.compute expects pl.DataFrame, got {type(df)}")
|
|
163
|
+
return self._compute_pl(df=df, signals=signals, data_context=data_context)
|
|
164
|
+
|
|
165
|
+
def _compute_pl(
|
|
166
|
+
self,
|
|
167
|
+
df: pl.DataFrame,
|
|
168
|
+
signals: Signals | None,
|
|
169
|
+
data_context: dict[str, Any] | None,
|
|
170
|
+
) -> pl.DataFrame:
|
|
171
|
+
"""Internal Polars-based computation.
|
|
172
|
+
|
|
173
|
+
Orchestrates validation, filtering, grouping, and projection.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
df (pl.DataFrame): Input data.
|
|
177
|
+
signals (Signals | None): Optional signals.
|
|
178
|
+
data_context (dict[str, Any] | None): Optional context.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
pl.DataFrame: Labeled data.
|
|
182
|
+
"""
|
|
183
|
+
self._validate_input_pl(df)
|
|
184
|
+
df0 = df.sort([self.pair_col, self.ts_col])
|
|
185
|
+
|
|
186
|
+
if signals is not None and self.filter_signal_type is not None:
|
|
187
|
+
s_pl = self._signals_to_pl(signals)
|
|
188
|
+
df0 = self._filter_by_signals_pl(df0, s_pl, self.filter_signal_type)
|
|
189
|
+
|
|
190
|
+
input_cols = set(df0.columns)
|
|
191
|
+
|
|
192
|
+
def _wrapped(g: pl.DataFrame) -> pl.DataFrame:
|
|
193
|
+
out = self.compute_group(g, data_context=data_context)
|
|
194
|
+
if not isinstance(out, pl.DataFrame):
|
|
195
|
+
raise TypeError(f"{self.__class__.__name__}.compute_group must return pl.DataFrame")
|
|
196
|
+
if out.height != g.height:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"{self.__class__.__name__}: len(output_group)={out.height} != len(input_group)={g.height}"
|
|
199
|
+
)
|
|
200
|
+
return out
|
|
201
|
+
|
|
202
|
+
out = (
|
|
203
|
+
df0.group_by(self.pair_col, maintain_order=True)
|
|
204
|
+
.map_groups(_wrapped)
|
|
205
|
+
.sort([self.pair_col, self.ts_col])
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if self.keep_input_columns:
|
|
209
|
+
return out
|
|
210
|
+
|
|
211
|
+
label_cols = (
|
|
212
|
+
sorted(set(out.columns) - input_cols)
|
|
213
|
+
if self.output_columns is None
|
|
214
|
+
else list(self.output_columns)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
keep_cols = [self.pair_col, self.ts_col] + label_cols
|
|
218
|
+
missing = [c for c in keep_cols if c not in out.columns]
|
|
219
|
+
if missing:
|
|
220
|
+
raise ValueError(f"Projection error, missing columns: {missing}")
|
|
221
|
+
|
|
222
|
+
return out.select(keep_cols)
|
|
223
|
+
|
|
224
|
+
def _signals_to_pl(self, signals: Signals) -> pl.DataFrame:
|
|
225
|
+
"""Convert Signals to Polars DataFrame.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
signals (Signals): Signals container.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
pl.DataFrame: Signals as DataFrame.
|
|
232
|
+
|
|
233
|
+
Raises:
|
|
234
|
+
TypeError: If Signals.value is not pl.DataFrame.
|
|
235
|
+
"""
|
|
236
|
+
s = signals.value
|
|
237
|
+
if isinstance(s, pl.DataFrame):
|
|
238
|
+
return s
|
|
239
|
+
raise TypeError(f"Unsupported Signals.value type: {type(s)}")
|
|
240
|
+
|
|
241
|
+
def _filter_by_signals_pl(
|
|
242
|
+
self, df: pl.DataFrame, s: pl.DataFrame, signal_type: SignalType
|
|
243
|
+
) -> pl.DataFrame:
|
|
244
|
+
"""Filter input to rows matching signal timestamps.
|
|
245
|
+
|
|
246
|
+
Inner join with signal timestamps of specific type.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
df (pl.DataFrame): Input data.
|
|
250
|
+
s (pl.DataFrame): Signals DataFrame.
|
|
251
|
+
signal_type (SignalType): Signal type to filter.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
pl.DataFrame: Filtered data (only rows at signal timestamps).
|
|
255
|
+
|
|
256
|
+
Raises:
|
|
257
|
+
ValueError: If signals missing required columns.
|
|
258
|
+
"""
|
|
259
|
+
required = {self.pair_col, self.ts_col, "signal_type"}
|
|
260
|
+
missing = required - set(s.columns)
|
|
261
|
+
if missing:
|
|
262
|
+
raise ValueError(f"Signals missing columns: {sorted(missing)}")
|
|
263
|
+
|
|
264
|
+
s_f = (
|
|
265
|
+
s.filter(pl.col("signal_type") == signal_type.value)
|
|
266
|
+
.select([self.pair_col, self.ts_col])
|
|
267
|
+
.unique(subset=[self.pair_col, self.ts_col])
|
|
268
|
+
)
|
|
269
|
+
return df.join(s_f, on=[self.pair_col, self.ts_col], how="inner")
|
|
270
|
+
|
|
271
|
+
@abstractmethod
|
|
272
|
+
def compute_group(
|
|
273
|
+
self, group_df: pl.DataFrame, data_context: dict[str, Any] | None
|
|
274
|
+
) -> pl.DataFrame:
|
|
275
|
+
"""Compute labels for single pair group.
|
|
276
|
+
|
|
277
|
+
Core labeling logic - must be implemented by subclasses.
|
|
278
|
+
|
|
279
|
+
CRITICAL: Must preserve row count (len(output) == len(input)).
|
|
280
|
+
No filtering allowed inside compute_group.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
group_df (pl.DataFrame): Single pair's data, sorted by timestamp.
|
|
284
|
+
data_context (dict[str, Any] | None): Additional context.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
pl.DataFrame: Same length as input with added label columns.
|
|
288
|
+
|
|
289
|
+
Example:
|
|
290
|
+
```python
|
|
291
|
+
def compute_group(self, group_df, data_context=None):
|
|
292
|
+
# Compute 10-bar forward return
|
|
293
|
+
return group_df.with_columns([
|
|
294
|
+
pl.col("close").shift(-10).alias("future_close")
|
|
295
|
+
]).with_columns([
|
|
296
|
+
((pl.col("future_close") / pl.col("close")) - 1).alias("return"),
|
|
297
|
+
pl.when((pl.col("future_close") / pl.col("close") - 1) > 0.01)
|
|
298
|
+
.then(pl.lit(SignalType.RISE.value))
|
|
299
|
+
.otherwise(pl.lit(SignalType.NONE.value))
|
|
300
|
+
.alias("label")
|
|
301
|
+
])
|
|
302
|
+
```
|
|
303
|
+
|
|
304
|
+
Note:
|
|
305
|
+
Output must have same height as input (length-preserving).
|
|
306
|
+
Use shift(-n) for forward-looking operations.
|
|
307
|
+
Last N bars will have null labels (no future data).
|
|
308
|
+
"""
|
|
309
|
+
raise NotImplementedError
|
|
310
|
+
|
|
311
|
+
def _validate_input_pl(self, df: pl.DataFrame) -> None:
|
|
312
|
+
"""Validate input DataFrame schema.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
df (pl.DataFrame): Input to validate.
|
|
316
|
+
|
|
317
|
+
Raises:
|
|
318
|
+
ValueError: If required columns missing.
|
|
319
|
+
"""
|
|
320
|
+
missing = [c for c in (self.pair_col, self.ts_col) if c not in df.columns]
|
|
321
|
+
if missing:
|
|
322
|
+
raise ValueError(f"Missing required columns: {missing}")
|
|
323
|
+
|
|
324
|
+
def _apply_signal_mask(
|
|
325
|
+
self,
|
|
326
|
+
df: pl.DataFrame,
|
|
327
|
+
data_context: dict[str, Any],
|
|
328
|
+
group_df: pl.DataFrame,
|
|
329
|
+
) -> pl.DataFrame:
|
|
330
|
+
"""Mask labels to signal timestamps only.
|
|
331
|
+
|
|
332
|
+
Labels are computed for all rows, but only signal timestamps
|
|
333
|
+
get actual labels; others are set to SignalType.NONE.
|
|
334
|
+
|
|
335
|
+
Used for meta-labeling: only label at detected signal points,
|
|
336
|
+
not every bar.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
df (pl.DataFrame): DataFrame with computed labels.
|
|
340
|
+
data_context (dict[str, Any]): Must contain "signal_keys" DataFrame.
|
|
341
|
+
group_df (pl.DataFrame): Original group data for extracting pair value.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
pl.DataFrame: DataFrame with masked labels.
|
|
345
|
+
|
|
346
|
+
Example:
|
|
347
|
+
```python
|
|
348
|
+
# In compute_group with masking
|
|
349
|
+
def compute_group(self, group_df, data_context=None):
|
|
350
|
+
# Compute labels for all rows
|
|
351
|
+
labeled = group_df.with_columns([...])
|
|
352
|
+
|
|
353
|
+
# Mask to signal timestamps only
|
|
354
|
+
if self.mask_to_signals and data_context:
|
|
355
|
+
labeled = self._apply_signal_mask(
|
|
356
|
+
labeled, data_context, group_df
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return labeled
|
|
360
|
+
```
|
|
361
|
+
|
|
362
|
+
Note:
|
|
363
|
+
Requires signal_keys in data_context with (pair, timestamp) columns.
|
|
364
|
+
Non-signal rows get label=SignalType.NONE.
|
|
365
|
+
Metadata columns also masked if include_meta=True.
|
|
366
|
+
"""
|
|
367
|
+
signal_keys: pl.DataFrame = data_context["signal_keys"]
|
|
368
|
+
pair_value = group_df.get_column(self.pair_col)[0]
|
|
369
|
+
|
|
370
|
+
signal_ts = (
|
|
371
|
+
signal_keys.filter(pl.col(self.pair_col) == pair_value)
|
|
372
|
+
.select(self.ts_col)
|
|
373
|
+
.unique()
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if signal_ts.height == 0:
|
|
377
|
+
df = df.with_columns(pl.lit(SignalType.NONE.value).alias(self.out_col))
|
|
378
|
+
if self.include_meta:
|
|
379
|
+
df = df.with_columns(
|
|
380
|
+
[pl.lit(None).alias(col) for col in self.meta_columns]
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
is_signal = pl.col("_is_signal").fill_null(False)
|
|
384
|
+
mask_exprs = [
|
|
385
|
+
pl.when(is_signal)
|
|
386
|
+
.then(pl.col(self.out_col))
|
|
387
|
+
.otherwise(pl.lit(SignalType.NONE.value))
|
|
388
|
+
.alias(self.out_col),
|
|
389
|
+
]
|
|
390
|
+
if self.include_meta:
|
|
391
|
+
mask_exprs += [
|
|
392
|
+
pl.when(is_signal)
|
|
393
|
+
.then(pl.col(col))
|
|
394
|
+
.otherwise(pl.lit(None))
|
|
395
|
+
.alias(col)
|
|
396
|
+
for col in self.meta_columns
|
|
397
|
+
]
|
|
398
|
+
|
|
399
|
+
df = (
|
|
400
|
+
df.join(
|
|
401
|
+
signal_ts.with_columns(pl.lit(True).alias("_is_signal")),
|
|
402
|
+
on=self.ts_col,
|
|
403
|
+
how="left",
|
|
404
|
+
)
|
|
405
|
+
.with_columns(mask_exprs)
|
|
406
|
+
.drop("_is_signal")
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return df
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# IMPORTANT
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from signalflow.core import SignalType
|
|
9
|
+
from signalflow.target.base import Labeler
|
|
10
|
+
from signalflow.core import sf_component
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
@sf_component(name="fixed_horizon")
|
|
15
|
+
class FixedHorizonLabeler(Labeler):
|
|
16
|
+
"""
|
|
17
|
+
Fixed-Horizon Labeling:
|
|
18
|
+
label[t0] = sign(close[t0 + horizon] - close[t0])
|
|
19
|
+
|
|
20
|
+
If signals provided, labels are written only on signal rows,
|
|
21
|
+
while horizon is computed on full series (per pair).
|
|
22
|
+
"""
|
|
23
|
+
price_col: str = "close"
|
|
24
|
+
horizon: int = 60
|
|
25
|
+
|
|
26
|
+
meta_columns: tuple[str, ...] = ("t1", "ret")
|
|
27
|
+
|
|
28
|
+
def __post_init__(self) -> None:
|
|
29
|
+
if self.horizon <= 0:
|
|
30
|
+
raise ValueError("horizon must be > 0")
|
|
31
|
+
|
|
32
|
+
cols = [self.out_col]
|
|
33
|
+
if self.include_meta:
|
|
34
|
+
cols += list(self.meta_columns)
|
|
35
|
+
self.output_columns = cols
|
|
36
|
+
|
|
37
|
+
def compute_group(
|
|
38
|
+
self, group_df: pl.DataFrame, data_context: dict[str, Any] | None
|
|
39
|
+
) -> pl.DataFrame:
|
|
40
|
+
if self.price_col not in group_df.columns:
|
|
41
|
+
raise ValueError(f"Missing required column '{self.price_col}'")
|
|
42
|
+
|
|
43
|
+
if group_df.height == 0:
|
|
44
|
+
return group_df
|
|
45
|
+
|
|
46
|
+
h = int(self.horizon)
|
|
47
|
+
price = pl.col(self.price_col)
|
|
48
|
+
future_price = price.shift(-h)
|
|
49
|
+
|
|
50
|
+
df = group_df.with_columns(future_price.alias("_future_price"))
|
|
51
|
+
|
|
52
|
+
label_expr = (
|
|
53
|
+
pl.when(
|
|
54
|
+
pl.col("_future_price").is_null()
|
|
55
|
+
| pl.col(self.price_col).is_null()
|
|
56
|
+
| (pl.col(self.price_col) <= 0)
|
|
57
|
+
| (pl.col("_future_price") <= 0)
|
|
58
|
+
)
|
|
59
|
+
.then(pl.lit(SignalType.NONE.value))
|
|
60
|
+
.when(pl.col("_future_price") > pl.col(self.price_col))
|
|
61
|
+
.then(pl.lit(SignalType.RISE.value))
|
|
62
|
+
.when(pl.col("_future_price") < pl.col(self.price_col))
|
|
63
|
+
.then(pl.lit(SignalType.FALL.value))
|
|
64
|
+
.otherwise(pl.lit(SignalType.NONE.value))
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
df = df.with_columns(label_expr.alias(self.out_col))
|
|
68
|
+
|
|
69
|
+
if self.include_meta:
|
|
70
|
+
df = df.with_columns(
|
|
71
|
+
[
|
|
72
|
+
pl.col(self.ts_col).shift(-h).alias("t1"),
|
|
73
|
+
pl.when(
|
|
74
|
+
pl.col("_future_price").is_not_null()
|
|
75
|
+
& (pl.col(self.price_col) > 0)
|
|
76
|
+
& (pl.col("_future_price") > 0)
|
|
77
|
+
)
|
|
78
|
+
.then((pl.col("_future_price") / pl.col(self.price_col)).log())
|
|
79
|
+
.otherwise(pl.lit(None))
|
|
80
|
+
.alias("ret"),
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
df = df.drop("_future_price")
|
|
85
|
+
|
|
86
|
+
if (
|
|
87
|
+
self.mask_to_signals
|
|
88
|
+
and data_context is not None
|
|
89
|
+
and "signal_keys" in data_context
|
|
90
|
+
):
|
|
91
|
+
df = self._apply_signal_mask(df, data_context, group_df)
|
|
92
|
+
|
|
93
|
+
return df
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import polars as pl
|
|
8
|
+
from numba import njit, prange
|
|
9
|
+
|
|
10
|
+
from signalflow.core import sf_component, SignalType
|
|
11
|
+
from signalflow.target.base import Labeler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@njit(parallel=True, cache=True)
|
|
15
|
+
def _find_first_hit_static(
|
|
16
|
+
prices: np.ndarray,
|
|
17
|
+
pt: np.ndarray,
|
|
18
|
+
sl: np.ndarray,
|
|
19
|
+
lookforward: int,
|
|
20
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
21
|
+
"""
|
|
22
|
+
Finds the first hit for static barriers.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
up_off: offset of the first PT hit (0 = no hit)
|
|
26
|
+
dn_off: offset of the first SL hit (0 = no hit)
|
|
27
|
+
"""
|
|
28
|
+
n = len(prices)
|
|
29
|
+
up_off = np.zeros(n, dtype=np.int32)
|
|
30
|
+
dn_off = np.zeros(n, dtype=np.int32)
|
|
31
|
+
|
|
32
|
+
for i in prange(n):
|
|
33
|
+
pt_i = pt[i]
|
|
34
|
+
sl_i = sl[i]
|
|
35
|
+
|
|
36
|
+
max_j = min(i + lookforward, n - 1)
|
|
37
|
+
|
|
38
|
+
for k in range(1, max_j - i + 1):
|
|
39
|
+
p = prices[i + k]
|
|
40
|
+
|
|
41
|
+
if up_off[i] == 0 and p >= pt_i:
|
|
42
|
+
up_off[i] = k
|
|
43
|
+
|
|
44
|
+
if dn_off[i] == 0 and p <= sl_i:
|
|
45
|
+
dn_off[i] = k
|
|
46
|
+
|
|
47
|
+
if up_off[i] > 0 and dn_off[i] > 0:
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
return up_off, dn_off
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
@sf_component(name="static_triple_barrier")
|
|
55
|
+
class StaticTripleBarrierLabeler(Labeler):
|
|
56
|
+
"""
|
|
57
|
+
Triple-Barrier (first-touch) labeling with STATIC horizontal barriers.
|
|
58
|
+
Numba-accelerated version.
|
|
59
|
+
|
|
60
|
+
De Prado's framework:
|
|
61
|
+
- Vertical barrier at t1 = t0 + lookforward_window
|
|
62
|
+
- Horizontal barriers defined as % from initial price at t0:
|
|
63
|
+
pt = close[t0] * (1 + profit_pct)
|
|
64
|
+
sl = close[t0] * (1 - stop_loss_pct)
|
|
65
|
+
- Label by first touch within (t0, t1]:
|
|
66
|
+
RISE if PT touched first (ties -> PT)
|
|
67
|
+
FALL if SL touched first
|
|
68
|
+
NONE if none touched by t1
|
|
69
|
+
"""
|
|
70
|
+
price_col: str = "close"
|
|
71
|
+
|
|
72
|
+
lookforward_window: int = 1440
|
|
73
|
+
profit_pct: float = 0.01
|
|
74
|
+
stop_loss_pct: float = 0.01
|
|
75
|
+
|
|
76
|
+
def __post_init__(self) -> None:
|
|
77
|
+
if self.lookforward_window <= 0:
|
|
78
|
+
raise ValueError("lookforward_window must be > 0")
|
|
79
|
+
if self.profit_pct <= 0 or self.stop_loss_pct <= 0:
|
|
80
|
+
raise ValueError("profit_pct/stop_loss_pct must be > 0")
|
|
81
|
+
|
|
82
|
+
cols = [self.out_col]
|
|
83
|
+
if self.include_meta:
|
|
84
|
+
cols += list(self.meta_columns)
|
|
85
|
+
self.output_columns = cols
|
|
86
|
+
|
|
87
|
+
def compute_group(
|
|
88
|
+
self, group_df: pl.DataFrame, data_context: dict[str, Any] | None
|
|
89
|
+
) -> pl.DataFrame:
|
|
90
|
+
if self.price_col not in group_df.columns:
|
|
91
|
+
raise ValueError(f"Missing required column '{self.price_col}'")
|
|
92
|
+
|
|
93
|
+
if group_df.height == 0:
|
|
94
|
+
return group_df
|
|
95
|
+
|
|
96
|
+
lf = int(self.lookforward_window)
|
|
97
|
+
n = group_df.height
|
|
98
|
+
|
|
99
|
+
prices = group_df.get_column(self.price_col).to_numpy().astype(np.float64)
|
|
100
|
+
pt = prices * (1.0 + self.profit_pct)
|
|
101
|
+
sl = prices * (1.0 - self.stop_loss_pct)
|
|
102
|
+
|
|
103
|
+
up_off, dn_off = _find_first_hit_static(prices, pt, sl, lf)
|
|
104
|
+
|
|
105
|
+
up_off_series = pl.Series("_up_off", up_off).replace(0, None).cast(pl.Int32)
|
|
106
|
+
dn_off_series = pl.Series("_dn_off", dn_off).replace(0, None).cast(pl.Int32)
|
|
107
|
+
|
|
108
|
+
df = group_df.with_columns([up_off_series, dn_off_series])
|
|
109
|
+
|
|
110
|
+
choose_up = pl.col("_up_off").is_not_null() & (
|
|
111
|
+
pl.col("_dn_off").is_null() | (pl.col("_up_off") <= pl.col("_dn_off"))
|
|
112
|
+
)
|
|
113
|
+
choose_dn = pl.col("_dn_off").is_not_null() & (
|
|
114
|
+
pl.col("_up_off").is_null() | (pl.col("_dn_off") < pl.col("_up_off"))
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
df = df.with_columns(
|
|
118
|
+
pl.when(choose_up)
|
|
119
|
+
.then(pl.lit(SignalType.RISE.value))
|
|
120
|
+
.when(choose_dn)
|
|
121
|
+
.then(pl.lit(SignalType.FALL.value))
|
|
122
|
+
.otherwise(pl.lit(SignalType.NONE.value))
|
|
123
|
+
.alias(self.out_col)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if self.include_meta:
|
|
127
|
+
ts_arr = group_df.get_column(self.ts_col).to_numpy()
|
|
128
|
+
|
|
129
|
+
up_np = up_off_series.fill_null(0).to_numpy()
|
|
130
|
+
dn_np = dn_off_series.fill_null(0).to_numpy()
|
|
131
|
+
idx = np.arange(n)
|
|
132
|
+
|
|
133
|
+
hit_off = np.where(
|
|
134
|
+
(up_np > 0) & ((dn_np == 0) | (up_np <= dn_np)),
|
|
135
|
+
up_np,
|
|
136
|
+
np.where(dn_np > 0, dn_np, 0),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
hit_idx = np.clip(idx + hit_off, 0, n - 1)
|
|
140
|
+
vert_idx = np.clip(idx + lf, 0, n - 1)
|
|
141
|
+
final_idx = np.where(hit_off > 0, hit_idx, vert_idx)
|
|
142
|
+
|
|
143
|
+
t_hit = ts_arr[final_idx]
|
|
144
|
+
ret = np.log(prices[final_idx] / prices)
|
|
145
|
+
|
|
146
|
+
df = df.with_columns(
|
|
147
|
+
[
|
|
148
|
+
pl.Series("t_hit", t_hit),
|
|
149
|
+
pl.Series("ret", ret),
|
|
150
|
+
]
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if (
|
|
154
|
+
self.mask_to_signals
|
|
155
|
+
and data_context is not None
|
|
156
|
+
and "signal_keys" in data_context
|
|
157
|
+
):
|
|
158
|
+
df = self._apply_signal_mask(df, data_context, group_df)
|
|
159
|
+
|
|
160
|
+
df = df.drop(["_up_off", "_dn_off"])
|
|
161
|
+
|
|
162
|
+
return df
|