shotgun-sh 0.1.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (94) hide show
  1. shotgun/__init__.py +3 -0
  2. shotgun/agents/__init__.py +1 -0
  3. shotgun/agents/agent_manager.py +196 -0
  4. shotgun/agents/common.py +295 -0
  5. shotgun/agents/config/__init__.py +13 -0
  6. shotgun/agents/config/manager.py +215 -0
  7. shotgun/agents/config/models.py +120 -0
  8. shotgun/agents/config/provider.py +91 -0
  9. shotgun/agents/history/__init__.py +5 -0
  10. shotgun/agents/history/history_processors.py +213 -0
  11. shotgun/agents/models.py +94 -0
  12. shotgun/agents/plan.py +119 -0
  13. shotgun/agents/research.py +131 -0
  14. shotgun/agents/tasks.py +122 -0
  15. shotgun/agents/tools/__init__.py +26 -0
  16. shotgun/agents/tools/codebase/__init__.py +28 -0
  17. shotgun/agents/tools/codebase/codebase_shell.py +256 -0
  18. shotgun/agents/tools/codebase/directory_lister.py +141 -0
  19. shotgun/agents/tools/codebase/file_read.py +144 -0
  20. shotgun/agents/tools/codebase/models.py +252 -0
  21. shotgun/agents/tools/codebase/query_graph.py +67 -0
  22. shotgun/agents/tools/codebase/retrieve_code.py +81 -0
  23. shotgun/agents/tools/file_management.py +130 -0
  24. shotgun/agents/tools/user_interaction.py +36 -0
  25. shotgun/agents/tools/web_search.py +69 -0
  26. shotgun/cli/__init__.py +1 -0
  27. shotgun/cli/codebase/__init__.py +5 -0
  28. shotgun/cli/codebase/commands.py +202 -0
  29. shotgun/cli/codebase/models.py +21 -0
  30. shotgun/cli/config.py +261 -0
  31. shotgun/cli/models.py +10 -0
  32. shotgun/cli/plan.py +65 -0
  33. shotgun/cli/research.py +78 -0
  34. shotgun/cli/tasks.py +71 -0
  35. shotgun/cli/utils.py +25 -0
  36. shotgun/codebase/__init__.py +12 -0
  37. shotgun/codebase/core/__init__.py +46 -0
  38. shotgun/codebase/core/change_detector.py +358 -0
  39. shotgun/codebase/core/code_retrieval.py +243 -0
  40. shotgun/codebase/core/ingestor.py +1497 -0
  41. shotgun/codebase/core/language_config.py +297 -0
  42. shotgun/codebase/core/manager.py +1554 -0
  43. shotgun/codebase/core/nl_query.py +327 -0
  44. shotgun/codebase/core/parser_loader.py +152 -0
  45. shotgun/codebase/models.py +107 -0
  46. shotgun/codebase/service.py +148 -0
  47. shotgun/logging_config.py +172 -0
  48. shotgun/main.py +73 -0
  49. shotgun/prompts/__init__.py +5 -0
  50. shotgun/prompts/agents/__init__.py +1 -0
  51. shotgun/prompts/agents/partials/codebase_understanding.j2 +79 -0
  52. shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +10 -0
  53. shotgun/prompts/agents/partials/interactive_mode.j2 +8 -0
  54. shotgun/prompts/agents/plan.j2 +57 -0
  55. shotgun/prompts/agents/research.j2 +38 -0
  56. shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +13 -0
  57. shotgun/prompts/agents/state/system_state.j2 +1 -0
  58. shotgun/prompts/agents/tasks.j2 +67 -0
  59. shotgun/prompts/codebase/__init__.py +1 -0
  60. shotgun/prompts/codebase/cypher_query_patterns.j2 +221 -0
  61. shotgun/prompts/codebase/cypher_system.j2 +28 -0
  62. shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
  63. shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
  64. shotgun/prompts/codebase/partials/graph_schema.j2 +28 -0
  65. shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
  66. shotgun/prompts/history/__init__.py +1 -0
  67. shotgun/prompts/history/summarization.j2 +46 -0
  68. shotgun/prompts/loader.py +140 -0
  69. shotgun/prompts/user/research.j2 +5 -0
  70. shotgun/py.typed +0 -0
  71. shotgun/sdk/__init__.py +13 -0
  72. shotgun/sdk/codebase.py +195 -0
  73. shotgun/sdk/exceptions.py +17 -0
  74. shotgun/sdk/models.py +189 -0
  75. shotgun/sdk/services.py +23 -0
  76. shotgun/telemetry.py +68 -0
  77. shotgun/tui/__init__.py +0 -0
  78. shotgun/tui/app.py +49 -0
  79. shotgun/tui/components/prompt_input.py +69 -0
  80. shotgun/tui/components/spinner.py +86 -0
  81. shotgun/tui/components/splash.py +25 -0
  82. shotgun/tui/components/vertical_tail.py +28 -0
  83. shotgun/tui/screens/chat.py +415 -0
  84. shotgun/tui/screens/chat.tcss +28 -0
  85. shotgun/tui/screens/provider_config.py +221 -0
  86. shotgun/tui/screens/splash.py +31 -0
  87. shotgun/tui/styles.tcss +10 -0
  88. shotgun/utils/__init__.py +5 -0
  89. shotgun/utils/file_system_utils.py +31 -0
  90. shotgun_sh-0.1.0.dev1.dist-info/METADATA +318 -0
  91. shotgun_sh-0.1.0.dev1.dist-info/RECORD +94 -0
  92. shotgun_sh-0.1.0.dev1.dist-info/WHEEL +4 -0
  93. shotgun_sh-0.1.0.dev1.dist-info/entry_points.txt +3 -0
  94. shotgun_sh-0.1.0.dev1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,195 @@
