hackagent 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hackagent/__init__.py +12 -0
- hackagent/agent.py +214 -0
- hackagent/api/__init__.py +1 -0
- hackagent/api/agent/__init__.py +1 -0
- hackagent/api/agent/agent_create.py +347 -0
- hackagent/api/agent/agent_destroy.py +140 -0
- hackagent/api/agent/agent_list.py +242 -0
- hackagent/api/agent/agent_partial_update.py +361 -0
- hackagent/api/agent/agent_retrieve.py +235 -0
- hackagent/api/agent/agent_update.py +361 -0
- hackagent/api/apilogs/__init__.py +1 -0
- hackagent/api/apilogs/apilogs_list.py +170 -0
- hackagent/api/apilogs/apilogs_retrieve.py +162 -0
- hackagent/api/attack/__init__.py +1 -0
- hackagent/api/attack/attack_create.py +275 -0
- hackagent/api/attack/attack_destroy.py +146 -0
- hackagent/api/attack/attack_list.py +254 -0
- hackagent/api/attack/attack_partial_update.py +289 -0
- hackagent/api/attack/attack_retrieve.py +247 -0
- hackagent/api/attack/attack_update.py +289 -0
- hackagent/api/checkout/__init__.py +1 -0
- hackagent/api/checkout/checkout_create.py +225 -0
- hackagent/api/generate/__init__.py +1 -0
- hackagent/api/generate/generate_create.py +253 -0
- hackagent/api/judge/__init__.py +1 -0
- hackagent/api/judge/judge_create.py +253 -0
- hackagent/api/key/__init__.py +1 -0
- hackagent/api/key/key_create.py +179 -0
- hackagent/api/key/key_destroy.py +103 -0
- hackagent/api/key/key_list.py +170 -0
- hackagent/api/key/key_retrieve.py +162 -0
- hackagent/api/organization/__init__.py +1 -0
- hackagent/api/organization/organization_create.py +208 -0
- hackagent/api/organization/organization_destroy.py +104 -0
- hackagent/api/organization/organization_list.py +170 -0
- hackagent/api/organization/organization_me_retrieve.py +126 -0
- hackagent/api/organization/organization_partial_update.py +222 -0
- hackagent/api/organization/organization_retrieve.py +163 -0
- hackagent/api/organization/organization_update.py +222 -0
- hackagent/api/prompt/__init__.py +1 -0
- hackagent/api/prompt/prompt_create.py +171 -0
- hackagent/api/prompt/prompt_destroy.py +104 -0
- hackagent/api/prompt/prompt_list.py +185 -0
- hackagent/api/prompt/prompt_partial_update.py +185 -0
- hackagent/api/prompt/prompt_retrieve.py +163 -0
- hackagent/api/prompt/prompt_update.py +185 -0
- hackagent/api/result/__init__.py +1 -0
- hackagent/api/result/result_create.py +175 -0
- hackagent/api/result/result_destroy.py +106 -0
- hackagent/api/result/result_list.py +249 -0
- hackagent/api/result/result_partial_update.py +193 -0
- hackagent/api/result/result_retrieve.py +167 -0
- hackagent/api/result/result_trace_create.py +177 -0
- hackagent/api/result/result_update.py +189 -0
- hackagent/api/run/__init__.py +1 -0
- hackagent/api/run/run_create.py +187 -0
- hackagent/api/run/run_destroy.py +112 -0
- hackagent/api/run/run_list.py +291 -0
- hackagent/api/run/run_partial_update.py +201 -0
- hackagent/api/run/run_result_create.py +177 -0
- hackagent/api/run/run_retrieve.py +179 -0
- hackagent/api/run/run_run_tests_create.py +187 -0
- hackagent/api/run/run_update.py +201 -0
- hackagent/api/user/__init__.py +1 -0
- hackagent/api/user/user_create.py +212 -0
- hackagent/api/user/user_destroy.py +106 -0
- hackagent/api/user/user_list.py +174 -0
- hackagent/api/user/user_me_retrieve.py +126 -0
- hackagent/api/user/user_me_update.py +196 -0
- hackagent/api/user/user_partial_update.py +226 -0
- hackagent/api/user/user_retrieve.py +167 -0
- hackagent/api/user/user_update.py +226 -0
- hackagent/attacks/AdvPrefix/__init__.py +41 -0
- hackagent/attacks/AdvPrefix/completions.py +416 -0
- hackagent/attacks/AdvPrefix/config.py +259 -0
- hackagent/attacks/AdvPrefix/evaluation.py +745 -0
- hackagent/attacks/AdvPrefix/evaluators.py +564 -0
- hackagent/attacks/AdvPrefix/generate.py +711 -0
- hackagent/attacks/AdvPrefix/utils.py +307 -0
- hackagent/attacks/__init__.py +35 -0
- hackagent/attacks/advprefix.py +507 -0
- hackagent/attacks/base.py +106 -0
- hackagent/attacks/strategies.py +906 -0
- hackagent/cli/__init__.py +19 -0
- hackagent/cli/commands/__init__.py +20 -0
- hackagent/cli/commands/agent.py +100 -0
- hackagent/cli/commands/attack.py +417 -0
- hackagent/cli/commands/config.py +301 -0
- hackagent/cli/commands/results.py +327 -0
- hackagent/cli/config.py +249 -0
- hackagent/cli/main.py +515 -0
- hackagent/cli/tui/__init__.py +31 -0
- hackagent/cli/tui/actions_logger.py +200 -0
- hackagent/cli/tui/app.py +288 -0
- hackagent/cli/tui/base.py +137 -0
- hackagent/cli/tui/logger.py +318 -0
- hackagent/cli/tui/views/__init__.py +33 -0
- hackagent/cli/tui/views/agents.py +488 -0
- hackagent/cli/tui/views/attacks.py +624 -0
- hackagent/cli/tui/views/config.py +244 -0
- hackagent/cli/tui/views/dashboard.py +307 -0
- hackagent/cli/tui/views/results.py +1210 -0
- hackagent/cli/tui/widgets/__init__.py +24 -0
- hackagent/cli/tui/widgets/actions.py +346 -0
- hackagent/cli/tui/widgets/logs.py +435 -0
- hackagent/cli/utils.py +276 -0
- hackagent/client.py +286 -0
- hackagent/errors.py +37 -0
- hackagent/logger.py +83 -0
- hackagent/models/__init__.py +109 -0
- hackagent/models/agent.py +223 -0
- hackagent/models/agent_request.py +129 -0
- hackagent/models/api_token_log.py +184 -0
- hackagent/models/attack.py +154 -0
- hackagent/models/attack_request.py +82 -0
- hackagent/models/checkout_session_request_request.py +76 -0
- hackagent/models/checkout_session_response.py +59 -0
- hackagent/models/choice.py +81 -0
- hackagent/models/choice_message.py +67 -0
- hackagent/models/evaluation_status_enum.py +14 -0
- hackagent/models/generate_error_response.py +59 -0
- hackagent/models/generate_request_request.py +212 -0
- hackagent/models/generate_success_response.py +115 -0
- hackagent/models/generic_error_response.py +70 -0
- hackagent/models/message_request.py +67 -0
- hackagent/models/organization.py +102 -0
- hackagent/models/organization_minimal.py +68 -0
- hackagent/models/organization_request.py +71 -0
- hackagent/models/paginated_agent_list.py +123 -0
- hackagent/models/paginated_api_token_log_list.py +123 -0
- hackagent/models/paginated_attack_list.py +123 -0
- hackagent/models/paginated_organization_list.py +123 -0
- hackagent/models/paginated_prompt_list.py +123 -0
- hackagent/models/paginated_result_list.py +123 -0
- hackagent/models/paginated_run_list.py +123 -0
- hackagent/models/paginated_user_api_key_list.py +123 -0
- hackagent/models/paginated_user_profile_list.py +123 -0
- hackagent/models/patched_agent_request.py +128 -0
- hackagent/models/patched_attack_request.py +92 -0
- hackagent/models/patched_organization_request.py +71 -0
- hackagent/models/patched_prompt_request.py +125 -0
- hackagent/models/patched_result_request.py +237 -0
- hackagent/models/patched_run_request.py +138 -0
- hackagent/models/patched_user_profile_request.py +99 -0
- hackagent/models/prompt.py +220 -0
- hackagent/models/prompt_request.py +126 -0
- hackagent/models/result.py +294 -0
- hackagent/models/result_list_evaluation_status.py +14 -0
- hackagent/models/result_request.py +232 -0
- hackagent/models/run.py +233 -0
- hackagent/models/run_list_status.py +12 -0
- hackagent/models/run_request.py +133 -0
- hackagent/models/status_enum.py +12 -0
- hackagent/models/step_type_enum.py +14 -0
- hackagent/models/trace.py +121 -0
- hackagent/models/trace_request.py +94 -0
- hackagent/models/usage.py +75 -0
- hackagent/models/user_api_key.py +201 -0
- hackagent/models/user_api_key_request.py +73 -0
- hackagent/models/user_profile.py +135 -0
- hackagent/models/user_profile_minimal.py +76 -0
- hackagent/models/user_profile_request.py +99 -0
- hackagent/router/__init__.py +25 -0
- hackagent/router/adapters/__init__.py +20 -0
- hackagent/router/adapters/base.py +63 -0
- hackagent/router/adapters/google_adk.py +671 -0
- hackagent/router/adapters/litellm_adapter.py +524 -0
- hackagent/router/adapters/openai_adapter.py +426 -0
- hackagent/router/router.py +969 -0
- hackagent/router/types.py +54 -0
- hackagent/tracking/__init__.py +42 -0
- hackagent/tracking/context.py +163 -0
- hackagent/tracking/decorators.py +299 -0
- hackagent/tracking/tracker.py +441 -0
- hackagent/types.py +54 -0
- hackagent/utils.py +194 -0
- hackagent/vulnerabilities/__init__.py +13 -0
- hackagent/vulnerabilities/prompts.py +81 -0
- hackagent-0.3.1.dist-info/METADATA +122 -0
- hackagent-0.3.1.dist-info/RECORD +183 -0
- hackagent-0.3.1.dist-info/WHEEL +4 -0
- hackagent-0.3.1.dist-info/entry_points.txt +2 -0
- hackagent-0.3.1.dist-info/licenses/LICENSE +202 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2025 - AI4I. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
SDK-specific types for HackAgent router.
|
|
17
|
+
|
|
18
|
+
These types are used internally by the SDK and are not part of the API models.
|
|
19
|
+
The API uses plain strings for agent_type, but the SDK provides enums for type safety.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from enum import Enum
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AgentTypeEnum(str, Enum):
|
|
26
|
+
"""
|
|
27
|
+
Enumeration of supported agent types in the HackAgent SDK.
|
|
28
|
+
|
|
29
|
+
These values correspond to the string values used in the API's agent_type field.
|
|
30
|
+
|
|
31
|
+
Endpoint Requirements by Type:
|
|
32
|
+
- GOOGLE_ADK: Google Agent Development Kit endpoint (custom protocol)
|
|
33
|
+
- LITELLM: Any LLM endpoint via LiteLLM (multi-provider support)
|
|
34
|
+
- OPENAI_SDK: OpenAI-compatible endpoint (should end with /v1 base path)
|
|
35
|
+
- LANGCHAIN: LangServe endpoint (typically /invoke or /stream)
|
|
36
|
+
- MCP: Model Context Protocol endpoint (MCP-specific protocol)
|
|
37
|
+
- A2A: Agent-to-Agent protocol endpoint (A2A-specific protocol)
|
|
38
|
+
- UNKNOWN: Unknown agent type (fallback)
|
|
39
|
+
|
|
40
|
+
Note: For OpenAI-compatible endpoints (OPENAI_SDK, LITELLM with custom endpoints),
|
|
41
|
+
provide the base URL ending in /v1 (e.g., http://localhost:8000/v1).
|
|
42
|
+
The OpenAI client will automatically append /chat/completions.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
GOOGLE_ADK = "GOOGLE_ADK"
|
|
46
|
+
LITELLM = "LITELLM"
|
|
47
|
+
OPENAI_SDK = "OPENAI_SDK"
|
|
48
|
+
LANGCHAIN = "LANGCHAIN"
|
|
49
|
+
MCP = "MCP"
|
|
50
|
+
A2A = "A2A"
|
|
51
|
+
UNKNOWN = "UNKNOWN"
|
|
52
|
+
|
|
53
|
+
def __str__(self) -> str:
|
|
54
|
+
return str(self.value)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 2025 - AI4I. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Operation tracking and synchronization module.
|
|
17
|
+
|
|
18
|
+
This module provides components for tracking pipeline operations and
|
|
19
|
+
synchronizing state with the HackAgent backend API. It includes:
|
|
20
|
+
|
|
21
|
+
- StepTracker: Main tracking class for managing operation lifecycle
|
|
22
|
+
- track_step: Context manager for tracking individual steps
|
|
23
|
+
- track_operation: Decorator for automatic operation tracking
|
|
24
|
+
- TrackingContext: Shared context for tracking state
|
|
25
|
+
|
|
26
|
+
The tracking system is designed to be:
|
|
27
|
+
- Modular: Each component has a single responsibility
|
|
28
|
+
- Reusable: Works with any attack or pipeline implementation
|
|
29
|
+
- Optional: Gracefully degrades when tracking is disabled
|
|
30
|
+
- Thread-safe: Safe for concurrent operations
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from .context import TrackingContext
|
|
34
|
+
from .decorators import track_operation, track_pipeline
|
|
35
|
+
from .tracker import StepTracker
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"StepTracker",
|
|
39
|
+
"TrackingContext",
|
|
40
|
+
"track_operation",
|
|
41
|
+
"track_pipeline",
|
|
42
|
+
]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright 2025 - AI4I. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Tracking context management.
|
|
17
|
+
|
|
18
|
+
This module provides the TrackingContext class for managing shared state
|
|
19
|
+
across tracking operations. It acts as a lightweight container for tracking
|
|
20
|
+
configuration and state that can be passed between components.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import logging
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from typing import Any, Dict, Optional
|
|
26
|
+
from uuid import UUID
|
|
27
|
+
|
|
28
|
+
from hackagent.client import AuthenticatedClient
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class TrackingContext:
|
|
33
|
+
"""
|
|
34
|
+
Shared context for operation tracking.
|
|
35
|
+
|
|
36
|
+
This class encapsulates all the state needed for tracking operations
|
|
37
|
+
and synchronizing with the backend API. It provides a clean interface
|
|
38
|
+
for passing tracking configuration between components.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
client: Authenticated client for API communication
|
|
42
|
+
run_id: Server-generated run ID for this execution
|
|
43
|
+
parent_result_id: ID of the parent result record
|
|
44
|
+
logger: Logger instance for tracking operations
|
|
45
|
+
enabled: Whether tracking is enabled
|
|
46
|
+
sequence_counter: Counter for trace sequence numbers
|
|
47
|
+
metadata: Additional metadata for tracking
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> context = TrackingContext(
|
|
51
|
+
... client=authenticated_client,
|
|
52
|
+
... run_id="run-123",
|
|
53
|
+
... parent_result_id="result-456"
|
|
54
|
+
... )
|
|
55
|
+
>>> if context.is_enabled:
|
|
56
|
+
... tracker = StepTracker(context)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
client: Optional[AuthenticatedClient] = None
|
|
60
|
+
run_id: Optional[str] = None
|
|
61
|
+
parent_result_id: Optional[str] = None
|
|
62
|
+
logger: Optional[logging.Logger] = None
|
|
63
|
+
sequence_counter: int = 0
|
|
64
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
65
|
+
|
|
66
|
+
def __post_init__(self):
|
|
67
|
+
"""Initialize default logger if not provided."""
|
|
68
|
+
if self.logger is None:
|
|
69
|
+
self.logger = logging.getLogger(__name__)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def is_enabled(self) -> bool:
|
|
73
|
+
"""
|
|
74
|
+
Check if tracking is enabled.
|
|
75
|
+
|
|
76
|
+
Tracking is enabled when all required components are available:
|
|
77
|
+
client, run_id, and parent_result_id.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
True if tracking is enabled, False otherwise
|
|
81
|
+
"""
|
|
82
|
+
return bool(
|
|
83
|
+
self.client is not None
|
|
84
|
+
and self.run_id is not None
|
|
85
|
+
and self.parent_result_id is not None
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def increment_sequence(self) -> int:
|
|
89
|
+
"""
|
|
90
|
+
Increment and return the sequence counter.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The new sequence number
|
|
94
|
+
"""
|
|
95
|
+
self.sequence_counter += 1
|
|
96
|
+
return self.sequence_counter
|
|
97
|
+
|
|
98
|
+
def get_run_uuid(self) -> Optional[UUID]:
|
|
99
|
+
"""
|
|
100
|
+
Get run_id as UUID.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
UUID instance or None if run_id is not set
|
|
104
|
+
"""
|
|
105
|
+
if self.run_id:
|
|
106
|
+
try:
|
|
107
|
+
return UUID(self.run_id)
|
|
108
|
+
except (ValueError, AttributeError):
|
|
109
|
+
self.logger.warning(f"Invalid UUID format for run_id: {self.run_id}")
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def get_result_uuid(self) -> Optional[UUID]:
|
|
113
|
+
"""
|
|
114
|
+
Get parent_result_id as UUID.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
UUID instance or None if parent_result_id is not set
|
|
118
|
+
"""
|
|
119
|
+
if self.parent_result_id:
|
|
120
|
+
try:
|
|
121
|
+
return UUID(self.parent_result_id)
|
|
122
|
+
except (ValueError, AttributeError):
|
|
123
|
+
self.logger.warning(
|
|
124
|
+
f"Invalid UUID format for parent_result_id: {self.parent_result_id}"
|
|
125
|
+
)
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
def add_metadata(self, key: str, value: Any) -> None:
|
|
129
|
+
"""
|
|
130
|
+
Add metadata to the context.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
key: Metadata key
|
|
134
|
+
value: Metadata value
|
|
135
|
+
"""
|
|
136
|
+
self.metadata[key] = value
|
|
137
|
+
|
|
138
|
+
def get_metadata(self, key: str, default: Any = None) -> Any:
|
|
139
|
+
"""
|
|
140
|
+
Get metadata from the context.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
key: Metadata key
|
|
144
|
+
default: Default value if key not found
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Metadata value or default
|
|
148
|
+
"""
|
|
149
|
+
return self.metadata.get(key, default)
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def create_disabled(cls) -> "TrackingContext":
|
|
153
|
+
"""
|
|
154
|
+
Create a disabled tracking context.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A TrackingContext with all tracking disabled
|
|
158
|
+
"""
|
|
159
|
+
return cls(
|
|
160
|
+
client=None,
|
|
161
|
+
run_id=None,
|
|
162
|
+
parent_result_id=None,
|
|
163
|
+
)
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# Copyright 2025 - AI4I. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Decorators for automatic operation tracking.
|
|
17
|
+
|
|
18
|
+
This module provides decorator functions that can be applied to functions
|
|
19
|
+
or methods to automatically track their execution. Decorators offer a
|
|
20
|
+
declarative way to add tracking without modifying function bodies.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import functools
|
|
24
|
+
from typing import Any, Callable, Dict, Optional, TypeVar
|
|
25
|
+
|
|
26
|
+
from .tracker import StepTracker
|
|
27
|
+
|
|
28
|
+
# Type variable for preserving function signatures
|
|
29
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def track_operation(
|
|
33
|
+
step_name: str,
|
|
34
|
+
step_type: str,
|
|
35
|
+
extract_input: Optional[Callable[[Any, Any], Dict[str, Any]]] = None,
|
|
36
|
+
extract_config: Optional[Callable[[Any, Any], Dict[str, Any]]] = None,
|
|
37
|
+
) -> Callable[[F], F]:
|
|
38
|
+
"""
|
|
39
|
+
Decorator for automatic operation tracking.
|
|
40
|
+
|
|
41
|
+
This decorator wraps a function to automatically track its execution
|
|
42
|
+
using a StepTracker. It looks for a 'tracker' parameter in the function
|
|
43
|
+
arguments and uses it if available.
|
|
44
|
+
|
|
45
|
+
The decorator is flexible and can extract input data and configuration
|
|
46
|
+
using custom extractor functions, allowing it to work with any function
|
|
47
|
+
signature.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
step_name: Human-readable name for the operation
|
|
51
|
+
step_type: Step type identifier (e.g., "STEP1_GENERATE")
|
|
52
|
+
extract_input: Optional function to extract input data from args/kwargs
|
|
53
|
+
extract_config: Optional function to extract config from args/kwargs
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Decorated function with automatic tracking
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
>>> @track_operation("Generate Prefixes", "STEP1_GENERATE")
|
|
60
|
+
... def generate_prefixes(goals, config, tracker=None):
|
|
61
|
+
... # Function logic
|
|
62
|
+
... return results
|
|
63
|
+
|
|
64
|
+
>>> # With custom extractors
|
|
65
|
+
>>> def get_input(args, kwargs):
|
|
66
|
+
... return {"goals": kwargs.get("goals", [])}
|
|
67
|
+
>>>
|
|
68
|
+
>>> @track_operation(
|
|
69
|
+
... "Process Data",
|
|
70
|
+
... "STEP2_PROCESS",
|
|
71
|
+
... extract_input=get_input
|
|
72
|
+
... )
|
|
73
|
+
... def process_data(data, config, tracker=None):
|
|
74
|
+
... return processed_data
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def decorator(func: F) -> F:
|
|
78
|
+
@functools.wraps(func)
|
|
79
|
+
def wrapper(*args, **kwargs):
|
|
80
|
+
# Try to get tracker from kwargs
|
|
81
|
+
tracker = kwargs.get("tracker")
|
|
82
|
+
|
|
83
|
+
# If no tracker or not a StepTracker, just run the function
|
|
84
|
+
if tracker is None or not isinstance(tracker, StepTracker):
|
|
85
|
+
return func(*args, **kwargs)
|
|
86
|
+
|
|
87
|
+
# Extract input data if extractor provided
|
|
88
|
+
input_data = None
|
|
89
|
+
if extract_input is not None:
|
|
90
|
+
try:
|
|
91
|
+
input_data = extract_input(args, kwargs)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
tracker.logger.warning(
|
|
94
|
+
f"Failed to extract input data for '{step_name}': {e}"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
# Default extraction: look for common parameter names
|
|
98
|
+
input_data = _default_extract_input(args, kwargs)
|
|
99
|
+
|
|
100
|
+
# Extract config if extractor provided
|
|
101
|
+
config = None
|
|
102
|
+
if extract_config is not None:
|
|
103
|
+
try:
|
|
104
|
+
config = extract_config(args, kwargs)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
tracker.logger.warning(
|
|
107
|
+
f"Failed to extract config for '{step_name}': {e}"
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
# Default extraction: look for 'config' parameter
|
|
111
|
+
config = kwargs.get("config")
|
|
112
|
+
|
|
113
|
+
# Track the operation
|
|
114
|
+
with tracker.track_step(
|
|
115
|
+
step_name=step_name,
|
|
116
|
+
step_type=step_type,
|
|
117
|
+
input_data=input_data,
|
|
118
|
+
config=config,
|
|
119
|
+
):
|
|
120
|
+
result = func(*args, **kwargs)
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
return wrapper # type: ignore
|
|
124
|
+
|
|
125
|
+
return decorator
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _default_extract_input(args: tuple, kwargs: dict) -> Optional[Dict[str, Any]]:
|
|
129
|
+
"""
|
|
130
|
+
Default input extractor for track_operation decorator.
|
|
131
|
+
|
|
132
|
+
Looks for common parameter names that might contain input data:
|
|
133
|
+
- input_df, df, data, dataframe
|
|
134
|
+
- goals, targets
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
args: Positional arguments
|
|
138
|
+
kwargs: Keyword arguments
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dictionary with extracted input sample or None
|
|
142
|
+
"""
|
|
143
|
+
# Try to get DataFrame-like input
|
|
144
|
+
for key in ["input_df", "df", "data", "dataframe"]:
|
|
145
|
+
if key in kwargs:
|
|
146
|
+
value = kwargs[key]
|
|
147
|
+
if hasattr(value, "head"):
|
|
148
|
+
# It's a DataFrame-like object
|
|
149
|
+
try:
|
|
150
|
+
return {"input_sample": value.head().to_dict()}
|
|
151
|
+
except Exception:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
# Try to get list inputs
|
|
155
|
+
for key in ["goals", "targets", "inputs"]:
|
|
156
|
+
if key in kwargs:
|
|
157
|
+
value = kwargs[key]
|
|
158
|
+
if isinstance(value, list):
|
|
159
|
+
# Sample first few items
|
|
160
|
+
sample = value[:5] if len(value) > 5 else value
|
|
161
|
+
return {key: sample}
|
|
162
|
+
|
|
163
|
+
# Try first positional argument if it's a DataFrame
|
|
164
|
+
if args and hasattr(args[0], "head"):
|
|
165
|
+
try:
|
|
166
|
+
return {"input_sample": args[0].head().to_dict()}
|
|
167
|
+
except Exception:
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def track_pipeline(tracker_param: str = "tracker"):
|
|
174
|
+
"""
|
|
175
|
+
Class decorator for automatic pipeline tracking.
|
|
176
|
+
|
|
177
|
+
This decorator can be applied to a class to make all its methods
|
|
178
|
+
automatically aware of a tracker instance. It's useful for pipeline
|
|
179
|
+
classes where multiple methods should be tracked.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
tracker_param: Name of the parameter that contains the tracker
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Decorated class with tracking support
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
>>> @track_pipeline(tracker_param="tracker")
|
|
189
|
+
... class MyPipeline:
|
|
190
|
+
... def __init__(self, tracker=None):
|
|
191
|
+
... self.tracker = tracker
|
|
192
|
+
...
|
|
193
|
+
... @track_operation("Step 1", "STEP1")
|
|
194
|
+
... def step1(self, data, tracker=None):
|
|
195
|
+
... return processed_data
|
|
196
|
+
...
|
|
197
|
+
... @track_operation("Step 2", "STEP2")
|
|
198
|
+
... def step2(self, data, tracker=None):
|
|
199
|
+
... return final_data
|
|
200
|
+
|
|
201
|
+
>>> # All methods will automatically use self.tracker
|
|
202
|
+
>>> pipeline = MyPipeline(tracker=my_tracker)
|
|
203
|
+
>>> pipeline.step1(data) # Automatically tracked
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def decorator(cls):
|
|
207
|
+
# Store original methods
|
|
208
|
+
original_methods = {}
|
|
209
|
+
|
|
210
|
+
# Wrap all methods that have tracker parameter
|
|
211
|
+
for attr_name in dir(cls):
|
|
212
|
+
if attr_name.startswith("_"):
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
attr = getattr(cls, attr_name)
|
|
216
|
+
if not callable(attr):
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# Check if method has tracker parameter
|
|
220
|
+
if hasattr(attr, "__code__"):
|
|
221
|
+
param_names = attr.__code__.co_varnames
|
|
222
|
+
if tracker_param in param_names:
|
|
223
|
+
original_methods[attr_name] = attr
|
|
224
|
+
|
|
225
|
+
# Wrap methods to inject tracker from self
|
|
226
|
+
for method_name, original_method in original_methods.items():
|
|
227
|
+
|
|
228
|
+
@functools.wraps(original_method)
|
|
229
|
+
def wrapped_method(self, *args, _original=original_method, **kwargs):
|
|
230
|
+
# Inject tracker from self if not already provided
|
|
231
|
+
if tracker_param not in kwargs and hasattr(self, tracker_param):
|
|
232
|
+
kwargs[tracker_param] = getattr(self, tracker_param)
|
|
233
|
+
return _original(self, *args, **kwargs)
|
|
234
|
+
|
|
235
|
+
setattr(cls, method_name, wrapped_method)
|
|
236
|
+
|
|
237
|
+
return cls
|
|
238
|
+
|
|
239
|
+
return decorator
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def track_method(step_name: str, step_type: str):
|
|
243
|
+
"""
|
|
244
|
+
Method decorator that automatically uses self.tracker.
|
|
245
|
+
|
|
246
|
+
This is a specialized version of track_operation designed for
|
|
247
|
+
class methods. It automatically looks for self.tracker and uses
|
|
248
|
+
it for tracking.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
step_name: Human-readable name for the operation
|
|
252
|
+
step_type: Step type identifier
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Decorated method with automatic tracking
|
|
256
|
+
|
|
257
|
+
Example:
|
|
258
|
+
>>> class Pipeline:
|
|
259
|
+
... def __init__(self, tracker):
|
|
260
|
+
... self.tracker = tracker
|
|
261
|
+
...
|
|
262
|
+
... @track_method("Generate Data", "STEP1")
|
|
263
|
+
... def generate(self, goals):
|
|
264
|
+
... return generated_data
|
|
265
|
+
...
|
|
266
|
+
... @track_method("Process Data", "STEP2")
|
|
267
|
+
... def process(self, data):
|
|
268
|
+
... return processed_data
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def decorator(func: F) -> F:
|
|
272
|
+
@functools.wraps(func)
|
|
273
|
+
def wrapper(self, *args, **kwargs):
|
|
274
|
+
# Get tracker from self if available
|
|
275
|
+
tracker = getattr(self, "tracker", None)
|
|
276
|
+
|
|
277
|
+
# If no tracker, just run the function
|
|
278
|
+
if tracker is None or not isinstance(tracker, StepTracker):
|
|
279
|
+
return func(self, *args, **kwargs)
|
|
280
|
+
|
|
281
|
+
# Extract input data
|
|
282
|
+
input_data = _default_extract_input(args, kwargs)
|
|
283
|
+
|
|
284
|
+
# Extract config (might be in kwargs or self.config)
|
|
285
|
+
config = kwargs.get("config") or getattr(self, "config", None)
|
|
286
|
+
|
|
287
|
+
# Track the operation
|
|
288
|
+
with tracker.track_step(
|
|
289
|
+
step_name=step_name,
|
|
290
|
+
step_type=step_type,
|
|
291
|
+
input_data=input_data,
|
|
292
|
+
config=config,
|
|
293
|
+
):
|
|
294
|
+
result = func(self, *args, **kwargs)
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
return wrapper # type: ignore
|
|
298
|
+
|
|
299
|
+
return decorator
|