intensify 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. intensify/__init__.py +96 -0
  2. intensify/_config.py +41 -0
  3. intensify/backends/__init__.py +10 -0
  4. intensify/backends/_backend.py +90 -0
  5. intensify/backends/jax_backend.py +132 -0
  6. intensify/backends/numpy_backend.py +144 -0
  7. intensify/core/__init__.py +57 -0
  8. intensify/core/base.py +122 -0
  9. intensify/core/diagnostics/__init__.py +34 -0
  10. intensify/core/diagnostics/goodness_of_fit.py +197 -0
  11. intensify/core/diagnostics/metrics.py +39 -0
  12. intensify/core/diagnostics/residuals.py +79 -0
  13. intensify/core/inference/__init__.py +317 -0
  14. intensify/core/inference/bayesian.py +162 -0
  15. intensify/core/inference/em.py +176 -0
  16. intensify/core/inference/mle.py +1300 -0
  17. intensify/core/inference/multivariate_hawkes_mle_params.py +76 -0
  18. intensify/core/inference/online.py +131 -0
  19. intensify/core/inference/univariate_hawkes_mle_params.py +154 -0
  20. intensify/core/kernels/__init__.py +34 -0
  21. intensify/core/kernels/approx_power_law.py +135 -0
  22. intensify/core/kernels/base.py +197 -0
  23. intensify/core/kernels/exponential.py +100 -0
  24. intensify/core/kernels/nonparametric.py +191 -0
  25. intensify/core/kernels/power_law.py +79 -0
  26. intensify/core/kernels/sum_exponential.py +122 -0
  27. intensify/core/processes/__init__.py +31 -0
  28. intensify/core/processes/cox.py +304 -0
  29. intensify/core/processes/hawkes.py +372 -0
  30. intensify/core/processes/marked_hawkes.py +233 -0
  31. intensify/core/processes/nonlinear_hawkes.py +279 -0
  32. intensify/core/processes/poisson.py +373 -0
  33. intensify/core/regularizers.py +55 -0
  34. intensify/core/simulation/__init__.py +17 -0
  35. intensify/core/simulation/cluster.py +142 -0
  36. intensify/core/simulation/thinning.py +249 -0
  37. intensify/py.typed +0 -0
  38. intensify/visualization/__init__.py +14 -0
  39. intensify/visualization/connectivity.py +84 -0
  40. intensify/visualization/event_histograms.py +106 -0
  41. intensify/visualization/intensity.py +102 -0
  42. intensify/visualization/kernels.py +70 -0
  43. intensify-0.2.0.dist-info/METADATA +170 -0
  44. intensify-0.2.0.dist-info/RECORD +46 -0
  45. intensify-0.2.0.dist-info/WHEEL +4 -0
  46. intensify-0.2.0.dist-info/licenses/LICENSE +21 -0
