dtSpark 1.1.0a3__py3-none-any.whl → 1.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. dtSpark/_version.txt +1 -1
  2. dtSpark/aws/authentication.py +1 -1
  3. dtSpark/aws/bedrock.py +238 -239
  4. dtSpark/aws/costs.py +9 -5
  5. dtSpark/aws/pricing.py +25 -21
  6. dtSpark/cli_interface.py +69 -62
  7. dtSpark/conversation_manager.py +54 -47
  8. dtSpark/core/application.py +112 -91
  9. dtSpark/core/context_compaction.py +241 -226
  10. dtSpark/daemon/__init__.py +36 -22
  11. dtSpark/daemon/action_monitor.py +46 -17
  12. dtSpark/daemon/daemon_app.py +126 -104
  13. dtSpark/daemon/daemon_manager.py +59 -23
  14. dtSpark/daemon/pid_file.py +3 -2
  15. dtSpark/database/autonomous_actions.py +3 -0
  16. dtSpark/database/credential_prompt.py +52 -54
  17. dtSpark/files/manager.py +6 -12
  18. dtSpark/limits/__init__.py +1 -1
  19. dtSpark/limits/tokens.py +2 -2
  20. dtSpark/llm/anthropic_direct.py +246 -141
  21. dtSpark/llm/ollama.py +3 -1
  22. dtSpark/mcp_integration/manager.py +4 -4
  23. dtSpark/mcp_integration/tool_selector.py +83 -77
  24. dtSpark/resources/config.yaml.template +10 -0
  25. dtSpark/safety/patterns.py +45 -46
  26. dtSpark/safety/prompt_inspector.py +8 -1
  27. dtSpark/scheduler/creation_tools.py +273 -181
  28. dtSpark/scheduler/executor.py +503 -221
  29. dtSpark/tools/builtin.py +70 -53
  30. dtSpark/web/endpoints/autonomous_actions.py +12 -9
  31. dtSpark/web/endpoints/chat.py +8 -6
  32. dtSpark/web/endpoints/conversations.py +11 -9
  33. dtSpark/web/endpoints/main_menu.py +132 -105
  34. dtSpark/web/endpoints/streaming.py +2 -2
  35. dtSpark/web/server.py +65 -5
  36. dtSpark/web/ssl_utils.py +3 -3
  37. dtSpark/web/static/css/dark-theme.css +8 -29
  38. dtSpark/web/static/js/chat.js +6 -8
  39. dtSpark/web/static/js/main.js +8 -8
  40. dtSpark/web/static/js/sse-client.js +130 -122
  41. dtSpark/web/templates/actions.html +5 -5
  42. dtSpark/web/templates/base.html +13 -0
  43. dtSpark/web/templates/chat.html +10 -10
  44. dtSpark/web/templates/conversations.html +2 -2
  45. dtSpark/web/templates/goodbye.html +2 -2
  46. dtSpark/web/templates/main_menu.html +17 -17
  47. dtSpark/web/web_interface.py +2 -2
  48. {dtspark-1.1.0a3.dist-info → dtspark-1.1.0a6.dist-info}/METADATA +9 -2
  49. dtspark-1.1.0a6.dist-info/RECORD +96 -0
  50. dtspark-1.1.0a3.dist-info/RECORD +0 -96
  51. {dtspark-1.1.0a3.dist-info → dtspark-1.1.0a6.dist-info}/WHEEL +0 -0
  52. {dtspark-1.1.0a3.dist-info → dtspark-1.1.0a6.dist-info}/entry_points.txt +0 -0
  53. {dtspark-1.1.0a3.dist-info → dtspark-1.1.0a6.dist-info}/licenses/LICENSE +0 -0
  54. {dtspark-1.1.0a3.dist-info → dtspark-1.1.0a6.dist-info}/top_level.txt +0 -0
dtSpark/_version.txt CHANGED
@@ -1 +1 @@
1
- 1.1.0a3
1
+ 1.1.0a6
@@ -161,7 +161,7 @@ class AWSAuthenticator:
161
161
  # Check if this is an SSO token expiration error
