signalflow-trading 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. signalflow/__init__.py +21 -0
  2. signalflow/analytics/__init__.py +0 -0
  3. signalflow/core/__init__.py +46 -0
  4. signalflow/core/base_mixin.py +232 -0
  5. signalflow/core/containers/__init__.py +21 -0
  6. signalflow/core/containers/order.py +216 -0
  7. signalflow/core/containers/portfolio.py +211 -0
  8. signalflow/core/containers/position.py +296 -0
  9. signalflow/core/containers/raw_data.py +167 -0
  10. signalflow/core/containers/raw_data_view.py +169 -0
  11. signalflow/core/containers/signals.py +198 -0
  12. signalflow/core/containers/strategy_state.py +147 -0
  13. signalflow/core/containers/trade.py +112 -0
  14. signalflow/core/decorators.py +103 -0
  15. signalflow/core/enums.py +270 -0
  16. signalflow/core/registry.py +322 -0
  17. signalflow/core/rolling_aggregator.py +362 -0
  18. signalflow/core/signal_transforms/__init__.py +5 -0
  19. signalflow/core/signal_transforms/base_signal_transform.py +186 -0
  20. signalflow/data/__init__.py +11 -0
  21. signalflow/data/raw_data_factory.py +225 -0
  22. signalflow/data/raw_store/__init__.py +7 -0
  23. signalflow/data/raw_store/base.py +271 -0
  24. signalflow/data/raw_store/duckdb_stores.py +696 -0
  25. signalflow/data/source/__init__.py +10 -0
  26. signalflow/data/source/base.py +300 -0
  27. signalflow/data/source/binance.py +442 -0
  28. signalflow/data/strategy_store/__init__.py +8 -0
  29. signalflow/data/strategy_store/base.py +278 -0
  30. signalflow/data/strategy_store/duckdb.py +409 -0
  31. signalflow/data/strategy_store/schema.py +36 -0
  32. signalflow/detector/__init__.py +7 -0
  33. signalflow/detector/adapter/__init__.py +5 -0
  34. signalflow/detector/adapter/pandas_detector.py +46 -0
  35. signalflow/detector/base.py +390 -0
  36. signalflow/detector/sma_cross.py +105 -0
  37. signalflow/feature/__init__.py +16 -0
  38. signalflow/feature/adapter/__init__.py +5 -0
  39. signalflow/feature/adapter/pandas_feature_extractor.py +54 -0
  40. signalflow/feature/base.py +330 -0
  41. signalflow/feature/feature_set.py +286 -0
  42. signalflow/feature/oscillator/__init__.py +5 -0
  43. signalflow/feature/oscillator/rsi_extractor.py +42 -0
  44. signalflow/feature/pandasta/__init__.py +10 -0
  45. signalflow/feature/pandasta/pandas_ta_extractor.py +141 -0
  46. signalflow/feature/pandasta/top_pandasta_extractors.py +64 -0
  47. signalflow/feature/smoother/__init__.py +5 -0
  48. signalflow/feature/smoother/sma_extractor.py +46 -0
  49. signalflow/strategy/__init__.py +9 -0
  50. signalflow/strategy/broker/__init__.py +15 -0
  51. signalflow/strategy/broker/backtest.py +172 -0
  52. signalflow/strategy/broker/base.py +186 -0
  53. signalflow/strategy/broker/executor/__init__.py +9 -0
  54. signalflow/strategy/broker/executor/base.py +35 -0
  55. signalflow/strategy/broker/executor/binance_spot.py +12 -0
  56. signalflow/strategy/broker/executor/virtual_spot.py +81 -0
  57. signalflow/strategy/broker/realtime_spot.py +12 -0
  58. signalflow/strategy/component/__init__.py +9 -0
  59. signalflow/strategy/component/base.py +65 -0
  60. signalflow/strategy/component/entry/__init__.py +7 -0
  61. signalflow/strategy/component/entry/fixed_size.py +57 -0
  62. signalflow/strategy/component/entry/signal.py +127 -0
  63. signalflow/strategy/component/exit/__init__.py +5 -0
  64. signalflow/strategy/component/exit/time_based.py +47 -0
  65. signalflow/strategy/component/exit/tp_sl.py +80 -0
  66. signalflow/strategy/component/metric/__init__.py +8 -0
  67. signalflow/strategy/component/metric/main_metrics.py +181 -0
  68. signalflow/strategy/runner/__init__.py +8 -0
  69. signalflow/strategy/runner/backtest_runner.py +208 -0
  70. signalflow/strategy/runner/base.py +19 -0
  71. signalflow/strategy/runner/optimized_backtest_runner.py +178 -0
  72. signalflow/strategy/runner/realtime_runner.py +0 -0
  73. signalflow/target/__init__.py +14 -0
  74. signalflow/target/adapter/__init__.py +5 -0
  75. signalflow/target/adapter/pandas_labeler.py +45 -0
  76. signalflow/target/base.py +409 -0
  77. signalflow/target/fixed_horizon_labeler.py +93 -0
  78. signalflow/target/static_triple_barrier.py +162 -0
  79. signalflow/target/triple_barrier.py +188 -0
  80. signalflow/utils/__init__.py +7 -0
  81. signalflow/utils/import_utils.py +11 -0
  82. signalflow/utils/tune_utils.py +19 -0
  83. signalflow/validator/__init__.py +6 -0
  84. signalflow/validator/base.py +139 -0
  85. signalflow/validator/sklearn_validator.py +527 -0
  86. signalflow_trading-0.2.1.dist-info/METADATA +149 -0
  87. signalflow_trading-0.2.1.dist-info/RECORD +90 -0
  88. signalflow_trading-0.2.1.dist-info/WHEEL +5 -0
  89. signalflow_trading-0.2.1.dist-info/licenses/LICENSE +21 -0
  90. signalflow_trading-0.2.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,409 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Any, ClassVar
