aiecs 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. aiecs/__init__.py +72 -0
  2. aiecs/__main__.py +41 -0
  3. aiecs/aiecs_client.py +469 -0
  4. aiecs/application/__init__.py +10 -0
  5. aiecs/application/executors/__init__.py +10 -0
  6. aiecs/application/executors/operation_executor.py +363 -0
  7. aiecs/application/knowledge_graph/__init__.py +7 -0
  8. aiecs/application/knowledge_graph/builder/__init__.py +37 -0
  9. aiecs/application/knowledge_graph/builder/document_builder.py +375 -0
  10. aiecs/application/knowledge_graph/builder/graph_builder.py +356 -0
  11. aiecs/application/knowledge_graph/builder/schema_mapping.py +531 -0
  12. aiecs/application/knowledge_graph/builder/structured_pipeline.py +443 -0
  13. aiecs/application/knowledge_graph/builder/text_chunker.py +319 -0
  14. aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
  15. aiecs/application/knowledge_graph/extractors/base.py +100 -0
  16. aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +327 -0
  17. aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +349 -0
  18. aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +244 -0
  19. aiecs/application/knowledge_graph/fusion/__init__.py +23 -0
  20. aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +387 -0
  21. aiecs/application/knowledge_graph/fusion/entity_linker.py +343 -0
  22. aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +580 -0
  23. aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +189 -0
  24. aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
  25. aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +344 -0
  26. aiecs/application/knowledge_graph/pattern_matching/query_executor.py +378 -0
  27. aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
  28. aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +199 -0
  29. aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
  30. aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
  31. aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +347 -0
  32. aiecs/application/knowledge_graph/reasoning/inference_engine.py +504 -0
  33. aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +167 -0
  34. aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
  35. aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
  36. aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +630 -0
  37. aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +654 -0
  38. aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +477 -0
  39. aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +390 -0
  40. aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +217 -0
  41. aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +169 -0
  42. aiecs/application/knowledge_graph/reasoning/query_planner.py +872 -0
  43. aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +554 -0
  44. aiecs/application/knowledge_graph/retrieval/__init__.py +19 -0
  45. aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +596 -0
  46. aiecs/application/knowledge_graph/search/__init__.py +59 -0
  47. aiecs/application/knowledge_graph/search/hybrid_search.py +423 -0
  48. aiecs/application/knowledge_graph/search/reranker.py +295 -0
  49. aiecs/application/knowledge_graph/search/reranker_strategies.py +553 -0
  50. aiecs/application/knowledge_graph/search/text_similarity.py +398 -0
  51. aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
  52. aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +329 -0
  53. aiecs/application/knowledge_graph/traversal/path_scorer.py +269 -0
  54. aiecs/application/knowledge_graph/validators/__init__.py +13 -0
  55. aiecs/application/knowledge_graph/validators/relation_validator.py +189 -0
  56. aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
  57. aiecs/application/knowledge_graph/visualization/graph_visualizer.py +321 -0
  58. aiecs/common/__init__.py +9 -0
  59. aiecs/common/knowledge_graph/__init__.py +17 -0
  60. aiecs/common/knowledge_graph/runnable.py +484 -0
  61. aiecs/config/__init__.py +16 -0
  62. aiecs/config/config.py +498 -0
  63. aiecs/config/graph_config.py +137 -0
  64. aiecs/config/registry.py +23 -0
  65. aiecs/core/__init__.py +46 -0
  66. aiecs/core/interface/__init__.py +34 -0
  67. aiecs/core/interface/execution_interface.py +152 -0
  68. aiecs/core/interface/storage_interface.py +171 -0
  69. aiecs/domain/__init__.py +289 -0
  70. aiecs/domain/agent/__init__.py +189 -0
  71. aiecs/domain/agent/base_agent.py +697 -0
  72. aiecs/domain/agent/exceptions.py +103 -0
  73. aiecs/domain/agent/graph_aware_mixin.py +559 -0
  74. aiecs/domain/agent/hybrid_agent.py +490 -0
  75. aiecs/domain/agent/integration/__init__.py +26 -0
  76. aiecs/domain/agent/integration/context_compressor.py +222 -0
  77. aiecs/domain/agent/integration/context_engine_adapter.py +252 -0
  78. aiecs/domain/agent/integration/retry_policy.py +219 -0
  79. aiecs/domain/agent/integration/role_config.py +213 -0
  80. aiecs/domain/agent/knowledge_aware_agent.py +646 -0
  81. aiecs/domain/agent/lifecycle.py +296 -0
  82. aiecs/domain/agent/llm_agent.py +300 -0
  83. aiecs/domain/agent/memory/__init__.py +12 -0
  84. aiecs/domain/agent/memory/conversation.py +197 -0
  85. aiecs/domain/agent/migration/__init__.py +14 -0
  86. aiecs/domain/agent/migration/conversion.py +160 -0
  87. aiecs/domain/agent/migration/legacy_wrapper.py +90 -0
  88. aiecs/domain/agent/models.py +317 -0
  89. aiecs/domain/agent/observability.py +407 -0
  90. aiecs/domain/agent/persistence.py +289 -0
  91. aiecs/domain/agent/prompts/__init__.py +29 -0
  92. aiecs/domain/agent/prompts/builder.py +161 -0
  93. aiecs/domain/agent/prompts/formatters.py +189 -0
  94. aiecs/domain/agent/prompts/template.py +255 -0
  95. aiecs/domain/agent/registry.py +260 -0
  96. aiecs/domain/agent/tool_agent.py +257 -0
  97. aiecs/domain/agent/tools/__init__.py +12 -0
  98. aiecs/domain/agent/tools/schema_generator.py +221 -0
  99. aiecs/domain/community/__init__.py +155 -0
  100. aiecs/domain/community/agent_adapter.py +477 -0
  101. aiecs/domain/community/analytics.py +481 -0
  102. aiecs/domain/community/collaborative_workflow.py +642 -0
  103. aiecs/domain/community/communication_hub.py +645 -0
  104. aiecs/domain/community/community_builder.py +320 -0
  105. aiecs/domain/community/community_integration.py +800 -0
  106. aiecs/domain/community/community_manager.py +813 -0
  107. aiecs/domain/community/decision_engine.py +879 -0
  108. aiecs/domain/community/exceptions.py +225 -0
  109. aiecs/domain/community/models/__init__.py +33 -0
  110. aiecs/domain/community/models/community_models.py +268 -0
  111. aiecs/domain/community/resource_manager.py +457 -0
  112. aiecs/domain/community/shared_context_manager.py +603 -0
  113. aiecs/domain/context/__init__.py +58 -0
  114. aiecs/domain/context/context_engine.py +989 -0
  115. aiecs/domain/context/conversation_models.py +354 -0
  116. aiecs/domain/context/graph_memory.py +467 -0
  117. aiecs/domain/execution/__init__.py +12 -0
  118. aiecs/domain/execution/model.py +57 -0
  119. aiecs/domain/knowledge_graph/__init__.py +19 -0
  120. aiecs/domain/knowledge_graph/models/__init__.py +52 -0
  121. aiecs/domain/knowledge_graph/models/entity.py +130 -0
  122. aiecs/domain/knowledge_graph/models/evidence.py +194 -0
  123. aiecs/domain/knowledge_graph/models/inference_rule.py +186 -0
  124. aiecs/domain/knowledge_graph/models/path.py +179 -0
  125. aiecs/domain/knowledge_graph/models/path_pattern.py +173 -0
  126. aiecs/domain/knowledge_graph/models/query.py +272 -0
  127. aiecs/domain/knowledge_graph/models/query_plan.py +187 -0
  128. aiecs/domain/knowledge_graph/models/relation.py +136 -0
  129. aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
  130. aiecs/domain/knowledge_graph/schema/entity_type.py +135 -0
  131. aiecs/domain/knowledge_graph/schema/graph_schema.py +271 -0
  132. aiecs/domain/knowledge_graph/schema/property_schema.py +155 -0
  133. aiecs/domain/knowledge_graph/schema/relation_type.py +171 -0
  134. aiecs/domain/knowledge_graph/schema/schema_manager.py +496 -0
  135. aiecs/domain/knowledge_graph/schema/type_enums.py +205 -0
  136. aiecs/domain/task/__init__.py +13 -0
  137. aiecs/domain/task/dsl_processor.py +613 -0
  138. aiecs/domain/task/model.py +62 -0
  139. aiecs/domain/task/task_context.py +268 -0
  140. aiecs/infrastructure/__init__.py +24 -0
  141. aiecs/infrastructure/graph_storage/__init__.py +11 -0
  142. aiecs/infrastructure/graph_storage/base.py +601 -0
  143. aiecs/infrastructure/graph_storage/batch_operations.py +449 -0
  144. aiecs/infrastructure/graph_storage/cache.py +429 -0
  145. aiecs/infrastructure/graph_storage/distributed.py +226 -0
  146. aiecs/infrastructure/graph_storage/error_handling.py +390 -0
  147. aiecs/infrastructure/graph_storage/graceful_degradation.py +306 -0
  148. aiecs/infrastructure/graph_storage/health_checks.py +378 -0
  149. aiecs/infrastructure/graph_storage/in_memory.py +514 -0
  150. aiecs/infrastructure/graph_storage/index_optimization.py +483 -0
  151. aiecs/infrastructure/graph_storage/lazy_loading.py +410 -0
  152. aiecs/infrastructure/graph_storage/metrics.py +357 -0
  153. aiecs/infrastructure/graph_storage/migration.py +413 -0
  154. aiecs/infrastructure/graph_storage/pagination.py +471 -0
  155. aiecs/infrastructure/graph_storage/performance_monitoring.py +466 -0
  156. aiecs/infrastructure/graph_storage/postgres.py +871 -0
  157. aiecs/infrastructure/graph_storage/query_optimizer.py +635 -0
  158. aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
  159. aiecs/infrastructure/graph_storage/sqlite.py +623 -0
  160. aiecs/infrastructure/graph_storage/streaming.py +495 -0
  161. aiecs/infrastructure/messaging/__init__.py +13 -0
  162. aiecs/infrastructure/messaging/celery_task_manager.py +383 -0
  163. aiecs/infrastructure/messaging/websocket_manager.py +298 -0
  164. aiecs/infrastructure/monitoring/__init__.py +34 -0
  165. aiecs/infrastructure/monitoring/executor_metrics.py +174 -0
  166. aiecs/infrastructure/monitoring/global_metrics_manager.py +213 -0
  167. aiecs/infrastructure/monitoring/structured_logger.py +48 -0
  168. aiecs/infrastructure/monitoring/tracing_manager.py +410 -0
  169. aiecs/infrastructure/persistence/__init__.py +24 -0
  170. aiecs/infrastructure/persistence/context_engine_client.py +187 -0
  171. aiecs/infrastructure/persistence/database_manager.py +333 -0
  172. aiecs/infrastructure/persistence/file_storage.py +754 -0
  173. aiecs/infrastructure/persistence/redis_client.py +220 -0
  174. aiecs/llm/__init__.py +86 -0
  175. aiecs/llm/callbacks/__init__.py +11 -0
  176. aiecs/llm/callbacks/custom_callbacks.py +264 -0
  177. aiecs/llm/client_factory.py +420 -0
  178. aiecs/llm/clients/__init__.py +33 -0
  179. aiecs/llm/clients/base_client.py +193 -0
  180. aiecs/llm/clients/googleai_client.py +181 -0
  181. aiecs/llm/clients/openai_client.py +131 -0
  182. aiecs/llm/clients/vertex_client.py +437 -0
  183. aiecs/llm/clients/xai_client.py +184 -0
  184. aiecs/llm/config/__init__.py +51 -0
  185. aiecs/llm/config/config_loader.py +275 -0
  186. aiecs/llm/config/config_validator.py +236 -0
  187. aiecs/llm/config/model_config.py +151 -0
  188. aiecs/llm/utils/__init__.py +10 -0
  189. aiecs/llm/utils/validate_config.py +91 -0
  190. aiecs/main.py +363 -0
  191. aiecs/scripts/__init__.py +3 -0
  192. aiecs/scripts/aid/VERSION_MANAGEMENT.md +97 -0
  193. aiecs/scripts/aid/__init__.py +19 -0
  194. aiecs/scripts/aid/version_manager.py +215 -0
  195. aiecs/scripts/dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md +242 -0
  196. aiecs/scripts/dependance_check/README_DEPENDENCY_CHECKER.md +310 -0
  197. aiecs/scripts/dependance_check/__init__.py +17 -0
  198. aiecs/scripts/dependance_check/dependency_checker.py +938 -0
  199. aiecs/scripts/dependance_check/dependency_fixer.py +391 -0
  200. aiecs/scripts/dependance_check/download_nlp_data.py +396 -0
  201. aiecs/scripts/dependance_check/quick_dependency_check.py +270 -0
  202. aiecs/scripts/dependance_check/setup_nlp_data.sh +217 -0
  203. aiecs/scripts/dependance_patch/__init__.py +7 -0
  204. aiecs/scripts/dependance_patch/fix_weasel/README_WEASEL_PATCH.md +126 -0
  205. aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
  206. aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.py +128 -0
  207. aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.sh +82 -0
  208. aiecs/scripts/dependance_patch/fix_weasel/patch_weasel_library.sh +188 -0
  209. aiecs/scripts/dependance_patch/fix_weasel/run_weasel_patch.sh +41 -0
  210. aiecs/scripts/tools_develop/README.md +449 -0
  211. aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
  212. aiecs/scripts/tools_develop/__init__.py +21 -0
  213. aiecs/scripts/tools_develop/check_type_annotations.py +259 -0
  214. aiecs/scripts/tools_develop/validate_tool_schemas.py +422 -0
  215. aiecs/scripts/tools_develop/verify_tools.py +356 -0
  216. aiecs/tasks/__init__.py +1 -0
  217. aiecs/tasks/worker.py +172 -0
  218. aiecs/tools/__init__.py +299 -0
  219. aiecs/tools/apisource/__init__.py +99 -0
  220. aiecs/tools/apisource/intelligence/__init__.py +19 -0
  221. aiecs/tools/apisource/intelligence/data_fusion.py +381 -0
  222. aiecs/tools/apisource/intelligence/query_analyzer.py +413 -0
  223. aiecs/tools/apisource/intelligence/search_enhancer.py +388 -0
  224. aiecs/tools/apisource/monitoring/__init__.py +9 -0
  225. aiecs/tools/apisource/monitoring/metrics.py +303 -0
  226. aiecs/tools/apisource/providers/__init__.py +115 -0
  227. aiecs/tools/apisource/providers/base.py +664 -0
  228. aiecs/tools/apisource/providers/census.py +401 -0
  229. aiecs/tools/apisource/providers/fred.py +564 -0
  230. aiecs/tools/apisource/providers/newsapi.py +412 -0
  231. aiecs/tools/apisource/providers/worldbank.py +357 -0
  232. aiecs/tools/apisource/reliability/__init__.py +12 -0
  233. aiecs/tools/apisource/reliability/error_handler.py +375 -0
  234. aiecs/tools/apisource/reliability/fallback_strategy.py +391 -0
  235. aiecs/tools/apisource/tool.py +850 -0
  236. aiecs/tools/apisource/utils/__init__.py +9 -0
  237. aiecs/tools/apisource/utils/validators.py +338 -0
  238. aiecs/tools/base_tool.py +201 -0
  239. aiecs/tools/docs/__init__.py +121 -0
  240. aiecs/tools/docs/ai_document_orchestrator.py +599 -0
  241. aiecs/tools/docs/ai_document_writer_orchestrator.py +2403 -0
  242. aiecs/tools/docs/content_insertion_tool.py +1333 -0
  243. aiecs/tools/docs/document_creator_tool.py +1317 -0
  244. aiecs/tools/docs/document_layout_tool.py +1166 -0
  245. aiecs/tools/docs/document_parser_tool.py +994 -0
  246. aiecs/tools/docs/document_writer_tool.py +1818 -0
  247. aiecs/tools/knowledge_graph/__init__.py +17 -0
  248. aiecs/tools/knowledge_graph/graph_reasoning_tool.py +734 -0
  249. aiecs/tools/knowledge_graph/graph_search_tool.py +923 -0
  250. aiecs/tools/knowledge_graph/kg_builder_tool.py +476 -0
  251. aiecs/tools/langchain_adapter.py +542 -0
  252. aiecs/tools/schema_generator.py +275 -0
  253. aiecs/tools/search_tool/__init__.py +100 -0
  254. aiecs/tools/search_tool/analyzers.py +589 -0
  255. aiecs/tools/search_tool/cache.py +260 -0
  256. aiecs/tools/search_tool/constants.py +128 -0
  257. aiecs/tools/search_tool/context.py +216 -0
  258. aiecs/tools/search_tool/core.py +749 -0
  259. aiecs/tools/search_tool/deduplicator.py +123 -0
  260. aiecs/tools/search_tool/error_handler.py +271 -0
  261. aiecs/tools/search_tool/metrics.py +371 -0
  262. aiecs/tools/search_tool/rate_limiter.py +178 -0
  263. aiecs/tools/search_tool/schemas.py +277 -0
  264. aiecs/tools/statistics/__init__.py +80 -0
  265. aiecs/tools/statistics/ai_data_analysis_orchestrator.py +643 -0
  266. aiecs/tools/statistics/ai_insight_generator_tool.py +505 -0
  267. aiecs/tools/statistics/ai_report_orchestrator_tool.py +694 -0
  268. aiecs/tools/statistics/data_loader_tool.py +564 -0
  269. aiecs/tools/statistics/data_profiler_tool.py +658 -0
  270. aiecs/tools/statistics/data_transformer_tool.py +573 -0
  271. aiecs/tools/statistics/data_visualizer_tool.py +495 -0
  272. aiecs/tools/statistics/model_trainer_tool.py +487 -0
  273. aiecs/tools/statistics/statistical_analyzer_tool.py +459 -0
  274. aiecs/tools/task_tools/__init__.py +86 -0
  275. aiecs/tools/task_tools/chart_tool.py +732 -0
  276. aiecs/tools/task_tools/classfire_tool.py +922 -0
  277. aiecs/tools/task_tools/image_tool.py +447 -0
  278. aiecs/tools/task_tools/office_tool.py +684 -0
  279. aiecs/tools/task_tools/pandas_tool.py +635 -0
  280. aiecs/tools/task_tools/report_tool.py +635 -0
  281. aiecs/tools/task_tools/research_tool.py +392 -0
  282. aiecs/tools/task_tools/scraper_tool.py +715 -0
  283. aiecs/tools/task_tools/stats_tool.py +688 -0
  284. aiecs/tools/temp_file_manager.py +130 -0
  285. aiecs/tools/tool_executor/__init__.py +37 -0
  286. aiecs/tools/tool_executor/tool_executor.py +881 -0
  287. aiecs/utils/LLM_output_structor.py +445 -0
  288. aiecs/utils/__init__.py +34 -0
  289. aiecs/utils/base_callback.py +47 -0
  290. aiecs/utils/cache_provider.py +695 -0
  291. aiecs/utils/execution_utils.py +184 -0
  292. aiecs/utils/logging.py +1 -0
  293. aiecs/utils/prompt_loader.py +14 -0
  294. aiecs/utils/token_usage_repository.py +323 -0
  295. aiecs/ws/__init__.py +0 -0
  296. aiecs/ws/socket_server.py +52 -0
  297. aiecs-1.5.1.dist-info/METADATA +608 -0
  298. aiecs-1.5.1.dist-info/RECORD +302 -0
  299. aiecs-1.5.1.dist-info/WHEEL +5 -0
  300. aiecs-1.5.1.dist-info/entry_points.txt +10 -0
  301. aiecs-1.5.1.dist-info/licenses/LICENSE +225 -0
  302. aiecs-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,487 @@
