pulse-engine 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. pulse_engine/__init__.py +0 -0
  2. pulse_engine/adapters/__init__.py +58 -0
  3. pulse_engine/adapters/audio_transcription.py +167 -0
  4. pulse_engine/adapters/batcher.py +36 -0
  5. pulse_engine/adapters/digital_news.py +128 -0
  6. pulse_engine/adapters/digital_news_metadata.py +536 -0
  7. pulse_engine/adapters/exceptions.py +10 -0
  8. pulse_engine/adapters/models.py +134 -0
  9. pulse_engine/adapters/opensearch_storage.py +160 -0
  10. pulse_engine/adapters/speech_content.py +130 -0
  11. pulse_engine/adapters/speech_metadata.py +374 -0
  12. pulse_engine/adapters/twitter.py +423 -0
  13. pulse_engine/adapters/youtube_downloader.py +186 -0
  14. pulse_engine/adapters/youtube_metadata.py +261 -0
  15. pulse_engine/api/__init__.py +0 -0
  16. pulse_engine/api/v1/__init__.py +0 -0
  17. pulse_engine/api/v1/auth.py +91 -0
  18. pulse_engine/api/v1/health.py +62 -0
  19. pulse_engine/api/v1/router.py +16 -0
  20. pulse_engine/chain_recovery.py +131 -0
  21. pulse_engine/cli/__init__.py +0 -0
  22. pulse_engine/cli/main.py +169 -0
  23. pulse_engine/cli/templates/cookiecutter.json +4 -0
  24. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/.gitignore +13 -0
  25. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/Dockerfile +32 -0
  26. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/pipeline.yaml +17 -0
  27. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/pyproject.toml +25 -0
  28. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/src/pulse_{{cookiecutter.product_slug}}/__init__.py +8 -0
  29. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/tests/__init__.py +0 -0
  30. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/tests/unit/__init__.py +0 -0
  31. pulse_engine/cli/templates/pulse-{{cookiecutter.product_name}}/tests/unit/test_manifest.py +15 -0
  32. pulse_engine/client.py +95 -0
  33. pulse_engine/config.py +157 -0
  34. pulse_engine/core/__init__.py +0 -0
  35. pulse_engine/core/error_handlers.py +64 -0
  36. pulse_engine/core/exceptions.py +67 -0
  37. pulse_engine/core/job_token.py +109 -0
  38. pulse_engine/core/logging.py +45 -0
  39. pulse_engine/core/scope.py +23 -0
  40. pulse_engine/core/security.py +130 -0
  41. pulse_engine/database.py +30 -0
  42. pulse_engine/dependencies.py +166 -0
  43. pulse_engine/deployment/__init__.py +0 -0
  44. pulse_engine/deployment/backend_deployment_repository.py +83 -0
  45. pulse_engine/deployment/backends/__init__.py +0 -0
  46. pulse_engine/deployment/backends/base.py +50 -0
  47. pulse_engine/deployment/backends/exceptions.py +20 -0
  48. pulse_engine/deployment/backends/native_lambda.py +125 -0
  49. pulse_engine/deployment/backends/prefect_ecs.py +116 -0
  50. pulse_engine/deployment/backends/prefect_k8s.py +131 -0
  51. pulse_engine/deployment/backends/registry.py +50 -0
  52. pulse_engine/deployment/infra_provisioner.py +285 -0
  53. pulse_engine/deployment/job_launcher.py +178 -0
  54. pulse_engine/deployment/models.py +48 -0
  55. pulse_engine/deployment/repository.py +54 -0
  56. pulse_engine/deployment/router.py +22 -0
  57. pulse_engine/deployment/schemas.py +18 -0
  58. pulse_engine/deployment/service.py +65 -0
  59. pulse_engine/extractor/__init__.py +0 -0
  60. pulse_engine/extractor/adapters/__init__.py +0 -0
  61. pulse_engine/extractor/base.py +48 -0
  62. pulse_engine/extractor/models.py +50 -0
  63. pulse_engine/extractor/orchestrator/__init__.py +15 -0
  64. pulse_engine/extractor/orchestrator/base.py +34 -0
  65. pulse_engine/extractor/orchestrator/noop.py +37 -0
  66. pulse_engine/extractor/orchestrator/prefect.py +163 -0
  67. pulse_engine/extractor/repository.py +163 -0
  68. pulse_engine/extractor/router.py +102 -0
  69. pulse_engine/extractor/schemas.py +93 -0
  70. pulse_engine/extractor/service.py +431 -0
  71. pulse_engine/extractor/stage_models.py +36 -0
  72. pulse_engine/extractor/stage_repository.py +109 -0
  73. pulse_engine/main.py +195 -0
  74. pulse_engine/mcp/__init__.py +0 -0
  75. pulse_engine/mcp/__main__.py +5 -0
  76. pulse_engine/mcp/server.py +108 -0
  77. pulse_engine/mcp/tools_jobs.py +159 -0
  78. pulse_engine/mcp/tools_kb.py +88 -0
  79. pulse_engine/mcp/tools_modules.py +115 -0
  80. pulse_engine/mcp/tools_pipelines.py +215 -0
  81. pulse_engine/mcp/tools_processor.py +208 -0
  82. pulse_engine/middleware/__init__.py +0 -0
  83. pulse_engine/middleware/rate_limit.py +144 -0
  84. pulse_engine/middleware/request_id.py +16 -0
  85. pulse_engine/middleware/security_headers.py +25 -0
  86. pulse_engine/middleware/tenant.py +90 -0
  87. pulse_engine/pipeline/__init__.py +0 -0
  88. pulse_engine/pipeline/config_parser.py +148 -0
  89. pulse_engine/pipeline/expression.py +268 -0
  90. pulse_engine/pipeline/models.py +98 -0
  91. pulse_engine/pipeline/repositories.py +224 -0
  92. pulse_engine/pipeline/router_modules.py +66 -0
  93. pulse_engine/pipeline/router_pipelines.py +198 -0
  94. pulse_engine/pipeline/schemas.py +200 -0
  95. pulse_engine/pipeline/service.py +250 -0
  96. pulse_engine/pipeline/translators/__init__.py +44 -0
  97. pulse_engine/pipeline/translators/airflow_status.py +11 -0
  98. pulse_engine/pipeline/translators/airflow_translator.py +22 -0
  99. pulse_engine/pipeline/translators/base.py +42 -0
  100. pulse_engine/pipeline/translators/prefect_status.py +93 -0
  101. pulse_engine/pipeline/translators/prefect_translator.py +195 -0
  102. pulse_engine/processor/__init__.py +0 -0
  103. pulse_engine/processor/base.py +36 -0
  104. pulse_engine/processor/core/__init__.py +0 -0
  105. pulse_engine/processor/core/analysis.py +148 -0
  106. pulse_engine/processor/core/chunking.py +158 -0
  107. pulse_engine/processor/core/prompts.py +340 -0
  108. pulse_engine/processor/core/topic_splitter.py +105 -0
  109. pulse_engine/processor/defaults/__init__.py +11 -0
  110. pulse_engine/processor/defaults/core_processor.py +12 -0
  111. pulse_engine/processor/defaults/postprocessor.py +12 -0
  112. pulse_engine/processor/defaults/preprocessor.py +12 -0
  113. pulse_engine/processor/llm/__init__.py +0 -0
  114. pulse_engine/processor/llm/provider.py +58 -0
  115. pulse_engine/processor/ocr/gemini.py +52 -0
  116. pulse_engine/processor/pipeline.py +107 -0
  117. pulse_engine/processor/postprocessor/__init__.py +0 -0
  118. pulse_engine/processor/postprocessor/embeddings.py +34 -0
  119. pulse_engine/processor/postprocessor/tasks.py +180 -0
  120. pulse_engine/processor/preprocessor/__init__.py +0 -0
  121. pulse_engine/processor/preprocessor/tasks.py +71 -0
  122. pulse_engine/processor/router.py +192 -0
  123. pulse_engine/processor/schemas.py +167 -0
  124. pulse_engine/registry.py +117 -0
  125. pulse_engine/runners/__init__.py +0 -0
  126. pulse_engine/runners/lambda_runner.py +26 -0
  127. pulse_engine/runners/pipeline_runner.py +43 -0
  128. pulse_engine/runners/prefect_pipeline_flow.py +904 -0
  129. pulse_engine/runners/prefect_runner.py +33 -0
  130. pulse_engine/s3.py +72 -0
  131. pulse_engine/secrets.py +46 -0
  132. pulse_engine/services/__init__.py +0 -0
  133. pulse_engine/services/bootstrap.py +211 -0
  134. pulse_engine/services/opensearch.py +84 -0
  135. pulse_engine/storage/__init__.py +0 -0
  136. pulse_engine/storage/connectors/__init__.py +0 -0
  137. pulse_engine/storage/connectors/athena.py +226 -0
  138. pulse_engine/storage/connectors/base.py +32 -0
  139. pulse_engine/storage/connectors/opensearch.py +344 -0
  140. pulse_engine/storage/knowledge_base.py +68 -0
  141. pulse_engine/storage/router.py +78 -0
  142. pulse_engine/storage/schemas.py +93 -0
  143. pulse_engine/testing/__init__.py +13 -0
  144. pulse_engine/testing/fixtures.py +50 -0
  145. pulse_engine/testing/mocks.py +104 -0
  146. pulse_engine/worker.py +53 -0
  147. pulse_engine-0.2.0.dist-info/METADATA +654 -0
  148. pulse_engine-0.2.0.dist-info/RECORD +150 -0
  149. pulse_engine-0.2.0.dist-info/WHEEL +4 -0
  150. pulse_engine-0.2.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,33 @@