1
+ """Codebase SDK for framework-agnostic business logic."""
2
+
3
+ import asyncio
4
+ from collections.abc import Awaitable, Callable
5
+ from pathlib import Path
6
+
7
+ from shotgun.codebase.models import CodebaseGraph, QueryType
8
+
9
+ from .exceptions import CodebaseNotFoundError, InvalidPathError
10
+ from .models import (
11
+ DeleteResult,
12
+ IndexResult,
13
+ InfoResult,
14
+ ListResult,
15
+ QueryCommandResult,
16
+ ReindexResult,
17
+ )
18
+ from .services import get_codebase_service
19
+
20
+
21
+ class CodebaseSDK:
22
+ """Framework-agnostic SDK for codebase operations.
23
+
24
+ This SDK provides business logic for codebase management that can be
25
+ used by both CLI and TUI implementations without framework dependencies.
26
+ """
27
+
28
+ def __init__(self, storage_dir: Path | None = None):
29
+ """Initialize SDK with optional storage directory.
30
+
31
+ Args:
32
+ storage_dir: Optional custom storage directory.
33
+ Defaults to ~/.shotgun-sh/codebases/
34
+ """
35
+ self.service = get_codebase_service(storage_dir)
36
+
37
+ async def list_codebases(self) -> ListResult:
38
+ """List all indexed codebases.
39
+
40
+ Returns:
41
+ ListResult containing list of codebases
42
+ """
43
+ graphs = await self.service.list_graphs()
44
+ return ListResult(graphs=graphs)
45
+
46
+ async def index_codebase(self, path: Path, name: str) -> IndexResult:
47
+ """Index a new codebase.
48
+
49
+ Args:
50
+ path: Path to the repository to index
51
+ name: Human-readable name for the codebase
52
+
53
+ Returns:
54
+ IndexResult with indexing details
55
+
56
+ Raises:
57
+ InvalidPathError: If the path does not exist
58
+ """
59
+ resolved_path = path.resolve()
60
+ if not resolved_path.exists():
61
+ raise InvalidPathError(f"Path does not exist: {resolved_path}")
62
+
63
+ graph = await self.service.create_graph(resolved_path, name)
64
+ file_count = sum(graph.language_stats.values()) if graph.language_stats else 0
65
+
66
+ return IndexResult(
67
+ graph_id=graph.graph_id,
68
+ name=name,
69
+ repo_path=str(resolved_path),
70
+ file_count=file_count,
71
+ node_count=graph.node_count,
72
+ relationship_count=graph.relationship_count,
73
+ )
74
+
75
+ async def delete_codebase(
76
+ self,
77
+ graph_id: str,
78
+ confirm_callback: Callable[[CodebaseGraph], bool]
79
+ | Callable[[CodebaseGraph], Awaitable[bool]]
80
+ | None = None,
81
+ ) -> DeleteResult:
82
+ """Delete a codebase with optional confirmation.
83
+
84
+ Args:
85
+ graph_id: ID of the graph to delete
86
+ confirm_callback: Optional callback for confirmation.
87
+ Can be sync or async function that receives
88
+ the CodebaseGraph object and returns boolean.
89
+
90
+ Returns:
91
+ DeleteResult indicating success, failure, or cancellation
92
+
93
+ Raises:
94
+ CodebaseNotFoundError: If the graph is not found
95
+ """
96
+ graph = await self.service.get_graph(graph_id)
97
+ if not graph:
98
+ raise CodebaseNotFoundError(f"Graph not found: {graph_id}")
99
+
100
+ # Handle confirmation callback if provided
101
+ if confirm_callback:
102
+ if asyncio.iscoroutinefunction(confirm_callback):
103
+ confirmed = await confirm_callback(graph)
104
+ else:
105
+ confirmed = confirm_callback(graph)
106
+
107
+ if not confirmed:
108
+ return DeleteResult(
109
+ graph_id=graph_id,
110
+ name=graph.name,
111
+ deleted=False,
112
+ cancelled=True,
113
+ )
114
+
115
+ await self.service.delete_graph(graph_id)
116
+ return DeleteResult(
117
+ graph_id=graph_id,
118
+ name=graph.name,
119
+ deleted=True,
120
+ cancelled=False,
121
+ )
122
+
123
+ async def get_info(self, graph_id: str) -> InfoResult:
124
+ """Get detailed information about a codebase.
125
+
126
+ Args:
127
+ graph_id: ID of the graph to get info for
128
+
129
+ Returns:
130
+ InfoResult with detailed graph information
131
+
132
+ Raises:
133
+ CodebaseNotFoundError: If the graph is not found
134
+ """
135
+ graph = await self.service.get_graph(graph_id)
136
+ if not graph:
137
+ raise CodebaseNotFoundError(f"Graph not found: {graph_id}")
138
+
139
+ return InfoResult(graph=graph)
140
+
141
+ async def query_codebase(
142
+ self, graph_id: str, query_text: str, query_type: QueryType
143
+ ) -> QueryCommandResult:
144
+ """Query a codebase using natural language or Cypher.
145
+
146
+ Args:
147
+ graph_id: ID of the graph to query
148
+ query_text: Query text (natural language or Cypher)
149
+ query_type: Type of query (NATURAL_LANGUAGE or CYPHER)
150
+
151
+ Returns:
152
+ QueryCommandResult with query results
153
+
154
+ Raises:
155
+ CodebaseNotFoundError: If the graph is not found
156
+ """
157
+ graph = await self.service.get_graph(graph_id)
158
+ if not graph:
159
+ raise CodebaseNotFoundError(f"Graph not found: {graph_id}")
160
+
161
+ query_result = await self.service.execute_query(
162
+ graph_id, query_text, query_type
163
+ )
164
+
165
+ return QueryCommandResult(
166
+ graph_name=graph.name,
167
+ query_type="Cypher"
168
+ if query_type == QueryType.CYPHER
169
+ else "natural language",
170
+ result=query_result,
171
+ )
172
+
173
+ async def reindex_codebase(self, graph_id: str) -> ReindexResult:
174
+ """Reindex an existing codebase.
175
+
176
+ Args:
177
+ graph_id: ID of the graph to reindex
178
+
179
+ Returns:
180
+ ReindexResult with reindexing details
181
+
182
+ Raises:
183
+ CodebaseNotFoundError: If the graph is not found
184
+ """
185
+ graph = await self.service.get_graph(graph_id)
186
+ if not graph:
187
+ raise CodebaseNotFoundError(f"Graph not found: {graph_id}")
188
+
189
+ stats = await self.service.reindex_graph(graph_id)
190
+
191
+ return ReindexResult(
192
+ graph_id=graph_id,
193
+ name=graph.name,
194
+ stats=stats,
195
+ )
@@ -0,0 +1,17 @@
1
+ """SDK-specific exceptions."""
2
+
3
+
4
+ class ShotgunSDKError(Exception):
5
+ """Base exception for all SDK operations."""
6
+
7
+
8
+ class CodebaseNotFoundError(ShotgunSDKError):
9
+ """Raised when a codebase or graph is not found."""
10
+
11
+
12
+ class CodebaseOperationError(ShotgunSDKError):
13
+ """Raised when a codebase operation fails."""
14
+
15
+
16
+ class InvalidPathError(ShotgunSDKError):
17
+ """Raised when a provided path is invalid."""
shotgun/sdk/models.py ADDED
@@ -0,0 +1,189 @@
1
+ """Result models for SDK operations."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from shotgun.codebase.models import CodebaseGraph, QueryResult
8
+
9
+
10
+ class ListResult(BaseModel):
11
+ """Result for list command."""
12
+
13
+ graphs: list[CodebaseGraph]
14
+
15
+ def __str__(self) -> str:
16
+ """Format list result as plain text table."""
17
+ if not self.graphs:
18
+ return "No codebases found."
19
+
20
+ lines = [
21
+ f"{'ID':<12} {'Name':<30} {'Status':<10} {'Files':<8} {'Path'}",
22
+ "-" * 80,
23
+ ]
24
+
25
+ for graph in self.graphs:
26
+ file_count = (
27
+ sum(graph.language_stats.values()) if graph.language_stats else 0
28
+ )
29
+ lines.append(
30
+ f"{graph.graph_id[:12]:<12} {graph.name[:30]:<30} {graph.status.value:<10} {file_count:<8} {graph.repo_path}"
31
+ )
32
+
33
+ return "\n".join(lines)
34
+
35
+
36
+ class IndexResult(BaseModel):
37
+ """Result for index command."""
38
+
39
+ graph_id: str
40
+ name: str
41
+ repo_path: str
42
+ file_count: int
43
+ node_count: int
44
+ relationship_count: int
45
+
46
+ def __str__(self) -> str:
47
+ """Format index result as success message."""
48
+ return (
49
+ "Successfully indexed codebase!\n"
50
+ f"Graph ID: {self.graph_id}\n"
51
+ f"Files processed: {self.file_count}\n"
52
+ f"Nodes: {self.node_count}\n"
53
+ f"Relationships: {self.relationship_count}"
54
+ )
55
+
56
+
57
+ class DeleteResult(BaseModel):
58
+ """Result for delete command."""
59
+
60
+ graph_id: str
61
+ name: str
62
+ deleted: bool
63
+ cancelled: bool = False
64
+
65
+ def __str__(self) -> str:
66
+ """Format delete result message."""
67
+ if self.cancelled:
68
+ return "Deletion cancelled."
69
+ elif self.deleted:
70
+ return f"Successfully deleted codebase: {self.graph_id}"
71
+ else:
72
+ return f"Failed to delete codebase: {self.graph_id}"
73
+
74
+
75
+ class InfoResult(BaseModel):
76
+ """Result for info command."""
77
+
78
+ graph: CodebaseGraph
79
+
80
+ def __str__(self) -> str:
81
+ """Format detailed graph information."""
82
+ graph = self.graph
83
+ lines = [
84
+ f"Graph ID: {graph.graph_id}",
85
+ f"Name: {graph.name}",
86
+ f"Status: {graph.status.value}",
87
+ f"Repository Path: {graph.repo_path}",
88
+ f"Database Path: {graph.graph_path}",
89
+ f"Created: {graph.created_at}",
90
+ f"Updated: {graph.updated_at}",
91
+ f"Schema Version: {graph.schema_version}",
92
+ f"Total Nodes: {graph.node_count}",
93
+ f"Total Relationships: {graph.relationship_count}",
94
+ ]
95
+
96
+ if graph.language_stats:
97
+ lines.append("\nLanguage Statistics:")
98
+ for lang, count in graph.language_stats.items():
99
+ lines.append(f" {lang}: {count} files")
100
+
101
+ if graph.node_stats:
102
+ lines.append("\nNode Statistics:")
103
+ for node_type, count in graph.node_stats.items():
104
+ lines.append(f" {node_type}: {count}")
105
+
106
+ if graph.relationship_stats:
107
+ lines.append("\nRelationship Statistics:")
108
+ for rel_type, count in graph.relationship_stats.items():
109
+ lines.append(f" {rel_type}: {count}")
110
+
111
+ return "\n".join(lines)
112
+
113
+
114
+ class QueryCommandResult(BaseModel):
115
+ """Result for query command."""
116
+
117
+ graph_name: str
118
+ query_type: str
119
+ result: QueryResult
120
+
121
+ def __str__(self) -> str:
122
+ """Format query results table."""
123
+ query_result = self.result
124
+
125
+ if not query_result.success:
126
+ return f"Query failed: {query_result.error}"
127
+
128
+ if not query_result.results:
129
+ return "No results found."
130
+
131
+ lines = [
132
+ f"Query executed in {query_result.execution_time_ms:.2f}ms",
133
+ f"Results: {query_result.row_count} rows",
134
+ ]
135
+
136
+ if query_result.cypher_query:
137
+ lines.append(f"Generated Cypher: {query_result.cypher_query}")
138
+
139
+ lines.append("") # Empty line
140
+
141
+ # Format results table
142
+ if query_result.column_names:
143
+ header = " | ".join(f"{col:<20}" for col in query_result.column_names)
144
+ lines.append(header)
145
+ lines.append("-" * len(header))
146
+
147
+ for row in query_result.results:
148
+ row_data = " | ".join(
149
+ f"{str(row.get(col, '')):<20}" for col in query_result.column_names
150
+ )
151
+ lines.append(row_data)
152
+ else:
153
+ # Fallback for results without column names
154
+ for i, row in enumerate(query_result.results):
155
+ lines.append(f"Row {i + 1}:")
156
+ for key, value in row.items():
157
+ lines.append(f" {key}: {value}")
158
+ lines.append("")
159
+
160
+ return "\n".join(lines)
161
+
162
+
163
+ class ReindexResult(BaseModel):
164
+ """Result for reindex command."""
165
+
166
+ graph_id: str
167
+ name: str
168
+ stats: dict[str, Any] | None = None
169
+
170
+ def __str__(self) -> str:
171
+ """Format reindex completion message."""
172
+ lines = ["Reindexing completed!"]
173
+ if self.stats:
174
+ lines.append(f"Stats: {self.stats}")
175
+ return "\n".join(lines)
176
+
177
+
178
+ class ErrorResult(BaseModel):
179
+ """Result for error cases."""
180
+
181
+ error_message: str
182
+ details: str | None = None
183
+
184
+ def __str__(self) -> str:
185
+ """Format error message."""
186
+ output = f"Error: {self.error_message}"
187
+ if self.details:
188
+ output += f"\n{self.details}"
189
+ return output
@@ -0,0 +1,23 @@
1
+ """Service factory functions for SDK."""
2
+
3
+ from pathlib import Path
4
+
5
+ from shotgun.codebase.service import CodebaseService
6
+ from shotgun.utils import get_shotgun_home
7
+
8
+
9
+ def get_codebase_service(storage_dir: Path | str | None = None) -> CodebaseService:
10
+ """Get CodebaseService instance with configurable storage.
11
+
12
+ Args:
13
+ storage_dir: Optional custom storage directory.
14
+ Defaults to ~/.shotgun-sh/codebases/
15
+
16
+ Returns:
17
+ Configured CodebaseService instance
18
+ """
19
+ if storage_dir is None:
20
+ storage_dir = get_shotgun_home() / "codebases"
21
+ elif isinstance(storage_dir, str):
22
+ storage_dir = Path(storage_dir)
23
+ return CodebaseService(storage_dir)
shotgun/telemetry.py ADDED
@@ -0,0 +1,68 @@
1
+ """Phoenix AI observability setup for cloud and local deployment."""
2
+
3
+ import logging
4
+ import os
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def setup_phoenix_observability() -> bool:
10
+ """Set up Phoenix AI observability if enabled.
11
+
12
+ Supports both local Phoenix and cloud Phoenix (Arize) configurations.
13
+
14
+ Returns:
15
+ True if Phoenix was successfully set up, False otherwise
16
+ """
17
+ # Check if Phoenix observability is enabled
18
+ if os.getenv("PHOENIX_ENABLED", "false").lower() not in ("true", "1", "yes"):
19
+ logger.debug("Phoenix AI observability disabled via PHOENIX_ENABLED env var")
20
+ return False
21
+
22
+ try:
23
+ # Check if using cloud Phoenix (Arize) or local Phoenix
24
+ phoenix_collector_endpoint = os.getenv("PHOENIX_COLLECTOR_ENDPOINT")
25
+ phoenix_api_key = os.getenv("PHOENIX_API_KEY")
26
+
27
+ if not phoenix_collector_endpoint or not phoenix_api_key:
28
+ return False
29
+
30
+ # Cloud Phoenix setup (Arize) - following exact docs pattern
31
+ logger.debug("Setting up cloud Phoenix AI observability")
32
+
33
+ from openinference.instrumentation.pydantic_ai import (
34
+ OpenInferenceSpanProcessor,
35
+ )
36
+ from opentelemetry import trace
37
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
38
+ OTLPSpanExporter,
39
+ )
40
+ from opentelemetry.sdk.trace import TracerProvider
41
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor
42
+
43
+ # Set up tracer provider
44
+ tracer_provider = TracerProvider()
45
+ trace.set_tracer_provider(tracer_provider)
46
+
47
+ # Set up OTLP exporter for cloud Phoenix
48
+ # Phoenix cloud expects Authorization header, not api_key
49
+ otlp_exporter = OTLPSpanExporter(
50
+ endpoint=phoenix_collector_endpoint,
51
+ headers={"authorization": f"Bearer {phoenix_api_key}"},
52
+ )
53
+
54
+ # Add both span processors - OpenInference for semantics and BatchSpanProcessor for export
55
+ tracer_provider.add_span_processor(OpenInferenceSpanProcessor())
56
+ tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
57
+
58
+ logger.debug("Cloud Phoenix AI observability configured successfully")
59
+ logger.debug("Endpoint: %s", phoenix_collector_endpoint)
60
+ logger.debug("API key configured: %s", "Yes" if phoenix_api_key else "No")
61
+ return True
62
+
63
+ except ImportError as e:
64
+ logger.warning("Phoenix AI not available: %s", e)
65
+ return False
66
+ except Exception as e:
67
+ logger.warning("Failed to setup Phoenix AI observability: %s", e)
68
+ return False
File without changes
shotgun/tui/app.py ADDED
@@ -0,0 +1,49 @@
1
+ from textual.app import App
2
+ from textual.binding import Binding
3
+
4
+ from shotgun.agents.config import ConfigManager, get_config_manager
5
+ from shotgun.tui.screens.splash import SplashScreen
6
+
7
+ from .screens.chat import ChatScreen
8
+ from .screens.provider_config import ProviderConfigScreen
9
+
10
+
11
+ class ShotgunApp(App[None]):
12
+ SCREENS = {"chat": ChatScreen, "provider_config": ProviderConfigScreen}
13
+ BINDINGS = [
14
+ Binding("ctrl+c", "quit", "Quit the app"),
15
+ ]
16
+ CSS_PATH = "styles.tcss"
17
+
18
+ def __init__(self) -> None:
19
+ super().__init__()
20
+ self.config_manager: ConfigManager = get_config_manager()
21
+
22
+ def on_mount(self) -> None:
23
+ self.push_screen(
24
+ SplashScreen(), callback=lambda _arg: self.refresh_startup_screen()
25
+ )
26
+ # self.refresh_startup_screen()
27
+
28
+ def refresh_startup_screen(self) -> None:
29
+ """Push the appropriate screen based on configured providers."""
30
+ if self.config_manager.has_any_provider_key():
31
+ if isinstance(self.screen, ChatScreen):
32
+ return
33
+ self.push_screen("chat")
34
+ else:
35
+ if isinstance(self.screen, ProviderConfigScreen):
36
+ return
37
+
38
+ self.push_screen(
39
+ "provider_config", callback=lambda _arg: self.refresh_startup_screen()
40
+ )
41
+
42
+
43
+ def run() -> None:
44
+ app = ShotgunApp()
45
+ app.run(inline_no_clear=True)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ run()
@@ -0,0 +1,69 @@
1
+ from textual import events
2
+ from textual.message import Message
3
+ from textual.widgets import TextArea
4
+
5
+
6
+ class PromptInput(TextArea):
7
+ """A TextArea with a submit binding."""
8
+
9
+ DEFAULT_CSS = """
10
+ PromptInput {
11
+ outline: round $primary;
12
+ background: transparent;
13
+ }
14
+ """
15
+
16
+ def check_action(self, action: str, parameters: tuple[object, ...]) -> bool:
17
+ if action != "copy":
18
+ return True
19
+ # run copy action if there is selected text
20
+ # otherwise, do nothing, so global ctrl+c still works.
21
+ return bool(self.selected_text)
22
+
23
+ class Submitted(Message):
24
+ """A message to indicate that the text has been submitted."""
25
+
26
+ def __init__(self, text: str) -> None:
27
+ super().__init__()
28
+ self.text = text
29
+
30
+ def action_submit(self) -> None:
31
+ """An action to submit the text."""
32
+ self.post_message(self.Submitted(self.text))
33
+
34
+ async def _on_key(self, event: events.Key) -> None:
35
+ """Handle key presses which correspond to document inserts."""
36
+
37
+ # Don't handle Enter key here - let the binding handle it
38
+ if event.key == "enter":
39
+ self.action_submit()
40
+
41
+ self._restart_blink()
42
+
43
+ if self.read_only:
44
+ return
45
+
46
+ key = event.key
47
+ insert_values = {
48
+ "ctrl+j": "\n",
49
+ }
50
+ if self.tab_behavior == "indent":
51
+ if key == "escape":
52
+ event.stop()
53
+ event.prevent_default()
54
+ self.screen.focus_next()
55
+ return
56
+ if self.indent_type == "tabs":
57
+ insert_values["tab"] = "\t"
58
+ else:
59
+ insert_values["tab"] = " " * self._find_columns_to_next_tab_stop()
60
+
61
+ if event.is_printable or key in insert_values:
62
+ event.stop()
63
+ event.prevent_default()
64
+ insert = insert_values.get(key, event.character)
65
+ # `insert` is not None because event.character cannot be
66
+ # None because we've checked that it's printable.
67
+ assert insert is not None # noqa: S101
68
+ start, end = self.selection
69
+ self._replace_via_keyboard(insert, start, end)
@@ -0,0 +1,86 @@
1
+ """Spinner component for showing loading/working state."""
2
+
3
+ from textual.app import ComposeResult
4
+ from textual.containers import Container
5
+ from textual.css.query import NoMatches
6
+ from textual.reactive import reactive
7
+ from textual.timer import Timer
8
+ from textual.widget import Widget
9
+ from textual.widgets import Static
10
+
11
+
12
+ class Spinner(Widget):
13
+ """A spinner widget that shows a rotating animation when working."""
14
+
15
+ DEFAULT_CSS = """
16
+ Spinner {
17
+ width: auto;
18
+ height: 1;
19
+ }
20
+
21
+ Spinner > Container {
22
+ width: auto;
23
+ height: 1;
24
+ layout: horizontal;
25
+ }
26
+
27
+ Spinner .spinner-icon {
28
+ width: 1;
29
+ margin-right: 1;
30
+ }
31
+
32
+ Spinner .spinner-text {
33
+ width: auto;
34
+ }
35
+ """
36
+
37
+ # Animation frames for the spinner
38
+ FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
39
+
40
+ text = reactive("Working...")
41
+ _frame_index = reactive(0)
42
+
43
+ def __init__(
44
+ self,
45
+ text: str = "Working...",
46
+ *,
47
+ name: str | None = None,
48
+ id: str | None = None,
49
+ classes: str | None = None,
50
+ ) -> None:
51
+ super().__init__(name=name, id=id, classes=classes)
52
+ self.text = text
53
+ self._timer: Timer | None = None
54
+
55
+ def compose(self) -> ComposeResult:
56
+ """Compose the spinner widget."""
57
+ with Container():
58
+ yield Static("", classes="spinner-icon")
59
+ yield Static(self.text, classes="spinner-text")
60
+
61
+ def on_mount(self) -> None:
62
+ """Set up the animation timer when mounted."""
63
+ self._timer = self.set_interval(0.1, self._advance_frame)
64
+
65
+ def _advance_frame(self) -> None:
66
+ """Advance to the next animation frame."""
67
+ self._frame_index = (self._frame_index + 1) % len(self.FRAMES)
68
+ self._update_display()
69
+
70
+ def _update_display(self) -> None:
71
+ """Update the spinner display."""
72
+ try:
73
+ icon_widget = self.query_one(".spinner-icon", Static)
74
+ icon_widget.update(self.FRAMES[self._frame_index])
75
+ except NoMatches:
76
+ # Widget not mounted yet, ignore
77
+ pass
78
+
79
+ def watch_text(self, text: str) -> None:
80
+ """React to changes in the text."""
81
+ try:
82
+ text_widget = self.query_one(".spinner-text", Static)
83
+ text_widget.update(text)
84
+ except NoMatches:
85
+ # Widget not mounted yet, ignore
86
+ return