aiecs 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiecs/__init__.py +72 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +469 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +363 -0
- aiecs/application/knowledge_graph/__init__.py +7 -0
- aiecs/application/knowledge_graph/builder/__init__.py +37 -0
- aiecs/application/knowledge_graph/builder/document_builder.py +375 -0
- aiecs/application/knowledge_graph/builder/graph_builder.py +356 -0
- aiecs/application/knowledge_graph/builder/schema_mapping.py +531 -0
- aiecs/application/knowledge_graph/builder/structured_pipeline.py +443 -0
- aiecs/application/knowledge_graph/builder/text_chunker.py +319 -0
- aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
- aiecs/application/knowledge_graph/extractors/base.py +100 -0
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +327 -0
- aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +349 -0
- aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +244 -0
- aiecs/application/knowledge_graph/fusion/__init__.py +23 -0
- aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +387 -0
- aiecs/application/knowledge_graph/fusion/entity_linker.py +343 -0
- aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +580 -0
- aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +189 -0
- aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
- aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +344 -0
- aiecs/application/knowledge_graph/pattern_matching/query_executor.py +378 -0
- aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
- aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +199 -0
- aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
- aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
- aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +347 -0
- aiecs/application/knowledge_graph/reasoning/inference_engine.py +504 -0
- aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +167 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +630 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +654 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +477 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +390 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +217 -0
- aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +169 -0
- aiecs/application/knowledge_graph/reasoning/query_planner.py +872 -0
- aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +554 -0
- aiecs/application/knowledge_graph/retrieval/__init__.py +19 -0
- aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +596 -0
- aiecs/application/knowledge_graph/search/__init__.py +59 -0
- aiecs/application/knowledge_graph/search/hybrid_search.py +423 -0
- aiecs/application/knowledge_graph/search/reranker.py +295 -0
- aiecs/application/knowledge_graph/search/reranker_strategies.py +553 -0
- aiecs/application/knowledge_graph/search/text_similarity.py +398 -0
- aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
- aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +329 -0
- aiecs/application/knowledge_graph/traversal/path_scorer.py +269 -0
- aiecs/application/knowledge_graph/validators/__init__.py +13 -0
- aiecs/application/knowledge_graph/validators/relation_validator.py +189 -0
- aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
- aiecs/application/knowledge_graph/visualization/graph_visualizer.py +321 -0
- aiecs/common/__init__.py +9 -0
- aiecs/common/knowledge_graph/__init__.py +17 -0
- aiecs/common/knowledge_graph/runnable.py +484 -0
- aiecs/config/__init__.py +16 -0
- aiecs/config/config.py +498 -0
- aiecs/config/graph_config.py +137 -0
- aiecs/config/registry.py +23 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +152 -0
- aiecs/core/interface/storage_interface.py +171 -0
- aiecs/domain/__init__.py +289 -0
- aiecs/domain/agent/__init__.py +189 -0
- aiecs/domain/agent/base_agent.py +697 -0
- aiecs/domain/agent/exceptions.py +103 -0
- aiecs/domain/agent/graph_aware_mixin.py +559 -0
- aiecs/domain/agent/hybrid_agent.py +490 -0
- aiecs/domain/agent/integration/__init__.py +26 -0
- aiecs/domain/agent/integration/context_compressor.py +222 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +252 -0
- aiecs/domain/agent/integration/retry_policy.py +219 -0
- aiecs/domain/agent/integration/role_config.py +213 -0
- aiecs/domain/agent/knowledge_aware_agent.py +646 -0
- aiecs/domain/agent/lifecycle.py +296 -0
- aiecs/domain/agent/llm_agent.py +300 -0
- aiecs/domain/agent/memory/__init__.py +12 -0
- aiecs/domain/agent/memory/conversation.py +197 -0
- aiecs/domain/agent/migration/__init__.py +14 -0
- aiecs/domain/agent/migration/conversion.py +160 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +90 -0
- aiecs/domain/agent/models.py +317 -0
- aiecs/domain/agent/observability.py +407 -0
- aiecs/domain/agent/persistence.py +289 -0
- aiecs/domain/agent/prompts/__init__.py +29 -0
- aiecs/domain/agent/prompts/builder.py +161 -0
- aiecs/domain/agent/prompts/formatters.py +189 -0
- aiecs/domain/agent/prompts/template.py +255 -0
- aiecs/domain/agent/registry.py +260 -0
- aiecs/domain/agent/tool_agent.py +257 -0
- aiecs/domain/agent/tools/__init__.py +12 -0
- aiecs/domain/agent/tools/schema_generator.py +221 -0
- aiecs/domain/community/__init__.py +155 -0
- aiecs/domain/community/agent_adapter.py +477 -0
- aiecs/domain/community/analytics.py +481 -0
- aiecs/domain/community/collaborative_workflow.py +642 -0
- aiecs/domain/community/communication_hub.py +645 -0
- aiecs/domain/community/community_builder.py +320 -0
- aiecs/domain/community/community_integration.py +800 -0
- aiecs/domain/community/community_manager.py +813 -0
- aiecs/domain/community/decision_engine.py +879 -0
- aiecs/domain/community/exceptions.py +225 -0
- aiecs/domain/community/models/__init__.py +33 -0
- aiecs/domain/community/models/community_models.py +268 -0
- aiecs/domain/community/resource_manager.py +457 -0
- aiecs/domain/community/shared_context_manager.py +603 -0
- aiecs/domain/context/__init__.py +58 -0
- aiecs/domain/context/context_engine.py +989 -0
- aiecs/domain/context/conversation_models.py +354 -0
- aiecs/domain/context/graph_memory.py +467 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +57 -0
- aiecs/domain/knowledge_graph/__init__.py +19 -0
- aiecs/domain/knowledge_graph/models/__init__.py +52 -0
- aiecs/domain/knowledge_graph/models/entity.py +130 -0
- aiecs/domain/knowledge_graph/models/evidence.py +194 -0
- aiecs/domain/knowledge_graph/models/inference_rule.py +186 -0
- aiecs/domain/knowledge_graph/models/path.py +179 -0
- aiecs/domain/knowledge_graph/models/path_pattern.py +173 -0
- aiecs/domain/knowledge_graph/models/query.py +272 -0
- aiecs/domain/knowledge_graph/models/query_plan.py +187 -0
- aiecs/domain/knowledge_graph/models/relation.py +136 -0
- aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
- aiecs/domain/knowledge_graph/schema/entity_type.py +135 -0
- aiecs/domain/knowledge_graph/schema/graph_schema.py +271 -0
- aiecs/domain/knowledge_graph/schema/property_schema.py +155 -0
- aiecs/domain/knowledge_graph/schema/relation_type.py +171 -0
- aiecs/domain/knowledge_graph/schema/schema_manager.py +496 -0
- aiecs/domain/knowledge_graph/schema/type_enums.py +205 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +613 -0
- aiecs/domain/task/model.py +62 -0
- aiecs/domain/task/task_context.py +268 -0
- aiecs/infrastructure/__init__.py +24 -0
- aiecs/infrastructure/graph_storage/__init__.py +11 -0
- aiecs/infrastructure/graph_storage/base.py +601 -0
- aiecs/infrastructure/graph_storage/batch_operations.py +449 -0
- aiecs/infrastructure/graph_storage/cache.py +429 -0
- aiecs/infrastructure/graph_storage/distributed.py +226 -0
- aiecs/infrastructure/graph_storage/error_handling.py +390 -0
- aiecs/infrastructure/graph_storage/graceful_degradation.py +306 -0
- aiecs/infrastructure/graph_storage/health_checks.py +378 -0
- aiecs/infrastructure/graph_storage/in_memory.py +514 -0
- aiecs/infrastructure/graph_storage/index_optimization.py +483 -0
- aiecs/infrastructure/graph_storage/lazy_loading.py +410 -0
- aiecs/infrastructure/graph_storage/metrics.py +357 -0
- aiecs/infrastructure/graph_storage/migration.py +413 -0
- aiecs/infrastructure/graph_storage/pagination.py +471 -0
- aiecs/infrastructure/graph_storage/performance_monitoring.py +466 -0
- aiecs/infrastructure/graph_storage/postgres.py +871 -0
- aiecs/infrastructure/graph_storage/query_optimizer.py +635 -0
- aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
- aiecs/infrastructure/graph_storage/sqlite.py +623 -0
- aiecs/infrastructure/graph_storage/streaming.py +495 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +383 -0
- aiecs/infrastructure/messaging/websocket_manager.py +298 -0
- aiecs/infrastructure/monitoring/__init__.py +34 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +174 -0
- aiecs/infrastructure/monitoring/global_metrics_manager.py +213 -0
- aiecs/infrastructure/monitoring/structured_logger.py +48 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +410 -0
- aiecs/infrastructure/persistence/__init__.py +24 -0
- aiecs/infrastructure/persistence/context_engine_client.py +187 -0
- aiecs/infrastructure/persistence/database_manager.py +333 -0
- aiecs/infrastructure/persistence/file_storage.py +754 -0
- aiecs/infrastructure/persistence/redis_client.py +220 -0
- aiecs/llm/__init__.py +86 -0
- aiecs/llm/callbacks/__init__.py +11 -0
- aiecs/llm/callbacks/custom_callbacks.py +264 -0
- aiecs/llm/client_factory.py +420 -0
- aiecs/llm/clients/__init__.py +33 -0
- aiecs/llm/clients/base_client.py +193 -0
- aiecs/llm/clients/googleai_client.py +181 -0
- aiecs/llm/clients/openai_client.py +131 -0
- aiecs/llm/clients/vertex_client.py +437 -0
- aiecs/llm/clients/xai_client.py +184 -0
- aiecs/llm/config/__init__.py +51 -0
- aiecs/llm/config/config_loader.py +275 -0
- aiecs/llm/config/config_validator.py +236 -0
- aiecs/llm/config/model_config.py +151 -0
- aiecs/llm/utils/__init__.py +10 -0
- aiecs/llm/utils/validate_config.py +91 -0
- aiecs/main.py +363 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/aid/VERSION_MANAGEMENT.md +97 -0
- aiecs/scripts/aid/__init__.py +19 -0
- aiecs/scripts/aid/version_manager.py +215 -0
- aiecs/scripts/dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md +242 -0
- aiecs/scripts/dependance_check/README_DEPENDENCY_CHECKER.md +310 -0
- aiecs/scripts/dependance_check/__init__.py +17 -0
- aiecs/scripts/dependance_check/dependency_checker.py +938 -0
- aiecs/scripts/dependance_check/dependency_fixer.py +391 -0
- aiecs/scripts/dependance_check/download_nlp_data.py +396 -0
- aiecs/scripts/dependance_check/quick_dependency_check.py +270 -0
- aiecs/scripts/dependance_check/setup_nlp_data.sh +217 -0
- aiecs/scripts/dependance_patch/__init__.py +7 -0
- aiecs/scripts/dependance_patch/fix_weasel/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.py +128 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.sh +82 -0
- aiecs/scripts/dependance_patch/fix_weasel/patch_weasel_library.sh +188 -0
- aiecs/scripts/dependance_patch/fix_weasel/run_weasel_patch.sh +41 -0
- aiecs/scripts/tools_develop/README.md +449 -0
- aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
- aiecs/scripts/tools_develop/__init__.py +21 -0
- aiecs/scripts/tools_develop/check_type_annotations.py +259 -0
- aiecs/scripts/tools_develop/validate_tool_schemas.py +422 -0
- aiecs/scripts/tools_develop/verify_tools.py +356 -0
- aiecs/tasks/__init__.py +1 -0
- aiecs/tasks/worker.py +172 -0
- aiecs/tools/__init__.py +299 -0
- aiecs/tools/apisource/__init__.py +99 -0
- aiecs/tools/apisource/intelligence/__init__.py +19 -0
- aiecs/tools/apisource/intelligence/data_fusion.py +381 -0
- aiecs/tools/apisource/intelligence/query_analyzer.py +413 -0
- aiecs/tools/apisource/intelligence/search_enhancer.py +388 -0
- aiecs/tools/apisource/monitoring/__init__.py +9 -0
- aiecs/tools/apisource/monitoring/metrics.py +303 -0
- aiecs/tools/apisource/providers/__init__.py +115 -0
- aiecs/tools/apisource/providers/base.py +664 -0
- aiecs/tools/apisource/providers/census.py +401 -0
- aiecs/tools/apisource/providers/fred.py +564 -0
- aiecs/tools/apisource/providers/newsapi.py +412 -0
- aiecs/tools/apisource/providers/worldbank.py +357 -0
- aiecs/tools/apisource/reliability/__init__.py +12 -0
- aiecs/tools/apisource/reliability/error_handler.py +375 -0
- aiecs/tools/apisource/reliability/fallback_strategy.py +391 -0
- aiecs/tools/apisource/tool.py +850 -0
- aiecs/tools/apisource/utils/__init__.py +9 -0
- aiecs/tools/apisource/utils/validators.py +338 -0
- aiecs/tools/base_tool.py +201 -0
- aiecs/tools/docs/__init__.py +121 -0
- aiecs/tools/docs/ai_document_orchestrator.py +599 -0
- aiecs/tools/docs/ai_document_writer_orchestrator.py +2403 -0
- aiecs/tools/docs/content_insertion_tool.py +1333 -0
- aiecs/tools/docs/document_creator_tool.py +1317 -0
- aiecs/tools/docs/document_layout_tool.py +1166 -0
- aiecs/tools/docs/document_parser_tool.py +994 -0
- aiecs/tools/docs/document_writer_tool.py +1818 -0
- aiecs/tools/knowledge_graph/__init__.py +17 -0
- aiecs/tools/knowledge_graph/graph_reasoning_tool.py +734 -0
- aiecs/tools/knowledge_graph/graph_search_tool.py +923 -0
- aiecs/tools/knowledge_graph/kg_builder_tool.py +476 -0
- aiecs/tools/langchain_adapter.py +542 -0
- aiecs/tools/schema_generator.py +275 -0
- aiecs/tools/search_tool/__init__.py +100 -0
- aiecs/tools/search_tool/analyzers.py +589 -0
- aiecs/tools/search_tool/cache.py +260 -0
- aiecs/tools/search_tool/constants.py +128 -0
- aiecs/tools/search_tool/context.py +216 -0
- aiecs/tools/search_tool/core.py +749 -0
- aiecs/tools/search_tool/deduplicator.py +123 -0
- aiecs/tools/search_tool/error_handler.py +271 -0
- aiecs/tools/search_tool/metrics.py +371 -0
- aiecs/tools/search_tool/rate_limiter.py +178 -0
- aiecs/tools/search_tool/schemas.py +277 -0
- aiecs/tools/statistics/__init__.py +80 -0
- aiecs/tools/statistics/ai_data_analysis_orchestrator.py +643 -0
- aiecs/tools/statistics/ai_insight_generator_tool.py +505 -0
- aiecs/tools/statistics/ai_report_orchestrator_tool.py +694 -0
- aiecs/tools/statistics/data_loader_tool.py +564 -0
- aiecs/tools/statistics/data_profiler_tool.py +658 -0
- aiecs/tools/statistics/data_transformer_tool.py +573 -0
- aiecs/tools/statistics/data_visualizer_tool.py +495 -0
- aiecs/tools/statistics/model_trainer_tool.py +487 -0
- aiecs/tools/statistics/statistical_analyzer_tool.py +459 -0
- aiecs/tools/task_tools/__init__.py +86 -0
- aiecs/tools/task_tools/chart_tool.py +732 -0
- aiecs/tools/task_tools/classfire_tool.py +922 -0
- aiecs/tools/task_tools/image_tool.py +447 -0
- aiecs/tools/task_tools/office_tool.py +684 -0
- aiecs/tools/task_tools/pandas_tool.py +635 -0
- aiecs/tools/task_tools/report_tool.py +635 -0
- aiecs/tools/task_tools/research_tool.py +392 -0
- aiecs/tools/task_tools/scraper_tool.py +715 -0
- aiecs/tools/task_tools/stats_tool.py +688 -0
- aiecs/tools/temp_file_manager.py +130 -0
- aiecs/tools/tool_executor/__init__.py +37 -0
- aiecs/tools/tool_executor/tool_executor.py +881 -0
- aiecs/utils/LLM_output_structor.py +445 -0
- aiecs/utils/__init__.py +34 -0
- aiecs/utils/base_callback.py +47 -0
- aiecs/utils/cache_provider.py +695 -0
- aiecs/utils/execution_utils.py +184 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +14 -0
- aiecs/utils/token_usage_repository.py +323 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +52 -0
- aiecs-1.5.1.dist-info/METADATA +608 -0
- aiecs-1.5.1.dist-info/RECORD +302 -0
- aiecs-1.5.1.dist-info/WHEEL +5 -0
- aiecs-1.5.1.dist-info/entry_points.txt +10 -0
- aiecs-1.5.1.dist-info/licenses/LICENSE +225 -0
- aiecs-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Trainer Tool - AutoML and machine learning model training
|
|
3
|
+
|
|
4
|
+
This tool provides AutoML capabilities with:
|
|
5
|
+
- Automatic model selection for classification and regression
|
|
6
|
+
- Hyperparameter tuning
|
|
7
|
+
- Model evaluation and comparison
|
|
8
|
+
- Feature importance analysis
|
|
9
|
+
- Model explanation support
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
from typing import Dict, Any, List, Optional, Union
|
|
14
|
+
from enum import Enum
|
|
15
|
+
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import numpy as np
|
|
18
|
+
from sklearn.model_selection import train_test_split, cross_val_score
|
|
19
|
+
from sklearn.metrics import (
|
|
20
|
+
accuracy_score,
|
|
21
|
+
precision_score,
|
|
22
|
+
recall_score,
|
|
23
|
+
f1_score,
|
|
24
|
+
r2_score,
|
|
25
|
+
mean_squared_error,
|
|
26
|
+
)
|
|
27
|
+
from sklearn.ensemble import (
|
|
28
|
+
RandomForestClassifier,
|
|
29
|
+
RandomForestRegressor,
|
|
30
|
+
GradientBoostingClassifier,
|
|
31
|
+
GradientBoostingRegressor,
|
|
32
|
+
)
|
|
33
|
+
from sklearn.linear_model import LogisticRegression, LinearRegression
|
|
34
|
+
from sklearn.preprocessing import LabelEncoder
|
|
35
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
36
|
+
|
|
37
|
+
from aiecs.tools.base_tool import BaseTool
|
|
38
|
+
from aiecs.tools import register_tool
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelType(str, Enum):
|
|
42
|
+
"""Supported model types"""
|
|
43
|
+
|
|
44
|
+
LOGISTIC_REGRESSION = "logistic_regression"
|
|
45
|
+
LINEAR_REGRESSION = "linear_regression"
|
|
46
|
+
RANDOM_FOREST_CLASSIFIER = "random_forest_classifier"
|
|
47
|
+
RANDOM_FOREST_REGRESSOR = "random_forest_regressor"
|
|
48
|
+
GRADIENT_BOOSTING_CLASSIFIER = "gradient_boosting_classifier"
|
|
49
|
+
GRADIENT_BOOSTING_REGRESSOR = "gradient_boosting_regressor"
|
|
50
|
+
AUTO = "auto"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TaskType(str, Enum):
|
|
54
|
+
"""Machine learning task types"""
|
|
55
|
+
|
|
56
|
+
CLASSIFICATION = "classification"
|
|
57
|
+
REGRESSION = "regression"
|
|
58
|
+
CLUSTERING = "clustering"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ModelTrainerError(Exception):
|
|
62
|
+
"""Base exception for ModelTrainer errors"""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TrainingError(ModelTrainerError):
|
|
66
|
+
"""Raised when model training fails"""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@register_tool("model_trainer")
|
|
70
|
+
class ModelTrainerTool(BaseTool):
|
|
71
|
+
"""
|
|
72
|
+
AutoML tool that can:
|
|
73
|
+
1. Train multiple model types
|
|
74
|
+
2. Perform hyperparameter tuning
|
|
75
|
+
3. Evaluate and compare models
|
|
76
|
+
4. Generate feature importance
|
|
77
|
+
5. Provide model explanations
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# Configuration schema
|
|
81
|
+
class Config(BaseModel):
|
|
82
|
+
"""Configuration for the model trainer tool"""
|
|
83
|
+
|
|
84
|
+
model_config = ConfigDict(env_prefix="MODEL_TRAINER_")
|
|
85
|
+
|
|
86
|
+
test_size: float = Field(default=0.2, description="Proportion of data to use for testing")
|
|
87
|
+
random_state: int = Field(default=42, description="Random state for reproducibility")
|
|
88
|
+
cv_folds: int = Field(default=5, description="Number of cross-validation folds")
|
|
89
|
+
enable_hyperparameter_tuning: bool = Field(
|
|
90
|
+
default=False,
|
|
91
|
+
description="Whether to enable hyperparameter tuning",
|
|
92
|
+
)
|
|
93
|
+
max_tuning_iterations: int = Field(
|
|
94
|
+
default=20,
|
|
95
|
+
description="Maximum number of hyperparameter tuning iterations",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
99
|
+
"""Initialize ModelTrainerTool with settings"""
|
|
100
|
+
super().__init__(config)
|
|
101
|
+
|
|
102
|
+
# Parse configuration
|
|
103
|
+
self.config = self.Config(**(config or {}))
|
|
104
|
+
|
|
105
|
+
self.logger = logging.getLogger(__name__)
|
|
106
|
+
if not self.logger.handlers:
|
|
107
|
+
handler = logging.StreamHandler()
|
|
108
|
+
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
|
109
|
+
self.logger.addHandler(handler)
|
|
110
|
+
self.logger.setLevel(logging.INFO)
|
|
111
|
+
|
|
112
|
+
self._init_external_tools()
|
|
113
|
+
self.trained_models = {}
|
|
114
|
+
|
|
115
|
+
def _init_external_tools(self):
|
|
116
|
+
"""Initialize external task tools"""
|
|
117
|
+
self.external_tools = {}
|
|
118
|
+
|
|
119
|
+
# Schema definitions
|
|
120
|
+
class TrainModelSchema(BaseModel):
|
|
121
|
+
"""Schema for train_model operation"""
|
|
122
|
+
|
|
123
|
+
data: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(description="Training data")
|
|
124
|
+
target: str = Field(description="Target column name")
|
|
125
|
+
model_type: ModelType = Field(default=ModelType.AUTO, description="Model type to train")
|
|
126
|
+
auto_tune: bool = Field(default=False, description="Enable hyperparameter tuning")
|
|
127
|
+
cross_validation: int = Field(default=5, description="Number of CV folds")
|
|
128
|
+
|
|
129
|
+
class AutoSelectModelSchema(BaseModel):
|
|
130
|
+
"""Schema for auto_select_model operation"""
|
|
131
|
+
|
|
132
|
+
data: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(
|
|
133
|
+
description="Data for model selection"
|
|
134
|
+
)
|
|
135
|
+
target: str = Field(description="Target column name")
|
|
136
|
+
task_type: Optional[TaskType] = Field(default=None, description="Task type")
|
|
137
|
+
|
|
138
|
+
class EvaluateModelSchema(BaseModel):
|
|
139
|
+
"""Schema for evaluate_model operation"""
|
|
140
|
+
|
|
141
|
+
model_id: str = Field(description="ID of trained model")
|
|
142
|
+
test_data: Union[Dict[str, Any], List[Dict[str, Any]]] = Field(description="Test data")
|
|
143
|
+
target: str = Field(description="Target column name")
|
|
144
|
+
|
|
145
|
+
def train_model(
|
|
146
|
+
self,
|
|
147
|
+
data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
|
|
148
|
+
target: str,
|
|
149
|
+
model_type: ModelType = ModelType.AUTO,
|
|
150
|
+
auto_tune: bool = False,
|
|
151
|
+
cross_validation: int = 5,
|
|
152
|
+
) -> Dict[str, Any]:
|
|
153
|
+
"""
|
|
154
|
+
Train and evaluate model.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
data: Training data
|
|
158
|
+
target: Target column name
|
|
159
|
+
model_type: Type of model to train (auto-selected if AUTO)
|
|
160
|
+
auto_tune: Enable hyperparameter tuning
|
|
161
|
+
cross_validation: Number of cross-validation folds
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Dict containing:
|
|
165
|
+
- model_id: Unique identifier for trained model
|
|
166
|
+
- model_type: Type of model trained
|
|
167
|
+
- performance: Performance metrics
|
|
168
|
+
- feature_importance: Feature importance scores
|
|
169
|
+
- cross_validation_scores: CV scores
|
|
170
|
+
"""
|
|
171
|
+
try:
|
|
172
|
+
df = self._to_dataframe(data)
|
|
173
|
+
|
|
174
|
+
# Separate features and target
|
|
175
|
+
X = df.drop(columns=[target])
|
|
176
|
+
y = df[target]
|
|
177
|
+
|
|
178
|
+
# Determine task type and model
|
|
179
|
+
task_type = self._determine_task_type(y)
|
|
180
|
+
|
|
181
|
+
if model_type == ModelType.AUTO:
|
|
182
|
+
model_type = self._auto_select_model_type(task_type)
|
|
183
|
+
self.logger.info(f"Auto-selected model type: {model_type.value}")
|
|
184
|
+
|
|
185
|
+
# Prepare data
|
|
186
|
+
X_processed, feature_names = self._preprocess_features(X)
|
|
187
|
+
|
|
188
|
+
# Handle categorical target for classification
|
|
189
|
+
if task_type == TaskType.CLASSIFICATION:
|
|
190
|
+
label_encoder = LabelEncoder()
|
|
191
|
+
y = label_encoder.fit_transform(y)
|
|
192
|
+
else:
|
|
193
|
+
label_encoder = None
|
|
194
|
+
|
|
195
|
+
# Split data
|
|
196
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
197
|
+
X_processed,
|
|
198
|
+
y,
|
|
199
|
+
test_size=self.config.test_size,
|
|
200
|
+
random_state=self.config.random_state,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Create and train model
|
|
204
|
+
model = self._create_model(model_type)
|
|
205
|
+
model.fit(X_train, y_train)
|
|
206
|
+
|
|
207
|
+
# Make predictions
|
|
208
|
+
y_pred = model.predict(X_test)
|
|
209
|
+
|
|
210
|
+
# Calculate metrics
|
|
211
|
+
performance = self._calculate_metrics(y_test, y_pred, task_type)
|
|
212
|
+
|
|
213
|
+
# Cross-validation
|
|
214
|
+
cv_scores = cross_val_score(model, X_processed, y, cv=cross_validation)
|
|
215
|
+
|
|
216
|
+
# Feature importance
|
|
217
|
+
feature_importance = self._get_feature_importance(model, feature_names)
|
|
218
|
+
|
|
219
|
+
# Store model
|
|
220
|
+
model_id = f"model_{len(self.trained_models) + 1}"
|
|
221
|
+
self.trained_models[model_id] = {
|
|
222
|
+
"model": model,
|
|
223
|
+
"model_type": model_type.value,
|
|
224
|
+
"task_type": task_type.value,
|
|
225
|
+
"feature_names": feature_names,
|
|
226
|
+
"label_encoder": label_encoder,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return {
|
|
230
|
+
"model_id": model_id,
|
|
231
|
+
"model_type": model_type.value,
|
|
232
|
+
"task_type": task_type.value,
|
|
233
|
+
"performance": performance,
|
|
234
|
+
"feature_importance": feature_importance,
|
|
235
|
+
"cross_validation_scores": {
|
|
236
|
+
"scores": cv_scores.tolist(),
|
|
237
|
+
"mean": float(cv_scores.mean()),
|
|
238
|
+
"std": float(cv_scores.std()),
|
|
239
|
+
},
|
|
240
|
+
"training_samples": len(X_train),
|
|
241
|
+
"test_samples": len(X_test),
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
self.logger.error(f"Error training model: {e}")
|
|
246
|
+
raise TrainingError(f"Model training failed: {e}")
|
|
247
|
+
|
|
248
|
+
def auto_select_model(
|
|
249
|
+
self,
|
|
250
|
+
data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
|
|
251
|
+
target: str,
|
|
252
|
+
task_type: Optional[TaskType] = None,
|
|
253
|
+
) -> Dict[str, Any]:
|
|
254
|
+
"""
|
|
255
|
+
Automatically select best model based on data characteristics.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
data: Data for model selection
|
|
259
|
+
target: Target column name
|
|
260
|
+
task_type: Optional task type (auto-determined if None)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Dict containing recommended model and reasoning
|
|
264
|
+
"""
|
|
265
|
+
try:
|
|
266
|
+
df = self._to_dataframe(data)
|
|
267
|
+
y = df[target]
|
|
268
|
+
|
|
269
|
+
# Determine task type
|
|
270
|
+
if task_type is None:
|
|
271
|
+
task_type = self._determine_task_type(y)
|
|
272
|
+
|
|
273
|
+
# Select model
|
|
274
|
+
model_type = self._auto_select_model_type(task_type)
|
|
275
|
+
|
|
276
|
+
# Provide reasoning
|
|
277
|
+
reasoning = self._explain_model_selection(df, y, task_type, model_type)
|
|
278
|
+
|
|
279
|
+
return {
|
|
280
|
+
"recommended_model": model_type.value,
|
|
281
|
+
"task_type": task_type.value,
|
|
282
|
+
"reasoning": reasoning,
|
|
283
|
+
"confidence": "high",
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
except Exception as e:
|
|
287
|
+
self.logger.error(f"Error in auto model selection: {e}")
|
|
288
|
+
raise TrainingError(f"Model selection failed: {e}")
|
|
289
|
+
|
|
290
|
+
def evaluate_model(
|
|
291
|
+
self,
|
|
292
|
+
model_id: str,
|
|
293
|
+
test_data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
|
|
294
|
+
target: str,
|
|
295
|
+
) -> Dict[str, Any]:
|
|
296
|
+
"""
|
|
297
|
+
Evaluate trained model on test data.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
model_id: ID of trained model
|
|
301
|
+
test_data: Test data
|
|
302
|
+
target: Target column name
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Dict containing evaluation metrics
|
|
306
|
+
"""
|
|
307
|
+
try:
|
|
308
|
+
if model_id not in self.trained_models:
|
|
309
|
+
raise TrainingError(f"Model {model_id} not found")
|
|
310
|
+
|
|
311
|
+
df = self._to_dataframe(test_data)
|
|
312
|
+
X_test = df.drop(columns=[target])
|
|
313
|
+
y_test = df[target]
|
|
314
|
+
|
|
315
|
+
model_info = self.trained_models[model_id]
|
|
316
|
+
model = model_info["model"]
|
|
317
|
+
task_type = TaskType(model_info["task_type"])
|
|
318
|
+
|
|
319
|
+
# Preprocess features
|
|
320
|
+
X_processed, _ = self._preprocess_features(X_test)
|
|
321
|
+
|
|
322
|
+
# Handle label encoding for classification
|
|
323
|
+
if model_info["label_encoder"]:
|
|
324
|
+
y_test = model_info["label_encoder"].transform(y_test)
|
|
325
|
+
|
|
326
|
+
# Make predictions
|
|
327
|
+
y_pred = model.predict(X_processed)
|
|
328
|
+
|
|
329
|
+
# Calculate metrics
|
|
330
|
+
performance = self._calculate_metrics(y_test, y_pred, task_type)
|
|
331
|
+
|
|
332
|
+
return {
|
|
333
|
+
"model_id": model_id,
|
|
334
|
+
"performance": performance,
|
|
335
|
+
"test_samples": len(X_test),
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
except Exception as e:
|
|
339
|
+
self.logger.error(f"Error evaluating model: {e}")
|
|
340
|
+
raise TrainingError(f"Model evaluation failed: {e}")
|
|
341
|
+
|
|
342
|
+
def tune_hyperparameters(
|
|
343
|
+
self,
|
|
344
|
+
data: Union[Dict[str, Any], List[Dict[str, Any]], pd.DataFrame],
|
|
345
|
+
target: str,
|
|
346
|
+
model_type: ModelType,
|
|
347
|
+
) -> Dict[str, Any]:
|
|
348
|
+
"""
|
|
349
|
+
Tune hyperparameters for specified model type.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
data: Training data
|
|
353
|
+
target: Target column name
|
|
354
|
+
model_type: Model type to tune
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Dict containing best parameters and performance
|
|
358
|
+
"""
|
|
359
|
+
try:
|
|
360
|
+
# Note: Full hyperparameter tuning with GridSearchCV would be implemented here
|
|
361
|
+
# For now, returning placeholder structure
|
|
362
|
+
self.logger.info("Hyperparameter tuning is a placeholder - train with default params")
|
|
363
|
+
|
|
364
|
+
result = self.train_model(data, target, model_type, auto_tune=False)
|
|
365
|
+
result["tuning_note"] = "Using default parameters - full tuning not implemented"
|
|
366
|
+
|
|
367
|
+
return result
|
|
368
|
+
|
|
369
|
+
except Exception as e:
|
|
370
|
+
self.logger.error(f"Error tuning hyperparameters: {e}")
|
|
371
|
+
raise TrainingError(f"Hyperparameter tuning failed: {e}")
|
|
372
|
+
|
|
373
|
+
# Internal helper methods
|
|
374
|
+
|
|
375
|
+
def _to_dataframe(self, data: Union[Dict, List, pd.DataFrame]) -> pd.DataFrame:
|
|
376
|
+
"""Convert data to DataFrame"""
|
|
377
|
+
if isinstance(data, pd.DataFrame):
|
|
378
|
+
return data
|
|
379
|
+
elif isinstance(data, list):
|
|
380
|
+
return pd.DataFrame(data)
|
|
381
|
+
elif isinstance(data, dict):
|
|
382
|
+
return pd.DataFrame([data])
|
|
383
|
+
else:
|
|
384
|
+
raise TrainingError(f"Unsupported data type: {type(data)}")
|
|
385
|
+
|
|
386
|
+
def _determine_task_type(self, y: pd.Series) -> TaskType:
|
|
387
|
+
"""Determine task type from target variable"""
|
|
388
|
+
if y.dtype in ["object", "category", "bool"]:
|
|
389
|
+
return TaskType.CLASSIFICATION
|
|
390
|
+
elif y.nunique() < 10 and y.dtype in ["int64", "int32"]:
|
|
391
|
+
return TaskType.CLASSIFICATION
|
|
392
|
+
else:
|
|
393
|
+
return TaskType.REGRESSION
|
|
394
|
+
|
|
395
|
+
def _auto_select_model_type(self, task_type: TaskType) -> ModelType:
|
|
396
|
+
"""Auto-select model type based on task"""
|
|
397
|
+
if task_type == TaskType.CLASSIFICATION:
|
|
398
|
+
return ModelType.RANDOM_FOREST_CLASSIFIER
|
|
399
|
+
else:
|
|
400
|
+
return ModelType.RANDOM_FOREST_REGRESSOR
|
|
401
|
+
|
|
402
|
+
def _create_model(self, model_type: ModelType):
|
|
403
|
+
"""Create model instance"""
|
|
404
|
+
if model_type == ModelType.LOGISTIC_REGRESSION:
|
|
405
|
+
return LogisticRegression(random_state=self.config.random_state, max_iter=1000)
|
|
406
|
+
elif model_type == ModelType.LINEAR_REGRESSION:
|
|
407
|
+
return LinearRegression()
|
|
408
|
+
elif model_type == ModelType.RANDOM_FOREST_CLASSIFIER:
|
|
409
|
+
return RandomForestClassifier(random_state=self.config.random_state, n_estimators=100)
|
|
410
|
+
elif model_type == ModelType.RANDOM_FOREST_REGRESSOR:
|
|
411
|
+
return RandomForestRegressor(random_state=self.config.random_state, n_estimators=100)
|
|
412
|
+
elif model_type == ModelType.GRADIENT_BOOSTING_CLASSIFIER:
|
|
413
|
+
return GradientBoostingClassifier(random_state=self.config.random_state)
|
|
414
|
+
elif model_type == ModelType.GRADIENT_BOOSTING_REGRESSOR:
|
|
415
|
+
return GradientBoostingRegressor(random_state=self.config.random_state)
|
|
416
|
+
else:
|
|
417
|
+
raise TrainingError(f"Unsupported model type: {model_type}")
|
|
418
|
+
|
|
419
|
+
def _preprocess_features(self, X: pd.DataFrame) -> tuple:
|
|
420
|
+
"""Preprocess features for training"""
|
|
421
|
+
X_processed = X.copy()
|
|
422
|
+
|
|
423
|
+
# Handle categorical variables with label encoding
|
|
424
|
+
for col in X_processed.select_dtypes(include=["object", "category"]).columns:
|
|
425
|
+
le = LabelEncoder()
|
|
426
|
+
X_processed[col] = le.fit_transform(X_processed[col].astype(str))
|
|
427
|
+
|
|
428
|
+
# Handle missing values
|
|
429
|
+
X_processed = X_processed.fillna(X_processed.mean(numeric_only=True))
|
|
430
|
+
|
|
431
|
+
feature_names = X_processed.columns.tolist()
|
|
432
|
+
|
|
433
|
+
return X_processed.values, feature_names
|
|
434
|
+
|
|
435
|
+
def _calculate_metrics(self, y_true, y_pred, task_type: TaskType) -> Dict[str, float]:
|
|
436
|
+
"""Calculate performance metrics"""
|
|
437
|
+
if task_type == TaskType.CLASSIFICATION:
|
|
438
|
+
return {
|
|
439
|
+
"accuracy": float(accuracy_score(y_true, y_pred)),
|
|
440
|
+
"precision": float(
|
|
441
|
+
precision_score(y_true, y_pred, average="weighted", zero_division=0)
|
|
442
|
+
),
|
|
443
|
+
"recall": float(recall_score(y_true, y_pred, average="weighted", zero_division=0)),
|
|
444
|
+
"f1_score": float(f1_score(y_true, y_pred, average="weighted", zero_division=0)),
|
|
445
|
+
}
|
|
446
|
+
else:
|
|
447
|
+
mse = mean_squared_error(y_true, y_pred)
|
|
448
|
+
return {
|
|
449
|
+
"r2_score": float(r2_score(y_true, y_pred)),
|
|
450
|
+
"mse": float(mse),
|
|
451
|
+
"rmse": float(np.sqrt(mse)),
|
|
452
|
+
"mae": float(np.mean(np.abs(y_true - y_pred))),
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
def _get_feature_importance(self, model, feature_names: List[str]) -> Dict[str, float]:
|
|
456
|
+
"""Extract feature importance from model"""
|
|
457
|
+
if hasattr(model, "feature_importances_"):
|
|
458
|
+
importance = model.feature_importances_
|
|
459
|
+
return {name: float(imp) for name, imp in zip(feature_names, importance)}
|
|
460
|
+
elif hasattr(model, "coef_"):
|
|
461
|
+
importance = np.abs(model.coef_).flatten()
|
|
462
|
+
return {name: float(imp) for name, imp in zip(feature_names, importance)}
|
|
463
|
+
else:
|
|
464
|
+
return {}
|
|
465
|
+
|
|
466
|
+
def _explain_model_selection(
|
|
467
|
+
self,
|
|
468
|
+
df: pd.DataFrame,
|
|
469
|
+
y: pd.Series,
|
|
470
|
+
task_type: TaskType,
|
|
471
|
+
model_type: ModelType,
|
|
472
|
+
) -> str:
|
|
473
|
+
"""Explain why a model was selected"""
|
|
474
|
+
n_samples = len(df)
|
|
475
|
+
n_features = len(df.columns) - 1
|
|
476
|
+
|
|
477
|
+
reasons = []
|
|
478
|
+
reasons.append(f"Task type: {task_type.value}")
|
|
479
|
+
reasons.append(f"Dataset size: {n_samples} samples, {n_features} features")
|
|
480
|
+
|
|
481
|
+
if model_type in [
|
|
482
|
+
ModelType.RANDOM_FOREST_CLASSIFIER,
|
|
483
|
+
ModelType.RANDOM_FOREST_REGRESSOR,
|
|
484
|
+
]:
|
|
485
|
+
reasons.append("Random Forest selected for robust performance and feature importance")
|
|
486
|
+
|
|
487
|
+
return "; ".join(reasons)
|