empathy-framework 5.0.1__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {empathy_framework-5.0.1.dist-info → empathy_framework-5.1.0.dist-info}/METADATA +311 -150
- {empathy_framework-5.0.1.dist-info → empathy_framework-5.1.0.dist-info}/RECORD +60 -33
- empathy_framework-5.1.0.dist-info/licenses/LICENSE +201 -0
- empathy_framework-5.1.0.dist-info/licenses/LICENSE_CHANGE_ANNOUNCEMENT.md +101 -0
- empathy_llm_toolkit/providers.py +175 -35
- empathy_llm_toolkit/utils/tokens.py +150 -30
- empathy_os/__init__.py +1 -1
- empathy_os/cli/commands/batch.py +256 -0
- empathy_os/cli/commands/cache.py +248 -0
- empathy_os/cli/commands/inspect.py +1 -2
- empathy_os/cli/commands/metrics.py +1 -1
- empathy_os/cli/commands/routing.py +285 -0
- empathy_os/cli/commands/workflow.py +2 -1
- empathy_os/cli/parsers/__init__.py +6 -0
- empathy_os/cli/parsers/batch.py +118 -0
- empathy_os/cli/parsers/cache 2.py +65 -0
- empathy_os/cli/parsers/cache.py +65 -0
- empathy_os/cli/parsers/routing.py +110 -0
- empathy_os/cli_minimal.py +3 -3
- empathy_os/cli_router 2.py +416 -0
- empathy_os/dashboard/__init__.py +1 -2
- empathy_os/dashboard/app 2.py +512 -0
- empathy_os/dashboard/app.py +1 -1
- empathy_os/dashboard/simple_server 2.py +403 -0
- empathy_os/dashboard/standalone_server 2.py +536 -0
- empathy_os/dashboard/standalone_server.py +22 -11
- empathy_os/memory/types 2.py +441 -0
- empathy_os/metrics/collector.py +31 -0
- empathy_os/models/__init__.py +19 -0
- empathy_os/models/adaptive_routing 2.py +437 -0
- empathy_os/models/auth_cli.py +444 -0
- empathy_os/models/auth_strategy.py +450 -0
- empathy_os/models/token_estimator.py +21 -13
- empathy_os/project_index/scanner_parallel 2.py +291 -0
- empathy_os/telemetry/agent_coordination 2.py +478 -0
- empathy_os/telemetry/agent_coordination.py +14 -16
- empathy_os/telemetry/agent_tracking 2.py +350 -0
- empathy_os/telemetry/agent_tracking.py +18 -20
- empathy_os/telemetry/approval_gates 2.py +563 -0
- empathy_os/telemetry/approval_gates.py +27 -39
- empathy_os/telemetry/event_streaming 2.py +405 -0
- empathy_os/telemetry/event_streaming.py +22 -22
- empathy_os/telemetry/feedback_loop 2.py +557 -0
- empathy_os/telemetry/feedback_loop.py +14 -17
- empathy_os/workflows/__init__.py +8 -0
- empathy_os/workflows/autonomous_test_gen.py +569 -0
- empathy_os/workflows/batch_processing.py +56 -10
- empathy_os/workflows/bug_predict.py +45 -0
- empathy_os/workflows/code_review.py +92 -22
- empathy_os/workflows/document_gen.py +594 -62
- empathy_os/workflows/llm_base.py +363 -0
- empathy_os/workflows/perf_audit.py +69 -0
- empathy_os/workflows/release_prep.py +54 -0
- empathy_os/workflows/security_audit.py +154 -79
- empathy_os/workflows/test_gen.py +60 -0
- empathy_os/workflows/test_gen_behavioral.py +477 -0
- empathy_os/workflows/test_gen_parallel.py +341 -0
- empathy_framework-5.0.1.dist-info/licenses/LICENSE +0 -139
- {empathy_framework-5.0.1.dist-info → empathy_framework-5.1.0.dist-info}/WHEEL +0 -0
- {empathy_framework-5.0.1.dist-info → empathy_framework-5.1.0.dist-info}/entry_points.txt +0 -0
- {empathy_framework-5.0.1.dist-info → empathy_framework-5.1.0.dist-info}/top_level.txt +0 -0
empathy_llm_toolkit/providers.py
CHANGED
|
@@ -322,6 +322,93 @@ class AnthropicProvider(BaseLLMProvider):
|
|
|
322
322
|
},
|
|
323
323
|
)
|
|
324
324
|
|
|
325
|
+
def estimate_tokens(self, text: str) -> int:
|
|
326
|
+
"""Estimate token count using accurate token counter (overrides base class).
|
|
327
|
+
|
|
328
|
+
Uses tiktoken for fast local estimation (~98% accurate).
|
|
329
|
+
Falls back to heuristic if tiktoken unavailable.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
text: Text to count tokens for
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Estimated token count
|
|
336
|
+
"""
|
|
337
|
+
try:
|
|
338
|
+
from .utils.tokens import count_tokens
|
|
339
|
+
|
|
340
|
+
return count_tokens(text, model=self.model, use_api=False)
|
|
341
|
+
except ImportError:
|
|
342
|
+
# Fallback to base class heuristic if utils not available
|
|
343
|
+
return super().estimate_tokens(text)
|
|
344
|
+
|
|
345
|
+
def calculate_actual_cost(
|
|
346
|
+
self,
|
|
347
|
+
input_tokens: int,
|
|
348
|
+
output_tokens: int,
|
|
349
|
+
cache_creation_tokens: int = 0,
|
|
350
|
+
cache_read_tokens: int = 0,
|
|
351
|
+
) -> dict[str, Any]:
|
|
352
|
+
"""Calculate actual cost based on precise token counts.
|
|
353
|
+
|
|
354
|
+
Includes Anthropic prompt caching cost adjustments:
|
|
355
|
+
- Cache writes: 25% markup over standard input pricing
|
|
356
|
+
- Cache reads: 90% discount from standard input pricing
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
input_tokens: Regular input tokens (not cached)
|
|
360
|
+
output_tokens: Output tokens
|
|
361
|
+
cache_creation_tokens: Tokens written to cache
|
|
362
|
+
cache_read_tokens: Tokens read from cache
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Dictionary with cost breakdown:
|
|
366
|
+
- base_cost: Cost for regular input/output tokens
|
|
367
|
+
- cache_write_cost: Cost for cache creation (if any)
|
|
368
|
+
- cache_read_cost: Cost for cache reads (if any)
|
|
369
|
+
- total_cost: Total cost including all components
|
|
370
|
+
- savings: Amount saved by cache reads vs. full price
|
|
371
|
+
|
|
372
|
+
Example:
|
|
373
|
+
>>> provider = AnthropicProvider(api_key="...")
|
|
374
|
+
>>> cost = provider.calculate_actual_cost(
|
|
375
|
+
... input_tokens=1000,
|
|
376
|
+
... output_tokens=500,
|
|
377
|
+
... cache_read_tokens=10000
|
|
378
|
+
... )
|
|
379
|
+
>>> cost["total_cost"]
|
|
380
|
+
0.0105 # Significantly less than without cache
|
|
381
|
+
"""
|
|
382
|
+
# Get pricing for this model
|
|
383
|
+
model_info = self.get_model_info()
|
|
384
|
+
input_price_per_million = model_info["cost_per_1m_input"]
|
|
385
|
+
output_price_per_million = model_info["cost_per_1m_output"]
|
|
386
|
+
|
|
387
|
+
# Base cost (non-cached tokens)
|
|
388
|
+
base_cost = (input_tokens / 1_000_000) * input_price_per_million
|
|
389
|
+
base_cost += (output_tokens / 1_000_000) * output_price_per_million
|
|
390
|
+
|
|
391
|
+
# Cache write cost (25% markup)
|
|
392
|
+
cache_write_price = input_price_per_million * 1.25
|
|
393
|
+
cache_write_cost = (cache_creation_tokens / 1_000_000) * cache_write_price
|
|
394
|
+
|
|
395
|
+
# Cache read cost (90% discount = 10% of input price)
|
|
396
|
+
cache_read_price = input_price_per_million * 0.1
|
|
397
|
+
cache_read_cost = (cache_read_tokens / 1_000_000) * cache_read_price
|
|
398
|
+
|
|
399
|
+
# Calculate savings from cache reads
|
|
400
|
+
full_price_for_cached = (cache_read_tokens / 1_000_000) * input_price_per_million
|
|
401
|
+
savings = full_price_for_cached - cache_read_cost
|
|
402
|
+
|
|
403
|
+
return {
|
|
404
|
+
"base_cost": round(base_cost, 6),
|
|
405
|
+
"cache_write_cost": round(cache_write_cost, 6),
|
|
406
|
+
"cache_read_cost": round(cache_read_cost, 6),
|
|
407
|
+
"total_cost": round(base_cost + cache_write_cost + cache_read_cost, 6),
|
|
408
|
+
"savings": round(savings, 6),
|
|
409
|
+
"currency": "USD",
|
|
410
|
+
}
|
|
411
|
+
|
|
325
412
|
|
|
326
413
|
class AnthropicBatchProvider:
|
|
327
414
|
"""Provider for Anthropic Batch API (50% cost reduction).
|
|
@@ -370,7 +457,8 @@ class AnthropicBatchProvider:
|
|
|
370
457
|
"""Create a batch job.
|
|
371
458
|
|
|
372
459
|
Args:
|
|
373
|
-
requests: List of request dicts with 'custom_id'
|
|
460
|
+
requests: List of request dicts with 'custom_id' and 'params' containing message creation parameters.
|
|
461
|
+
Format: [{"custom_id": "id1", "params": {"model": "...", "messages": [...], "max_tokens": 1024}}]
|
|
374
462
|
job_id: Optional job identifier for tracking (unused, for API compatibility)
|
|
375
463
|
|
|
376
464
|
Returns:
|
|
@@ -384,22 +472,46 @@ class AnthropicBatchProvider:
|
|
|
384
472
|
>>> requests = [
|
|
385
473
|
... {
|
|
386
474
|
... "custom_id": "task_1",
|
|
387
|
-
... "
|
|
388
|
-
...
|
|
389
|
-
...
|
|
475
|
+
... "params": {
|
|
476
|
+
... "model": "claude-sonnet-4-5-20250929",
|
|
477
|
+
... "messages": [{"role": "user", "content": "Test"}],
|
|
478
|
+
... "max_tokens": 1024
|
|
479
|
+
... }
|
|
390
480
|
... }
|
|
391
481
|
... ]
|
|
392
482
|
>>> batch_id = provider.create_batch(requests)
|
|
393
483
|
>>> print(f"Batch created: {batch_id}")
|
|
394
|
-
Batch created:
|
|
484
|
+
Batch created: msgbatch_abc123
|
|
395
485
|
"""
|
|
396
486
|
if not requests:
|
|
397
487
|
raise ValueError("requests cannot be empty")
|
|
398
488
|
|
|
489
|
+
# Validate and convert old format to new format if needed
|
|
490
|
+
formatted_requests = []
|
|
491
|
+
for req in requests:
|
|
492
|
+
if "params" not in req:
|
|
493
|
+
# Old format: convert to new format with params wrapper
|
|
494
|
+
formatted_req = {
|
|
495
|
+
"custom_id": req.get("custom_id", f"req_{id(req)}"),
|
|
496
|
+
"params": {
|
|
497
|
+
"model": req.get("model", "claude-sonnet-4-5-20250929"),
|
|
498
|
+
"messages": req.get("messages", []),
|
|
499
|
+
"max_tokens": req.get("max_tokens", 4096),
|
|
500
|
+
},
|
|
501
|
+
}
|
|
502
|
+
# Copy other optional params
|
|
503
|
+
for key in ["temperature", "system", "stop_sequences"]:
|
|
504
|
+
if key in req:
|
|
505
|
+
formatted_req["params"][key] = req[key]
|
|
506
|
+
formatted_requests.append(formatted_req)
|
|
507
|
+
else:
|
|
508
|
+
formatted_requests.append(req)
|
|
509
|
+
|
|
399
510
|
try:
|
|
400
|
-
|
|
511
|
+
# Use correct Message Batches API endpoint
|
|
512
|
+
batch = self.client.messages.batches.create(requests=formatted_requests)
|
|
401
513
|
self._batch_jobs[batch.id] = batch
|
|
402
|
-
logger.info(f"Created batch {batch.id} with {len(
|
|
514
|
+
logger.info(f"Created batch {batch.id} with {len(formatted_requests)} requests")
|
|
403
515
|
return batch.id
|
|
404
516
|
except Exception as e:
|
|
405
517
|
logger.error(f"Failed to create batch: {e}")
|
|
@@ -412,18 +524,20 @@ class AnthropicBatchProvider:
|
|
|
412
524
|
batch_id: Batch job ID
|
|
413
525
|
|
|
414
526
|
Returns:
|
|
415
|
-
|
|
416
|
-
- "
|
|
417
|
-
- "
|
|
418
|
-
- "
|
|
527
|
+
MessageBatch object with processing_status field:
|
|
528
|
+
- "in_progress": Batch is being processed
|
|
529
|
+
- "canceling": Cancellation initiated
|
|
530
|
+
- "ended": Batch processing ended (check request_counts for success/errors)
|
|
419
531
|
|
|
420
532
|
Example:
|
|
421
|
-
>>> status = provider.get_batch_status("
|
|
422
|
-
>>> print(status.
|
|
423
|
-
|
|
533
|
+
>>> status = provider.get_batch_status("msgbatch_abc123")
|
|
534
|
+
>>> print(status.processing_status)
|
|
535
|
+
in_progress
|
|
536
|
+
>>> print(f"Succeeded: {status.request_counts.succeeded}")
|
|
424
537
|
"""
|
|
425
538
|
try:
|
|
426
|
-
|
|
539
|
+
# Use correct Message Batches API endpoint
|
|
540
|
+
batch = self.client.messages.batches.retrieve(batch_id)
|
|
427
541
|
self._batch_jobs[batch_id] = batch
|
|
428
542
|
return batch
|
|
429
543
|
except Exception as e:
|
|
@@ -437,25 +551,37 @@ class AnthropicBatchProvider:
|
|
|
437
551
|
batch_id: Batch job ID
|
|
438
552
|
|
|
439
553
|
Returns:
|
|
440
|
-
List of result dicts
|
|
554
|
+
List of result dicts. Each dict contains:
|
|
555
|
+
- custom_id: Request identifier
|
|
556
|
+
- result: Either {"type": "succeeded", "message": {...}} or {"type": "errored", "error": {...}}
|
|
441
557
|
|
|
442
558
|
Raises:
|
|
443
|
-
ValueError: If batch
|
|
559
|
+
ValueError: If batch has not ended processing
|
|
444
560
|
RuntimeError: If API call fails
|
|
445
561
|
|
|
446
562
|
Example:
|
|
447
|
-
>>> results = provider.get_batch_results("
|
|
563
|
+
>>> results = provider.get_batch_results("msgbatch_abc123")
|
|
448
564
|
>>> for result in results:
|
|
449
|
-
...
|
|
565
|
+
... if result['result']['type'] == 'succeeded':
|
|
566
|
+
... message = result['result']['message']
|
|
567
|
+
... print(f"{result['custom_id']}: {message.content[0].text}")
|
|
568
|
+
... else:
|
|
569
|
+
... error = result['result']['error']
|
|
570
|
+
... print(f"{result['custom_id']}: Error {error['type']}")
|
|
450
571
|
"""
|
|
451
572
|
status = self.get_batch_status(batch_id)
|
|
452
573
|
|
|
453
|
-
|
|
454
|
-
|
|
574
|
+
# Check processing_status instead of status
|
|
575
|
+
if status.processing_status != "ended":
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"Batch {batch_id} has not ended processing (status: {status.processing_status})"
|
|
578
|
+
)
|
|
455
579
|
|
|
456
580
|
try:
|
|
457
|
-
|
|
458
|
-
|
|
581
|
+
# Use correct Message Batches API endpoint
|
|
582
|
+
# results() returns an iterator, convert to list
|
|
583
|
+
results_iterator = self.client.messages.batches.results(batch_id)
|
|
584
|
+
return list(results_iterator)
|
|
459
585
|
except Exception as e:
|
|
460
586
|
logger.error(f"Failed to get batch results for {batch_id}: {e}")
|
|
461
587
|
raise RuntimeError(f"Failed to get batch results: {e}") from e
|
|
@@ -474,15 +600,15 @@ class AnthropicBatchProvider:
|
|
|
474
600
|
timeout: Maximum wait time in seconds (default: 86400 = 24 hours)
|
|
475
601
|
|
|
476
602
|
Returns:
|
|
477
|
-
Batch results when
|
|
603
|
+
Batch results when processing ends
|
|
478
604
|
|
|
479
605
|
Raises:
|
|
480
606
|
TimeoutError: If batch doesn't complete within timeout
|
|
481
|
-
RuntimeError: If batch processing
|
|
607
|
+
RuntimeError: If batch had errors during processing
|
|
482
608
|
|
|
483
609
|
Example:
|
|
484
610
|
>>> results = await provider.wait_for_batch(
|
|
485
|
-
... "
|
|
611
|
+
... "msgbatch_abc123",
|
|
486
612
|
... poll_interval=300, # Check every 5 minutes
|
|
487
613
|
... )
|
|
488
614
|
>>> print(f"Batch completed: {len(results)} results")
|
|
@@ -493,22 +619,36 @@ class AnthropicBatchProvider:
|
|
|
493
619
|
while True:
|
|
494
620
|
status = self.get_batch_status(batch_id)
|
|
495
621
|
|
|
496
|
-
if
|
|
497
|
-
|
|
498
|
-
|
|
622
|
+
# Check if batch processing has ended
|
|
623
|
+
if status.processing_status == "ended":
|
|
624
|
+
# Check request counts to see if there were errors
|
|
625
|
+
counts = status.request_counts
|
|
626
|
+
logger.info(
|
|
627
|
+
f"Batch {batch_id} ended: "
|
|
628
|
+
f"{counts.succeeded} succeeded, {counts.errored} errored, "
|
|
629
|
+
f"{counts.canceled} canceled, {counts.expired} expired"
|
|
630
|
+
)
|
|
499
631
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
raise RuntimeError(f"Batch {batch_id} failed: {error_msg}")
|
|
632
|
+
# Return results even if some requests failed
|
|
633
|
+
# The caller can inspect individual results for errors
|
|
634
|
+
return self.get_batch_results(batch_id)
|
|
504
635
|
|
|
505
636
|
# Check timeout
|
|
506
637
|
elapsed = (datetime.now() - start_time).total_seconds()
|
|
507
638
|
if elapsed > timeout:
|
|
508
639
|
raise TimeoutError(f"Batch {batch_id} did not complete within {timeout}s")
|
|
509
640
|
|
|
510
|
-
# Log progress
|
|
511
|
-
|
|
641
|
+
# Log progress with request counts
|
|
642
|
+
try:
|
|
643
|
+
counts = status.request_counts
|
|
644
|
+
logger.debug(
|
|
645
|
+
f"Batch {batch_id} status: {status.processing_status} "
|
|
646
|
+
f"(processing: {counts.processing}, elapsed: {elapsed:.0f}s)"
|
|
647
|
+
)
|
|
648
|
+
except AttributeError:
|
|
649
|
+
logger.debug(
|
|
650
|
+
f"Batch {batch_id} status: {status.processing_status} (elapsed: {elapsed:.0f}s)"
|
|
651
|
+
)
|
|
512
652
|
|
|
513
653
|
# Wait before next poll
|
|
514
654
|
await asyncio.sleep(poll_interval)
|
|
@@ -7,10 +7,35 @@ Copyright 2025 Smart-AI-Memory
|
|
|
7
7
|
Licensed under Fair Source License 0.9
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
+
import functools
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
from dataclasses import dataclass
|
|
10
14
|
from typing import Any
|
|
11
15
|
|
|
12
|
-
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Lazy import to avoid requiring dependencies if not used
|
|
13
19
|
_client = None
|
|
20
|
+
_tiktoken_encoding = None
|
|
21
|
+
|
|
22
|
+
# Try to import tiktoken for fast local estimation
|
|
23
|
+
try:
|
|
24
|
+
import tiktoken
|
|
25
|
+
|
|
26
|
+
TIKTOKEN_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
TIKTOKEN_AVAILABLE = False
|
|
29
|
+
logger.debug("tiktoken not available - will use API or heuristic fallback")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class TokenCount:
|
|
34
|
+
"""Token count result with metadata."""
|
|
35
|
+
|
|
36
|
+
tokens: int
|
|
37
|
+
method: str # "anthropic_api", "tiktoken", "heuristic"
|
|
38
|
+
model: str | None = None
|
|
14
39
|
|
|
15
40
|
|
|
16
41
|
def _get_client():
|
|
@@ -20,7 +45,12 @@ def _get_client():
|
|
|
20
45
|
try:
|
|
21
46
|
from anthropic import Anthropic
|
|
22
47
|
|
|
23
|
-
|
|
48
|
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
49
|
+
if not api_key:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"ANTHROPIC_API_KEY environment variable required for API token counting"
|
|
52
|
+
)
|
|
53
|
+
_client = Anthropic(api_key=api_key)
|
|
24
54
|
except ImportError as e:
|
|
25
55
|
raise ImportError(
|
|
26
56
|
"anthropic package required for token counting. Install with: pip install anthropic"
|
|
@@ -28,57 +58,109 @@ def _get_client():
|
|
|
28
58
|
return _client
|
|
29
59
|
|
|
30
60
|
|
|
31
|
-
|
|
32
|
-
|
|
61
|
+
@functools.lru_cache(maxsize=4)
|
|
62
|
+
def _get_tiktoken_encoding(model: str) -> Any:
|
|
63
|
+
"""Get tiktoken encoding for Claude models (cached)."""
|
|
64
|
+
if not TIKTOKEN_AVAILABLE:
|
|
65
|
+
return None
|
|
66
|
+
try:
|
|
67
|
+
# Claude uses cl100k_base encoding (similar to GPT-4)
|
|
68
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.warning(f"Failed to get tiktoken encoding: {e}")
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _count_tokens_tiktoken(text: str, model: str) -> int:
|
|
75
|
+
"""Count tokens using tiktoken (fast local estimation)."""
|
|
76
|
+
if not text:
|
|
77
|
+
return 0
|
|
78
|
+
|
|
79
|
+
encoding = _get_tiktoken_encoding(model)
|
|
80
|
+
if not encoding:
|
|
81
|
+
return 0
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
return len(encoding.encode(text))
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.warning(f"tiktoken encoding failed: {e}")
|
|
87
|
+
return 0
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _count_tokens_heuristic(text: str) -> int:
|
|
91
|
+
"""Fallback heuristic token counting (~4 chars per token)."""
|
|
92
|
+
if not text:
|
|
93
|
+
return 0
|
|
94
|
+
return max(1, len(text) // 4)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def count_tokens(text: str, model: str = "claude-sonnet-4-5-20250929", use_api: bool = False) -> int:
|
|
98
|
+
"""Count tokens using best available method.
|
|
99
|
+
|
|
100
|
+
By default, uses tiktoken for fast local estimation (~98% accurate).
|
|
101
|
+
Set use_api=True for exact count via Anthropic API (requires network call).
|
|
33
102
|
|
|
34
103
|
Args:
|
|
35
104
|
text: Text to tokenize
|
|
36
105
|
model: Model ID (different models may have different tokenizers)
|
|
106
|
+
use_api: Whether to use Anthropic API for exact count (slower, requires API key)
|
|
37
107
|
|
|
38
108
|
Returns:
|
|
39
|
-
|
|
109
|
+
Token count
|
|
40
110
|
|
|
41
111
|
Example:
|
|
42
112
|
>>> count_tokens("Hello, world!")
|
|
43
113
|
4
|
|
44
|
-
>>> count_tokens("def hello():\\n print('hi')")
|
|
114
|
+
>>> count_tokens("def hello():\\n print('hi')", use_api=True)
|
|
45
115
|
8
|
|
46
116
|
|
|
47
117
|
Raises:
|
|
48
|
-
ImportError: If anthropic package not installed
|
|
49
|
-
ValueError: If
|
|
118
|
+
ImportError: If anthropic package not installed (when use_api=True)
|
|
119
|
+
ValueError: If API key missing (when use_api=True)
|
|
50
120
|
|
|
51
121
|
"""
|
|
52
122
|
if not text:
|
|
53
123
|
return 0
|
|
54
124
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
125
|
+
# Use API if explicitly requested
|
|
126
|
+
if use_api:
|
|
127
|
+
try:
|
|
128
|
+
client = _get_client()
|
|
129
|
+
# FIXED: Use correct API method - client.messages.count_tokens()
|
|
130
|
+
result = client.messages.count_tokens(
|
|
131
|
+
model=model,
|
|
132
|
+
messages=[{"role": "user", "content": text}],
|
|
133
|
+
)
|
|
134
|
+
return int(result.input_tokens)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.warning(f"API token counting failed, using fallback: {e}")
|
|
137
|
+
# Continue to fallback methods
|
|
138
|
+
|
|
139
|
+
# Try tiktoken first (fast and accurate)
|
|
140
|
+
if TIKTOKEN_AVAILABLE:
|
|
141
|
+
tokens = _count_tokens_tiktoken(text, model)
|
|
142
|
+
if tokens > 0:
|
|
143
|
+
return tokens
|
|
144
|
+
|
|
145
|
+
# Fallback to heuristic
|
|
146
|
+
return _count_tokens_heuristic(text)
|
|
69
147
|
|
|
70
148
|
|
|
71
149
|
def count_message_tokens(
|
|
72
150
|
messages: list[dict[str, str]],
|
|
73
151
|
system_prompt: str | None = None,
|
|
74
|
-
model: str = "claude-sonnet-4-5",
|
|
152
|
+
model: str = "claude-sonnet-4-5-20250929",
|
|
153
|
+
use_api: bool = False,
|
|
75
154
|
) -> dict[str, int]:
|
|
76
155
|
"""Count tokens in a conversation.
|
|
77
156
|
|
|
157
|
+
By default uses tiktoken for fast estimation. Set use_api=True for exact count.
|
|
158
|
+
|
|
78
159
|
Args:
|
|
79
160
|
messages: List of message dicts with "role" and "content"
|
|
80
161
|
system_prompt: Optional system prompt
|
|
81
162
|
model: Model ID
|
|
163
|
+
use_api: Whether to use Anthropic API for exact count
|
|
82
164
|
|
|
83
165
|
Returns:
|
|
84
166
|
Dict with token counts by component:
|
|
@@ -92,21 +174,59 @@ def count_message_tokens(
|
|
|
92
174
|
{"system": 4, "messages": 6, "total": 10}
|
|
93
175
|
|
|
94
176
|
"""
|
|
177
|
+
if not messages:
|
|
178
|
+
if system_prompt:
|
|
179
|
+
tokens = count_tokens(system_prompt, model, use_api)
|
|
180
|
+
return {"system": tokens, "messages": 0, "total": tokens}
|
|
181
|
+
return {"system": 0, "messages": 0, "total": 0}
|
|
182
|
+
|
|
183
|
+
# Use Anthropic API for exact count if requested
|
|
184
|
+
if use_api:
|
|
185
|
+
try:
|
|
186
|
+
client = _get_client()
|
|
187
|
+
kwargs: dict[str, Any] = {"model": model, "messages": messages}
|
|
188
|
+
if system_prompt:
|
|
189
|
+
kwargs["system"] = system_prompt
|
|
190
|
+
|
|
191
|
+
result = client.messages.count_tokens(**kwargs)
|
|
192
|
+
# API returns total input tokens, estimate breakdown
|
|
193
|
+
total_tokens = result.input_tokens
|
|
194
|
+
|
|
195
|
+
# Estimate system vs message breakdown
|
|
196
|
+
if system_prompt:
|
|
197
|
+
system_tokens = count_tokens(system_prompt, model, use_api=False)
|
|
198
|
+
message_tokens = max(0, total_tokens - system_tokens)
|
|
199
|
+
else:
|
|
200
|
+
system_tokens = 0
|
|
201
|
+
message_tokens = total_tokens
|
|
202
|
+
|
|
203
|
+
return {
|
|
204
|
+
"system": system_tokens,
|
|
205
|
+
"messages": message_tokens,
|
|
206
|
+
"total": total_tokens,
|
|
207
|
+
}
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.warning(f"API token counting failed, using fallback: {e}")
|
|
210
|
+
# Continue to fallback method
|
|
211
|
+
|
|
212
|
+
# Fallback: count each component separately
|
|
95
213
|
counts: dict[str, int] = {}
|
|
96
214
|
|
|
97
215
|
# Count system prompt
|
|
98
216
|
if system_prompt:
|
|
99
|
-
counts["system"] = count_tokens(system_prompt, model)
|
|
217
|
+
counts["system"] = count_tokens(system_prompt, model, use_api=False)
|
|
100
218
|
else:
|
|
101
219
|
counts["system"] = 0
|
|
102
220
|
|
|
103
|
-
# Count messages
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
221
|
+
# Count messages with overhead
|
|
222
|
+
message_tokens = 0
|
|
223
|
+
for message in messages:
|
|
224
|
+
content = message.get("content", "")
|
|
225
|
+
message_tokens += count_tokens(content, model, use_api=False)
|
|
226
|
+
message_tokens += 4 # Overhead for role markers
|
|
107
227
|
|
|
108
|
-
|
|
109
|
-
counts["total"] = counts["system"] +
|
|
228
|
+
counts["messages"] = message_tokens
|
|
229
|
+
counts["total"] = counts["system"] + message_tokens
|
|
110
230
|
|
|
111
231
|
return counts
|
|
112
232
|
|