6
+
7
+ import polars as pl
8
+
9
+ from signalflow.core import RawDataType, SfComponentType, SignalType, Signals
10
+
11
+
12
+ @dataclass
13
+ class Labeler(ABC):
14
+ """Base class for Polars-only signal labeling.
15
+
16
+ Assigns forward-looking labels to historical data based on future price
17
+ movement. Labels are computed per-pair with length-preserving operations.
18
+
19
+ Key concepts:
20
+ - Forward-looking: Labels depend on future data (not available in live trading)
21
+ - Per-pair processing: Each pair labeled independently
22
+ - Length-preserving: Output has same row count as input
23
+ - Signal masking: Optionally label only at signal timestamps
24
+
25
+ Public API:
26
+ - compute(): Main entry point (handles grouping, filtering, projection)
27
+ - compute_group(): Per-pair labeling logic (must implement)
28
+
29
+ Common labeling strategies:
30
+ - Fixed horizon: Label based on return over N bars
31
+ - Triple barrier: Label based on first hit of profit/loss/time barrier
32
+ - Quantile-based: Label based on return quantiles
33
+
34
+ Attributes:
35
+ component_type (ClassVar[SfComponentType]): Always LABELER for registry.
36
+ raw_data_type (RawDataType): Type of raw data. Default: SPOT.
37
+ pair_col (str): Trading pair column. Default: "pair".
38
+ ts_col (str): Timestamp column. Default: "timestamp".
39
+ keep_input_columns (bool): Keep all input columns. Default: False.
40
+ output_columns (list[str] | None): Specific columns to output. Default: None.
41
+ filter_signal_type (SignalType | None): Filter to specific signal type. Default: None.
42
+ mask_to_signals (bool): Mask labels to signal timestamps only. Default: True.
43
+ out_col (str): Output label column name. Default: "label".
44
+ include_meta (bool): Include metadata columns. Default: False.
45
+ meta_columns (tuple[str, ...]): Metadata column names. Default: ("t_hit", "ret").
46
+
47
+ Example:
48
+ ```python
49
+ from signalflow.target import Labeler
50
+ from signalflow.core import SignalType
51
+ import polars as pl
52
+
53
+ class FixedHorizonLabeler(Labeler):
54
+ '''Label based on fixed-horizon return'''
55
+
56
+ def __init__(self, horizon: int = 10, threshold: float = 0.01):
57
+ super().__init__()
58
+ self.horizon = horizon
59
+ self.threshold = threshold
60
+
61
+ def compute_group(self, group_df, data_context=None):
62
+ # Compute forward return
63
+ labels = group_df.with_columns([
64
+ pl.col("close").shift(-self.horizon).alias("future_close")
65
+ ]).with_columns([
66
+ ((pl.col("future_close") / pl.col("close")) - 1).alias("return")
67
+ ]).with_columns([
68
+ pl.when(pl.col("return") > self.threshold)
69
+ .then(pl.lit(SignalType.RISE.value))
70
+ .when(pl.col("return") < -self.threshold)
71
+ .then(pl.lit(SignalType.FALL.value))
72
+ .otherwise(pl.lit(SignalType.NONE.value))
73
+ .alias("label")
74
+ ])
75
+
76
+ return labels
77
+
78
+ # Usage
79
+ labeler = FixedHorizonLabeler(horizon=10, threshold=0.01)
80
+ labeled = labeler.compute(ohlcv_df, signals=signals)
81
+ ```
82
+
83
+ Note:
84
+ compute_group() must preserve row count (no filtering).
85
+ All timestamps must be timezone-naive.
86
+ Signal masking requires mask_to_signals=True and signal_keys in context.
87
+
88
+ See Also:
89
+ FixedHorizonLabeler: Simple fixed-horizon implementation.
90
+ TripleBarrierLabeler: Three-barrier labeling strategy.
91
+ """
92
+
93
+ component_type: ClassVar[SfComponentType] = SfComponentType.LABELER
94
+ raw_data_type: RawDataType = RawDataType.SPOT
95
+
96
+ pair_col: str = "pair"
97
+ ts_col: str = "timestamp"
98
+
99
+ keep_input_columns: bool = False
100
+ output_columns: list[str] | None = None
101
+ filter_signal_type: SignalType | None = None
102
+
103
+ mask_to_signals: bool = True
104
+ out_col: str = "label"
105
+ include_meta: bool = False
106
+ meta_columns: tuple[str, ...] = ("t_hit", "ret")
107
+
108
+ def compute(
109
+ self,
110
+ df: pl.DataFrame,
111
+ signals: Signals | None = None,
112
+ data_context: dict[str, Any] | None = None,
113
+ ) -> pl.DataFrame:
114
+ """Compute labels for input DataFrame.
115
+
116
+ Main entry point - handles validation, filtering, grouping, and projection.
117
+
118
+ Processing steps:
119
+ 1. Validate input schema
120
+ 2. Sort by (pair, timestamp)
121
+ 3. (optional) Filter to specific signal type
122
+ 4. Group by pair and apply compute_group()
123
+ 5. Validate output (length-preserving)
124
+ 6. Project to output columns
125
+
126
+ Args:
127
+ df (pl.DataFrame): Input data with OHLCV and required columns.
128
+ signals (Signals | None): Signals for filtering/masking.
129
+ data_context (dict[str, Any] | None): Additional context.
130
+
131
+ Returns:
132
+ pl.DataFrame: Labeled data with columns:
133
+ - pair, timestamp (always included)
134
+ - label column(s) (as specified by out_col)
135
+ - (optional) metadata columns
136
+
137
+ Raises:
138
+ TypeError: If df not pl.DataFrame or compute_group returns wrong type.
139
+ ValueError: If compute_group changes row count or columns missing.
140
+
141
+ Example:
142
+ ```python
143
+ # Basic labeling
144
+ labeled = labeler.compute(ohlcv_df)
145
+
146
+ # With signal filtering
147
+ labeled = labeler.compute(
148
+ ohlcv_df,
149
+ signals=signals,
150
+ filter_signal_type=SignalType.RISE
151
+ )
152
+
153
+ # With masking context
154
+ labeled = labeler.compute(
155
+ ohlcv_df,
156
+ signals=signals,
157
+ data_context={"signal_keys": signal_timestamps_df}
158
+ )
159
+ ```
160
+ """
161
+ if not isinstance(df, pl.DataFrame):
162
+ raise TypeError(f"{self.__class__.__name__}.compute expects pl.DataFrame, got {type(df)}")
163
+ return self._compute_pl(df=df, signals=signals, data_context=data_context)
164
+
165
+ def _compute_pl(
166
+ self,
167
+ df: pl.DataFrame,
168
+ signals: Signals | None,
169
+ data_context: dict[str, Any] | None,
170
+ ) -> pl.DataFrame:
171
+ """Internal Polars-based computation.
172
+
173
+ Orchestrates validation, filtering, grouping, and projection.
174
+
175
+ Args:
176
+ df (pl.DataFrame): Input data.
177
+ signals (Signals | None): Optional signals.
178
+ data_context (dict[str, Any] | None): Optional context.
179
+
180
+ Returns:
181
+ pl.DataFrame: Labeled data.
182
+ """
183
+ self._validate_input_pl(df)
184
+ df0 = df.sort([self.pair_col, self.ts_col])
185
+
186
+ if signals is not None and self.filter_signal_type is not None:
187
+ s_pl = self._signals_to_pl(signals)
188
+ df0 = self._filter_by_signals_pl(df0, s_pl, self.filter_signal_type)
189
+
190
+ input_cols = set(df0.columns)
191
+
192
+ def _wrapped(g: pl.DataFrame) -> pl.DataFrame:
193
+ out = self.compute_group(g, data_context=data_context)
194
+ if not isinstance(out, pl.DataFrame):
195
+ raise TypeError(f"{self.__class__.__name__}.compute_group must return pl.DataFrame")
196
+ if out.height != g.height:
197
+ raise ValueError(
198
+ f"{self.__class__.__name__}: len(output_group)={out.height} != len(input_group)={g.height}"
199
+ )
200
+ return out
201
+
202
+ out = (
203
+ df0.group_by(self.pair_col, maintain_order=True)
204
+ .map_groups(_wrapped)
205
+ .sort([self.pair_col, self.ts_col])
206
+ )
207
+
208
+ if self.keep_input_columns:
209
+ return out
210
+
211
+ label_cols = (
212
+ sorted(set(out.columns) - input_cols)
213
+ if self.output_columns is None
214
+ else list(self.output_columns)
215
+ )
216
+
217
+ keep_cols = [self.pair_col, self.ts_col] + label_cols
218
+ missing = [c for c in keep_cols if c not in out.columns]
219
+ if missing:
220
+ raise ValueError(f"Projection error, missing columns: {missing}")
221
+
222
+ return out.select(keep_cols)
223
+
224
+ def _signals_to_pl(self, signals: Signals) -> pl.DataFrame:
225
+ """Convert Signals to Polars DataFrame.
226
+
227
+ Args:
228
+ signals (Signals): Signals container.
229
+
230
+ Returns:
231
+ pl.DataFrame: Signals as DataFrame.
232
+
233
+ Raises:
234
+ TypeError: If Signals.value is not pl.DataFrame.
235
+ """
236
+ s = signals.value
237
+ if isinstance(s, pl.DataFrame):
238
+ return s
239
+ raise TypeError(f"Unsupported Signals.value type: {type(s)}")
240
+
241
+ def _filter_by_signals_pl(
242
+ self, df: pl.DataFrame, s: pl.DataFrame, signal_type: SignalType
243
+ ) -> pl.DataFrame:
244
+ """Filter input to rows matching signal timestamps.
245
+
246
+ Inner join with signal timestamps of specific type.
247
+
248
+ Args:
249
+ df (pl.DataFrame): Input data.
250
+ s (pl.DataFrame): Signals DataFrame.
251
+ signal_type (SignalType): Signal type to filter.
252
+
253
+ Returns:
254
+ pl.DataFrame: Filtered data (only rows at signal timestamps).
255
+
256
+ Raises:
257
+ ValueError: If signals missing required columns.
258
+ """
259
+ required = {self.pair_col, self.ts_col, "signal_type"}
260
+ missing = required - set(s.columns)
261
+ if missing:
262
+ raise ValueError(f"Signals missing columns: {sorted(missing)}")
263
+
264
+ s_f = (
265
+ s.filter(pl.col("signal_type") == signal_type.value)
266
+ .select([self.pair_col, self.ts_col])
267
+ .unique(subset=[self.pair_col, self.ts_col])
268
+ )
269
+ return df.join(s_f, on=[self.pair_col, self.ts_col], how="inner")
270
+
271
+ @abstractmethod
272
+ def compute_group(
273
+ self, group_df: pl.DataFrame, data_context: dict[str, Any] | None
274
+ ) -> pl.DataFrame:
275
+ """Compute labels for single pair group.
276
+
277
+ Core labeling logic - must be implemented by subclasses.
278
+
279
+ CRITICAL: Must preserve row count (len(output) == len(input)).
280
+ No filtering allowed inside compute_group.
281
+
282
+ Args:
283
+ group_df (pl.DataFrame): Single pair's data, sorted by timestamp.
284
+ data_context (dict[str, Any] | None): Additional context.
285
+
286
+ Returns:
287
+ pl.DataFrame: Same length as input with added label columns.
288
+
289
+ Example:
290
+ ```python
291
+ def compute_group(self, group_df, data_context=None):
292
+ # Compute 10-bar forward return
293
+ return group_df.with_columns([
294
+ pl.col("close").shift(-10).alias("future_close")
295
+ ]).with_columns([
296
+ ((pl.col("future_close") / pl.col("close")) - 1).alias("return"),
297
+ pl.when((pl.col("future_close") / pl.col("close") - 1) > 0.01)
298
+ .then(pl.lit(SignalType.RISE.value))
299
+ .otherwise(pl.lit(SignalType.NONE.value))
300
+ .alias("label")
301
+ ])
302
+ ```
303
+
304
+ Note:
305
+ Output must have same height as input (length-preserving).
306
+ Use shift(-n) for forward-looking operations.
307
+ Last N bars will have null labels (no future data).
308
+ """
309
+ raise NotImplementedError
310
+
311
+ def _validate_input_pl(self, df: pl.DataFrame) -> None:
312
+ """Validate input DataFrame schema.
313
+
314
+ Args:
315
+ df (pl.DataFrame): Input to validate.
316
+
317
+ Raises:
318
+ ValueError: If required columns missing.
319
+ """
320
+ missing = [c for c in (self.pair_col, self.ts_col) if c not in df.columns]
321
+ if missing:
322
+ raise ValueError(f"Missing required columns: {missing}")
323
+
324
+ def _apply_signal_mask(
325
+ self,
326
+ df: pl.DataFrame,
327
+ data_context: dict[str, Any],
328
+ group_df: pl.DataFrame,
329
+ ) -> pl.DataFrame:
330
+ """Mask labels to signal timestamps only.
331
+
332
+ Labels are computed for all rows, but only signal timestamps
333
+ get actual labels; others are set to SignalType.NONE.
334
+
335
+ Used for meta-labeling: only label at detected signal points,
336
+ not every bar.
337
+
338
+ Args:
339
+ df (pl.DataFrame): DataFrame with computed labels.
340
+ data_context (dict[str, Any]): Must contain "signal_keys" DataFrame.
341
+ group_df (pl.DataFrame): Original group data for extracting pair value.
342
+
343
+ Returns:
344
+ pl.DataFrame: DataFrame with masked labels.
345
+
346
+ Example:
347
+ ```python
348
+ # In compute_group with masking
349
+ def compute_group(self, group_df, data_context=None):
350
+ # Compute labels for all rows
351
+ labeled = group_df.with_columns([...])
352
+
353
+ # Mask to signal timestamps only
354
+ if self.mask_to_signals and data_context:
355
+ labeled = self._apply_signal_mask(
356
+ labeled, data_context, group_df
357
+ )
358
+
359
+ return labeled
360
+ ```
361
+
362
+ Note:
363
+ Requires signal_keys in data_context with (pair, timestamp) columns.
364
+ Non-signal rows get label=SignalType.NONE.
365
+ Metadata columns also masked if include_meta=True.
366
+ """
367
+ signal_keys: pl.DataFrame = data_context["signal_keys"]
368
+ pair_value = group_df.get_column(self.pair_col)[0]
369
+
370
+ signal_ts = (
371
+ signal_keys.filter(pl.col(self.pair_col) == pair_value)
372
+ .select(self.ts_col)
373
+ .unique()
374
+ )
375
+
376
+ if signal_ts.height == 0:
377
+ df = df.with_columns(pl.lit(SignalType.NONE.value).alias(self.out_col))
378
+ if self.include_meta:
379
+ df = df.with_columns(
380
+ [pl.lit(None).alias(col) for col in self.meta_columns]
381
+ )
382
+ else:
383
+ is_signal = pl.col("_is_signal").fill_null(False)
384
+ mask_exprs = [
385
+ pl.when(is_signal)
386
+ .then(pl.col(self.out_col))
387
+ .otherwise(pl.lit(SignalType.NONE.value))
388
+ .alias(self.out_col),
389
+ ]
390
+ if self.include_meta:
391
+ mask_exprs += [
392
+ pl.when(is_signal)
393
+ .then(pl.col(col))
394
+ .otherwise(pl.lit(None))
395
+ .alias(col)
396
+ for col in self.meta_columns
397
+ ]
398
+
399
+ df = (
400
+ df.join(
401
+ signal_ts.with_columns(pl.lit(True).alias("_is_signal")),
402
+ on=self.ts_col,
403
+ how="left",
404
+ )
405
+ .with_columns(mask_exprs)
406
+ .drop("_is_signal")
407
+ )
408
+
409
+ return df
@@ -0,0 +1,93 @@
1
+ # IMPORTANT
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import polars as pl
7
+
8
+ from signalflow.core import SignalType
9
+ from signalflow.target.base import Labeler
10
+ from signalflow.core import sf_component
11
+
12
+
13
+ @dataclass
14
+ @sf_component(name="fixed_horizon")
15
+ class FixedHorizonLabeler(Labeler):
16
+ """
17
+ Fixed-Horizon Labeling:
18
+ label[t0] = sign(close[t0 + horizon] - close[t0])
19
+
20
+ If signals provided, labels are written only on signal rows,
21
+ while horizon is computed on full series (per pair).
22
+ """
23
+ price_col: str = "close"
24
+ horizon: int = 60
25
+
26
+ meta_columns: tuple[str, ...] = ("t1", "ret")
27
+
28
+ def __post_init__(self) -> None:
29
+ if self.horizon <= 0:
30
+ raise ValueError("horizon must be > 0")
31
+
32
+ cols = [self.out_col]
33
+ if self.include_meta:
34
+ cols += list(self.meta_columns)
35
+ self.output_columns = cols
36
+
37
+ def compute_group(
38
+ self, group_df: pl.DataFrame, data_context: dict[str, Any] | None
39
+ ) -> pl.DataFrame:
40
+ if self.price_col not in group_df.columns:
41
+ raise ValueError(f"Missing required column '{self.price_col}'")
42
+
43
+ if group_df.height == 0:
44
+ return group_df
45
+
46
+ h = int(self.horizon)
47
+ price = pl.col(self.price_col)
48
+ future_price = price.shift(-h)
49
+
50
+ df = group_df.with_columns(future_price.alias("_future_price"))
51
+
52
+ label_expr = (
53
+ pl.when(
54
+ pl.col("_future_price").is_null()
55
+ | pl.col(self.price_col).is_null()
56
+ | (pl.col(self.price_col) <= 0)
57
+ | (pl.col("_future_price") <= 0)
58
+ )
59
+ .then(pl.lit(SignalType.NONE.value))
60
+ .when(pl.col("_future_price") > pl.col(self.price_col))
61
+ .then(pl.lit(SignalType.RISE.value))
62
+ .when(pl.col("_future_price") < pl.col(self.price_col))
63
+ .then(pl.lit(SignalType.FALL.value))
64
+ .otherwise(pl.lit(SignalType.NONE.value))
65
+ )
66
+
67
+ df = df.with_columns(label_expr.alias(self.out_col))
68
+
69
+ if self.include_meta:
70
+ df = df.with_columns(
71
+ [
72
+ pl.col(self.ts_col).shift(-h).alias("t1"),
73
+ pl.when(
74
+ pl.col("_future_price").is_not_null()
75
+ & (pl.col(self.price_col) > 0)
76
+ & (pl.col("_future_price") > 0)
77
+ )
78
+ .then((pl.col("_future_price") / pl.col(self.price_col)).log())
79
+ .otherwise(pl.lit(None))
80
+ .alias("ret"),
81
+ ]
82
+ )
83
+
84
+ df = df.drop("_future_price")
85
+
86
+ if (
87
+ self.mask_to_signals
88
+ and data_context is not None
89
+ and "signal_keys" in data_context
90
+ ):
91
+ df = self._apply_signal_mask(df, data_context, group_df)
92
+
93
+ return df
@@ -0,0 +1,162 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import polars as pl
8
+ from numba import njit, prange
9
+
10
+ from signalflow.core import sf_component, SignalType
11
+ from signalflow.target.base import Labeler
12
+
13
+
14
+ @njit(parallel=True, cache=True)
15
+ def _find_first_hit_static(
16
+ prices: np.ndarray,
17
+ pt: np.ndarray,
18
+ sl: np.ndarray,
19
+ lookforward: int,
20
+ ) -> tuple[np.ndarray, np.ndarray]:
21
+ """
22
+ Finds the first hit for static barriers.
23
+
24
+ Returns:
25
+ up_off: offset of the first PT hit (0 = no hit)
26
+ dn_off: offset of the first SL hit (0 = no hit)
27
+ """
28
+ n = len(prices)
29
+ up_off = np.zeros(n, dtype=np.int32)
30
+ dn_off = np.zeros(n, dtype=np.int32)
31
+
32
+ for i in prange(n):
33
+ pt_i = pt[i]
34
+ sl_i = sl[i]
35
+
36
+ max_j = min(i + lookforward, n - 1)
37
+
38
+ for k in range(1, max_j - i + 1):
39
+ p = prices[i + k]
40
+
41
+ if up_off[i] == 0 and p >= pt_i:
42
+ up_off[i] = k
43
+
44
+ if dn_off[i] == 0 and p <= sl_i:
45
+ dn_off[i] = k
46
+
47
+ if up_off[i] > 0 and dn_off[i] > 0:
48
+ break
49
+
50
+ return up_off, dn_off
51
+
52
+
53
+ @dataclass
54
+ @sf_component(name="static_triple_barrier")
55
+ class StaticTripleBarrierLabeler(Labeler):
56
+ """
57
+ Triple-Barrier (first-touch) labeling with STATIC horizontal barriers.
58
+ Numba-accelerated version.
59
+
60
+ De Prado's framework:
61
+ - Vertical barrier at t1 = t0 + lookforward_window
62
+ - Horizontal barriers defined as % from initial price at t0:
63
+ pt = close[t0] * (1 + profit_pct)
64
+ sl = close[t0] * (1 - stop_loss_pct)
65
+ - Label by first touch within (t0, t1]:
66
+ RISE if PT touched first (ties -> PT)
67
+ FALL if SL touched first
68
+ NONE if none touched by t1
69
+ """
70
+ price_col: str = "close"
71
+
72
+ lookforward_window: int = 1440
73
+ profit_pct: float = 0.01
74
+ stop_loss_pct: float = 0.01
75
+
76
+ def __post_init__(self) -> None:
77
+ if self.lookforward_window <= 0:
78
+ raise ValueError("lookforward_window must be > 0")
79
+ if self.profit_pct <= 0 or self.stop_loss_pct <= 0:
80
+ raise ValueError("profit_pct/stop_loss_pct must be > 0")
81
+
82
+ cols = [self.out_col]
83
+ if self.include_meta:
84
+ cols += list(self.meta_columns)
85
+ self.output_columns = cols
86
+
87
+ def compute_group(
88
+ self, group_df: pl.DataFrame, data_context: dict[str, Any] | None
89
+ ) -> pl.DataFrame:
90
+ if self.price_col not in group_df.columns:
91
+ raise ValueError(f"Missing required column '{self.price_col}'")
92
+
93
+ if group_df.height == 0:
94
+ return group_df
95
+
96
+ lf = int(self.lookforward_window)
97
+ n = group_df.height
98
+
99
+ prices = group_df.get_column(self.price_col).to_numpy().astype(np.float64)
100
+ pt = prices * (1.0 + self.profit_pct)
101
+ sl = prices * (1.0 - self.stop_loss_pct)
102
+
103
+ up_off, dn_off = _find_first_hit_static(prices, pt, sl, lf)
104
+
105
+ up_off_series = pl.Series("_up_off", up_off).replace(0, None).cast(pl.Int32)
106
+ dn_off_series = pl.Series("_dn_off", dn_off).replace(0, None).cast(pl.Int32)
107
+
108
+ df = group_df.with_columns([up_off_series, dn_off_series])
109
+
110
+ choose_up = pl.col("_up_off").is_not_null() & (
111
+ pl.col("_dn_off").is_null() | (pl.col("_up_off") <= pl.col("_dn_off"))
112
+ )
113
+ choose_dn = pl.col("_dn_off").is_not_null() & (
114
+ pl.col("_up_off").is_null() | (pl.col("_dn_off") < pl.col("_up_off"))
115
+ )
116
+
117
+ df = df.with_columns(
118
+ pl.when(choose_up)
119
+ .then(pl.lit(SignalType.RISE.value))
120
+ .when(choose_dn)
121
+ .then(pl.lit(SignalType.FALL.value))
122
+ .otherwise(pl.lit(SignalType.NONE.value))
123
+ .alias(self.out_col)
124
+ )
125
+
126
+ if self.include_meta:
127
+ ts_arr = group_df.get_column(self.ts_col).to_numpy()
128
+
129
+ up_np = up_off_series.fill_null(0).to_numpy()
130
+ dn_np = dn_off_series.fill_null(0).to_numpy()
131
+ idx = np.arange(n)
132
+
133
+ hit_off = np.where(
134
+ (up_np > 0) & ((dn_np == 0) | (up_np <= dn_np)),
135
+ up_np,
136
+ np.where(dn_np > 0, dn_np, 0),
137
+ )
138
+
139
+ hit_idx = np.clip(idx + hit_off, 0, n - 1)
140
+ vert_idx = np.clip(idx + lf, 0, n - 1)
141
+ final_idx = np.where(hit_off > 0, hit_idx, vert_idx)
142
+
143
+ t_hit = ts_arr[final_idx]
144
+ ret = np.log(prices[final_idx] / prices)
145
+
146
+ df = df.with_columns(
147
+ [
148
+ pl.Series("t_hit", t_hit),
149
+ pl.Series("ret", ret),
150
+ ]
151
+ )
152
+
153
+ if (
154
+ self.mask_to_signals
155
+ and data_context is not None
156
+ and "signal_keys" in data_context
157
+ ):
158
+ df = self._apply_signal_mask(df, data_context, group_df)
159
+
160
+ df = df.drop(["_up_off", "_dn_off"])
161
+
162
+ return df