shotgun-sh 0.2.23.dev1__py3-none-any.whl → 0.2.29.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +3 -3
- shotgun/agents/common.py +1 -1
- shotgun/agents/config/manager.py +36 -21
- shotgun/agents/config/models.py +30 -0
- shotgun/agents/config/provider.py +27 -14
- shotgun/agents/context_analyzer/analyzer.py +6 -2
- shotgun/agents/conversation/__init__.py +18 -0
- shotgun/agents/conversation/filters.py +164 -0
- shotgun/agents/conversation/history/chunking.py +278 -0
- shotgun/agents/{history → conversation/history}/compaction.py +27 -1
- shotgun/agents/{history → conversation/history}/constants.py +5 -0
- shotgun/agents/conversation/history/file_content_deduplication.py +216 -0
- shotgun/agents/{history → conversation/history}/history_processors.py +267 -3
- shotgun/agents/{conversation_manager.py → conversation/manager.py} +1 -1
- shotgun/agents/{conversation_history.py → conversation/models.py} +8 -94
- shotgun/agents/tools/web_search/openai.py +1 -1
- shotgun/cli/clear.py +1 -1
- shotgun/cli/compact.py +5 -3
- shotgun/cli/context.py +1 -1
- shotgun/cli/spec/__init__.py +5 -0
- shotgun/cli/spec/backup.py +81 -0
- shotgun/cli/spec/commands.py +130 -0
- shotgun/cli/spec/models.py +30 -0
- shotgun/cli/spec/pull_service.py +165 -0
- shotgun/codebase/core/ingestor.py +153 -7
- shotgun/codebase/models.py +2 -0
- shotgun/exceptions.py +5 -3
- shotgun/main.py +2 -0
- shotgun/posthog_telemetry.py +1 -1
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +3 -3
- shotgun/prompts/agents/partials/interactive_mode.j2 +3 -3
- shotgun/prompts/agents/research.j2 +0 -3
- shotgun/prompts/history/chunk_summarization.j2 +34 -0
- shotgun/prompts/history/combine_summaries.j2 +53 -0
- shotgun/shotgun_web/__init__.py +67 -1
- shotgun/shotgun_web/client.py +42 -1
- shotgun/shotgun_web/constants.py +46 -0
- shotgun/shotgun_web/exceptions.py +29 -0
- shotgun/shotgun_web/models.py +390 -0
- shotgun/shotgun_web/shared_specs/__init__.py +32 -0
- shotgun/shotgun_web/shared_specs/file_scanner.py +175 -0
- shotgun/shotgun_web/shared_specs/hasher.py +83 -0
- shotgun/shotgun_web/shared_specs/models.py +71 -0
- shotgun/shotgun_web/shared_specs/upload_pipeline.py +291 -0
- shotgun/shotgun_web/shared_specs/utils.py +34 -0
- shotgun/shotgun_web/specs_client.py +703 -0
- shotgun/shotgun_web/supabase_client.py +31 -0
- shotgun/tui/app.py +39 -0
- shotgun/tui/containers.py +1 -1
- shotgun/tui/layout.py +5 -0
- shotgun/tui/screens/chat/chat_screen.py +212 -16
- shotgun/tui/screens/chat/codebase_index_prompt_screen.py +147 -19
- shotgun/tui/screens/chat_screen/command_providers.py +10 -0
- shotgun/tui/screens/chat_screen/history/chat_history.py +0 -36
- shotgun/tui/screens/confirmation_dialog.py +40 -0
- shotgun/tui/screens/model_picker.py +7 -1
- shotgun/tui/screens/onboarding.py +149 -0
- shotgun/tui/screens/pipx_migration.py +46 -0
- shotgun/tui/screens/provider_config.py +41 -0
- shotgun/tui/screens/shared_specs/__init__.py +21 -0
- shotgun/tui/screens/shared_specs/create_spec_dialog.py +273 -0
- shotgun/tui/screens/shared_specs/models.py +56 -0
- shotgun/tui/screens/shared_specs/share_specs_dialog.py +390 -0
- shotgun/tui/screens/shared_specs/upload_progress_screen.py +452 -0
- shotgun/tui/screens/shotgun_auth.py +60 -6
- shotgun/tui/screens/spec_pull.py +286 -0
- shotgun/tui/screens/welcome.py +91 -0
- shotgun/tui/services/conversation_service.py +5 -2
- shotgun/tui/widgets/widget_coordinator.py +1 -1
- {shotgun_sh-0.2.23.dev1.dist-info → shotgun_sh-0.2.29.dev2.dist-info}/METADATA +1 -1
- {shotgun_sh-0.2.23.dev1.dist-info → shotgun_sh-0.2.29.dev2.dist-info}/RECORD +86 -59
- {shotgun_sh-0.2.23.dev1.dist-info → shotgun_sh-0.2.29.dev2.dist-info}/WHEEL +1 -1
- /shotgun/agents/{history → conversation/history}/__init__.py +0 -0
- /shotgun/agents/{history → conversation/history}/context_extraction.py +0 -0
- /shotgun/agents/{history → conversation/history}/history_building.py +0 -0
- /shotgun/agents/{history → conversation/history}/message_utils.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/__init__.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/anthropic.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/base.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/openai.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/sentencepiece_counter.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/tokenizer_cache.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_counting/utils.py +0 -0
- /shotgun/agents/{history → conversation/history}/token_estimation.py +0 -0
- {shotgun_sh-0.2.23.dev1.dist-info → shotgun_sh-0.2.29.dev2.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.2.23.dev1.dist-info → shotgun_sh-0.2.29.dev2.dist-info}/licenses/LICENSE +0 -0
shotgun/agents/agent_manager.py
CHANGED
|
@@ -17,7 +17,7 @@ from tenacity import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
|
-
from shotgun.agents.
|
|
20
|
+
from shotgun.agents.conversation import ConversationState
|
|
21
21
|
|
|
22
22
|
from pydantic_ai import (
|
|
23
23
|
Agent,
|
|
@@ -68,8 +68,8 @@ from shotgun.posthog_telemetry import track_event
|
|
|
68
68
|
from shotgun.tui.screens.chat_screen.hint_message import HintMessage
|
|
69
69
|
from shotgun.utils.source_detection import detect_source
|
|
70
70
|
|
|
71
|
+
from .conversation.history.compaction import apply_persistent_compaction
|
|
71
72
|
from .export import create_export_agent
|
|
72
|
-
from .history.compaction import apply_persistent_compaction
|
|
73
73
|
from .messages import AgentSystemPrompt
|
|
74
74
|
from .models import AgentDeps, AgentRuntimeOptions
|
|
75
75
|
from .plan import create_plan_agent
|
|
@@ -1314,7 +1314,7 @@ class AgentManager(Widget):
|
|
|
1314
1314
|
Returns:
|
|
1315
1315
|
ConversationState object containing UI and agent messages and current type
|
|
1316
1316
|
"""
|
|
1317
|
-
from shotgun.agents.
|
|
1317
|
+
from shotgun.agents.conversation import ConversationState
|
|
1318
1318
|
|
|
1319
1319
|
return ConversationState(
|
|
1320
1320
|
agent_messages=self.message_history.copy(),
|
shotgun/agents/common.py
CHANGED
|
@@ -25,7 +25,7 @@ from shotgun.utils import ensure_shotgun_directory_exists
|
|
|
25
25
|
from shotgun.utils.datetime_utils import get_datetime_context
|
|
26
26
|
from shotgun.utils.file_system_utils import get_shotgun_base_path
|
|
27
27
|
|
|
28
|
-
from .history import token_limit_compactor
|
|
28
|
+
from .conversation.history import token_limit_compactor
|
|
29
29
|
from .messages import AgentSystemPrompt, SystemStatusPrompt
|
|
30
30
|
from .models import AgentDeps, AgentRuntimeOptions, PipelineConfigEntry
|
|
31
31
|
from .tools import (
|
shotgun/agents/config/manager.py
CHANGED
|
@@ -307,29 +307,41 @@ class ConfigManager:
|
|
|
307
307
|
# Convert plain text secrets to SecretStr objects
|
|
308
308
|
self._convert_secrets_to_secretstr(data)
|
|
309
309
|
|
|
310
|
+
# Clean up invalid selected_model before Pydantic validation
|
|
311
|
+
if "selected_model" in data and data["selected_model"] is not None:
|
|
312
|
+
from .models import MODEL_SPECS, ModelName
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# Try to convert to ModelName enum
|
|
316
|
+
model_name = ModelName(data["selected_model"])
|
|
317
|
+
# Check if it exists in MODEL_SPECS
|
|
318
|
+
if model_name not in MODEL_SPECS:
|
|
319
|
+
data["selected_model"] = None
|
|
320
|
+
except (ValueError, KeyError):
|
|
321
|
+
# Invalid model name - reset to None
|
|
322
|
+
data["selected_model"] = None
|
|
323
|
+
|
|
310
324
|
self._config = ShotgunConfig.model_validate(data)
|
|
311
325
|
logger.debug("Configuration loaded successfully from %s", self.config_path)
|
|
312
326
|
|
|
313
|
-
#
|
|
314
|
-
|
|
315
|
-
|
|
327
|
+
# Clear migration_failed flag if config loaded successfully
|
|
328
|
+
should_save = False
|
|
329
|
+
if self._config.migration_failed:
|
|
330
|
+
self._config.migration_failed = False
|
|
331
|
+
self._config.migration_backup_path = None
|
|
332
|
+
should_save = True
|
|
316
333
|
|
|
334
|
+
# Validate selected_model for BYOK mode - verify provider has a key
|
|
335
|
+
if not self._provider_has_api_key(self._config.shotgun):
|
|
317
336
|
# If selected_model is set, verify its provider has a key
|
|
318
337
|
if self._config.selected_model:
|
|
319
338
|
from .models import MODEL_SPECS
|
|
320
339
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
logger.info(
|
|
325
|
-
"Selected model %s provider has no API key, finding available model",
|
|
326
|
-
self._config.selected_model.value,
|
|
327
|
-
)
|
|
328
|
-
self._config.selected_model = None
|
|
329
|
-
should_save = True
|
|
330
|
-
else:
|
|
340
|
+
spec = MODEL_SPECS[self._config.selected_model]
|
|
341
|
+
if not await self.has_provider_key(spec.provider):
|
|
342
|
+
# Provider has no key - reset to None
|
|
331
343
|
logger.info(
|
|
332
|
-
"Selected model %s
|
|
344
|
+
"Selected model %s provider has no API key, finding available model",
|
|
333
345
|
self._config.selected_model.value,
|
|
334
346
|
)
|
|
335
347
|
self._config.selected_model = None
|
|
@@ -344,17 +356,13 @@ class ConfigManager:
|
|
|
344
356
|
|
|
345
357
|
# Find default model for this provider
|
|
346
358
|
provider_models = {
|
|
347
|
-
ProviderType.OPENAI: ModelName.
|
|
359
|
+
ProviderType.OPENAI: ModelName.GPT_5_1,
|
|
348
360
|
ProviderType.ANTHROPIC: ModelName.CLAUDE_HAIKU_4_5,
|
|
349
361
|
ProviderType.GOOGLE: ModelName.GEMINI_2_5_PRO,
|
|
350
362
|
}
|
|
351
363
|
|
|
352
364
|
if provider in provider_models:
|
|
353
365
|
self._config.selected_model = provider_models[provider]
|
|
354
|
-
logger.info(
|
|
355
|
-
"Set selected_model to %s (first available provider)",
|
|
356
|
-
self._config.selected_model.value,
|
|
357
|
-
)
|
|
358
366
|
should_save = True
|
|
359
367
|
break
|
|
360
368
|
|
|
@@ -498,7 +506,7 @@ class ConfigManager:
|
|
|
498
506
|
from .models import ModelName
|
|
499
507
|
|
|
500
508
|
provider_models = {
|
|
501
|
-
ProviderType.OPENAI: ModelName.
|
|
509
|
+
ProviderType.OPENAI: ModelName.GPT_5_1,
|
|
502
510
|
ProviderType.ANTHROPIC: ModelName.CLAUDE_HAIKU_4_5,
|
|
503
511
|
ProviderType.GOOGLE: ModelName.GEMINI_2_5_PRO,
|
|
504
512
|
}
|
|
@@ -736,13 +744,17 @@ class ConfigManager:
|
|
|
736
744
|
return config.shotgun_instance_id
|
|
737
745
|
|
|
738
746
|
async def update_shotgun_account(
|
|
739
|
-
self,
|
|
747
|
+
self,
|
|
748
|
+
api_key: str | None = None,
|
|
749
|
+
supabase_jwt: str | None = None,
|
|
750
|
+
workspace_id: str | None = None,
|
|
740
751
|
) -> None:
|
|
741
752
|
"""Update Shotgun Account configuration.
|
|
742
753
|
|
|
743
754
|
Args:
|
|
744
755
|
api_key: LiteLLM proxy API key (optional)
|
|
745
756
|
supabase_jwt: Supabase authentication JWT (optional)
|
|
757
|
+
workspace_id: Default workspace ID for shared specs (optional)
|
|
746
758
|
"""
|
|
747
759
|
config = await self.load()
|
|
748
760
|
|
|
@@ -754,6 +766,9 @@ class ConfigManager:
|
|
|
754
766
|
SecretStr(supabase_jwt) if supabase_jwt else None
|
|
755
767
|
)
|
|
756
768
|
|
|
769
|
+
if workspace_id is not None:
|
|
770
|
+
config.shotgun.workspace_id = workspace_id
|
|
771
|
+
|
|
757
772
|
await self.save(config)
|
|
758
773
|
logger.info("Updated Shotgun Account configuration")
|
|
759
774
|
|
shotgun/agents/config/models.py
CHANGED
|
@@ -27,6 +27,9 @@ class ModelName(StrEnum):
|
|
|
27
27
|
|
|
28
28
|
GPT_5 = "gpt-5"
|
|
29
29
|
GPT_5_MINI = "gpt-5-mini"
|
|
30
|
+
GPT_5_1 = "gpt-5.1"
|
|
31
|
+
GPT_5_1_CODEX = "gpt-5.1-codex"
|
|
32
|
+
GPT_5_1_CODEX_MINI = "gpt-5.1-codex-mini"
|
|
30
33
|
CLAUDE_OPUS_4_1 = "claude-opus-4-1"
|
|
31
34
|
CLAUDE_SONNET_4_5 = "claude-sonnet-4-5"
|
|
32
35
|
CLAUDE_HAIKU_4_5 = "claude-haiku-4-5"
|
|
@@ -114,6 +117,30 @@ MODEL_SPECS: dict[ModelName, ModelSpec] = {
|
|
|
114
117
|
litellm_proxy_model_name="openai/gpt-5-mini",
|
|
115
118
|
short_name="GPT-5 Mini",
|
|
116
119
|
),
|
|
120
|
+
ModelName.GPT_5_1: ModelSpec(
|
|
121
|
+
name=ModelName.GPT_5_1,
|
|
122
|
+
provider=ProviderType.OPENAI,
|
|
123
|
+
max_input_tokens=272_000,
|
|
124
|
+
max_output_tokens=128_000,
|
|
125
|
+
litellm_proxy_model_name="openai/gpt-5.1",
|
|
126
|
+
short_name="GPT-5.1",
|
|
127
|
+
),
|
|
128
|
+
ModelName.GPT_5_1_CODEX: ModelSpec(
|
|
129
|
+
name=ModelName.GPT_5_1_CODEX,
|
|
130
|
+
provider=ProviderType.OPENAI,
|
|
131
|
+
max_input_tokens=272_000,
|
|
132
|
+
max_output_tokens=128_000,
|
|
133
|
+
litellm_proxy_model_name="openai/gpt-5.1-codex",
|
|
134
|
+
short_name="GPT-5.1 Codex",
|
|
135
|
+
),
|
|
136
|
+
ModelName.GPT_5_1_CODEX_MINI: ModelSpec(
|
|
137
|
+
name=ModelName.GPT_5_1_CODEX_MINI,
|
|
138
|
+
provider=ProviderType.OPENAI,
|
|
139
|
+
max_input_tokens=272_000,
|
|
140
|
+
max_output_tokens=128_000,
|
|
141
|
+
litellm_proxy_model_name="openai/gpt-5.1-codex-mini",
|
|
142
|
+
short_name="GPT-5.1 Codex Mini",
|
|
143
|
+
),
|
|
117
144
|
ModelName.CLAUDE_OPUS_4_1: ModelSpec(
|
|
118
145
|
name=ModelName.CLAUDE_OPUS_4_1,
|
|
119
146
|
provider=ProviderType.ANTHROPIC,
|
|
@@ -186,6 +213,9 @@ class ShotgunAccountConfig(BaseModel):
|
|
|
186
213
|
supabase_jwt: SecretStr | None = Field(
|
|
187
214
|
default=None, description="Supabase authentication JWT"
|
|
188
215
|
)
|
|
216
|
+
workspace_id: str | None = Field(
|
|
217
|
+
default=None, description="Default workspace ID for shared specs"
|
|
218
|
+
)
|
|
189
219
|
|
|
190
220
|
|
|
191
221
|
class MarketingMessageRecord(BaseModel):
|
|
@@ -47,13 +47,13 @@ def get_default_model_for_provider(config: ShotgunConfig) -> ModelName:
|
|
|
47
47
|
"""
|
|
48
48
|
# Priority 1: Shotgun Account
|
|
49
49
|
if _get_api_key(config.shotgun.api_key):
|
|
50
|
-
return ModelName.
|
|
50
|
+
return ModelName.GPT_5_1
|
|
51
51
|
|
|
52
52
|
# Priority 2: Individual provider keys
|
|
53
53
|
if _get_api_key(config.anthropic.api_key):
|
|
54
54
|
return ModelName.CLAUDE_HAIKU_4_5
|
|
55
55
|
if _get_api_key(config.openai.api_key):
|
|
56
|
-
return ModelName.
|
|
56
|
+
return ModelName.GPT_5_1
|
|
57
57
|
if _get_api_key(config.google.api_key):
|
|
58
58
|
return ModelName.GEMINI_2_5_PRO
|
|
59
59
|
|
|
@@ -201,10 +201,12 @@ async def get_provider_model(
|
|
|
201
201
|
model_name = provider_or_model
|
|
202
202
|
else:
|
|
203
203
|
# No specific model requested - use selected or default
|
|
204
|
-
model_name = config.selected_model or
|
|
204
|
+
model_name = config.selected_model or get_default_model_for_provider(config)
|
|
205
205
|
|
|
206
|
+
# Gracefully fall back if the selected model doesn't exist (backwards compatibility)
|
|
206
207
|
if model_name not in MODEL_SPECS:
|
|
207
|
-
|
|
208
|
+
model_name = get_default_model_for_provider(config)
|
|
209
|
+
|
|
208
210
|
spec = MODEL_SPECS[model_name]
|
|
209
211
|
|
|
210
212
|
# Use Shotgun Account with determined model (provider = actual LLM provider)
|
|
@@ -225,10 +227,12 @@ async def get_provider_model(
|
|
|
225
227
|
if isinstance(provider_or_model, ModelName):
|
|
226
228
|
# Look up the model spec
|
|
227
229
|
if provider_or_model not in MODEL_SPECS:
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
230
|
+
requested_model = None # Fall back to provider default
|
|
231
|
+
provider_enum = None # Will be determined below
|
|
232
|
+
else:
|
|
233
|
+
spec = MODEL_SPECS[provider_or_model]
|
|
234
|
+
provider_enum = spec.provider
|
|
235
|
+
requested_model = provider_or_model
|
|
232
236
|
else:
|
|
233
237
|
# Convert string to ProviderType enum if needed (backward compatible)
|
|
234
238
|
if provider_or_model:
|
|
@@ -257,15 +261,22 @@ async def get_provider_model(
|
|
|
257
261
|
if not api_key:
|
|
258
262
|
raise ValueError("OpenAI API key not configured. Set via config.")
|
|
259
263
|
|
|
260
|
-
# Use requested model or default to gpt-5
|
|
261
|
-
model_name = requested_model if requested_model else ModelName.
|
|
264
|
+
# Use requested model or default to gpt-5.1
|
|
265
|
+
model_name = requested_model if requested_model else ModelName.GPT_5_1
|
|
266
|
+
# Gracefully fall back if model doesn't exist
|
|
262
267
|
if model_name not in MODEL_SPECS:
|
|
263
|
-
|
|
268
|
+
model_name = ModelName.GPT_5_1
|
|
264
269
|
spec = MODEL_SPECS[model_name]
|
|
265
270
|
|
|
266
271
|
# Check and test streaming capability for GPT-5 family models
|
|
267
272
|
supports_streaming = True # Default to True for all models
|
|
268
|
-
if model_name in (
|
|
273
|
+
if model_name in (
|
|
274
|
+
ModelName.GPT_5,
|
|
275
|
+
ModelName.GPT_5_MINI,
|
|
276
|
+
ModelName.GPT_5_1,
|
|
277
|
+
ModelName.GPT_5_1_CODEX,
|
|
278
|
+
ModelName.GPT_5_1_CODEX_MINI,
|
|
279
|
+
):
|
|
269
280
|
# Check if streaming capability has been tested
|
|
270
281
|
streaming_capability = config.openai.supports_streaming
|
|
271
282
|
|
|
@@ -304,8 +315,9 @@ async def get_provider_model(
|
|
|
304
315
|
|
|
305
316
|
# Use requested model or default to claude-haiku-4-5
|
|
306
317
|
model_name = requested_model if requested_model else ModelName.CLAUDE_HAIKU_4_5
|
|
318
|
+
# Gracefully fall back if model doesn't exist
|
|
307
319
|
if model_name not in MODEL_SPECS:
|
|
308
|
-
|
|
320
|
+
model_name = ModelName.CLAUDE_HAIKU_4_5
|
|
309
321
|
spec = MODEL_SPECS[model_name]
|
|
310
322
|
|
|
311
323
|
# Create fully configured ModelConfig
|
|
@@ -325,8 +337,9 @@ async def get_provider_model(
|
|
|
325
337
|
|
|
326
338
|
# Use requested model or default to gemini-2.5-pro
|
|
327
339
|
model_name = requested_model if requested_model else ModelName.GEMINI_2_5_PRO
|
|
340
|
+
# Gracefully fall back if model doesn't exist
|
|
328
341
|
if model_name not in MODEL_SPECS:
|
|
329
|
-
|
|
342
|
+
model_name = ModelName.GEMINI_2_5_PRO
|
|
330
343
|
spec = MODEL_SPECS[model_name]
|
|
331
344
|
|
|
332
345
|
# Create fully configured ModelConfig
|
|
@@ -15,8 +15,12 @@ from pydantic_ai.messages import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
from shotgun.agents.config.models import ModelConfig
|
|
18
|
-
from shotgun.agents.history.token_counting.utils import
|
|
19
|
-
|
|
18
|
+
from shotgun.agents.conversation.history.token_counting.utils import (
|
|
19
|
+
count_tokens_from_messages,
|
|
20
|
+
)
|
|
21
|
+
from shotgun.agents.conversation.history.token_estimation import (
|
|
22
|
+
estimate_tokens_from_messages,
|
|
23
|
+
)
|
|
20
24
|
from shotgun.agents.messages import AgentSystemPrompt, SystemStatusPrompt
|
|
21
25
|
from shotgun.logging_config import get_logger
|
|
22
26
|
from shotgun.tui.screens.chat_screen.hint_message import HintMessage
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Conversation module for managing conversation history and persistence."""
|
|
2
|
+
|
|
3
|
+
from .filters import (
|
|
4
|
+
filter_incomplete_messages,
|
|
5
|
+
filter_orphaned_tool_responses,
|
|
6
|
+
is_tool_call_complete,
|
|
7
|
+
)
|
|
8
|
+
from .manager import ConversationManager
|
|
9
|
+
from .models import ConversationHistory, ConversationState
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ConversationHistory",
|
|
13
|
+
"ConversationManager",
|
|
14
|
+
"ConversationState",
|
|
15
|
+
"filter_incomplete_messages",
|
|
16
|
+
"filter_orphaned_tool_responses",
|
|
17
|
+
"is_tool_call_complete",
|
|
18
|
+
]
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Filter functions for conversation message validation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from pydantic_ai.messages import (
|
|
7
|
+
ModelMessage,
|
|
8
|
+
ModelRequest,
|
|
9
|
+
ModelRequestPart,
|
|
10
|
+
ModelResponse,
|
|
11
|
+
ToolCallPart,
|
|
12
|
+
ToolReturnPart,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_tool_call_complete(tool_call: ToolCallPart) -> bool:
|
|
19
|
+
"""Check if a tool call has valid, complete JSON arguments.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
tool_call: The tool call part to validate
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
True if the tool call args are valid JSON, False otherwise
|
|
26
|
+
"""
|
|
27
|
+
if tool_call.args is None:
|
|
28
|
+
return True # No args is valid
|
|
29
|
+
|
|
30
|
+
if isinstance(tool_call.args, dict):
|
|
31
|
+
return True # Already parsed dict is valid
|
|
32
|
+
|
|
33
|
+
if not isinstance(tool_call.args, str):
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
# Try to parse the JSON string
|
|
37
|
+
try:
|
|
38
|
+
json.loads(tool_call.args)
|
|
39
|
+
return True
|
|
40
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
41
|
+
# Log incomplete tool call detection
|
|
42
|
+
args_preview = (
|
|
43
|
+
tool_call.args[:100] + "..."
|
|
44
|
+
if len(tool_call.args) > 100
|
|
45
|
+
else tool_call.args
|
|
46
|
+
)
|
|
47
|
+
logger.info(
|
|
48
|
+
"Detected incomplete tool call in validation",
|
|
49
|
+
extra={
|
|
50
|
+
"tool_name": tool_call.tool_name,
|
|
51
|
+
"tool_call_id": tool_call.tool_call_id,
|
|
52
|
+
"args_preview": args_preview,
|
|
53
|
+
"error": str(e),
|
|
54
|
+
},
|
|
55
|
+
)
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def filter_incomplete_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
60
|
+
"""Filter out messages with incomplete tool calls.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
messages: List of messages to filter
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List of messages with only complete tool calls
|
|
67
|
+
"""
|
|
68
|
+
filtered: list[ModelMessage] = []
|
|
69
|
+
filtered_count = 0
|
|
70
|
+
filtered_tool_names: list[str] = []
|
|
71
|
+
|
|
72
|
+
for message in messages:
|
|
73
|
+
# Only check ModelResponse messages for tool calls
|
|
74
|
+
if not isinstance(message, ModelResponse):
|
|
75
|
+
filtered.append(message)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
# Check if any tool calls are incomplete
|
|
79
|
+
has_incomplete_tool_call = False
|
|
80
|
+
for part in message.parts:
|
|
81
|
+
if isinstance(part, ToolCallPart) and not is_tool_call_complete(part):
|
|
82
|
+
has_incomplete_tool_call = True
|
|
83
|
+
filtered_tool_names.append(part.tool_name)
|
|
84
|
+
break
|
|
85
|
+
|
|
86
|
+
# Only include messages without incomplete tool calls
|
|
87
|
+
if not has_incomplete_tool_call:
|
|
88
|
+
filtered.append(message)
|
|
89
|
+
else:
|
|
90
|
+
filtered_count += 1
|
|
91
|
+
|
|
92
|
+
# Log if any messages were filtered
|
|
93
|
+
if filtered_count > 0:
|
|
94
|
+
logger.info(
|
|
95
|
+
"Filtered incomplete messages before saving",
|
|
96
|
+
extra={
|
|
97
|
+
"filtered_count": filtered_count,
|
|
98
|
+
"total_messages": len(messages),
|
|
99
|
+
"filtered_tool_names": filtered_tool_names,
|
|
100
|
+
},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return filtered
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def filter_orphaned_tool_responses(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
107
|
+
"""Filter out tool responses without corresponding tool calls.
|
|
108
|
+
|
|
109
|
+
This ensures message history is valid for OpenAI API which requires
|
|
110
|
+
tool responses to follow their corresponding tool calls.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
messages: List of messages to filter
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of messages with orphaned tool responses removed
|
|
117
|
+
"""
|
|
118
|
+
# Collect all tool_call_ids from ToolCallPart in ModelResponse
|
|
119
|
+
valid_tool_call_ids: set[str] = set()
|
|
120
|
+
for msg in messages:
|
|
121
|
+
if isinstance(msg, ModelResponse):
|
|
122
|
+
for part in msg.parts:
|
|
123
|
+
if isinstance(part, ToolCallPart) and part.tool_call_id:
|
|
124
|
+
valid_tool_call_ids.add(part.tool_call_id)
|
|
125
|
+
|
|
126
|
+
# Filter out orphaned ToolReturnPart from ModelRequest
|
|
127
|
+
filtered: list[ModelMessage] = []
|
|
128
|
+
orphaned_count = 0
|
|
129
|
+
orphaned_tool_names: list[str] = []
|
|
130
|
+
|
|
131
|
+
for msg in messages:
|
|
132
|
+
if isinstance(msg, ModelRequest):
|
|
133
|
+
# Filter parts, removing orphaned ToolReturnPart
|
|
134
|
+
filtered_parts: list[ModelRequestPart] = []
|
|
135
|
+
request_part: ModelRequestPart
|
|
136
|
+
for request_part in msg.parts:
|
|
137
|
+
if isinstance(request_part, ToolReturnPart):
|
|
138
|
+
if request_part.tool_call_id in valid_tool_call_ids:
|
|
139
|
+
filtered_parts.append(request_part)
|
|
140
|
+
else:
|
|
141
|
+
# Skip orphaned tool response
|
|
142
|
+
orphaned_count += 1
|
|
143
|
+
orphaned_tool_names.append(request_part.tool_name or "unknown")
|
|
144
|
+
else:
|
|
145
|
+
filtered_parts.append(request_part)
|
|
146
|
+
|
|
147
|
+
# Only add if there are remaining parts
|
|
148
|
+
if filtered_parts:
|
|
149
|
+
filtered.append(ModelRequest(parts=filtered_parts))
|
|
150
|
+
else:
|
|
151
|
+
filtered.append(msg)
|
|
152
|
+
|
|
153
|
+
# Log if any tool responses were filtered
|
|
154
|
+
if orphaned_count > 0:
|
|
155
|
+
logger.info(
|
|
156
|
+
"Filtered orphaned tool responses",
|
|
157
|
+
extra={
|
|
158
|
+
"orphaned_count": orphaned_count,
|
|
159
|
+
"total_messages": len(messages),
|
|
160
|
+
"orphaned_tool_names": orphaned_tool_names,
|
|
161
|
+
},
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return filtered
|