162
162
  if 'Token has expired' in error_str or 'refresh failed' in error_str:
163
163
  logging.warning("AWS SSO token has expired")
164
- logging.info(f"Attempting automatic re-authentication...")
164
+ logging.info("Attempting automatic re-authentication...")
165
165
 
166
166
  # Try to trigger SSO login automatically
167
167
  if self.trigger_sso_login():
dtSpark/aws/bedrock.py CHANGED
@@ -57,77 +57,28 @@ class BedrockService(LLMService):
57
57
  """
58
58
  models = []
59
59
 
60
- # Get inference profiles (recommended approach)
61
60
  try:
62
61
  response = self.bedrock_client.list_inference_profiles()
63
62
 
64
63
  for profile in response.get('inferenceProfileSummaries', []):
65
- # Only include ACTIVE profiles
66
64
  if profile.get('status') != 'ACTIVE':
67
65
  continue
68
66
 
69
- # Extract model info from the profile
70
67
  profile_models = profile.get('models', [])
71
68
  model_id = profile_models[0].get('modelArn', '').split('/')[-1] if profile_models else 'unknown'
72
- profile_name_lower = profile['inferenceProfileName'].lower()
73
- model_id_lower = model_id.lower()
74
69
 
75
- # Filter out embedding models (they can't be used for chat)
76
- if 'embed' in profile_name_lower or 'embed' in model_id_lower:
77
- logging.debug(f"Skipping embedding model: {profile['inferenceProfileName']}")
70
+ if self._should_skip_profile(profile, model_id):
78
71
  continue
79
72
 
80
- # Filter out image/vision-only models
81
- if 'stable-diffusion' in profile_name_lower or 'stable-diffusion' in model_id_lower:
82
- logging.debug(f"Skipping image generation model: {profile['inferenceProfileName']}")
73
+ if not self._verify_model_access(profile, profile_models):
83
74
  continue
84
75
 
85
- # Verify access to the underlying foundation model
86
- # Check if the model has been granted access
87
- try:
88
- # Try to get the foundation model details to verify access
89
- if profile_models and len(profile_models) > 0:
90
- foundation_model_arn = profile_models[0].get('modelArn', '')
91
- if foundation_model_arn:
92
- # Extract the model ID from the ARN
93
- foundation_model_id = foundation_model_arn.split('/')[-1]
94
- try:
95
- # Attempt to get foundation model details
96
- self.bedrock_client.get_foundation_model(modelIdentifier=foundation_model_id)
97
- except ClientError as model_error:
98
- # If we get access denied or validation error, skip this model
99
- error_code = model_error.response.get('Error', {}).get('Code', '')
100
- if error_code in ['AccessDeniedException', 'ValidationException', 'ResourceNotFoundException']:
101
- logging.debug(f"Skipping model without access: {profile['inferenceProfileName']} ({error_code})")
102
- continue
103
- # For other errors, log but continue (might be accessible)
104
- logging.debug(f"Could not verify access for {profile['inferenceProfileName']}: {error_code}")
105
- except Exception as verify_error:
106
- logging.debug(f"Error verifying model access for {profile['inferenceProfileName']}: {verify_error}")
107
- # If we can't verify, skip it to be safe
108
- continue
109
-
110
- # Determine model maker from model ID or profile name
111
- model_maker = 'Unknown'
112
-
113
- if 'anthropic' in model_id_lower or 'anthropic' in profile_name_lower or 'claude' in profile_name_lower:
114
- model_maker = 'Anthropic'
115
- elif 'amazon' in model_id_lower or 'amazon' in profile_name_lower or 'titan' in profile_name_lower:
116
- model_maker = 'Amazon'
117
- elif 'meta' in model_id_lower or 'meta' in profile_name_lower or 'llama' in profile_name_lower:
118
- model_maker = 'Meta'
119
- elif 'ai21' in model_id_lower or 'ai21' in profile_name_lower or 'jamba' in profile_name_lower:
120
- model_maker = 'AI21'
121
- elif 'cohere' in model_id_lower or 'cohere' in profile_name_lower:
122
- model_maker = 'Cohere'
123
- elif 'mistral' in model_id_lower or 'mistral' in profile_name_lower:
124
- model_maker = 'Mistral'
76
+ model_maker = self._detect_model_maker(model_id, profile['inferenceProfileName'])
125
77
 
126
78
  models.append({
127
79
  'id': profile['inferenceProfileArn'],
128
80
  'name': profile['inferenceProfileName'],
129
- 'model_maker': model_maker, # Model creator (Anthropic, Meta, etc.)
130
- # 'provider' will be added by LLM manager to indicate service (AWS Bedrock)
81
+ 'model_maker': model_maker,
131
82
  'access_info': self.get_access_info(),
132
83
  'input_modalities': ['TEXT'],
133
84
  'output_modalities': ['TEXT'],
@@ -142,12 +93,67 @@ class BedrockService(LLMService):
142
93
  except Exception as e:
143
94
  logging.error(f"Unexpected error listing inference profiles: {e}")
144
95
 
145
- # Sort models by model maker and name for better display
146
96
  models.sort(key=lambda x: (x.get('model_maker', 'Unknown'), x['name']))
147
-
148
97
  logging.info(f"Total available models: {len(models)}")
149
98
  return models
150
99
 
100
+ @staticmethod
101
+ def _should_skip_profile(profile: Dict[str, Any], model_id: str) -> bool:
102
+ """Check whether an inference profile should be excluded from the model list."""
103
+ profile_name_lower = profile['inferenceProfileName'].lower()
104
+ model_id_lower = model_id.lower()
105
+
106
+ if 'embed' in profile_name_lower or 'embed' in model_id_lower:
107
+ logging.debug(f"Skipping embedding model: {profile['inferenceProfileName']}")
108
+ return True
109
+
110
+ if 'stable-diffusion' in profile_name_lower or 'stable-diffusion' in model_id_lower:
111
+ logging.debug(f"Skipping image generation model: {profile['inferenceProfileName']}")
112
+ return True
113
+
114
+ return False
115
+
116
+ def _verify_model_access(self, profile: Dict[str, Any], profile_models: List[Dict[str, Any]]) -> bool:
117
+ """Verify that the underlying foundation model is accessible. Returns True if accessible."""
118
+ _NO_ACCESS_CODES = {'AccessDeniedException', 'ValidationException', 'ResourceNotFoundException'}
119
+ try:
120
+ if not profile_models:
121
+ return True
122
+ foundation_model_arn = profile_models[0].get('modelArn', '')
123
+ if not foundation_model_arn:
124
+ return True
125
+ foundation_model_id = foundation_model_arn.split('/')[-1]
126
+ try:
127
+ self.bedrock_client.get_foundation_model(modelIdentifier=foundation_model_id)
128
+ except ClientError as model_error:
129
+ error_code = model_error.response.get('Error', {}).get('Code', '')
130
+ if error_code in _NO_ACCESS_CODES:
131
+ logging.debug(f"Skipping model without access: {profile['inferenceProfileName']} ({error_code})")
132
+ return False
133
+ logging.debug(f"Could not verify access for {profile['inferenceProfileName']}: {error_code}")
134
+ except Exception as verify_error:
135
+ logging.debug(f"Error verifying model access for {profile['inferenceProfileName']}: {verify_error}")
136
+ return False
137
+ return True
138
+
139
+ @staticmethod
140
+ def _detect_model_maker(model_id: str, profile_name: str) -> str:
141
+ """Determine the model maker from a model ID and profile name."""
142
+ id_lower = model_id.lower()
143
+ name_lower = profile_name.lower()
144
+ maker_keywords = [
145
+ ('Anthropic', ['anthropic', 'claude']),
146
+ ('Amazon', ['amazon', 'titan']),
147
+ ('Meta', ['meta', 'llama']),
148
+ ('AI21', ['ai21', 'jamba']),
149
+ ('Cohere', ['cohere']),
150
+ ('Mistral', ['mistral']),
151
+ ]
152
+ for maker, keywords in maker_keywords:
153
+ if any(kw in id_lower or kw in name_lower for kw in keywords):
154
+ return maker
155
+ return 'Unknown'
156
+
151
157
  def set_model(self, model_id: str):
152
158
  """
