genxai-framework 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. cli/__init__.py +3 -0
  2. cli/commands/__init__.py +6 -0
  3. cli/commands/approval.py +85 -0
  4. cli/commands/audit.py +127 -0
  5. cli/commands/metrics.py +25 -0
  6. cli/commands/tool.py +389 -0
  7. cli/main.py +32 -0
  8. genxai/__init__.py +81 -0
  9. genxai/api/__init__.py +5 -0
  10. genxai/api/app.py +21 -0
  11. genxai/config/__init__.py +5 -0
  12. genxai/config/settings.py +37 -0
  13. genxai/connectors/__init__.py +19 -0
  14. genxai/connectors/base.py +122 -0
  15. genxai/connectors/kafka.py +92 -0
  16. genxai/connectors/postgres_cdc.py +95 -0
  17. genxai/connectors/registry.py +44 -0
  18. genxai/connectors/sqs.py +94 -0
  19. genxai/connectors/webhook.py +73 -0
  20. genxai/core/__init__.py +37 -0
  21. genxai/core/agent/__init__.py +32 -0
  22. genxai/core/agent/base.py +206 -0
  23. genxai/core/agent/config_io.py +59 -0
  24. genxai/core/agent/registry.py +98 -0
  25. genxai/core/agent/runtime.py +970 -0
  26. genxai/core/communication/__init__.py +6 -0
  27. genxai/core/communication/collaboration.py +44 -0
  28. genxai/core/communication/message_bus.py +192 -0
  29. genxai/core/communication/protocols.py +35 -0
  30. genxai/core/execution/__init__.py +22 -0
  31. genxai/core/execution/metadata.py +181 -0
  32. genxai/core/execution/queue.py +201 -0
  33. genxai/core/graph/__init__.py +30 -0
  34. genxai/core/graph/checkpoints.py +77 -0
  35. genxai/core/graph/edges.py +131 -0
  36. genxai/core/graph/engine.py +813 -0
  37. genxai/core/graph/executor.py +516 -0
  38. genxai/core/graph/nodes.py +161 -0
  39. genxai/core/graph/trigger_runner.py +40 -0
  40. genxai/core/memory/__init__.py +19 -0
  41. genxai/core/memory/base.py +72 -0
  42. genxai/core/memory/embedding.py +327 -0
  43. genxai/core/memory/episodic.py +448 -0
  44. genxai/core/memory/long_term.py +467 -0
  45. genxai/core/memory/manager.py +543 -0
  46. genxai/core/memory/persistence.py +297 -0
  47. genxai/core/memory/procedural.py +461 -0
  48. genxai/core/memory/semantic.py +526 -0
  49. genxai/core/memory/shared.py +62 -0
  50. genxai/core/memory/short_term.py +303 -0
  51. genxai/core/memory/vector_store.py +508 -0
  52. genxai/core/memory/working.py +211 -0
  53. genxai/core/state/__init__.py +6 -0
  54. genxai/core/state/manager.py +293 -0
  55. genxai/core/state/schema.py +115 -0
  56. genxai/llm/__init__.py +14 -0
  57. genxai/llm/base.py +150 -0
  58. genxai/llm/factory.py +329 -0
  59. genxai/llm/providers/__init__.py +1 -0
  60. genxai/llm/providers/anthropic.py +249 -0
  61. genxai/llm/providers/cohere.py +274 -0
  62. genxai/llm/providers/google.py +334 -0
  63. genxai/llm/providers/ollama.py +147 -0
  64. genxai/llm/providers/openai.py +257 -0
  65. genxai/llm/routing.py +83 -0
  66. genxai/observability/__init__.py +6 -0
  67. genxai/observability/logging.py +327 -0
  68. genxai/observability/metrics.py +494 -0
  69. genxai/observability/tracing.py +372 -0
  70. genxai/performance/__init__.py +39 -0
  71. genxai/performance/cache.py +256 -0
  72. genxai/performance/pooling.py +289 -0
  73. genxai/security/audit.py +304 -0
  74. genxai/security/auth.py +315 -0
  75. genxai/security/cost_control.py +528 -0
  76. genxai/security/default_policies.py +44 -0
  77. genxai/security/jwt.py +142 -0
  78. genxai/security/oauth.py +226 -0
  79. genxai/security/pii.py +366 -0
  80. genxai/security/policy_engine.py +82 -0
  81. genxai/security/rate_limit.py +341 -0
  82. genxai/security/rbac.py +247 -0
  83. genxai/security/validation.py +218 -0
  84. genxai/tools/__init__.py +21 -0
  85. genxai/tools/base.py +383 -0
  86. genxai/tools/builtin/__init__.py +131 -0
  87. genxai/tools/builtin/communication/__init__.py +15 -0
  88. genxai/tools/builtin/communication/email_sender.py +159 -0
  89. genxai/tools/builtin/communication/notification_manager.py +167 -0
  90. genxai/tools/builtin/communication/slack_notifier.py +118 -0
  91. genxai/tools/builtin/communication/sms_sender.py +118 -0
  92. genxai/tools/builtin/communication/webhook_caller.py +136 -0
  93. genxai/tools/builtin/computation/__init__.py +15 -0
  94. genxai/tools/builtin/computation/calculator.py +101 -0
  95. genxai/tools/builtin/computation/code_executor.py +183 -0
  96. genxai/tools/builtin/computation/data_validator.py +259 -0
  97. genxai/tools/builtin/computation/hash_generator.py +129 -0
  98. genxai/tools/builtin/computation/regex_matcher.py +201 -0
  99. genxai/tools/builtin/data/__init__.py +15 -0
  100. genxai/tools/builtin/data/csv_processor.py +213 -0
  101. genxai/tools/builtin/data/data_transformer.py +299 -0
  102. genxai/tools/builtin/data/json_processor.py +233 -0
  103. genxai/tools/builtin/data/text_analyzer.py +288 -0
  104. genxai/tools/builtin/data/xml_processor.py +175 -0
  105. genxai/tools/builtin/database/__init__.py +15 -0
  106. genxai/tools/builtin/database/database_inspector.py +157 -0
  107. genxai/tools/builtin/database/mongodb_query.py +196 -0
  108. genxai/tools/builtin/database/redis_cache.py +167 -0
  109. genxai/tools/builtin/database/sql_query.py +145 -0
  110. genxai/tools/builtin/database/vector_search.py +163 -0
  111. genxai/tools/builtin/file/__init__.py +17 -0
  112. genxai/tools/builtin/file/directory_scanner.py +214 -0
  113. genxai/tools/builtin/file/file_compressor.py +237 -0
  114. genxai/tools/builtin/file/file_reader.py +102 -0
  115. genxai/tools/builtin/file/file_writer.py +122 -0
  116. genxai/tools/builtin/file/image_processor.py +186 -0
  117. genxai/tools/builtin/file/pdf_parser.py +144 -0
  118. genxai/tools/builtin/test/__init__.py +15 -0
  119. genxai/tools/builtin/test/async_simulator.py +62 -0
  120. genxai/tools/builtin/test/data_transformer.py +99 -0
  121. genxai/tools/builtin/test/error_generator.py +82 -0
  122. genxai/tools/builtin/test/simple_math.py +94 -0
  123. genxai/tools/builtin/test/string_processor.py +72 -0
  124. genxai/tools/builtin/web/__init__.py +15 -0
  125. genxai/tools/builtin/web/api_caller.py +161 -0
  126. genxai/tools/builtin/web/html_parser.py +330 -0
  127. genxai/tools/builtin/web/http_client.py +187 -0
  128. genxai/tools/builtin/web/url_validator.py +162 -0
  129. genxai/tools/builtin/web/web_scraper.py +170 -0
  130. genxai/tools/custom/my_test_tool_2.py +9 -0
  131. genxai/tools/dynamic.py +105 -0
  132. genxai/tools/mcp_server.py +167 -0
  133. genxai/tools/persistence/__init__.py +6 -0
  134. genxai/tools/persistence/models.py +55 -0
  135. genxai/tools/persistence/service.py +322 -0
  136. genxai/tools/registry.py +227 -0
  137. genxai/tools/security/__init__.py +11 -0
  138. genxai/tools/security/limits.py +214 -0
  139. genxai/tools/security/policy.py +20 -0
  140. genxai/tools/security/sandbox.py +248 -0
  141. genxai/tools/templates.py +435 -0
  142. genxai/triggers/__init__.py +19 -0
  143. genxai/triggers/base.py +104 -0
  144. genxai/triggers/file_watcher.py +75 -0
  145. genxai/triggers/queue.py +68 -0
  146. genxai/triggers/registry.py +82 -0
  147. genxai/triggers/schedule.py +66 -0
  148. genxai/triggers/webhook.py +68 -0
  149. genxai/utils/__init__.py +1 -0
  150. genxai/utils/tokens.py +295 -0
  151. genxai_framework-0.1.0.dist-info/METADATA +495 -0
  152. genxai_framework-0.1.0.dist-info/RECORD +156 -0
  153. genxai_framework-0.1.0.dist-info/WHEEL +5 -0
  154. genxai_framework-0.1.0.dist-info/entry_points.txt +2 -0
  155. genxai_framework-0.1.0.dist-info/licenses/LICENSE +21 -0
  156. genxai_framework-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,82 @@