1
+ """Prefect flow wrapper for product containers.
2
+
3
+ This module is the ONLY place in the engine that imports Prefect.
4
+ Product containers must NOT import prefect directly — they expose a plain
5
+ `entrypoint:run` function and this runner wraps it with @flow at runtime.
6
+
7
+ All Prefect-based backends (ECS, Kubernetes) use this as the flow entrypoint:
8
+ pulse_engine.runners.prefect_runner:main
9
+ """
10
+
11
+ from typing import Any
12
+
13
+ from prefect import flow
14
+
15
+
16
+ @flow(name="pulse-stage-runner") # type: ignore[untyped-decorator]
17
+ def main(
18
+ job_id: str,
19
+ chain: bool,
20
+ config: dict[str, Any],
21
+ pulse_api_token: str,
22
+ pulse_engine_url: str,
23
+ ) -> None:
24
+ """Engine-owned Prefect flow. Delegates to the product's plain entrypoint."""
25
+ from entrypoint import run # imported at call time from the product container
26
+
27
+ run(
28
+ job_id=job_id,
29
+ chain=chain,
30
+ config=config,
31
+ pulse_api_token=pulse_api_token,
32
+ pulse_engine_url=pulse_engine_url,
33
+ )
pulse_engine/s3.py ADDED
@@ -0,0 +1,72 @@
1
+ """S3-backed inter-stage data exchange with NDJSON format and _SUCCESS sentinel."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import Any
7
+
8
+ from botocore.exceptions import ClientError
9
+
10
+
11
+ class StageDataNotReady(Exception):
12
+ """Raised when the upstream _SUCCESS sentinel is missing."""
13
+
14
+
15
+ class S3Stage:
16
+ """Read/write NDJSON data between pipeline stages via S3."""
17
+
18
+ def __init__(self, bucket: str, job_id: str, s3_client: Any) -> None:
19
+ self._bucket = bucket
20
+ self._job_id = job_id
21
+ self._s3 = s3_client
22
+
23
+ @property
24
+ def raw_prefix(self) -> str:
25
+ return f"jobs/{self._job_id}/raw/"
26
+
27
+ @property
28
+ def processed_prefix(self) -> str:
29
+ return f"jobs/{self._job_id}/processed/"
30
+
31
+ def write_raw(self, documents: list[dict[str, Any]]) -> None:
32
+ self._write(self.raw_prefix, documents)
33
+
34
+ def write_processed(self, documents: list[dict[str, Any]]) -> None:
35
+ self._write(self.processed_prefix, documents)
36
+
37
+ def read_raw(self) -> list[dict[str, Any]]:
38
+ return self._read(self.raw_prefix)
39
+
40
+ def read_processed(self) -> list[dict[str, Any]]:
41
+ return self._read(self.processed_prefix)
42
+
43
+ def _write(self, prefix: str, documents: list[dict[str, Any]]) -> None:
44
+ ndjson = "\n".join(json.dumps(doc) for doc in documents)
45
+ self._s3.put_object(
46
+ Bucket=self._bucket,
47
+ Key=f"{prefix}data.ndjson",
48
+ Body=ndjson.encode(),
49
+ ContentType="application/x-ndjson",
50
+ )
51
+ self._s3.put_object(
52
+ Bucket=self._bucket,
53
+ Key=f"{prefix}_SUCCESS",
54
+ Body=b"",
55
+ )
56
+
57
+ def _read(self, prefix: str) -> list[dict[str, Any]]:
58
+ # Check sentinel — only catch 404; let other S3 errors propagate
59
+ try:
60
+ self._s3.head_object(Bucket=self._bucket, Key=f"{prefix}_SUCCESS")
61
+ except ClientError as e:
62
+ if e.response["Error"]["Code"] == "404":
63
+ msg = (
64
+ f"Sentinel {prefix}_SUCCESS not found"
65
+ " — upstream stage may not be complete"
66
+ )
67
+ raise StageDataNotReady(msg) from e
68
+ raise # re-raise permissions errors, network errors, etc.
69
+
70
+ resp = self._s3.get_object(Bucket=self._bucket, Key=f"{prefix}data.ndjson")
71
+ body = resp["Body"].read().decode()
72
+ return [json.loads(line) for line in body.strip().split("\n") if line.strip()]
@@ -0,0 +1,46 @@
1
+ """AWS Secrets Manager utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+
7
+ import structlog
8
+
9
+ logger = structlog.get_logger(__name__)
10
+
11
+
12
+ async def fetch_secret(secret_id: str, region: str = "ap-south-1") -> str:
13
+ """Fetch a plaintext secret string from AWS Secrets Manager.
14
+
15
+ Runs the boto3 call in a thread so it is safe to await from async code.
16
+
17
+ Args:
18
+ secret_id: The secret name or ARN.
19
+ region: AWS region for the Secrets Manager endpoint.
20
+
21
+ Returns:
22
+ The secret value as a plain string.
23
+
24
+ Raises:
25
+ RuntimeError: If the secret cannot be retrieved or is empty/binary.
26
+ """
27
+ import boto3
28
+ from botocore.exceptions import ClientError
29
+
30
+ def _get() -> str:
31
+ client = boto3.client("secretsmanager", region_name=region)
32
+ try:
33
+ resp = client.get_secret_value(SecretId=secret_id)
34
+ except ClientError as exc:
35
+ raise RuntimeError(
36
+ f"Failed to fetch secret '{secret_id}' from Secrets Manager: {exc}"
37
+ ) from exc
38
+ value: str = resp.get("SecretString") or ""
39
+ if not value:
40
+ raise RuntimeError(
41
+ f"Secret '{secret_id}' is empty or binary — expected a plaintext string"
42
+ )
43
+ return value
44
+
45
+ logger.debug("secrets_manager_fetch", secret_id=secret_id, region=region)
46
+ return await asyncio.to_thread(_get)
File without changes
@@ -0,0 +1,211 @@
1
+ """Shared service initialization for FastAPI and MCP entry points."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import AsyncIterator
6
+ from contextlib import asynccontextmanager
7
+ from dataclasses import dataclass, field
8
+ from types import EllipsisType
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import boto3
12
+
13
+ from pulse_engine.config import Settings
14
+ from pulse_engine.core.logging import setup_logging
15
+ from pulse_engine.database import build_async_engine, build_session_factory
16
+ from pulse_engine.extractor.orchestrator import get_orchestrator_adapter
17
+ from pulse_engine.processor.base import (
18
+ BaseCoreProcessor,
19
+ BasePostprocessor,
20
+ BasePreprocessor,
21
+ )
22
+ from pulse_engine.processor.pipeline import _SENTINEL, ProcessingPipeline
23
+ from pulse_engine.services.opensearch import OpenSearchService
24
+ from pulse_engine.storage.connectors.athena import AthenaConnector
25
+ from pulse_engine.storage.connectors.opensearch import OpenSearchConnector
26
+ from pulse_engine.storage.knowledge_base import KnowledgeBaseService
27
+
28
+ if TYPE_CHECKING:
29
+ from celery import Celery
30
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
31
+
32
+ from pulse_engine.extractor.orchestrator.base import BaseOrchestratorAdapter
33
+ from pulse_engine.registry import ProductManifest
34
+
35
+
36
+ @dataclass
37
+ class ServiceContainer:
38
+ settings: Settings
39
+ opensearch: OpenSearchService
40
+ kb_service: KnowledgeBaseService
41
+ db_session_factory: async_sessionmaker[AsyncSession] | None
42
+ orchestrator_adapter: BaseOrchestratorAdapter
43
+ pipeline: ProcessingPipeline | None = None
44
+ celery_app: Celery | None = None
45
+ s3_client: Any = None
46
+ token_issuer: Any = None
47
+ pipeline_translators: dict[str, Any] = field(default_factory=dict)
48
+ pipeline_status_providers: dict[str, Any] = field(default_factory=dict)
49
+ _engine: AsyncEngine | None = field(default=None, repr=False)
50
+ _opensearch_connector: OpenSearchConnector | None = field(default=None, repr=False)
51
+ _athena_connector: AthenaConnector | None = field(default=None, repr=False)
52
+
53
+
54
+ def _resolve_stage(
55
+ value: Any,
56
+ ) -> Any:
57
+ """Convert manifest Ellipsis convention into pipeline sentinel."""
58
+ if isinstance(value, EllipsisType):
59
+ return _SENTINEL
60
+ return value
61
+
62
+
63
+ def _build_celery(
64
+ settings: Settings, manifest: ProductManifest | None = None
65
+ ) -> Celery:
66
+ from pulse_engine.worker import create_celery_app
67
+
68
+ return create_celery_app(settings, manifest)
69
+
70
+
71
+ @asynccontextmanager
72
+ async def bootstrap_services(
73
+ settings: Settings,
74
+ manifest: ProductManifest | None = None,
75
+ ) -> AsyncIterator[ServiceContainer]:
76
+ setup_logging(log_level=settings.log_level, env=settings.app_env)
77
+
78
+ # OpenSearch — connect with auth if credentials provided
79
+ http_auth: tuple[str, str] | None = None
80
+ if settings.opensearch_username and settings.opensearch_password:
81
+ http_auth = (settings.opensearch_username, settings.opensearch_password)
82
+
83
+ opensearch = OpenSearchService(
84
+ url=settings.opensearch_url,
85
+ http_auth=http_auth,
86
+ use_ssl=settings.opensearch_use_ssl,
87
+ verify_certs=settings.opensearch_verify_certs,
88
+ )
89
+
90
+ opensearch_connector = OpenSearchConnector(
91
+ opensearch,
92
+ settings.embedding_dimension,
93
+ index_prefix=settings.opensearch_index_prefix,
94
+ )
95
+ await opensearch_connector.initialize()
96
+
97
+ # Athena — uses its own AWS credentials; product specifies the database
98
+ athena_connector: AthenaConnector | None = None
99
+ if settings.athena_output_location:
100
+ athena_kwargs: dict[str, str] = {"region_name": settings.aws_region}
101
+ if settings.athena_aws_access_key_id:
102
+ athena_kwargs["aws_access_key_id"] = settings.athena_aws_access_key_id
103
+ athena_kwargs["aws_secret_access_key"] = (
104
+ settings.athena_aws_secret_access_key
105
+ )
106
+ athena_client = boto3.client("athena", **athena_kwargs) # type: ignore[call-overload]
107
+ athena_connector = AthenaConnector(
108
+ athena_client=athena_client,
109
+ database=manifest.athena_database if manifest else "",
110
+ output_location=settings.athena_output_location,
111
+ workgroup=settings.athena_workgroup,
112
+ query_timeout_seconds=settings.athena_query_timeout_seconds,
113
+ )
114
+ await athena_connector.initialize()
115
+
116
+ kb_service = KnowledgeBaseService(opensearch_connector, athena_connector)
117
+
118
+ engine = None
119
+ db_session_factory = None
120
+ if settings.database_url:
121
+ engine = build_async_engine(settings.database_url)
122
+ db_session_factory = build_session_factory(engine)
123
+
124
+ orchestrator_adapter = get_orchestrator_adapter(settings)
125
+
126
+ # Celery / Redis
127
+ celery_app = _build_celery(settings, manifest) if settings.redis_url else None
128
+
129
+ # S3 client for inter-stage data
130
+ s3_client = None
131
+ if settings.pulse_s3_bucket:
132
+ s3_client = boto3.client(
133
+ "s3",
134
+ region_name=settings.aws_region,
135
+ aws_access_key_id=settings.aws_access_key_id or None,
136
+ aws_secret_access_key=settings.aws_secret_access_key or None,
137
+ )
138
+
139
+ # Job token issuer
140
+ token_issuer = None
141
+ if settings.pulse_job_token_secret:
142
+ from pulse_engine.core.job_token import JobTokenIssuer
143
+
144
+ token_issuer = JobTokenIssuer(
145
+ secret=settings.pulse_job_token_secret,
146
+ )
147
+
148
+ # Pipeline orchestration translators and status providers
149
+ pipeline_translators: dict[str, Any] = {}
150
+ pipeline_status_providers: dict[str, Any] = {}
151
+ if settings.prefect_api_url:
152
+ from pulse_engine.pipeline.translators.prefect_status import (
153
+ PrefectStatusProvider,
154
+ )
155
+ from pulse_engine.pipeline.translators.prefect_translator import (
156
+ PrefectTranslator,
157
+ )
158
+
159
+ pipeline_translators["prefect"] = PrefectTranslator(
160
+ prefect_api_url=settings.prefect_api_url,
161
+ prefect_api_key=settings.prefect_api_key,
162
+ work_pool_name=settings.prefect_ecs_work_pool_name,
163
+ engine_image=settings.prefect_engine_image,
164
+ )
165
+ pipeline_status_providers["prefect"] = PrefectStatusProvider(
166
+ prefect_api_url=settings.prefect_api_url,
167
+ prefect_api_key=settings.prefect_api_key,
168
+ )
169
+
170
+ # Build pluggable pipeline
171
+ pre: BasePreprocessor | None = _SENTINEL
172
+ core: BaseCoreProcessor | None = _SENTINEL
173
+ post: BasePostprocessor | None = _SENTINEL
174
+ if manifest:
175
+ pre = _resolve_stage(manifest.preprocessor)
176
+ core = _resolve_stage(manifest.core_processor)
177
+ post = _resolve_stage(manifest.postprocessor)
178
+
179
+ pipeline = ProcessingPipeline(
180
+ kb_service=kb_service,
181
+ preprocessor=pre,
182
+ core_processor=core,
183
+ postprocessor=post,
184
+ )
185
+
186
+ container = ServiceContainer(
187
+ settings=settings,
188
+ opensearch=opensearch,
189
+ kb_service=kb_service,
190
+ db_session_factory=db_session_factory,
191
+ orchestrator_adapter=orchestrator_adapter,
192
+ pipeline=pipeline,
193
+ celery_app=celery_app,
194
+ s3_client=s3_client,
195
+ token_issuer=token_issuer,
196
+ pipeline_translators=pipeline_translators,
197
+ pipeline_status_providers=pipeline_status_providers,
198
+ _engine=engine,
199
+ _opensearch_connector=opensearch_connector,
200
+ _athena_connector=athena_connector,
201
+ )
202
+
203
+ try:
204
+ yield container
205
+ finally:
206
+ if engine:
207
+ await engine.dispose()
208
+ await opensearch_connector.teardown()
209
+ if athena_connector:
210
+ await athena_connector.teardown()
211
+ await opensearch.close()
@@ -0,0 +1,84 @@
1
+ from typing import Any
2
+
3
+ import structlog
4
+ from opensearchpy import AsyncOpenSearch
5
+
6
+ logger = structlog.get_logger()
7
+
8
+
9
+ class OpenSearchService:
10
+ def __init__(
11
+ self,
12
+ url: str,
13
+ http_auth: tuple[str, str] | None = None,
14
+ use_ssl: bool = False,
15
+ verify_certs: bool = True,
16
+ ssl_show_warn: bool = True,
17
+ ) -> None:
18
+ hosts = [url]
19
+ kwargs: dict[str, Any] = {
20
+ "hosts": hosts,
21
+ "use_ssl": use_ssl,
22
+ "verify_certs": verify_certs,
23
+ "ssl_show_warn": ssl_show_warn,
24
+ }
25
+ if http_auth:
26
+ kwargs["http_auth"] = http_auth
27
+
28
+ self._client = AsyncOpenSearch(**kwargs)
29
+
30
+ async def ping(self) -> bool:
31
+ try:
32
+ return bool(await self._client.ping())
33
+ except Exception:
34
+ logger.warning("opensearch_ping_failed")
35
+ return False
36
+
37
+ async def index_document(
38
+ self, index: str, doc_id: str | None, body: dict[str, Any]
39
+ ) -> dict[str, Any]:
40
+ result: dict[str, Any] = await self._client.index(
41
+ index=index, id=doc_id, body=body
42
+ )
43
+ return result
44
+
45
+ async def get_document(self, index: str, doc_id: str) -> dict[str, Any]:
46
+ result: dict[str, Any] = await self._client.get(index=index, id=doc_id)
47
+ return result
48
+
49
+ async def search(self, index: str, query: dict[str, Any]) -> dict[str, Any]:
50
+ result: dict[str, Any] = await self._client.search(index=index, body=query)
51
+ return result
52
+
53
+ async def delete_document(self, index: str, doc_id: str) -> dict[str, Any]:
54
+ result: dict[str, Any] = await self._client.delete(index=index, id=doc_id)
55
+ return result
56
+
57
+ async def create_index(
58
+ self, index: str, body: dict[str, Any]
59
+ ) -> dict[str, Any] | None:
60
+ if await self._client.indices.exists(index=index):
61
+ logger.info("index_already_exists", index=index)
62
+ return None
63
+ result: dict[str, Any] = await self._client.indices.create(
64
+ index=index, body=body
65
+ )
66
+ return result
67
+
68
+ async def delete_index(self, index: str) -> dict[str, Any]:
69
+ result: dict[str, Any] = await self._client.indices.delete(index=index)
70
+ return result
71
+
72
+ async def bulk(self, body: list[dict[str, Any]]) -> dict[str, Any]:
73
+ result: dict[str, Any] = await self._client.bulk(body=body)
74
+ return result
75
+
76
+ async def index_stats(self, index: str) -> dict[str, Any]:
77
+ result: dict[str, Any] = await self._client.indices.stats(index=index)
78
+ return result
79
+
80
+ async def index_exists(self, index: str) -> bool:
81
+ return bool(await self._client.indices.exists(index=index))
82
+
83
+ async def close(self) -> None:
84
+ await self._client.close()
File without changes
File without changes
@@ -0,0 +1,226 @@
1
+ import asyncio
2
+ import re
3
+ import time
4
+ from typing import Any
5
+
6
+ import structlog
7
+
8
+ from pulse_engine.core.exceptions import (
9
+ AppError,
10
+ BadRequestError,
11
+ ServiceUnavailableError,
12
+ )
13
+ from pulse_engine.storage.connectors.base import BaseStorageConnector
14
+ from pulse_engine.storage.schemas import (
15
+ ConnectorHealth,
16
+ Document,
17
+ QueryResult,
18
+ SearchQuery,
19
+ SearchResult,
20
+ StoreResult,
21
+ )
22
+
23
+ logger = structlog.get_logger()
24
+
25
+ _TENANT_ID_RE = re.compile(r"^[a-zA-Z0-9_-]{1,128}$")
26
+
27
+ # SQL keywords that are forbidden in user-supplied queries
28
+ _FORBIDDEN_SQL_RE = re.compile(
29
+ r"""
30
+ \b(
31
+ INSERT\s+INTO | UPDATE\s+\S+\s+SET | DELETE\s+FROM |
32
+ DROP\s+(TABLE|DATABASE|INDEX|VIEW|SCHEMA) |
33
+ ALTER\s+(TABLE|DATABASE|INDEX|VIEW|SCHEMA) |
34
+ CREATE\s+(TABLE|DATABASE|INDEX|VIEW|SCHEMA) |
35
+ TRUNCATE | GRANT | REVOKE | EXEC | EXECUTE |
36
+ MERGE\s+INTO | CALL
37
+ )\b
38
+ """,
39
+ re.IGNORECASE | re.VERBOSE,
40
+ )
41
+
42
+ # Detect multiple SQL statements (semicolons not inside string literals)
43
+ _MULTI_STATEMENT_RE = re.compile(r";\s*\S")
44
+
45
+ # Detect UNION-based injection attempts
46
+ _UNION_RE = re.compile(r"\bUNION\b\s+(ALL\s+)?\bSELECT\b", re.IGNORECASE)
47
+
48
+
49
+ class AthenaConnector(BaseStorageConnector):
50
+ def __init__(
51
+ self,
52
+ athena_client: Any,
53
+ database: str,
54
+ output_location: str,
55
+ workgroup: str = "primary",
56
+ query_timeout_seconds: int = 60,
57
+ ) -> None:
58
+ self._client = athena_client
59
+ self._database = database
60
+ self._output_location = output_location
61
+ self._workgroup = workgroup
62
+ self._query_timeout = query_timeout_seconds
63
+
64
+ async def initialize(self) -> None:
65
+ logger.info("athena_connector_initialized")
66
+
67
+ async def store(self, tenant_id: str, documents: list[Document]) -> StoreResult:
68
+ raise NotImplementedError("Athena connector does not support store operations")
69
+
70
+ async def retrieve(self, tenant_id: str, doc_id: str) -> Document | None:
71
+ raise NotImplementedError(
72
+ "Athena connector does not support retrieve operations"
73
+ )
74
+
75
+ async def search(self, tenant_id: str, query: SearchQuery) -> SearchResult:
76
+ raise NotImplementedError("Athena connector does not support search operations")
77
+
78
+ @staticmethod
79
+ def _validate_sql(sql: str) -> None:
80
+ """Reject dangerous SQL patterns before execution.
81
+
82
+ Only SELECT statements are permitted. DDL, DML mutations,
83
+ multi-statement payloads, and UNION injections are blocked.
84
+ """
85
+ stripped = sql.strip().rstrip(";").strip()
86
+
87
+ if not stripped.upper().startswith("SELECT"):
88
+ raise BadRequestError("Only SELECT queries are allowed against Athena")
89
+
90
+ if _FORBIDDEN_SQL_RE.search(stripped):
91
+ raise BadRequestError("Query contains forbidden SQL keywords")
92
+
93
+ if _MULTI_STATEMENT_RE.search(stripped):
94
+ raise BadRequestError("Multiple SQL statements are not allowed")
95
+
96
+ if _UNION_RE.search(stripped):
97
+ raise BadRequestError("UNION SELECT queries are not allowed")
98
+
99
+ @staticmethod
100
+ def _inject_tenant_filter(sql: str, tenant_id: str) -> str:
101
+ if not _TENANT_ID_RE.match(tenant_id):
102
+ raise BadRequestError("Invalid tenant ID format")
103
+ where_pattern = re.compile(r"\bWHERE\b", re.IGNORECASE)
104
+ match = where_pattern.search(sql)
105
+ tenant_clause = f"tenant_id = '{tenant_id}'"
106
+
107
+ if match:
108
+ insert_pos = match.end()
109
+ return f"{sql[:insert_pos]} {tenant_clause} AND{sql[insert_pos:]}"
110
+ else:
111
+ order_pattern = re.compile(
112
+ r"\b(ORDER BY|GROUP BY|HAVING|LIMIT)\b", re.IGNORECASE
113
+ )
114
+ order_match = order_pattern.search(sql)
115
+ if order_match:
116
+ insert_pos = order_match.start()
117
+ return f"{sql[:insert_pos]}WHERE {tenant_clause} {sql[insert_pos:]}"
118
+ return f"{sql} WHERE {tenant_clause}"
119
+
120
+ async def execute_query(
121
+ self,
122
+ tenant_id: str,
123
+ sql: str,
124
+ parameters: dict[str, Any] | None = None,
125
+ database: str | None = None,
126
+ ) -> QueryResult:
127
+ """Execute an Athena SQL query.
128
+
129
+ Args:
130
+ database: The Athena database to query. Products specify this at
131
+ query time; falls back to the connector-level default.
132
+ """
133
+ self._validate_sql(sql)
134
+ filtered_sql = self._inject_tenant_filter(sql, tenant_id)
135
+ target_database = database or self._database
136
+
137
+ try:
138
+ start = time.monotonic()
139
+ query_kwargs: dict[str, Any] = {
140
+ "QueryString": filtered_sql,
141
+ "ResultConfiguration": {"OutputLocation": self._output_location},
142
+ "WorkGroup": self._workgroup,
143
+ }
144
+ if target_database:
145
+ query_kwargs["QueryExecutionContext"] = {"Database": target_database}
146
+
147
+ response = await asyncio.to_thread(
148
+ self._client.start_query_execution,
149
+ **query_kwargs,
150
+ )
151
+ query_id = response["QueryExecutionId"]
152
+
153
+ elapsed = 0.0
154
+ wait = 0.5
155
+ while elapsed < self._query_timeout:
156
+ status_resp = await asyncio.to_thread(
157
+ self._client.get_query_execution,
158
+ QueryExecutionId=query_id,
159
+ )
160
+ state = status_resp["QueryExecution"]["Status"]["State"]
161
+
162
+ if state == "SUCCEEDED":
163
+ break
164
+ elif state == "FAILED":
165
+ reason = status_resp["QueryExecution"]["Status"].get(
166
+ "StateChangeReason", "Query failed"
167
+ )
168
+ logger.error(
169
+ "athena_query_failed", reason=reason, query_id=query_id
170
+ )
171
+ raise AppError("Query execution failed")
172
+ elif state == "CANCELLED":
173
+ raise AppError("Athena query was cancelled")
174
+
175
+ await asyncio.sleep(wait)
176
+ elapsed += wait
177
+ wait = min(wait * 2, 5.0)
178
+ else:
179
+ raise ServiceUnavailableError("Athena query timed out")
180
+
181
+ results_resp = await asyncio.to_thread(
182
+ self._client.get_query_results,
183
+ QueryExecutionId=query_id,
184
+ )
185
+
186
+ result_set = results_resp["ResultSet"]
187
+ columns = [
188
+ col["Label"] for col in result_set["ResultSetMetadata"]["ColumnInfo"]
189
+ ]
190
+
191
+ rows: list[list[Any]] = []
192
+ data_rows = result_set.get("Rows", [])
193
+ for row in data_rows[1:]:
194
+ rows.append([d.get("VarCharValue", "") for d in row["Data"]])
195
+
196
+ execution_time_ms = int((time.monotonic() - start) * 1000)
197
+
198
+ return QueryResult(
199
+ columns=columns,
200
+ rows=rows,
201
+ row_count=len(rows),
202
+ execution_time_ms=execution_time_ms,
203
+ )
204
+ except (AppError, ServiceUnavailableError):
205
+ raise
206
+ except Exception as e:
207
+ logger.error("athena_query_error", exc_info=True)
208
+ raise ServiceUnavailableError("Query execution unavailable") from e
209
+
210
+ async def delete(self, tenant_id: str, doc_id: str) -> bool:
211
+ raise NotImplementedError("Athena connector does not support delete operations")
212
+
213
+ async def health_check(self) -> ConnectorHealth:
214
+ start = time.monotonic()
215
+ try:
216
+ await asyncio.to_thread(self._client.list_work_groups)
217
+ latency = (time.monotonic() - start) * 1000
218
+ return ConnectorHealth(connector="athena", status="up", latency_ms=latency)
219
+ except Exception:
220
+ latency = (time.monotonic() - start) * 1000
221
+ return ConnectorHealth(
222
+ connector="athena", status="down", latency_ms=latency
223
+ )
224
+
225
+ async def teardown(self) -> None:
226
+ logger.info("athena_connector_teardown")