spark-advisor-analyzer 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. spark_advisor_analyzer-0.1.6/.gitignore +38 -0
  2. spark_advisor_analyzer-0.1.6/Dockerfile +13 -0
  3. spark_advisor_analyzer-0.1.6/PKG-INFO +73 -0
  4. spark_advisor_analyzer-0.1.6/README.md +47 -0
  5. spark_advisor_analyzer-0.1.6/pyproject.toml +87 -0
  6. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/__init__.py +1 -0
  7. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/agent/__init__.py +4 -0
  8. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/agent/context.py +29 -0
  9. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/agent/handlers.py +182 -0
  10. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/agent/orchestrator.py +167 -0
  11. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/agent/prompts.py +67 -0
  12. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/agent/tools.py +109 -0
  13. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/ai/__init__.py +4 -0
  14. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/ai/client.py +54 -0
  15. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/ai/prompts.py +166 -0
  16. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/ai/report_builder.py +39 -0
  17. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/ai/service.py +50 -0
  18. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/ai/tool_config.py +118 -0
  19. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/app.py +56 -0
  20. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/config.py +27 -0
  21. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/factory.py +32 -0
  22. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/handlers.py +44 -0
  23. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/orchestrator.py +38 -0
  24. spark_advisor_analyzer-0.1.6/src/spark_advisor_analyzer/py.typed +0 -0
  25. spark_advisor_analyzer-0.1.6/tests/__init__.py +0 -0
  26. spark_advisor_analyzer-0.1.6/tests/agent_factories.py +75 -0
  27. spark_advisor_analyzer-0.1.6/tests/conftest.py +29 -0
  28. spark_advisor_analyzer-0.1.6/tests/test_agent_orchestrator.py +165 -0
  29. spark_advisor_analyzer-0.1.6/tests/test_agent_prompts.py +42 -0
  30. spark_advisor_analyzer-0.1.6/tests/test_agent_tools.py +112 -0
  31. spark_advisor_analyzer-0.1.6/tests/test_ai_client.py +37 -0
  32. spark_advisor_analyzer-0.1.6/tests/test_ai_prompts.py +311 -0
  33. spark_advisor_analyzer-0.1.6/tests/test_ai_service.py +122 -0
  34. spark_advisor_analyzer-0.1.6/tests/test_handlers.py +25 -0
  35. spark_advisor_analyzer-0.1.6/tests/test_orchestrator.py +63 -0