1
+ """
2
+ Model Trainer Tool - AutoML and machine learning model training
3
+
4
+ This tool provides AutoML capabilities with:
5
+ - Automatic model selection for classification and regression
6
+ - Hyperparameter tuning
7
+ - Model evaluation and comparison
8
+ - Feature importance analysis
9
+ - Model explanation support
10
+ """
11
+
12
+ import logging
13
+ from typing import Dict, Any, List, Optional, Union
14
+ from enum import Enum
15
+
16
+ import pandas as pd
17
+ import numpy as np
18
+ from sklearn.model_selection import train_test_split, cross_val_score
19
+ from sklearn.metrics import (
20
+ accuracy_score,
21
+ precision_score,
22
+ recall_score,
23
+ f1_score,
24
+ r2_score,
25
+ mean_squared_error,
26
+ )
27
+ from sklearn.ensemble import (
28
+ RandomForestClassifier,
29
+ RandomForestRegressor,
30
+ GradientBoostingClassifier,
31
+ GradientBoostingRegressor,
32
+ )
33
+ from sklearn.linear_model import LogisticRegression, LinearRegression
34
+ from sklearn.preprocessing import LabelEncoder
35
+ from pydantic import BaseModel, Field, ConfigDict
36
+
37
+ from aiecs.tools.base_tool import BaseTool
38
+ from aiecs.tools import register_tool
39
+
40
+
41
+ class ModelType(str, Enum):
42
+ """Supported model types"""
43
+
44
+ LOGISTIC_REGRESSION = "logistic_regression"
45
+ LINEAR_REGRESSION = "linear_regression"
46
+ RANDOM_FOREST_CLASSIFIER = "random_forest_classifier"
47
+ RANDOM_FOREST_REGRESSOR = "random_forest_regressor"
48
+ GRADIENT_BOOSTING_CLASSIFIER = "gradient_boosting_classifier"
49
+ GRADIENT_BOOSTING_REGRESSOR = "gradient_boosting_regressor"
50
+ AUTO = "auto"
51
+
52
+
53
+ class TaskType(str, Enum):
54
+ """Machine learning task types"""
55
+
56
+ CLASSIFICATION = "classification"
57
+ REGRESSION = "regression"
58
+ CLUSTERING = "clustering"
59
+
60
+
61
+ class ModelTrainerError(Exception):
62
+ """Base exception for ModelTrainer errors"""
63
+
64
+
65
+ class TrainingError(ModelTrainerError):
66
+ """Raised when model training fails"""
67
+
68
+
69
+ @register_tool("model_trainer")
70
+ class ModelTrainerTool(BaseTool):
71
+ """
72
+ AutoML tool that can:
73
+ 1. Train multiple model types
74
+ 2. Perform hyperparameter tuning
75
+ 3. Evaluate and compare models
76
+ 4. Generate feature importance
77
+ 5. Provide model explanations
78
+ """
79
+
80
+ # Configuration schema
81
+ class Config(BaseModel):
82
+ """Configuration for the model trainer tool"""
83
+
84
+ model_config = ConfigDict(env_prefix="MODEL_TRAINER_")
85
+
86
+ test_size: float = Field(default=0.2, description="Proportion of data to use for testing")
87
+ random_state: int = Field(default=42, description="Random state for reproducibility")
88
+ cv_folds: int = Field(default=5, description="Number of cross-validation folds")
89
+ enable_hyperparameter_tuning: bool = Field(
90
+ default=False,
91
+ description="Whether to enable hyperparameter tuning",
92
+ )
93
+ max_tuning_iterations: int = Field(
94
+ default=20,
95
+ description="Maximum number of hyperparameter tuning iterations",
96
+ )
97
+
98
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
99
+ """Initialize ModelTrainerTool with settings"""
100
+ super().__init__(config)
101
+
102
+ # Parse configuration
103
+ self.config = self.Config(**(config or {}))
104
+
105
+ self.logger = logging.getLogger(__name__)
106
+ if not self.logger.handlers:
107
+ handler = logging.StreamHandler()
108
+ handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
109
+ self.logger.addHandler(handler)
110
+ self.logger.setLevel(logging.INFO)
111
+
112
+ self._init_external_tools()
113
+ self.trained_models = {}
114
+
115
+ def _init_external_tools(self):
116
+ """Initialize external task tools"""
117
+ self.external_tools = {}
118
+
119
+ # Schema definitions
120
+ class TrainModelSchema(BaseModel):
121
+ """Schema for train_model operation"""
122
+
123
+ data: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(description="Training data")
124
+ target: str = Field(description="Target column name")
125
+ model_type: ModelType = Field(default=ModelType.AUTO, description="Model type to train")
126
+ auto_tune: bool = Field(default=False, description="Enable hyperparameter tuning")
127
+ cross_validation: int = Field(default=5, description="Number of CV folds")
128
+
129
+ class AutoSelectModelSchema(BaseModel):
130
+ """Schema for auto_select_model operation"""
131
+
132
+ data: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(
133
+ description="Data for model selection"
134
+ )
135
+ target: str = Field(description="Target column name")
136
+ task_type: Optional[TaskType] = Field(default=None, description="Task type")
137
+
138
+ class EvaluateModelSchema(BaseModel):
139
+ """Schema for evaluate_model operation"""
140
+
141
+ model_id: str = Field(description="ID of trained model")
142
+ test_data: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(description="Test data")
143
+ target: str = Field(description="Target column name")
144
+
145
+ def train_model(
146
+ self,
147
+ data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
148
+ target: str,
149
+ model_type: ModelType = ModelType.AUTO,
150
+ auto_tune: bool = False,
151
+ cross_validation: int = 5,
152
+ ) -> Dict[str, Any]:
153
+ """
154
+ Train and evaluate model.
155
+
156
+ Args:
157
+ data: Training data
158
+ target: Target column name
159
+ model_type: Type of model to train (auto-selected if AUTO)
160
+ auto_tune: Enable hyperparameter tuning
161
+ cross_validation: Number of cross-validation folds
162
+
163
+ Returns:
164
+ Dict containing:
165
+ - model_id: Unique identifier for trained model
166
+ - model_type: Type of model trained
167
+ - performance: Performance metrics
168
+ - feature_importance: Feature importance scores
169
+ - cross_validation_scores: CV scores
170
+ """
171
+ try:
172
+ df = self._to_dataframe(data)
173
+
174
+ # Separate features and target
175
+ X = df.drop(columns=[target])
176
+ y = df[target]
177
+
178
+ # Determine task type and model
179
+ task_type = self._determine_task_type(y)
180
+
181
+ if model_type == ModelType.AUTO:
182
+ model_type = self._auto_select_model_type(task_type)
183
+ self.logger.info(f"Auto-selected model type: {model_type.value}")
184
+
185
+ # Prepare data
186
+ X_processed, feature_names = self._preprocess_features(X)
187
+
188
+ # Handle categorical target for classification
189
+ if task_type == TaskType.CLASSIFICATION:
190
+ label_encoder = LabelEncoder()
191
+ y = label_encoder.fit_transform(y)
192
+ else:
193
+ label_encoder = None
194
+
195
+ # Split data
196
+ X_train, X_test, y_train, y_test = train_test_split(
197
+ X_processed,
198
+ y,
199
+ test_size=self.config.test_size,
200
+ random_state=self.config.random_state,
201
+ )
202
+
203
+ # Create and train model
204
+ model = self._create_model(model_type)
205
+ model.fit(X_train, y_train)
206
+
207
+ # Make predictions
208
+ y_pred = model.predict(X_test)
209
+
210
+ # Calculate metrics
211
+ performance = self._calculate_metrics(y_test, y_pred, task_type)
212
+
213
+ # Cross-validation
214
+ cv_scores = cross_val_score(model, X_processed, y, cv=cross_validation)
215
+
216
+ # Feature importance
217
+ feature_importance = self._get_feature_importance(model, feature_names)
218
+
219
+ # Store model
220
+ model_id = f"model_{len(self.trained_models) + 1}"
221
+ self.trained_models[model_id] = {
222
+ "model": model,
223
+ "model_type": model_type.value,
224
+ "task_type": task_type.value,
225
+ "feature_names": feature_names,
226
+ "label_encoder": label_encoder,
227
+ }
228
+
229
+ return {
230
+ "model_id": model_id,
231
+ "model_type": model_type.value,
232
+ "task_type": task_type.value,
233
+ "performance": performance,
234
+ "feature_importance": feature_importance,
235
+ "cross_validation_scores": {
236
+ "scores": cv_scores.tolist(),
237
+ "mean": float(cv_scores.mean()),
238
+ "std": float(cv_scores.std()),
239
+ },
240
+ "training_samples": len(X_train),
241
+ "test_samples": len(X_test),
242
+ }
243
+
244
+ except Exception as e:
245
+ self.logger.error(f"Error training model: {e}")
246
+ raise TrainingError(f"Model training failed: {e}")
247
+
248
+ def auto_select_model(
249
+ self,
250
+ data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
251
+ target: str,
252
+ task_type: Optional[TaskType] = None,
253
+ ) -> Dict[str, Any]:
254
+ """
255
+ Automatically select best model based on data characteristics.
256
+
257
+ Args:
258
+ data: Data for model selection
259
+ target: Target column name
260
+ task_type: Optional task type (auto-determined if None)
261
+
262
+ Returns:
263
+ Dict containing recommended model and reasoning
264
+ """
265
+ try:
266
+ df = self._to_dataframe(data)
267
+ y = df[target]
268
+
269
+ # Determine task type
270
+ if task_type is None:
271
+ task_type = self._determine_task_type(y)
272
+
273
+ # Select model
274
+ model_type = self._auto_select_model_type(task_type)
275
+
276
+ # Provide reasoning
277
+ reasoning = self._explain_model_selection(df, y, task_type, model_type)
278
+
279
+ return {
280
+ "recommended_model": model_type.value,
281
+ "task_type": task_type.value,
282
+ "reasoning": reasoning,
283
+ "confidence": "high",
284
+ }
285
+
286
+ except Exception as e:
287
+ self.logger.error(f"Error in auto model selection: {e}")
288
+ raise TrainingError(f"Model selection failed: {e}")
289
+
290
+ def evaluate_model(
291
+ self,
292
+ model_id: str,
293
+ test_data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
294
+ target: str,
295
+ ) -> Dict[str, Any]:
296
+ """
297
+ Evaluate trained model on test data.
298
+
299
+ Args:
300
+ model_id: ID of trained model
301
+ test_data: Test data
302
+ target: Target column name
303
+
304
+ Returns:
305
+ Dict containing evaluation metrics
306
+ """
307
+ try:
308
+ if model_id not in self.trained_models:
309
+ raise TrainingError(f"Model {model_id} not found")
310
+
311
+ df = self._to_dataframe(test_data)
312
+ X_test = df.drop(columns=[target])
313
+ y_test = df[target]
314
+
315
+ model_info = self.trained_models[model_id]
316
+ model = model_info["model"]
317
+ task_type = TaskType(model_info["task_type"])
318
+
319
+ # Preprocess features
320
+ X_processed, _ = self._preprocess_features(X_test)
321
+
322
+ # Handle label encoding for classification
323
+ if model_info["label_encoder"]:
324
+ y_test = model_info["label_encoder"].transform(y_test)
325
+
326
+ # Make predictions
327
+ y_pred = model.predict(X_processed)
328
+
329
+ # Calculate metrics
330
+ performance = self._calculate_metrics(y_test, y_pred, task_type)
331
+
332
+ return {
333
+ "model_id": model_id,
334
+ "performance": performance,
335
+ "test_samples": len(X_test),
336
+ }
337
+
338
+ except Exception as e:
339
+ self.logger.error(f"Error evaluating model: {e}")
340
+ raise TrainingError(f"Model evaluation failed: {e}")
341
+
342
+ def tune_hyperparameters(
343
+ self,
344
+ data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
345
+ target: str,
346
+ model_type: ModelType,
347
+ ) -> Dict[str, Any]:
348
+ """
349
+ Tune hyperparameters for specified model type.
350
+
351
+ Args:
352
+ data: Training data
353
+ target: Target column name
354
+ model_type: Model type to tune
355
+
356
+ Returns:
357
+ Dict containing best parameters and performance
358
+ """
359
+ try:
360
+ # Note: Full hyperparameter tuning with GridSearchCV would be implemented here
361
+ # For now, returning placeholder structure
362
+ self.logger.info("Hyperparameter tuning is a placeholder - train with default params")
363
+
364
+ result = self.train_model(data, target, model_type, auto_tune=False)
365
+ result["tuning_note"] = "Using default parameters - full tuning not implemented"
366
+
367
+ return result
368
+
369
+ except Exception as e:
370
+ self.logger.error(f"Error tuning hyperparameters: {e}")
371
+ raise TrainingError(f"Hyperparameter tuning failed: {e}")
372
+
373
+ # Internal helper methods
374
+
375
+ def _to_dataframe(self, data: Union[Dict, List, pd.DataFrame]) -> pd.DataFrame:
376
+ """Convert data to DataFrame"""
377
+ if isinstance(data, pd.DataFrame):
378
+ return data
379
+ elif isinstance(data, list):
380
+ return pd.DataFrame(data)
381
+ elif isinstance(data, dict):
382
+ return pd.DataFrame([data])
383
+ else:
384
+ raise TrainingError(f"Unsupported data type: {type(data)}")
385
+
386
+ def _determine_task_type(self, y: pd.Series) -> TaskType:
387
+ """Determine task type from target variable"""
388
+ if y.dtype in ["object", "category", "bool"]:
389
+ return TaskType.CLASSIFICATION
390
+ elif y.nunique() < 10 and y.dtype in ["int64", "int32"]:
391
+ return TaskType.CLASSIFICATION
392
+ else:
393
+ return TaskType.REGRESSION
394
+
395
+ def _auto_select_model_type(self, task_type: TaskType) -> ModelType:
396
+ """Auto-select model type based on task"""
397
+ if task_type == TaskType.CLASSIFICATION:
398
+ return ModelType.RANDOM_FOREST_CLASSIFIER
399
+ else:
400
+ return ModelType.RANDOM_FOREST_REGRESSOR
401
+
402
+ def _create_model(self, model_type: ModelType):
403
+ """Create model instance"""
404
+ if model_type == ModelType.LOGISTIC_REGRESSION:
405
+ return LogisticRegression(random_state=self.config.random_state, max_iter=1000)
406
+ elif model_type == ModelType.LINEAR_REGRESSION:
407
+ return LinearRegression()
408
+ elif model_type == ModelType.RANDOM_FOREST_CLASSIFIER:
409
+ return RandomForestClassifier(random_state=self.config.random_state, n_estimators=100)
410
+ elif model_type == ModelType.RANDOM_FOREST_REGRESSOR:
411
+ return RandomForestRegressor(random_state=self.config.random_state, n_estimators=100)
412
+ elif model_type == ModelType.GRADIENT_BOOSTING_CLASSIFIER:
413
+ return GradientBoostingClassifier(random_state=self.config.random_state)
414
+ elif model_type == ModelType.GRADIENT_BOOSTING_REGRESSOR:
415
+ return GradientBoostingRegressor(random_state=self.config.random_state)
416
+ else:
417
+ raise TrainingError(f"Unsupported model type: {model_type}")
418
+
419
+ def _preprocess_features(self, X: pd.DataFrame) -> tuple:
420
+ """Preprocess features for training"""
421
+ X_processed = X.copy()
422
+
423
+ # Handle categorical variables with label encoding
424
+ for col in X_processed.select_dtypes(include=["object", "category"]).columns:
425
+ le = LabelEncoder()
426
+ X_processed[col] = le.fit_transform(X_processed[col].astype(str))
427
+
428
+ # Handle missing values
429
+ X_processed = X_processed.fillna(X_processed.mean(numeric_only=True))
430
+
431
+ feature_names = X_processed.columns.tolist()
432
+
433
+ return X_processed.values, feature_names
434
+
435
+ def _calculate_metrics(self, y_true, y_pred, task_type: TaskType) -> Dict[str, float]:
436
+ """Calculate performance metrics"""
437
+ if task_type == TaskType.CLASSIFICATION:
438
+ return {
439
+ "accuracy": float(accuracy_score(y_true, y_pred)),
440
+ "precision": float(
441
+ precision_score(y_true, y_pred, average="weighted", zero_division=0)
442
+ ),
443
+ "recall": float(recall_score(y_true, y_pred, average="weighted", zero_division=0)),
444
+ "f1_score": float(f1_score(y_true, y_pred, average="weighted", zero_division=0)),
445
+ }
446
+ else:
447
+ mse = mean_squared_error(y_true, y_pred)
448
+ return {
449
+ "r2_score": float(r2_score(y_true, y_pred)),
450
+ "mse": float(mse),
451
+ "rmse": float(np.sqrt(mse)),
452
+ "mae": float(np.mean(np.abs(y_true - y_pred))),
453
+ }
454
+
455
+ def _get_feature_importance(self, model, feature_names: List[str]) -> Dict[str, float]:
456
+ """Extract feature importance from model"""
457
+ if hasattr(model, "feature_importances_"):
458
+ importance = model.feature_importances_
459
+ return {name: float(imp) for name, imp in zip(feature_names, importance)}
460
+ elif hasattr(model, "coef_"):
461
+ importance = np.abs(model.coef_).flatten()
462
+ return {name: float(imp) for name, imp in zip(feature_names, importance)}
463
+ else:
464
+ return {}
465
+
466
+ def _explain_model_selection(
467
+ self,
468
+ df: pd.DataFrame,
469
+ y: pd.Series,
470
+ task_type: TaskType,
471
+ model_type: ModelType,
472
+ ) -> str:
473
+ """Explain why a model was selected"""
474
+ n_samples = len(df)
475
+ n_features = len(df.columns) - 1
476
+
477
+ reasons = []
478
+ reasons.append(f"Task type: {task_type.value}")
479
+ reasons.append(f"Dataset size: {n_samples} samples, {n_features} features")
480
+
481
+ if model_type in [
482
+ ModelType.RANDOM_FOREST_CLASSIFIER,
483
+ ModelType.RANDOM_FOREST_REGRESSOR,
484
+ ]:
485
+ reasons.append("Random Forest selected for robust performance and feature importance")
486
+
487
+ return "; ".join(reasons)