153
159
  Set the current model for chat operations.
@@ -173,6 +179,17 @@ class BedrockService(LLMService):
173
179
 
174
180
  logging.info(f"{'Inference profile' if self.is_inference_profile else 'Model'} set to: {model_id}")
175
181
 
182
+ # Transient error codes that should be retried
183
+ _TRANSIENT_ERRORS = {
184
+ 'ThrottlingException',
185
+ 'TooManyRequestsException',
186
+ 'ModelTimeoutException',
187
+ 'ServiceUnavailableException',
188
+ 'InternalServerError',
189
+ 'ModelNotReadyException',
190
+ 'ModelStreamErrorException',
191
+ }
192
+
176
193
  def invoke_model(self, messages: List[Dict[str, str]], max_tokens: int = 4096,
177
194
  temperature: float = 0.7, tools: Optional[List[Dict[str, Any]]] = None,
178
195
  system: Optional[str] = None, max_retries: int = 3) -> Optional[Dict[str, Any]]:
@@ -200,17 +217,6 @@ class BedrockService(LLMService):
200
217
  'error_type': 'ConfigurationError'
201
218
  }
202
219
 
203
- # Transient error codes that should be retried
204
- transient_errors = [
205
- 'ThrottlingException',
206
- 'TooManyRequestsException',
207
- 'ModelTimeoutException',
208
- 'ServiceUnavailableException',
209
- 'InternalServerError',
210
- 'ModelNotReadyException',
211
- 'ModelStreamErrorException'
212
- ]
213
-
214
220
  import time
