GLDF 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,754 @@
1
+ import numpy as np
2
+ from scipy.stats import binom, norm
3
+ from .data_management import CIT_DataPatterned, BlockView
4
+ from dataclasses import dataclass
5
+ from typing import Literal, Callable
6
+ from warnings import warn
7
+
8
+ class ITestCI:
9
+ """
10
+ Interface specifying how to expose (custom) CIT-implementations. Supports the efficient dispatch of
11
+ large collections of blocks.
12
+
13
+ .. seealso::
14
+ For use with the current mCIT implementation, also the interfaces for
15
+ :py:class:`IProvideVarianceForCIT` and potentially
16
+ :py:class:`IProvideAnalyticQuantilesForCIT` should be exposed.
17
+ For details, see also :ref:`label-interfaces-custom-data-proc`.
18
+ """
19
+
20
+ @dataclass
21
+ class Result:
22
+ """
23
+ Output format for (custom) CIT-implementations.
24
+ """
25
+
26
+ global_score : float #: Score for the entire dataset.
27
+ dependent : bool #: Dependency on the entire dataset.
28
+ block_scores : np.ndarray|None = None #: Score for each block individually (if applicable).
29
+
30
+ def run_single(self, data: CIT_DataPatterned) -> Result:
31
+ """Run a single CIT on all data.
32
+
33
+ :param data: The data-set to use.
34
+ :type data: CIT_DataPatterned
35
+ :return: Structured test-output.
36
+ :rtype: ITestCI.Result
37
+ """
38
+ raise NotImplementedError()
39
+
40
+ def run_many(self, data: BlockView) -> Result:
41
+ """Run CITs on many blocks (efficiently).
42
+
43
+ :param data: The data-set to use.
44
+ :type data: BlockView
45
+ :return: Structured test-output.
46
+ :rtype: ITestCI.Result
47
+ """
48
+ raise NotImplementedError()
49
+
50
+ @staticmethod
51
+ def _extract_cache_id(fname: str, data: CIT_DataPatterned|BlockView) -> tuple:
52
+ """Extract a cache-id from a given query. A fallback returning :py:obj:`data.cache_id`
53
+ is provided; it is possible, but not typically necessary to overwrite this fallback.
54
+
55
+ .. seealso::
56
+ For a discussion of cache-IDs, see :ref:`label-cache-ids`.
57
+
58
+ :param fname: name of the method to cache
59
+ :type fname: str
60
+ :param data: data containing the cache-id
61
+ :type data: CIT_DataPatterned | BlockView
62
+ :return: extracted cache-id, here simply :py:obj:`data.cache_id`.
63
+ :rtype: tuple
64
+ """
65
+ return data.cache_id
66
+
67
+
68
+ class ITestHomogeneity:
69
+ """
70
+ Interface specifying how to expose (custom) implementations of homogeneity-tests.
71
+ """
72
+
73
+ def is_homogeneous(self, data: CIT_DataPatterned) -> bool:
74
+ """Test if the data supplied by the query is homogenous.
75
+
76
+ :param data: The data to inspect.
77
+ :type data: CIT_DataPatterned
78
+ :return: The truth-value indicating if the data-set was accepted as homogenous.
79
+ :rtype: bool
80
+ """
81
+ raise NotImplementedError()
82
+
83
+ class ITestWeakRegime:
84
+ """
85
+ Interface specifying how to expose (custom) implementations of weak-regime tests.
86
+ """
87
+
88
+ def has_true_regime(self, data: CIT_DataPatterned) -> bool:
89
+ """Test if the non-homogenous data supplied by the query contains a true or just a weak regime.
90
+
91
+ :param data: The non-homogenous data to inspect.
92
+ :type data: CIT_DataPatterned
93
+ :return: The truth-value indicating if the data-set is beliefed to feature a true regime.
94
+ :rtype: bool
95
+ """
96
+ raise NotImplementedError()
97
+
98
+
99
+
100
+ class ITestMarkedCI:
101
+ """
102
+ Interface specifying how to expose (custom) implementations of marked conditional independence tests.
103
+ """
104
+
105
+ type _three_way_result_t = Literal["dependent", "independent", "regime", "weak or regime"] #: encode internal result/category
106
+
107
+ class Result:
108
+ """
109
+ Structured output of mCIT.
110
+ """
111
+ def __init__(self, result_string:'ITestMarkedCI._three_way_result_t'):
112
+ """Initialize from completed query's result.
113
+
114
+ :param result_string: The concluded categorization of data.
115
+ :type result_string: Literal["dependent", "independent", "regime", "weak or regime"]
116
+ """
117
+ self.result_string : 'mCIT._three_way_result_t' = result_string # can be dependent, independent, regime or "weak or regime" (in pc1-phase)
118
+
119
+ def is_regime(self) -> bool:
120
+ """Inspect the result to check if a true regime was found.
121
+
122
+ :return: Truth-value about the presence of a true-regime.
123
+ :rtype: bool
124
+ """
125
+ return self.result_string == "regime"
126
+
127
+ def is_globally_dependent(self) -> bool:
128
+ """Inspect the result to check if any sort of dependence (global, weak-regime or true-regime) was found.
129
+
130
+ :return: Truth-value about the presence of dependence.
131
+ :rtype: bool
132
+ """
133
+ return self.result_string != "independent"
134
+
135
+ def is_globally_independent(self) -> bool:
136
+ """Inspect the result to check if no dependence was found.
137
+
138
+ :return: Truth-value about global independence.
139
+ :rtype: bool
140
+ """
141
+ return self.result_string == "independent"
142
+
143
+ def marked_independence(self, data: CIT_DataPatterned) -> Result:
144
+ """Test marked conditional independence on the supplied data.
145
+
146
+ :param data: The data associated to the test to perform.
147
+ :type data: CIT_DataPatterned
148
+ :return: The structured mCIT output.
149
+ :rtype: ITestMarkedCI.Result
150
+ """
151
+ raise NotImplementedError()
152
+
153
+ class ITestIndicatorImplications:
154
+ """
155
+ Interface specifying how to expose (custom) implementations of indicator-relation tests.
156
+ """
157
+
158
+ def is_implied_regime(self, A_list: list[CIT_DataPatterned], B: CIT_DataPatterned) -> bool:
159
+ """Test the indicator implication all test of lhs list are independent :math:`\\Rightarrow` the test on
160
+ the rhs is independent.
161
+
162
+ :param A_list: the lhs list of tests
163
+ :type A_list: list[CIT_DataPatterned]
164
+ :param B: the rhs test
165
+ :type B: CIT_DataPatterned
166
+ :return: The truth-value of the given implication.
167
+ :rtype: bool
168
+ """
169
+ raise NotImplementedError()
170
+
171
+ class IProvideHyperparamsForRobustCIT:
172
+ """
173
+ Interface specifying how to supply (customized) hyper-parameters for robust dependence testing.
174
+ """
175
+
176
+ @dataclass
177
+ class Hyperparams:
178
+ B : int #: block-size
179
+
180
+ def hyperparams_for_robust_cit(self, N: int, dim_Z: int) -> 'IProvideHyperparamsForRobustCIT.Hyperparams':
181
+ """Supply hyper-parameters for robust CI testing for the given setup.
182
+
183
+ :param N: sample-size
184
+ :type N: int
185
+ :param dim_Z: size of the conditioning set
186
+ :type dim_Z: int
187
+ :return: Hyper-parameters to use.
188
+ :rtype: IProvideHyperparamsForRobustCIT.Hyperparams
189
+ """
190
+ raise NotImplementedError()
191
+
192
+ @staticmethod
193
+ def from_binomial_homogeneity_test(homogeneity_test_hyperparams: 'IProvideHyperparamsForBinomial') -> 'IProvideHyperparamsForRobustCIT':
194
+ """It is runtime-efficient and simple to obtain hyper-parameters for robust CI testing compatible to
195
+ those used by homogeneity-testing.
196
+
197
+ :param homogeneity_test: homogeneity-test hyper-parameters to copy
198
+ :type homogeneity_test: IProvideHyperparamsForBinomial
199
+ :return: Hyper-parameter Provider.
200
+ :rtype: IProvideHyperparamsForRobustCIT
201
+ """
202
+ class _HyperparamsRobustCIT(IProvideHyperparamsForRobustCIT):
203
+ def __init__(self, homog_hyper: IProvideHyperparamsForBinomial):
204
+ self.homog_hyper = homog_hyper
205
+ def hyperparams_for_robust_cit(self, N: int, dim_Z: int) -> 'IProvideHyperparamsForRobustCIT.Hyperparams':
206
+ return IProvideHyperparamsForRobustCIT.Hyperparams(B=self.homog_hyper.hyperparams_for_binomial(N, dim_Z).B)
207
+ return _HyperparamsRobustCIT(homogeneity_test_hyperparams)
208
+
209
+
210
+ class mCIT(ITestMarkedCI):
211
+ """
212
+ (Regime-)marked independence test (mCIT).
213
+ """
214
+
215
+ _internal_result_t = Literal["dependent", "independent", "weak", "regime", "weak or regime"]
216
+
217
+
218
+ def __init__(self, cit: ITestCI, homogeneity_test: ITestHomogeneity|None, weak_test: ITestWeakRegime|None=None,
219
+ homogeneity_first: bool=True, robust_conditional_testing: IProvideHyperparamsForRobustCIT|bool|None=True):
220
+ """Constructor of mCIT from underlying tests.
221
+
222
+ :param ci_test: Underlying CI-test.
223
+ :type ci_test: ITestCI
224
+ :param homogeneity_test: Underyling homogeneity-test.
225
+ :type homogeneity_test: ITestHomogeneity, optional
226
+ :param weak_test: Underlying weak-regime test, defaults to None
227
+ :type weak_test: ITestWeakRegime, optional
228
+ :param homogeneity_first: Test homogeneity first, then global dependency, defaults to True (recommended)
229
+ :type homogeneity_first: bool, optional
230
+ :param robust_conditional_testing: Configuration of conditinonal test (eg of regressors) to rely on data only locally in the pattern
231
+ (recommended for simple parametric tests), defaults to True (which will use
232
+ :py:meth:`IProvideHyperparamsForRobustCIT.from_binomial_homogeneity_test`)
233
+ :type robust_conditional_testing: IProvideHyperparamsForRobustCIT|bool, optional
234
+ :type min_regime_fraction: float, optional
235
+ """
236
+ self.ci_test = cit
237
+ self.homogeneity_test = homogeneity_test
238
+ self.weak_test = weak_test
239
+ self.homogeneity_first = homogeneity_first
240
+ if isinstance(robust_conditional_testing, bool):
241
+ self.robust_conditional_testing = IProvideHyperparamsForRobustCIT.from_binomial_homogeneity_test(homogeneity_test.hyperparams) if robust_conditional_testing else None
242
+ else:
243
+ self.robust_conditional_testing = robust_conditional_testing
244
+
245
+ if homogeneity_first:
246
+ assert homogeneity_test is not None
247
+ self.run = self._run_inhom_first
248
+ else:
249
+ self.run = self._run_global_first
250
+
251
+ def marked_independence(self, data: CIT_DataPatterned) -> ITestMarkedCI.Result:
252
+ if self.homogeneity_first:
253
+ internal_result = self._run_inhom_first(data)
254
+ else:
255
+ internal_result = self._run_global_first(data)
256
+
257
+ three_way_result = self._marked_independence_from_category(internal_result)
258
+ return ITestMarkedCI.Result(three_way_result)
259
+
260
+
261
+ def _is_globally_dependent(self, data: CIT_DataPatterned) -> bool:
262
+ if self.robust_conditional_testing is not None:
263
+ # by default use same blocks as homogeneity, can use a different hyper-parameter provider
264
+ robust_params = self.robust_conditional_testing.hyperparams_for_robust_cit(data.sample_count(), data.z_dim())
265
+ cit_result = self.ci_test.run_many(data.view_blocks(robust_params.B))
266
+ return cit_result.dependent
267
+ else:
268
+ return self.ci_test.run_single(data).dependent
269
+
270
+ def _weak_or_regime(self, data: CIT_DataPatterned) -> 'mCIT._internal_result_t':
271
+ if self.weak_test is None:
272
+ return "weak or regime"
273
+ else:
274
+ if self.weak_test.has_true_regime(data):
275
+ return "regime"
276
+ else:
277
+ return "weak"
278
+
279
+ def _marked_independence_from_category(self, internal_result: 'mCIT._internal_result_t') -> 'ITestMarkedCI._three_way_result_t':
280
+ # merge 'weak' and 'dependent' into the single output 'dependent'
281
+ return "dependent" if (internal_result == "weak") else internal_result
282
+
283
+ def _run_global_first(self, data: CIT_DataPatterned) -> 'mCIT._internal_result_t':
284
+ if self._is_globally_dependent(data):
285
+ if self.homogeneity_test is not None:
286
+ if self.homogeneity_test.is_homogeneous(data):
287
+ return "dependent"
288
+ else:
289
+ return self._weak_or_regime(data)
290
+ else:
291
+ return "dependent"
292
+ else:
293
+ return "independent"
294
+
295
+ def _run_inhom_first(self, data: CIT_DataPatterned) -> 'mCIT._internal_result_t':
296
+ if self.homogeneity_test.is_homogeneous(data):
297
+ if self._is_globally_dependent(data):
298
+ return "dependent"
299
+ else:
300
+ return "independent"
301
+ else:
302
+ return self._weak_or_regime(data)
303
+
304
+
305
+
306
+ class IProvideHyperparamsForBinomial:
307
+ """
308
+ Interface specifying how to supply (customized) hyper-parameters for binomial homogeneity testing.
309
+ """
310
+
311
+ @dataclass
312
+ class Hyperparams:
313
+ B : int #: block-size
314
+ alpha : float #: error-control target :math:`\alpha`
315
+ beta : float #: binomial quantile
316
+ max_acceptable_count : float #: by numerical precision, pvalue at max acceptable count may be within :math:`\alpha +` tolerance (by default :math:`10^{-5}`)
317
+
318
+ def hyperparams_for_binomial(self, N: int, dim_Z: int) -> Hyperparams:
319
+ """Supply hyper-parameters for the binomial homogeneity-test for the given setup.
320
+
321
+ :param N: sample-size
322
+ :type N: int
323
+ :param dim_Z: size of conditioning set
324
+ :type dim_Z: int
325
+ :return: The hyper-parameters to use.
326
+ :rtype: IProvideHyperparamsForBinomial.Hyperparams
327
+ """
328
+ raise NotImplementedError()
329
+
330
+ class IProvideAnalyticQuantilesForCIT:
331
+ """
332
+ Interface to expose for (custom) CIT-implementations if the homogeneity-test
333
+ :py:class:`Homogeneity_Binomial` is used. Quantiles can also be bootstrapped,
334
+ if no implementation (:py:data:`None`) of this interface is provided.
335
+ """
336
+
337
+ def cit_quantile_estimate(self, data: BlockView, cit_result: ITestCI.Result, beta: float, cit_obj: ITestCI) -> float:
338
+ """Provide an estimate of the :math:`\\beta`-quantile for the test implemented by cit_obj.
339
+
340
+ :param data: The data-blocks to operate on.
341
+ :type data: BlockView
342
+ :param cit_result: The CIT result for the data (currently always run previously anyway).
343
+ :type cit_result: ITestCI.Result
344
+ :param beta: The quantile :math:`\\beta` to estimate.
345
+ :type beta: float
346
+ :param cit_obj: The underlying cit-instance for which the quantile should be computed.
347
+ (The present interface :py:class:`IProvideAnalyticQuantilesForCIT` can,
348
+ but does not have to, be exposed on the CIT-type itself.)
349
+ :type cit_obj: ITestCI
350
+ :return: Estimate of the dependence-value at the quantile :math:`\\beta`
351
+ :rtype: float
352
+ """
353
+ raise NotImplementedError()
354
+
355
+
356
+ class Homogeneity_Binomial(ITestHomogeneity):
357
+ """
358
+ Homogeneity test based on binomial approach via quantile estimator.
359
+ Implements :py:class:`ITestHomogeneity` interface for use with :py:class:`mCIT`.
360
+ """
361
+
362
+ @staticmethod
363
+ def _get_actual_error_control_raw(alpha_homogeneity_err_control_requested, block_count, alpha_binom) -> float:
364
+ discrete_cutoff = binom.ppf(1.0 - alpha_homogeneity_err_control_requested, n=block_count, p=alpha_binom)
365
+ return binom.sf(discrete_cutoff, n=block_count, p=alpha_binom)
366
+
367
+ def get_actual_error_control(self, N:int, dim_Z:int=0) -> float:
368
+ """
369
+ Gets the actual error-control after accounting for counting-statistics. Depending on the internals
370
+ of the used hyper-parameter set, this may be different from :math:`\\alpha` as specified originally.
371
+
372
+ :param N: Sample size N.
373
+ :type N: int
374
+ :param dim_Z: Size of the conditioning-set Z, defaults to 0
375
+ :type dim_Z: int, optional
376
+ :return: Effective error-control target :math:`\\alpha`.
377
+ :rtype: float
378
+ """
379
+ params = self.hyperparams.hyperparams_for_binomial(N, dim_Z)
380
+ block_count = int(N/params.B)
381
+ return Homogeneity_Binomial._get_actual_error_control_raw(params.alpha, block_count, params.beta)
382
+
383
+
384
+ def __init__(self, hyperparams: IProvideHyperparamsForBinomial, cit: ITestCI, cit_analytic_quantile_estimate: IProvideAnalyticQuantilesForCIT|None = None,
385
+ bootstrap_block_count: int=5000, next_bootstrap_seed: Callable[[], None|int|np.random.SeedSequence]= lambda : None):
386
+ """Construct from hyper-parameter set, and either cit-specific quantile estimate or bootstrap block-count for generic quantile estimation.
387
+
388
+ :param hyperparams: Hyper-parameter set to use.
389
+ :type hyperparams: IProvideHyperparamsForBinomial
390
+ :param cit_analytic_quantile_estimate: Cit-specific quantile estimate (if available), defaults to None
391
+ :type cit_analytic_quantile_estimate: IProvideAnalyticQuantilesForCIT | None, optional
392
+ :param bootstrap_block_count: Block-count for bootstrap of quantile (if no cit-specific quantile was provided), defaults to 5000
393
+ :type bootstrap_block_count: int, optional
394
+ """
395
+ self.hyperparams = hyperparams
396
+ self.cit = cit
397
+ self.analytic_quantile = cit_analytic_quantile_estimate
398
+ self.bootstrap_block_count=bootstrap_block_count
399
+ self.next_bootstrap_seed = next_bootstrap_seed
400
+
401
+
402
+ def get_quantile(self, data: BlockView, cit_result: ITestCI.Result, beta: float) -> float:
403
+ """Obtain a quantile for the given dataset.
404
+
405
+ :param data: Data-blocks associated to current test.
406
+ :type data: BlockView
407
+ :param cit_result: The CIT result for the data (currently always run previously anyway).
408
+ :type cit_result: ITestCI.Result
409
+ :param beta: Target probabilty to get a quantile (lower bound) for.
410
+ :type beta: float
411
+ :return: The estimated quantile lower bound.
412
+ :rtype: float
413
+ """
414
+ if self.analytic_quantile is not None:
415
+ return self.analytic_quantile.cit_quantile_estimate(data, cit_result, beta, cit_obj=self.cit)
416
+ else:
417
+ return self._bootstrap_quantile(data, cit_result, beta)
418
+
419
+ def _bootstrap_quantile(self, data: BlockView, cit_result: ITestCI.Result, beta: float) -> float:
420
+ d1_positive = (cit_result.global_score > 0.0)
421
+
422
+ rng = np.random.default_rng(self.next_bootstrap_seed())
423
+ bootstrap_blocks = data.bootstrap_unaligned_blocks(rng, bootstrap_block_count=self.bootstrap_block_count)
424
+
425
+ z_scores_unaligned = self.cit.run_many(bootstrap_blocks).block_scores
426
+
427
+ target_quantile = beta if d1_positive else 1.0-beta
428
+ return float(np.quantile(z_scores_unaligned, target_quantile))
429
+
430
+ def is_homogeneous(self, data: CIT_DataPatterned) -> bool:
431
+ params = self.hyperparams.hyperparams_for_binomial(data.sample_count(), data.z_dim())
432
+ data_blocks = data.view_blocks(params.B)
433
+
434
+ cit_result = self.cit.run_many(data_blocks)
435
+
436
+ d1 = cit_result.global_score
437
+ d1_is_positive = (d1 > 0.0)
438
+
439
+ # get_cutoff is provided by analytic or bootstrap (below)
440
+ cutoff = self.get_quantile(data_blocks, cit_result, params.beta)
441
+
442
+ if d1_is_positive:
443
+ binom_count = np.count_nonzero(cit_result.block_scores < cutoff)
444
+ else:
445
+ binom_count = np.count_nonzero(cit_result.block_scores > cutoff)
446
+
447
+ # by numerical precision, pvalue at max acceptable count may be within alpha + tolerance (by default 10^-5)
448
+ return binom_count <= params.max_acceptable_count
449
+
450
+
451
+ class IProvideHyperparamsForAcceptanceInterval:
452
+ """
453
+ Interface specifying how to supply (customized) hyper-parameters for acceptance-interval testing.
454
+ """
455
+
456
+ @dataclass
457
+ class Hyperparams:
458
+ B : int #: block-size
459
+ alpha : float #: error-control target
460
+ cutoff : float #: cutoff
461
+
462
+ def hyperparams_for_acceptance_interval(self, N: int, dim_Z: int) -> Hyperparams:
463
+ """Supply hyper-parameters for acceptance-interval tests for given setup.
464
+
465
+ :param N: sample-size
466
+ :type N: int
467
+ :param dim_Z: size of conditioning set
468
+ :type dim_Z: int
469
+ :return: The hyper-parameters to use.
470
+ :rtype: IProvideHyperparamsForAcceptanceInterval.Hyperparams
471
+ """
472
+ raise NotImplementedError()
473
+
474
+ class IProvideVarianceForCIT:
475
+ """
476
+ Interface to expose for (custom) CIT-implementations if one of the acceptance-interval tests
477
+ (:py:class:`IndicatorImplication_AcceptanceInterval` or :py:class:`IndicatorImplication_AcceptanceInterval`) is used.
478
+ These tests require variance-estimates for the block-wise dependence-scores provided by the CIT.
479
+ """
480
+
481
+ def get_variance_estimate(self, N: int, dim_Z: int, cit_obj: ITestCI) -> float:
482
+ """Get an estimate of the block-wise variance of the dependence score implemented by cit_obj.
483
+
484
+ :param N: The sample count N.
485
+ :type N: int
486
+ :param dim_Z: The dimension of (number of variables in) the condition conditioning set Z.
487
+ :type dim_Z: int
488
+ :param cit_obj: The underlying cit-instance for which the variance should be computed.
489
+ (The present interface :py:class:`IProvideVarianceForCIT` can,
490
+ but does not have to, be exposed on the CIT-type itself.)
491
+ :type cit_obj: ITestCI
492
+ :return: The estimated value of the variance.
493
+ :rtype: float
494
+ """
495
+ raise NotImplementedError()
496
+
497
+ def get_std_estimate(self, N: int, dim_Z: int, cit_obj: ITestCI) -> float:
498
+ """Get an estimate of the block-wise standard-deviation of the dependence score implemented by cit_obj.
499
+ Implementation is optional, if this method is not overridden, the square-root of the variance is used.
500
+
501
+ :param N: The sample count N.
502
+ :type N: int
503
+ :param dim_Z: The dimension of (number of variables in) the condition conditioning set Z.
504
+ :type dim_Z: int
505
+ :param cit_obj: The underlying cit-instance for which the standard-deviation should be computed.
506
+ (The present interface :py:class:`IProvideVarianceForCIT` can,
507
+ but does not have to, be exposed on the CIT-type itself.)
508
+ :type cit_obj: ITestCI
509
+ :return: The estimated value of the standard-deviation.
510
+ :rtype: float
511
+ """
512
+ return np.sqrt( self.get_variance_estimate(N, dim_Z, cit_obj) )
513
+
514
+
515
+ class TruncatedNormal:
516
+ """
517
+ Namescope for collection of helpers providing different useful properties of truncated
518
+ normal distributions. Used by :py:class:`WeakRegime_AcceptanceInterval`
519
+ and :py:class:`IndicatorImplication_AcceptanceInterval`.
520
+ """
521
+
522
+ @staticmethod
523
+ def mills_ratio(beta: float) -> float:
524
+ """Compute the mills-ratio for :math:`\\beta`.
525
+
526
+ :param beta: argument :math:`\\beta`
527
+ :type beta: float
528
+ :return: mills-ratio
529
+ :rtype: float
530
+ """
531
+ return norm.sf(beta) / norm.pdf(beta)
532
+
533
+ # For approximations see eg A. Gasull, F. Utzet: "Approximating Mills ratio"
534
+ @staticmethod
535
+ def approx_mills_lower_bound(beta: float) -> float:
536
+ """Lower bound for mills-ratio.
537
+
538
+ :param beta: argument :math:`\\beta > 0`
539
+ :type beta: float
540
+ :return: lower bound for mills-ratio
541
+ :rtype: float
542
+ """
543
+ assert beta > 0.0
544
+ return np.pi / (np.sqrt(beta*beta + 2 * np.pi) + (np.pi - 1) * beta)
545
+ @staticmethod
546
+ def approx_mills_upper_bound(beta):
547
+ """Upper bound for mills-ratio.
548
+
549
+ :param beta: argument :math:`\\beta > 0`
550
+ :type beta: float
551
+ :return: upper bound for mills-ratio
552
+ :rtype: float
553
+ """
554
+ assert beta > 0.0
555
+ return np.pi / (np.sqrt((np.pi - 2.0)**2 *beta*beta + 2 * np.pi) + 2 * beta)
556
+
557
+ # "inverse" in this context traditionally means "reciprocal" (not an inverse function)
558
+ @classmethod
559
+ def inv_mills_ratio(cls, beta: float) -> float:
560
+ """Reciprocal value of the mills-ratio with improved numerical stability.
561
+
562
+ :param beta: argument :math:`\\beta > 0`
563
+ :type beta: float
564
+ :return: reciprocal value of mills-ratio
565
+ :rtype: float
566
+ """
567
+ if beta > 5.0:
568
+ if beta > 1e9: # for some reason scipy scalar_root sometimes feeds inf into this function ....
569
+ return 0.0
570
+ else:
571
+ return 1.0 / cls.approx_mills_lower_bound(beta)
572
+ elif beta > -8.0:
573
+ return 1.0 / cls.mills_ratio(beta) # this can be numerically unstable, avoid large absolute betas
574
+ else:
575
+ return 0.0
576
+
577
+ # compute mean of a truncated normal
578
+ @classmethod
579
+ def mean_cutoff_below(cls, true_mean: float, true_sigma: float, cutoff: float) -> float:
580
+ """Mean-value of a truncated-below normal distribution.
581
+
582
+ :param true_mean: the normal-distribution's original mean-value parameter :math:`\\mu`
583
+ :type true_mean: float
584
+ :param true_sigma: the normal-distribution's original standard-deviation parameter :math:`\\sigma`
585
+ :type true_sigma: float
586
+ :param cutoff: the cutoff location :math:`c`
587
+ :type cutoff: float
588
+ :return: :math:`E[X|X\\geq c]`, where :math:`X \\sim \\mathcal{N}(\\mu, \\sigma^2)`.
589
+ :rtype: float
590
+ """
591
+ beta = (cutoff - true_mean) / true_sigma
592
+ return true_mean + true_sigma * cls.inv_mills_ratio(beta)
593
+
594
+ @classmethod
595
+ def mean_cutoff_above(cls, true_mean: float, true_sigma: float, cutoff: float) -> float:
596
+ """Mean-value of a truncated-above normal distribution.
597
+
598
+ :param true_mean: the normal-distribution's original mean-value parameter :math:`\\mu`
599
+ :type true_mean: float
600
+ :param true_sigma: the normal-distribution's original standard-deviation parameter :math:`\\sigma`
601
+ :type true_sigma: float
602
+ :param cutoff: the cutoff location :math:`c`
603
+ :type cutoff: float
604
+ :return: :math:`E[X|X\\leq c]`, where :math:`X \\sim \\mathcal{N}(\\mu, \\sigma^2)`.
605
+ :rtype: float
606
+ """
607
+ return -cls.mean_cutoff_below(-true_mean, true_sigma, -cutoff)
608
+
609
+ class WeakRegime_AcceptanceInterval(ITestWeakRegime):
610
+ """
611
+ Acceptance-Interval test implementation of the weak-regime test :py:class:`ITestWeakRegime` interface as used by :py:class:`mCIT`.
612
+ """
613
+
614
+ def __init__(self, hyperparams: IProvideHyperparamsForAcceptanceInterval, cit: ITestCI, cit_variance_estimate: IProvideVarianceForCIT, min_regime_fraction: float=0.15):
615
+ """Construct from hyper-parameter set and cit-specific dependency-score estimator variance.
616
+
617
+ :param hyperparams: Hyper-parameter set to use.
618
+ :type hyperparams: IProvideHyperparamsForAcceptanceInterval
619
+ :param cit: Underlying CIT.
620
+ :type cit: ITestCI
621
+ :param cit_variance_estimate: A cit-specific estimate of the dependency-score estimator's variance.
622
+ :type cit_variance_estimate: IProvideVarianceForCIT
623
+ :param min_regime_fraction: Minimum fraction of data-points in a regime to be considered.
624
+ :type min_regime_fraction: float
625
+ """
626
+ self.hyperparams = hyperparams
627
+ self.cit = cit
628
+ self.cit_variance_est = cit_variance_estimate
629
+ self.min_regime_fraction = min_regime_fraction
630
+
631
+
632
+ def has_true_regime(self, data: CIT_DataPatterned) -> bool:
633
+ params = self.hyperparams.hyperparams_for_acceptance_interval(data.sample_count(), data.z_dim())
634
+ data_blocks = data.view_blocks(params.B)
635
+
636
+ sigma = self.cit_variance_est.get_std_estimate(N=params.B, dim_Z=data.z_dim(), cit_obj=self.cit)
637
+ cit_result = self.cit.run_many(data_blocks)
638
+
639
+ d1 = cit_result.global_score
640
+ d1_is_positive = (d1 > 0.0)
641
+
642
+ if not d1_is_positive:
643
+ cit_result.block_scores = -cit_result.block_scores # sign such that higher dependence regime is positive
644
+
645
+ marked_as_below_cutoff = cit_result.block_scores < params.cutoff
646
+
647
+ data_below_cutoff = cit_result.block_scores[marked_as_below_cutoff]
648
+ approx_count = len(data_below_cutoff)
649
+
650
+ if approx_count < self.min_regime_fraction * len(cit_result.block_scores): # approx count and len(data) both count in blocks
651
+ return False
652
+ else:
653
+ tolerance = norm.ppf(1.0-params.alpha, loc=0.0, scale=sigma / np.sqrt(approx_count))
654
+ lower_bound = TruncatedNormal.mean_cutoff_above(true_mean=0.0, true_sigma=sigma, cutoff=params.cutoff)
655
+ m = np.mean(data_below_cutoff)
656
+ return lower_bound - tolerance < m < tolerance
657
+
658
+
659
+
660
+
661
+ class IndicatorImplication_AcceptanceInterval(ITestIndicatorImplications):
662
+ """
663
+ Indicator Implication test based on acceptance interval.
664
+ """
665
+ def __init__(self, hyperparams: IProvideHyperparamsForAcceptanceInterval, cit: ITestCI, cit_variance_estimate: IProvideVarianceForCIT, min_regime_fraction: float=0.15):
666
+ """Construct from hyper-parameter set and cit-specific dependency-score estimator variance.
667
+
668
+ .. seealso::
669
+ This test is based on the :py:class:`WeakRegime_AcceptanceInterval`.
670
+
671
+ :param hyperparams: Hyper-parameter set to use.
672
+ :type hyperparams: IProvideHyperparamsForAcceptanceInterval
673
+ :param cit: Underlying CIT.
674
+ :type cit: ITestCI
675
+ :param cit_variance_estimate: A cit-specific estimate of the dependency-score estimator's variance.
676
+ :type cit_variance_estimate: IProvideVarianceForCIT
677
+ :param min_regime_fraction: Minimum fraction of data-points in a regime to be considered.
678
+ :type min_regime_fraction: float
679
+ """
680
+ self.hyperparams = hyperparams
681
+ self.cit = cit
682
+ self.cit_variance_est = cit_variance_estimate
683
+ self.min_regime_fraction = min_regime_fraction
684
+
685
+ def is_implied_regime(self, A_list: list[CIT_DataPatterned], B: CIT_DataPatterned) -> bool:
686
+ assert len(A_list) > 0
687
+
688
+ max_z_dim = max( max([A.z_dim() for A in A_list]), B.z_dim() )
689
+ params = self.hyperparams.hyperparams_for_acceptance_interval(B.sample_count(), max_z_dim)
690
+
691
+ sigma = self.cit_variance_est.get_std_estimate(N=params.B, dim_Z=max_z_dim, cit_obj=self.cit)
692
+
693
+ effective_lhs_d_sqr = None
694
+ for A in A_list:
695
+ # recreate weak query locally for uniform block-size (of max dim(Z))
696
+ cit_result = self.cit.run_many(A.view_blocks(params.B))
697
+ d1_is_positive = (cit_result.global_score > 0.0)
698
+ d1_sign = 1.0 if d1_is_positive else -1.0
699
+ d_sqr_contribution = np.square(cit_result.block_scores)
700
+ relative_sign = np.choose( cit_result.block_scores > 0.0, [-d1_sign, d1_sign] )
701
+ eff_d_sqr_contribution = relative_sign * d_sqr_contribution
702
+
703
+ if effective_lhs_d_sqr is None:
704
+ # first entry on lhs
705
+ effective_lhs_d_sqr = eff_d_sqr_contribution
706
+ else:
707
+ # on time-series with differnt max lag in different conditioning-sets,
708
+ # the total number of blocks can be off by one ...
709
+ if abs(len(effective_lhs_d_sqr) - len(eff_d_sqr_contribution)) > 1:
710
+ warn("implied tests should never have block-count off by more than 1???")
711
+
712
+ # add squared contributions (to compute 'euclidean' [with opposite sign correction] distance), but trim (for ts case) entry count
713
+ if len(effective_lhs_d_sqr) < len(eff_d_sqr_contribution):
714
+ # if last block (of this contribution) not in all others, discard
715
+ effective_lhs_d_sqr += eff_d_sqr_contribution[:len(effective_lhs_d_sqr)]
716
+ elif len(effective_lhs_d_sqr) > len(eff_d_sqr_contribution):
717
+ # if this contribution has one block less, trim all others (discard from aggregate result)
718
+ effective_lhs_d_sqr = effective_lhs_d_sqr[:len(eff_d_sqr_contribution)] + eff_d_sqr_contribution
719
+ else:
720
+ # same length: simply add
721
+ effective_lhs_d_sqr += eff_d_sqr_contribution
722
+
723
+ effective_lhs_d = np.sqrt(np.maximum(effective_lhs_d_sqr, 0.0) / len(A_list)) # max ok if c>0
724
+ below_cutoff = effective_lhs_d < params.cutoff
725
+
726
+ # recreate weak query locally for uniform block-size (of max dim(Z)) for B (rhs)
727
+ blocks_B = B.view_blocks(params.B)
728
+ cit_result = self.cit.run_many(blocks_B)
729
+ d1_is_positive = (cit_result.global_score > 0.0)
730
+ block_d = cit_result.block_scores if d1_is_positive else -cit_result.block_scores
731
+
732
+
733
+ # on time-series with differnt max lag in different conditioning-sets,
734
+ # the total number of blocks can be off by one ...
735
+ if abs(len(block_d) - len(below_cutoff)) > 1:
736
+ warn("implied tests should never have block-count off by more than 1???")
737
+ if len(block_d) < len(below_cutoff):
738
+ data_below_cutoff = block_d[below_cutoff[:len(block_d)]]
739
+ elif len(block_d) > len(below_cutoff):
740
+ data_below_cutoff = (block_d[:len(below_cutoff)])[below_cutoff]
741
+ else:
742
+ data_below_cutoff = block_d[below_cutoff]
743
+ approx_count = len(data_below_cutoff)
744
+
745
+ if approx_count < self.min_regime_fraction * blocks_B.block_count():
746
+ return False
747
+ else:
748
+ tolerance = norm.ppf(1.0-params.alpha, loc=0.0, scale=sigma / np.sqrt(approx_count))
749
+ lower_bound = TruncatedNormal.mean_cutoff_above(true_mean=0.0, true_sigma=sigma, cutoff=params.cutoff)
750
+ m = np.mean(data_below_cutoff)
751
+ a = np.clip(approx_count / blocks_B.block_count(), a_min=self.min_regime_fraction, a_max=1.0-self.min_regime_fraction)
752
+ d1_est = ( abs(cit_result.global_score) - a * m ) / ( 1 - a )
753
+ upper_bound = d1_est * norm.sf(params.cutoff, loc=0.0, scale=sigma)
754
+ return lower_bound - tolerance < m < tolerance + upper_bound