daita-agents 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daita/__init__.py +216 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +743 -0
- daita/agents/substrate.py +1141 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +481 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +779 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +459 -0
- daita/core/tools.py +554 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1144 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +291 -0
- daita/llm/base.py +530 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +355 -0
- daita/llm/grok.py +219 -0
- daita/llm/mock.py +172 -0
- daita/llm/openai.py +220 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +849 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +520 -0
- daita/plugins/mysql.py +362 -0
- daita/plugins/postgresql.py +342 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +537 -0
- daita/plugins/s3.py +770 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.2.0.dist-info/METADATA +409 -0
- daita_agents-0.2.0.dist-info/RECORD +69 -0
- daita_agents-0.2.0.dist-info/WHEEL +5 -0
- daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Deployment management commands for Daita CLI.
|
|
3
|
+
|
|
4
|
+
Provides commands to list, download, and manage deployments.
|
|
5
|
+
"""
|
|
6
|
+
import os
|
|
7
|
+
import json
|
|
8
|
+
import asyncio
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from ..utils import find_project_root
|
|
13
|
+
|
|
14
|
+
async def list_deployments(project_name: Optional[str] = None, environment: Optional[str] = None, limit: int = 10):
|
|
15
|
+
"""List deployment history from managed cloud API."""
|
|
16
|
+
import aiohttp
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
# Get API key
|
|
20
|
+
api_key = os.getenv('DAITA_API_KEY')
|
|
21
|
+
if not api_key:
|
|
22
|
+
print(" No DAITA_API_KEY found")
|
|
23
|
+
print(" Get your API key at daita-tech.io")
|
|
24
|
+
print(" Then: export DAITA_API_KEY='your-key-here'")
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
# If no project specified, try to get from current directory
|
|
28
|
+
if not project_name:
|
|
29
|
+
project_root = find_project_root()
|
|
30
|
+
if project_root:
|
|
31
|
+
config = _load_project_config(project_root)
|
|
32
|
+
if config:
|
|
33
|
+
project_name = config.get('name', 'unknown')
|
|
34
|
+
|
|
35
|
+
# Get API endpoint
|
|
36
|
+
api_endpoint = os.getenv('DAITA_API_ENDPOINT', 'https://ondk4sdyv0.execute-api.us-east-1.amazonaws.com')
|
|
37
|
+
|
|
38
|
+
headers = {
|
|
39
|
+
"Authorization": f"Bearer {api_key}",
|
|
40
|
+
"User-Agent": "Daita-CLI/1.0.0"
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Build query parameters
|
|
44
|
+
params = {}
|
|
45
|
+
if project_name:
|
|
46
|
+
params['project_name'] = project_name
|
|
47
|
+
if environment:
|
|
48
|
+
params['environment'] = environment
|
|
49
|
+
if limit:
|
|
50
|
+
params['limit'] = limit
|
|
51
|
+
|
|
52
|
+
print(f" Deployment History{' for ' + repr(project_name) if project_name else ''}")
|
|
53
|
+
if environment:
|
|
54
|
+
print(f" Environment: {environment}")
|
|
55
|
+
print()
|
|
56
|
+
|
|
57
|
+
# Fetch deployment history from API
|
|
58
|
+
async with aiohttp.ClientSession() as session:
|
|
59
|
+
url = f"{api_endpoint}/api/v1/deployments/api-key"
|
|
60
|
+
|
|
61
|
+
async with session.get(url, headers=headers, params=params, timeout=10) as response:
|
|
62
|
+
if response.status == 200:
|
|
63
|
+
data = await response.json()
|
|
64
|
+
|
|
65
|
+
# Handle paginated response from API
|
|
66
|
+
if isinstance(data, dict) and 'deployments' in data:
|
|
67
|
+
deployments = data['deployments']
|
|
68
|
+
else:
|
|
69
|
+
deployments = data if isinstance(data, list) else []
|
|
70
|
+
|
|
71
|
+
if not deployments:
|
|
72
|
+
print(" No deployments found.")
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
# Display deployments
|
|
76
|
+
for i, deployment in enumerate(deployments[:limit], 1):
|
|
77
|
+
status_emoji = "●" if deployment.get('status') == 'active' else "○"
|
|
78
|
+
deployed_at = deployment.get('deployed_at', '')
|
|
79
|
+
|
|
80
|
+
# Parse deployment timestamp
|
|
81
|
+
try:
|
|
82
|
+
deployed_date = datetime.fromisoformat(deployed_at.replace('Z', '+00:00'))
|
|
83
|
+
deployed_str = deployed_date.strftime('%Y-%m-%d %H:%M:%S UTC')
|
|
84
|
+
except:
|
|
85
|
+
deployed_str = deployed_at
|
|
86
|
+
|
|
87
|
+
deployment_id = deployment.get('deployment_id', 'unknown')
|
|
88
|
+
|
|
89
|
+
print(f"{i:2}. {status_emoji} {deployment_id[:36]}")
|
|
90
|
+
print(f" Environment: {deployment.get('environment', 'unknown')}")
|
|
91
|
+
print(f" Version: {deployment.get('version', '1.0.0')}")
|
|
92
|
+
print(f" Deployed: {deployed_str}")
|
|
93
|
+
|
|
94
|
+
# Show package size if available
|
|
95
|
+
if 'package_size_bytes' in deployment:
|
|
96
|
+
package_size_mb = deployment['package_size_bytes'] / 1024 / 1024
|
|
97
|
+
print(f" Package: {package_size_mb:.1f}MB")
|
|
98
|
+
|
|
99
|
+
# Show agents from deployment info
|
|
100
|
+
deployment_info = deployment.get('deployment_info', {})
|
|
101
|
+
if deployment_info and 'functions' in deployment_info:
|
|
102
|
+
agent_names = [f.get('agent_name', 'Unknown') for f in deployment_info['functions']]
|
|
103
|
+
if agent_names:
|
|
104
|
+
print(f" Agents: {', '.join(agent_names)}")
|
|
105
|
+
|
|
106
|
+
print()
|
|
107
|
+
|
|
108
|
+
# Show total count if there are more
|
|
109
|
+
total_count = len(deployments)
|
|
110
|
+
if total_count > limit:
|
|
111
|
+
print(f" Showing {limit} of {total_count} deployments")
|
|
112
|
+
print(f" Use --limit {total_count} to see all deployments")
|
|
113
|
+
|
|
114
|
+
elif response.status == 401:
|
|
115
|
+
print(" Authentication failed")
|
|
116
|
+
print(" Check your DAITA_API_KEY")
|
|
117
|
+
else:
|
|
118
|
+
error_text = await response.text()
|
|
119
|
+
print(f" Failed to fetch deployments (HTTP {response.status})")
|
|
120
|
+
print(f" {error_text}")
|
|
121
|
+
|
|
122
|
+
except aiohttp.ClientConnectorError:
|
|
123
|
+
print(" Cannot connect to deployment API")
|
|
124
|
+
print(" Check your internet connection")
|
|
125
|
+
except Exception as e:
|
|
126
|
+
print(f" Failed to list deployments: {e}")
|
|
127
|
+
|
|
128
|
+
async def download_deployment(deployment_id: str, output_path: Optional[str] = None):
|
|
129
|
+
"""Download a deployment package."""
|
|
130
|
+
try:
|
|
131
|
+
from ...cloud.lambda_deploy import LambdaDeployer
|
|
132
|
+
|
|
133
|
+
# Get AWS region
|
|
134
|
+
aws_region = os.getenv('AWS_REGION', 'us-east-1')
|
|
135
|
+
deployer = LambdaDeployer(aws_region)
|
|
136
|
+
|
|
137
|
+
# Determine output path
|
|
138
|
+
if not output_path:
|
|
139
|
+
output_path = f"{deployment_id}.zip"
|
|
140
|
+
|
|
141
|
+
output_file = Path(output_path)
|
|
142
|
+
|
|
143
|
+
print(f" Downloading deployment: {deployment_id}")
|
|
144
|
+
print(f" Output: {output_file.absolute()}")
|
|
145
|
+
|
|
146
|
+
# Download deployment
|
|
147
|
+
success = await deployer.download_deployment(deployment_id, output_file)
|
|
148
|
+
|
|
149
|
+
if success:
|
|
150
|
+
file_size = output_file.stat().st_size
|
|
151
|
+
print(f" Downloaded successfully ({file_size / 1024 / 1024:.1f}MB)")
|
|
152
|
+
else:
|
|
153
|
+
print(f" Download failed")
|
|
154
|
+
|
|
155
|
+
except Exception as e:
|
|
156
|
+
print(f" Failed to download deployment: {e}")
|
|
157
|
+
|
|
158
|
+
async def show_deployment_details(deployment_id: str):
|
|
159
|
+
"""Show detailed information about a deployment."""
|
|
160
|
+
try:
|
|
161
|
+
from ...cloud.lambda_deploy import LambdaDeployer
|
|
162
|
+
|
|
163
|
+
# Get AWS region
|
|
164
|
+
aws_region = os.getenv('AWS_REGION', 'us-east-1')
|
|
165
|
+
deployer = LambdaDeployer(aws_region)
|
|
166
|
+
|
|
167
|
+
print(f" Deployment Details: {deployment_id}")
|
|
168
|
+
print()
|
|
169
|
+
|
|
170
|
+
# Get all deployments and find the target
|
|
171
|
+
all_deployments = await deployer.get_deployment_history('all', limit=100)
|
|
172
|
+
target_deployment = None
|
|
173
|
+
|
|
174
|
+
for deployment in all_deployments:
|
|
175
|
+
if deployment['deployment_id'] == deployment_id:
|
|
176
|
+
target_deployment = deployment
|
|
177
|
+
break
|
|
178
|
+
|
|
179
|
+
if not target_deployment:
|
|
180
|
+
print(f" Deployment {deployment_id} not found")
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
# Display detailed information
|
|
184
|
+
deployed_date = datetime.fromisoformat(target_deployment['deployed_at'].replace('Z', '+00:00'))
|
|
185
|
+
|
|
186
|
+
print(f"Project: {target_deployment['project_name']}")
|
|
187
|
+
print(f"Environment: {target_deployment['environment']}")
|
|
188
|
+
print(f"Version: {target_deployment.get('version', '1.0.0')}")
|
|
189
|
+
print(f"Deployed At: {deployed_date.strftime('%Y-%m-%d %H:%M:%S UTC')}")
|
|
190
|
+
print(f"Package Size: {target_deployment.get('package_size_bytes', 0) / 1024 / 1024:.1f}MB")
|
|
191
|
+
print(f"Package Hash: {target_deployment.get('package_hash', 'N/A')}")
|
|
192
|
+
|
|
193
|
+
if target_deployment.get('s3_bucket'):
|
|
194
|
+
print(f"S3 Location: s3://{target_deployment['s3_bucket']}/{target_deployment['s3_key']}")
|
|
195
|
+
|
|
196
|
+
print()
|
|
197
|
+
|
|
198
|
+
# Show agents
|
|
199
|
+
agents = target_deployment.get('agents', [])
|
|
200
|
+
if agents:
|
|
201
|
+
print("Agents:")
|
|
202
|
+
for agent in agents:
|
|
203
|
+
print(f" {agent.get('name', 'Unknown')}")
|
|
204
|
+
print(f" Type: {agent.get('type', 'substrate')}")
|
|
205
|
+
print(f" Enabled: {agent.get('enabled', True)}")
|
|
206
|
+
if agent.get('file'):
|
|
207
|
+
print(f" File: {agent['file']}")
|
|
208
|
+
print()
|
|
209
|
+
|
|
210
|
+
# Show workflows
|
|
211
|
+
workflows = target_deployment.get('workflows', [])
|
|
212
|
+
if workflows:
|
|
213
|
+
print("Workflows:")
|
|
214
|
+
for workflow in workflows:
|
|
215
|
+
print(f" {workflow.get('name', 'Unknown')}")
|
|
216
|
+
print(f" Type: {workflow.get('type', 'basic')}")
|
|
217
|
+
print(f" Enabled: {workflow.get('enabled', True)}")
|
|
218
|
+
if workflow.get('file'):
|
|
219
|
+
print(f" File: {workflow['file']}")
|
|
220
|
+
print()
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
print(f" Failed to get deployment details: {e}")
|
|
224
|
+
|
|
225
|
+
async def rollback_deployment(deployment_id: str, environment: str = 'production'):
|
|
226
|
+
"""Rollback to a previous deployment."""
|
|
227
|
+
try:
|
|
228
|
+
from ...cloud.lambda_deploy import LambdaDeployer
|
|
229
|
+
|
|
230
|
+
# Get AWS region
|
|
231
|
+
aws_region = os.getenv('AWS_REGION', 'us-east-1')
|
|
232
|
+
deployer = LambdaDeployer(aws_region)
|
|
233
|
+
|
|
234
|
+
print(f" Rolling back to deployment: {deployment_id}")
|
|
235
|
+
print(f" Target environment: {environment}")
|
|
236
|
+
|
|
237
|
+
# Confirm rollback
|
|
238
|
+
if environment == 'production':
|
|
239
|
+
confirm = input(" Rollback PRODUCTION environment? Type 'yes' to confirm: ")
|
|
240
|
+
if confirm != 'yes':
|
|
241
|
+
print(" Rollback cancelled")
|
|
242
|
+
return
|
|
243
|
+
|
|
244
|
+
# Perform rollback
|
|
245
|
+
result = await deployer.rollback_deployment(deployment_id, environment)
|
|
246
|
+
|
|
247
|
+
if result['status'] == 'success':
|
|
248
|
+
print(f" Rollback initiated")
|
|
249
|
+
print(f" New deployment ID: {result['rollback_deployment_id']}")
|
|
250
|
+
print(f" Original deployment: {result['original_deployment_id']}")
|
|
251
|
+
print(f" {result['message']}")
|
|
252
|
+
else:
|
|
253
|
+
print(f" Rollback failed: {result['error']}")
|
|
254
|
+
|
|
255
|
+
except Exception as e:
|
|
256
|
+
print(f" Failed to rollback deployment: {e}")
|
|
257
|
+
|
|
258
|
+
async def delete_deployment(deployment_id: str, force: bool = False):
|
|
259
|
+
"""Delete a deployment and its Lambda functions."""
|
|
260
|
+
try:
|
|
261
|
+
from ...cloud.lambda_deploy import LambdaDeployer
|
|
262
|
+
|
|
263
|
+
# Get AWS region
|
|
264
|
+
aws_region = os.getenv('AWS_REGION', 'us-east-1')
|
|
265
|
+
deployer = LambdaDeployer(aws_region)
|
|
266
|
+
|
|
267
|
+
print(f" Deleting deployment: {deployment_id}")
|
|
268
|
+
|
|
269
|
+
# Confirm deletion
|
|
270
|
+
if not force:
|
|
271
|
+
confirm = input(" Delete deployment and all Lambda functions? Type 'yes' to confirm: ")
|
|
272
|
+
if confirm != 'yes':
|
|
273
|
+
print(" Deletion cancelled")
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
# Delete deployment
|
|
277
|
+
result = await deployer.delete_deployment(deployment_id)
|
|
278
|
+
|
|
279
|
+
if result['deleted_functions']:
|
|
280
|
+
print(f" Deleted {len(result['deleted_functions'])} Lambda functions:")
|
|
281
|
+
for func_name in result['deleted_functions']:
|
|
282
|
+
print(f" {func_name}")
|
|
283
|
+
|
|
284
|
+
if result['errors']:
|
|
285
|
+
print(f" {len(result['errors'])} errors occurred:")
|
|
286
|
+
for error in result['errors']:
|
|
287
|
+
print(f" {error}")
|
|
288
|
+
|
|
289
|
+
print(f" Note: S3 packages are retained for audit purposes")
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
print(f" Failed to delete deployment: {e}")
|
|
293
|
+
|
|
294
|
+
# Helper functions
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _load_project_config(project_root: Path):
|
|
298
|
+
"""Load project configuration from daita-project.yaml."""
|
|
299
|
+
import yaml
|
|
300
|
+
|
|
301
|
+
config_file = project_root / 'daita-project.yaml'
|
|
302
|
+
if not config_file.exists():
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
with open(config_file, 'r') as f:
|
|
307
|
+
return yaml.safe_load(f)
|
|
308
|
+
except Exception:
|
|
309
|
+
return None
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Import Detection for Daita Lambda Layer Optimization
|
|
4
|
+
|
|
5
|
+
Analyzes user code to detect package imports and determine which Lambda layers
|
|
6
|
+
are needed for deployment. This enables smart layer selection to minimize
|
|
7
|
+
package sizes while ensuring all dependencies are available.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import ast
|
|
11
|
+
import os
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Set, Dict, List, Optional
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
class ImportDetector:
|
|
19
|
+
"""Detects package imports in Python code to optimize Lambda layer selection."""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
# Define packages available in each layer
|
|
23
|
+
self.layer_packages = {
|
|
24
|
+
'common_dependencies': {
|
|
25
|
+
'requests',
|
|
26
|
+
'dateutil', 'python_dateutil', # python-dateutil can be imported as either
|
|
27
|
+
'PIL', 'pillow', # Pillow can be imported as PIL
|
|
28
|
+
'tqdm',
|
|
29
|
+
'joblib'
|
|
30
|
+
},
|
|
31
|
+
'core_dependencies': {
|
|
32
|
+
'pydantic', 'pydantic_core',
|
|
33
|
+
'httpx',
|
|
34
|
+
'aiofiles',
|
|
35
|
+
'boto3', 'botocore'
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Mapping of import names to actual package names
|
|
40
|
+
self.import_mappings = {
|
|
41
|
+
'PIL': 'pillow',
|
|
42
|
+
'dateutil': 'python_dateutil',
|
|
43
|
+
'cv2': 'opencv_python',
|
|
44
|
+
'sklearn': 'scikit_learn',
|
|
45
|
+
'skimage': 'scikit_image'
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
def analyze_file(self, file_path: Path) -> Set[str]:
|
|
49
|
+
"""Analyze a single Python file and return set of imported packages."""
|
|
50
|
+
imports = set()
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
54
|
+
content = f.read()
|
|
55
|
+
|
|
56
|
+
tree = ast.parse(content)
|
|
57
|
+
|
|
58
|
+
for node in ast.walk(tree):
|
|
59
|
+
if isinstance(node, ast.Import):
|
|
60
|
+
for alias in node.names:
|
|
61
|
+
package_name = alias.name.split('.')[0]
|
|
62
|
+
imports.add(package_name)
|
|
63
|
+
|
|
64
|
+
elif isinstance(node, ast.ImportFrom):
|
|
65
|
+
if node.module:
|
|
66
|
+
package_name = node.module.split('.')[0]
|
|
67
|
+
imports.add(package_name)
|
|
68
|
+
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.warning(f"Could not parse {file_path}: {e}")
|
|
71
|
+
|
|
72
|
+
return imports
|
|
73
|
+
|
|
74
|
+
def analyze_directory(self, directory: Path, exclude_patterns: Optional[List[str]] = None) -> Set[str]:
|
|
75
|
+
"""Analyze all Python files in a directory and return all imported packages."""
|
|
76
|
+
if exclude_patterns is None:
|
|
77
|
+
exclude_patterns = ['__pycache__', '.git', 'node_modules', '.venv', 'venv']
|
|
78
|
+
|
|
79
|
+
all_imports = set()
|
|
80
|
+
|
|
81
|
+
for root, dirs, files in os.walk(directory):
|
|
82
|
+
# Filter out excluded directories
|
|
83
|
+
dirs[:] = [d for d in dirs if not any(pattern in d for pattern in exclude_patterns)]
|
|
84
|
+
|
|
85
|
+
for file in files:
|
|
86
|
+
if file.endswith('.py'):
|
|
87
|
+
file_path = Path(root) / file
|
|
88
|
+
file_imports = self.analyze_file(file_path)
|
|
89
|
+
all_imports.update(file_imports)
|
|
90
|
+
|
|
91
|
+
return all_imports
|
|
92
|
+
|
|
93
|
+
def normalize_package_name(self, import_name: str) -> str:
|
|
94
|
+
"""Normalize import name to actual package name using mappings."""
|
|
95
|
+
return self.import_mappings.get(import_name, import_name)
|
|
96
|
+
|
|
97
|
+
def detect_required_layers(self, imports: Set[str]) -> Dict[str, List[str]]:
|
|
98
|
+
"""Determine which layers are needed based on detected imports."""
|
|
99
|
+
required_layers = {}
|
|
100
|
+
|
|
101
|
+
# Normalize import names
|
|
102
|
+
normalized_imports = {self.normalize_package_name(imp) for imp in imports}
|
|
103
|
+
|
|
104
|
+
for layer_name, layer_packages in self.layer_packages.items():
|
|
105
|
+
matching_packages = []
|
|
106
|
+
|
|
107
|
+
for package in layer_packages:
|
|
108
|
+
if package in normalized_imports or package in imports:
|
|
109
|
+
matching_packages.append(package)
|
|
110
|
+
|
|
111
|
+
if matching_packages:
|
|
112
|
+
required_layers[layer_name] = matching_packages
|
|
113
|
+
|
|
114
|
+
return required_layers
|
|
115
|
+
|
|
116
|
+
def analyze_project(self, project_path: Path) -> Dict[str, any]:
|
|
117
|
+
"""Analyze an entire project and return comprehensive import analysis."""
|
|
118
|
+
logger.info(f" Analyzing imports in project: {project_path}")
|
|
119
|
+
|
|
120
|
+
# Detect all imports
|
|
121
|
+
all_imports = self.analyze_directory(project_path)
|
|
122
|
+
|
|
123
|
+
# Determine required layers
|
|
124
|
+
required_layers = self.detect_required_layers(all_imports)
|
|
125
|
+
|
|
126
|
+
# Calculate optimization potential
|
|
127
|
+
layer_packages_found = set()
|
|
128
|
+
for packages in required_layers.values():
|
|
129
|
+
layer_packages_found.update(packages)
|
|
130
|
+
|
|
131
|
+
analysis = {
|
|
132
|
+
'total_imports': len(all_imports),
|
|
133
|
+
'all_imports': sorted(list(all_imports)),
|
|
134
|
+
'required_layers': required_layers,
|
|
135
|
+
'layer_packages_detected': sorted(list(layer_packages_found)),
|
|
136
|
+
'optimization_potential': len(layer_packages_found) > 0
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
logger.info(f" Analysis complete:")
|
|
140
|
+
logger.info(f" Total imports detected: {analysis['total_imports']}")
|
|
141
|
+
logger.info(f" Layer packages found: {len(layer_packages_found)}")
|
|
142
|
+
logger.info(f" Layers needed: {list(required_layers.keys())}")
|
|
143
|
+
|
|
144
|
+
return analysis
|
|
145
|
+
|
|
146
|
+
def get_layer_arns_for_project(self, project_path: Path, layer_config_path: Path) -> List[str]:
|
|
147
|
+
"""Get the specific layer ARNs needed for a project based on its imports."""
|
|
148
|
+
import json
|
|
149
|
+
|
|
150
|
+
# Analyze project imports
|
|
151
|
+
analysis = self.analyze_project(project_path)
|
|
152
|
+
required_layers = analysis['required_layers']
|
|
153
|
+
|
|
154
|
+
# Load layer configuration
|
|
155
|
+
if not layer_config_path.exists():
|
|
156
|
+
logger.warning(f"Layer config not found at {layer_config_path}")
|
|
157
|
+
return []
|
|
158
|
+
|
|
159
|
+
with open(layer_config_path, 'r') as f:
|
|
160
|
+
layer_config = json.load(f)
|
|
161
|
+
|
|
162
|
+
# Map layer types to ARNs
|
|
163
|
+
layer_arn_mapping = {
|
|
164
|
+
'common_dependencies': layer_config.get('ml_dependencies_layer_arn'),
|
|
165
|
+
'core_dependencies': layer_config.get('dependencies_layer_arn')
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Get ARNs for required layers
|
|
169
|
+
required_arns = []
|
|
170
|
+
for layer_name in required_layers.keys():
|
|
171
|
+
if layer_name in layer_arn_mapping and layer_arn_mapping[layer_name]:
|
|
172
|
+
required_arns.append(layer_arn_mapping[layer_name])
|
|
173
|
+
logger.info(f" Including {layer_name} layer: {layer_arn_mapping[layer_name]}")
|
|
174
|
+
|
|
175
|
+
# Always include framework layer
|
|
176
|
+
framework_arn = layer_config.get('framework_layer_arn')
|
|
177
|
+
if framework_arn:
|
|
178
|
+
required_arns.append(framework_arn)
|
|
179
|
+
logger.info(f" Including framework layer: {framework_arn}")
|
|
180
|
+
|
|
181
|
+
return required_arns
|
|
182
|
+
|
|
183
|
+
def main():
|
|
184
|
+
"""CLI entry point for import detection."""
|
|
185
|
+
import sys
|
|
186
|
+
|
|
187
|
+
if len(sys.argv) < 2:
|
|
188
|
+
print("Usage: python import_detector.py <project_path>")
|
|
189
|
+
sys.exit(1)
|
|
190
|
+
|
|
191
|
+
project_path = Path(sys.argv[1])
|
|
192
|
+
if not project_path.exists():
|
|
193
|
+
print(f"Error: Project path {project_path} does not exist")
|
|
194
|
+
sys.exit(1)
|
|
195
|
+
|
|
196
|
+
detector = ImportDetector()
|
|
197
|
+
analysis = detector.analyze_project(project_path)
|
|
198
|
+
|
|
199
|
+
print("\n" + "="*60)
|
|
200
|
+
print(" IMPORT ANALYSIS RESULTS")
|
|
201
|
+
print("="*60)
|
|
202
|
+
print(f" Total imports: {analysis['total_imports']}")
|
|
203
|
+
print(f" Layer-optimizable packages: {len(analysis['layer_packages_detected'])}")
|
|
204
|
+
print(f" Optimization potential: {'Yes' if analysis['optimization_potential'] else 'No'}")
|
|
205
|
+
|
|
206
|
+
if analysis['required_layers']:
|
|
207
|
+
print("\n Required layers:")
|
|
208
|
+
for layer, packages in analysis['required_layers'].items():
|
|
209
|
+
print(f" {layer}: {', '.join(packages)}")
|
|
210
|
+
|
|
211
|
+
print("\n All detected imports:")
|
|
212
|
+
for imp in analysis['all_imports']:
|
|
213
|
+
status = "" if imp in analysis['layer_packages_detected'] else ""
|
|
214
|
+
print(f" {status} {imp}")
|
|
215
|
+
|
|
216
|
+
print("="*60)
|
|
217
|
+
|
|
218
|
+
if __name__ == "__main__":
|
|
219
|
+
main()
|