215
221
  attempt = 0
216
222
 
@@ -218,77 +224,14 @@ class BedrockService(LLMService):
218
224
  if attempt > 1:
219
225
  logging.info(f"Retry attempt {attempt}/{max_retries} for model invocation")
220
226
 
221
- try:
222
- # Format the request based on the model provider
223
- request_body = self._format_request(messages, max_tokens, temperature, tools, system)
224
-
225
- # Invoke the model or inference profile
226
- # Note: modelId parameter accepts both model IDs and inference profile ARNs
227
- logging.debug(f"Invoking {'inference profile' if self.is_inference_profile else 'model'}: {self.current_model_id}")
228
-
229
- # Log the request for debugging
230
- logging.debug(f"Request body keys: {list(request_body.keys())}")
231
- if 'tools' in request_body:
232
- logging.debug(f"Tools count: {len(request_body['tools'])}")
233
- logging.debug(f"max_tokens is set to {max_tokens}")
234
- try:
235
- response = self.bedrock_runtime_client.invoke_model(
236
- modelId=self.current_model_id,
237
- contentType='application/json',
238
- accept='application/json',
239
- body=json.dumps(request_body)
240
- )
241
- except Exception as api_error:
242
- logging.error(f"Bedrock API error: {api_error}")
243
- logging.error(f"Request body: {json.dumps(request_body, indent=2)}")
244
- raise
245
-
246
- # Parse the response
247
- response_body = json.loads(response['body'].read())
248
- parsed_response = self._parse_response(response_body)
249
-
250
- logging.debug(f"{'Inference profile' if self.is_inference_profile else 'Model'} invoked successfully: {self.current_model_id}")
251
- return parsed_response
252
-
253
- except ClientError as e:
254
- error_code = e.response['Error']['Code']
255
- error_message = e.response['Error']['Message']
256
-
257
- # Log detailed error information
258
- logging.error(f"Bedrock API error - Code: {error_code}, Message: {error_message}")
259
-
260
- # Check if this is a transient error that should be retried
261
- if error_code in transient_errors and attempt <= max_retries:
262
- wait_time = min(2 ** (attempt - 1), 30) # Exponential backoff, max 30 seconds
263
- logging.warning(f"Transient error {error_code}, retrying in {wait_time} seconds... (attempt {attempt}/{max_retries})")
264
- time.sleep(wait_time)
265
- continue # Retry
266
-
267
- # Non-transient error or max retries reached - return error details
268
- return {
269
- 'error': True,
270
- 'error_code': error_code,
271
- 'error_message': error_message,
272
- 'error_type': 'ClientError',
273
- 'retries_attempted': attempt - 1
274
- }
275
-
276
- except Exception as e:
277
- logging.error(f"Unexpected error invoking {'inference profile' if self.is_inference_profile else 'model'}: {e}")
278
- logging.error(f"Error type: {type(e).__name__}")
279
- import traceback
280
- logging.error(f"Traceback: {traceback.format_exc()}")
281
-
282
- # Return error details (unexpected errors are not retried)
283
- return {
284
- 'error': True,
285
- 'error_code': type(e).__name__,
286
- 'error_message': str(e),
287
- 'error_type': 'Exception',
288
- 'retries_attempted': 0
289
- }
290
-
291
- # Should not reach here, but just in case
227
+ result = self._attempt_invocation(messages, max_tokens, temperature, tools, system, attempt, max_retries)
228
+ if result.get('_retry'):
229
+ wait_time = min(2 ** (attempt - 1), 30)
230
+ time.sleep(wait_time)
231
+ attempt += 1
232
+ continue
233
+ return result
234
+
292
235
  return {
293
236
  'error': True,
294
237
  'error_code': 'MaxRetriesExceeded',
@@ -297,6 +240,67 @@ class BedrockService(LLMService):
297
240
  'retries_attempted': max_retries
298
241
  }
