genxai-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +3 -0
- cli/commands/__init__.py +6 -0
- cli/commands/approval.py +85 -0
- cli/commands/audit.py +127 -0
- cli/commands/metrics.py +25 -0
- cli/commands/tool.py +389 -0
- cli/main.py +32 -0
- genxai/__init__.py +81 -0
- genxai/api/__init__.py +5 -0
- genxai/api/app.py +21 -0
- genxai/config/__init__.py +5 -0
- genxai/config/settings.py +37 -0
- genxai/connectors/__init__.py +19 -0
- genxai/connectors/base.py +122 -0
- genxai/connectors/kafka.py +92 -0
- genxai/connectors/postgres_cdc.py +95 -0
- genxai/connectors/registry.py +44 -0
- genxai/connectors/sqs.py +94 -0
- genxai/connectors/webhook.py +73 -0
- genxai/core/__init__.py +37 -0
- genxai/core/agent/__init__.py +32 -0
- genxai/core/agent/base.py +206 -0
- genxai/core/agent/config_io.py +59 -0
- genxai/core/agent/registry.py +98 -0
- genxai/core/agent/runtime.py +970 -0
- genxai/core/communication/__init__.py +6 -0
- genxai/core/communication/collaboration.py +44 -0
- genxai/core/communication/message_bus.py +192 -0
- genxai/core/communication/protocols.py +35 -0
- genxai/core/execution/__init__.py +22 -0
- genxai/core/execution/metadata.py +181 -0
- genxai/core/execution/queue.py +201 -0
- genxai/core/graph/__init__.py +30 -0
- genxai/core/graph/checkpoints.py +77 -0
- genxai/core/graph/edges.py +131 -0
- genxai/core/graph/engine.py +813 -0
- genxai/core/graph/executor.py +516 -0
- genxai/core/graph/nodes.py +161 -0
- genxai/core/graph/trigger_runner.py +40 -0
- genxai/core/memory/__init__.py +19 -0
- genxai/core/memory/base.py +72 -0
- genxai/core/memory/embedding.py +327 -0
- genxai/core/memory/episodic.py +448 -0
- genxai/core/memory/long_term.py +467 -0
- genxai/core/memory/manager.py +543 -0
- genxai/core/memory/persistence.py +297 -0
- genxai/core/memory/procedural.py +461 -0
- genxai/core/memory/semantic.py +526 -0
- genxai/core/memory/shared.py +62 -0
- genxai/core/memory/short_term.py +303 -0
- genxai/core/memory/vector_store.py +508 -0
- genxai/core/memory/working.py +211 -0
- genxai/core/state/__init__.py +6 -0
- genxai/core/state/manager.py +293 -0
- genxai/core/state/schema.py +115 -0
- genxai/llm/__init__.py +14 -0
- genxai/llm/base.py +150 -0
- genxai/llm/factory.py +329 -0
- genxai/llm/providers/__init__.py +1 -0
- genxai/llm/providers/anthropic.py +249 -0
- genxai/llm/providers/cohere.py +274 -0
- genxai/llm/providers/google.py +334 -0
- genxai/llm/providers/ollama.py +147 -0
- genxai/llm/providers/openai.py +257 -0
- genxai/llm/routing.py +83 -0
- genxai/observability/__init__.py +6 -0
- genxai/observability/logging.py +327 -0
- genxai/observability/metrics.py +494 -0
- genxai/observability/tracing.py +372 -0
- genxai/performance/__init__.py +39 -0
- genxai/performance/cache.py +256 -0
- genxai/performance/pooling.py +289 -0
- genxai/security/audit.py +304 -0
- genxai/security/auth.py +315 -0
- genxai/security/cost_control.py +528 -0
- genxai/security/default_policies.py +44 -0
- genxai/security/jwt.py +142 -0
- genxai/security/oauth.py +226 -0
- genxai/security/pii.py +366 -0
- genxai/security/policy_engine.py +82 -0
- genxai/security/rate_limit.py +341 -0
- genxai/security/rbac.py +247 -0
- genxai/security/validation.py +218 -0
- genxai/tools/__init__.py +21 -0
- genxai/tools/base.py +383 -0
- genxai/tools/builtin/__init__.py +131 -0
- genxai/tools/builtin/communication/__init__.py +15 -0
- genxai/tools/builtin/communication/email_sender.py +159 -0
- genxai/tools/builtin/communication/notification_manager.py +167 -0
- genxai/tools/builtin/communication/slack_notifier.py +118 -0
- genxai/tools/builtin/communication/sms_sender.py +118 -0
- genxai/tools/builtin/communication/webhook_caller.py +136 -0
- genxai/tools/builtin/computation/__init__.py +15 -0
- genxai/tools/builtin/computation/calculator.py +101 -0
- genxai/tools/builtin/computation/code_executor.py +183 -0
- genxai/tools/builtin/computation/data_validator.py +259 -0
- genxai/tools/builtin/computation/hash_generator.py +129 -0
- genxai/tools/builtin/computation/regex_matcher.py +201 -0
- genxai/tools/builtin/data/__init__.py +15 -0
- genxai/tools/builtin/data/csv_processor.py +213 -0
- genxai/tools/builtin/data/data_transformer.py +299 -0
- genxai/tools/builtin/data/json_processor.py +233 -0
- genxai/tools/builtin/data/text_analyzer.py +288 -0
- genxai/tools/builtin/data/xml_processor.py +175 -0
- genxai/tools/builtin/database/__init__.py +15 -0
- genxai/tools/builtin/database/database_inspector.py +157 -0
- genxai/tools/builtin/database/mongodb_query.py +196 -0
- genxai/tools/builtin/database/redis_cache.py +167 -0
- genxai/tools/builtin/database/sql_query.py +145 -0
- genxai/tools/builtin/database/vector_search.py +163 -0
- genxai/tools/builtin/file/__init__.py +17 -0
- genxai/tools/builtin/file/directory_scanner.py +214 -0
- genxai/tools/builtin/file/file_compressor.py +237 -0
- genxai/tools/builtin/file/file_reader.py +102 -0
- genxai/tools/builtin/file/file_writer.py +122 -0
- genxai/tools/builtin/file/image_processor.py +186 -0
- genxai/tools/builtin/file/pdf_parser.py +144 -0
- genxai/tools/builtin/test/__init__.py +15 -0
- genxai/tools/builtin/test/async_simulator.py +62 -0
- genxai/tools/builtin/test/data_transformer.py +99 -0
- genxai/tools/builtin/test/error_generator.py +82 -0
- genxai/tools/builtin/test/simple_math.py +94 -0
- genxai/tools/builtin/test/string_processor.py +72 -0
- genxai/tools/builtin/web/__init__.py +15 -0
- genxai/tools/builtin/web/api_caller.py +161 -0
- genxai/tools/builtin/web/html_parser.py +330 -0
- genxai/tools/builtin/web/http_client.py +187 -0
- genxai/tools/builtin/web/url_validator.py +162 -0
- genxai/tools/builtin/web/web_scraper.py +170 -0
- genxai/tools/custom/my_test_tool_2.py +9 -0
- genxai/tools/dynamic.py +105 -0
- genxai/tools/mcp_server.py +167 -0
- genxai/tools/persistence/__init__.py +6 -0
- genxai/tools/persistence/models.py +55 -0
- genxai/tools/persistence/service.py +322 -0
- genxai/tools/registry.py +227 -0
- genxai/tools/security/__init__.py +11 -0
- genxai/tools/security/limits.py +214 -0
- genxai/tools/security/policy.py +20 -0
- genxai/tools/security/sandbox.py +248 -0
- genxai/tools/templates.py +435 -0
- genxai/triggers/__init__.py +19 -0
- genxai/triggers/base.py +104 -0
- genxai/triggers/file_watcher.py +75 -0
- genxai/triggers/queue.py +68 -0
- genxai/triggers/registry.py +82 -0
- genxai/triggers/schedule.py +66 -0
- genxai/triggers/webhook.py +68 -0
- genxai/utils/__init__.py +1 -0
- genxai/utils/tokens.py +295 -0
- genxai_framework-0.1.0.dist-info/METADATA +495 -0
- genxai_framework-0.1.0.dist-info/RECORD +156 -0
- genxai_framework-0.1.0.dist-info/WHEEL +5 -0
- genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
- genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
- genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Trigger registry for managing workflow triggers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from genxai.triggers.base import BaseTrigger, TriggerStatus
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TriggerRegistry:
|
|
14
|
+
"""Central registry for triggers."""
|
|
15
|
+
|
|
16
|
+
_instance: Optional["TriggerRegistry"] = None
|
|
17
|
+
_triggers: Dict[str, BaseTrigger] = {}
|
|
18
|
+
|
|
19
|
+
def __new__(cls) -> "TriggerRegistry":
|
|
20
|
+
if cls._instance is None:
|
|
21
|
+
cls._instance = super().__new__(cls)
|
|
22
|
+
return cls._instance
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def register(cls, trigger: BaseTrigger) -> None:
|
|
26
|
+
"""Register a trigger instance."""
|
|
27
|
+
if trigger.trigger_id in cls._triggers:
|
|
28
|
+
logger.warning("Trigger %s already registered, overwriting", trigger.trigger_id)
|
|
29
|
+
cls._triggers[trigger.trigger_id] = trigger
|
|
30
|
+
logger.info("Registered trigger: %s", trigger.trigger_id)
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def unregister(cls, trigger_id: str) -> None:
|
|
34
|
+
"""Unregister a trigger by id."""
|
|
35
|
+
trigger = cls._triggers.pop(trigger_id, None)
|
|
36
|
+
if trigger:
|
|
37
|
+
logger.info("Unregistered trigger: %s", trigger_id)
|
|
38
|
+
else:
|
|
39
|
+
logger.warning("Trigger %s not found in registry", trigger_id)
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def get(cls, trigger_id: str) -> Optional[BaseTrigger]:
|
|
43
|
+
"""Get a trigger by id."""
|
|
44
|
+
return cls._triggers.get(trigger_id)
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def list_all(cls) -> List[BaseTrigger]:
|
|
48
|
+
"""List all registered triggers."""
|
|
49
|
+
return list(cls._triggers.values())
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def clear(cls) -> None:
|
|
53
|
+
"""Clear all triggers from the registry."""
|
|
54
|
+
cls._triggers.clear()
|
|
55
|
+
logger.info("Cleared all triggers from registry")
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
async def start_all(cls) -> None:
|
|
59
|
+
"""Start all registered triggers."""
|
|
60
|
+
for trigger in cls._triggers.values():
|
|
61
|
+
if trigger.status != TriggerStatus.RUNNING:
|
|
62
|
+
await trigger.start()
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
async def stop_all(cls) -> None:
|
|
66
|
+
"""Stop all registered triggers."""
|
|
67
|
+
for trigger in cls._triggers.values():
|
|
68
|
+
if trigger.status != TriggerStatus.STOPPED:
|
|
69
|
+
await trigger.stop()
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def get_stats(cls) -> Dict[str, int]:
|
|
73
|
+
"""Return registry stats by trigger status."""
|
|
74
|
+
stats: Dict[str, int] = {}
|
|
75
|
+
for trigger in cls._triggers.values():
|
|
76
|
+
status = trigger.status.value
|
|
77
|
+
stats[status] = stats.get(status, 0) + 1
|
|
78
|
+
stats["total"] = len(cls._triggers)
|
|
79
|
+
return stats
|
|
80
|
+
|
|
81
|
+
def __repr__(self) -> str:
|
|
82
|
+
return f"TriggerRegistry(triggers={len(self._triggers)})"
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Schedule-based trigger implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from genxai.triggers.base import BaseTrigger
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ScheduleTrigger(BaseTrigger):
|
|
15
|
+
"""Cron/interval trigger using APScheduler if available."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
trigger_id: str,
|
|
20
|
+
cron: Optional[str] = None,
|
|
21
|
+
interval_seconds: Optional[int] = None,
|
|
22
|
+
payload: Optional[Dict[str, Any]] = None,
|
|
23
|
+
name: Optional[str] = None,
|
|
24
|
+
timezone: str = "UTC",
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(trigger_id=trigger_id, name=name)
|
|
27
|
+
self.cron = cron
|
|
28
|
+
self.interval_seconds = interval_seconds
|
|
29
|
+
self.payload = payload or {}
|
|
30
|
+
self.timezone = timezone
|
|
31
|
+
self._scheduler = None
|
|
32
|
+
self._job = None
|
|
33
|
+
|
|
34
|
+
if not self.cron and not self.interval_seconds:
|
|
35
|
+
raise ValueError("Either cron or interval_seconds must be provided")
|
|
36
|
+
|
|
37
|
+
async def _start(self) -> None:
|
|
38
|
+
try:
|
|
39
|
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
40
|
+
from apscheduler.triggers.cron import CronTrigger
|
|
41
|
+
from apscheduler.triggers.interval import IntervalTrigger
|
|
42
|
+
except ImportError as exc:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"APScheduler is required for ScheduleTrigger. Install with: pip install apscheduler"
|
|
45
|
+
) from exc
|
|
46
|
+
|
|
47
|
+
scheduler = AsyncIOScheduler(timezone=self.timezone)
|
|
48
|
+
|
|
49
|
+
if self.cron:
|
|
50
|
+
trigger = CronTrigger.from_crontab(self.cron, timezone=self.timezone)
|
|
51
|
+
else:
|
|
52
|
+
trigger = IntervalTrigger(seconds=self.interval_seconds)
|
|
53
|
+
|
|
54
|
+
async def _emit_wrapper() -> None:
|
|
55
|
+
await self.emit(payload={"scheduled_at": datetime.utcnow().isoformat(), **self.payload})
|
|
56
|
+
|
|
57
|
+
scheduler.add_job(_emit_wrapper, trigger=trigger)
|
|
58
|
+
scheduler.start()
|
|
59
|
+
self._scheduler = scheduler
|
|
60
|
+
logger.info("ScheduleTrigger %s started", self.trigger_id)
|
|
61
|
+
|
|
62
|
+
async def _stop(self) -> None:
|
|
63
|
+
if self._scheduler:
|
|
64
|
+
self._scheduler.shutdown(wait=False)
|
|
65
|
+
self._scheduler = None
|
|
66
|
+
logger.info("ScheduleTrigger %s stopped", self.trigger_id)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Webhook trigger implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
import hmac
|
|
7
|
+
import hashlib
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from genxai.triggers.base import BaseTrigger
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WebhookTrigger(BaseTrigger):
|
|
16
|
+
"""HTTP webhook trigger.
|
|
17
|
+
|
|
18
|
+
This trigger does not start its own web server; it provides a handler that
|
|
19
|
+
can be mounted in FastAPI or other ASGI frameworks.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
trigger_id: str,
|
|
25
|
+
secret: Optional[str] = None,
|
|
26
|
+
name: Optional[str] = None,
|
|
27
|
+
header_name: str = "X-GenXAI-Signature",
|
|
28
|
+
hash_alg: str = "sha256",
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__(trigger_id=trigger_id, name=name)
|
|
31
|
+
self.secret = secret
|
|
32
|
+
self.header_name = header_name
|
|
33
|
+
self.hash_alg = hash_alg
|
|
34
|
+
|
|
35
|
+
async def _start(self) -> None:
|
|
36
|
+
logger.debug("Webhook trigger %s ready for requests", self.trigger_id)
|
|
37
|
+
|
|
38
|
+
async def _stop(self) -> None:
|
|
39
|
+
logger.debug("Webhook trigger %s stopped", self.trigger_id)
|
|
40
|
+
|
|
41
|
+
def validate_signature(self, payload: bytes, signature: Optional[str]) -> bool:
|
|
42
|
+
"""Validate the webhook signature when a secret is provided."""
|
|
43
|
+
if not self.secret:
|
|
44
|
+
return True
|
|
45
|
+
if not signature:
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
digest = hmac.new(self.secret.encode(), payload, getattr(hashlib, self.hash_alg)).hexdigest()
|
|
49
|
+
expected = f"{self.hash_alg}={digest}"
|
|
50
|
+
return hmac.compare_digest(expected, signature)
|
|
51
|
+
|
|
52
|
+
async def handle_request(
|
|
53
|
+
self,
|
|
54
|
+
payload: Dict[str, Any],
|
|
55
|
+
raw_body: Optional[bytes] = None,
|
|
56
|
+
headers: Optional[Dict[str, str]] = None,
|
|
57
|
+
) -> Dict[str, Any]:
|
|
58
|
+
"""Process an inbound webhook request and emit a trigger event."""
|
|
59
|
+
headers = headers or {}
|
|
60
|
+
signature = headers.get(self.header_name)
|
|
61
|
+
|
|
62
|
+
if self.secret and raw_body is not None:
|
|
63
|
+
if not self.validate_signature(raw_body, signature):
|
|
64
|
+
logger.warning("Webhook signature validation failed for %s", self.trigger_id)
|
|
65
|
+
return {"status": "rejected", "reason": "invalid signature"}
|
|
66
|
+
|
|
67
|
+
await self.emit(payload=payload, metadata={"headers": headers})
|
|
68
|
+
return {"status": "accepted", "trigger_id": self.trigger_id}
|
genxai/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utility functions for GenXAI."""
|
genxai/utils/tokens.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""Token counting and context window management utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Model token limits (context window sizes)
|
|
10
|
+
MODEL_TOKEN_LIMITS: Dict[str, int] = {
|
|
11
|
+
# OpenAI models
|
|
12
|
+
"gpt-4": 8192,
|
|
13
|
+
"gpt-4-32k": 32768,
|
|
14
|
+
"gpt-4-turbo": 128000,
|
|
15
|
+
"gpt-4-turbo-preview": 128000,
|
|
16
|
+
"gpt-3.5-turbo": 4096,
|
|
17
|
+
"gpt-3.5-turbo-16k": 16384,
|
|
18
|
+
# Anthropic models
|
|
19
|
+
"claude-3-opus": 200000,
|
|
20
|
+
"claude-3-sonnet": 200000,
|
|
21
|
+
"claude-3-haiku": 200000,
|
|
22
|
+
"claude-2.1": 200000,
|
|
23
|
+
"claude-2": 100000,
|
|
24
|
+
"claude-instant": 100000,
|
|
25
|
+
# Google models
|
|
26
|
+
"gemini-pro": 32768,
|
|
27
|
+
"gemini-pro-vision": 16384,
|
|
28
|
+
# Cohere models
|
|
29
|
+
"command": 4096,
|
|
30
|
+
"command-light": 4096,
|
|
31
|
+
"command-nightly": 8192,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_model_token_limit(model: str) -> int:
|
|
36
|
+
"""Get token limit for a model.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model: Model name
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Token limit for the model, or 4096 as default
|
|
43
|
+
"""
|
|
44
|
+
# Try exact match first
|
|
45
|
+
if model in MODEL_TOKEN_LIMITS:
|
|
46
|
+
return MODEL_TOKEN_LIMITS[model]
|
|
47
|
+
|
|
48
|
+
# Try partial match (e.g., "gpt-4-0125-preview" matches "gpt-4")
|
|
49
|
+
for model_prefix, limit in MODEL_TOKEN_LIMITS.items():
|
|
50
|
+
if model.startswith(model_prefix):
|
|
51
|
+
return limit
|
|
52
|
+
|
|
53
|
+
# Default to conservative 4K limit
|
|
54
|
+
logger.warning(f"Unknown model '{model}', using default 4096 token limit")
|
|
55
|
+
return 4096
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def estimate_tokens(text: str) -> int:
|
|
59
|
+
"""Estimate token count for text.
|
|
60
|
+
|
|
61
|
+
This is a simple estimation based on character count.
|
|
62
|
+
For production use, consider using tiktoken library for accurate counting.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
text: Text to estimate tokens for
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Estimated token count
|
|
69
|
+
"""
|
|
70
|
+
# Rough estimation: ~4 characters per token for English text
|
|
71
|
+
# This is conservative and works reasonably well for most cases
|
|
72
|
+
return len(text) // 4
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def truncate_to_token_limit(
|
|
76
|
+
text: str,
|
|
77
|
+
max_tokens: int,
|
|
78
|
+
preserve_start: bool = True,
|
|
79
|
+
) -> str:
|
|
80
|
+
"""Truncate text to fit within token limit.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
text: Text to truncate
|
|
84
|
+
max_tokens: Maximum tokens allowed
|
|
85
|
+
preserve_start: If True, keep start of text; if False, keep end
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Truncated text
|
|
89
|
+
"""
|
|
90
|
+
estimated_tokens = estimate_tokens(text)
|
|
91
|
+
|
|
92
|
+
if estimated_tokens <= max_tokens:
|
|
93
|
+
return text
|
|
94
|
+
|
|
95
|
+
# Calculate how many characters to keep
|
|
96
|
+
# Using 4 chars per token estimation
|
|
97
|
+
max_chars = max_tokens * 4
|
|
98
|
+
|
|
99
|
+
if preserve_start:
|
|
100
|
+
truncated = text[:max_chars]
|
|
101
|
+
logger.debug(f"Truncated text from {len(text)} to {len(truncated)} chars (start preserved)")
|
|
102
|
+
else:
|
|
103
|
+
truncated = text[-max_chars:]
|
|
104
|
+
logger.debug(f"Truncated text from {len(text)} to {len(truncated)} chars (end preserved)")
|
|
105
|
+
|
|
106
|
+
return truncated
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def manage_context_window(
|
|
110
|
+
system_prompt: str,
|
|
111
|
+
user_prompt: str,
|
|
112
|
+
memory_context: str,
|
|
113
|
+
model: str,
|
|
114
|
+
reserve_tokens: int = 1000,
|
|
115
|
+
) -> tuple[str, str, str]:
|
|
116
|
+
"""Manage context window to fit within model limits.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
system_prompt: System prompt text
|
|
120
|
+
user_prompt: User prompt text
|
|
121
|
+
memory_context: Memory context text
|
|
122
|
+
model: Model name
|
|
123
|
+
reserve_tokens: Tokens to reserve for response
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Tuple of (system_prompt, user_prompt, memory_context) adjusted to fit
|
|
127
|
+
"""
|
|
128
|
+
model_limit = get_model_token_limit(model)
|
|
129
|
+
available_tokens = model_limit - reserve_tokens
|
|
130
|
+
|
|
131
|
+
# Estimate current token usage
|
|
132
|
+
system_tokens = estimate_tokens(system_prompt)
|
|
133
|
+
user_tokens = estimate_tokens(user_prompt)
|
|
134
|
+
memory_tokens = estimate_tokens(memory_context)
|
|
135
|
+
total_tokens = system_tokens + user_tokens + memory_tokens
|
|
136
|
+
|
|
137
|
+
logger.debug(
|
|
138
|
+
f"Context window: {total_tokens}/{model_limit} tokens "
|
|
139
|
+
f"(system: {system_tokens}, user: {user_tokens}, memory: {memory_tokens})"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# If within limit, return as-is
|
|
143
|
+
if total_tokens <= available_tokens:
|
|
144
|
+
return system_prompt, user_prompt, memory_context
|
|
145
|
+
|
|
146
|
+
# Need to truncate - prioritize user prompt, then system, then memory
|
|
147
|
+
tokens_to_remove = total_tokens - available_tokens
|
|
148
|
+
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"Context window exceeded by {tokens_to_remove} tokens, truncating..."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# First, try truncating memory context
|
|
154
|
+
if memory_tokens > 0 and tokens_to_remove > 0:
|
|
155
|
+
memory_reduction = min(memory_tokens, tokens_to_remove)
|
|
156
|
+
new_memory_tokens = max(0, memory_tokens - memory_reduction)
|
|
157
|
+
memory_context = truncate_to_token_limit(
|
|
158
|
+
memory_context,
|
|
159
|
+
new_memory_tokens,
|
|
160
|
+
preserve_start=False # Keep most recent memories
|
|
161
|
+
)
|
|
162
|
+
tokens_to_remove -= memory_reduction
|
|
163
|
+
logger.debug(f"Truncated memory context by {memory_reduction} tokens")
|
|
164
|
+
|
|
165
|
+
# If still over limit, truncate system prompt
|
|
166
|
+
if tokens_to_remove > 0 and system_tokens > 500: # Keep at least 500 tokens
|
|
167
|
+
system_reduction = min(system_tokens - 500, tokens_to_remove)
|
|
168
|
+
new_system_tokens = max(500, system_tokens - system_reduction)
|
|
169
|
+
system_prompt = truncate_to_token_limit(
|
|
170
|
+
system_prompt,
|
|
171
|
+
new_system_tokens,
|
|
172
|
+
preserve_start=True # Keep role/goal at start
|
|
173
|
+
)
|
|
174
|
+
tokens_to_remove -= system_reduction
|
|
175
|
+
logger.debug(f"Truncated system prompt by {system_reduction} tokens")
|
|
176
|
+
|
|
177
|
+
# If still over limit, truncate user prompt (last resort)
|
|
178
|
+
if tokens_to_remove > 0 and user_tokens > 0:
|
|
179
|
+
new_user_tokens = max(100, user_tokens - tokens_to_remove) # Keep at least 100 tokens
|
|
180
|
+
user_prompt = truncate_to_token_limit(
|
|
181
|
+
user_prompt,
|
|
182
|
+
new_user_tokens,
|
|
183
|
+
preserve_start=True # Keep task description
|
|
184
|
+
)
|
|
185
|
+
logger.warning(f"Had to truncate user prompt by {tokens_to_remove} tokens")
|
|
186
|
+
|
|
187
|
+
return system_prompt, user_prompt, memory_context
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def split_text_by_tokens(
|
|
191
|
+
text: str,
|
|
192
|
+
max_tokens_per_chunk: int,
|
|
193
|
+
overlap_tokens: int = 100,
|
|
194
|
+
) -> List[str]:
|
|
195
|
+
"""Split text into chunks by token count.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
text: Text to split
|
|
199
|
+
max_tokens_per_chunk: Maximum tokens per chunk
|
|
200
|
+
overlap_tokens: Number of tokens to overlap between chunks
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
List of text chunks
|
|
204
|
+
"""
|
|
205
|
+
estimated_total_tokens = estimate_tokens(text)
|
|
206
|
+
|
|
207
|
+
if estimated_total_tokens <= max_tokens_per_chunk:
|
|
208
|
+
return [text]
|
|
209
|
+
|
|
210
|
+
chunks = []
|
|
211
|
+
chars_per_chunk = max_tokens_per_chunk * 4 # 4 chars per token
|
|
212
|
+
overlap_chars = overlap_tokens * 4
|
|
213
|
+
|
|
214
|
+
start = 0
|
|
215
|
+
while start < len(text):
|
|
216
|
+
end = start + chars_per_chunk
|
|
217
|
+
chunk = text[start:end]
|
|
218
|
+
chunks.append(chunk)
|
|
219
|
+
|
|
220
|
+
# Move start forward, accounting for overlap
|
|
221
|
+
start = end - overlap_chars
|
|
222
|
+
|
|
223
|
+
if start >= len(text):
|
|
224
|
+
break
|
|
225
|
+
|
|
226
|
+
logger.debug(f"Split text into {len(chunks)} chunks")
|
|
227
|
+
return chunks
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class TokenCounter:
|
|
231
|
+
"""Token counter with caching for efficiency."""
|
|
232
|
+
|
|
233
|
+
def __init__(self, model: str):
|
|
234
|
+
"""Initialize token counter.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
model: Model name for token limit
|
|
238
|
+
"""
|
|
239
|
+
self.model = model
|
|
240
|
+
self.token_limit = get_model_token_limit(model)
|
|
241
|
+
self._cache: Dict[str, int] = {}
|
|
242
|
+
|
|
243
|
+
def count(self, text: str, use_cache: bool = True) -> int:
|
|
244
|
+
"""Count tokens in text.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
text: Text to count tokens for
|
|
248
|
+
use_cache: Whether to use cache
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Token count
|
|
252
|
+
"""
|
|
253
|
+
if use_cache and text in self._cache:
|
|
254
|
+
return self._cache[text]
|
|
255
|
+
|
|
256
|
+
count = estimate_tokens(text)
|
|
257
|
+
|
|
258
|
+
if use_cache:
|
|
259
|
+
self._cache[text] = count
|
|
260
|
+
|
|
261
|
+
return count
|
|
262
|
+
|
|
263
|
+
def fits_in_context(
|
|
264
|
+
self,
|
|
265
|
+
*texts: str,
|
|
266
|
+
reserve_tokens: int = 1000,
|
|
267
|
+
) -> bool:
|
|
268
|
+
"""Check if texts fit in context window.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
*texts: Texts to check
|
|
272
|
+
reserve_tokens: Tokens to reserve for response
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
True if texts fit in context window
|
|
276
|
+
"""
|
|
277
|
+
total_tokens = sum(self.count(text) for text in texts)
|
|
278
|
+
available_tokens = self.token_limit - reserve_tokens
|
|
279
|
+
return total_tokens <= available_tokens
|
|
280
|
+
|
|
281
|
+
def clear_cache(self) -> None:
|
|
282
|
+
"""Clear token count cache."""
|
|
283
|
+
self._cache.clear()
|
|
284
|
+
|
|
285
|
+
def get_stats(self) -> Dict[str, any]:
|
|
286
|
+
"""Get counter statistics.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Statistics dictionary
|
|
290
|
+
"""
|
|
291
|
+
return {
|
|
292
|
+
"model": self.model,
|
|
293
|
+
"token_limit": self.token_limit,
|
|
294
|
+
"cache_size": len(self._cache),
|
|
295
|
+
}
|