1
+ """Trigger registry for managing workflow triggers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Optional
6
+ import logging
7
+
8
+ from genxai.triggers.base import BaseTrigger, TriggerStatus
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class TriggerRegistry:
14
+ """Central registry for triggers."""
15
+
16
+ _instance: Optional["TriggerRegistry"] = None
17
+ _triggers: Dict[str, BaseTrigger] = {}
18
+
19
+ def __new__(cls) -> "TriggerRegistry":
20
+ if cls._instance is None:
21
+ cls._instance = super().__new__(cls)
22
+ return cls._instance
23
+
24
+ @classmethod
25
+ def register(cls, trigger: BaseTrigger) -> None:
26
+ """Register a trigger instance."""
27
+ if trigger.trigger_id in cls._triggers:
28
+ logger.warning("Trigger %s already registered, overwriting", trigger.trigger_id)
29
+ cls._triggers[trigger.trigger_id] = trigger
30
+ logger.info("Registered trigger: %s", trigger.trigger_id)
31
+
32
+ @classmethod
33
+ def unregister(cls, trigger_id: str) -> None:
34
+ """Unregister a trigger by id."""
35
+ trigger = cls._triggers.pop(trigger_id, None)
36
+ if trigger:
37
+ logger.info("Unregistered trigger: %s", trigger_id)
38
+ else:
39
+ logger.warning("Trigger %s not found in registry", trigger_id)
40
+
41
+ @classmethod
42
+ def get(cls, trigger_id: str) -> Optional[BaseTrigger]:
43
+ """Get a trigger by id."""
44
+ return cls._triggers.get(trigger_id)
45
+
46
+ @classmethod
47
+ def list_all(cls) -> List[BaseTrigger]:
48
+ """List all registered triggers."""
49
+ return list(cls._triggers.values())
50
+
51
+ @classmethod
52
+ def clear(cls) -> None:
53
+ """Clear all triggers from the registry."""
54
+ cls._triggers.clear()
55
+ logger.info("Cleared all triggers from registry")
56
+
57
+ @classmethod
58
+ async def start_all(cls) -> None:
59
+ """Start all registered triggers."""
60
+ for trigger in cls._triggers.values():
61
+ if trigger.status != TriggerStatus.RUNNING:
62
+ await trigger.start()
63
+
64
+ @classmethod
65
+ async def stop_all(cls) -> None:
66
+ """Stop all registered triggers."""
67
+ for trigger in cls._triggers.values():
68
+ if trigger.status != TriggerStatus.STOPPED:
69
+ await trigger.stop()
70
+
71
+ @classmethod
72
+ def get_stats(cls) -> Dict[str, int]:
73
+ """Return registry stats by trigger status."""
74
+ stats: Dict[str, int] = {}
75
+ for trigger in cls._triggers.values():
76
+ status = trigger.status.value
77
+ stats[status] = stats.get(status, 0) + 1
78
+ stats["total"] = len(cls._triggers)
79
+ return stats
80
+
81
+ def __repr__(self) -> str:
82
+ return f"TriggerRegistry(triggers={len(self._triggers)})"
@@ -0,0 +1,66 @@
1
+ """Schedule-based trigger implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import datetime
6
+ from typing import Any, Dict, Optional
7
+ import logging
8
+
9
+ from genxai.triggers.base import BaseTrigger
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ScheduleTrigger(BaseTrigger):
15
+ """Cron/interval trigger using APScheduler if available."""
16
+
17
+ def __init__(
18
+ self,
19
+ trigger_id: str,
20
+ cron: Optional[str] = None,
21
+ interval_seconds: Optional[int] = None,
22
+ payload: Optional[Dict[str, Any]] = None,
23
+ name: Optional[str] = None,
24
+ timezone: str = "UTC",
25
+ ) -> None:
26
+ super().__init__(trigger_id=trigger_id, name=name)
27
+ self.cron = cron
28
+ self.interval_seconds = interval_seconds
29
+ self.payload = payload or {}
30
+ self.timezone = timezone
31
+ self._scheduler = None
32
+ self._job = None
33
+
34
+ if not self.cron and not self.interval_seconds:
35
+ raise ValueError("Either cron or interval_seconds must be provided")
36
+
37
+ async def _start(self) -> None:
38
+ try:
39
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
40
+ from apscheduler.triggers.cron import CronTrigger
41
+ from apscheduler.triggers.interval import IntervalTrigger
42
+ except ImportError as exc:
43
+ raise ImportError(
44
+ "APScheduler is required for ScheduleTrigger. Install with: pip install apscheduler"
45
+ ) from exc
46
+
47
+ scheduler = AsyncIOScheduler(timezone=self.timezone)
48
+
49
+ if self.cron:
50
+ trigger = CronTrigger.from_crontab(self.cron, timezone=self.timezone)
51
+ else:
52
+ trigger = IntervalTrigger(seconds=self.interval_seconds)
53
+
54
+ async def _emit_wrapper() -> None:
55
+ await self.emit(payload={"scheduled_at": datetime.utcnow().isoformat(), **self.payload})
56
+
57
+ scheduler.add_job(_emit_wrapper, trigger=trigger)
58
+ scheduler.start()
59
+ self._scheduler = scheduler
60
+ logger.info("ScheduleTrigger %s started", self.trigger_id)
61
+
62
+ async def _stop(self) -> None:
63
+ if self._scheduler:
64
+ self._scheduler.shutdown(wait=False)
65
+ self._scheduler = None
66
+ logger.info("ScheduleTrigger %s stopped", self.trigger_id)
@@ -0,0 +1,68 @@
1
+ """Webhook trigger implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional
6
+ import hmac
7
+ import hashlib
8
+ import logging
9
+
10
+ from genxai.triggers.base import BaseTrigger
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class WebhookTrigger(BaseTrigger):
16
+ """HTTP webhook trigger.
17
+
18
+ This trigger does not start its own web server; it provides a handler that
19
+ can be mounted in FastAPI or other ASGI frameworks.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ trigger_id: str,
25
+ secret: Optional[str] = None,
26
+ name: Optional[str] = None,
27
+ header_name: str = "X-GenXAI-Signature",
28
+ hash_alg: str = "sha256",
29
+ ) -> None:
30
+ super().__init__(trigger_id=trigger_id, name=name)
31
+ self.secret = secret
32
+ self.header_name = header_name
33
+ self.hash_alg = hash_alg
34
+
35
+ async def _start(self) -> None:
36
+ logger.debug("Webhook trigger %s ready for requests", self.trigger_id)
37
+
38
+ async def _stop(self) -> None:
39
+ logger.debug("Webhook trigger %s stopped", self.trigger_id)
40
+
41
+ def validate_signature(self, payload: bytes, signature: Optional[str]) -> bool:
42
+ """Validate the webhook signature when a secret is provided."""
43
+ if not self.secret:
44
+ return True
45
+ if not signature:
46
+ return False
47
+
48
+ digest = hmac.new(self.secret.encode(), payload, getattr(hashlib, self.hash_alg)).hexdigest()
49
+ expected = f"{self.hash_alg}={digest}"
50
+ return hmac.compare_digest(expected, signature)
51
+
52
+ async def handle_request(
53
+ self,
54
+ payload: Dict[str, Any],
55
+ raw_body: Optional[bytes] = None,
56
+ headers: Optional[Dict[str, str]] = None,
57
+ ) -> Dict[str, Any]:
58
+ """Process an inbound webhook request and emit a trigger event."""
59
+ headers = headers or {}
60
+ signature = headers.get(self.header_name)
61
+
62
+ if self.secret and raw_body is not None:
63
+ if not self.validate_signature(raw_body, signature):
64
+ logger.warning("Webhook signature validation failed for %s", self.trigger_id)
65
+ return {"status": "rejected", "reason": "invalid signature"}
66
+
67
+ await self.emit(payload=payload, metadata={"headers": headers})
68
+ return {"status": "accepted", "trigger_id": self.trigger_id}
@@ -0,0 +1 @@
1
+ """Utility functions for GenXAI."""
genxai/utils/tokens.py ADDED
@@ -0,0 +1,295 @@
1
+ """Token counting and context window management utilities."""
2
+
3
+ from typing import Dict, List, Optional
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ # Model token limits (context window sizes)
10
+ MODEL_TOKEN_LIMITS: Dict[str, int] = {
11
+ # OpenAI models
12
+ "gpt-4": 8192,
13
+ "gpt-4-32k": 32768,
14
+ "gpt-4-turbo": 128000,
15
+ "gpt-4-turbo-preview": 128000,
16
+ "gpt-3.5-turbo": 4096,
17
+ "gpt-3.5-turbo-16k": 16384,
18
+ # Anthropic models
19
+ "claude-3-opus": 200000,
20
+ "claude-3-sonnet": 200000,
21
+ "claude-3-haiku": 200000,
22
+ "claude-2.1": 200000,
23
+ "claude-2": 100000,
24
+ "claude-instant": 100000,
25
+ # Google models
26
+ "gemini-pro": 32768,
27
+ "gemini-pro-vision": 16384,
28
+ # Cohere models
29
+ "command": 4096,
30
+ "command-light": 4096,
31
+ "command-nightly": 8192,
32
+ }
33
+
34
+
35
+ def get_model_token_limit(model: str) -> int:
36
+ """Get token limit for a model.
37
+
38
+ Args:
39
+ model: Model name
40
+
41
+ Returns:
42
+ Token limit for the model, or 4096 as default
43
+ """
44
+ # Try exact match first
45
+ if model in MODEL_TOKEN_LIMITS:
46
+ return MODEL_TOKEN_LIMITS[model]
47
+
48
+ # Try partial match (e.g., "gpt-4-0125-preview" matches "gpt-4")
49
+ for model_prefix, limit in MODEL_TOKEN_LIMITS.items():
50
+ if model.startswith(model_prefix):
51
+ return limit
52
+
53
+ # Default to conservative 4K limit
54
+ logger.warning(f"Unknown model '{model}', using default 4096 token limit")
55
+ return 4096
56
+
57
+
58
+ def estimate_tokens(text: str) -> int:
59
+ """Estimate token count for text.
60
+
61
+ This is a simple estimation based on character count.
62
+ For production use, consider using tiktoken library for accurate counting.
63
+
64
+ Args:
65
+ text: Text to estimate tokens for
66
+
67
+ Returns:
68
+ Estimated token count
69
+ """
70
+ # Rough estimation: ~4 characters per token for English text
71
+ # This is conservative and works reasonably well for most cases
72
+ return len(text) // 4
73
+
74
+
75
+ def truncate_to_token_limit(
76
+ text: str,
77
+ max_tokens: int,
78
+ preserve_start: bool = True,
79
+ ) -> str:
80
+ """Truncate text to fit within token limit.
81
+
82
+ Args:
83
+ text: Text to truncate
84
+ max_tokens: Maximum tokens allowed
85
+ preserve_start: If True, keep start of text; if False, keep end
86
+
87
+ Returns:
88
+ Truncated text
89
+ """
90
+ estimated_tokens = estimate_tokens(text)
91
+
92
+ if estimated_tokens <= max_tokens:
93
+ return text
94
+
95
+ # Calculate how many characters to keep
96
+ # Using 4 chars per token estimation
97
+ max_chars = max_tokens * 4
98
+
99
+ if preserve_start:
100
+ truncated = text[:max_chars]
101
+ logger.debug(f"Truncated text from {len(text)} to {len(truncated)} chars (start preserved)")
102
+ else:
103
+ truncated = text[-max_chars:]
104
+ logger.debug(f"Truncated text from {len(text)} to {len(truncated)} chars (end preserved)")
105
+
106
+ return truncated
107
+
108
+
109
+ def manage_context_window(
110
+ system_prompt: str,
111
+ user_prompt: str,
112
+ memory_context: str,
113
+ model: str,
114
+ reserve_tokens: int = 1000,
115
+ ) -> tuple[str, str, str]:
116
+ """Manage context window to fit within model limits.
117
+
118
+ Args:
119
+ system_prompt: System prompt text
120
+ user_prompt: User prompt text
121
+ memory_context: Memory context text
122
+ model: Model name
123
+ reserve_tokens: Tokens to reserve for response
124
+
125
+ Returns:
126
+ Tuple of (system_prompt, user_prompt, memory_context) adjusted to fit
127
+ """
128
+ model_limit = get_model_token_limit(model)
129
+ available_tokens = model_limit - reserve_tokens
130
+
131
+ # Estimate current token usage
132
+ system_tokens = estimate_tokens(system_prompt)
133
+ user_tokens = estimate_tokens(user_prompt)
134
+ memory_tokens = estimate_tokens(memory_context)
135
+ total_tokens = system_tokens + user_tokens + memory_tokens
136
+
137
+ logger.debug(
138
+ f"Context window: {total_tokens}/{model_limit} tokens "
139
+ f"(system: {system_tokens}, user: {user_tokens}, memory: {memory_tokens})"
140
+ )
141
+
142
+ # If within limit, return as-is
143
+ if total_tokens <= available_tokens:
144
+ return system_prompt, user_prompt, memory_context
145
+
146
+ # Need to truncate - prioritize user prompt, then system, then memory
147
+ tokens_to_remove = total_tokens - available_tokens
148
+
149
+ logger.warning(
150
+ f"Context window exceeded by {tokens_to_remove} tokens, truncating..."
151
+ )
152
+
153
+ # First, try truncating memory context
154
+ if memory_tokens > 0 and tokens_to_remove > 0:
155
+ memory_reduction = min(memory_tokens, tokens_to_remove)
156
+ new_memory_tokens = max(0, memory_tokens - memory_reduction)
157
+ memory_context = truncate_to_token_limit(
158
+ memory_context,
159
+ new_memory_tokens,
160
+ preserve_start=False # Keep most recent memories
161
+ )
162
+ tokens_to_remove -= memory_reduction
163
+ logger.debug(f"Truncated memory context by {memory_reduction} tokens")
164
+
165
+ # If still over limit, truncate system prompt
166
+ if tokens_to_remove > 0 and system_tokens > 500: # Keep at least 500 tokens
167
+ system_reduction = min(system_tokens - 500, tokens_to_remove)
168
+ new_system_tokens = max(500, system_tokens - system_reduction)
169
+ system_prompt = truncate_to_token_limit(
170
+ system_prompt,
171
+ new_system_tokens,
172
+ preserve_start=True # Keep role/goal at start
173
+ )
174
+ tokens_to_remove -= system_reduction
175
+ logger.debug(f"Truncated system prompt by {system_reduction} tokens")
176
+
177
+ # If still over limit, truncate user prompt (last resort)
178
+ if tokens_to_remove > 0 and user_tokens > 0:
179
+ new_user_tokens = max(100, user_tokens - tokens_to_remove) # Keep at least 100 tokens
180
+ user_prompt = truncate_to_token_limit(
181
+ user_prompt,
182
+ new_user_tokens,
183
+ preserve_start=True # Keep task description
184
+ )
185
+ logger.warning(f"Had to truncate user prompt by {tokens_to_remove} tokens")
186
+
187
+ return system_prompt, user_prompt, memory_context
188
+
189
+
190
+ def split_text_by_tokens(
191
+ text: str,
192
+ max_tokens_per_chunk: int,
193
+ overlap_tokens: int = 100,
194
+ ) -> List[str]:
195
+ """Split text into chunks by token count.
196
+
197
+ Args:
198
+ text: Text to split
199
+ max_tokens_per_chunk: Maximum tokens per chunk
200
+ overlap_tokens: Number of tokens to overlap between chunks
201
+
202
+ Returns:
203
+ List of text chunks
204
+ """
205
+ estimated_total_tokens = estimate_tokens(text)
206
+
207
+ if estimated_total_tokens <= max_tokens_per_chunk:
208
+ return [text]
209
+
210
+ chunks = []
211
+ chars_per_chunk = max_tokens_per_chunk * 4 # 4 chars per token
212
+ overlap_chars = overlap_tokens * 4
213
+
214
+ start = 0
215
+ while start < len(text):
216
+ end = start + chars_per_chunk
217
+ chunk = text[start:end]
218
+ chunks.append(chunk)
219
+
220
+ # Move start forward, accounting for overlap
221
+ start = end - overlap_chars
222
+
223
+ if start >= len(text):
224
+ break
225
+
226
+ logger.debug(f"Split text into {len(chunks)} chunks")
227
+ return chunks
228
+
229
+
230
+ class TokenCounter:
231
+ """Token counter with caching for efficiency."""
232
+
233
+ def __init__(self, model: str):
234
+ """Initialize token counter.
235
+
236
+ Args:
237
+ model: Model name for token limit
238
+ """
239
+ self.model = model
240
+ self.token_limit = get_model_token_limit(model)
241
+ self._cache: Dict[str, int] = {}
242
+
243
+ def count(self, text: str, use_cache: bool = True) -> int:
244
+ """Count tokens in text.
245
+
246
+ Args:
247
+ text: Text to count tokens for
248
+ use_cache: Whether to use cache
249
+
250
+ Returns:
251
+ Token count
252
+ """
253
+ if use_cache and text in self._cache:
254
+ return self._cache[text]
255
+
256
+ count = estimate_tokens(text)
257
+
258
+ if use_cache:
259
+ self._cache[text] = count
260
+
261
+ return count
262
+
263
+ def fits_in_context(
264
+ self,
265
+ *texts: str,
266
+ reserve_tokens: int = 1000,
267
+ ) -> bool:
268
+ """Check if texts fit in context window.
269
+
270
+ Args:
271
+ *texts: Texts to check
272
+ reserve_tokens: Tokens to reserve for response
273
+
274
+ Returns:
275
+ True if texts fit in context window
276
+ """
277
+ total_tokens = sum(self.count(text) for text in texts)
278
+ available_tokens = self.token_limit - reserve_tokens
279
+ return total_tokens <= available_tokens
280
+
281
+ def clear_cache(self) -> None:
282
+ """Clear token count cache."""
283
+ self._cache.clear()
284
+
285
+ def get_stats(self) -> Dict[str, any]:
286
+ """Get counter statistics.
287
+
288
+ Returns:
289
+ Statistics dictionary
290
+ """
291
+ return {
292
+ "model": self.model,
293
+ "token_limit": self.token_limit,
294
+ "cache_size": len(self._cache),
295
+ }