299
242
 
243
+ def _attempt_invocation(self, messages, max_tokens, temperature, tools, system,
244
+ attempt, max_retries) -> Dict[str, Any]:
245
+ """Execute a single model invocation attempt. Returns a _retry sentinel on transient failure."""
246
+ model_label = 'inference profile' if self.is_inference_profile else 'model'
247
+ try:
248
+ request_body = self._format_request(messages, max_tokens, temperature, tools, system)
249
+ logging.debug(f"Invoking {model_label}: {self.current_model_id}")
250
+ logging.debug(f"Request body keys: {list(request_body.keys())}")
251
+ if 'tools' in request_body:
252
+ logging.debug(f"Tools count: {len(request_body['tools'])}")
253
+ logging.debug(f"max_tokens is set to {max_tokens}")
254
+
255
+ try:
256
+ response = self.bedrock_runtime_client.invoke_model(
257
+ modelId=self.current_model_id,
258
+ contentType='application/json',
259
+ accept='application/json',
260
+ body=json.dumps(request_body)
261
+ )
262
+ except Exception as api_error:
263
+ logging.error(f"Bedrock API error: {api_error}")
264
+ logging.error(f"Request body: {json.dumps(request_body, indent=2)}")
265
+ raise
266
+
267
+ response_body = json.loads(response['body'].read())
268
+ parsed_response = self._parse_response(response_body)
269
+ logging.debug(f"{model_label} invoked successfully: {self.current_model_id}")
270
+ return parsed_response
271
+
272
+ except ClientError as e:
273
+ error_code = e.response['Error']['Code']
274
+ error_message = e.response['Error']['Message']
275
+ logging.error(f"Bedrock API error - Code: {error_code}, Message: {error_message}")
276
+
277
+ if error_code in self._TRANSIENT_ERRORS and attempt <= max_retries:
278
+ wait_time = min(2 ** (attempt - 1), 30)
279
+ logging.warning(f"Transient error {error_code}, retrying in {wait_time} seconds... "
280
+ f"(attempt {attempt}/{max_retries})")
281
+ return {'_retry': True}
282
+
283
+ return {
284
+ 'error': True,
285
+ 'error_code': error_code,
286
+ 'error_message': error_message,
287
+ 'error_type': 'ClientError',
288
+ 'retries_attempted': attempt - 1
289
+ }
290
+
291
+ except Exception as e:
292
+ logging.error(f"Unexpected error invoking {model_label}: {e}")
293
+ logging.error(f"Error type: {type(e).__name__}")
294
+ import traceback
295
+ logging.error(f"Traceback: {traceback.format_exc()}")
296
+ return {
297
+ 'error': True,
298
+ 'error_code': type(e).__name__,
299
+ 'error_message': str(e),
300
+ 'error_type': 'Exception',
301
+ 'retries_attempted': 0
302
+ }
303
+
300
304
  def _format_request(self, messages: List[Dict[str, str]], max_tokens: int,
301
305
  temperature: float, tools: Optional[List[Dict[str, Any]]] = None,
302
306
  system: Optional[str] = None) -> Dict[str, Any]:
@@ -402,96 +406,92 @@ class BedrockService(LLMService):
402
406
  Returns:
403
407
  Standardised response dictionary
404
408
  """
405
- # Use model_identifier for provider detection (works for both direct models and profiles)
406
409
  model_id = self.model_identifier or self.current_model_id
410
+ model_lower = model_id.lower()
411
+
412
+ if 'anthropic.claude' in model_id or 'anthropic' in model_lower:
413
+ return self._parse_anthropic_response(response_body)
414
+ if 'amazon.titan' in model_id or 'titan' in model_lower:
415
+ return self._parse_titan_response(response_body)
416
+ if 'meta.llama' in model_id or 'llama' in model_lower:
417
+ return self._parse_llama_response(response_body)
418
+ if 'ai21' in model_id:
419
+ return self._parse_ai21_response(response_body)
420
+ if 'cohere' in model_id:
421
+ return self._parse_cohere_response(response_body)
422
+
423
+ return {'content': str(response_body), 'stop_reason': None, 'usage': {}}
424
+
425
+ @staticmethod
426
+ def _parse_anthropic_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
427
+ """Parse an Anthropic Claude response."""
428
+ content_blocks = response_body.get('content', [])
429
+ parsed_content = []
430
+ text_parts = []
431
+
432
+ for block in content_blocks:
433
+ if block.get('type') == 'text':
434
+ text_parts.append(block.get('text', ''))
435
+ parsed_content.append({'type': 'text', 'text': block.get('text', '')})
436
+ elif block.get('type') == 'tool_use':
437
+ parsed_content.append({
438
+ 'type': 'tool_use',
439
+ 'id': block.get('id'),
440
+ 'name': block.get('name'),
441
+ 'input': block.get('input', {})
442
+ })
407
443
 
408
- # Anthropic Claude models
409
- if 'anthropic.claude' in model_id or 'anthropic' in model_id.lower():
410
- content_blocks = response_body.get('content', [])
411
-
412
- # Parse content blocks (can be text or tool_use)
413
- parsed_content = []
414
- text_parts = []
415
-
416
- for block in content_blocks:
417
- if block.get('type') == 'text':
418
- text_parts.append(block.get('text', ''))
419
- parsed_content.append({
420
- 'type': 'text',
421
- 'text': block.get('text', '')
422
- })
423
- elif block.get('type') == 'tool_use':
424
- parsed_content.append({
425
- 'type': 'tool_use',
426
- 'id': block.get('id'),
427
- 'name': block.get('name'),
428
- 'input': block.get('input', {})
429
- })
430
-
431
- # Combine text for backwards compatibility
432
- text = '\n'.join(text_parts) if text_parts else ''
433
-
434
- return {
435
- 'content': text,
436
- 'content_blocks': parsed_content,
437
- 'stop_reason': response_body.get('stop_reason'),
438
- 'usage': response_body.get('usage', {})
439
- }
440
-
441
- # Amazon Titan models
442
- elif 'amazon.titan' in model_id or 'titan' in model_id.lower():
443
- results = response_body.get('results', [])
444
- text = results[0].get('outputText', '') if results else ''
445
-
446
- return {
447
- 'content': text,
448
- 'stop_reason': results[0].get('completionReason') if results else None,
449
- 'usage': {
450
- 'input_tokens': response_body.get('inputTextTokenCount', 0),
451
- 'output_tokens': response_body.get('results', [{}])[0].get('tokenCount', 0)
452
- }
453
- }
444
+ return {
445
+ 'content': '\n'.join(text_parts) if text_parts else '',
446
+ 'content_blocks': parsed_content,
447
+ 'stop_reason': response_body.get('stop_reason'),
448
+ 'usage': response_body.get('usage', {})
449
+ }
454
450
 
455
- # Meta Llama models
456
- elif 'meta.llama' in model_id or 'llama' in model_id.lower():
457
- return {
458
- 'content': response_body.get('generation', ''),
459
- 'stop_reason': response_body.get('stop_reason'),
460
- 'usage': {
461
- 'input_tokens': response_body.get('prompt_token_count', 0),
462
- 'output_tokens': response_body.get('generation_token_count', 0)
463
- }
451
+ @staticmethod
452
+ def _parse_titan_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
453
+ """Parse an Amazon Titan response."""
454
+ results = response_body.get('results', [])
455
+ return {
456
+ 'content': results[0].get('outputText', '') if results else '',
457
+ 'stop_reason': results[0].get('completionReason') if results else None,
458
+ 'usage': {
459
+ 'input_tokens': response_body.get('inputTextTokenCount', 0),
460
+ 'output_tokens': response_body.get('results', [{}])[0].get('tokenCount', 0)
464
461
  }
462
+ }
465
463
 
466
- # AI21 models
467
- elif 'ai21' in model_id:
468
- completions = response_body.get('completions', [])
469
- text = completions[0].get('data', {}).get('text', '') if completions else ''
470
-
471
- return {
472
- 'content': text,
473
- 'stop_reason': completions[0].get('finishReason', {}).get('reason') if completions else None,
474
- 'usage': {}
464
+ @staticmethod
465
+ def _parse_llama_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
466
+ """Parse a Meta Llama response."""
467
+ return {
468
+ 'content': response_body.get('generation', ''),
469
+ 'stop_reason': response_body.get('stop_reason'),
470
+ 'usage': {
471
+ 'input_tokens': response_body.get('prompt_token_count', 0),
472
+ 'output_tokens': response_body.get('generation_token_count', 0)
475
473
  }
474
+ }
476
475
 
477
- # Cohere models
478
- elif 'cohere' in model_id:
479
- generations = response_body.get('generations', [])
480
- text = generations[0].get('text', '') if generations else ''
481
-
482
- return {
483
- 'content': text,
484
- 'stop_reason': generations[0].get('finish_reason') if generations else None,
485
- 'usage': {}
486
- }
476
+ @staticmethod
477
+ def _parse_ai21_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
478
+ """Parse an AI21 response."""
479
+ completions = response_body.get('completions', [])
480
+ return {
481
+ 'content': completions[0].get('data', {}).get('text', '') if completions else '',
482
+ 'stop_reason': completions[0].get('finishReason', {}).get('reason') if completions else None,
483
+ 'usage': {}
484
+ }
487
485
 
488
- # Default fallback
489
- else:
490
- return {
491
- 'content': str(response_body),
492
- 'stop_reason': None,
493
- 'usage': {}
494
- }
486
+ @staticmethod
487
+ def _parse_cohere_response(response_body: Dict[str, Any]) -> Dict[str, Any]:
488
+ """Parse a Cohere response."""
489
+ generations = response_body.get('generations', [])
490
+ return {
491
+ 'content': generations[0].get('text', '') if generations else '',
492
+ 'stop_reason': generations[0].get('finish_reason') if generations else None,
493
+ 'usage': {}
494
+ }
495
495
 
496
496
  def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
497
497
  """
