intensify 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- intensify/__init__.py +96 -0
- intensify/_config.py +41 -0
- intensify/backends/__init__.py +10 -0
- intensify/backends/_backend.py +90 -0
- intensify/backends/jax_backend.py +132 -0
- intensify/backends/numpy_backend.py +144 -0
- intensify/core/__init__.py +57 -0
- intensify/core/base.py +122 -0
- intensify/core/diagnostics/__init__.py +34 -0
- intensify/core/diagnostics/goodness_of_fit.py +197 -0
- intensify/core/diagnostics/metrics.py +39 -0
- intensify/core/diagnostics/residuals.py +79 -0
- intensify/core/inference/__init__.py +317 -0
- intensify/core/inference/bayesian.py +162 -0
- intensify/core/inference/em.py +176 -0
- intensify/core/inference/mle.py +1300 -0
- intensify/core/inference/multivariate_hawkes_mle_params.py +76 -0
- intensify/core/inference/online.py +131 -0
- intensify/core/inference/univariate_hawkes_mle_params.py +154 -0
- intensify/core/kernels/__init__.py +34 -0
- intensify/core/kernels/approx_power_law.py +135 -0
- intensify/core/kernels/base.py +197 -0
- intensify/core/kernels/exponential.py +100 -0
- intensify/core/kernels/nonparametric.py +191 -0
- intensify/core/kernels/power_law.py +79 -0
- intensify/core/kernels/sum_exponential.py +122 -0
- intensify/core/processes/__init__.py +31 -0
- intensify/core/processes/cox.py +304 -0
- intensify/core/processes/hawkes.py +372 -0
- intensify/core/processes/marked_hawkes.py +233 -0
- intensify/core/processes/nonlinear_hawkes.py +279 -0
- intensify/core/processes/poisson.py +373 -0
- intensify/core/regularizers.py +55 -0
- intensify/core/simulation/__init__.py +17 -0
- intensify/core/simulation/cluster.py +142 -0
- intensify/core/simulation/thinning.py +249 -0
- intensify/py.typed +0 -0
- intensify/visualization/__init__.py +14 -0
- intensify/visualization/connectivity.py +84 -0
- intensify/visualization/event_histograms.py +106 -0
- intensify/visualization/intensity.py +102 -0
- intensify/visualization/kernels.py +70 -0
- intensify-0.2.0.dist-info/METADATA +170 -0
- intensify-0.2.0.dist-info/RECORD +46 -0
- intensify-0.2.0.dist-info/WHEEL +4 -0
- intensify-0.2.0.dist-info/licenses/LICENSE +21 -0
intensify/__init__.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Intensify — A modern library for point process modeling with Hawkes specialization.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.2.0"
|
|
6
|
+
|
|
7
|
+
from .backends import get_backend, get_backend_name, set_backend
|
|
8
|
+
|
|
9
|
+
# Core abstractions
|
|
10
|
+
# Config API
|
|
11
|
+
from ._config import config_get, config_reset, config_set
|
|
12
|
+
from .core.base import PointProcess
|
|
13
|
+
from .core.inference import (
|
|
14
|
+
BayesianInference,
|
|
15
|
+
FitResult,
|
|
16
|
+
MLEInference,
|
|
17
|
+
OnlineInference,
|
|
18
|
+
get_inference_engine,
|
|
19
|
+
)
|
|
20
|
+
from .core.kernels import (
|
|
21
|
+
ApproxPowerLawKernel,
|
|
22
|
+
ExponentialKernel,
|
|
23
|
+
Kernel,
|
|
24
|
+
NonparametricKernel,
|
|
25
|
+
PowerLawKernel,
|
|
26
|
+
SumExponentialKernel,
|
|
27
|
+
)
|
|
28
|
+
from .core.regularizers import ElasticNet, L1
|
|
29
|
+
from .core.processes import (
|
|
30
|
+
Hawkes, # alias to UnivariateHawkes
|
|
31
|
+
HomogeneousPoisson,
|
|
32
|
+
InhomogeneousPoisson,
|
|
33
|
+
LogGaussianCoxProcess,
|
|
34
|
+
MarkedHawkes,
|
|
35
|
+
MultivariateHawkes,
|
|
36
|
+
MultivariateNonlinearHawkes,
|
|
37
|
+
NonlinearHawkes,
|
|
38
|
+
Poisson, # alias to HomogeneousPoisson
|
|
39
|
+
ShotNoiseCoxProcess,
|
|
40
|
+
UnivariateHawkes,
|
|
41
|
+
)
|
|
42
|
+
from .visualization import (
|
|
43
|
+
plot_connectivity,
|
|
44
|
+
plot_event_aligned_histogram,
|
|
45
|
+
plot_intensity,
|
|
46
|
+
plot_inter_event_intervals,
|
|
47
|
+
plot_kernel,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
__all__ = [
|
|
51
|
+
# Metadata
|
|
52
|
+
"__version__",
|
|
53
|
+
# Backend
|
|
54
|
+
"get_backend",
|
|
55
|
+
"set_backend",
|
|
56
|
+
"get_backend_name",
|
|
57
|
+
# Kernels
|
|
58
|
+
"Kernel",
|
|
59
|
+
"ExponentialKernel",
|
|
60
|
+
"SumExponentialKernel",
|
|
61
|
+
"PowerLawKernel",
|
|
62
|
+
"ApproxPowerLawKernel",
|
|
63
|
+
"NonparametricKernel",
|
|
64
|
+
# Processes
|
|
65
|
+
"PointProcess",
|
|
66
|
+
"HomogeneousPoisson",
|
|
67
|
+
"InhomogeneousPoisson",
|
|
68
|
+
"Poisson",
|
|
69
|
+
"UnivariateHawkes",
|
|
70
|
+
"Hawkes",
|
|
71
|
+
"MultivariateHawkes",
|
|
72
|
+
"MarkedHawkes",
|
|
73
|
+
"NonlinearHawkes",
|
|
74
|
+
"MultivariateNonlinearHawkes",
|
|
75
|
+
"LogGaussianCoxProcess",
|
|
76
|
+
"ShotNoiseCoxProcess",
|
|
77
|
+
# Regularizers
|
|
78
|
+
"L1",
|
|
79
|
+
"ElasticNet",
|
|
80
|
+
# Inference
|
|
81
|
+
"FitResult",
|
|
82
|
+
"get_inference_engine",
|
|
83
|
+
"MLEInference",
|
|
84
|
+
"OnlineInference",
|
|
85
|
+
"BayesianInference",
|
|
86
|
+
# Visualization
|
|
87
|
+
"plot_intensity",
|
|
88
|
+
"plot_kernel",
|
|
89
|
+
"plot_connectivity",
|
|
90
|
+
"plot_inter_event_intervals",
|
|
91
|
+
"plot_event_aligned_histogram",
|
|
92
|
+
# Config
|
|
93
|
+
"config_get",
|
|
94
|
+
"config_set",
|
|
95
|
+
"config_reset",
|
|
96
|
+
]
|
intensify/_config.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Global configuration for intensify."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# Configuration defaults
|
|
5
|
+
_DEFAULTS = {
|
|
6
|
+
"recursive_warning_threshold": 50_000, # warn if N > 50k and using O(N^2) kernel
|
|
7
|
+
"float64_auto_enable": True, # automatically enable x64 for float64 inputs
|
|
8
|
+
"warn_on_nonstationary": True,
|
|
9
|
+
"performance_warning": True,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
# Current config (mutable)
|
|
13
|
+
_CONFIG: dict[str, object] = _DEFAULTS.copy()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get(key: str):
|
|
17
|
+
"""Get configuration value."""
|
|
18
|
+
return _CONFIG.get(key, _DEFAULTS.get(key))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Public aliases for package exports
|
|
22
|
+
config_get = get
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def set_config(key: str, value: object) -> None:
|
|
26
|
+
"""Set configuration value."""
|
|
27
|
+
if key not in _DEFAULTS:
|
|
28
|
+
raise KeyError(f"Unknown config key '{key}'. Valid keys: {list(_DEFAULTS.keys())}")
|
|
29
|
+
_CONFIG[key] = value
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
config_set = set_config
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def reset() -> None:
|
|
36
|
+
"""Reset configuration to defaults."""
|
|
37
|
+
_CONFIG.clear()
|
|
38
|
+
_CONFIG.update(_DEFAULTS)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
config_reset = reset
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backend abstraction layer.
|
|
3
|
+
|
|
4
|
+
All numerical operations in the library should import from here
|
|
5
|
+
to remain backend-agnostic (JAX vs NumPy).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from ._backend import _active_backend, get_backend, get_backend_name, set_backend
|
|
9
|
+
|
|
10
|
+
__all__ = ["get_backend", "get_backend_name", "set_backend", "_active_backend"]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Backend selection and dispatch."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from . import jax_backend, numpy_backend
|
|
6
|
+
|
|
7
|
+
_active_backend: object | None = None
|
|
8
|
+
_backend_name: str = "jax" # default
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _BackendProxy:
|
|
12
|
+
"""Thin wrapper that always delegates to the currently active backend.
|
|
13
|
+
|
|
14
|
+
Lets modules cache ``bt = get_backend()`` at import time while still
|
|
15
|
+
picking up later ``set_backend()`` switches — every attribute access
|
|
16
|
+
resolves through ``_resolve_backend()`` on demand.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
__slots__ = ()
|
|
20
|
+
|
|
21
|
+
def __getattr__(self, name: str):
|
|
22
|
+
return getattr(_resolve_backend(), name)
|
|
23
|
+
|
|
24
|
+
def __repr__(self) -> str:
|
|
25
|
+
return f"<BackendProxy active={_backend_name}>"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_PROXY = _BackendProxy()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _resolve_backend() -> object:
|
|
32
|
+
"""Return the concrete backend module, initializing on first call."""
|
|
33
|
+
global _active_backend, _backend_name
|
|
34
|
+
if _active_backend is None:
|
|
35
|
+
if jax_backend.is_available():
|
|
36
|
+
jax_backend.enable_x64()
|
|
37
|
+
_active_backend = jax_backend.backend
|
|
38
|
+
_backend_name = "jax"
|
|
39
|
+
elif numpy_backend.is_available():
|
|
40
|
+
_active_backend = numpy_backend.backend
|
|
41
|
+
_backend_name = "numpy"
|
|
42
|
+
else:
|
|
43
|
+
raise RuntimeError(
|
|
44
|
+
"Neither JAX nor NumPy is available. Please install one of them."
|
|
45
|
+
)
|
|
46
|
+
return _active_backend
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def set_backend(name: str) -> None:
|
|
50
|
+
"""Set the active backend: 'jax' or 'numpy'."""
|
|
51
|
+
global _active_backend, _backend_name
|
|
52
|
+
name = name.lower()
|
|
53
|
+
if name == "jax":
|
|
54
|
+
if not jax_backend.is_available():
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"JAX backend requested but JAX is not available. Falling back to NumPy.",
|
|
57
|
+
RuntimeWarning,
|
|
58
|
+
)
|
|
59
|
+
_active_backend = numpy_backend.backend
|
|
60
|
+
_backend_name = "numpy"
|
|
61
|
+
else:
|
|
62
|
+
jax_backend.enable_x64()
|
|
63
|
+
_active_backend = jax_backend.backend
|
|
64
|
+
_backend_name = "jax"
|
|
65
|
+
elif name == "numpy":
|
|
66
|
+
_active_backend = numpy_backend.backend
|
|
67
|
+
_backend_name = "numpy"
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Unknown backend '{name}'. Use 'jax' or 'numpy'.")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_backend() -> object:
|
|
73
|
+
"""Return a proxy to the active backend.
|
|
74
|
+
|
|
75
|
+
The returned proxy delegates every attribute access to the concrete
|
|
76
|
+
backend at the moment of access, so ``bt = get_backend()`` cached in
|
|
77
|
+
module scope stays valid across later ``set_backend()`` calls.
|
|
78
|
+
"""
|
|
79
|
+
_resolve_backend()
|
|
80
|
+
return _PROXY
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_backend_name() -> str:
|
|
84
|
+
"""Return the name of the active backend ('jax' or 'numpy')."""
|
|
85
|
+
get_backend() # Ensure initialized
|
|
86
|
+
return _backend_name
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Expose the active backend's namespace at module level for convenience
|
|
90
|
+
# All functions/tools will be accessed via get_backend()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""JAX backend implementation."""
|
|
2
|
+
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
|
|
5
|
+
# Attempt to import JAX
|
|
6
|
+
try:
|
|
7
|
+
import jax
|
|
8
|
+
import jax.lax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import jax.random as random
|
|
11
|
+
from jax import config as jax_config
|
|
12
|
+
from jax import grad, jit, pmap, value_and_grad, vmap
|
|
13
|
+
HAS_JAX = True
|
|
14
|
+
except ImportError as e: # pragma: no cover
|
|
15
|
+
HAS_JAX = False
|
|
16
|
+
jnp = None
|
|
17
|
+
jax = None
|
|
18
|
+
jax_config = None
|
|
19
|
+
grad = None
|
|
20
|
+
jit = None
|
|
21
|
+
random = None
|
|
22
|
+
vmap = None
|
|
23
|
+
pmap = None
|
|
24
|
+
IMPORT_ERROR = e
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def enable_x64():
|
|
28
|
+
"""Enable 64-bit precision in JAX."""
|
|
29
|
+
if HAS_JAX:
|
|
30
|
+
jax_config.update("jax_enable_x64", True)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_available():
|
|
34
|
+
return HAS_JAX
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
backend = SimpleNamespace(
|
|
38
|
+
# Array types and constructors
|
|
39
|
+
array=lambda *args, **kwargs: jnp.array(*args, **kwargs) if HAS_JAX else None,
|
|
40
|
+
asarray=lambda *args, **kwargs: jnp.asarray(*args, **kwargs) if HAS_JAX else None,
|
|
41
|
+
zeros=lambda *args, **kwargs: jnp.zeros(*args, **kwargs) if HAS_JAX else None,
|
|
42
|
+
ones=lambda *args, **kwargs: jnp.ones(*args, **kwargs) if HAS_JAX else None,
|
|
43
|
+
zeros_like=lambda *args, **kwargs: jnp.zeros_like(*args, **kwargs) if HAS_JAX else None,
|
|
44
|
+
ones_like=lambda *args, **kwargs: jnp.ones_like(*args, **kwargs) if HAS_JAX else None,
|
|
45
|
+
full=lambda *args, **kwargs: jnp.full(*args, **kwargs) if HAS_JAX else None,
|
|
46
|
+
arange=lambda *args, **kwargs: jnp.arange(*args, **kwargs) if HAS_JAX else None,
|
|
47
|
+
linspace=lambda *args, **kwargs: jnp.linspace(*args, **kwargs) if HAS_JAX else None,
|
|
48
|
+
# Basic math ops
|
|
49
|
+
exp=jnp.exp if HAS_JAX else None,
|
|
50
|
+
log=jnp.log if HAS_JAX else None,
|
|
51
|
+
log1p=jnp.log1p if HAS_JAX else None,
|
|
52
|
+
sqrt=jnp.sqrt if HAS_JAX else None,
|
|
53
|
+
square=jnp.square if HAS_JAX else None,
|
|
54
|
+
power=jnp.power if HAS_JAX else None,
|
|
55
|
+
abs=jnp.abs if HAS_JAX else None,
|
|
56
|
+
sign=jnp.sign if HAS_JAX else None,
|
|
57
|
+
sin=jnp.sin if HAS_JAX else None,
|
|
58
|
+
cos=jnp.cos if HAS_JAX else None,
|
|
59
|
+
tan=jnp.tan if HAS_JAX else None,
|
|
60
|
+
# Reduction ops
|
|
61
|
+
sum=jnp.sum if HAS_JAX else None,
|
|
62
|
+
mean=jnp.mean if HAS_JAX else None,
|
|
63
|
+
max=jnp.max if HAS_JAX else None,
|
|
64
|
+
min=jnp.min if HAS_JAX else None,
|
|
65
|
+
std=jnp.std if HAS_JAX else None,
|
|
66
|
+
var=jnp.var if HAS_JAX else None,
|
|
67
|
+
prod=jnp.prod if HAS_JAX else None,
|
|
68
|
+
any=jnp.any if HAS_JAX else None,
|
|
69
|
+
all=jnp.all if HAS_JAX else None,
|
|
70
|
+
# Array manipulation
|
|
71
|
+
concatenate=jnp.concatenate if HAS_JAX else None,
|
|
72
|
+
stack=jnp.stack if HAS_JAX else None,
|
|
73
|
+
reshape=jnp.reshape if HAS_JAX else None,
|
|
74
|
+
transpose=jnp.transpose if HAS_JAX else None,
|
|
75
|
+
swapaxes=jnp.swapaxes if HAS_JAX else None,
|
|
76
|
+
squeeze=jnp.squeeze if HAS_JAX else None,
|
|
77
|
+
expand_dims=jnp.expand_dims if HAS_JAX else None,
|
|
78
|
+
split=jnp.split if HAS_JAX else None,
|
|
79
|
+
roll=jnp.roll if HAS_JAX else None,
|
|
80
|
+
repeat=jnp.repeat if HAS_JAX else None,
|
|
81
|
+
tile=jnp.tile if HAS_JAX else None,
|
|
82
|
+
# Indexing and slicing
|
|
83
|
+
take=jnp.take if HAS_JAX else None,
|
|
84
|
+
where=jnp.where if HAS_JAX else None,
|
|
85
|
+
select=jnp.select if HAS_JAX else None,
|
|
86
|
+
# Linear algebra
|
|
87
|
+
dot=jnp.dot if HAS_JAX else None,
|
|
88
|
+
matmul=jnp.matmul if HAS_JAX else None,
|
|
89
|
+
inner=jnp.inner if HAS_JAX else None,
|
|
90
|
+
outer=jnp.outer if HAS_JAX else None,
|
|
91
|
+
cross=jnp.cross if HAS_JAX else None,
|
|
92
|
+
linalg=jax.numpy.linalg if HAS_JAX else None,
|
|
93
|
+
# Statistics
|
|
94
|
+
median=jnp.median if HAS_JAX else None,
|
|
95
|
+
quantile=jnp.quantile if HAS_JAX else None,
|
|
96
|
+
# Array operations
|
|
97
|
+
diff=jnp.diff if HAS_JAX else None,
|
|
98
|
+
cumsum=jnp.cumsum if HAS_JAX else None,
|
|
99
|
+
# Random
|
|
100
|
+
random=random if HAS_JAX else None,
|
|
101
|
+
PRNGKey=jax.random.PRNGKey if HAS_JAX else None,
|
|
102
|
+
# Control flow and transformations
|
|
103
|
+
lax=jax.lax if HAS_JAX else None,
|
|
104
|
+
jit=jit if HAS_JAX else None,
|
|
105
|
+
grad=grad if HAS_JAX else None,
|
|
106
|
+
value_and_grad=value_and_grad if HAS_JAX else None,
|
|
107
|
+
vmap=vmap if HAS_JAX else None,
|
|
108
|
+
pmap=pmap if HAS_JAX else None,
|
|
109
|
+
cond=jax.lax.cond if HAS_JAX else None,
|
|
110
|
+
scan=jax.lax.scan if HAS_JAX else None,
|
|
111
|
+
# Utility functions
|
|
112
|
+
dtype=jnp.dtype if HAS_JAX else None,
|
|
113
|
+
finfo=jnp.finfo if HAS_JAX else None,
|
|
114
|
+
iinfo=jnp.iinfo if HAS_JAX else None,
|
|
115
|
+
isfinite=jnp.isfinite if HAS_JAX else None,
|
|
116
|
+
isnan=jnp.isnan if HAS_JAX else None,
|
|
117
|
+
isinf=jnp.isinf if HAS_JAX else None,
|
|
118
|
+
allclose=jnp.allclose if HAS_JAX else None,
|
|
119
|
+
array_equal=jnp.array_equal if HAS_JAX else None,
|
|
120
|
+
# Configuration
|
|
121
|
+
config=jax_config if HAS_JAX else None,
|
|
122
|
+
enable_x64=enable_x64,
|
|
123
|
+
is_available=is_available,
|
|
124
|
+
scipy_optimize=None,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
__all__ = [
|
|
128
|
+
"backend",
|
|
129
|
+
"enable_x64",
|
|
130
|
+
"is_available",
|
|
131
|
+
"HAS_JAX",
|
|
132
|
+
]
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""NumPy backend implementation (fallback when JAX is unavailable)."""
|
|
2
|
+
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import numpy as np
|
|
7
|
+
except ImportError: # pragma: no cover
|
|
8
|
+
np = None # type: ignore[assignment]
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import scipy.optimize as opt
|
|
12
|
+
except ImportError: # pragma: no cover
|
|
13
|
+
opt = None
|
|
14
|
+
|
|
15
|
+
HAS_NUMPY = np is not None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_available() -> bool:
|
|
19
|
+
return HAS_NUMPY
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _noop_random(name: str):
|
|
23
|
+
def _f(*_a, **_k):
|
|
24
|
+
raise RuntimeError(
|
|
25
|
+
f"numpy backend random.{name} called but NumPy is not available."
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return _f
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if HAS_NUMPY:
|
|
32
|
+
_random_ns = SimpleNamespace(
|
|
33
|
+
PRNGKey=lambda seed: None,
|
|
34
|
+
split=lambda key, num: [None] * num,
|
|
35
|
+
uniform=lambda key=None, shape=None, low=0.0, high=1.0, size=None, dtype=None: (
|
|
36
|
+
np.random.uniform(low, high, size if size is not None else shape).astype(dtype)
|
|
37
|
+
if dtype
|
|
38
|
+
else np.random.uniform(low, high, size if size is not None else shape)
|
|
39
|
+
),
|
|
40
|
+
normal=lambda key=None, shape=None, loc=0.0, scale=1.0, size=None, dtype=None: (
|
|
41
|
+
np.random.normal(loc, scale, size if size is not None else shape).astype(dtype)
|
|
42
|
+
if dtype
|
|
43
|
+
else np.random.normal(loc, scale, size if size is not None else shape)
|
|
44
|
+
),
|
|
45
|
+
exponential=lambda key=None, shape=None, scale=1.0, size=None, dtype=None: (
|
|
46
|
+
np.random.exponential(scale, size if size is not None else shape).astype(dtype)
|
|
47
|
+
if dtype
|
|
48
|
+
else np.random.exponential(scale, size if size is not None else shape)
|
|
49
|
+
),
|
|
50
|
+
poisson=lambda key=None, lam=1.0, shape=None: np.random.poisson(lam=lam, size=shape),
|
|
51
|
+
)
|
|
52
|
+
else: # pragma: no cover
|
|
53
|
+
_random_ns = SimpleNamespace(
|
|
54
|
+
PRNGKey=_noop_random("PRNGKey"),
|
|
55
|
+
split=_noop_random("split"),
|
|
56
|
+
uniform=_noop_random("uniform"),
|
|
57
|
+
normal=_noop_random("normal"),
|
|
58
|
+
exponential=_noop_random("exponential"),
|
|
59
|
+
poisson=_noop_random("poisson"),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
backend = SimpleNamespace(
|
|
63
|
+
array=lambda *args, **kwargs: np.array(*args, **kwargs) if HAS_NUMPY else None,
|
|
64
|
+
asarray=lambda *args, **kwargs: np.asarray(*args, **kwargs) if HAS_NUMPY else None,
|
|
65
|
+
zeros=lambda *args, **kwargs: np.zeros(*args, **kwargs) if HAS_NUMPY else None,
|
|
66
|
+
ones=lambda *args, **kwargs: np.ones(*args, **kwargs) if HAS_NUMPY else None,
|
|
67
|
+
zeros_like=lambda *args, **kwargs: np.zeros_like(*args, **kwargs) if HAS_NUMPY else None,
|
|
68
|
+
ones_like=lambda *args, **kwargs: np.ones_like(*args, **kwargs) if HAS_NUMPY else None,
|
|
69
|
+
full=lambda *args, **kwargs: np.full(*args, **kwargs) if HAS_NUMPY else None,
|
|
70
|
+
arange=lambda *args, **kwargs: np.arange(*args, **kwargs) if HAS_NUMPY else None,
|
|
71
|
+
linspace=lambda *args, **kwargs: np.linspace(*args, **kwargs) if HAS_NUMPY else None,
|
|
72
|
+
exp=np.exp if HAS_NUMPY else None,
|
|
73
|
+
log=np.log if HAS_NUMPY else None,
|
|
74
|
+
log1p=np.log1p if HAS_NUMPY else None,
|
|
75
|
+
power=np.power if HAS_NUMPY else None,
|
|
76
|
+
sqrt=np.sqrt if HAS_NUMPY else None,
|
|
77
|
+
square=np.square if HAS_NUMPY else None,
|
|
78
|
+
abs=np.abs if HAS_NUMPY else None,
|
|
79
|
+
sign=np.sign if HAS_NUMPY else None,
|
|
80
|
+
sin=np.sin if HAS_NUMPY else None,
|
|
81
|
+
cos=np.cos if HAS_NUMPY else None,
|
|
82
|
+
tan=np.tan if HAS_NUMPY else None,
|
|
83
|
+
sum=np.sum if HAS_NUMPY else None,
|
|
84
|
+
mean=np.mean if HAS_NUMPY else None,
|
|
85
|
+
max=np.max if HAS_NUMPY else None,
|
|
86
|
+
min=np.min if HAS_NUMPY else None,
|
|
87
|
+
std=np.std if HAS_NUMPY else None,
|
|
88
|
+
var=np.var if HAS_NUMPY else None,
|
|
89
|
+
prod=np.prod if HAS_NUMPY else None,
|
|
90
|
+
any=np.any if HAS_NUMPY else None,
|
|
91
|
+
all=np.all if HAS_NUMPY else None,
|
|
92
|
+
concatenate=np.concatenate if HAS_NUMPY else None,
|
|
93
|
+
stack=np.stack if HAS_NUMPY else None,
|
|
94
|
+
reshape=np.reshape if HAS_NUMPY else None,
|
|
95
|
+
transpose=np.transpose if HAS_NUMPY else None,
|
|
96
|
+
swapaxes=np.swapaxes if HAS_NUMPY else None,
|
|
97
|
+
squeeze=np.squeeze if HAS_NUMPY else None,
|
|
98
|
+
expand_dims=np.expand_dims if HAS_NUMPY else None,
|
|
99
|
+
split=np.split if HAS_NUMPY else None,
|
|
100
|
+
roll=np.roll if HAS_NUMPY else None,
|
|
101
|
+
repeat=np.repeat if HAS_NUMPY else None,
|
|
102
|
+
tile=np.tile if HAS_NUMPY else None,
|
|
103
|
+
take=np.take if HAS_NUMPY else None,
|
|
104
|
+
where=np.where if HAS_NUMPY else None,
|
|
105
|
+
select=np.select if HAS_NUMPY else None,
|
|
106
|
+
sort=np.sort if HAS_NUMPY else None,
|
|
107
|
+
dot=np.dot if HAS_NUMPY else None,
|
|
108
|
+
matmul=np.matmul if HAS_NUMPY else None,
|
|
109
|
+
inner=np.inner if HAS_NUMPY else None,
|
|
110
|
+
outer=np.outer if HAS_NUMPY else None,
|
|
111
|
+
cross=np.cross if HAS_NUMPY else None,
|
|
112
|
+
linalg=np.linalg if HAS_NUMPY else None,
|
|
113
|
+
diff=np.diff if HAS_NUMPY else None,
|
|
114
|
+
cumsum=np.cumsum if HAS_NUMPY else None,
|
|
115
|
+
median=np.median if HAS_NUMPY else None,
|
|
116
|
+
quantile=np.quantile if HAS_NUMPY else None,
|
|
117
|
+
random=_random_ns,
|
|
118
|
+
lax=SimpleNamespace(
|
|
119
|
+
cond=lambda pred, true_fun, false_fun, operand: (
|
|
120
|
+
true_fun(operand) if pred else false_fun(operand)
|
|
121
|
+
),
|
|
122
|
+
scan=lambda f, init, xs, length=None: (init, []),
|
|
123
|
+
),
|
|
124
|
+
jit=lambda f: f,
|
|
125
|
+
grad=None,
|
|
126
|
+
value_and_grad=None,
|
|
127
|
+
vmap=None,
|
|
128
|
+
pmap=None,
|
|
129
|
+
dtype=np.dtype if HAS_NUMPY else None,
|
|
130
|
+
finfo=np.finfo if HAS_NUMPY else None,
|
|
131
|
+
iinfo=np.iinfo if HAS_NUMPY else None,
|
|
132
|
+
isfinite=np.isfinite if HAS_NUMPY else None,
|
|
133
|
+
isnan=np.isnan if HAS_NUMPY else None,
|
|
134
|
+
isinf=np.isinf if HAS_NUMPY else None,
|
|
135
|
+
allclose=np.allclose if HAS_NUMPY else None,
|
|
136
|
+
array_equal=np.array_equal if HAS_NUMPY else None,
|
|
137
|
+
scipy_optimize=opt,
|
|
138
|
+
is_available=is_available,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
__all__ = [
|
|
142
|
+
"backend",
|
|
143
|
+
"is_available",
|
|
144
|
+
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Core abstractions and implementations."""
|
|
2
|
+
|
|
3
|
+
from .base import PointProcess, PointProcessBase
|
|
4
|
+
from .diagnostics import *
|
|
5
|
+
from .inference import *
|
|
6
|
+
from .kernels import *
|
|
7
|
+
from .processes import *
|
|
8
|
+
from .regularizers import L1, ElasticNet, Regularizer
|
|
9
|
+
from .simulation import *
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"PointProcess",
|
|
13
|
+
"PointProcessBase",
|
|
14
|
+
# Kernel subclasses
|
|
15
|
+
"Kernel",
|
|
16
|
+
"ExponentialKernel",
|
|
17
|
+
"SumExponentialKernel",
|
|
18
|
+
"PowerLawKernel",
|
|
19
|
+
"ApproxPowerLawKernel",
|
|
20
|
+
"NonparametricKernel",
|
|
21
|
+
# Process models
|
|
22
|
+
"HomogeneousPoisson",
|
|
23
|
+
"InhomogeneousPoisson",
|
|
24
|
+
"Hawkes",
|
|
25
|
+
"MultivariateHawkes",
|
|
26
|
+
"MarkedHawkes",
|
|
27
|
+
"NonlinearHawkes",
|
|
28
|
+
"MultivariateNonlinearHawkes",
|
|
29
|
+
"LogGaussianCoxProcess",
|
|
30
|
+
"ShotNoiseCoxProcess",
|
|
31
|
+
# Inference
|
|
32
|
+
"InferenceEngine",
|
|
33
|
+
"FitResult",
|
|
34
|
+
"MLEInference",
|
|
35
|
+
"EMInference",
|
|
36
|
+
"OnlineInference",
|
|
37
|
+
"BayesianInference",
|
|
38
|
+
"get_inference_engine",
|
|
39
|
+
"register_engine",
|
|
40
|
+
# Simulation
|
|
41
|
+
"ogata_thinning",
|
|
42
|
+
"ogata_thinning_multivariate",
|
|
43
|
+
"branching_simulation",
|
|
44
|
+
"branching_simulation_multivariate",
|
|
45
|
+
# Diagnostics
|
|
46
|
+
"time_rescaling_test",
|
|
47
|
+
"qq_plot",
|
|
48
|
+
"residual_intensity_plot",
|
|
49
|
+
"raw_residuals",
|
|
50
|
+
"pearson_residuals",
|
|
51
|
+
"compute_information_criteria",
|
|
52
|
+
"branching_ratio",
|
|
53
|
+
"endogeneity_index",
|
|
54
|
+
"L1",
|
|
55
|
+
"ElasticNet",
|
|
56
|
+
"Regularizer",
|
|
57
|
+
]
|
intensify/core/base.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Abstract base classes for point process models."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Protocol, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from ..backends import get_backend
|
|
7
|
+
|
|
8
|
+
bt = get_backend()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class PointProcess(Protocol):
|
|
13
|
+
"""Protocol defining the interface for point process models."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def simulate(self, T: float, seed: int = None) -> bt.array:
|
|
17
|
+
"""Generate event times on interval [0, T].
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
T : float
|
|
22
|
+
End of observation window.
|
|
23
|
+
seed : int, optional
|
|
24
|
+
Random seed for reproducibility.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
events : jnp.ndarray or np.ndarray
|
|
29
|
+
Sorted array of event timestamps in [0, T].
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def intensity(self, t: float, history: bt.array) -> float:
|
|
35
|
+
"""Evaluate conditional intensity function λ(t | history).
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
t : float
|
|
40
|
+
Time at which to evaluate intensity.
|
|
41
|
+
history : jnp.ndarray or np.ndarray
|
|
42
|
+
Past event times before t.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
intensity : float
|
|
47
|
+
Conditional intensity value at time t.
|
|
48
|
+
"""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def log_likelihood(self, events: bt.array, T: float) -> float:
|
|
53
|
+
"""Compute log-likelihood of observed event sequence.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
events : jnp.ndarray or np.ndarray
|
|
58
|
+
Event timestamps on [0, T].
|
|
59
|
+
T : float
|
|
60
|
+
End of observation window.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
ll : float
|
|
65
|
+
Log-likelihood value.
|
|
66
|
+
"""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
def fit(self, events, T: float = None, method: str = "mle", **kwargs):
|
|
70
|
+
"""Fit process parameters to observed event data.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
events : array-like or domain-specific data object
|
|
75
|
+
Event timestamps. May also accept domain objects (SpikeTrainData, OrderBookStream).
|
|
76
|
+
T : float, optional
|
|
77
|
+
Observation window end time. Inferred from events if not provided.
|
|
78
|
+
method : str
|
|
79
|
+
Inference method: 'mle', 'em', 'bayesian' (bayesian not yet implemented).
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
result : FitResult
|
|
84
|
+
Standardized container with fitted parameters and diagnostics.
|
|
85
|
+
"""
|
|
86
|
+
from .inference import get_inference_engine
|
|
87
|
+
|
|
88
|
+
# Infer T if not provided
|
|
89
|
+
if T is None:
|
|
90
|
+
import warnings
|
|
91
|
+
warnings.warn(
|
|
92
|
+
"T not specified; inferring T = max(events). "
|
|
93
|
+
"This may be incorrect if the observation window extends past the last event.",
|
|
94
|
+
UserWarning,
|
|
95
|
+
)
|
|
96
|
+
events_array = bt.asarray(events)
|
|
97
|
+
T = float(events_array.max())
|
|
98
|
+
|
|
99
|
+
engine = get_inference_engine(method)
|
|
100
|
+
return engine.fit(self, events, T, **kwargs)
|
|
101
|
+
|
|
102
|
+
def get_params(self) -> dict:
|
|
103
|
+
"""Return model parameters as a dict (for optimization)."""
|
|
104
|
+
raise NotImplementedError
|
|
105
|
+
|
|
106
|
+
def set_params(self, params: dict) -> None:
|
|
107
|
+
"""Set model parameters from a dict."""
|
|
108
|
+
raise NotImplementedError
|
|
109
|
+
|
|
110
|
+
def project_params(self) -> None:
|
|
111
|
+
"""Project parameters onto feasible set (e.g., enforce stationarity)."""
|
|
112
|
+
pass # default: nothing to project
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class PointProcessBase(ABC, PointProcess):
|
|
116
|
+
"""Convenience base class that implements get/set_params stubs and other helpers."""
|
|
117
|
+
|
|
118
|
+
def get_params(self) -> dict:
|
|
119
|
+
raise NotImplementedError
|
|
120
|
+
|
|
121
|
+
def set_params(self, params: dict) -> None:
|
|
122
|
+
raise NotImplementedError
|