intensify/__init__.py ADDED
@@ -0,0 +1,96 @@
1
+ """
2
+ Intensify — A modern library for point process modeling with Hawkes specialization.
3
+ """
4
+
5
+ __version__ = "0.2.0"
6
+
7
+ from .backends import get_backend, get_backend_name, set_backend
8
+
9
+ # Core abstractions
10
+ # Config API
11
+ from ._config import config_get, config_reset, config_set
12
+ from .core.base import PointProcess
13
+ from .core.inference import (
14
+ BayesianInference,
15
+ FitResult,
16
+ MLEInference,
17
+ OnlineInference,
18
+ get_inference_engine,
19
+ )
20
+ from .core.kernels import (
21
+ ApproxPowerLawKernel,
22
+ ExponentialKernel,
23
+ Kernel,
24
+ NonparametricKernel,
25
+ PowerLawKernel,
26
+ SumExponentialKernel,
27
+ )
28
+ from .core.regularizers import ElasticNet, L1
29
+ from .core.processes import (
30
+ Hawkes, # alias to UnivariateHawkes
31
+ HomogeneousPoisson,
32
+ InhomogeneousPoisson,
33
+ LogGaussianCoxProcess,
34
+ MarkedHawkes,
35
+ MultivariateHawkes,
36
+ MultivariateNonlinearHawkes,
37
+ NonlinearHawkes,
38
+ Poisson, # alias to HomogeneousPoisson
39
+ ShotNoiseCoxProcess,
40
+ UnivariateHawkes,
41
+ )
42
+ from .visualization import (
43
+ plot_connectivity,
44
+ plot_event_aligned_histogram,
45
+ plot_intensity,
46
+ plot_inter_event_intervals,
47
+ plot_kernel,
48
+ )
49
+
50
+ __all__ = [
51
+ # Metadata
52
+ "__version__",
53
+ # Backend
54
+ "get_backend",
55
+ "set_backend",
56
+ "get_backend_name",
57
+ # Kernels
58
+ "Kernel",
59
+ "ExponentialKernel",
60
+ "SumExponentialKernel",
61
+ "PowerLawKernel",
62
+ "ApproxPowerLawKernel",
63
+ "NonparametricKernel",
64
+ # Processes
65
+ "PointProcess",
66
+ "HomogeneousPoisson",
67
+ "InhomogeneousPoisson",
68
+ "Poisson",
69
+ "UnivariateHawkes",
70
+ "Hawkes",
71
+ "MultivariateHawkes",
72
+ "MarkedHawkes",
73
+ "NonlinearHawkes",
74
+ "MultivariateNonlinearHawkes",
75
+ "LogGaussianCoxProcess",
76
+ "ShotNoiseCoxProcess",
77
+ # Regularizers
78
+ "L1",
79
+ "ElasticNet",
80
+ # Inference
81
+ "FitResult",
82
+ "get_inference_engine",
83
+ "MLEInference",
84
+ "OnlineInference",
85
+ "BayesianInference",
86
+ # Visualization
87
+ "plot_intensity",
88
+ "plot_kernel",
89
+ "plot_connectivity",
90
+ "plot_inter_event_intervals",
91
+ "plot_event_aligned_histogram",
92
+ # Config
93
+ "config_get",
94
+ "config_set",
95
+ "config_reset",
96
+ ]
intensify/_config.py ADDED
@@ -0,0 +1,41 @@
1
+ """Global configuration for intensify."""
2
+
3
+
4
+ # Configuration defaults
5
+ _DEFAULTS = {
6
+ "recursive_warning_threshold": 50_000, # warn if N > 50k and using O(N^2) kernel
7
+ "float64_auto_enable": True, # automatically enable x64 for float64 inputs
8
+ "warn_on_nonstationary": True,
9
+ "performance_warning": True,
10
+ }
11
+
12
+ # Current config (mutable)
13
+ _CONFIG: dict[str, object] = _DEFAULTS.copy()
14
+
15
+
16
+ def get(key: str):
17
+ """Get configuration value."""
18
+ return _CONFIG.get(key, _DEFAULTS.get(key))
19
+
20
+
21
+ # Public aliases for package exports
22
+ config_get = get
23
+
24
+
25
+ def set_config(key: str, value: object) -> None:
26
+ """Set configuration value."""
27
+ if key not in _DEFAULTS:
28
+ raise KeyError(f"Unknown config key '{key}'. Valid keys: {list(_DEFAULTS.keys())}")
29
+ _CONFIG[key] = value
30
+
31
+
32
+ config_set = set_config
33
+
34
+
35
+ def reset() -> None:
36
+ """Reset configuration to defaults."""
37
+ _CONFIG.clear()
38
+ _CONFIG.update(_DEFAULTS)
39
+
40
+
41
+ config_reset = reset
@@ -0,0 +1,10 @@
1
+ """
2
+ Backend abstraction layer.
3
+
4
+ All numerical operations in the library should import from here
5
+ to remain backend-agnostic (JAX vs NumPy).
6
+ """
7
+
8
+ from ._backend import _active_backend, get_backend, get_backend_name, set_backend
9
+
10
+ __all__ = ["get_backend", "get_backend_name", "set_backend", "_active_backend"]
@@ -0,0 +1,90 @@
1
+ """Backend selection and dispatch."""
2
+
3
+ import warnings
4
+
5
+ from . import jax_backend, numpy_backend
6
+
7
+ _active_backend: object | None = None
8
+ _backend_name: str = "jax" # default
9
+
10
+
11
+ class _BackendProxy:
12
+ """Thin wrapper that always delegates to the currently active backend.
13
+
14
+ Lets modules cache ``bt = get_backend()`` at import time while still
15
+ picking up later ``set_backend()`` switches — every attribute access
16
+ resolves through ``_resolve_backend()`` on demand.
17
+ """
18
+
19
+ __slots__ = ()
20
+
21
+ def __getattr__(self, name: str):
22
+ return getattr(_resolve_backend(), name)
23
+
24
+ def __repr__(self) -> str:
25
+ return f"<BackendProxy active={_backend_name}>"
26
+
27
+
28
+ _PROXY = _BackendProxy()
29
+
30
+
31
+ def _resolve_backend() -> object:
32
+ """Return the concrete backend module, initializing on first call."""
33
+ global _active_backend, _backend_name
34
+ if _active_backend is None:
35
+ if jax_backend.is_available():
36
+ jax_backend.enable_x64()
37
+ _active_backend = jax_backend.backend
38
+ _backend_name = "jax"
39
+ elif numpy_backend.is_available():
40
+ _active_backend = numpy_backend.backend
41
+ _backend_name = "numpy"
42
+ else:
43
+ raise RuntimeError(
44
+ "Neither JAX nor NumPy is available. Please install one of them."
45
+ )
46
+ return _active_backend
47
+
48
+
49
+ def set_backend(name: str) -> None:
50
+ """Set the active backend: 'jax' or 'numpy'."""
51
+ global _active_backend, _backend_name
52
+ name = name.lower()
53
+ if name == "jax":
54
+ if not jax_backend.is_available():
55
+ warnings.warn(
56
+ "JAX backend requested but JAX is not available. Falling back to NumPy.",
57
+ RuntimeWarning,
58
+ )
59
+ _active_backend = numpy_backend.backend
60
+ _backend_name = "numpy"
61
+ else:
62
+ jax_backend.enable_x64()
63
+ _active_backend = jax_backend.backend
64
+ _backend_name = "jax"
65
+ elif name == "numpy":
66
+ _active_backend = numpy_backend.backend
67
+ _backend_name = "numpy"
68
+ else:
69
+ raise ValueError(f"Unknown backend '{name}'. Use 'jax' or 'numpy'.")
70
+
71
+
72
+ def get_backend() -> object:
73
+ """Return a proxy to the active backend.
74
+
75
+ The returned proxy delegates every attribute access to the concrete
76
+ backend at the moment of access, so ``bt = get_backend()`` cached in
77
+ module scope stays valid across later ``set_backend()`` calls.
78
+ """
79
+ _resolve_backend()
80
+ return _PROXY
81
+
82
+
83
+ def get_backend_name() -> str:
84
+ """Return the name of the active backend ('jax' or 'numpy')."""
85
+ get_backend() # Ensure initialized
86
+ return _backend_name
87
+
88
+
89
+ # Expose the active backend's namespace at module level for convenience
90
+ # All functions/tools will be accessed via get_backend()
@@ -0,0 +1,132 @@
1
+ """JAX backend implementation."""
2
+
3
+ from types import SimpleNamespace
4
+
5
+ # Attempt to import JAX
6
+ try:
7
+ import jax
8
+ import jax.lax
9
+ import jax.numpy as jnp
10
+ import jax.random as random
11
+ from jax import config as jax_config
12
+ from jax import grad, jit, pmap, value_and_grad, vmap
13
+ HAS_JAX = True
14
+ except ImportError as e: # pragma: no cover
15
+ HAS_JAX = False
16
+ jnp = None
17
+ jax = None
18
+ jax_config = None
19
+ grad = None
20
+ jit = None
21
+ random = None
22
+ vmap = None
23
+ pmap = None
24
+ IMPORT_ERROR = e
25
+
26
+
27
+ def enable_x64():
28
+ """Enable 64-bit precision in JAX."""
29
+ if HAS_JAX:
30
+ jax_config.update("jax_enable_x64", True)
31
+
32
+
33
+ def is_available():
34
+ return HAS_JAX
35
+
36
+
37
+ backend = SimpleNamespace(
38
+ # Array types and constructors
39
+ array=lambda *args, **kwargs: jnp.array(*args, **kwargs) if HAS_JAX else None,
40
+ asarray=lambda *args, **kwargs: jnp.asarray(*args, **kwargs) if HAS_JAX else None,
41
+ zeros=lambda *args, **kwargs: jnp.zeros(*args, **kwargs) if HAS_JAX else None,
42
+ ones=lambda *args, **kwargs: jnp.ones(*args, **kwargs) if HAS_JAX else None,
43
+ zeros_like=lambda *args, **kwargs: jnp.zeros_like(*args, **kwargs) if HAS_JAX else None,
44
+ ones_like=lambda *args, **kwargs: jnp.ones_like(*args, **kwargs) if HAS_JAX else None,
45
+ full=lambda *args, **kwargs: jnp.full(*args, **kwargs) if HAS_JAX else None,
46
+ arange=lambda *args, **kwargs: jnp.arange(*args, **kwargs) if HAS_JAX else None,
47
+ linspace=lambda *args, **kwargs: jnp.linspace(*args, **kwargs) if HAS_JAX else None,
48
+ # Basic math ops
49
+ exp=jnp.exp if HAS_JAX else None,
50
+ log=jnp.log if HAS_JAX else None,
51
+ log1p=jnp.log1p if HAS_JAX else None,
52
+ sqrt=jnp.sqrt if HAS_JAX else None,
53
+ square=jnp.square if HAS_JAX else None,
54
+ power=jnp.power if HAS_JAX else None,
55
+ abs=jnp.abs if HAS_JAX else None,
56
+ sign=jnp.sign if HAS_JAX else None,
57
+ sin=jnp.sin if HAS_JAX else None,
58
+ cos=jnp.cos if HAS_JAX else None,
59
+ tan=jnp.tan if HAS_JAX else None,
60
+ # Reduction ops
61
+ sum=jnp.sum if HAS_JAX else None,
62
+ mean=jnp.mean if HAS_JAX else None,
63
+ max=jnp.max if HAS_JAX else None,
64
+ min=jnp.min if HAS_JAX else None,
65
+ std=jnp.std if HAS_JAX else None,
66
+ var=jnp.var if HAS_JAX else None,
67
+ prod=jnp.prod if HAS_JAX else None,
68
+ any=jnp.any if HAS_JAX else None,
69
+ all=jnp.all if HAS_JAX else None,
70
+ # Array manipulation
71
+ concatenate=jnp.concatenate if HAS_JAX else None,
72
+ stack=jnp.stack if HAS_JAX else None,
73
+ reshape=jnp.reshape if HAS_JAX else None,
74
+ transpose=jnp.transpose if HAS_JAX else None,
75
+ swapaxes=jnp.swapaxes if HAS_JAX else None,
76
+ squeeze=jnp.squeeze if HAS_JAX else None,
77
+ expand_dims=jnp.expand_dims if HAS_JAX else None,
78
+ split=jnp.split if HAS_JAX else None,
79
+ roll=jnp.roll if HAS_JAX else None,
80
+ repeat=jnp.repeat if HAS_JAX else None,
81
+ tile=jnp.tile if HAS_JAX else None,
82
+ # Indexing and slicing
83
+ take=jnp.take if HAS_JAX else None,
84
+ where=jnp.where if HAS_JAX else None,
85
+ select=jnp.select if HAS_JAX else None,
86
+ # Linear algebra
87
+ dot=jnp.dot if HAS_JAX else None,
88
+ matmul=jnp.matmul if HAS_JAX else None,
89
+ inner=jnp.inner if HAS_JAX else None,
90
+ outer=jnp.outer if HAS_JAX else None,
91
+ cross=jnp.cross if HAS_JAX else None,
92
+ linalg=jax.numpy.linalg if HAS_JAX else None,
93
+ # Statistics
94
+ median=jnp.median if HAS_JAX else None,
95
+ quantile=jnp.quantile if HAS_JAX else None,
96
+ # Array operations
97
+ diff=jnp.diff if HAS_JAX else None,
98
+ cumsum=jnp.cumsum if HAS_JAX else None,
99
+ # Random
100
+ random=random if HAS_JAX else None,
101
+ PRNGKey=jax.random.PRNGKey if HAS_JAX else None,
102
+ # Control flow and transformations
103
+ lax=jax.lax if HAS_JAX else None,
104
+ jit=jit if HAS_JAX else None,
105
+ grad=grad if HAS_JAX else None,
106
+ value_and_grad=value_and_grad if HAS_JAX else None,
107
+ vmap=vmap if HAS_JAX else None,
108
+ pmap=pmap if HAS_JAX else None,
109
+ cond=jax.lax.cond if HAS_JAX else None,
110
+ scan=jax.lax.scan if HAS_JAX else None,
111
+ # Utility functions
112
+ dtype=jnp.dtype if HAS_JAX else None,
113
+ finfo=jnp.finfo if HAS_JAX else None,
114
+ iinfo=jnp.iinfo if HAS_JAX else None,
115
+ isfinite=jnp.isfinite if HAS_JAX else None,
116
+ isnan=jnp.isnan if HAS_JAX else None,
117
+ isinf=jnp.isinf if HAS_JAX else None,
118
+ allclose=jnp.allclose if HAS_JAX else None,
119
+ array_equal=jnp.array_equal if HAS_JAX else None,
120
+ # Configuration
121
+ config=jax_config if HAS_JAX else None,
122
+ enable_x64=enable_x64,
123
+ is_available=is_available,
124
+ scipy_optimize=None,
125
+ )
126
+
127
+ __all__ = [
128
+ "backend",
129
+ "enable_x64",
130
+ "is_available",
131
+ "HAS_JAX",
132
+ ]
@@ -0,0 +1,144 @@
1
+ """NumPy backend implementation (fallback when JAX is unavailable)."""
2
+
3
+ from types import SimpleNamespace
4
+
5
+ try:
6
+ import numpy as np
7
+ except ImportError: # pragma: no cover
8
+ np = None # type: ignore[assignment]
9
+
10
+ try:
11
+ import scipy.optimize as opt
12
+ except ImportError: # pragma: no cover
13
+ opt = None
14
+
15
+ HAS_NUMPY = np is not None
16
+
17
+
18
+ def is_available() -> bool:
19
+ return HAS_NUMPY
20
+
21
+
22
+ def _noop_random(name: str):
23
+ def _f(*_a, **_k):
24
+ raise RuntimeError(
25
+ f"numpy backend random.{name} called but NumPy is not available."
26
+ )
27
+
28
+ return _f
29
+
30
+
31
+ if HAS_NUMPY:
32
+ _random_ns = SimpleNamespace(
33
+ PRNGKey=lambda seed: None,
34
+ split=lambda key, num: [None] * num,
35
+ uniform=lambda key=None, shape=None, low=0.0, high=1.0, size=None, dtype=None: (
36
+ np.random.uniform(low, high, size if size is not None else shape).astype(dtype)
37
+ if dtype
38
+ else np.random.uniform(low, high, size if size is not None else shape)
39
+ ),
40
+ normal=lambda key=None, shape=None, loc=0.0, scale=1.0, size=None, dtype=None: (
41
+ np.random.normal(loc, scale, size if size is not None else shape).astype(dtype)
42
+ if dtype
43
+ else np.random.normal(loc, scale, size if size is not None else shape)
44
+ ),
45
+ exponential=lambda key=None, shape=None, scale=1.0, size=None, dtype=None: (
46
+ np.random.exponential(scale, size if size is not None else shape).astype(dtype)
47
+ if dtype
48
+ else np.random.exponential(scale, size if size is not None else shape)
49
+ ),
50
+ poisson=lambda key=None, lam=1.0, shape=None: np.random.poisson(lam=lam, size=shape),
51
+ )
52
+ else: # pragma: no cover
53
+ _random_ns = SimpleNamespace(
54
+ PRNGKey=_noop_random("PRNGKey"),
55
+ split=_noop_random("split"),
56
+ uniform=_noop_random("uniform"),
57
+ normal=_noop_random("normal"),
58
+ exponential=_noop_random("exponential"),
59
+ poisson=_noop_random("poisson"),
60
+ )
61
+
62
+ backend = SimpleNamespace(
63
+ array=lambda *args, **kwargs: np.array(*args, **kwargs) if HAS_NUMPY else None,
64
+ asarray=lambda *args, **kwargs: np.asarray(*args, **kwargs) if HAS_NUMPY else None,
65
+ zeros=lambda *args, **kwargs: np.zeros(*args, **kwargs) if HAS_NUMPY else None,
66
+ ones=lambda *args, **kwargs: np.ones(*args, **kwargs) if HAS_NUMPY else None,
67
+ zeros_like=lambda *args, **kwargs: np.zeros_like(*args, **kwargs) if HAS_NUMPY else None,
68
+ ones_like=lambda *args, **kwargs: np.ones_like(*args, **kwargs) if HAS_NUMPY else None,
69
+ full=lambda *args, **kwargs: np.full(*args, **kwargs) if HAS_NUMPY else None,
70
+ arange=lambda *args, **kwargs: np.arange(*args, **kwargs) if HAS_NUMPY else None,
71
+ linspace=lambda *args, **kwargs: np.linspace(*args, **kwargs) if HAS_NUMPY else None,
72
+ exp=np.exp if HAS_NUMPY else None,
73
+ log=np.log if HAS_NUMPY else None,
74
+ log1p=np.log1p if HAS_NUMPY else None,
75
+ power=np.power if HAS_NUMPY else None,
76
+ sqrt=np.sqrt if HAS_NUMPY else None,
77
+ square=np.square if HAS_NUMPY else None,
78
+ abs=np.abs if HAS_NUMPY else None,
79
+ sign=np.sign if HAS_NUMPY else None,
80
+ sin=np.sin if HAS_NUMPY else None,
81
+ cos=np.cos if HAS_NUMPY else None,
82
+ tan=np.tan if HAS_NUMPY else None,
83
+ sum=np.sum if HAS_NUMPY else None,
84
+ mean=np.mean if HAS_NUMPY else None,
85
+ max=np.max if HAS_NUMPY else None,
86
+ min=np.min if HAS_NUMPY else None,
87
+ std=np.std if HAS_NUMPY else None,
88
+ var=np.var if HAS_NUMPY else None,
89
+ prod=np.prod if HAS_NUMPY else None,
90
+ any=np.any if HAS_NUMPY else None,
91
+ all=np.all if HAS_NUMPY else None,
92
+ concatenate=np.concatenate if HAS_NUMPY else None,
93
+ stack=np.stack if HAS_NUMPY else None,
94
+ reshape=np.reshape if HAS_NUMPY else None,
95
+ transpose=np.transpose if HAS_NUMPY else None,
96
+ swapaxes=np.swapaxes if HAS_NUMPY else None,
97
+ squeeze=np.squeeze if HAS_NUMPY else None,
98
+ expand_dims=np.expand_dims if HAS_NUMPY else None,
99
+ split=np.split if HAS_NUMPY else None,
100
+ roll=np.roll if HAS_NUMPY else None,
101
+ repeat=np.repeat if HAS_NUMPY else None,
102
+ tile=np.tile if HAS_NUMPY else None,
103
+ take=np.take if HAS_NUMPY else None,
104
+ where=np.where if HAS_NUMPY else None,
105
+ select=np.select if HAS_NUMPY else None,
106
+ sort=np.sort if HAS_NUMPY else None,
107
+ dot=np.dot if HAS_NUMPY else None,
108
+ matmul=np.matmul if HAS_NUMPY else None,
109
+ inner=np.inner if HAS_NUMPY else None,
110
+ outer=np.outer if HAS_NUMPY else None,
111
+ cross=np.cross if HAS_NUMPY else None,
112
+ linalg=np.linalg if HAS_NUMPY else None,
113
+ diff=np.diff if HAS_NUMPY else None,
114
+ cumsum=np.cumsum if HAS_NUMPY else None,
115
+ median=np.median if HAS_NUMPY else None,
116
+ quantile=np.quantile if HAS_NUMPY else None,
117
+ random=_random_ns,
118
+ lax=SimpleNamespace(
119
+ cond=lambda pred, true_fun, false_fun, operand: (
120
+ true_fun(operand) if pred else false_fun(operand)
121
+ ),
122
+ scan=lambda f, init, xs, length=None: (init, []),
123
+ ),
124
+ jit=lambda f: f,
125
+ grad=None,
126
+ value_and_grad=None,
127
+ vmap=None,
128
+ pmap=None,
129
+ dtype=np.dtype if HAS_NUMPY else None,
130
+ finfo=np.finfo if HAS_NUMPY else None,
131
+ iinfo=np.iinfo if HAS_NUMPY else None,
132
+ isfinite=np.isfinite if HAS_NUMPY else None,
133
+ isnan=np.isnan if HAS_NUMPY else None,
134
+ isinf=np.isinf if HAS_NUMPY else None,
135
+ allclose=np.allclose if HAS_NUMPY else None,
136
+ array_equal=np.array_equal if HAS_NUMPY else None,
137
+ scipy_optimize=opt,
138
+ is_available=is_available,
139
+ )
140
+
141
+ __all__ = [
142
+ "backend",
143
+ "is_available",
144
+ ]
@@ -0,0 +1,57 @@
1
+ """Core abstractions and implementations."""
2
+
3
+ from .base import PointProcess, PointProcessBase
4
+ from .diagnostics import *
5
+ from .inference import *
6
+ from .kernels import *
7
+ from .processes import *
8
+ from .regularizers import L1, ElasticNet, Regularizer
9
+ from .simulation import *
10
+
11
+ __all__ = [
12
+ "PointProcess",
13
+ "PointProcessBase",
14
+ # Kernel subclasses
15
+ "Kernel",
16
+ "ExponentialKernel",
17
+ "SumExponentialKernel",
18
+ "PowerLawKernel",
19
+ "ApproxPowerLawKernel",
20
+ "NonparametricKernel",
21
+ # Process models
22
+ "HomogeneousPoisson",
23
+ "InhomogeneousPoisson",
24
+ "Hawkes",
25
+ "MultivariateHawkes",
26
+ "MarkedHawkes",
27
+ "NonlinearHawkes",
28
+ "MultivariateNonlinearHawkes",
29
+ "LogGaussianCoxProcess",
30
+ "ShotNoiseCoxProcess",
31
+ # Inference
32
+ "InferenceEngine",
33
+ "FitResult",
34
+ "MLEInference",
35
+ "EMInference",
36
+ "OnlineInference",
37
+ "BayesianInference",
38
+ "get_inference_engine",
39
+ "register_engine",
40
+ # Simulation
41
+ "ogata_thinning",
42
+ "ogata_thinning_multivariate",
43
+ "branching_simulation",
44
+ "branching_simulation_multivariate",
45
+ # Diagnostics
46
+ "time_rescaling_test",
47
+ "qq_plot",
48
+ "residual_intensity_plot",
49
+ "raw_residuals",
50
+ "pearson_residuals",
51
+ "compute_information_criteria",
52
+ "branching_ratio",
53
+ "endogeneity_index",
54
+ "L1",
55
+ "ElasticNet",
56
+ "Regularizer",
57
+ ]
intensify/core/base.py ADDED
@@ -0,0 +1,122 @@
1
+ """Abstract base classes for point process models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Protocol, runtime_checkable
5
+
6
+ from ..backends import get_backend
7
+
8
+ bt = get_backend()
9
+
10
+
11
+ @runtime_checkable
12
+ class PointProcess(Protocol):
13
+ """Protocol defining the interface for point process models."""
14
+
15
+ @abstractmethod
16
+ def simulate(self, T: float, seed: int = None) -> bt.array:
17
+ """Generate event times on interval [0, T].
18
+
19
+ Parameters
20
+ ----------
21
+ T : float
22
+ End of observation window.
23
+ seed : int, optional
24
+ Random seed for reproducibility.
25
+
26
+ Returns
27
+ -------
28
+ events : jnp.ndarray or np.ndarray
29
+ Sorted array of event timestamps in [0, T].
30
+ """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def intensity(self, t: float, history: bt.array) -> float:
35
+ """Evaluate conditional intensity function λ(t | history).
36
+
37
+ Parameters
38
+ ----------
39
+ t : float
40
+ Time at which to evaluate intensity.
41
+ history : jnp.ndarray or np.ndarray
42
+ Past event times before t.
43
+
44
+ Returns
45
+ -------
46
+ intensity : float
47
+ Conditional intensity value at time t.
48
+ """
49
+ pass
50
+
51
+ @abstractmethod
52
+ def log_likelihood(self, events: bt.array, T: float) -> float:
53
+ """Compute log-likelihood of observed event sequence.
54
+
55
+ Parameters
56
+ ----------
57
+ events : jnp.ndarray or np.ndarray
58
+ Event timestamps on [0, T].
59
+ T : float
60
+ End of observation window.
61
+
62
+ Returns
63
+ -------
64
+ ll : float
65
+ Log-likelihood value.
66
+ """
67
+ pass
68
+
69
+ def fit(self, events, T: float = None, method: str = "mle", **kwargs):
70
+ """Fit process parameters to observed event data.
71
+
72
+ Parameters
73
+ ----------
74
+ events : array-like or domain-specific data object
75
+ Event timestamps. May also accept domain objects (SpikeTrainData, OrderBookStream).
76
+ T : float, optional
77
+ Observation window end time. Inferred from events if not provided.
78
+ method : str
79
+ Inference method: 'mle', 'em', 'bayesian' (bayesian not yet implemented).
80
+
81
+ Returns
82
+ -------
83
+ result : FitResult
84
+ Standardized container with fitted parameters and diagnostics.
85
+ """
86
+ from .inference import get_inference_engine
87
+
88
+ # Infer T if not provided
89
+ if T is None:
90
+ import warnings
91
+ warnings.warn(
92
+ "T not specified; inferring T = max(events). "
93
+ "This may be incorrect if the observation window extends past the last event.",
94
+ UserWarning,
95
+ )
96
+ events_array = bt.asarray(events)
97
+ T = float(events_array.max())
98
+
99
+ engine = get_inference_engine(method)
100
+ return engine.fit(self, events, T, **kwargs)
101
+
102
+ def get_params(self) -> dict:
103
+ """Return model parameters as a dict (for optimization)."""
104
+ raise NotImplementedError
105
+
106
+ def set_params(self, params: dict) -> None:
107
+ """Set model parameters from a dict."""
108
+ raise NotImplementedError
109
+
110
+ def project_params(self) -> None:
111
+ """Project parameters onto feasible set (e.g., enforce stationarity)."""
112
+ pass # default: nothing to project
113
+
114
+
115
+ class PointProcessBase(ABC, PointProcess):
116
+ """Convenience base class that implements get/set_params stubs and other helpers."""
117
+
118
+ def get_params(self) -> dict:
119
+ raise NotImplementedError
120
+
121
+ def set_params(self, params: dict) -> None:
122
+ raise NotImplementedError