scitex 2.14.0__py3-none-any.whl → 2.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. scitex/__init__.py +71 -17
  2. scitex/_env_loader.py +156 -0
  3. scitex/_mcp_resources/__init__.py +37 -0
  4. scitex/_mcp_resources/_cheatsheet.py +135 -0
  5. scitex/_mcp_resources/_figrecipe.py +138 -0
  6. scitex/_mcp_resources/_formats.py +102 -0
  7. scitex/_mcp_resources/_modules.py +337 -0
  8. scitex/_mcp_resources/_session.py +149 -0
  9. scitex/_mcp_tools/__init__.py +4 -0
  10. scitex/_mcp_tools/audio.py +66 -0
  11. scitex/_mcp_tools/diagram.py +11 -95
  12. scitex/_mcp_tools/introspect.py +210 -0
  13. scitex/_mcp_tools/plt.py +260 -305
  14. scitex/_mcp_tools/scholar.py +74 -0
  15. scitex/_mcp_tools/social.py +244 -0
  16. scitex/_mcp_tools/template.py +24 -0
  17. scitex/_mcp_tools/writer.py +21 -204
  18. scitex/ai/_gen_ai/_PARAMS.py +10 -7
  19. scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
  20. scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
  21. scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
  22. scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
  23. scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
  24. scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
  25. scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
  26. scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
  27. scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
  28. scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +30 -1550
  29. scitex/ai/classification/timeseries/_sliding_window_core.py +467 -0
  30. scitex/ai/classification/timeseries/_sliding_window_plotting.py +369 -0
  31. scitex/audio/README.md +40 -36
  32. scitex/audio/__init__.py +129 -61
  33. scitex/audio/_branding.py +185 -0
  34. scitex/audio/_mcp/__init__.py +32 -0
  35. scitex/audio/_mcp/handlers.py +59 -6
  36. scitex/audio/_mcp/speak_handlers.py +238 -0
  37. scitex/audio/_relay.py +225 -0
  38. scitex/audio/_tts.py +18 -10
  39. scitex/audio/engines/base.py +17 -10
  40. scitex/audio/engines/elevenlabs_engine.py +7 -2
  41. scitex/audio/mcp_server.py +228 -75
  42. scitex/canvas/README.md +1 -1
  43. scitex/canvas/editor/_dearpygui/__init__.py +25 -0
  44. scitex/canvas/editor/_dearpygui/_editor.py +147 -0
  45. scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
  46. scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
  47. scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
  48. scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
  49. scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
  50. scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
  51. scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
  52. scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
  53. scitex/canvas/editor/_dearpygui/_selection.py +295 -0
  54. scitex/canvas/editor/_dearpygui/_state.py +93 -0
  55. scitex/canvas/editor/_dearpygui/_utils.py +61 -0
  56. scitex/canvas/editor/flask_editor/_core/__init__.py +27 -0
  57. scitex/canvas/editor/flask_editor/_core/_bbox_extraction.py +200 -0
  58. scitex/canvas/editor/flask_editor/_core/_editor.py +173 -0
  59. scitex/canvas/editor/flask_editor/_core/_export_helpers.py +353 -0
  60. scitex/canvas/editor/flask_editor/_core/_routes_basic.py +190 -0
  61. scitex/canvas/editor/flask_editor/_core/_routes_export.py +332 -0
  62. scitex/canvas/editor/flask_editor/_core/_routes_panels.py +252 -0
  63. scitex/canvas/editor/flask_editor/_core/_routes_save.py +218 -0
  64. scitex/canvas/editor/flask_editor/_core.py +25 -1684
  65. scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
  66. scitex/cli/__init__.py +38 -43
  67. scitex/cli/audio.py +76 -27
  68. scitex/cli/capture.py +13 -20
  69. scitex/cli/introspect.py +481 -0
  70. scitex/cli/main.py +200 -109
  71. scitex/cli/mcp.py +60 -34
  72. scitex/cli/plt.py +357 -0
  73. scitex/cli/repro.py +15 -8
  74. scitex/cli/resource.py +15 -8
  75. scitex/cli/scholar/__init__.py +23 -8
  76. scitex/cli/scholar/_crossref_scitex.py +296 -0
  77. scitex/cli/scholar/_fetch.py +25 -3
  78. scitex/cli/social.py +314 -0
  79. scitex/cli/stats.py +15 -8
  80. scitex/cli/template.py +129 -12
  81. scitex/cli/tex.py +15 -8
  82. scitex/cli/writer.py +132 -8
  83. scitex/cloud/__init__.py +41 -2
  84. scitex/config/README.md +1 -1
  85. scitex/config/__init__.py +16 -2
  86. scitex/config/_env_registry.py +256 -0
  87. scitex/context/__init__.py +22 -0
  88. scitex/dev/__init__.py +20 -1
  89. scitex/diagram/__init__.py +42 -19
  90. scitex/diagram/mcp_server.py +13 -125
  91. scitex/gen/__init__.py +50 -14
  92. scitex/gen/_list_packages.py +4 -4
  93. scitex/introspect/__init__.py +82 -0
  94. scitex/introspect/_call_graph.py +303 -0
  95. scitex/introspect/_class_hierarchy.py +163 -0
  96. scitex/introspect/_core.py +41 -0
  97. scitex/introspect/_docstring.py +131 -0
  98. scitex/introspect/_examples.py +113 -0
  99. scitex/introspect/_imports.py +271 -0
  100. scitex/{gen/_inspect_module.py → introspect/_list_api.py} +43 -54
  101. scitex/introspect/_mcp/__init__.py +41 -0
  102. scitex/introspect/_mcp/handlers.py +233 -0
  103. scitex/introspect/_members.py +155 -0
  104. scitex/introspect/_resolve.py +89 -0
  105. scitex/introspect/_signature.py +131 -0
  106. scitex/introspect/_source.py +80 -0
  107. scitex/introspect/_type_hints.py +172 -0
  108. scitex/io/_save.py +1 -2
  109. scitex/io/bundle/README.md +1 -1
  110. scitex/logging/_formatters.py +19 -9
  111. scitex/mcp_server.py +98 -5
  112. scitex/os/__init__.py +4 -0
  113. scitex/{gen → os}/_check_host.py +4 -5
  114. scitex/plt/__init__.py +245 -550
  115. scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
  116. scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
  117. scitex/plt/gallery/README.md +1 -1
  118. scitex/plt/utils/_hitmap/__init__.py +82 -0
  119. scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
  120. scitex/plt/utils/_hitmap/_color_application.py +346 -0
  121. scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
  122. scitex/plt/utils/_hitmap/_constants.py +40 -0
  123. scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
  124. scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
  125. scitex/plt/utils/_hitmap/_query.py +113 -0
  126. scitex/plt/utils/_hitmap.py +46 -1616
  127. scitex/plt/utils/_metadata/__init__.py +80 -0
  128. scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
  129. scitex/plt/utils/_metadata/_artists/_base.py +195 -0
  130. scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
  131. scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
  132. scitex/plt/utils/_metadata/_artists/_images.py +80 -0
  133. scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
  134. scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
  135. scitex/plt/utils/_metadata/_artists/_text.py +106 -0
  136. scitex/plt/utils/_metadata/_csv.py +416 -0
  137. scitex/plt/utils/_metadata/_detect.py +225 -0
  138. scitex/plt/utils/_metadata/_legend.py +127 -0
  139. scitex/plt/utils/_metadata/_rounding.py +117 -0
  140. scitex/plt/utils/_metadata/_verification.py +202 -0
  141. scitex/schema/README.md +1 -1
  142. scitex/scholar/__init__.py +8 -0
  143. scitex/scholar/_mcp/crossref_handlers.py +265 -0
  144. scitex/scholar/core/Scholar.py +63 -1700
  145. scitex/scholar/core/_mixins/__init__.py +36 -0
  146. scitex/scholar/core/_mixins/_enrichers.py +270 -0
  147. scitex/scholar/core/_mixins/_library_handlers.py +100 -0
  148. scitex/scholar/core/_mixins/_loaders.py +103 -0
  149. scitex/scholar/core/_mixins/_pdf_download.py +375 -0
  150. scitex/scholar/core/_mixins/_pipeline.py +312 -0
  151. scitex/scholar/core/_mixins/_project_handlers.py +125 -0
  152. scitex/scholar/core/_mixins/_savers.py +69 -0
  153. scitex/scholar/core/_mixins/_search.py +103 -0
  154. scitex/scholar/core/_mixins/_services.py +88 -0
  155. scitex/scholar/core/_mixins/_url_finding.py +105 -0
  156. scitex/scholar/crossref_scitex.py +367 -0
  157. scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
  158. scitex/scholar/examples/00_run_all.sh +120 -0
  159. scitex/scholar/jobs/_executors.py +27 -3
  160. scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
  161. scitex/scholar/pdf_download/_cli.py +154 -0
  162. scitex/scholar/pdf_download/strategies/__init__.py +11 -8
  163. scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
  164. scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
  165. scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
  166. scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
  167. scitex/scholar/pipelines/_single_steps.py +71 -36
  168. scitex/scholar/storage/_LibraryManager.py +97 -1695
  169. scitex/scholar/storage/_mixins/__init__.py +30 -0
  170. scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
  171. scitex/scholar/storage/_mixins/_library_operations.py +218 -0
  172. scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
  173. scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
  174. scitex/scholar/storage/_mixins/_resolution.py +376 -0
  175. scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
  176. scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
  177. scitex/scholar/url_finder/.tmp/open_url/KNOWN_RESOLVERS.py +462 -0
  178. scitex/scholar/url_finder/.tmp/open_url/README.md +223 -0
  179. scitex/scholar/url_finder/.tmp/open_url/_DOIToURLResolver.py +694 -0
  180. scitex/scholar/url_finder/.tmp/open_url/_OpenURLResolver.py +1160 -0
  181. scitex/scholar/url_finder/.tmp/open_url/_ResolverLinkFinder.py +344 -0
  182. scitex/scholar/url_finder/.tmp/open_url/__init__.py +24 -0
  183. scitex/security/README.md +3 -3
  184. scitex/session/README.md +1 -1
  185. scitex/session/__init__.py +26 -7
  186. scitex/session/_decorator.py +1 -1
  187. scitex/sh/README.md +1 -1
  188. scitex/sh/__init__.py +7 -4
  189. scitex/social/__init__.py +155 -0
  190. scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
  191. scitex/stats/_mcp/_handlers/__init__.py +31 -0
  192. scitex/stats/_mcp/_handlers/_corrections.py +113 -0
  193. scitex/stats/_mcp/_handlers/_descriptive.py +78 -0
  194. scitex/stats/_mcp/_handlers/_effect_size.py +106 -0
  195. scitex/stats/_mcp/_handlers/_format.py +94 -0
  196. scitex/stats/_mcp/_handlers/_normality.py +110 -0
  197. scitex/stats/_mcp/_handlers/_posthoc.py +224 -0
  198. scitex/stats/_mcp/_handlers/_power.py +247 -0
  199. scitex/stats/_mcp/_handlers/_recommend.py +102 -0
  200. scitex/stats/_mcp/_handlers/_run_test.py +279 -0
  201. scitex/stats/_mcp/_handlers/_stars.py +48 -0
  202. scitex/stats/_mcp/handlers.py +19 -1171
  203. scitex/stats/auto/_stat_style.py +175 -0
  204. scitex/stats/auto/_style_definitions.py +411 -0
  205. scitex/stats/auto/_styles.py +22 -620
  206. scitex/stats/descriptive/__init__.py +11 -8
  207. scitex/stats/descriptive/_ci.py +39 -0
  208. scitex/stats/power/_power.py +15 -4
  209. scitex/str/__init__.py +2 -1
  210. scitex/str/_title_case.py +63 -0
  211. scitex/template/README.md +1 -1
  212. scitex/template/__init__.py +25 -10
  213. scitex/template/_code_templates.py +147 -0
  214. scitex/template/_mcp/handlers.py +81 -0
  215. scitex/template/_mcp/tool_schemas.py +55 -0
  216. scitex/template/_templates/__init__.py +51 -0
  217. scitex/template/_templates/audio.py +233 -0
  218. scitex/template/_templates/canvas.py +312 -0
  219. scitex/template/_templates/capture.py +268 -0
  220. scitex/template/_templates/config.py +43 -0
  221. scitex/template/_templates/diagram.py +294 -0
  222. scitex/template/_templates/io.py +107 -0
  223. scitex/template/_templates/module.py +53 -0
  224. scitex/template/_templates/plt.py +202 -0
  225. scitex/template/_templates/scholar.py +267 -0
  226. scitex/template/_templates/session.py +130 -0
  227. scitex/template/_templates/session_minimal.py +43 -0
  228. scitex/template/_templates/session_plot.py +67 -0
  229. scitex/template/_templates/session_stats.py +77 -0
  230. scitex/template/_templates/stats.py +323 -0
  231. scitex/template/_templates/writer.py +296 -0
  232. scitex/template/clone_writer_directory.py +5 -5
  233. scitex/ui/_backends/_email.py +10 -2
  234. scitex/ui/_backends/_webhook.py +5 -1
  235. scitex/web/_search_pubmed.py +10 -6
  236. scitex/writer/README.md +1 -1
  237. scitex/writer/_mcp/handlers.py +11 -744
  238. scitex/writer/_mcp/tool_schemas.py +5 -335
  239. scitex-2.15.2.dist-info/METADATA +648 -0
  240. {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/RECORD +246 -150
  241. scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
  242. scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
  243. scitex/dev/plt/data/mpl/PLOTTING_FUNCTIONS.yaml +0 -90
  244. scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES.yaml +0 -1571
  245. scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES_DETAILED.yaml +0 -6262
  246. scitex/dev/plt/data/mpl/SIGNATURES_FLATTENED.yaml +0 -1274
  247. scitex/dev/plt/data/mpl/dir_ax.txt +0 -459
  248. scitex/diagram/_compile.py +0 -312
  249. scitex/diagram/_diagram.py +0 -355
  250. scitex/diagram/_mcp/__init__.py +0 -4
  251. scitex/diagram/_mcp/handlers.py +0 -400
  252. scitex/diagram/_mcp/tool_schemas.py +0 -157
  253. scitex/diagram/_presets.py +0 -173
  254. scitex/diagram/_schema.py +0 -182
  255. scitex/diagram/_split.py +0 -278
  256. scitex/gen/_ci.py +0 -12
  257. scitex/gen/_title_case.py +0 -89
  258. scitex/plt/_mcp/__init__.py +0 -4
  259. scitex/plt/_mcp/_handlers_annotation.py +0 -102
  260. scitex/plt/_mcp/_handlers_figure.py +0 -195
  261. scitex/plt/_mcp/_handlers_plot.py +0 -252
  262. scitex/plt/_mcp/_handlers_style.py +0 -219
  263. scitex/plt/_mcp/handlers.py +0 -74
  264. scitex/plt/_mcp/tool_schemas.py +0 -497
  265. scitex/plt/mcp_server.py +0 -231
  266. scitex/scholar/data/.gitkeep +0 -0
  267. scitex/scholar/data/README.md +0 -44
  268. scitex/scholar/data/bib_files/bibliography.bib +0 -1952
  269. scitex/scholar/data/bib_files/neurovista.bib +0 -277
  270. scitex/scholar/data/bib_files/neurovista_enriched.bib +0 -441
  271. scitex/scholar/data/bib_files/neurovista_enriched_enriched.bib +0 -441
  272. scitex/scholar/data/bib_files/neurovista_processed.bib +0 -338
  273. scitex/scholar/data/bib_files/openaccess.bib +0 -89
  274. scitex/scholar/data/bib_files/pac-seizure_prediction_enriched.bib +0 -2178
  275. scitex/scholar/data/bib_files/pac.bib +0 -698
  276. scitex/scholar/data/bib_files/pac_enriched.bib +0 -1061
  277. scitex/scholar/data/bib_files/pac_processed.bib +0 -0
  278. scitex/scholar/data/bib_files/pac_titles.txt +0 -75
  279. scitex/scholar/data/bib_files/paywalled.bib +0 -98
  280. scitex/scholar/data/bib_files/related-papers-by-coauthors.bib +0 -58
  281. scitex/scholar/data/bib_files/related-papers-by-coauthors_enriched.bib +0 -87
  282. scitex/scholar/data/bib_files/seizure_prediction.bib +0 -694
  283. scitex/scholar/data/bib_files/seizure_prediction_processed.bib +0 -0
  284. scitex/scholar/data/bib_files/test_complete_enriched.bib +0 -437
  285. scitex/scholar/data/bib_files/test_final_enriched.bib +0 -437
  286. scitex/scholar/data/bib_files/test_seizure.bib +0 -46
  287. scitex/scholar/data/impact_factor/JCR_IF_2022.xlsx +0 -0
  288. scitex/scholar/data/impact_factor/JCR_IF_2024.db +0 -0
  289. scitex/scholar/data/impact_factor/JCR_IF_2024.xlsx +0 -0
  290. scitex/scholar/data/impact_factor/JCR_IF_2024_v01.db +0 -0
  291. scitex/scholar/data/impact_factor.db +0 -0
  292. scitex/scholar/examples/SUGGESTIONS.md +0 -865
  293. scitex/scholar/examples/dev.py +0 -38
  294. scitex-2.14.0.dist-info/METADATA +0 -1238
  295. /scitex/{gen → context}/_detect_environment.py +0 -0
  296. /scitex/{gen → context}/_get_notebook_path.py +0 -0
  297. /scitex/{gen/_shell.py → sh/_shell_legacy.py} +0 -0
  298. {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/WHEEL +0 -0
  299. {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/entry_points.txt +0 -0
  300. {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,59 +1,30 @@
1
1
  #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Timestamp: "2025-10-03 03:22:45 (ywatanabe)"
4
- # File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/classification/timeseries/_TimeSeriesSlidingWindowSplit.py
5
- # ----------------------------------------
6
- from __future__ import annotations
7
- import os
8
-
9
- __FILE__ = "./src/scitex/ml/classification/timeseries/_TimeSeriesSlidingWindowSplit.py"
10
- __DIR__ = os.path.dirname(__FILE__)
11
- # ----------------------------------------
2
+ # Timestamp: "2026-01-24 (ywatanabe)"
3
+ # File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py
12
4
 
13
- """
14
- Functionalities:
15
- - Implements sliding window cross-validation for time series
16
- - Creates overlapping train/test windows that slide through time
17
- - Supports temporal gaps between train and test sets
18
- - Provides visualization with scatter plots showing actual data points
19
- - Validates temporal order in all windows
20
- - Ensures no data leakage between train and test sets
5
+ """Sliding window cross-validation for time series.
21
6
 
22
- Dependencies:
23
- - packages:
24
- - numpy
25
- - sklearn
26
- - matplotlib
27
- - scitex
7
+ This module provides the TimeSeriesSlidingWindowSplit class which combines:
8
+ - Core splitting functionality from _sliding_window_core
9
+ - Visualization support from _sliding_window_plotting
28
10
 
29
- IO:
30
- - input-files:
31
- - None (generates synthetic data for demonstration)
32
- - output-files:
33
- - ./sliding_window_demo.png (visualization with scatter plots)
11
+ For demo/example usage, see examples/ai/classification/sliding_window_demo.py
34
12
  """
35
13
 
36
- """Imports"""
37
- import argparse
38
- from typing import Iterator, Optional, Tuple
14
+ from __future__ import annotations
39
15
 
40
- import matplotlib.patches as patches
41
- import matplotlib.pyplot as plt
42
- import numpy as np
43
- import scitex as stx
44
- from scitex import logging
45
- from sklearn.model_selection import BaseCrossValidator
46
- from sklearn.utils.validation import _num_samples
16
+ from typing import Optional
47
17
 
48
- logger = logging.getLogger(__name__)
18
+ from ._sliding_window_core import TimeSeriesSlidingWindowSplitCore
19
+ from ._sliding_window_plotting import SlidingWindowPlottingMixin
49
20
 
50
- COLORS = stx.plt.color.PARAMS
51
- COLORS["RGBA_NORM"]
21
+ __all__ = ["TimeSeriesSlidingWindowSplit"]
52
22
 
53
23
 
54
- class TimeSeriesSlidingWindowSplit(BaseCrossValidator):
55
- """
56
- Sliding window cross-validation for time series.
24
+ class TimeSeriesSlidingWindowSplit(
25
+ SlidingWindowPlottingMixin, TimeSeriesSlidingWindowSplitCore
26
+ ):
27
+ """Sliding window cross-validation for time series.
57
28
 
58
29
  Creates train/test windows that slide through time with configurable behavior.
59
30
 
@@ -114,6 +85,9 @@ class TimeSeriesSlidingWindowSplit(BaseCrossValidator):
114
85
  ... )
115
86
  >>> for train_idx, test_idx in swcv.split(X, y, timestamps):
116
87
  ... print(f"Train: {len(train_idx)}, Test: {len(test_idx)}")
88
+ >>>
89
+ >>> # Visualize splits
90
+ >>> fig = swcv.plot_splits(X, y, timestamps)
117
91
  """
118
92
 
119
93
  def __init__(
@@ -129,1512 +103,18 @@ class TimeSeriesSlidingWindowSplit(BaseCrossValidator):
129
103
  undersample: bool = False,
130
104
  n_splits: Optional[int] = None,
131
105
  ):
132
- # Handle n_splits mode vs manual mode
133
- if n_splits is not None:
134
- # n_splits mode: automatically calculate window_size and test_size
135
- self.n_splits_mode = True
136
- self._n_splits = n_splits
137
- # Use placeholder values, will be calculated in split()
138
- self.window_size = window_size if window_size is not None else 50
139
- self.test_size = test_size if test_size is not None else 10
140
- else:
141
- # Manual mode: require window_size and test_size
142
- if window_size is None or test_size is None:
143
- raise ValueError(
144
- "Either n_splits OR (window_size AND test_size) must be specified"
145
- )
146
- self.n_splits_mode = False
147
- self._n_splits = None
148
- self.window_size = window_size
149
- self.test_size = test_size
150
-
151
- self.gap = gap
152
- self.val_ratio = val_ratio
153
- self.random_state = random_state
154
- self.rng = np.random.default_rng(random_state)
155
- self.overlapping_tests = overlapping_tests
156
- self.expanding_window = expanding_window
157
- self.undersample = undersample
158
-
159
- # Handle step_size logic
160
- if not overlapping_tests:
161
- # overlapping_tests=False: ensure non-overlapping tests
162
- if step_size is not None and step_size < test_size:
163
- logger.warning(
164
- f"overlapping_tests=False but step_size={step_size} < test_size={test_size}. "
165
- f"This would cause test overlap. Setting step_size=test_size={test_size}."
166
- )
167
- self.step_size = test_size
168
- elif step_size is None:
169
- # Default: non-overlapping tests
170
- self.step_size = test_size
171
- logger.info(
172
- f"step_size not specified with overlapping_tests=False. "
173
- f"Using step_size=test_size={test_size} for non-overlapping tests."
174
- )
175
- else:
176
- # step_size >= test_size: acceptable, no overlap
177
- self.step_size = step_size
178
- else:
179
- # overlapping_tests=True: allow any step_size
180
- if step_size is None:
181
- # Default for overlapping: half the test size for 50% overlap
182
- self.step_size = max(1, test_size // 2)
183
- logger.info(
184
- f"step_size not specified with overlapping_tests=True. "
185
- f"Using step_size={self.step_size} (50% overlap)."
186
- )
187
- else:
188
- self.step_size = step_size
189
-
190
- def _undersample_indices(
191
- self, train_indices: np.ndarray, y: np.ndarray, timestamps: np.ndarray
192
- ) -> np.ndarray:
193
- """
194
- Undersample majority class to balance training set.
195
-
196
- Maintains temporal order of samples.
197
-
198
- Parameters
199
- ----------
200
- train_indices : ndarray
201
- Original training indices
202
- y : ndarray
203
- Full label array
204
- timestamps : ndarray
205
- Full timestamp array
206
-
207
- Returns
208
- -------
209
- ndarray
210
- Undersampled training indices (sorted by timestamp)
211
- """
212
- # Get labels for training indices
213
- train_labels = y[train_indices]
214
-
215
- # Find unique classes and their counts
216
- unique_classes, class_counts = np.unique(train_labels, return_counts=True)
217
-
218
- if len(unique_classes) < 2:
219
- # Only one class, no undersampling needed
220
- return train_indices
221
-
222
- # Find minority class count
223
- min_count = class_counts.min()
224
-
225
- # Undersample each class to match minority class count
226
- undersampled_indices = []
227
- for cls in unique_classes:
228
- # Find indices of this class within train_indices
229
- cls_mask = train_labels == cls
230
- cls_train_indices = train_indices[cls_mask]
231
-
232
- if len(cls_train_indices) > min_count:
233
- # Randomly select min_count samples
234
- selected = self.rng.choice(
235
- cls_train_indices, size=min_count, replace=False
236
- )
237
- undersampled_indices.extend(selected)
238
- else:
239
- # Keep all samples from minority class
240
- undersampled_indices.extend(cls_train_indices)
241
-
242
- # Convert to array and sort by timestamp to maintain temporal order
243
- undersampled_indices = np.array(undersampled_indices)
244
- temporal_order = np.argsort(timestamps[undersampled_indices])
245
- undersampled_indices = undersampled_indices[temporal_order]
246
-
247
- return undersampled_indices
248
-
249
- def split(
250
- self,
251
- X: np.ndarray,
252
- y: Optional[np.ndarray] = None,
253
- timestamps: Optional[np.ndarray] = None,
254
- groups: Optional[np.ndarray] = None,
255
- ) -> Iterator[Tuple[np.ndarray, np.ndarray]]:
256
- """
257
- Generate sliding window splits.
258
-
259
- Parameters
260
- ----------
261
- X : array-like, shape (n_samples, n_features)
262
- Training data
263
- y : array-like, shape (n_samples,), optional
264
- Target variable
265
- timestamps : array-like, shape (n_samples,), optional
266
- Timestamps for temporal ordering. If None, uses sequential order
267
- groups : array-like, shape (n_samples,), optional
268
- Group labels (not used in this splitter)
269
-
270
- Yields
271
- ------
272
- train : ndarray
273
- Training set indices
274
- test : ndarray
275
- Test set indices
276
- """
277
- if timestamps is None:
278
- timestamps = np.arange(len(X))
279
-
280
- n_samples = _num_samples(X)
281
- indices = np.arange(n_samples)
282
-
283
- # Sort by timestamp to get temporal order
284
- time_order = np.argsort(timestamps)
285
- sorted_indices = indices[time_order]
286
-
287
- # Auto-calculate sizes if using n_splits mode
288
- if self.n_splits_mode:
289
- # Calculate test_size to create exactly n_splits folds
290
- # Formula: n_samples = window_size + (n_splits * (test_size + gap))
291
- # For expanding window, window_size is minimum training size
292
- # We want non-overlapping tests by default
293
-
294
- if self.expanding_window:
295
- # Expanding window: start with minimum window, test slides forward
296
- # Let's use 20% of data as initial window (similar to sklearn)
297
- min_window_size = max(1, n_samples // (self._n_splits + 1))
298
- available_for_test = (
299
- n_samples - min_window_size - (self._n_splits * self.gap)
300
- )
301
- calculated_test_size = max(1, available_for_test // self._n_splits)
302
-
303
- # Set calculated values
304
- self.window_size = min_window_size
305
- self.test_size = calculated_test_size
306
- self.step_size = calculated_test_size # Non-overlapping by default
307
-
308
- logger.info(
309
- f"n_splits={self._n_splits} with expanding_window: "
310
- f"Calculated window_size={self.window_size}, test_size={self.test_size}"
311
- )
312
- else:
313
- # Fixed window: calculate window and test size
314
- # We want: n_samples = window_size + (n_splits * (test_size + gap))
315
- # Let's make window_size same as test_size for simplicity
316
- available = n_samples - (self._n_splits * self.gap)
317
- calculated_test_size = max(1, available // (self._n_splits + 1))
318
- calculated_window_size = calculated_test_size
319
-
320
- # Set calculated values
321
- self.window_size = calculated_window_size
322
- self.test_size = calculated_test_size
323
- self.step_size = calculated_test_size # Non-overlapping by default
324
-
325
- logger.info(
326
- f"n_splits={self._n_splits} with fixed window: "
327
- f"Calculated window_size={self.window_size}, test_size={self.test_size}"
328
- )
329
-
330
- if self.expanding_window:
331
- # Expanding window: training set grows to include all past data
332
- # Start with minimum window_size, test slides forward
333
- min_train_size = self.window_size
334
- total_min = min_train_size + self.gap + self.test_size
335
-
336
- if n_samples < total_min:
337
- logger.warning(
338
- f"Not enough samples ({n_samples}) for even one split. "
339
- f"Need at least {total_min} samples."
340
- )
341
- return
342
-
343
- # First fold starts at window_size
344
- test_start_pos = min_train_size + self.gap
345
-
346
- while test_start_pos + self.test_size <= n_samples:
347
- test_end_pos = test_start_pos + self.test_size
348
-
349
- # Training includes all data from start to before gap
350
- train_end_pos = test_start_pos - self.gap
351
- train_indices = sorted_indices[0:train_end_pos]
352
- test_indices = sorted_indices[test_start_pos:test_end_pos]
353
-
354
- # Apply undersampling if enabled and y is provided
355
- if self.undersample and y is not None:
356
- train_indices = self._undersample_indices(
357
- train_indices, y, timestamps
358
- )
359
-
360
- assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
361
-
362
- yield train_indices, test_indices
363
-
364
- # Move test window forward by step_size
365
- test_start_pos += self.step_size
366
-
367
- else:
368
- # Fixed sliding window: window slides through data
369
- total_window = self.window_size + self.gap + self.test_size
370
-
371
- for start in range(0, n_samples - total_window + 1, self.step_size):
372
- # These positions are in the sorted (temporal) domain
373
- train_end = start + self.window_size
374
- test_start = train_end + self.gap
375
- test_end = test_start + self.test_size
376
-
377
- if test_end > n_samples:
378
- break
379
-
380
- # Extract indices from the temporally sorted sequence
381
- train_indices = sorted_indices[start:train_end]
382
- test_indices = sorted_indices[test_start:test_end]
383
-
384
- # Apply undersampling if enabled and y is provided
385
- if self.undersample and y is not None:
386
- train_indices = self._undersample_indices(
387
- train_indices, y, timestamps
388
- )
389
-
390
- assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
391
-
392
- yield train_indices, test_indices
393
-
394
- def split_with_val(
395
- self,
396
- X: np.ndarray,
397
- y: Optional[np.ndarray] = None,
398
- timestamps: Optional[np.ndarray] = None,
399
- groups: Optional[np.ndarray] = None,
400
- ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
401
- """
402
- Generate sliding window splits with validation set.
403
-
404
- The validation set comes after training but before test, maintaining
405
- temporal order: train < val < test.
406
-
407
- Parameters
408
- ----------
409
- X : array-like, shape (n_samples, n_features)
410
- Training data
411
- y : array-like, shape (n_samples,), optional
412
- Target variable
413
- timestamps : array-like, shape (n_samples,), optional
414
- Timestamps for temporal ordering. If None, uses sequential order
415
- groups : array-like, shape (n_samples,), optional
416
- Group labels (not used in this splitter)
417
-
418
- Yields
419
- ------
420
- train : ndarray
421
- Training set indices
422
- val : ndarray
423
- Validation set indices
424
- test : ndarray
425
- Test set indices
426
- """
427
- if timestamps is None:
428
- timestamps = np.arange(len(X))
429
-
430
- n_samples = _num_samples(X)
431
- indices = np.arange(n_samples)
432
-
433
- # Sort by timestamp to get temporal order
434
- time_order = np.argsort(timestamps)
435
- sorted_indices = indices[time_order]
436
-
437
- # Auto-calculate sizes if using n_splits mode
438
- if self.n_splits_mode:
439
- if self.expanding_window:
440
- min_window_size = max(1, n_samples // (self._n_splits + 1))
441
- available_for_test = (
442
- n_samples - min_window_size - (self._n_splits * self.gap)
443
- )
444
- calculated_test_size = max(1, available_for_test // self._n_splits)
445
- self.window_size = min_window_size
446
- self.test_size = calculated_test_size
447
- self.step_size = calculated_test_size
448
- else:
449
- available = n_samples - (self._n_splits * self.gap)
450
- calculated_test_size = max(1, available // (self._n_splits + 1))
451
- calculated_window_size = calculated_test_size
452
- self.window_size = calculated_window_size
453
- self.test_size = calculated_test_size
454
- self.step_size = calculated_test_size
455
-
456
- # Calculate validation size from training window
457
- val_size = int(self.window_size * self.val_ratio) if self.val_ratio > 0 else 0
458
- actual_train_size = self.window_size - val_size
459
-
460
- if self.expanding_window:
461
- # Expanding window with validation
462
- min_train_size = self.window_size
463
- total_min = min_train_size + self.gap + self.test_size
464
-
465
- if n_samples < total_min:
466
- logger.warning(
467
- f"Not enough samples ({n_samples}) for even one split. "
468
- f"Need at least {total_min} samples."
469
- )
470
- return
471
-
472
- # Calculate positions for validation and test
473
- test_start_pos = min_train_size + self.gap
474
-
475
- while test_start_pos + self.test_size <= n_samples:
476
- test_end_pos = test_start_pos + self.test_size
477
-
478
- # Training + validation comes before gap
479
- train_val_end_pos = test_start_pos - self.gap
480
-
481
- # Split train/val from the expanding window
482
- if val_size > 0:
483
- # Calculate validation size dynamically based on current expanding window
484
- # This ensures val_ratio is respected across all folds as window expands
485
- current_val_size = int(train_val_end_pos * self.val_ratio)
486
- train_end_pos = train_val_end_pos - current_val_size
487
- train_indices = sorted_indices[0:train_end_pos]
488
- val_indices = sorted_indices[train_end_pos:train_val_end_pos]
489
- else:
490
- train_indices = sorted_indices[0:train_val_end_pos]
491
- val_indices = np.array([])
492
-
493
- test_indices = sorted_indices[test_start_pos:test_end_pos]
494
-
495
- # Apply undersampling if enabled and y is provided
496
- if self.undersample and y is not None:
497
- train_indices = self._undersample_indices(
498
- train_indices, y, timestamps
499
- )
500
- # Also undersample validation set if it exists
501
- if len(val_indices) > 0:
502
- val_indices = self._undersample_indices(
503
- val_indices, y, timestamps
504
- )
505
-
506
- assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
507
-
508
- yield train_indices, val_indices, test_indices
509
-
510
- # Move test window forward by step_size
511
- test_start_pos += self.step_size
512
-
513
- else:
514
- # Fixed sliding window with validation
515
- total_window = self.window_size + self.gap + self.test_size
516
-
517
- for start in range(0, n_samples - total_window + 1, self.step_size):
518
- # These positions are in the sorted (temporal) domain
519
- train_end = start + actual_train_size
520
-
521
- # Validation comes after train with optional gap
522
- val_start = train_end + (self.gap if val_size > 0 else 0)
523
- val_end = val_start + val_size
524
-
525
- # Test comes after validation with gap
526
- test_start = (
527
- val_end + self.gap if val_size > 0 else train_end + self.gap
528
- )
529
- test_end = test_start + self.test_size
530
-
531
- if test_end > n_samples:
532
- break
533
-
534
- # Extract indices from the temporally sorted sequence
535
- train_indices = sorted_indices[start:train_end]
536
- val_indices = (
537
- sorted_indices[val_start:val_end] if val_size > 0 else np.array([])
538
- )
539
- test_indices = sorted_indices[test_start:test_end]
540
-
541
- # Apply undersampling if enabled and y is provided
542
- if self.undersample and y is not None:
543
- train_indices = self._undersample_indices(
544
- train_indices, y, timestamps
545
- )
546
- # Also undersample validation set if it exists
547
- if len(val_indices) > 0:
548
- val_indices = self._undersample_indices(
549
- val_indices, y, timestamps
550
- )
551
-
552
- # Ensure temporal order is preserved
553
- assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
554
-
555
- yield train_indices, val_indices, test_indices
556
-
557
- def get_n_splits(self, X=None, y=None, groups=None):
558
- """
559
- Calculate number of splits.
560
-
561
- Parameters
562
- ----------
563
- X : array-like, shape (n_samples, n_features), optional
564
- Training data (required to determine number of splits in manual mode)
565
- y : array-like, optional
566
- Not used
567
- groups : array-like, optional
568
- Not used
569
-
570
- Returns
571
- -------
572
- n_splits : int
573
- Number of splits. Returns -1 if X is None and not in n_splits mode.
574
- """
575
- # If using n_splits mode, return the specified n_splits
576
- if self.n_splits_mode:
577
- return self._n_splits
578
-
579
- # Manual mode: need data to calculate
580
- if X is None:
581
- return -1 # Can't determine without data
582
-
583
- n_samples = _num_samples(X)
584
- total_window = self.window_size + self.gap + self.test_size
585
- n_windows = (n_samples - total_window) // self.step_size + 1
586
- return max(0, n_windows)
587
-
588
- def plot_splits(self, X, y=None, timestamps=None, figsize=(12, 6), save_path=None):
589
- """
590
- Visualize the sliding window splits as rectangles.
591
-
592
- Shows train (blue), validation (green), and test (red) sets.
593
- When val_ratio=0, only shows train and test.
594
- When undersampling is enabled, shows dropped samples in gray.
595
-
596
- Parameters
597
- ----------
598
- X : array-like
599
- Training data
600
- y : array-like, optional
601
- Target variable (required for undersampling visualization)
602
- timestamps : array-like, optional
603
- Timestamps (if None, uses sample indices)
604
- figsize : tuple, default (12, 6)
605
- Figure size
606
- save_path : str, optional
607
- Path to save the plot
608
-
609
- Returns
610
- -------
611
- fig : matplotlib.figure.Figure
612
- The created figure
613
- """
614
- # Use sample indices if no timestamps provided
615
- if timestamps is None:
616
- timestamps = np.arange(len(X))
617
-
618
- # Get temporal ordering
619
- time_order = np.argsort(timestamps)
620
- sorted_timestamps = timestamps[time_order]
621
-
622
- # Get splits WITH undersampling (if enabled)
623
- if self.val_ratio > 0:
624
- splits = list(self.split_with_val(X, y, timestamps))[:10]
625
- split_type = "train-val-test"
626
- else:
627
- splits = list(self.split(X, y, timestamps))[:10]
628
- split_type = "train-test"
629
-
630
- if not splits:
631
- raise ValueError("No splits generated")
632
-
633
- # If undersampling is enabled, also get splits WITHOUT undersampling to show dropped samples
634
- splits_no_undersample = None
635
- if self.undersample and y is not None:
636
- original_undersample = self.undersample
637
- self.undersample = False # Temporarily disable
638
- if self.val_ratio > 0:
639
- splits_no_undersample = list(self.split_with_val(X, y, timestamps))[:10]
640
- else:
641
- splits_no_undersample = list(self.split(X, y, timestamps))[:10]
642
- self.undersample = original_undersample # Restore
643
-
644
- # Create figure
645
- fig, ax = stx.plt.subplots(figsize=figsize)
646
-
647
- # Plot each fold based on temporal position
648
- for fold, split_indices in enumerate(splits):
649
- y_pos = fold
650
-
651
- if len(split_indices) == 3: # train, val, test
652
- train_idx, val_idx, test_idx = split_indices
653
-
654
- # Find temporal positions of train indices
655
- train_positions = []
656
- for idx in train_idx:
657
- temp_pos = np.where(time_order == idx)[0][
658
- 0
659
- ] # Find position in sorted order
660
- train_positions.append(temp_pos)
661
-
662
- # Plot train window based on temporal positions
663
- if train_positions:
664
- train_start = min(train_positions)
665
- train_end = max(train_positions)
666
- train_rect = patches.Rectangle(
667
- (train_start, y_pos - 0.3),
668
- train_end - train_start + 1,
669
- 0.6,
670
- linewidth=1,
671
- edgecolor="blue",
672
- facecolor="lightblue",
673
- alpha=0.7,
674
- label="Train" if fold == 0 else "",
675
- )
676
- ax.add_patch(train_rect)
677
-
678
- # Find temporal positions of validation indices
679
- if len(val_idx) > 0:
680
- val_positions = []
681
- for idx in val_idx:
682
- temp_pos = np.where(time_order == idx)[0][0]
683
- val_positions.append(temp_pos)
684
-
685
- # Plot validation window
686
- if val_positions:
687
- val_start = min(val_positions)
688
- val_end = max(val_positions)
689
- val_rect = patches.Rectangle(
690
- (val_start, y_pos - 0.3),
691
- val_end - val_start + 1,
692
- 0.6,
693
- linewidth=1,
694
- edgecolor="green",
695
- facecolor="lightgreen",
696
- alpha=0.7,
697
- label="Validation" if fold == 0 else "",
698
- )
699
- ax.add_patch(val_rect)
700
-
701
- # Find temporal positions of test indices
702
- test_positions = []
703
- for idx in test_idx:
704
- temp_pos = np.where(time_order == idx)[0][
705
- 0
706
- ] # Find position in sorted order
707
- test_positions.append(temp_pos)
708
-
709
- # Plot test window based on temporal positions
710
- if test_positions:
711
- test_start = min(test_positions)
712
- test_end = max(test_positions)
713
- test_rect = patches.Rectangle(
714
- (test_start, y_pos - 0.3),
715
- test_end - test_start + 1,
716
- 0.6,
717
- linewidth=1,
718
- edgecolor=COLORS["RGBA_NORM"]["red"],
719
- facecolor=COLORS["RGBA_NORM"]["red"],
720
- alpha=0.7,
721
- label="Test" if fold == 0 else "",
722
- )
723
- ax.add_patch(test_rect)
724
-
725
- else: # train, test (2-way split)
726
- train_idx, test_idx = split_indices
727
-
728
- # Find temporal positions of train indices
729
- train_positions = []
730
- for idx in train_idx:
731
- temp_pos = np.where(time_order == idx)[0][
732
- 0
733
- ] # Find position in sorted order
734
- train_positions.append(temp_pos)
735
-
736
- # Plot train window based on temporal positions
737
- if train_positions:
738
- train_start = min(train_positions)
739
- train_end = max(train_positions)
740
- train_rect = patches.Rectangle(
741
- (train_start, y_pos - 0.3),
742
- train_end - train_start + 1,
743
- 0.6,
744
- linewidth=1,
745
- edgecolor=COLORS["RGBA_NORM"]["lightblue"],
746
- facecolor=COLORS["RGBA_NORM"]["lightblue"],
747
- alpha=0.7,
748
- label="Train" if fold == 0 else "",
749
- )
750
- ax.add_patch(train_rect)
751
-
752
- # Find temporal positions of test indices
753
- test_positions = []
754
- for idx in test_idx:
755
- temp_pos = np.where(time_order == idx)[0][
756
- 0
757
- ] # Find position in sorted order
758
- test_positions.append(temp_pos)
759
-
760
- # Plot test window based on temporal positions
761
- if test_positions:
762
- test_start = min(test_positions)
763
- test_end = max(test_positions)
764
- test_rect = patches.Rectangle(
765
- (test_start, y_pos - 0.3),
766
- test_end - test_start + 1,
767
- 0.6,
768
- linewidth=1,
769
- edgecolor="red",
770
- facecolor="lightcoral",
771
- alpha=0.7,
772
- label="Test" if fold == 0 else "",
773
- )
774
- ax.add_patch(test_rect)
775
-
776
- # Add scatter plots of actual data points with jittering
777
- np.random.seed(42) # For reproducible jittering
778
- jitter_strength = 0.15 # Amount of vertical jittering
779
-
780
- # First, plot dropped samples in gray if undersampling is enabled
781
- if splits_no_undersample is not None:
782
- for fold, split_indices_no_us in enumerate(splits_no_undersample):
783
- y_pos = fold
784
- split_indices_us = splits[fold]
785
-
786
- if len(split_indices_no_us) == 3: # train, val, test
787
- train_idx_no_us, val_idx_no_us, test_idx_no_us = split_indices_no_us
788
- train_idx_us, val_idx_us, test_idx_us = split_indices_us
789
-
790
- # Find dropped train samples
791
- dropped_train = np.setdiff1d(train_idx_no_us, train_idx_us)
792
- if len(dropped_train) > 0:
793
- dropped_train_positions = [
794
- np.where(time_order == idx)[0][0] for idx in dropped_train
795
- ]
796
- dropped_train_jitter = np.random.normal(
797
- 0, jitter_strength, len(dropped_train_positions)
798
- )
799
- ax.plot_scatter(
800
- dropped_train_positions,
801
- y_pos + dropped_train_jitter,
802
- c="gray",
803
- s=15,
804
- alpha=0.3,
805
- marker="x",
806
- label="Dropped (train)" if fold == 0 else "",
807
- zorder=2,
808
- )
809
-
810
- # Find dropped validation samples
811
- dropped_val = np.setdiff1d(val_idx_no_us, val_idx_us)
812
- if len(dropped_val) > 0:
813
- dropped_val_positions = [
814
- np.where(time_order == idx)[0][0] for idx in dropped_val
815
- ]
816
- dropped_val_jitter = np.random.normal(
817
- 0, jitter_strength, len(dropped_val_positions)
818
- )
819
- ax.plot_scatter(
820
- dropped_val_positions,
821
- y_pos + dropped_val_jitter,
822
- c="gray",
823
- s=15,
824
- alpha=0.3,
825
- marker="x",
826
- label="Dropped (val)" if fold == 0 else "",
827
- zorder=2,
828
- )
829
-
830
- else: # train, test (2-way split)
831
- train_idx_no_us, test_idx_no_us = split_indices_no_us
832
- train_idx_us, test_idx_us = split_indices_us
833
-
834
- # Find dropped train samples
835
- dropped_train = np.setdiff1d(train_idx_no_us, train_idx_us)
836
- if len(dropped_train) > 0:
837
- dropped_train_positions = [
838
- np.where(time_order == idx)[0][0] for idx in dropped_train
839
- ]
840
- dropped_train_jitter = np.random.normal(
841
- 0, jitter_strength, len(dropped_train_positions)
842
- )
843
- ax.plot_scatter(
844
- dropped_train_positions,
845
- y_pos + dropped_train_jitter,
846
- c="gray",
847
- s=15,
848
- alpha=0.3,
849
- marker="x",
850
- label="Dropped samples" if fold == 0 else "",
851
- zorder=2,
852
- )
853
-
854
- # Then, plot kept samples in color
855
- for fold, split_indices in enumerate(splits):
856
- y_pos = fold
857
-
858
- if len(split_indices) == 3: # train, val, test
859
- train_idx, val_idx, test_idx = split_indices
860
-
861
- # Find temporal positions for scatter plot
862
- train_positions = []
863
- for idx in train_idx:
864
- temp_pos = np.where(time_order == idx)[0][0]
865
- train_positions.append(temp_pos)
866
-
867
- val_positions = []
868
- if len(val_idx) > 0:
869
- for idx in val_idx:
870
- temp_pos = np.where(time_order == idx)[0][0]
871
- val_positions.append(temp_pos)
872
-
873
- test_positions = []
874
- for idx in test_idx:
875
- temp_pos = np.where(time_order == idx)[0][0]
876
- test_positions.append(temp_pos)
877
-
878
- # Add jittered scatter plots for 3-way split
879
- if train_positions:
880
- train_jitter = np.random.normal(
881
- 0, jitter_strength, len(train_positions)
882
- )
883
- # Color by class if y is provided
884
- if y is not None:
885
- train_colors = [
886
- stx.plt.color.PARAMS["RGBA_NORM"]["blue"]
887
- if y[idx] == 0
888
- else stx.plt.color.PARAMS["RGBA_NORM"]["lightblue"]
889
- for idx in train_idx
890
- ]
891
- ax.plot_scatter(
892
- train_positions,
893
- y_pos + train_jitter,
894
- c=train_colors,
895
- s=20,
896
- alpha=0.7,
897
- marker="o",
898
- label="Train (class 0)" if fold == 0 else "",
899
- zorder=3,
900
- )
901
- else:
902
- ax.plot_scatter(
903
- train_positions,
904
- y_pos + train_jitter,
905
- c="darkblue",
906
- s=20,
907
- alpha=0.7,
908
- marker="o",
909
- label="Train points" if fold == 0 else "",
910
- zorder=3,
911
- )
912
-
913
- if val_positions:
914
- val_jitter = np.random.normal(
915
- 0, jitter_strength, len(val_positions)
916
- )
917
- # Color by class if y is provided
918
- if y is not None:
919
- val_colors = [
920
- stx.plt.color.PARAMS["RGBA_NORM"]["yellow"]
921
- if y[idx] == 0
922
- else stx.plt.color.PARAMS["RGBA_NORM"]["orange"]
923
- for idx in val_idx
924
- ]
925
- ax.plot_scatter(
926
- val_positions,
927
- y_pos + val_jitter,
928
- c=val_colors,
929
- s=20,
930
- alpha=0.7,
931
- marker="^",
932
- label="Val (class 0)" if fold == 0 else "",
933
- zorder=3,
934
- )
935
- else:
936
- ax.plot_scatter(
937
- val_positions,
938
- y_pos + val_jitter,
939
- c="darkgreen",
940
- s=20,
941
- alpha=0.7,
942
- marker="^",
943
- label="Val points" if fold == 0 else "",
944
- zorder=3,
945
- )
946
-
947
- if test_positions:
948
- test_jitter = np.random.normal(
949
- 0, jitter_strength, len(test_positions)
950
- )
951
- # Color by class if y is provided
952
- if y is not None:
953
- test_colors = [
954
- stx.plt.color.PARAMS["RGBA_NORM"]["red"]
955
- if y[idx] == 0
956
- else stx.plt.color.PARAMS["RGBA_NORM"]["brown"]
957
- for idx in test_idx
958
- ]
959
- ax.plot_scatter(
960
- test_positions,
961
- y_pos + test_jitter,
962
- c=test_colors,
963
- s=20,
964
- alpha=0.7,
965
- marker="s",
966
- label="Test (class 0)" if fold == 0 else "",
967
- zorder=3,
968
- )
969
- else:
970
- ax.plot_scatter(
971
- test_positions,
972
- y_pos + test_jitter,
973
- c="darkred",
974
- s=20,
975
- alpha=0.7,
976
- marker="s",
977
- label="Test points" if fold == 0 else "",
978
- zorder=3,
979
- )
980
-
981
- else: # train, test (2-way split)
982
- train_idx, test_idx = split_indices
983
-
984
- # Get actual timestamps for train and test indices
985
- train_times = (
986
- timestamps[train_idx] if timestamps is not None else train_idx
987
- )
988
- test_times = (
989
- timestamps[test_idx] if timestamps is not None else test_idx
990
- )
991
-
992
- # Find temporal positions for scatter plot
993
- train_positions = []
994
- for idx in train_idx:
995
- temp_pos = np.where(time_order == idx)[0][0]
996
- train_positions.append(temp_pos)
997
-
998
- test_positions = []
999
- for idx in test_idx:
1000
- temp_pos = np.where(time_order == idx)[0][0]
1001
- test_positions.append(temp_pos)
1002
-
1003
- # Add jittered scatter plots for 2-way split
1004
- if train_positions:
1005
- train_jitter = np.random.normal(
1006
- 0, jitter_strength, len(train_positions)
1007
- )
1008
- # Color by class if y is provided
1009
- if y is not None:
1010
- train_colors = [
1011
- stx.plt.color.PARAMS["RGBA_NORM"]["blue"]
1012
- if y[idx] == 0
1013
- else stx.plt.color.PARAMS["RGBA_NORM"]["lightblue"]
1014
- for idx in train_idx
1015
- ]
1016
- ax.plot_scatter(
1017
- train_positions,
1018
- y_pos + train_jitter,
1019
- c=train_colors,
1020
- s=20,
1021
- alpha=0.7,
1022
- marker="o",
1023
- label="Train (class 0)" if fold == 0 else "",
1024
- zorder=3,
1025
- )
1026
- else:
1027
- ax.plot_scatter(
1028
- train_positions,
1029
- y_pos + train_jitter,
1030
- c="darkblue",
1031
- s=20,
1032
- alpha=0.7,
1033
- marker="o",
1034
- label="Train points" if fold == 0 else "",
1035
- zorder=3,
1036
- )
1037
-
1038
- if test_positions:
1039
- test_jitter = np.random.normal(
1040
- 0, jitter_strength, len(test_positions)
1041
- )
1042
- # Color by class if y is provided
1043
- if y is not None:
1044
- test_colors = [
1045
- stx.plt.color.PARAMS["RGBA_NORM"]["red"]
1046
- if y[idx] == 0
1047
- else stx.plt.color.PARAMS["RGBA_NORM"]["brown"]
1048
- for idx in test_idx
1049
- ]
1050
- ax.plot_scatter(
1051
- test_positions,
1052
- y_pos + test_jitter,
1053
- c=test_colors,
1054
- s=20,
1055
- alpha=0.7,
1056
- marker="s",
1057
- label="Test (class 0)" if fold == 0 else "",
1058
- zorder=3,
1059
- )
1060
- else:
1061
- ax.plot_scatter(
1062
- test_positions,
1063
- y_pos + test_jitter,
1064
- c="darkred",
1065
- s=20,
1066
- alpha=0.7,
1067
- marker="s",
1068
- label="Test points" if fold == 0 else "",
1069
- zorder=3,
1070
- )
1071
-
1072
- # Format plot
1073
- ax.set_ylim(-0.5, len(splits) - 0.5)
1074
- ax.set_xlim(0, len(X))
1075
- ax.set_xlabel("Temporal Position (sorted by timestamp)")
1076
- ax.set_ylabel("Fold")
1077
- gap_text = f", Gap: {self.gap}" if self.gap > 0 else ""
1078
- val_text = f", Val ratio: {self.val_ratio:.1%}" if self.val_ratio > 0 else ""
1079
- ax.set_title(
1080
- f"Sliding Window Split Visualization ({split_type})\\n"
1081
- f"Window: {self.window_size}, Step: {self.step_size}, Test: {self.test_size}{gap_text}{val_text}\\n"
1082
- f"Rectangles show windows, dots show actual data points"
1083
- )
1084
-
1085
- # Set y-ticks
1086
- ax.set_yticks(range(len(splits)))
1087
- ax.set_yticklabels([f"Fold {i}" for i in range(len(splits))])
1088
-
1089
- # Add enhanced legend with class and sample information
1090
- if y is not None:
1091
- # Count samples per class in total dataset
1092
- unique_classes, class_counts = np.unique(y, return_counts=True)
1093
- total_class_info = ", ".join(
1094
- [
1095
- f"Class {cls}: n={count}"
1096
- for cls, count in zip(unique_classes, class_counts)
1097
- ]
1098
- )
1099
-
1100
- # Count samples in first fold to show per-fold distribution
1101
- first_split = splits[0]
1102
- if len(first_split) == 3: # train, val, test
1103
- train_idx, val_idx, test_idx = first_split
1104
- fold_info = f"Fold 0: Train n={len(train_idx)}, Val n={len(val_idx)}, Test n={len(test_idx)}"
1105
- else: # train, test
1106
- train_idx, test_idx = first_split
1107
- fold_info = f"Fold 0: Train n={len(train_idx)}, Test n={len(test_idx)}"
1108
-
1109
- # Add legend with class information
1110
- handles, labels = ax.get_legend_handles_labels()
1111
- # Add title to legend showing class distribution
1112
- legend_title = f"Total: {total_class_info}\\n{fold_info}"
1113
- ax.legend(handles, labels, loc="upper right", title=legend_title)
1114
- else:
1115
- ax.legend(loc="upper right")
1116
-
1117
- plt.tight_layout()
1118
-
1119
- if save_path:
1120
- fig.savefig(save_path, dpi=150, bbox_inches="tight")
1121
-
1122
- return fig
1123
-
1124
-
1125
- """Functions & Classes"""
1126
-
1127
-
1128
- def main(args) -> int:
1129
- """Demonstrate TimeSeriesSlidingWindowSplit functionality.
1130
-
1131
- Args:
1132
- args: Command line arguments
1133
-
1134
- Returns:
1135
- int: Exit status
1136
- """
1137
-
1138
- def demo_01_fixed_window_non_overlapping_tests(X, y, timestamps):
1139
- """Demo 1: Fixed window size with non-overlapping test sets (DEFAULT).
1140
-
1141
- Best for: Testing model on consistent recent history.
1142
- Each sample tested exactly once (like K-fold for time series).
1143
- """
1144
- logger.info("=" * 70)
1145
- logger.info("DEMO 1: Fixed Window + Non-overlapping Tests (DEFAULT)")
1146
- logger.info("=" * 70)
1147
- logger.info("Best for: Testing model on consistent recent history")
1148
-
1149
- splitter = TimeSeriesSlidingWindowSplit(
1150
- window_size=args.window_size,
1151
- test_size=args.test_size,
1152
- gap=args.gap,
1153
- overlapping_tests=False, # Default
1154
- expanding_window=False, # Default
1155
- )
1156
-
1157
- splits = list(splitter.split(X, y, timestamps))[:5]
1158
- logger.info(f"Generated {len(splits)} splits")
1159
-
1160
- for fold, (train_idx, test_idx) in enumerate(splits):
1161
- logger.info(
1162
- f" Fold {fold}: Train={len(train_idx)} (fixed), Test={len(test_idx)}"
1163
- )
1164
-
1165
- fig = splitter.plot_splits(X, y, timestamps)
1166
- stx.io.save(fig, "./01_sliding_window_fixed.jpg", symlink_from_cwd=True)
1167
- logger.info("")
1168
-
1169
- return splits
1170
-
1171
- def demo_02_expanding_window_non_overlapping_tests(X, y, timestamps):
1172
- """Demo 2: Expanding window with non-overlapping test sets.
1173
-
1174
- Best for: Using all available past data (like sklearn TimeSeriesSplit).
1175
- Training set grows to include all historical data.
1176
- """
1177
- logger.info("=" * 70)
1178
- logger.info("DEMO 2: Expanding Window + Non-overlapping Tests")
1179
- logger.info("=" * 70)
1180
- logger.info(
1181
- "Best for: Using all available past data (like sklearn TimeSeriesSplit)"
1182
- )
1183
-
1184
- splitter = TimeSeriesSlidingWindowSplit(
1185
- window_size=args.window_size,
1186
- test_size=args.test_size,
1187
- gap=args.gap,
1188
- overlapping_tests=False,
1189
- expanding_window=True, # Use all past data!
1190
- )
1191
-
1192
- splits = list(splitter.split(X, y, timestamps))[:5]
1193
- logger.info(f"Generated {len(splits)} splits")
1194
-
1195
- for fold, (train_idx, test_idx) in enumerate(splits):
1196
- logger.info(
1197
- f" Fold {fold}: Train={len(train_idx)} (growing!), Test={len(test_idx)}"
1198
- )
1199
-
1200
- fig = splitter.plot_splits(X, y, timestamps)
1201
- stx.io.save(fig, "./02_sliding_window_expanding.jpg", symlink_from_cwd=True)
1202
- logger.info("")
1203
-
1204
- return splits
1205
-
1206
- def demo_03_fixed_window_overlapping_tests(X, y, timestamps):
1207
- """Demo 3: Fixed window with overlapping test sets.
1208
-
1209
- Best for: Maximum evaluation points (like K-fold training reuse).
1210
- Test sets can overlap for more frequent model evaluation.
1211
- """
1212
- logger.info("=" * 70)
1213
- logger.info("DEMO 3: Fixed Window + Overlapping Tests")
1214
- logger.info("=" * 70)
1215
- logger.info("Best for: Maximum evaluation points (like K-fold for training)")
1216
-
1217
- splitter = TimeSeriesSlidingWindowSplit(
1218
- window_size=args.window_size,
1219
- test_size=args.test_size,
1220
- gap=args.gap,
1221
- overlapping_tests=True, # Allow test overlap
1222
- expanding_window=False,
1223
- # step_size will default to test_size // 2 for 50% overlap
1224
- )
1225
-
1226
- splits = list(splitter.split(X, y, timestamps))[:5]
1227
- logger.info(f"Generated {len(splits)} splits")
1228
-
1229
- for fold, (train_idx, test_idx) in enumerate(splits):
1230
- logger.info(f" Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
1231
-
1232
- fig = splitter.plot_splits(X, y, timestamps)
1233
- stx.io.save(fig, "./03_sliding_window_overlapping.jpg", symlink_from_cwd=True)
1234
- logger.info("")
1235
-
1236
- return splits
1237
-
1238
- def demo_04_undersample_imbalanced_data(X, y_imbalanced, timestamps):
1239
- """Demo 4: Undersampling for imbalanced time series data.
1240
-
1241
- Best for: Handling class imbalance in training sets.
1242
- Balances classes by randomly undersampling majority class.
1243
- """
1244
- logger.info("=" * 70)
1245
- logger.info("DEMO 4: Undersampling for Imbalanced Data")
1246
- logger.info("=" * 70)
1247
- logger.info("Best for: Handling class imbalance in time series")
1248
-
1249
- # Show data imbalance
1250
- unique, counts = np.unique(y_imbalanced, return_counts=True)
1251
- logger.info(f"Class distribution: {dict(zip(unique, counts))}")
1252
- logger.info("")
1253
-
1254
- # Without undersampling
1255
- splitter_no_undersample = TimeSeriesSlidingWindowSplit(
1256
- window_size=args.window_size,
1257
- test_size=args.test_size,
1258
- gap=args.gap,
1259
- undersample=False,
1260
- )
1261
-
1262
- splits_no_us = list(splitter_no_undersample.split(X, y_imbalanced, timestamps))[
1263
- :3
1264
- ]
1265
- logger.info(f"WITHOUT undersampling: {len(splits_no_us)} splits")
1266
- for fold, (train_idx, test_idx) in enumerate(splits_no_us):
1267
- train_labels = y_imbalanced[train_idx]
1268
- train_unique, train_counts = np.unique(train_labels, return_counts=True)
1269
- logger.info(
1270
- f" Fold {fold}: Train size={len(train_idx)}, "
1271
- f"Class dist={dict(zip(train_unique, train_counts))}"
1272
- )
1273
- logger.info("")
1274
-
1275
- # With undersampling
1276
- splitter_undersample = TimeSeriesSlidingWindowSplit(
1277
- window_size=args.window_size,
1278
- test_size=args.test_size,
1279
- gap=args.gap,
1280
- undersample=True, # Enable undersampling!
1281
- random_state=42,
1282
- )
1283
-
1284
- splits_us = list(splitter_undersample.split(X, y_imbalanced, timestamps))[:3]
1285
- logger.info(f"WITH undersampling: {len(splits_us)} splits")
1286
- for fold, (train_idx, test_idx) in enumerate(splits_us):
1287
- train_labels = y_imbalanced[train_idx]
1288
- train_unique, train_counts = np.unique(train_labels, return_counts=True)
1289
- logger.info(
1290
- f" Fold {fold}: Train size={len(train_idx)} (balanced!), "
1291
- f"Class dist={dict(zip(train_unique, train_counts))}"
1292
- )
1293
-
1294
- # Save visualization for undersampling
1295
- fig = splitter_undersample.plot_splits(X, y_imbalanced, timestamps)
1296
- stx.io.save(fig, "./04_sliding_window_undersample.jpg", symlink_from_cwd=True)
1297
- logger.info("")
1298
-
1299
- return splits_us
1300
-
1301
- def demo_05_validation_dataset(X, y, timestamps):
1302
- """Demo 5: Using validation dataset with train-val-test splits.
1303
-
1304
- Best for: Model selection and hyperparameter tuning.
1305
- Creates train/validation/test splits maintaining temporal order.
1306
- """
1307
- logger.info("=" * 70)
1308
- logger.info("DEMO 5: Validation Dataset (Train-Val-Test Splits)")
1309
- logger.info("=" * 70)
1310
- logger.info("Best for: Model selection and hyperparameter tuning")
1311
-
1312
- splitter = TimeSeriesSlidingWindowSplit(
1313
- window_size=args.window_size,
1314
- test_size=args.test_size,
1315
- gap=args.gap,
1316
- val_ratio=0.2, # 20% of training window for validation
1317
- overlapping_tests=False,
1318
- expanding_window=False,
1319
- )
1320
-
1321
- splits = list(splitter.split_with_val(X, y, timestamps))[:3]
1322
- logger.info(f"Generated {len(splits)} splits")
1323
-
1324
- for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
1325
- logger.info(
1326
- f" Fold {fold}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
1327
- )
1328
-
1329
- fig = splitter.plot_splits(X, y, timestamps)
1330
- stx.io.save(fig, "./05_sliding_window_validation.jpg", symlink_from_cwd=True)
1331
- logger.info("")
1332
-
1333
- return splits
1334
-
1335
- def demo_06_expanding_with_validation(X, y, timestamps):
1336
- """Demo 6: Expanding window with validation dataset.
1337
-
1338
- Best for: Using all historical data with model selection.
1339
- Combines expanding window and validation split.
1340
- """
1341
- logger.info("=" * 70)
1342
- logger.info("DEMO 6: Expanding Window + Validation Dataset")
1343
- logger.info("=" * 70)
1344
- logger.info("Best for: Using all historical data with model selection")
1345
-
1346
- splitter = TimeSeriesSlidingWindowSplit(
1347
- window_size=args.window_size,
1348
- test_size=args.test_size,
1349
- gap=args.gap,
1350
- val_ratio=0.2,
1351
- overlapping_tests=False,
1352
- expanding_window=True, # Expanding + validation!
1353
- )
1354
-
1355
- splits = list(splitter.split_with_val(X, y, timestamps))[:3]
1356
- logger.info(f"Generated {len(splits)} splits")
1357
-
1358
- for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
1359
- logger.info(
1360
- f" Fold {fold}: Train={len(train_idx)} (growing!), Val={len(val_idx)}, Test={len(test_idx)}"
1361
- )
1362
-
1363
- fig = splitter.plot_splits(X, y, timestamps)
1364
- stx.io.save(
1365
- fig,
1366
- "./06_sliding_window_expanding_validation.jpg",
1367
- symlink_from_cwd=True,
1368
- )
1369
- logger.info("")
1370
-
1371
- return splits
1372
-
1373
- def demo_07_undersample_with_validation(X, y_imbalanced, timestamps):
1374
- """Demo 7: Undersampling with validation dataset.
1375
-
1376
- Best for: Handling imbalanced data with hyperparameter tuning.
1377
- Combines undersampling and validation split.
1378
- """
1379
-
1380
- logger.info("=" * 70)
1381
- logger.info("DEMO 7: Undersampling + Validation Dataset")
1382
- logger.info("=" * 70)
1383
- logger.info("Best for: Imbalanced data with hyperparameter tuning")
1384
-
1385
- splitter = TimeSeriesSlidingWindowSplit(
1386
- window_size=args.window_size,
1387
- test_size=args.test_size,
1388
- gap=args.gap,
1389
- val_ratio=0.2,
1390
- undersample=True, # Undersample + validation!
1391
- random_state=42,
1392
- )
1393
-
1394
- splits = list(splitter.split_with_val(X, y_imbalanced, timestamps))[:3]
1395
- logger.info(f"Generated {len(splits)} splits")
1396
-
1397
- for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
1398
- train_labels = y_imbalanced[train_idx]
1399
- train_unique, train_counts = np.unique(train_labels, return_counts=True)
1400
- logger.info(
1401
- f" Fold {fold}: Train={len(train_idx)} (balanced!), Val={len(val_idx)}, Test={len(test_idx)}, "
1402
- f"Class dist={dict(zip(train_unique, train_counts))}"
1403
- )
1404
-
1405
- fig = splitter.plot_splits(X, y_imbalanced, timestamps)
1406
- stx.io.save(
1407
- fig,
1408
- "./07_sliding_window_undersample_validation.jpg",
1409
- symlink_from_cwd=True,
1410
- )
1411
- logger.info("")
1412
-
1413
- return splits
1414
-
1415
- def demo_08_all_options_combined(X, y_imbalanced, timestamps):
1416
- """Demo 8: All options combined.
1417
-
1418
- Best for: Maximum flexibility - expanding window, undersampling, and validation.
1419
- Shows all features working together.
1420
- """
1421
- logger.info("=" * 70)
1422
- logger.info("DEMO 8: Expanding + Undersampling + Validation (ALL OPTIONS)")
1423
- logger.info("=" * 70)
1424
- logger.info("Best for: Comprehensive time series CV with all features")
1425
-
1426
- splitter = TimeSeriesSlidingWindowSplit(
1427
- window_size=args.window_size,
1428
- test_size=args.test_size,
1429
- gap=args.gap,
1430
- val_ratio=0.2,
1431
- overlapping_tests=False,
1432
- expanding_window=True, # All three!
1433
- undersample=True,
1434
- random_state=42,
1435
- )
1436
-
1437
- splits = list(splitter.split_with_val(X, y_imbalanced, timestamps))[:3]
1438
- logger.info(f"Generated {len(splits)} splits")
1439
-
1440
- for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
1441
- train_labels = y_imbalanced[train_idx]
1442
- train_unique, train_counts = np.unique(train_labels, return_counts=True)
1443
- logger.info(
1444
- f" Fold {fold}: Train={len(train_idx)} (growing & balanced!), Val={len(val_idx)}, Test={len(test_idx)}, "
1445
- f"Class dist={dict(zip(train_unique, train_counts))}"
1446
- )
1447
-
1448
- fig = splitter.plot_splits(X, y_imbalanced, timestamps)
1449
- stx.io.save(fig, "./08_sliding_window_all_options.jpg", symlink_from_cwd=True)
1450
- logger.info("")
1451
-
1452
- return splits
1453
-
1454
- def print_summary(
1455
- splits_fixed,
1456
- splits_expanding,
1457
- splits_overlap,
1458
- splits_undersample=None,
1459
- splits_validation=None,
1460
- splits_expanding_val=None,
1461
- splits_undersample_val=None,
1462
- splits_all_options=None,
1463
- ):
1464
- """Print comparison summary of all modes."""
1465
- logger.info("=" * 70)
1466
- logger.info("SUMMARY COMPARISON")
1467
- logger.info("=" * 70)
1468
- logger.info(
1469
- f"01. Fixed window (non-overlap): {len(splits_fixed)} folds, train size constant"
1470
- )
1471
- logger.info(
1472
- f"02. Expanding window (non-overlap): {len(splits_expanding)} folds, train size grows"
1473
- )
1474
- logger.info(
1475
- f"03. Fixed window (overlapping): {len(splits_overlap)} folds, more eval points"
106
+ super().__init__(
107
+ window_size=window_size,
108
+ step_size=step_size,
109
+ test_size=test_size,
110
+ gap=gap,
111
+ val_ratio=val_ratio,
112
+ random_state=random_state,
113
+ overlapping_tests=overlapping_tests,
114
+ expanding_window=expanding_window,
115
+ undersample=undersample,
116
+ n_splits=n_splits,
1476
117
  )
1477
- if splits_undersample is not None:
1478
- logger.info(
1479
- f"04. With undersampling: {len(splits_undersample)} folds, balanced classes"
1480
- )
1481
- if splits_validation is not None:
1482
- logger.info(
1483
- f"05. With validation set: {len(splits_validation)} folds, train-val-test"
1484
- )
1485
- if splits_expanding_val is not None:
1486
- logger.info(
1487
- f"06. Expanding + validation: {len(splits_expanding_val)} folds, growing train with val"
1488
- )
1489
- if splits_undersample_val is not None:
1490
- logger.info(
1491
- f"07. Undersample + validation: {len(splits_undersample_val)} folds, balanced with val"
1492
- )
1493
- if splits_all_options is not None:
1494
- logger.info(
1495
- f"08. All options combined: {len(splits_all_options)} folds, expanding + balanced + val"
1496
- )
1497
- logger.info("")
1498
- logger.info("Key Insights:")
1499
- logger.info(
1500
- " - Non-overlapping tests (default): Each sample tested exactly once"
1501
- )
1502
- logger.info(
1503
- " - Expanding window: Maximizes training data, like sklearn TimeSeriesSplit"
1504
- )
1505
- logger.info(
1506
- " - Overlapping tests: More evaluation points, like K-fold training reuse"
1507
- )
1508
- if splits_undersample is not None:
1509
- logger.info(
1510
- " - Undersampling: Balances imbalanced classes in training sets"
1511
- )
1512
- if splits_validation is not None:
1513
- logger.info(
1514
- " - Validation set: Enables hyperparameter tuning with temporal order"
1515
- )
1516
- if splits_all_options is not None:
1517
- logger.info(
1518
- " - Combined options: Maximum flexibility for complex time series CV"
1519
- )
1520
- logger.info("=" * 70)
1521
-
1522
- # Main execution
1523
- logger.info("=" * 70)
1524
- logger.info("Demonstrating TimeSeriesSlidingWindowSplit with New Options")
1525
- logger.info("=" * 70)
1526
-
1527
- # Generate test data
1528
- np.random.seed(42)
1529
- n_samples = args.n_samples
1530
- X = np.random.randn(n_samples, 5)
1531
- y = np.random.randint(0, 2, n_samples) # Balanced
1532
- timestamps = np.arange(n_samples) + np.random.normal(0, 0.1, n_samples)
1533
-
1534
- # Create imbalanced labels (80% class 0, 20% class 1)
1535
- y_imbalanced = np.zeros(n_samples, dtype=int)
1536
- n_minority = int(n_samples * 0.2)
1537
- minority_indices = np.random.choice(n_samples, size=n_minority, replace=False)
1538
- y_imbalanced[minority_indices] = 1
1539
-
1540
- logger.info(f"Generated test data: {n_samples} samples, {X.shape[1]} features")
1541
- logger.info("")
1542
-
1543
- # Run demos
1544
- splits_fixed = demo_01_fixed_window_non_overlapping_tests(X, y, timestamps)
1545
- splits_expanding = demo_02_expanding_window_non_overlapping_tests(X, y, timestamps)
1546
- splits_overlap = demo_03_fixed_window_overlapping_tests(X, y, timestamps)
1547
- splits_undersample = demo_04_undersample_imbalanced_data(
1548
- X, y_imbalanced, timestamps
1549
- )
1550
- splits_validation = demo_05_validation_dataset(X, y, timestamps)
1551
- splits_expanding_val = demo_06_expanding_with_validation(X, y, timestamps)
1552
- splits_undersample_val = demo_07_undersample_with_validation(
1553
- X, y_imbalanced, timestamps
1554
- )
1555
- splits_all_options = demo_08_all_options_combined(X, y_imbalanced, timestamps)
1556
-
1557
- # Print summary
1558
- print_summary(
1559
- splits_fixed,
1560
- splits_expanding,
1561
- splits_overlap,
1562
- splits_undersample,
1563
- splits_validation,
1564
- splits_expanding_val,
1565
- splits_undersample_val,
1566
- splits_all_options,
1567
- )
1568
-
1569
- return 0
1570
-
1571
-
1572
- def parse_args() -> argparse.Namespace:
1573
- """Parse command line arguments."""
1574
- parser = argparse.ArgumentParser(
1575
- description="Demonstrate TimeSeriesSlidingWindowSplit with overlapping_tests and expanding_window options"
1576
- )
1577
- parser.add_argument(
1578
- "--n-samples",
1579
- type=int,
1580
- default=200,
1581
- help="Number of samples to generate (default: %(default)s)",
1582
- )
1583
- parser.add_argument(
1584
- "--window-size",
1585
- type=int,
1586
- default=50,
1587
- help="Size of training window (default: %(default)s)",
1588
- )
1589
- parser.add_argument(
1590
- "--test-size",
1591
- type=int,
1592
- default=20,
1593
- help="Size of test window (default: %(default)s)",
1594
- )
1595
- parser.add_argument(
1596
- "--gap",
1597
- type=int,
1598
- default=5,
1599
- help="Gap between train and test (default: %(default)s)",
1600
- )
1601
- args = parser.parse_args()
1602
- return args
1603
-
1604
-
1605
- def run_main() -> None:
1606
- """Initialize scitex framework, run main function, and cleanup."""
1607
- global CONFIG, CC, sys, plt, rng
1608
-
1609
- import sys
1610
-
1611
- import matplotlib.pyplot as plt
1612
- import scitex as stx
1613
-
1614
- args = parse_args()
1615
-
1616
- CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start(
1617
- sys,
1618
- plt,
1619
- args=args,
1620
- file=__FILE__,
1621
- sdir_suffix=None,
1622
- verbose=False,
1623
- agg=True,
1624
- )
1625
-
1626
- exit_status = main(args)
1627
-
1628
- stx.session.close(
1629
- CONFIG,
1630
- verbose=False,
1631
- notify=False,
1632
- message="",
1633
- exit_status=exit_status,
1634
- )
1635
-
1636
118
 
1637
- if __name__ == "__main__":
1638
- run_main()
1639
119
 
1640
120
  # EOF