scitex 2.14.0__py3-none-any.whl → 2.15.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scitex/__init__.py +71 -17
- scitex/_env_loader.py +156 -0
- scitex/_mcp_resources/__init__.py +37 -0
- scitex/_mcp_resources/_cheatsheet.py +135 -0
- scitex/_mcp_resources/_figrecipe.py +138 -0
- scitex/_mcp_resources/_formats.py +102 -0
- scitex/_mcp_resources/_modules.py +337 -0
- scitex/_mcp_resources/_session.py +149 -0
- scitex/_mcp_tools/__init__.py +4 -0
- scitex/_mcp_tools/audio.py +66 -0
- scitex/_mcp_tools/diagram.py +11 -95
- scitex/_mcp_tools/introspect.py +210 -0
- scitex/_mcp_tools/plt.py +260 -305
- scitex/_mcp_tools/scholar.py +74 -0
- scitex/_mcp_tools/social.py +244 -0
- scitex/_mcp_tools/template.py +24 -0
- scitex/_mcp_tools/writer.py +21 -204
- scitex/ai/_gen_ai/_PARAMS.py +10 -7
- scitex/ai/classification/reporters/_SingleClassificationReporter.py +45 -1603
- scitex/ai/classification/reporters/_mixins/__init__.py +36 -0
- scitex/ai/classification/reporters/_mixins/_constants.py +67 -0
- scitex/ai/classification/reporters/_mixins/_cv_summary.py +387 -0
- scitex/ai/classification/reporters/_mixins/_feature_importance.py +119 -0
- scitex/ai/classification/reporters/_mixins/_metrics.py +275 -0
- scitex/ai/classification/reporters/_mixins/_plotting.py +179 -0
- scitex/ai/classification/reporters/_mixins/_reports.py +153 -0
- scitex/ai/classification/reporters/_mixins/_storage.py +160 -0
- scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py +30 -1550
- scitex/ai/classification/timeseries/_sliding_window_core.py +467 -0
- scitex/ai/classification/timeseries/_sliding_window_plotting.py +369 -0
- scitex/audio/README.md +40 -36
- scitex/audio/__init__.py +129 -61
- scitex/audio/_branding.py +185 -0
- scitex/audio/_mcp/__init__.py +32 -0
- scitex/audio/_mcp/handlers.py +59 -6
- scitex/audio/_mcp/speak_handlers.py +238 -0
- scitex/audio/_relay.py +225 -0
- scitex/audio/_tts.py +18 -10
- scitex/audio/engines/base.py +17 -10
- scitex/audio/engines/elevenlabs_engine.py +7 -2
- scitex/audio/mcp_server.py +228 -75
- scitex/canvas/README.md +1 -1
- scitex/canvas/editor/_dearpygui/__init__.py +25 -0
- scitex/canvas/editor/_dearpygui/_editor.py +147 -0
- scitex/canvas/editor/_dearpygui/_handlers.py +476 -0
- scitex/canvas/editor/_dearpygui/_panels/__init__.py +17 -0
- scitex/canvas/editor/_dearpygui/_panels/_control.py +119 -0
- scitex/canvas/editor/_dearpygui/_panels/_element_controls.py +190 -0
- scitex/canvas/editor/_dearpygui/_panels/_preview.py +43 -0
- scitex/canvas/editor/_dearpygui/_panels/_sections.py +390 -0
- scitex/canvas/editor/_dearpygui/_plotting.py +187 -0
- scitex/canvas/editor/_dearpygui/_rendering.py +504 -0
- scitex/canvas/editor/_dearpygui/_selection.py +295 -0
- scitex/canvas/editor/_dearpygui/_state.py +93 -0
- scitex/canvas/editor/_dearpygui/_utils.py +61 -0
- scitex/canvas/editor/flask_editor/_core/__init__.py +27 -0
- scitex/canvas/editor/flask_editor/_core/_bbox_extraction.py +200 -0
- scitex/canvas/editor/flask_editor/_core/_editor.py +173 -0
- scitex/canvas/editor/flask_editor/_core/_export_helpers.py +353 -0
- scitex/canvas/editor/flask_editor/_core/_routes_basic.py +190 -0
- scitex/canvas/editor/flask_editor/_core/_routes_export.py +332 -0
- scitex/canvas/editor/flask_editor/_core/_routes_panels.py +252 -0
- scitex/canvas/editor/flask_editor/_core/_routes_save.py +218 -0
- scitex/canvas/editor/flask_editor/_core.py +25 -1684
- scitex/canvas/editor/flask_editor/templates/__init__.py +32 -70
- scitex/cli/__init__.py +38 -43
- scitex/cli/audio.py +76 -27
- scitex/cli/capture.py +13 -20
- scitex/cli/introspect.py +481 -0
- scitex/cli/main.py +200 -109
- scitex/cli/mcp.py +60 -34
- scitex/cli/plt.py +357 -0
- scitex/cli/repro.py +15 -8
- scitex/cli/resource.py +15 -8
- scitex/cli/scholar/__init__.py +23 -8
- scitex/cli/scholar/_crossref_scitex.py +296 -0
- scitex/cli/scholar/_fetch.py +25 -3
- scitex/cli/social.py +314 -0
- scitex/cli/stats.py +15 -8
- scitex/cli/template.py +129 -12
- scitex/cli/tex.py +15 -8
- scitex/cli/writer.py +132 -8
- scitex/cloud/__init__.py +41 -2
- scitex/config/README.md +1 -1
- scitex/config/__init__.py +16 -2
- scitex/config/_env_registry.py +256 -0
- scitex/context/__init__.py +22 -0
- scitex/dev/__init__.py +20 -1
- scitex/diagram/__init__.py +42 -19
- scitex/diagram/mcp_server.py +13 -125
- scitex/gen/__init__.py +50 -14
- scitex/gen/_list_packages.py +4 -4
- scitex/introspect/__init__.py +82 -0
- scitex/introspect/_call_graph.py +303 -0
- scitex/introspect/_class_hierarchy.py +163 -0
- scitex/introspect/_core.py +41 -0
- scitex/introspect/_docstring.py +131 -0
- scitex/introspect/_examples.py +113 -0
- scitex/introspect/_imports.py +271 -0
- scitex/{gen/_inspect_module.py → introspect/_list_api.py} +43 -54
- scitex/introspect/_mcp/__init__.py +41 -0
- scitex/introspect/_mcp/handlers.py +233 -0
- scitex/introspect/_members.py +155 -0
- scitex/introspect/_resolve.py +89 -0
- scitex/introspect/_signature.py +131 -0
- scitex/introspect/_source.py +80 -0
- scitex/introspect/_type_hints.py +172 -0
- scitex/io/_save.py +1 -2
- scitex/io/bundle/README.md +1 -1
- scitex/logging/_formatters.py +19 -9
- scitex/mcp_server.py +98 -5
- scitex/os/__init__.py +4 -0
- scitex/{gen → os}/_check_host.py +4 -5
- scitex/plt/__init__.py +245 -550
- scitex/plt/_subplots/_AxisWrapperMixins/_SeabornMixin/_wrappers.py +5 -10
- scitex/plt/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/plt/gallery/README.md +1 -1
- scitex/plt/utils/_hitmap/__init__.py +82 -0
- scitex/plt/utils/_hitmap/_artist_extraction.py +343 -0
- scitex/plt/utils/_hitmap/_color_application.py +346 -0
- scitex/plt/utils/_hitmap/_color_conversion.py +121 -0
- scitex/plt/utils/_hitmap/_constants.py +40 -0
- scitex/plt/utils/_hitmap/_hitmap_core.py +334 -0
- scitex/plt/utils/_hitmap/_path_extraction.py +357 -0
- scitex/plt/utils/_hitmap/_query.py +113 -0
- scitex/plt/utils/_hitmap.py +46 -1616
- scitex/plt/utils/_metadata/__init__.py +80 -0
- scitex/plt/utils/_metadata/_artists/__init__.py +25 -0
- scitex/plt/utils/_metadata/_artists/_base.py +195 -0
- scitex/plt/utils/_metadata/_artists/_collections.py +356 -0
- scitex/plt/utils/_metadata/_artists/_extract.py +57 -0
- scitex/plt/utils/_metadata/_artists/_images.py +80 -0
- scitex/plt/utils/_metadata/_artists/_lines.py +261 -0
- scitex/plt/utils/_metadata/_artists/_patches.py +247 -0
- scitex/plt/utils/_metadata/_artists/_text.py +106 -0
- scitex/plt/utils/_metadata/_csv.py +416 -0
- scitex/plt/utils/_metadata/_detect.py +225 -0
- scitex/plt/utils/_metadata/_legend.py +127 -0
- scitex/plt/utils/_metadata/_rounding.py +117 -0
- scitex/plt/utils/_metadata/_verification.py +202 -0
- scitex/schema/README.md +1 -1
- scitex/scholar/__init__.py +8 -0
- scitex/scholar/_mcp/crossref_handlers.py +265 -0
- scitex/scholar/core/Scholar.py +63 -1700
- scitex/scholar/core/_mixins/__init__.py +36 -0
- scitex/scholar/core/_mixins/_enrichers.py +270 -0
- scitex/scholar/core/_mixins/_library_handlers.py +100 -0
- scitex/scholar/core/_mixins/_loaders.py +103 -0
- scitex/scholar/core/_mixins/_pdf_download.py +375 -0
- scitex/scholar/core/_mixins/_pipeline.py +312 -0
- scitex/scholar/core/_mixins/_project_handlers.py +125 -0
- scitex/scholar/core/_mixins/_savers.py +69 -0
- scitex/scholar/core/_mixins/_search.py +103 -0
- scitex/scholar/core/_mixins/_services.py +88 -0
- scitex/scholar/core/_mixins/_url_finding.py +105 -0
- scitex/scholar/crossref_scitex.py +367 -0
- scitex/scholar/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/scholar/examples/00_run_all.sh +120 -0
- scitex/scholar/jobs/_executors.py +27 -3
- scitex/scholar/pdf_download/ScholarPDFDownloader.py +38 -416
- scitex/scholar/pdf_download/_cli.py +154 -0
- scitex/scholar/pdf_download/strategies/__init__.py +11 -8
- scitex/scholar/pdf_download/strategies/manual_download_fallback.py +80 -3
- scitex/scholar/pipelines/ScholarPipelineBibTeX.py +73 -121
- scitex/scholar/pipelines/ScholarPipelineParallel.py +80 -138
- scitex/scholar/pipelines/ScholarPipelineSingle.py +43 -63
- scitex/scholar/pipelines/_single_steps.py +71 -36
- scitex/scholar/storage/_LibraryManager.py +97 -1695
- scitex/scholar/storage/_mixins/__init__.py +30 -0
- scitex/scholar/storage/_mixins/_bibtex_handlers.py +128 -0
- scitex/scholar/storage/_mixins/_library_operations.py +218 -0
- scitex/scholar/storage/_mixins/_metadata_conversion.py +226 -0
- scitex/scholar/storage/_mixins/_paper_saving.py +456 -0
- scitex/scholar/storage/_mixins/_resolution.py +376 -0
- scitex/scholar/storage/_mixins/_storage_helpers.py +121 -0
- scitex/scholar/storage/_mixins/_symlink_handlers.py +226 -0
- scitex/scholar/url_finder/.tmp/open_url/KNOWN_RESOLVERS.py +462 -0
- scitex/scholar/url_finder/.tmp/open_url/README.md +223 -0
- scitex/scholar/url_finder/.tmp/open_url/_DOIToURLResolver.py +694 -0
- scitex/scholar/url_finder/.tmp/open_url/_OpenURLResolver.py +1160 -0
- scitex/scholar/url_finder/.tmp/open_url/_ResolverLinkFinder.py +344 -0
- scitex/scholar/url_finder/.tmp/open_url/__init__.py +24 -0
- scitex/security/README.md +3 -3
- scitex/session/README.md +1 -1
- scitex/session/__init__.py +26 -7
- scitex/session/_decorator.py +1 -1
- scitex/sh/README.md +1 -1
- scitex/sh/__init__.py +7 -4
- scitex/social/__init__.py +155 -0
- scitex/social/docs/EXTERNAL_PACKAGE_BRANDING.md +149 -0
- scitex/stats/_mcp/_handlers/__init__.py +31 -0
- scitex/stats/_mcp/_handlers/_corrections.py +113 -0
- scitex/stats/_mcp/_handlers/_descriptive.py +78 -0
- scitex/stats/_mcp/_handlers/_effect_size.py +106 -0
- scitex/stats/_mcp/_handlers/_format.py +94 -0
- scitex/stats/_mcp/_handlers/_normality.py +110 -0
- scitex/stats/_mcp/_handlers/_posthoc.py +224 -0
- scitex/stats/_mcp/_handlers/_power.py +247 -0
- scitex/stats/_mcp/_handlers/_recommend.py +102 -0
- scitex/stats/_mcp/_handlers/_run_test.py +279 -0
- scitex/stats/_mcp/_handlers/_stars.py +48 -0
- scitex/stats/_mcp/handlers.py +19 -1171
- scitex/stats/auto/_stat_style.py +175 -0
- scitex/stats/auto/_style_definitions.py +411 -0
- scitex/stats/auto/_styles.py +22 -620
- scitex/stats/descriptive/__init__.py +11 -8
- scitex/stats/descriptive/_ci.py +39 -0
- scitex/stats/power/_power.py +15 -4
- scitex/str/__init__.py +2 -1
- scitex/str/_title_case.py +63 -0
- scitex/template/README.md +1 -1
- scitex/template/__init__.py +25 -10
- scitex/template/_code_templates.py +147 -0
- scitex/template/_mcp/handlers.py +81 -0
- scitex/template/_mcp/tool_schemas.py +55 -0
- scitex/template/_templates/__init__.py +51 -0
- scitex/template/_templates/audio.py +233 -0
- scitex/template/_templates/canvas.py +312 -0
- scitex/template/_templates/capture.py +268 -0
- scitex/template/_templates/config.py +43 -0
- scitex/template/_templates/diagram.py +294 -0
- scitex/template/_templates/io.py +107 -0
- scitex/template/_templates/module.py +53 -0
- scitex/template/_templates/plt.py +202 -0
- scitex/template/_templates/scholar.py +267 -0
- scitex/template/_templates/session.py +130 -0
- scitex/template/_templates/session_minimal.py +43 -0
- scitex/template/_templates/session_plot.py +67 -0
- scitex/template/_templates/session_stats.py +77 -0
- scitex/template/_templates/stats.py +323 -0
- scitex/template/_templates/writer.py +296 -0
- scitex/template/clone_writer_directory.py +5 -5
- scitex/ui/_backends/_email.py +10 -2
- scitex/ui/_backends/_webhook.py +5 -1
- scitex/web/_search_pubmed.py +10 -6
- scitex/writer/README.md +1 -1
- scitex/writer/_mcp/handlers.py +11 -744
- scitex/writer/_mcp/tool_schemas.py +5 -335
- scitex-2.15.2.dist-info/METADATA +648 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/RECORD +246 -150
- scitex/canvas/editor/flask_editor/templates/_scripts.py +0 -4933
- scitex/canvas/editor/flask_editor/templates/_styles.py +0 -1658
- scitex/dev/plt/data/mpl/PLOTTING_FUNCTIONS.yaml +0 -90
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES.yaml +0 -1571
- scitex/dev/plt/data/mpl/PLOTTING_SIGNATURES_DETAILED.yaml +0 -6262
- scitex/dev/plt/data/mpl/SIGNATURES_FLATTENED.yaml +0 -1274
- scitex/dev/plt/data/mpl/dir_ax.txt +0 -459
- scitex/diagram/_compile.py +0 -312
- scitex/diagram/_diagram.py +0 -355
- scitex/diagram/_mcp/__init__.py +0 -4
- scitex/diagram/_mcp/handlers.py +0 -400
- scitex/diagram/_mcp/tool_schemas.py +0 -157
- scitex/diagram/_presets.py +0 -173
- scitex/diagram/_schema.py +0 -182
- scitex/diagram/_split.py +0 -278
- scitex/gen/_ci.py +0 -12
- scitex/gen/_title_case.py +0 -89
- scitex/plt/_mcp/__init__.py +0 -4
- scitex/plt/_mcp/_handlers_annotation.py +0 -102
- scitex/plt/_mcp/_handlers_figure.py +0 -195
- scitex/plt/_mcp/_handlers_plot.py +0 -252
- scitex/plt/_mcp/_handlers_style.py +0 -219
- scitex/plt/_mcp/handlers.py +0 -74
- scitex/plt/_mcp/tool_schemas.py +0 -497
- scitex/plt/mcp_server.py +0 -231
- scitex/scholar/data/.gitkeep +0 -0
- scitex/scholar/data/README.md +0 -44
- scitex/scholar/data/bib_files/bibliography.bib +0 -1952
- scitex/scholar/data/bib_files/neurovista.bib +0 -277
- scitex/scholar/data/bib_files/neurovista_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_enriched_enriched.bib +0 -441
- scitex/scholar/data/bib_files/neurovista_processed.bib +0 -338
- scitex/scholar/data/bib_files/openaccess.bib +0 -89
- scitex/scholar/data/bib_files/pac-seizure_prediction_enriched.bib +0 -2178
- scitex/scholar/data/bib_files/pac.bib +0 -698
- scitex/scholar/data/bib_files/pac_enriched.bib +0 -1061
- scitex/scholar/data/bib_files/pac_processed.bib +0 -0
- scitex/scholar/data/bib_files/pac_titles.txt +0 -75
- scitex/scholar/data/bib_files/paywalled.bib +0 -98
- scitex/scholar/data/bib_files/related-papers-by-coauthors.bib +0 -58
- scitex/scholar/data/bib_files/related-papers-by-coauthors_enriched.bib +0 -87
- scitex/scholar/data/bib_files/seizure_prediction.bib +0 -694
- scitex/scholar/data/bib_files/seizure_prediction_processed.bib +0 -0
- scitex/scholar/data/bib_files/test_complete_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_final_enriched.bib +0 -437
- scitex/scholar/data/bib_files/test_seizure.bib +0 -46
- scitex/scholar/data/impact_factor/JCR_IF_2022.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.db +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024.xlsx +0 -0
- scitex/scholar/data/impact_factor/JCR_IF_2024_v01.db +0 -0
- scitex/scholar/data/impact_factor.db +0 -0
- scitex/scholar/examples/SUGGESTIONS.md +0 -865
- scitex/scholar/examples/dev.py +0 -38
- scitex-2.14.0.dist-info/METADATA +0 -1238
- /scitex/{gen → context}/_detect_environment.py +0 -0
- /scitex/{gen → context}/_get_notebook_path.py +0 -0
- /scitex/{gen/_shell.py → sh/_shell_legacy.py} +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/WHEEL +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/entry_points.txt +0 -0
- {scitex-2.14.0.dist-info → scitex-2.15.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,59 +1,30 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/classification/timeseries/_TimeSeriesSlidingWindowSplit.py
|
|
5
|
-
# ----------------------------------------
|
|
6
|
-
from __future__ import annotations
|
|
7
|
-
import os
|
|
8
|
-
|
|
9
|
-
__FILE__ = "./src/scitex/ml/classification/timeseries/_TimeSeriesSlidingWindowSplit.py"
|
|
10
|
-
__DIR__ = os.path.dirname(__FILE__)
|
|
11
|
-
# ----------------------------------------
|
|
2
|
+
# Timestamp: "2026-01-24 (ywatanabe)"
|
|
3
|
+
# File: /home/ywatanabe/proj/scitex-python/src/scitex/ai/classification/timeseries/_TimeSeriesSlidingWindowSplit.py
|
|
12
4
|
|
|
13
|
-
"""
|
|
14
|
-
Functionalities:
|
|
15
|
-
- Implements sliding window cross-validation for time series
|
|
16
|
-
- Creates overlapping train/test windows that slide through time
|
|
17
|
-
- Supports temporal gaps between train and test sets
|
|
18
|
-
- Provides visualization with scatter plots showing actual data points
|
|
19
|
-
- Validates temporal order in all windows
|
|
20
|
-
- Ensures no data leakage between train and test sets
|
|
5
|
+
"""Sliding window cross-validation for time series.
|
|
21
6
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
- sklearn
|
|
26
|
-
- matplotlib
|
|
27
|
-
- scitex
|
|
7
|
+
This module provides the TimeSeriesSlidingWindowSplit class which combines:
|
|
8
|
+
- Core splitting functionality from _sliding_window_core
|
|
9
|
+
- Visualization support from _sliding_window_plotting
|
|
28
10
|
|
|
29
|
-
|
|
30
|
-
- input-files:
|
|
31
|
-
- None (generates synthetic data for demonstration)
|
|
32
|
-
- output-files:
|
|
33
|
-
- ./sliding_window_demo.png (visualization with scatter plots)
|
|
11
|
+
For demo/example usage, see examples/ai/classification/sliding_window_demo.py
|
|
34
12
|
"""
|
|
35
13
|
|
|
36
|
-
|
|
37
|
-
import argparse
|
|
38
|
-
from typing import Iterator, Optional, Tuple
|
|
14
|
+
from __future__ import annotations
|
|
39
15
|
|
|
40
|
-
|
|
41
|
-
import matplotlib.pyplot as plt
|
|
42
|
-
import numpy as np
|
|
43
|
-
import scitex as stx
|
|
44
|
-
from scitex import logging
|
|
45
|
-
from sklearn.model_selection import BaseCrossValidator
|
|
46
|
-
from sklearn.utils.validation import _num_samples
|
|
16
|
+
from typing import Optional
|
|
47
17
|
|
|
48
|
-
|
|
18
|
+
from ._sliding_window_core import TimeSeriesSlidingWindowSplitCore
|
|
19
|
+
from ._sliding_window_plotting import SlidingWindowPlottingMixin
|
|
49
20
|
|
|
50
|
-
|
|
51
|
-
COLORS["RGBA_NORM"]
|
|
21
|
+
__all__ = ["TimeSeriesSlidingWindowSplit"]
|
|
52
22
|
|
|
53
23
|
|
|
54
|
-
class TimeSeriesSlidingWindowSplit(
|
|
55
|
-
|
|
56
|
-
|
|
24
|
+
class TimeSeriesSlidingWindowSplit(
|
|
25
|
+
SlidingWindowPlottingMixin, TimeSeriesSlidingWindowSplitCore
|
|
26
|
+
):
|
|
27
|
+
"""Sliding window cross-validation for time series.
|
|
57
28
|
|
|
58
29
|
Creates train/test windows that slide through time with configurable behavior.
|
|
59
30
|
|
|
@@ -114,6 +85,9 @@ class TimeSeriesSlidingWindowSplit(BaseCrossValidator):
|
|
|
114
85
|
... )
|
|
115
86
|
>>> for train_idx, test_idx in swcv.split(X, y, timestamps):
|
|
116
87
|
... print(f"Train: {len(train_idx)}, Test: {len(test_idx)}")
|
|
88
|
+
>>>
|
|
89
|
+
>>> # Visualize splits
|
|
90
|
+
>>> fig = swcv.plot_splits(X, y, timestamps)
|
|
117
91
|
"""
|
|
118
92
|
|
|
119
93
|
def __init__(
|
|
@@ -129,1512 +103,18 @@ class TimeSeriesSlidingWindowSplit(BaseCrossValidator):
|
|
|
129
103
|
undersample: bool = False,
|
|
130
104
|
n_splits: Optional[int] = None,
|
|
131
105
|
):
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
raise ValueError(
|
|
144
|
-
"Either n_splits OR (window_size AND test_size) must be specified"
|
|
145
|
-
)
|
|
146
|
-
self.n_splits_mode = False
|
|
147
|
-
self._n_splits = None
|
|
148
|
-
self.window_size = window_size
|
|
149
|
-
self.test_size = test_size
|
|
150
|
-
|
|
151
|
-
self.gap = gap
|
|
152
|
-
self.val_ratio = val_ratio
|
|
153
|
-
self.random_state = random_state
|
|
154
|
-
self.rng = np.random.default_rng(random_state)
|
|
155
|
-
self.overlapping_tests = overlapping_tests
|
|
156
|
-
self.expanding_window = expanding_window
|
|
157
|
-
self.undersample = undersample
|
|
158
|
-
|
|
159
|
-
# Handle step_size logic
|
|
160
|
-
if not overlapping_tests:
|
|
161
|
-
# overlapping_tests=False: ensure non-overlapping tests
|
|
162
|
-
if step_size is not None and step_size < test_size:
|
|
163
|
-
logger.warning(
|
|
164
|
-
f"overlapping_tests=False but step_size={step_size} < test_size={test_size}. "
|
|
165
|
-
f"This would cause test overlap. Setting step_size=test_size={test_size}."
|
|
166
|
-
)
|
|
167
|
-
self.step_size = test_size
|
|
168
|
-
elif step_size is None:
|
|
169
|
-
# Default: non-overlapping tests
|
|
170
|
-
self.step_size = test_size
|
|
171
|
-
logger.info(
|
|
172
|
-
f"step_size not specified with overlapping_tests=False. "
|
|
173
|
-
f"Using step_size=test_size={test_size} for non-overlapping tests."
|
|
174
|
-
)
|
|
175
|
-
else:
|
|
176
|
-
# step_size >= test_size: acceptable, no overlap
|
|
177
|
-
self.step_size = step_size
|
|
178
|
-
else:
|
|
179
|
-
# overlapping_tests=True: allow any step_size
|
|
180
|
-
if step_size is None:
|
|
181
|
-
# Default for overlapping: half the test size for 50% overlap
|
|
182
|
-
self.step_size = max(1, test_size // 2)
|
|
183
|
-
logger.info(
|
|
184
|
-
f"step_size not specified with overlapping_tests=True. "
|
|
185
|
-
f"Using step_size={self.step_size} (50% overlap)."
|
|
186
|
-
)
|
|
187
|
-
else:
|
|
188
|
-
self.step_size = step_size
|
|
189
|
-
|
|
190
|
-
def _undersample_indices(
|
|
191
|
-
self, train_indices: np.ndarray, y: np.ndarray, timestamps: np.ndarray
|
|
192
|
-
) -> np.ndarray:
|
|
193
|
-
"""
|
|
194
|
-
Undersample majority class to balance training set.
|
|
195
|
-
|
|
196
|
-
Maintains temporal order of samples.
|
|
197
|
-
|
|
198
|
-
Parameters
|
|
199
|
-
----------
|
|
200
|
-
train_indices : ndarray
|
|
201
|
-
Original training indices
|
|
202
|
-
y : ndarray
|
|
203
|
-
Full label array
|
|
204
|
-
timestamps : ndarray
|
|
205
|
-
Full timestamp array
|
|
206
|
-
|
|
207
|
-
Returns
|
|
208
|
-
-------
|
|
209
|
-
ndarray
|
|
210
|
-
Undersampled training indices (sorted by timestamp)
|
|
211
|
-
"""
|
|
212
|
-
# Get labels for training indices
|
|
213
|
-
train_labels = y[train_indices]
|
|
214
|
-
|
|
215
|
-
# Find unique classes and their counts
|
|
216
|
-
unique_classes, class_counts = np.unique(train_labels, return_counts=True)
|
|
217
|
-
|
|
218
|
-
if len(unique_classes) < 2:
|
|
219
|
-
# Only one class, no undersampling needed
|
|
220
|
-
return train_indices
|
|
221
|
-
|
|
222
|
-
# Find minority class count
|
|
223
|
-
min_count = class_counts.min()
|
|
224
|
-
|
|
225
|
-
# Undersample each class to match minority class count
|
|
226
|
-
undersampled_indices = []
|
|
227
|
-
for cls in unique_classes:
|
|
228
|
-
# Find indices of this class within train_indices
|
|
229
|
-
cls_mask = train_labels == cls
|
|
230
|
-
cls_train_indices = train_indices[cls_mask]
|
|
231
|
-
|
|
232
|
-
if len(cls_train_indices) > min_count:
|
|
233
|
-
# Randomly select min_count samples
|
|
234
|
-
selected = self.rng.choice(
|
|
235
|
-
cls_train_indices, size=min_count, replace=False
|
|
236
|
-
)
|
|
237
|
-
undersampled_indices.extend(selected)
|
|
238
|
-
else:
|
|
239
|
-
# Keep all samples from minority class
|
|
240
|
-
undersampled_indices.extend(cls_train_indices)
|
|
241
|
-
|
|
242
|
-
# Convert to array and sort by timestamp to maintain temporal order
|
|
243
|
-
undersampled_indices = np.array(undersampled_indices)
|
|
244
|
-
temporal_order = np.argsort(timestamps[undersampled_indices])
|
|
245
|
-
undersampled_indices = undersampled_indices[temporal_order]
|
|
246
|
-
|
|
247
|
-
return undersampled_indices
|
|
248
|
-
|
|
249
|
-
def split(
|
|
250
|
-
self,
|
|
251
|
-
X: np.ndarray,
|
|
252
|
-
y: Optional[np.ndarray] = None,
|
|
253
|
-
timestamps: Optional[np.ndarray] = None,
|
|
254
|
-
groups: Optional[np.ndarray] = None,
|
|
255
|
-
) -> Iterator[Tuple[np.ndarray, np.ndarray]]:
|
|
256
|
-
"""
|
|
257
|
-
Generate sliding window splits.
|
|
258
|
-
|
|
259
|
-
Parameters
|
|
260
|
-
----------
|
|
261
|
-
X : array-like, shape (n_samples, n_features)
|
|
262
|
-
Training data
|
|
263
|
-
y : array-like, shape (n_samples,), optional
|
|
264
|
-
Target variable
|
|
265
|
-
timestamps : array-like, shape (n_samples,), optional
|
|
266
|
-
Timestamps for temporal ordering. If None, uses sequential order
|
|
267
|
-
groups : array-like, shape (n_samples,), optional
|
|
268
|
-
Group labels (not used in this splitter)
|
|
269
|
-
|
|
270
|
-
Yields
|
|
271
|
-
------
|
|
272
|
-
train : ndarray
|
|
273
|
-
Training set indices
|
|
274
|
-
test : ndarray
|
|
275
|
-
Test set indices
|
|
276
|
-
"""
|
|
277
|
-
if timestamps is None:
|
|
278
|
-
timestamps = np.arange(len(X))
|
|
279
|
-
|
|
280
|
-
n_samples = _num_samples(X)
|
|
281
|
-
indices = np.arange(n_samples)
|
|
282
|
-
|
|
283
|
-
# Sort by timestamp to get temporal order
|
|
284
|
-
time_order = np.argsort(timestamps)
|
|
285
|
-
sorted_indices = indices[time_order]
|
|
286
|
-
|
|
287
|
-
# Auto-calculate sizes if using n_splits mode
|
|
288
|
-
if self.n_splits_mode:
|
|
289
|
-
# Calculate test_size to create exactly n_splits folds
|
|
290
|
-
# Formula: n_samples = window_size + (n_splits * (test_size + gap))
|
|
291
|
-
# For expanding window, window_size is minimum training size
|
|
292
|
-
# We want non-overlapping tests by default
|
|
293
|
-
|
|
294
|
-
if self.expanding_window:
|
|
295
|
-
# Expanding window: start with minimum window, test slides forward
|
|
296
|
-
# Let's use 20% of data as initial window (similar to sklearn)
|
|
297
|
-
min_window_size = max(1, n_samples // (self._n_splits + 1))
|
|
298
|
-
available_for_test = (
|
|
299
|
-
n_samples - min_window_size - (self._n_splits * self.gap)
|
|
300
|
-
)
|
|
301
|
-
calculated_test_size = max(1, available_for_test // self._n_splits)
|
|
302
|
-
|
|
303
|
-
# Set calculated values
|
|
304
|
-
self.window_size = min_window_size
|
|
305
|
-
self.test_size = calculated_test_size
|
|
306
|
-
self.step_size = calculated_test_size # Non-overlapping by default
|
|
307
|
-
|
|
308
|
-
logger.info(
|
|
309
|
-
f"n_splits={self._n_splits} with expanding_window: "
|
|
310
|
-
f"Calculated window_size={self.window_size}, test_size={self.test_size}"
|
|
311
|
-
)
|
|
312
|
-
else:
|
|
313
|
-
# Fixed window: calculate window and test size
|
|
314
|
-
# We want: n_samples = window_size + (n_splits * (test_size + gap))
|
|
315
|
-
# Let's make window_size same as test_size for simplicity
|
|
316
|
-
available = n_samples - (self._n_splits * self.gap)
|
|
317
|
-
calculated_test_size = max(1, available // (self._n_splits + 1))
|
|
318
|
-
calculated_window_size = calculated_test_size
|
|
319
|
-
|
|
320
|
-
# Set calculated values
|
|
321
|
-
self.window_size = calculated_window_size
|
|
322
|
-
self.test_size = calculated_test_size
|
|
323
|
-
self.step_size = calculated_test_size # Non-overlapping by default
|
|
324
|
-
|
|
325
|
-
logger.info(
|
|
326
|
-
f"n_splits={self._n_splits} with fixed window: "
|
|
327
|
-
f"Calculated window_size={self.window_size}, test_size={self.test_size}"
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
if self.expanding_window:
|
|
331
|
-
# Expanding window: training set grows to include all past data
|
|
332
|
-
# Start with minimum window_size, test slides forward
|
|
333
|
-
min_train_size = self.window_size
|
|
334
|
-
total_min = min_train_size + self.gap + self.test_size
|
|
335
|
-
|
|
336
|
-
if n_samples < total_min:
|
|
337
|
-
logger.warning(
|
|
338
|
-
f"Not enough samples ({n_samples}) for even one split. "
|
|
339
|
-
f"Need at least {total_min} samples."
|
|
340
|
-
)
|
|
341
|
-
return
|
|
342
|
-
|
|
343
|
-
# First fold starts at window_size
|
|
344
|
-
test_start_pos = min_train_size + self.gap
|
|
345
|
-
|
|
346
|
-
while test_start_pos + self.test_size <= n_samples:
|
|
347
|
-
test_end_pos = test_start_pos + self.test_size
|
|
348
|
-
|
|
349
|
-
# Training includes all data from start to before gap
|
|
350
|
-
train_end_pos = test_start_pos - self.gap
|
|
351
|
-
train_indices = sorted_indices[0:train_end_pos]
|
|
352
|
-
test_indices = sorted_indices[test_start_pos:test_end_pos]
|
|
353
|
-
|
|
354
|
-
# Apply undersampling if enabled and y is provided
|
|
355
|
-
if self.undersample and y is not None:
|
|
356
|
-
train_indices = self._undersample_indices(
|
|
357
|
-
train_indices, y, timestamps
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
|
|
361
|
-
|
|
362
|
-
yield train_indices, test_indices
|
|
363
|
-
|
|
364
|
-
# Move test window forward by step_size
|
|
365
|
-
test_start_pos += self.step_size
|
|
366
|
-
|
|
367
|
-
else:
|
|
368
|
-
# Fixed sliding window: window slides through data
|
|
369
|
-
total_window = self.window_size + self.gap + self.test_size
|
|
370
|
-
|
|
371
|
-
for start in range(0, n_samples - total_window + 1, self.step_size):
|
|
372
|
-
# These positions are in the sorted (temporal) domain
|
|
373
|
-
train_end = start + self.window_size
|
|
374
|
-
test_start = train_end + self.gap
|
|
375
|
-
test_end = test_start + self.test_size
|
|
376
|
-
|
|
377
|
-
if test_end > n_samples:
|
|
378
|
-
break
|
|
379
|
-
|
|
380
|
-
# Extract indices from the temporally sorted sequence
|
|
381
|
-
train_indices = sorted_indices[start:train_end]
|
|
382
|
-
test_indices = sorted_indices[test_start:test_end]
|
|
383
|
-
|
|
384
|
-
# Apply undersampling if enabled and y is provided
|
|
385
|
-
if self.undersample and y is not None:
|
|
386
|
-
train_indices = self._undersample_indices(
|
|
387
|
-
train_indices, y, timestamps
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
|
|
391
|
-
|
|
392
|
-
yield train_indices, test_indices
|
|
393
|
-
|
|
394
|
-
def split_with_val(
|
|
395
|
-
self,
|
|
396
|
-
X: np.ndarray,
|
|
397
|
-
y: Optional[np.ndarray] = None,
|
|
398
|
-
timestamps: Optional[np.ndarray] = None,
|
|
399
|
-
groups: Optional[np.ndarray] = None,
|
|
400
|
-
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
|
401
|
-
"""
|
|
402
|
-
Generate sliding window splits with validation set.
|
|
403
|
-
|
|
404
|
-
The validation set comes after training but before test, maintaining
|
|
405
|
-
temporal order: train < val < test.
|
|
406
|
-
|
|
407
|
-
Parameters
|
|
408
|
-
----------
|
|
409
|
-
X : array-like, shape (n_samples, n_features)
|
|
410
|
-
Training data
|
|
411
|
-
y : array-like, shape (n_samples,), optional
|
|
412
|
-
Target variable
|
|
413
|
-
timestamps : array-like, shape (n_samples,), optional
|
|
414
|
-
Timestamps for temporal ordering. If None, uses sequential order
|
|
415
|
-
groups : array-like, shape (n_samples,), optional
|
|
416
|
-
Group labels (not used in this splitter)
|
|
417
|
-
|
|
418
|
-
Yields
|
|
419
|
-
------
|
|
420
|
-
train : ndarray
|
|
421
|
-
Training set indices
|
|
422
|
-
val : ndarray
|
|
423
|
-
Validation set indices
|
|
424
|
-
test : ndarray
|
|
425
|
-
Test set indices
|
|
426
|
-
"""
|
|
427
|
-
if timestamps is None:
|
|
428
|
-
timestamps = np.arange(len(X))
|
|
429
|
-
|
|
430
|
-
n_samples = _num_samples(X)
|
|
431
|
-
indices = np.arange(n_samples)
|
|
432
|
-
|
|
433
|
-
# Sort by timestamp to get temporal order
|
|
434
|
-
time_order = np.argsort(timestamps)
|
|
435
|
-
sorted_indices = indices[time_order]
|
|
436
|
-
|
|
437
|
-
# Auto-calculate sizes if using n_splits mode
|
|
438
|
-
if self.n_splits_mode:
|
|
439
|
-
if self.expanding_window:
|
|
440
|
-
min_window_size = max(1, n_samples // (self._n_splits + 1))
|
|
441
|
-
available_for_test = (
|
|
442
|
-
n_samples - min_window_size - (self._n_splits * self.gap)
|
|
443
|
-
)
|
|
444
|
-
calculated_test_size = max(1, available_for_test // self._n_splits)
|
|
445
|
-
self.window_size = min_window_size
|
|
446
|
-
self.test_size = calculated_test_size
|
|
447
|
-
self.step_size = calculated_test_size
|
|
448
|
-
else:
|
|
449
|
-
available = n_samples - (self._n_splits * self.gap)
|
|
450
|
-
calculated_test_size = max(1, available // (self._n_splits + 1))
|
|
451
|
-
calculated_window_size = calculated_test_size
|
|
452
|
-
self.window_size = calculated_window_size
|
|
453
|
-
self.test_size = calculated_test_size
|
|
454
|
-
self.step_size = calculated_test_size
|
|
455
|
-
|
|
456
|
-
# Calculate validation size from training window
|
|
457
|
-
val_size = int(self.window_size * self.val_ratio) if self.val_ratio > 0 else 0
|
|
458
|
-
actual_train_size = self.window_size - val_size
|
|
459
|
-
|
|
460
|
-
if self.expanding_window:
|
|
461
|
-
# Expanding window with validation
|
|
462
|
-
min_train_size = self.window_size
|
|
463
|
-
total_min = min_train_size + self.gap + self.test_size
|
|
464
|
-
|
|
465
|
-
if n_samples < total_min:
|
|
466
|
-
logger.warning(
|
|
467
|
-
f"Not enough samples ({n_samples}) for even one split. "
|
|
468
|
-
f"Need at least {total_min} samples."
|
|
469
|
-
)
|
|
470
|
-
return
|
|
471
|
-
|
|
472
|
-
# Calculate positions for validation and test
|
|
473
|
-
test_start_pos = min_train_size + self.gap
|
|
474
|
-
|
|
475
|
-
while test_start_pos + self.test_size <= n_samples:
|
|
476
|
-
test_end_pos = test_start_pos + self.test_size
|
|
477
|
-
|
|
478
|
-
# Training + validation comes before gap
|
|
479
|
-
train_val_end_pos = test_start_pos - self.gap
|
|
480
|
-
|
|
481
|
-
# Split train/val from the expanding window
|
|
482
|
-
if val_size > 0:
|
|
483
|
-
# Calculate validation size dynamically based on current expanding window
|
|
484
|
-
# This ensures val_ratio is respected across all folds as window expands
|
|
485
|
-
current_val_size = int(train_val_end_pos * self.val_ratio)
|
|
486
|
-
train_end_pos = train_val_end_pos - current_val_size
|
|
487
|
-
train_indices = sorted_indices[0:train_end_pos]
|
|
488
|
-
val_indices = sorted_indices[train_end_pos:train_val_end_pos]
|
|
489
|
-
else:
|
|
490
|
-
train_indices = sorted_indices[0:train_val_end_pos]
|
|
491
|
-
val_indices = np.array([])
|
|
492
|
-
|
|
493
|
-
test_indices = sorted_indices[test_start_pos:test_end_pos]
|
|
494
|
-
|
|
495
|
-
# Apply undersampling if enabled and y is provided
|
|
496
|
-
if self.undersample and y is not None:
|
|
497
|
-
train_indices = self._undersample_indices(
|
|
498
|
-
train_indices, y, timestamps
|
|
499
|
-
)
|
|
500
|
-
# Also undersample validation set if it exists
|
|
501
|
-
if len(val_indices) > 0:
|
|
502
|
-
val_indices = self._undersample_indices(
|
|
503
|
-
val_indices, y, timestamps
|
|
504
|
-
)
|
|
505
|
-
|
|
506
|
-
assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
|
|
507
|
-
|
|
508
|
-
yield train_indices, val_indices, test_indices
|
|
509
|
-
|
|
510
|
-
# Move test window forward by step_size
|
|
511
|
-
test_start_pos += self.step_size
|
|
512
|
-
|
|
513
|
-
else:
|
|
514
|
-
# Fixed sliding window with validation
|
|
515
|
-
total_window = self.window_size + self.gap + self.test_size
|
|
516
|
-
|
|
517
|
-
for start in range(0, n_samples - total_window + 1, self.step_size):
|
|
518
|
-
# These positions are in the sorted (temporal) domain
|
|
519
|
-
train_end = start + actual_train_size
|
|
520
|
-
|
|
521
|
-
# Validation comes after train with optional gap
|
|
522
|
-
val_start = train_end + (self.gap if val_size > 0 else 0)
|
|
523
|
-
val_end = val_start + val_size
|
|
524
|
-
|
|
525
|
-
# Test comes after validation with gap
|
|
526
|
-
test_start = (
|
|
527
|
-
val_end + self.gap if val_size > 0 else train_end + self.gap
|
|
528
|
-
)
|
|
529
|
-
test_end = test_start + self.test_size
|
|
530
|
-
|
|
531
|
-
if test_end > n_samples:
|
|
532
|
-
break
|
|
533
|
-
|
|
534
|
-
# Extract indices from the temporally sorted sequence
|
|
535
|
-
train_indices = sorted_indices[start:train_end]
|
|
536
|
-
val_indices = (
|
|
537
|
-
sorted_indices[val_start:val_end] if val_size > 0 else np.array([])
|
|
538
|
-
)
|
|
539
|
-
test_indices = sorted_indices[test_start:test_end]
|
|
540
|
-
|
|
541
|
-
# Apply undersampling if enabled and y is provided
|
|
542
|
-
if self.undersample and y is not None:
|
|
543
|
-
train_indices = self._undersample_indices(
|
|
544
|
-
train_indices, y, timestamps
|
|
545
|
-
)
|
|
546
|
-
# Also undersample validation set if it exists
|
|
547
|
-
if len(val_indices) > 0:
|
|
548
|
-
val_indices = self._undersample_indices(
|
|
549
|
-
val_indices, y, timestamps
|
|
550
|
-
)
|
|
551
|
-
|
|
552
|
-
# Ensure temporal order is preserved
|
|
553
|
-
assert len(train_indices) > 0 and len(test_indices) > 0, "Empty window"
|
|
554
|
-
|
|
555
|
-
yield train_indices, val_indices, test_indices
|
|
556
|
-
|
|
557
|
-
def get_n_splits(self, X=None, y=None, groups=None):
|
|
558
|
-
"""
|
|
559
|
-
Calculate number of splits.
|
|
560
|
-
|
|
561
|
-
Parameters
|
|
562
|
-
----------
|
|
563
|
-
X : array-like, shape (n_samples, n_features), optional
|
|
564
|
-
Training data (required to determine number of splits in manual mode)
|
|
565
|
-
y : array-like, optional
|
|
566
|
-
Not used
|
|
567
|
-
groups : array-like, optional
|
|
568
|
-
Not used
|
|
569
|
-
|
|
570
|
-
Returns
|
|
571
|
-
-------
|
|
572
|
-
n_splits : int
|
|
573
|
-
Number of splits. Returns -1 if X is None and not in n_splits mode.
|
|
574
|
-
"""
|
|
575
|
-
# If using n_splits mode, return the specified n_splits
|
|
576
|
-
if self.n_splits_mode:
|
|
577
|
-
return self._n_splits
|
|
578
|
-
|
|
579
|
-
# Manual mode: need data to calculate
|
|
580
|
-
if X is None:
|
|
581
|
-
return -1 # Can't determine without data
|
|
582
|
-
|
|
583
|
-
n_samples = _num_samples(X)
|
|
584
|
-
total_window = self.window_size + self.gap + self.test_size
|
|
585
|
-
n_windows = (n_samples - total_window) // self.step_size + 1
|
|
586
|
-
return max(0, n_windows)
|
|
587
|
-
|
|
588
|
-
def plot_splits(self, X, y=None, timestamps=None, figsize=(12, 6), save_path=None):
|
|
589
|
-
"""
|
|
590
|
-
Visualize the sliding window splits as rectangles.
|
|
591
|
-
|
|
592
|
-
Shows train (blue), validation (green), and test (red) sets.
|
|
593
|
-
When val_ratio=0, only shows train and test.
|
|
594
|
-
When undersampling is enabled, shows dropped samples in gray.
|
|
595
|
-
|
|
596
|
-
Parameters
|
|
597
|
-
----------
|
|
598
|
-
X : array-like
|
|
599
|
-
Training data
|
|
600
|
-
y : array-like, optional
|
|
601
|
-
Target variable (required for undersampling visualization)
|
|
602
|
-
timestamps : array-like, optional
|
|
603
|
-
Timestamps (if None, uses sample indices)
|
|
604
|
-
figsize : tuple, default (12, 6)
|
|
605
|
-
Figure size
|
|
606
|
-
save_path : str, optional
|
|
607
|
-
Path to save the plot
|
|
608
|
-
|
|
609
|
-
Returns
|
|
610
|
-
-------
|
|
611
|
-
fig : matplotlib.figure.Figure
|
|
612
|
-
The created figure
|
|
613
|
-
"""
|
|
614
|
-
# Use sample indices if no timestamps provided
|
|
615
|
-
if timestamps is None:
|
|
616
|
-
timestamps = np.arange(len(X))
|
|
617
|
-
|
|
618
|
-
# Get temporal ordering
|
|
619
|
-
time_order = np.argsort(timestamps)
|
|
620
|
-
sorted_timestamps = timestamps[time_order]
|
|
621
|
-
|
|
622
|
-
# Get splits WITH undersampling (if enabled)
|
|
623
|
-
if self.val_ratio > 0:
|
|
624
|
-
splits = list(self.split_with_val(X, y, timestamps))[:10]
|
|
625
|
-
split_type = "train-val-test"
|
|
626
|
-
else:
|
|
627
|
-
splits = list(self.split(X, y, timestamps))[:10]
|
|
628
|
-
split_type = "train-test"
|
|
629
|
-
|
|
630
|
-
if not splits:
|
|
631
|
-
raise ValueError("No splits generated")
|
|
632
|
-
|
|
633
|
-
# If undersampling is enabled, also get splits WITHOUT undersampling to show dropped samples
|
|
634
|
-
splits_no_undersample = None
|
|
635
|
-
if self.undersample and y is not None:
|
|
636
|
-
original_undersample = self.undersample
|
|
637
|
-
self.undersample = False # Temporarily disable
|
|
638
|
-
if self.val_ratio > 0:
|
|
639
|
-
splits_no_undersample = list(self.split_with_val(X, y, timestamps))[:10]
|
|
640
|
-
else:
|
|
641
|
-
splits_no_undersample = list(self.split(X, y, timestamps))[:10]
|
|
642
|
-
self.undersample = original_undersample # Restore
|
|
643
|
-
|
|
644
|
-
# Create figure
|
|
645
|
-
fig, ax = stx.plt.subplots(figsize=figsize)
|
|
646
|
-
|
|
647
|
-
# Plot each fold based on temporal position
|
|
648
|
-
for fold, split_indices in enumerate(splits):
|
|
649
|
-
y_pos = fold
|
|
650
|
-
|
|
651
|
-
if len(split_indices) == 3: # train, val, test
|
|
652
|
-
train_idx, val_idx, test_idx = split_indices
|
|
653
|
-
|
|
654
|
-
# Find temporal positions of train indices
|
|
655
|
-
train_positions = []
|
|
656
|
-
for idx in train_idx:
|
|
657
|
-
temp_pos = np.where(time_order == idx)[0][
|
|
658
|
-
0
|
|
659
|
-
] # Find position in sorted order
|
|
660
|
-
train_positions.append(temp_pos)
|
|
661
|
-
|
|
662
|
-
# Plot train window based on temporal positions
|
|
663
|
-
if train_positions:
|
|
664
|
-
train_start = min(train_positions)
|
|
665
|
-
train_end = max(train_positions)
|
|
666
|
-
train_rect = patches.Rectangle(
|
|
667
|
-
(train_start, y_pos - 0.3),
|
|
668
|
-
train_end - train_start + 1,
|
|
669
|
-
0.6,
|
|
670
|
-
linewidth=1,
|
|
671
|
-
edgecolor="blue",
|
|
672
|
-
facecolor="lightblue",
|
|
673
|
-
alpha=0.7,
|
|
674
|
-
label="Train" if fold == 0 else "",
|
|
675
|
-
)
|
|
676
|
-
ax.add_patch(train_rect)
|
|
677
|
-
|
|
678
|
-
# Find temporal positions of validation indices
|
|
679
|
-
if len(val_idx) > 0:
|
|
680
|
-
val_positions = []
|
|
681
|
-
for idx in val_idx:
|
|
682
|
-
temp_pos = np.where(time_order == idx)[0][0]
|
|
683
|
-
val_positions.append(temp_pos)
|
|
684
|
-
|
|
685
|
-
# Plot validation window
|
|
686
|
-
if val_positions:
|
|
687
|
-
val_start = min(val_positions)
|
|
688
|
-
val_end = max(val_positions)
|
|
689
|
-
val_rect = patches.Rectangle(
|
|
690
|
-
(val_start, y_pos - 0.3),
|
|
691
|
-
val_end - val_start + 1,
|
|
692
|
-
0.6,
|
|
693
|
-
linewidth=1,
|
|
694
|
-
edgecolor="green",
|
|
695
|
-
facecolor="lightgreen",
|
|
696
|
-
alpha=0.7,
|
|
697
|
-
label="Validation" if fold == 0 else "",
|
|
698
|
-
)
|
|
699
|
-
ax.add_patch(val_rect)
|
|
700
|
-
|
|
701
|
-
# Find temporal positions of test indices
|
|
702
|
-
test_positions = []
|
|
703
|
-
for idx in test_idx:
|
|
704
|
-
temp_pos = np.where(time_order == idx)[0][
|
|
705
|
-
0
|
|
706
|
-
] # Find position in sorted order
|
|
707
|
-
test_positions.append(temp_pos)
|
|
708
|
-
|
|
709
|
-
# Plot test window based on temporal positions
|
|
710
|
-
if test_positions:
|
|
711
|
-
test_start = min(test_positions)
|
|
712
|
-
test_end = max(test_positions)
|
|
713
|
-
test_rect = patches.Rectangle(
|
|
714
|
-
(test_start, y_pos - 0.3),
|
|
715
|
-
test_end - test_start + 1,
|
|
716
|
-
0.6,
|
|
717
|
-
linewidth=1,
|
|
718
|
-
edgecolor=COLORS["RGBA_NORM"]["red"],
|
|
719
|
-
facecolor=COLORS["RGBA_NORM"]["red"],
|
|
720
|
-
alpha=0.7,
|
|
721
|
-
label="Test" if fold == 0 else "",
|
|
722
|
-
)
|
|
723
|
-
ax.add_patch(test_rect)
|
|
724
|
-
|
|
725
|
-
else: # train, test (2-way split)
|
|
726
|
-
train_idx, test_idx = split_indices
|
|
727
|
-
|
|
728
|
-
# Find temporal positions of train indices
|
|
729
|
-
train_positions = []
|
|
730
|
-
for idx in train_idx:
|
|
731
|
-
temp_pos = np.where(time_order == idx)[0][
|
|
732
|
-
0
|
|
733
|
-
] # Find position in sorted order
|
|
734
|
-
train_positions.append(temp_pos)
|
|
735
|
-
|
|
736
|
-
# Plot train window based on temporal positions
|
|
737
|
-
if train_positions:
|
|
738
|
-
train_start = min(train_positions)
|
|
739
|
-
train_end = max(train_positions)
|
|
740
|
-
train_rect = patches.Rectangle(
|
|
741
|
-
(train_start, y_pos - 0.3),
|
|
742
|
-
train_end - train_start + 1,
|
|
743
|
-
0.6,
|
|
744
|
-
linewidth=1,
|
|
745
|
-
edgecolor=COLORS["RGBA_NORM"]["lightblue"],
|
|
746
|
-
facecolor=COLORS["RGBA_NORM"]["lightblue"],
|
|
747
|
-
alpha=0.7,
|
|
748
|
-
label="Train" if fold == 0 else "",
|
|
749
|
-
)
|
|
750
|
-
ax.add_patch(train_rect)
|
|
751
|
-
|
|
752
|
-
# Find temporal positions of test indices
|
|
753
|
-
test_positions = []
|
|
754
|
-
for idx in test_idx:
|
|
755
|
-
temp_pos = np.where(time_order == idx)[0][
|
|
756
|
-
0
|
|
757
|
-
] # Find position in sorted order
|
|
758
|
-
test_positions.append(temp_pos)
|
|
759
|
-
|
|
760
|
-
# Plot test window based on temporal positions
|
|
761
|
-
if test_positions:
|
|
762
|
-
test_start = min(test_positions)
|
|
763
|
-
test_end = max(test_positions)
|
|
764
|
-
test_rect = patches.Rectangle(
|
|
765
|
-
(test_start, y_pos - 0.3),
|
|
766
|
-
test_end - test_start + 1,
|
|
767
|
-
0.6,
|
|
768
|
-
linewidth=1,
|
|
769
|
-
edgecolor="red",
|
|
770
|
-
facecolor="lightcoral",
|
|
771
|
-
alpha=0.7,
|
|
772
|
-
label="Test" if fold == 0 else "",
|
|
773
|
-
)
|
|
774
|
-
ax.add_patch(test_rect)
|
|
775
|
-
|
|
776
|
-
# Add scatter plots of actual data points with jittering
|
|
777
|
-
np.random.seed(42) # For reproducible jittering
|
|
778
|
-
jitter_strength = 0.15 # Amount of vertical jittering
|
|
779
|
-
|
|
780
|
-
# First, plot dropped samples in gray if undersampling is enabled
|
|
781
|
-
if splits_no_undersample is not None:
|
|
782
|
-
for fold, split_indices_no_us in enumerate(splits_no_undersample):
|
|
783
|
-
y_pos = fold
|
|
784
|
-
split_indices_us = splits[fold]
|
|
785
|
-
|
|
786
|
-
if len(split_indices_no_us) == 3: # train, val, test
|
|
787
|
-
train_idx_no_us, val_idx_no_us, test_idx_no_us = split_indices_no_us
|
|
788
|
-
train_idx_us, val_idx_us, test_idx_us = split_indices_us
|
|
789
|
-
|
|
790
|
-
# Find dropped train samples
|
|
791
|
-
dropped_train = np.setdiff1d(train_idx_no_us, train_idx_us)
|
|
792
|
-
if len(dropped_train) > 0:
|
|
793
|
-
dropped_train_positions = [
|
|
794
|
-
np.where(time_order == idx)[0][0] for idx in dropped_train
|
|
795
|
-
]
|
|
796
|
-
dropped_train_jitter = np.random.normal(
|
|
797
|
-
0, jitter_strength, len(dropped_train_positions)
|
|
798
|
-
)
|
|
799
|
-
ax.plot_scatter(
|
|
800
|
-
dropped_train_positions,
|
|
801
|
-
y_pos + dropped_train_jitter,
|
|
802
|
-
c="gray",
|
|
803
|
-
s=15,
|
|
804
|
-
alpha=0.3,
|
|
805
|
-
marker="x",
|
|
806
|
-
label="Dropped (train)" if fold == 0 else "",
|
|
807
|
-
zorder=2,
|
|
808
|
-
)
|
|
809
|
-
|
|
810
|
-
# Find dropped validation samples
|
|
811
|
-
dropped_val = np.setdiff1d(val_idx_no_us, val_idx_us)
|
|
812
|
-
if len(dropped_val) > 0:
|
|
813
|
-
dropped_val_positions = [
|
|
814
|
-
np.where(time_order == idx)[0][0] for idx in dropped_val
|
|
815
|
-
]
|
|
816
|
-
dropped_val_jitter = np.random.normal(
|
|
817
|
-
0, jitter_strength, len(dropped_val_positions)
|
|
818
|
-
)
|
|
819
|
-
ax.plot_scatter(
|
|
820
|
-
dropped_val_positions,
|
|
821
|
-
y_pos + dropped_val_jitter,
|
|
822
|
-
c="gray",
|
|
823
|
-
s=15,
|
|
824
|
-
alpha=0.3,
|
|
825
|
-
marker="x",
|
|
826
|
-
label="Dropped (val)" if fold == 0 else "",
|
|
827
|
-
zorder=2,
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
else: # train, test (2-way split)
|
|
831
|
-
train_idx_no_us, test_idx_no_us = split_indices_no_us
|
|
832
|
-
train_idx_us, test_idx_us = split_indices_us
|
|
833
|
-
|
|
834
|
-
# Find dropped train samples
|
|
835
|
-
dropped_train = np.setdiff1d(train_idx_no_us, train_idx_us)
|
|
836
|
-
if len(dropped_train) > 0:
|
|
837
|
-
dropped_train_positions = [
|
|
838
|
-
np.where(time_order == idx)[0][0] for idx in dropped_train
|
|
839
|
-
]
|
|
840
|
-
dropped_train_jitter = np.random.normal(
|
|
841
|
-
0, jitter_strength, len(dropped_train_positions)
|
|
842
|
-
)
|
|
843
|
-
ax.plot_scatter(
|
|
844
|
-
dropped_train_positions,
|
|
845
|
-
y_pos + dropped_train_jitter,
|
|
846
|
-
c="gray",
|
|
847
|
-
s=15,
|
|
848
|
-
alpha=0.3,
|
|
849
|
-
marker="x",
|
|
850
|
-
label="Dropped samples" if fold == 0 else "",
|
|
851
|
-
zorder=2,
|
|
852
|
-
)
|
|
853
|
-
|
|
854
|
-
# Then, plot kept samples in color
|
|
855
|
-
for fold, split_indices in enumerate(splits):
|
|
856
|
-
y_pos = fold
|
|
857
|
-
|
|
858
|
-
if len(split_indices) == 3: # train, val, test
|
|
859
|
-
train_idx, val_idx, test_idx = split_indices
|
|
860
|
-
|
|
861
|
-
# Find temporal positions for scatter plot
|
|
862
|
-
train_positions = []
|
|
863
|
-
for idx in train_idx:
|
|
864
|
-
temp_pos = np.where(time_order == idx)[0][0]
|
|
865
|
-
train_positions.append(temp_pos)
|
|
866
|
-
|
|
867
|
-
val_positions = []
|
|
868
|
-
if len(val_idx) > 0:
|
|
869
|
-
for idx in val_idx:
|
|
870
|
-
temp_pos = np.where(time_order == idx)[0][0]
|
|
871
|
-
val_positions.append(temp_pos)
|
|
872
|
-
|
|
873
|
-
test_positions = []
|
|
874
|
-
for idx in test_idx:
|
|
875
|
-
temp_pos = np.where(time_order == idx)[0][0]
|
|
876
|
-
test_positions.append(temp_pos)
|
|
877
|
-
|
|
878
|
-
# Add jittered scatter plots for 3-way split
|
|
879
|
-
if train_positions:
|
|
880
|
-
train_jitter = np.random.normal(
|
|
881
|
-
0, jitter_strength, len(train_positions)
|
|
882
|
-
)
|
|
883
|
-
# Color by class if y is provided
|
|
884
|
-
if y is not None:
|
|
885
|
-
train_colors = [
|
|
886
|
-
stx.plt.color.PARAMS["RGBA_NORM"]["blue"]
|
|
887
|
-
if y[idx] == 0
|
|
888
|
-
else stx.plt.color.PARAMS["RGBA_NORM"]["lightblue"]
|
|
889
|
-
for idx in train_idx
|
|
890
|
-
]
|
|
891
|
-
ax.plot_scatter(
|
|
892
|
-
train_positions,
|
|
893
|
-
y_pos + train_jitter,
|
|
894
|
-
c=train_colors,
|
|
895
|
-
s=20,
|
|
896
|
-
alpha=0.7,
|
|
897
|
-
marker="o",
|
|
898
|
-
label="Train (class 0)" if fold == 0 else "",
|
|
899
|
-
zorder=3,
|
|
900
|
-
)
|
|
901
|
-
else:
|
|
902
|
-
ax.plot_scatter(
|
|
903
|
-
train_positions,
|
|
904
|
-
y_pos + train_jitter,
|
|
905
|
-
c="darkblue",
|
|
906
|
-
s=20,
|
|
907
|
-
alpha=0.7,
|
|
908
|
-
marker="o",
|
|
909
|
-
label="Train points" if fold == 0 else "",
|
|
910
|
-
zorder=3,
|
|
911
|
-
)
|
|
912
|
-
|
|
913
|
-
if val_positions:
|
|
914
|
-
val_jitter = np.random.normal(
|
|
915
|
-
0, jitter_strength, len(val_positions)
|
|
916
|
-
)
|
|
917
|
-
# Color by class if y is provided
|
|
918
|
-
if y is not None:
|
|
919
|
-
val_colors = [
|
|
920
|
-
stx.plt.color.PARAMS["RGBA_NORM"]["yellow"]
|
|
921
|
-
if y[idx] == 0
|
|
922
|
-
else stx.plt.color.PARAMS["RGBA_NORM"]["orange"]
|
|
923
|
-
for idx in val_idx
|
|
924
|
-
]
|
|
925
|
-
ax.plot_scatter(
|
|
926
|
-
val_positions,
|
|
927
|
-
y_pos + val_jitter,
|
|
928
|
-
c=val_colors,
|
|
929
|
-
s=20,
|
|
930
|
-
alpha=0.7,
|
|
931
|
-
marker="^",
|
|
932
|
-
label="Val (class 0)" if fold == 0 else "",
|
|
933
|
-
zorder=3,
|
|
934
|
-
)
|
|
935
|
-
else:
|
|
936
|
-
ax.plot_scatter(
|
|
937
|
-
val_positions,
|
|
938
|
-
y_pos + val_jitter,
|
|
939
|
-
c="darkgreen",
|
|
940
|
-
s=20,
|
|
941
|
-
alpha=0.7,
|
|
942
|
-
marker="^",
|
|
943
|
-
label="Val points" if fold == 0 else "",
|
|
944
|
-
zorder=3,
|
|
945
|
-
)
|
|
946
|
-
|
|
947
|
-
if test_positions:
|
|
948
|
-
test_jitter = np.random.normal(
|
|
949
|
-
0, jitter_strength, len(test_positions)
|
|
950
|
-
)
|
|
951
|
-
# Color by class if y is provided
|
|
952
|
-
if y is not None:
|
|
953
|
-
test_colors = [
|
|
954
|
-
stx.plt.color.PARAMS["RGBA_NORM"]["red"]
|
|
955
|
-
if y[idx] == 0
|
|
956
|
-
else stx.plt.color.PARAMS["RGBA_NORM"]["brown"]
|
|
957
|
-
for idx in test_idx
|
|
958
|
-
]
|
|
959
|
-
ax.plot_scatter(
|
|
960
|
-
test_positions,
|
|
961
|
-
y_pos + test_jitter,
|
|
962
|
-
c=test_colors,
|
|
963
|
-
s=20,
|
|
964
|
-
alpha=0.7,
|
|
965
|
-
marker="s",
|
|
966
|
-
label="Test (class 0)" if fold == 0 else "",
|
|
967
|
-
zorder=3,
|
|
968
|
-
)
|
|
969
|
-
else:
|
|
970
|
-
ax.plot_scatter(
|
|
971
|
-
test_positions,
|
|
972
|
-
y_pos + test_jitter,
|
|
973
|
-
c="darkred",
|
|
974
|
-
s=20,
|
|
975
|
-
alpha=0.7,
|
|
976
|
-
marker="s",
|
|
977
|
-
label="Test points" if fold == 0 else "",
|
|
978
|
-
zorder=3,
|
|
979
|
-
)
|
|
980
|
-
|
|
981
|
-
else: # train, test (2-way split)
|
|
982
|
-
train_idx, test_idx = split_indices
|
|
983
|
-
|
|
984
|
-
# Get actual timestamps for train and test indices
|
|
985
|
-
train_times = (
|
|
986
|
-
timestamps[train_idx] if timestamps is not None else train_idx
|
|
987
|
-
)
|
|
988
|
-
test_times = (
|
|
989
|
-
timestamps[test_idx] if timestamps is not None else test_idx
|
|
990
|
-
)
|
|
991
|
-
|
|
992
|
-
# Find temporal positions for scatter plot
|
|
993
|
-
train_positions = []
|
|
994
|
-
for idx in train_idx:
|
|
995
|
-
temp_pos = np.where(time_order == idx)[0][0]
|
|
996
|
-
train_positions.append(temp_pos)
|
|
997
|
-
|
|
998
|
-
test_positions = []
|
|
999
|
-
for idx in test_idx:
|
|
1000
|
-
temp_pos = np.where(time_order == idx)[0][0]
|
|
1001
|
-
test_positions.append(temp_pos)
|
|
1002
|
-
|
|
1003
|
-
# Add jittered scatter plots for 2-way split
|
|
1004
|
-
if train_positions:
|
|
1005
|
-
train_jitter = np.random.normal(
|
|
1006
|
-
0, jitter_strength, len(train_positions)
|
|
1007
|
-
)
|
|
1008
|
-
# Color by class if y is provided
|
|
1009
|
-
if y is not None:
|
|
1010
|
-
train_colors = [
|
|
1011
|
-
stx.plt.color.PARAMS["RGBA_NORM"]["blue"]
|
|
1012
|
-
if y[idx] == 0
|
|
1013
|
-
else stx.plt.color.PARAMS["RGBA_NORM"]["lightblue"]
|
|
1014
|
-
for idx in train_idx
|
|
1015
|
-
]
|
|
1016
|
-
ax.plot_scatter(
|
|
1017
|
-
train_positions,
|
|
1018
|
-
y_pos + train_jitter,
|
|
1019
|
-
c=train_colors,
|
|
1020
|
-
s=20,
|
|
1021
|
-
alpha=0.7,
|
|
1022
|
-
marker="o",
|
|
1023
|
-
label="Train (class 0)" if fold == 0 else "",
|
|
1024
|
-
zorder=3,
|
|
1025
|
-
)
|
|
1026
|
-
else:
|
|
1027
|
-
ax.plot_scatter(
|
|
1028
|
-
train_positions,
|
|
1029
|
-
y_pos + train_jitter,
|
|
1030
|
-
c="darkblue",
|
|
1031
|
-
s=20,
|
|
1032
|
-
alpha=0.7,
|
|
1033
|
-
marker="o",
|
|
1034
|
-
label="Train points" if fold == 0 else "",
|
|
1035
|
-
zorder=3,
|
|
1036
|
-
)
|
|
1037
|
-
|
|
1038
|
-
if test_positions:
|
|
1039
|
-
test_jitter = np.random.normal(
|
|
1040
|
-
0, jitter_strength, len(test_positions)
|
|
1041
|
-
)
|
|
1042
|
-
# Color by class if y is provided
|
|
1043
|
-
if y is not None:
|
|
1044
|
-
test_colors = [
|
|
1045
|
-
stx.plt.color.PARAMS["RGBA_NORM"]["red"]
|
|
1046
|
-
if y[idx] == 0
|
|
1047
|
-
else stx.plt.color.PARAMS["RGBA_NORM"]["brown"]
|
|
1048
|
-
for idx in test_idx
|
|
1049
|
-
]
|
|
1050
|
-
ax.plot_scatter(
|
|
1051
|
-
test_positions,
|
|
1052
|
-
y_pos + test_jitter,
|
|
1053
|
-
c=test_colors,
|
|
1054
|
-
s=20,
|
|
1055
|
-
alpha=0.7,
|
|
1056
|
-
marker="s",
|
|
1057
|
-
label="Test (class 0)" if fold == 0 else "",
|
|
1058
|
-
zorder=3,
|
|
1059
|
-
)
|
|
1060
|
-
else:
|
|
1061
|
-
ax.plot_scatter(
|
|
1062
|
-
test_positions,
|
|
1063
|
-
y_pos + test_jitter,
|
|
1064
|
-
c="darkred",
|
|
1065
|
-
s=20,
|
|
1066
|
-
alpha=0.7,
|
|
1067
|
-
marker="s",
|
|
1068
|
-
label="Test points" if fold == 0 else "",
|
|
1069
|
-
zorder=3,
|
|
1070
|
-
)
|
|
1071
|
-
|
|
1072
|
-
# Format plot
|
|
1073
|
-
ax.set_ylim(-0.5, len(splits) - 0.5)
|
|
1074
|
-
ax.set_xlim(0, len(X))
|
|
1075
|
-
ax.set_xlabel("Temporal Position (sorted by timestamp)")
|
|
1076
|
-
ax.set_ylabel("Fold")
|
|
1077
|
-
gap_text = f", Gap: {self.gap}" if self.gap > 0 else ""
|
|
1078
|
-
val_text = f", Val ratio: {self.val_ratio:.1%}" if self.val_ratio > 0 else ""
|
|
1079
|
-
ax.set_title(
|
|
1080
|
-
f"Sliding Window Split Visualization ({split_type})\\n"
|
|
1081
|
-
f"Window: {self.window_size}, Step: {self.step_size}, Test: {self.test_size}{gap_text}{val_text}\\n"
|
|
1082
|
-
f"Rectangles show windows, dots show actual data points"
|
|
1083
|
-
)
|
|
1084
|
-
|
|
1085
|
-
# Set y-ticks
|
|
1086
|
-
ax.set_yticks(range(len(splits)))
|
|
1087
|
-
ax.set_yticklabels([f"Fold {i}" for i in range(len(splits))])
|
|
1088
|
-
|
|
1089
|
-
# Add enhanced legend with class and sample information
|
|
1090
|
-
if y is not None:
|
|
1091
|
-
# Count samples per class in total dataset
|
|
1092
|
-
unique_classes, class_counts = np.unique(y, return_counts=True)
|
|
1093
|
-
total_class_info = ", ".join(
|
|
1094
|
-
[
|
|
1095
|
-
f"Class {cls}: n={count}"
|
|
1096
|
-
for cls, count in zip(unique_classes, class_counts)
|
|
1097
|
-
]
|
|
1098
|
-
)
|
|
1099
|
-
|
|
1100
|
-
# Count samples in first fold to show per-fold distribution
|
|
1101
|
-
first_split = splits[0]
|
|
1102
|
-
if len(first_split) == 3: # train, val, test
|
|
1103
|
-
train_idx, val_idx, test_idx = first_split
|
|
1104
|
-
fold_info = f"Fold 0: Train n={len(train_idx)}, Val n={len(val_idx)}, Test n={len(test_idx)}"
|
|
1105
|
-
else: # train, test
|
|
1106
|
-
train_idx, test_idx = first_split
|
|
1107
|
-
fold_info = f"Fold 0: Train n={len(train_idx)}, Test n={len(test_idx)}"
|
|
1108
|
-
|
|
1109
|
-
# Add legend with class information
|
|
1110
|
-
handles, labels = ax.get_legend_handles_labels()
|
|
1111
|
-
# Add title to legend showing class distribution
|
|
1112
|
-
legend_title = f"Total: {total_class_info}\\n{fold_info}"
|
|
1113
|
-
ax.legend(handles, labels, loc="upper right", title=legend_title)
|
|
1114
|
-
else:
|
|
1115
|
-
ax.legend(loc="upper right")
|
|
1116
|
-
|
|
1117
|
-
plt.tight_layout()
|
|
1118
|
-
|
|
1119
|
-
if save_path:
|
|
1120
|
-
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
1121
|
-
|
|
1122
|
-
return fig
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
"""Functions & Classes"""
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
def main(args) -> int:
|
|
1129
|
-
"""Demonstrate TimeSeriesSlidingWindowSplit functionality.
|
|
1130
|
-
|
|
1131
|
-
Args:
|
|
1132
|
-
args: Command line arguments
|
|
1133
|
-
|
|
1134
|
-
Returns:
|
|
1135
|
-
int: Exit status
|
|
1136
|
-
"""
|
|
1137
|
-
|
|
1138
|
-
def demo_01_fixed_window_non_overlapping_tests(X, y, timestamps):
|
|
1139
|
-
"""Demo 1: Fixed window size with non-overlapping test sets (DEFAULT).
|
|
1140
|
-
|
|
1141
|
-
Best for: Testing model on consistent recent history.
|
|
1142
|
-
Each sample tested exactly once (like K-fold for time series).
|
|
1143
|
-
"""
|
|
1144
|
-
logger.info("=" * 70)
|
|
1145
|
-
logger.info("DEMO 1: Fixed Window + Non-overlapping Tests (DEFAULT)")
|
|
1146
|
-
logger.info("=" * 70)
|
|
1147
|
-
logger.info("Best for: Testing model on consistent recent history")
|
|
1148
|
-
|
|
1149
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1150
|
-
window_size=args.window_size,
|
|
1151
|
-
test_size=args.test_size,
|
|
1152
|
-
gap=args.gap,
|
|
1153
|
-
overlapping_tests=False, # Default
|
|
1154
|
-
expanding_window=False, # Default
|
|
1155
|
-
)
|
|
1156
|
-
|
|
1157
|
-
splits = list(splitter.split(X, y, timestamps))[:5]
|
|
1158
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1159
|
-
|
|
1160
|
-
for fold, (train_idx, test_idx) in enumerate(splits):
|
|
1161
|
-
logger.info(
|
|
1162
|
-
f" Fold {fold}: Train={len(train_idx)} (fixed), Test={len(test_idx)}"
|
|
1163
|
-
)
|
|
1164
|
-
|
|
1165
|
-
fig = splitter.plot_splits(X, y, timestamps)
|
|
1166
|
-
stx.io.save(fig, "./01_sliding_window_fixed.jpg", symlink_from_cwd=True)
|
|
1167
|
-
logger.info("")
|
|
1168
|
-
|
|
1169
|
-
return splits
|
|
1170
|
-
|
|
1171
|
-
def demo_02_expanding_window_non_overlapping_tests(X, y, timestamps):
|
|
1172
|
-
"""Demo 2: Expanding window with non-overlapping test sets.
|
|
1173
|
-
|
|
1174
|
-
Best for: Using all available past data (like sklearn TimeSeriesSplit).
|
|
1175
|
-
Training set grows to include all historical data.
|
|
1176
|
-
"""
|
|
1177
|
-
logger.info("=" * 70)
|
|
1178
|
-
logger.info("DEMO 2: Expanding Window + Non-overlapping Tests")
|
|
1179
|
-
logger.info("=" * 70)
|
|
1180
|
-
logger.info(
|
|
1181
|
-
"Best for: Using all available past data (like sklearn TimeSeriesSplit)"
|
|
1182
|
-
)
|
|
1183
|
-
|
|
1184
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1185
|
-
window_size=args.window_size,
|
|
1186
|
-
test_size=args.test_size,
|
|
1187
|
-
gap=args.gap,
|
|
1188
|
-
overlapping_tests=False,
|
|
1189
|
-
expanding_window=True, # Use all past data!
|
|
1190
|
-
)
|
|
1191
|
-
|
|
1192
|
-
splits = list(splitter.split(X, y, timestamps))[:5]
|
|
1193
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1194
|
-
|
|
1195
|
-
for fold, (train_idx, test_idx) in enumerate(splits):
|
|
1196
|
-
logger.info(
|
|
1197
|
-
f" Fold {fold}: Train={len(train_idx)} (growing!), Test={len(test_idx)}"
|
|
1198
|
-
)
|
|
1199
|
-
|
|
1200
|
-
fig = splitter.plot_splits(X, y, timestamps)
|
|
1201
|
-
stx.io.save(fig, "./02_sliding_window_expanding.jpg", symlink_from_cwd=True)
|
|
1202
|
-
logger.info("")
|
|
1203
|
-
|
|
1204
|
-
return splits
|
|
1205
|
-
|
|
1206
|
-
def demo_03_fixed_window_overlapping_tests(X, y, timestamps):
|
|
1207
|
-
"""Demo 3: Fixed window with overlapping test sets.
|
|
1208
|
-
|
|
1209
|
-
Best for: Maximum evaluation points (like K-fold training reuse).
|
|
1210
|
-
Test sets can overlap for more frequent model evaluation.
|
|
1211
|
-
"""
|
|
1212
|
-
logger.info("=" * 70)
|
|
1213
|
-
logger.info("DEMO 3: Fixed Window + Overlapping Tests")
|
|
1214
|
-
logger.info("=" * 70)
|
|
1215
|
-
logger.info("Best for: Maximum evaluation points (like K-fold for training)")
|
|
1216
|
-
|
|
1217
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1218
|
-
window_size=args.window_size,
|
|
1219
|
-
test_size=args.test_size,
|
|
1220
|
-
gap=args.gap,
|
|
1221
|
-
overlapping_tests=True, # Allow test overlap
|
|
1222
|
-
expanding_window=False,
|
|
1223
|
-
# step_size will default to test_size // 2 for 50% overlap
|
|
1224
|
-
)
|
|
1225
|
-
|
|
1226
|
-
splits = list(splitter.split(X, y, timestamps))[:5]
|
|
1227
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1228
|
-
|
|
1229
|
-
for fold, (train_idx, test_idx) in enumerate(splits):
|
|
1230
|
-
logger.info(f" Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
|
|
1231
|
-
|
|
1232
|
-
fig = splitter.plot_splits(X, y, timestamps)
|
|
1233
|
-
stx.io.save(fig, "./03_sliding_window_overlapping.jpg", symlink_from_cwd=True)
|
|
1234
|
-
logger.info("")
|
|
1235
|
-
|
|
1236
|
-
return splits
|
|
1237
|
-
|
|
1238
|
-
def demo_04_undersample_imbalanced_data(X, y_imbalanced, timestamps):
|
|
1239
|
-
"""Demo 4: Undersampling for imbalanced time series data.
|
|
1240
|
-
|
|
1241
|
-
Best for: Handling class imbalance in training sets.
|
|
1242
|
-
Balances classes by randomly undersampling majority class.
|
|
1243
|
-
"""
|
|
1244
|
-
logger.info("=" * 70)
|
|
1245
|
-
logger.info("DEMO 4: Undersampling for Imbalanced Data")
|
|
1246
|
-
logger.info("=" * 70)
|
|
1247
|
-
logger.info("Best for: Handling class imbalance in time series")
|
|
1248
|
-
|
|
1249
|
-
# Show data imbalance
|
|
1250
|
-
unique, counts = np.unique(y_imbalanced, return_counts=True)
|
|
1251
|
-
logger.info(f"Class distribution: {dict(zip(unique, counts))}")
|
|
1252
|
-
logger.info("")
|
|
1253
|
-
|
|
1254
|
-
# Without undersampling
|
|
1255
|
-
splitter_no_undersample = TimeSeriesSlidingWindowSplit(
|
|
1256
|
-
window_size=args.window_size,
|
|
1257
|
-
test_size=args.test_size,
|
|
1258
|
-
gap=args.gap,
|
|
1259
|
-
undersample=False,
|
|
1260
|
-
)
|
|
1261
|
-
|
|
1262
|
-
splits_no_us = list(splitter_no_undersample.split(X, y_imbalanced, timestamps))[
|
|
1263
|
-
:3
|
|
1264
|
-
]
|
|
1265
|
-
logger.info(f"WITHOUT undersampling: {len(splits_no_us)} splits")
|
|
1266
|
-
for fold, (train_idx, test_idx) in enumerate(splits_no_us):
|
|
1267
|
-
train_labels = y_imbalanced[train_idx]
|
|
1268
|
-
train_unique, train_counts = np.unique(train_labels, return_counts=True)
|
|
1269
|
-
logger.info(
|
|
1270
|
-
f" Fold {fold}: Train size={len(train_idx)}, "
|
|
1271
|
-
f"Class dist={dict(zip(train_unique, train_counts))}"
|
|
1272
|
-
)
|
|
1273
|
-
logger.info("")
|
|
1274
|
-
|
|
1275
|
-
# With undersampling
|
|
1276
|
-
splitter_undersample = TimeSeriesSlidingWindowSplit(
|
|
1277
|
-
window_size=args.window_size,
|
|
1278
|
-
test_size=args.test_size,
|
|
1279
|
-
gap=args.gap,
|
|
1280
|
-
undersample=True, # Enable undersampling!
|
|
1281
|
-
random_state=42,
|
|
1282
|
-
)
|
|
1283
|
-
|
|
1284
|
-
splits_us = list(splitter_undersample.split(X, y_imbalanced, timestamps))[:3]
|
|
1285
|
-
logger.info(f"WITH undersampling: {len(splits_us)} splits")
|
|
1286
|
-
for fold, (train_idx, test_idx) in enumerate(splits_us):
|
|
1287
|
-
train_labels = y_imbalanced[train_idx]
|
|
1288
|
-
train_unique, train_counts = np.unique(train_labels, return_counts=True)
|
|
1289
|
-
logger.info(
|
|
1290
|
-
f" Fold {fold}: Train size={len(train_idx)} (balanced!), "
|
|
1291
|
-
f"Class dist={dict(zip(train_unique, train_counts))}"
|
|
1292
|
-
)
|
|
1293
|
-
|
|
1294
|
-
# Save visualization for undersampling
|
|
1295
|
-
fig = splitter_undersample.plot_splits(X, y_imbalanced, timestamps)
|
|
1296
|
-
stx.io.save(fig, "./04_sliding_window_undersample.jpg", symlink_from_cwd=True)
|
|
1297
|
-
logger.info("")
|
|
1298
|
-
|
|
1299
|
-
return splits_us
|
|
1300
|
-
|
|
1301
|
-
def demo_05_validation_dataset(X, y, timestamps):
|
|
1302
|
-
"""Demo 5: Using validation dataset with train-val-test splits.
|
|
1303
|
-
|
|
1304
|
-
Best for: Model selection and hyperparameter tuning.
|
|
1305
|
-
Creates train/validation/test splits maintaining temporal order.
|
|
1306
|
-
"""
|
|
1307
|
-
logger.info("=" * 70)
|
|
1308
|
-
logger.info("DEMO 5: Validation Dataset (Train-Val-Test Splits)")
|
|
1309
|
-
logger.info("=" * 70)
|
|
1310
|
-
logger.info("Best for: Model selection and hyperparameter tuning")
|
|
1311
|
-
|
|
1312
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1313
|
-
window_size=args.window_size,
|
|
1314
|
-
test_size=args.test_size,
|
|
1315
|
-
gap=args.gap,
|
|
1316
|
-
val_ratio=0.2, # 20% of training window for validation
|
|
1317
|
-
overlapping_tests=False,
|
|
1318
|
-
expanding_window=False,
|
|
1319
|
-
)
|
|
1320
|
-
|
|
1321
|
-
splits = list(splitter.split_with_val(X, y, timestamps))[:3]
|
|
1322
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1323
|
-
|
|
1324
|
-
for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
|
|
1325
|
-
logger.info(
|
|
1326
|
-
f" Fold {fold}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
|
|
1327
|
-
)
|
|
1328
|
-
|
|
1329
|
-
fig = splitter.plot_splits(X, y, timestamps)
|
|
1330
|
-
stx.io.save(fig, "./05_sliding_window_validation.jpg", symlink_from_cwd=True)
|
|
1331
|
-
logger.info("")
|
|
1332
|
-
|
|
1333
|
-
return splits
|
|
1334
|
-
|
|
1335
|
-
def demo_06_expanding_with_validation(X, y, timestamps):
|
|
1336
|
-
"""Demo 6: Expanding window with validation dataset.
|
|
1337
|
-
|
|
1338
|
-
Best for: Using all historical data with model selection.
|
|
1339
|
-
Combines expanding window and validation split.
|
|
1340
|
-
"""
|
|
1341
|
-
logger.info("=" * 70)
|
|
1342
|
-
logger.info("DEMO 6: Expanding Window + Validation Dataset")
|
|
1343
|
-
logger.info("=" * 70)
|
|
1344
|
-
logger.info("Best for: Using all historical data with model selection")
|
|
1345
|
-
|
|
1346
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1347
|
-
window_size=args.window_size,
|
|
1348
|
-
test_size=args.test_size,
|
|
1349
|
-
gap=args.gap,
|
|
1350
|
-
val_ratio=0.2,
|
|
1351
|
-
overlapping_tests=False,
|
|
1352
|
-
expanding_window=True, # Expanding + validation!
|
|
1353
|
-
)
|
|
1354
|
-
|
|
1355
|
-
splits = list(splitter.split_with_val(X, y, timestamps))[:3]
|
|
1356
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1357
|
-
|
|
1358
|
-
for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
|
|
1359
|
-
logger.info(
|
|
1360
|
-
f" Fold {fold}: Train={len(train_idx)} (growing!), Val={len(val_idx)}, Test={len(test_idx)}"
|
|
1361
|
-
)
|
|
1362
|
-
|
|
1363
|
-
fig = splitter.plot_splits(X, y, timestamps)
|
|
1364
|
-
stx.io.save(
|
|
1365
|
-
fig,
|
|
1366
|
-
"./06_sliding_window_expanding_validation.jpg",
|
|
1367
|
-
symlink_from_cwd=True,
|
|
1368
|
-
)
|
|
1369
|
-
logger.info("")
|
|
1370
|
-
|
|
1371
|
-
return splits
|
|
1372
|
-
|
|
1373
|
-
def demo_07_undersample_with_validation(X, y_imbalanced, timestamps):
|
|
1374
|
-
"""Demo 7: Undersampling with validation dataset.
|
|
1375
|
-
|
|
1376
|
-
Best for: Handling imbalanced data with hyperparameter tuning.
|
|
1377
|
-
Combines undersampling and validation split.
|
|
1378
|
-
"""
|
|
1379
|
-
|
|
1380
|
-
logger.info("=" * 70)
|
|
1381
|
-
logger.info("DEMO 7: Undersampling + Validation Dataset")
|
|
1382
|
-
logger.info("=" * 70)
|
|
1383
|
-
logger.info("Best for: Imbalanced data with hyperparameter tuning")
|
|
1384
|
-
|
|
1385
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1386
|
-
window_size=args.window_size,
|
|
1387
|
-
test_size=args.test_size,
|
|
1388
|
-
gap=args.gap,
|
|
1389
|
-
val_ratio=0.2,
|
|
1390
|
-
undersample=True, # Undersample + validation!
|
|
1391
|
-
random_state=42,
|
|
1392
|
-
)
|
|
1393
|
-
|
|
1394
|
-
splits = list(splitter.split_with_val(X, y_imbalanced, timestamps))[:3]
|
|
1395
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1396
|
-
|
|
1397
|
-
for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
|
|
1398
|
-
train_labels = y_imbalanced[train_idx]
|
|
1399
|
-
train_unique, train_counts = np.unique(train_labels, return_counts=True)
|
|
1400
|
-
logger.info(
|
|
1401
|
-
f" Fold {fold}: Train={len(train_idx)} (balanced!), Val={len(val_idx)}, Test={len(test_idx)}, "
|
|
1402
|
-
f"Class dist={dict(zip(train_unique, train_counts))}"
|
|
1403
|
-
)
|
|
1404
|
-
|
|
1405
|
-
fig = splitter.plot_splits(X, y_imbalanced, timestamps)
|
|
1406
|
-
stx.io.save(
|
|
1407
|
-
fig,
|
|
1408
|
-
"./07_sliding_window_undersample_validation.jpg",
|
|
1409
|
-
symlink_from_cwd=True,
|
|
1410
|
-
)
|
|
1411
|
-
logger.info("")
|
|
1412
|
-
|
|
1413
|
-
return splits
|
|
1414
|
-
|
|
1415
|
-
def demo_08_all_options_combined(X, y_imbalanced, timestamps):
|
|
1416
|
-
"""Demo 8: All options combined.
|
|
1417
|
-
|
|
1418
|
-
Best for: Maximum flexibility - expanding window, undersampling, and validation.
|
|
1419
|
-
Shows all features working together.
|
|
1420
|
-
"""
|
|
1421
|
-
logger.info("=" * 70)
|
|
1422
|
-
logger.info("DEMO 8: Expanding + Undersampling + Validation (ALL OPTIONS)")
|
|
1423
|
-
logger.info("=" * 70)
|
|
1424
|
-
logger.info("Best for: Comprehensive time series CV with all features")
|
|
1425
|
-
|
|
1426
|
-
splitter = TimeSeriesSlidingWindowSplit(
|
|
1427
|
-
window_size=args.window_size,
|
|
1428
|
-
test_size=args.test_size,
|
|
1429
|
-
gap=args.gap,
|
|
1430
|
-
val_ratio=0.2,
|
|
1431
|
-
overlapping_tests=False,
|
|
1432
|
-
expanding_window=True, # All three!
|
|
1433
|
-
undersample=True,
|
|
1434
|
-
random_state=42,
|
|
1435
|
-
)
|
|
1436
|
-
|
|
1437
|
-
splits = list(splitter.split_with_val(X, y_imbalanced, timestamps))[:3]
|
|
1438
|
-
logger.info(f"Generated {len(splits)} splits")
|
|
1439
|
-
|
|
1440
|
-
for fold, (train_idx, val_idx, test_idx) in enumerate(splits):
|
|
1441
|
-
train_labels = y_imbalanced[train_idx]
|
|
1442
|
-
train_unique, train_counts = np.unique(train_labels, return_counts=True)
|
|
1443
|
-
logger.info(
|
|
1444
|
-
f" Fold {fold}: Train={len(train_idx)} (growing & balanced!), Val={len(val_idx)}, Test={len(test_idx)}, "
|
|
1445
|
-
f"Class dist={dict(zip(train_unique, train_counts))}"
|
|
1446
|
-
)
|
|
1447
|
-
|
|
1448
|
-
fig = splitter.plot_splits(X, y_imbalanced, timestamps)
|
|
1449
|
-
stx.io.save(fig, "./08_sliding_window_all_options.jpg", symlink_from_cwd=True)
|
|
1450
|
-
logger.info("")
|
|
1451
|
-
|
|
1452
|
-
return splits
|
|
1453
|
-
|
|
1454
|
-
def print_summary(
|
|
1455
|
-
splits_fixed,
|
|
1456
|
-
splits_expanding,
|
|
1457
|
-
splits_overlap,
|
|
1458
|
-
splits_undersample=None,
|
|
1459
|
-
splits_validation=None,
|
|
1460
|
-
splits_expanding_val=None,
|
|
1461
|
-
splits_undersample_val=None,
|
|
1462
|
-
splits_all_options=None,
|
|
1463
|
-
):
|
|
1464
|
-
"""Print comparison summary of all modes."""
|
|
1465
|
-
logger.info("=" * 70)
|
|
1466
|
-
logger.info("SUMMARY COMPARISON")
|
|
1467
|
-
logger.info("=" * 70)
|
|
1468
|
-
logger.info(
|
|
1469
|
-
f"01. Fixed window (non-overlap): {len(splits_fixed)} folds, train size constant"
|
|
1470
|
-
)
|
|
1471
|
-
logger.info(
|
|
1472
|
-
f"02. Expanding window (non-overlap): {len(splits_expanding)} folds, train size grows"
|
|
1473
|
-
)
|
|
1474
|
-
logger.info(
|
|
1475
|
-
f"03. Fixed window (overlapping): {len(splits_overlap)} folds, more eval points"
|
|
106
|
+
super().__init__(
|
|
107
|
+
window_size=window_size,
|
|
108
|
+
step_size=step_size,
|
|
109
|
+
test_size=test_size,
|
|
110
|
+
gap=gap,
|
|
111
|
+
val_ratio=val_ratio,
|
|
112
|
+
random_state=random_state,
|
|
113
|
+
overlapping_tests=overlapping_tests,
|
|
114
|
+
expanding_window=expanding_window,
|
|
115
|
+
undersample=undersample,
|
|
116
|
+
n_splits=n_splits,
|
|
1476
117
|
)
|
|
1477
|
-
if splits_undersample is not None:
|
|
1478
|
-
logger.info(
|
|
1479
|
-
f"04. With undersampling: {len(splits_undersample)} folds, balanced classes"
|
|
1480
|
-
)
|
|
1481
|
-
if splits_validation is not None:
|
|
1482
|
-
logger.info(
|
|
1483
|
-
f"05. With validation set: {len(splits_validation)} folds, train-val-test"
|
|
1484
|
-
)
|
|
1485
|
-
if splits_expanding_val is not None:
|
|
1486
|
-
logger.info(
|
|
1487
|
-
f"06. Expanding + validation: {len(splits_expanding_val)} folds, growing train with val"
|
|
1488
|
-
)
|
|
1489
|
-
if splits_undersample_val is not None:
|
|
1490
|
-
logger.info(
|
|
1491
|
-
f"07. Undersample + validation: {len(splits_undersample_val)} folds, balanced with val"
|
|
1492
|
-
)
|
|
1493
|
-
if splits_all_options is not None:
|
|
1494
|
-
logger.info(
|
|
1495
|
-
f"08. All options combined: {len(splits_all_options)} folds, expanding + balanced + val"
|
|
1496
|
-
)
|
|
1497
|
-
logger.info("")
|
|
1498
|
-
logger.info("Key Insights:")
|
|
1499
|
-
logger.info(
|
|
1500
|
-
" - Non-overlapping tests (default): Each sample tested exactly once"
|
|
1501
|
-
)
|
|
1502
|
-
logger.info(
|
|
1503
|
-
" - Expanding window: Maximizes training data, like sklearn TimeSeriesSplit"
|
|
1504
|
-
)
|
|
1505
|
-
logger.info(
|
|
1506
|
-
" - Overlapping tests: More evaluation points, like K-fold training reuse"
|
|
1507
|
-
)
|
|
1508
|
-
if splits_undersample is not None:
|
|
1509
|
-
logger.info(
|
|
1510
|
-
" - Undersampling: Balances imbalanced classes in training sets"
|
|
1511
|
-
)
|
|
1512
|
-
if splits_validation is not None:
|
|
1513
|
-
logger.info(
|
|
1514
|
-
" - Validation set: Enables hyperparameter tuning with temporal order"
|
|
1515
|
-
)
|
|
1516
|
-
if splits_all_options is not None:
|
|
1517
|
-
logger.info(
|
|
1518
|
-
" - Combined options: Maximum flexibility for complex time series CV"
|
|
1519
|
-
)
|
|
1520
|
-
logger.info("=" * 70)
|
|
1521
|
-
|
|
1522
|
-
# Main execution
|
|
1523
|
-
logger.info("=" * 70)
|
|
1524
|
-
logger.info("Demonstrating TimeSeriesSlidingWindowSplit with New Options")
|
|
1525
|
-
logger.info("=" * 70)
|
|
1526
|
-
|
|
1527
|
-
# Generate test data
|
|
1528
|
-
np.random.seed(42)
|
|
1529
|
-
n_samples = args.n_samples
|
|
1530
|
-
X = np.random.randn(n_samples, 5)
|
|
1531
|
-
y = np.random.randint(0, 2, n_samples) # Balanced
|
|
1532
|
-
timestamps = np.arange(n_samples) + np.random.normal(0, 0.1, n_samples)
|
|
1533
|
-
|
|
1534
|
-
# Create imbalanced labels (80% class 0, 20% class 1)
|
|
1535
|
-
y_imbalanced = np.zeros(n_samples, dtype=int)
|
|
1536
|
-
n_minority = int(n_samples * 0.2)
|
|
1537
|
-
minority_indices = np.random.choice(n_samples, size=n_minority, replace=False)
|
|
1538
|
-
y_imbalanced[minority_indices] = 1
|
|
1539
|
-
|
|
1540
|
-
logger.info(f"Generated test data: {n_samples} samples, {X.shape[1]} features")
|
|
1541
|
-
logger.info("")
|
|
1542
|
-
|
|
1543
|
-
# Run demos
|
|
1544
|
-
splits_fixed = demo_01_fixed_window_non_overlapping_tests(X, y, timestamps)
|
|
1545
|
-
splits_expanding = demo_02_expanding_window_non_overlapping_tests(X, y, timestamps)
|
|
1546
|
-
splits_overlap = demo_03_fixed_window_overlapping_tests(X, y, timestamps)
|
|
1547
|
-
splits_undersample = demo_04_undersample_imbalanced_data(
|
|
1548
|
-
X, y_imbalanced, timestamps
|
|
1549
|
-
)
|
|
1550
|
-
splits_validation = demo_05_validation_dataset(X, y, timestamps)
|
|
1551
|
-
splits_expanding_val = demo_06_expanding_with_validation(X, y, timestamps)
|
|
1552
|
-
splits_undersample_val = demo_07_undersample_with_validation(
|
|
1553
|
-
X, y_imbalanced, timestamps
|
|
1554
|
-
)
|
|
1555
|
-
splits_all_options = demo_08_all_options_combined(X, y_imbalanced, timestamps)
|
|
1556
|
-
|
|
1557
|
-
# Print summary
|
|
1558
|
-
print_summary(
|
|
1559
|
-
splits_fixed,
|
|
1560
|
-
splits_expanding,
|
|
1561
|
-
splits_overlap,
|
|
1562
|
-
splits_undersample,
|
|
1563
|
-
splits_validation,
|
|
1564
|
-
splits_expanding_val,
|
|
1565
|
-
splits_undersample_val,
|
|
1566
|
-
splits_all_options,
|
|
1567
|
-
)
|
|
1568
|
-
|
|
1569
|
-
return 0
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
def parse_args() -> argparse.Namespace:
|
|
1573
|
-
"""Parse command line arguments."""
|
|
1574
|
-
parser = argparse.ArgumentParser(
|
|
1575
|
-
description="Demonstrate TimeSeriesSlidingWindowSplit with overlapping_tests and expanding_window options"
|
|
1576
|
-
)
|
|
1577
|
-
parser.add_argument(
|
|
1578
|
-
"--n-samples",
|
|
1579
|
-
type=int,
|
|
1580
|
-
default=200,
|
|
1581
|
-
help="Number of samples to generate (default: %(default)s)",
|
|
1582
|
-
)
|
|
1583
|
-
parser.add_argument(
|
|
1584
|
-
"--window-size",
|
|
1585
|
-
type=int,
|
|
1586
|
-
default=50,
|
|
1587
|
-
help="Size of training window (default: %(default)s)",
|
|
1588
|
-
)
|
|
1589
|
-
parser.add_argument(
|
|
1590
|
-
"--test-size",
|
|
1591
|
-
type=int,
|
|
1592
|
-
default=20,
|
|
1593
|
-
help="Size of test window (default: %(default)s)",
|
|
1594
|
-
)
|
|
1595
|
-
parser.add_argument(
|
|
1596
|
-
"--gap",
|
|
1597
|
-
type=int,
|
|
1598
|
-
default=5,
|
|
1599
|
-
help="Gap between train and test (default: %(default)s)",
|
|
1600
|
-
)
|
|
1601
|
-
args = parser.parse_args()
|
|
1602
|
-
return args
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
def run_main() -> None:
|
|
1606
|
-
"""Initialize scitex framework, run main function, and cleanup."""
|
|
1607
|
-
global CONFIG, CC, sys, plt, rng
|
|
1608
|
-
|
|
1609
|
-
import sys
|
|
1610
|
-
|
|
1611
|
-
import matplotlib.pyplot as plt
|
|
1612
|
-
import scitex as stx
|
|
1613
|
-
|
|
1614
|
-
args = parse_args()
|
|
1615
|
-
|
|
1616
|
-
CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start(
|
|
1617
|
-
sys,
|
|
1618
|
-
plt,
|
|
1619
|
-
args=args,
|
|
1620
|
-
file=__FILE__,
|
|
1621
|
-
sdir_suffix=None,
|
|
1622
|
-
verbose=False,
|
|
1623
|
-
agg=True,
|
|
1624
|
-
)
|
|
1625
|
-
|
|
1626
|
-
exit_status = main(args)
|
|
1627
|
-
|
|
1628
|
-
stx.session.close(
|
|
1629
|
-
CONFIG,
|
|
1630
|
-
verbose=False,
|
|
1631
|
-
notify=False,
|
|
1632
|
-
message="",
|
|
1633
|
-
exit_status=exit_status,
|
|
1634
|
-
)
|
|
1635
|
-
|
|
1636
118
|
|
|
1637
|
-
if __name__ == "__main__":
|
|
1638
|
-
run_main()
|
|
1639
119
|
|
|
1640
120
|
# EOF
|