@@ -0,0 +1,38 @@
1
+ CLAUDE.md
2
+ .claude/
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.py[cod]
7
+ *.pyo
8
+ *.egg-info/
9
+ dist/
10
+ build/
11
+
12
+ # Virtual environment
13
+ .venv/
14
+ .envrc
15
+
16
+ # Testing
17
+ .coverage
18
+ .pytest_cache/
19
+ htmlcov/
20
+
21
+ # Type checking
22
+ .mypy_cache/
23
+
24
+ # Ruff
25
+ .ruff_cache/
26
+
27
+ # IDE
28
+ .idea/
29
+ .vscode/
30
+ *.swp
31
+ *.swo
32
+ *~
33
+
34
+ # OS
35
+ .DS_Store
36
+ Thumbs.db
37
+ /.claude/
38
+ tasks
@@ -0,0 +1,13 @@
1
+ FROM python:3.12-slim AS base
2
+
3
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
4
+
5
+ WORKDIR /app
6
+ COPY pyproject.toml uv.lock ./
7
+ COPY packages/spark-advisor-models/ packages/spark-advisor-models/
8
+ COPY packages/spark-advisor-rules/ packages/spark-advisor-rules/
9
+ COPY packages/spark-advisor-analyzer/ packages/spark-advisor-analyzer/
10
+
11
+ RUN uv sync --frozen --no-dev --package spark-advisor-analyzer
12
+
13
+ CMD ["uv", "run", "--package", "spark-advisor-analyzer", "spark-advisor-analyzer"]
@@ -0,0 +1,73 @@
1
+ Metadata-Version: 2.4
2
+ Name: spark-advisor-analyzer
3
+ Version: 0.1.6
4
+ Summary: AI-powered Spark job analyzer - NATS worker service
5
+ Project-URL: Homepage, https://github.com/pstysz/spark-advisor
6
+ Project-URL: Repository, https://github.com/pstysz/spark-advisor
7
+ Project-URL: Issues, https://github.com/pstysz/spark-advisor/issues
8
+ Author: Pawel Stysz
9
+ License-Expression: Apache-2.0
10
+ Keywords: ai,analyzer,apache-spark,claude,nats,performance,spark
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: Apache Software License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Classifier: Topic :: Software Development :: Quality Assurance
18
+ Classifier: Typing :: Typed
19
+ Requires-Python: >=3.12
20
+ Requires-Dist: anthropic>=0.52
21
+ Requires-Dist: faststream[nats]>=0.6
22
+ Requires-Dist: pyyaml>=6.0
23
+ Requires-Dist: spark-advisor-models==0.1.6
24
+ Requires-Dist: spark-advisor-rules==0.1.6
25
+ Description-Content-Type: text/markdown
26
+
27
+ # spark-advisor-analyzer
28
+
29
+ AI-powered Spark job analyzer — NATS worker service. Part of the [spark-advisor](https://github.com/pstysz/spark-advisor) ecosystem.
30
+
31
+ ## Install
32
+
33
+ ```bash
34
+ pip install spark-advisor-analyzer
35
+ ```
36
+
37
+ ## What it does
38
+
39
+ Combines deterministic rules engine with optional Claude AI analysis to provide actionable Spark configuration recommendations.
40
+
41
+ - **Rules engine** — 11 expert rules detecting data skew, GC pressure, disk spill, and more
42
+ - **AI analysis** — Claude API integration for prioritized recommendations with causal chains
43
+ - **Agent mode** — multi-turn Claude tool_use loop for autonomous job exploration
44
+ - **NATS worker** — FastStream subscriber for `analyze.request` and `analyze.agent.request`
45
+
46
+ ## Deployment
47
+
48
+ As a standalone NATS worker:
49
+
50
+ ```bash
51
+ export SA_ANALYZER_NATS__URL=nats://localhost:4222
52
+ export ANTHROPIC_API_KEY=sk-ant-...
53
+ spark-advisor-analyzer
54
+ ```
55
+
56
+ Or use via `spark-advisor` CLI or `spark-advisor-mcp` server.
57
+
58
+ ## Configuration
59
+
60
+ | Variable | Default | Description |
61
+ |----------|---------|-------------|
62
+ | `SA_ANALYZER_NATS__URL` | `nats://localhost:4222` | NATS broker URL |
63
+ | `SA_ANALYZER_AI__ENABLED` | `true` | Enable AI analysis |
64
+ | `ANTHROPIC_API_KEY` | — | Claude API key (required if AI enabled) |
65
+
66
+ ## Links
67
+
68
+ - [Main project](https://github.com/pstysz/spark-advisor)
69
+ - [Contributing](https://github.com/pstysz/spark-advisor/blob/main/CONTRIBUTING.md)
70
+
71
+ ## License
72
+
73
+ Apache 2.0
@@ -0,0 +1,47 @@
1
+ # spark-advisor-analyzer
2
+
3
+ AI-powered Spark job analyzer — NATS worker service. Part of the [spark-advisor](https://github.com/pstysz/spark-advisor) ecosystem.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install spark-advisor-analyzer
9
+ ```
10
+
11
+ ## What it does
12
+
13
+ Combines deterministic rules engine with optional Claude AI analysis to provide actionable Spark configuration recommendations.
14
+
15
+ - **Rules engine** — 11 expert rules detecting data skew, GC pressure, disk spill, and more
16
+ - **AI analysis** — Claude API integration for prioritized recommendations with causal chains
17
+ - **Agent mode** — multi-turn Claude tool_use loop for autonomous job exploration
18
+ - **NATS worker** — FastStream subscriber for `analyze.request` and `analyze.agent.request`
19
+
20
+ ## Deployment
21
+
22
+ As a standalone NATS worker:
23
+
24
+ ```bash
25
+ export SA_ANALYZER_NATS__URL=nats://localhost:4222
26
+ export ANTHROPIC_API_KEY=sk-ant-...
27
+ spark-advisor-analyzer
28
+ ```
29
+
30
+ Or use via `spark-advisor` CLI or `spark-advisor-mcp` server.
31
+
32
+ ## Configuration
33
+
34
+ | Variable | Default | Description |
35
+ |----------|---------|-------------|
36
+ | `SA_ANALYZER_NATS__URL` | `nats://localhost:4222` | NATS broker URL |
37
+ | `SA_ANALYZER_AI__ENABLED` | `true` | Enable AI analysis |
38
+ | `ANTHROPIC_API_KEY` | — | Claude API key (required if AI enabled) |
39
+
40
+ ## Links
41
+
42
+ - [Main project](https://github.com/pstysz/spark-advisor)
43
+ - [Contributing](https://github.com/pstysz/spark-advisor/blob/main/CONTRIBUTING.md)
44
+
45
+ ## License
46
+
47
+ Apache 2.0
@@ -0,0 +1,87 @@
1
+ [project]
2
+ name = "spark-advisor-analyzer"
3
+ version = "0.1.6" # x-release-please-version
4
+ description = "AI-powered Spark job analyzer - NATS worker service"
5
+ readme = "README.md"
6
+ license = "Apache-2.0"
7
+ requires-python = ">=3.12"
8
+ authors = [
9
+ { name = "Pawel Stysz" },
10
+ ]
11
+ keywords = ["spark", "apache-spark", "ai", "claude", "nats", "performance", "analyzer"]
12
+ classifiers = [
13
+ "Development Status :: 4 - Beta",
14
+ "Intended Audience :: Developers",
15
+ "License :: OSI Approved :: Apache Software License",
16
+ "Programming Language :: Python :: 3",
17
+ "Programming Language :: Python :: 3.12",
18
+ "Programming Language :: Python :: 3.13",
19
+ "Topic :: Software Development :: Quality Assurance",
20
+ "Typing :: Typed",
21
+ ]
22
+ dependencies = [
23
+ "spark-advisor-models==0.1.6", # x-release-please-version
24
+ "spark-advisor-rules==0.1.6", # x-release-please-version
25
+ "anthropic>=0.52",
26
+ "faststream[nats]>=0.6",
27
+ "pyyaml>=6.0",
28
+ ]
29
+
30
+ [project.urls]
31
+ Homepage = "https://github.com/pstysz/spark-advisor"
32
+ Repository = "https://github.com/pstysz/spark-advisor"
33
+ Issues = "https://github.com/pstysz/spark-advisor/issues"
34
+
35
+ [project.scripts]
36
+ spark-advisor-analyzer = "spark_advisor_analyzer.app:main"
37
+
38
+ [build-system]
39
+ requires = ["hatchling"]
40
+ build-backend = "hatchling.build"
41
+
42
+ [tool.hatch.build.targets.wheel]
43
+ packages = ["src/spark_advisor_analyzer"]
44
+
45
+ [tool.uv.sources]
46
+ spark-advisor-models = { workspace = true }
47
+ spark-advisor-rules = { workspace = true }
48
+
49
+ [dependency-groups]
50
+ dev = [
51
+ "pytest>=8.3",
52
+ "pytest-asyncio>=1.0",
53
+ "pytest-cov>=6.1",
54
+ "mypy>=1.15",
55
+ "ruff>=0.11",
56
+ ]
57
+
58
+ [tool.pytest.ini_options]
59
+ testpaths = ["tests"]
60
+ pythonpath = ["src", "tests"]
61
+ asyncio_mode = "auto"
62
+
63
+ [tool.ruff]
64
+ target-version = "py312"
65
+ line-length = 120
66
+ src = ["src", "tests"]
67
+
68
+ [tool.ruff.lint]
69
+ select = ["E", "W", "F", "I", "UP", "B", "SIM", "TCH", "RUF"]
70
+
71
+ [tool.ruff.lint.isort]
72
+ known-first-party = ["spark_advisor_analyzer", "spark_advisor_models", "spark_advisor_rules"]
73
+
74
+ [tool.ruff.lint.flake8-type-checking]
75
+ runtime-evaluated-base-classes = ["pydantic.BaseModel", "pydantic_settings.BaseSettings"]
76
+
77
+ [tool.mypy]
78
+ python_version = "3.12"
79
+ strict = true
80
+ plugins = ["pydantic.mypy"]
81
+
82
+ [tool.coverage.run]
83
+ source = ["spark_advisor_analyzer"]
84
+ branch = true
85
+
86
+ [tool.coverage.report]
87
+ show_missing = true
@@ -0,0 +1 @@
1
+ """AI-powered Spark job analyzer - NATS worker service."""
@@ -0,0 +1,4 @@
1
+ from spark_advisor_analyzer.agent.orchestrator import AgentOrchestrator
2
+ from spark_advisor_analyzer.agent.tools import AgentToolName
3
+
4
+ __all__ = ["AgentOrchestrator", "AgentToolName"]
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from spark_advisor_models.model import JobAnalysis, RuleResult
8
+ from spark_advisor_rules import StaticAnalysisService
9
+
10
+
11
+ @dataclass
12
+ class AgentContext:
13
+ job: JobAnalysis
14
+ _rule_results: list[RuleResult] = field(default_factory=list, init=False)
15
+ _rules_executed: bool = field(default=False, init=False)
16
+
17
+ @property
18
+ def rules_executed(self) -> bool:
19
+ return self._rules_executed
20
+
21
+ @property
22
+ def rule_results(self) -> list[RuleResult]:
23
+ return self._rule_results
24
+
25
+ def get_or_run_rules(self, static: StaticAnalysisService) -> list[RuleResult]:
26
+ if not self._rules_executed:
27
+ self._rule_results = static.analyze(self.job)
28
+ self._rules_executed = True
29
+ return self._rule_results
@@ -0,0 +1,182 @@
1
+ import json
2
+ import math
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, ValidationError
6
+
7
+ from spark_advisor_analyzer.agent.context import AgentContext
8
+ from spark_advisor_analyzer.agent.tools import (
9
+ AgentToolName,
10
+ CalculateOptimalPartitionsInput,
11
+ CompareConfigsInput,
12
+ GetStageDetailsInput,
13
+ )
14
+ from spark_advisor_analyzer.ai.tool_config import IMPORTANT_KEYS
15
+ from spark_advisor_models.util import format_bytes
16
+ from spark_advisor_rules import StaticAnalysisService
17
+
18
+ _TOOL_INPUT_MODELS: dict[AgentToolName, type[BaseModel]] = {
19
+ AgentToolName.GET_STAGE_DETAILS: GetStageDetailsInput,
20
+ AgentToolName.CALCULATE_OPTIMAL_PARTITIONS: CalculateOptimalPartitionsInput,
21
+ AgentToolName.COMPARE_CONFIGS: CompareConfigsInput,
22
+ }
23
+
24
+
25
+ class ToolExecutionError(Exception):
26
+ pass
27
+
28
+
29
+ def execute_tool(
30
+ name: str,
31
+ input_data: dict[str, Any],
32
+ context: AgentContext,
33
+ static_analysis: StaticAnalysisService,
34
+ ) -> str:
35
+ try:
36
+ tool = AgentToolName(name)
37
+ except ValueError as err:
38
+ raise ToolExecutionError(f"Unknown tool: {name}") from err
39
+
40
+ if tool == AgentToolName.SUBMIT_FINAL_REPORT:
41
+ raise ToolExecutionError("submit_final_report is handled by the orchestrator, not execute_tool")
42
+
43
+ input_model = _TOOL_INPUT_MODELS.get(tool)
44
+ validated: BaseModel | None = None
45
+ if input_model is not None:
46
+ try:
47
+ validated = input_model.model_validate(input_data)
48
+ except ValidationError as e:
49
+ raise ToolExecutionError(f"Invalid input for {name}: {e}") from e
50
+
51
+ match tool:
52
+ case AgentToolName.GET_JOB_OVERVIEW:
53
+ return _handle_get_job_overview(context)
54
+ case AgentToolName.GET_STAGE_DETAILS:
55
+ assert isinstance(validated, GetStageDetailsInput)
56
+ return _handle_get_stage_details(context, validated)
57
+ case AgentToolName.RUN_RULES_ENGINE:
58
+ return _handle_run_rules_engine(context, static_analysis)
59
+ case AgentToolName.CALCULATE_OPTIMAL_PARTITIONS:
60
+ assert isinstance(validated, CalculateOptimalPartitionsInput)
61
+ return _handle_calculate_optimal_partitions(context, validated)
62
+ case AgentToolName.COMPARE_CONFIGS:
63
+ assert isinstance(validated, CompareConfigsInput)
64
+ return _handle_compare_configs(context, validated)
65
+ case _:
66
+ raise ToolExecutionError(f"Unhandled tool: {tool}")
67
+
68
+
69
+ def _handle_get_job_overview(context: AgentContext) -> str:
70
+ job = context.job
71
+ total_tasks = sum(s.tasks.task_count for s in job.stages)
72
+ total_shuffle_read = sum(s.total_shuffle_read_bytes for s in job.stages)
73
+ total_shuffle_write = sum(s.total_shuffle_write_bytes for s in job.stages)
74
+ total_spill = sum(s.spill_to_disk_bytes for s in job.stages)
75
+ total_input = sum(s.input_bytes for s in job.stages)
76
+
77
+ config_snapshot = {k: job.config.get(k) for k in IMPORTANT_KEYS if job.config.get(k)}
78
+
79
+ overview: dict[str, Any] = {
80
+ "app_id": job.app_id,
81
+ "app_name": job.app_name,
82
+ "spark_version": job.spark_version,
83
+ "duration_ms": job.duration_ms,
84
+ "duration_min": round(job.duration_ms / 60_000, 1),
85
+ "stage_count": len(job.stages),
86
+ "total_tasks": total_tasks,
87
+ "total_input": format_bytes(total_input),
88
+ "total_shuffle_read": format_bytes(total_shuffle_read),
89
+ "total_shuffle_write": format_bytes(total_shuffle_write),
90
+ "total_spill_to_disk": format_bytes(total_spill),
91
+ "stages_summary": [
92
+ {
93
+ "stage_id": s.stage_id,
94
+ "name": s.stage_name,
95
+ "tasks": s.tasks.task_count,
96
+ "gc_percent": round(s.gc_time_percent, 1) if s.sum_executor_run_time_ms > 0 else 0.0,
97
+ "has_spill": s.spill_to_disk_bytes > 0,
98
+ "has_shuffle": s.total_shuffle_read_bytes > 0 or s.total_shuffle_write_bytes > 0,
99
+ "skew_ratio": round(s.tasks.duration_skew_ratio, 1),
100
+ "failed_tasks": s.failed_task_count,
101
+ }
102
+ for s in job.stages
103
+ ],
104
+ "config": config_snapshot,
105
+ }
106
+
107
+ if job.executors:
108
+ overview["executors"] = {
109
+ "count": job.executors.executor_count,
110
+ "total_cores": job.executors.total_cores,
111
+ "memory_utilization_percent": round(job.executors.memory_utilization_percent, 1),
112
+ "failed_tasks": job.executors.failed_tasks,
113
+ }
114
+
115
+ return json.dumps(overview)
116
+
117
+
118
+ def _handle_get_stage_details(context: AgentContext, input_data: GetStageDetailsInput) -> str:
119
+ stage = next((s for s in context.job.stages if s.stage_id == input_data.stage_id), None)
120
+ if stage is None:
121
+ available = [s.stage_id for s in context.job.stages]
122
+ raise ToolExecutionError(f"Stage {input_data.stage_id} not found. Available stages: {available}")
123
+
124
+ return stage.model_dump_json()
125
+
126
+
127
+ def _handle_run_rules_engine(context: AgentContext, static_analysis: StaticAnalysisService) -> str:
128
+ results = context.get_or_run_rules(static_analysis)
129
+
130
+ findings = [
131
+ {
132
+ "rule_id": r.rule_id,
133
+ "severity": r.severity.value,
134
+ "title": r.title,
135
+ "message": r.message,
136
+ "stage_id": r.stage_id,
137
+ "current_value": r.current_value,
138
+ "recommended_value": r.recommended_value,
139
+ }
140
+ for r in results
141
+ ]
142
+ return json.dumps({"findings_count": len(findings), "findings": findings})
143
+
144
+
145
+ def _handle_calculate_optimal_partitions(
146
+ context: AgentContext,
147
+ input_data: CalculateOptimalPartitionsInput,
148
+ ) -> str:
149
+ total_bytes = input_data.total_shuffle_bytes
150
+ target_mb = input_data.target_partition_mb
151
+ target_bytes = target_mb * 1024 * 1024
152
+
153
+ optimal = max(1, math.ceil(total_bytes / target_bytes)) if total_bytes > 0 else 1
154
+ current = context.job.config.shuffle_partitions
155
+
156
+ actual_size_mb = round(total_bytes / (optimal * 1024 * 1024), 1) if optimal > 0 else 0
157
+
158
+ return json.dumps(
159
+ {
160
+ "current_partitions": current,
161
+ "optimal_partitions": optimal,
162
+ "total_shuffle_bytes": total_bytes,
163
+ "total_shuffle_formatted": format_bytes(total_bytes),
164
+ "target_partition_size_mb": target_mb,
165
+ "actual_partition_size_mb": actual_size_mb,
166
+ }
167
+ )
168
+
169
+
170
+ def _handle_compare_configs(context: AgentContext, input_data: CompareConfigsInput) -> str:
171
+ changes = []
172
+ for key, new_value in input_data.proposed_changes.items():
173
+ current = context.job.config.get(key)
174
+ changes.append(
175
+ {
176
+ "parameter": key,
177
+ "current_value": current or "(not set)",
178
+ "proposed_value": new_value,
179
+ "changed": current != new_value,
180
+ }
181
+ )
182
+ return json.dumps({"changes": changes, "total_changes": len(changes)})
@@ -0,0 +1,167 @@
1
+ import json
2
+ import logging
3
+ from typing import Any
4
+
5
+ from anthropic.types import (
6
+ Message,
7
+ MessageParam,
8
+ ToolChoiceAutoParam,
9
+ ToolChoiceToolParam,
10
+ ToolResultBlockParam,
11
+ ToolUseBlock,
12
+ )
13
+ from pydantic import ValidationError
14
+
15
+ from spark_advisor_analyzer.agent.context import AgentContext
16
+ from spark_advisor_analyzer.agent.handlers import ToolExecutionError, execute_tool
17
+ from spark_advisor_analyzer.agent.prompts import build_agent_system_prompt, build_initial_message
18
+ from spark_advisor_analyzer.agent.tools import AGENT_TOOLS, AgentToolName
19
+ from spark_advisor_analyzer.ai.client import AnthropicClient
20
+ from spark_advisor_analyzer.ai.report_builder import build_advisor_report
21
+ from spark_advisor_models.config import AiSettings
22
+ from spark_advisor_models.model import AnalysisResult, AnalysisToolInput, JobAnalysis
23
+ from spark_advisor_rules import StaticAnalysisService
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class AgentOrchestrator:
29
+ def __init__(
30
+ self,
31
+ client: AnthropicClient,
32
+ static_analysis: StaticAnalysisService,
33
+ ai_settings: AiSettings,
34
+ ) -> None:
35
+ self._client = client
36
+ self._static = static_analysis
37
+ self._ai = ai_settings
38
+ self._max_iterations = ai_settings.max_agent_iterations
39
+ self._system_prompt = build_agent_system_prompt(self._max_iterations)
40
+
41
+ def run(self, job: JobAnalysis) -> AnalysisResult:
42
+ context = AgentContext(job=job)
43
+ messages: list[MessageParam] = [
44
+ MessageParam(role="user", content=build_initial_message(job)),
45
+ ]
46
+
47
+ for iteration in range(self._max_iterations):
48
+ logger.info("Agent iteration %d/%d", iteration + 1, self._max_iterations)
49
+
50
+ response = self._call_claude(messages, tool_choice=ToolChoiceAutoParam(type="auto"))
51
+
52
+ final_input = self._extract_final_report(response)
53
+ if final_input is not None:
54
+ logger.info("Agent completed after %d iterations", iteration + 1)
55
+ return self._build_result(job, context, final_input)
56
+
57
+ tool_uses = [b for b in response.content if isinstance(b, ToolUseBlock)]
58
+
59
+ if not tool_uses:
60
+ messages.append(MessageParam(role="assistant", content=response.content))
61
+ messages.append(
62
+ MessageParam(
63
+ role="user",
64
+ content=f"Please call {AgentToolName.SUBMIT_FINAL_REPORT} with your analysis now.",
65
+ )
66
+ )
67
+ continue
68
+
69
+ messages.append(MessageParam(role="assistant", content=response.content))
70
+
71
+ tool_results: list[ToolResultBlockParam] = []
72
+ for tool_use in tool_uses:
73
+ logger.info("Executing tool: %s", tool_use.name)
74
+ result_content = self._execute_single_tool(tool_use.name, tool_use.input, context)
75
+ tool_results.append(
76
+ {
77
+ "type": "tool_result",
78
+ "tool_use_id": tool_use.id,
79
+ "content": result_content,
80
+ }
81
+ )
82
+
83
+ messages.append(MessageParam(role="user", content=tool_results))
84
+
85
+ logger.warning("Agent hit max iterations (%d), forcing final report", self._max_iterations)
86
+ return self._force_final_report(job, context, messages)
87
+
88
+ def _call_claude(
89
+ self,
90
+ messages: list[MessageParam],
91
+ *,
92
+ tool_choice: ToolChoiceAutoParam | ToolChoiceToolParam,
93
+ ) -> Message:
94
+ return self._client.create_message(
95
+ model=self._ai.model,
96
+ max_tokens=self._ai.max_tokens,
97
+ system=self._system_prompt,
98
+ messages=messages,
99
+ tools=AGENT_TOOLS,
100
+ tool_choice=tool_choice,
101
+ )
102
+
103
+ def _execute_single_tool(
104
+ self,
105
+ name: str,
106
+ input_data: Any,
107
+ context: AgentContext,
108
+ ) -> str:
109
+ try:
110
+ return execute_tool(name, input_data, context, self._static)
111
+ except ToolExecutionError as e:
112
+ logger.warning("Tool %s failed: %s", name, e)
113
+ return json.dumps({"error": str(e)})
114
+ except Exception:
115
+ logger.exception("Unexpected error in tool %s", name)
116
+ return json.dumps({"error": f"Internal error executing {name}"})
117
+
118
+ def _force_final_report(
119
+ self,
120
+ job: JobAnalysis,
121
+ context: AgentContext,
122
+ messages: list[MessageParam],
123
+ ) -> AnalysisResult:
124
+ messages.append(
125
+ MessageParam(
126
+ role="user",
127
+ content=(
128
+ "You have reached the maximum number of tool calls. "
129
+ "You MUST call submit_final_report now with your best analysis "
130
+ "based on the information gathered so far."
131
+ ),
132
+ )
133
+ )
134
+
135
+ response = self._call_claude(
136
+ messages,
137
+ tool_choice=ToolChoiceToolParam(type="tool", name=AgentToolName.SUBMIT_FINAL_REPORT),
138
+ )
139
+
140
+ final_input = self._extract_final_report(response)
141
+ if final_input is not None:
142
+ return self._build_result(job, context, final_input)
143
+
144
+ logger.error("Force-submit failed, returning rules-only result")
145
+ rule_results = context.get_or_run_rules(self._static)
146
+ return AnalysisResult(app_id=job.app_id, job=job, rule_results=rule_results)
147
+
148
+ def _build_result(
149
+ self,
150
+ job: JobAnalysis,
151
+ context: AgentContext,
152
+ parsed: AnalysisToolInput,
153
+ ) -> AnalysisResult:
154
+ rule_results = context.get_or_run_rules(self._static)
155
+ report = build_advisor_report(job.app_id, parsed, rule_results)
156
+ return AnalysisResult(app_id=job.app_id, job=job, rule_results=rule_results, ai_report=report)
157
+
158
+ @staticmethod
159
+ def _extract_final_report(response: Message) -> AnalysisToolInput | None:
160
+ for block in response.content:
161
+ if isinstance(block, ToolUseBlock) and block.name == AgentToolName.SUBMIT_FINAL_REPORT:
162
+ try:
163
+ return AnalysisToolInput.model_validate(block.input)
164
+ except ValidationError:
165
+ logger.warning("Failed to validate submit_final_report input, skipping")
166
+ return None
167
+ return None