@@ -544,30 +544,29 @@ class BedrockService(LLMService):
544
544
  Estimated total token count for all messages
545
545
  """
546
546
  total_tokens = 0
547
-
548
547
  for message in messages:
549
- # Count tokens for role (small overhead)
550
548
  total_tokens += 4 # Approximate overhead for role formatting
551
-
552
- # Count tokens in content
553
- content = message.get('content', '')
554
-
555
- # Handle content that might be a list (for multi-part content)
556
- if isinstance(content, list):
557
- for part in content:
558
- if isinstance(part, dict):
559
- if 'text' in part:
560
- total_tokens += self.count_tokens(part['text'], model_id)
561
- # Add overhead for other content types (images, documents, etc.)
562
- elif 'image' in part or 'document' in part:
563
- total_tokens += 1000 # Rough estimate for non-text content
564
- elif isinstance(part, str):
565
- total_tokens += self.count_tokens(part, model_id)
566
- elif isinstance(content, str):
567
- total_tokens += self.count_tokens(content, model_id)
568
-
549
+ total_tokens += self._count_content_tokens(message.get('content', ''), model_id)
569
550
  return total_tokens
570
551
 
552
+ def _count_content_tokens(self, content, model_id: Optional[str] = None) -> int:
553
+ """Count tokens for a single message content value (string or list of parts)."""
554
+ if isinstance(content, str):
555
+ return self.count_tokens(content, model_id)
556
+ if not isinstance(content, list):
557
+ return 0
558
+
559
+ total = 0
560
+ for part in content:
561
+ if isinstance(part, str):
562
+ total += self.count_tokens(part, model_id)
563
+ elif isinstance(part, dict):
564
+ if 'text' in part:
565
+ total += self.count_tokens(part['text'], model_id)
566
+ elif 'image' in part or 'document' in part:
567
+ total += 1000 # Rough estimate for non-text content
568
+ return total
569
+
571
570
  def get_current_model_id(self) -> Optional[str]:
572
571
  """
573
572
  Get the currently selected model ID.
dtSpark/aws/costs.py CHANGED
@@ -12,6 +12,10 @@ from datetime import datetime, timedelta
12
12
  from typing import Dict, Optional, List
13
13
  from botocore.exceptions import ClientError
14
14
 
15
+ # Period name constants
16
+ _PERIOD_CURRENT_MONTH = 'Current Month'
17
+ _PERIOD_LAST_MONTH = 'Last Month'
18
+
15
19
 
16
20
  class CostTracker:
17
21
  """Tracks AWS Bedrock costs using Cost Explorer API."""
@@ -95,14 +99,14 @@ class CostTracker:
95
99
  current_month_costs = self._get_costs_for_period(
96
100
  first_day_this_month,
97
101
  today + timedelta(days=1), # End date is exclusive, so add 1 day to include today
98
- 'Current Month'
102
+ _PERIOD_CURRENT_MONTH
99
103
  )
100
104
 
101
105
  # Get last month's costs
102
106
  last_month_costs = self._get_costs_for_period(
103
107
  first_day_last_month,
104
108
  first_day_this_month, # End date is exclusive
105
- 'Last Month'
109
+ _PERIOD_LAST_MONTH
106
110
  )
107
111
 
108
112
  # Get last 24 hours costs
@@ -144,7 +148,7 @@ class CostTracker:
144
148
  'Start': start_date.strftime('%Y-%m-%d'),
145
149
  'End': end_date.strftime('%Y-%m-%d')
146
150
  },
147
- Granularity='MONTHLY' if period_name in ['Last Month', 'Current Month'] else 'DAILY',
151
+ Granularity='MONTHLY' if period_name in [_PERIOD_LAST_MONTH, _PERIOD_CURRENT_MONTH] else 'DAILY',
148
152
  Metrics=['UnblendedCost'],
149
153
  GroupBy=[{
150
154
  'Type': 'DIMENSION',
@@ -269,7 +273,7 @@ class CostTracker:
269
273
  if current_month:
270
274
  total = current_month.get('total', 0.0)
271
275
  breakdown = current_month.get('breakdown', {})
272
- period = current_month.get('period', 'Current Month')
276
+ period = current_month.get('period', _PERIOD_CURRENT_MONTH)
273
277
 
274
278
  lines.append(f"{period}: ${total:.2f} {currency}")
275
279
 
@@ -285,7 +289,7 @@ class CostTracker:
285
289
  if last_month:
286
290
  total = last_month.get('total', 0.0)
287
291
  breakdown = last_month.get('breakdown', {})
288
- period = last_month.get('period', 'Last Month')
292
+ period = last_month.get('period', _PERIOD_LAST_MONTH)
289
293
 
290
294
  lines.append(f"{period}: ${total:.2f} {currency}")
291
295