aimodelshare 0.1.12__py3-none-any.whl → 0.1.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aimodelshare might be problematic. Click here for more details.
- aimodelshare/__init__.py +94 -14
- aimodelshare/aimsonnx.py +417 -262
- aimodelshare/api.py +7 -6
- aimodelshare/auth.py +163 -0
- aimodelshare/aws.py +4 -4
- aimodelshare/base_image.py +1 -1
- aimodelshare/containerisation.py +1 -1
- aimodelshare/data_sharing/download_data.py +145 -88
- aimodelshare/generatemodelapi.py +7 -6
- aimodelshare/main/eval_lambda.txt +81 -13
- aimodelshare/model.py +493 -197
- aimodelshare/modeluser.py +89 -1
- aimodelshare/moral_compass/README.md +408 -0
- aimodelshare/moral_compass/__init__.py +37 -0
- aimodelshare/moral_compass/_version.py +3 -0
- aimodelshare/moral_compass/api_client.py +601 -0
- aimodelshare/moral_compass/apps/__init__.py +26 -0
- aimodelshare/moral_compass/apps/ai_consequences.py +297 -0
- aimodelshare/moral_compass/apps/judge.py +299 -0
- aimodelshare/moral_compass/apps/tutorial.py +198 -0
- aimodelshare/moral_compass/apps/what_is_ai.py +426 -0
- aimodelshare/moral_compass/challenge.py +365 -0
- aimodelshare/moral_compass/config.py +187 -0
- aimodelshare/playground.py +26 -14
- aimodelshare/preprocessormodules.py +60 -6
- aimodelshare/reproducibility.py +20 -5
- aimodelshare/utils/__init__.py +78 -0
- aimodelshare/utils/optional_deps.py +38 -0
- aimodelshare-0.1.64.dist-info/METADATA +298 -0
- {aimodelshare-0.1.12.dist-info → aimodelshare-0.1.64.dist-info}/RECORD +33 -22
- {aimodelshare-0.1.12.dist-info → aimodelshare-0.1.64.dist-info}/WHEEL +1 -1
- aimodelshare-0.1.64.dist-info/licenses/LICENSE +5 -0
- {aimodelshare-0.1.12.dist-info → aimodelshare-0.1.64.dist-info}/top_level.txt +0 -1
- aimodelshare-0.1.12.dist-info/LICENSE +0 -22
- aimodelshare-0.1.12.dist-info/METADATA +0 -68
- tests/__init__.py +0 -0
- tests/test_aimsonnx.py +0 -135
- tests/test_playground.py +0 -721
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Challenge Manager for Moral Compass system.
|
|
3
|
+
|
|
4
|
+
Provides a local state manager for tracking multi-metric progress
|
|
5
|
+
and syncing with the Moral Compass API.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Dict, Optional, List
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from .api_client import MoralcompassApiClient
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class Question:
|
|
15
|
+
"""Represents a challenge question"""
|
|
16
|
+
id: str
|
|
17
|
+
text: str
|
|
18
|
+
options: List[str]
|
|
19
|
+
correct_index: int
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Task:
|
|
24
|
+
"""Represents a challenge task"""
|
|
25
|
+
id: str
|
|
26
|
+
title: str
|
|
27
|
+
description: str
|
|
28
|
+
questions: List[Question]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class JusticeAndEquityChallenge:
|
|
32
|
+
"""
|
|
33
|
+
Justice & Equity Challenge with predefined tasks and questions.
|
|
34
|
+
|
|
35
|
+
Contains 6 tasks (A-F) with associated questions for teaching
|
|
36
|
+
ethical AI principles related to fairness and bias.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""Initialize the Justice & Equity Challenge with tasks A-F"""
|
|
41
|
+
self.tasks = [
|
|
42
|
+
Task(
|
|
43
|
+
id="A",
|
|
44
|
+
title="Understanding Algorithmic Bias",
|
|
45
|
+
description="Learn about different types of bias in AI systems",
|
|
46
|
+
questions=[
|
|
47
|
+
Question(
|
|
48
|
+
id="A1",
|
|
49
|
+
text="What is algorithmic bias?",
|
|
50
|
+
options=[
|
|
51
|
+
"Bias in the training data",
|
|
52
|
+
"Systematic and repeatable errors in computer systems",
|
|
53
|
+
"User preference bias",
|
|
54
|
+
"Network latency bias"
|
|
55
|
+
],
|
|
56
|
+
correct_index=1
|
|
57
|
+
)
|
|
58
|
+
]
|
|
59
|
+
),
|
|
60
|
+
Task(
|
|
61
|
+
id="B",
|
|
62
|
+
title="Identifying Protected Attributes",
|
|
63
|
+
description="Understanding which attributes require fairness considerations",
|
|
64
|
+
questions=[
|
|
65
|
+
Question(
|
|
66
|
+
id="B1",
|
|
67
|
+
text="Which is a protected attribute in fairness?",
|
|
68
|
+
options=[
|
|
69
|
+
"Email address",
|
|
70
|
+
"Race or ethnicity",
|
|
71
|
+
"Browser type",
|
|
72
|
+
"Screen resolution"
|
|
73
|
+
],
|
|
74
|
+
correct_index=1
|
|
75
|
+
)
|
|
76
|
+
]
|
|
77
|
+
),
|
|
78
|
+
Task(
|
|
79
|
+
id="C",
|
|
80
|
+
title="Measuring Disparate Impact",
|
|
81
|
+
description="Learn to measure fairness using statistical metrics",
|
|
82
|
+
questions=[
|
|
83
|
+
Question(
|
|
84
|
+
id="C1",
|
|
85
|
+
text="What is disparate impact?",
|
|
86
|
+
options=[
|
|
87
|
+
"Equal outcome rates across groups",
|
|
88
|
+
"Different outcome rates for different groups",
|
|
89
|
+
"Same prediction accuracy",
|
|
90
|
+
"Uniform data distribution"
|
|
91
|
+
],
|
|
92
|
+
correct_index=1
|
|
93
|
+
)
|
|
94
|
+
]
|
|
95
|
+
),
|
|
96
|
+
Task(
|
|
97
|
+
id="D",
|
|
98
|
+
title="Evaluating Model Fairness",
|
|
99
|
+
description="Apply fairness metrics to assess model performance",
|
|
100
|
+
questions=[
|
|
101
|
+
Question(
|
|
102
|
+
id="D1",
|
|
103
|
+
text="What does equal opportunity mean?",
|
|
104
|
+
options=[
|
|
105
|
+
"Same accuracy for all groups",
|
|
106
|
+
"Equal true positive rates across groups",
|
|
107
|
+
"Equal false positive rates",
|
|
108
|
+
"Same number of predictions"
|
|
109
|
+
],
|
|
110
|
+
correct_index=1
|
|
111
|
+
)
|
|
112
|
+
]
|
|
113
|
+
),
|
|
114
|
+
Task(
|
|
115
|
+
id="E",
|
|
116
|
+
title="Mitigation Strategies",
|
|
117
|
+
description="Explore techniques to reduce algorithmic bias",
|
|
118
|
+
questions=[
|
|
119
|
+
Question(
|
|
120
|
+
id="E1",
|
|
121
|
+
text="Which is a bias mitigation technique?",
|
|
122
|
+
options=[
|
|
123
|
+
"Ignore protected attributes",
|
|
124
|
+
"Reweighting training samples",
|
|
125
|
+
"Use more servers",
|
|
126
|
+
"Faster algorithms"
|
|
127
|
+
],
|
|
128
|
+
correct_index=1
|
|
129
|
+
)
|
|
130
|
+
]
|
|
131
|
+
),
|
|
132
|
+
Task(
|
|
133
|
+
id="F",
|
|
134
|
+
title="Ethical Deployment",
|
|
135
|
+
description="Best practices for deploying fair AI systems",
|
|
136
|
+
questions=[
|
|
137
|
+
Question(
|
|
138
|
+
id="F1",
|
|
139
|
+
text="What is essential for ethical AI deployment?",
|
|
140
|
+
options=[
|
|
141
|
+
"Fastest inference time",
|
|
142
|
+
"Continuous monitoring and auditing",
|
|
143
|
+
"Most complex model",
|
|
144
|
+
"Largest dataset"
|
|
145
|
+
],
|
|
146
|
+
correct_index=1
|
|
147
|
+
)
|
|
148
|
+
]
|
|
149
|
+
)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def total_tasks(self) -> int:
|
|
154
|
+
"""Total number of tasks in the challenge"""
|
|
155
|
+
return len(self.tasks)
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def total_questions(self) -> int:
|
|
159
|
+
"""Total number of questions across all tasks"""
|
|
160
|
+
return sum(len(task.questions) for task in self.tasks)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ChallengeManager:
|
|
164
|
+
"""
|
|
165
|
+
Manages local state for a user's challenge progress with multiple metrics.
|
|
166
|
+
|
|
167
|
+
Features:
|
|
168
|
+
- Track arbitrary metrics (accuracy, fairness, robustness, etc.)
|
|
169
|
+
- Specify primary metric for scoring
|
|
170
|
+
- Track task and question progress
|
|
171
|
+
- Local preview of moral compass score
|
|
172
|
+
- Sync to server via API
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(self, table_id: str, username: str, api_client: Optional[MoralcompassApiClient] = None,
|
|
176
|
+
challenge: Optional[JusticeAndEquityChallenge] = None):
|
|
177
|
+
"""
|
|
178
|
+
Initialize a challenge manager.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
table_id: The table identifier
|
|
182
|
+
username: The username
|
|
183
|
+
api_client: Optional API client instance (creates new one if None)
|
|
184
|
+
challenge: Optional challenge instance (creates JusticeAndEquityChallenge if None)
|
|
185
|
+
"""
|
|
186
|
+
self.table_id = table_id
|
|
187
|
+
self.username = username
|
|
188
|
+
self.api_client = api_client or MoralcompassApiClient()
|
|
189
|
+
self.challenge = challenge or JusticeAndEquityChallenge()
|
|
190
|
+
|
|
191
|
+
# Metrics state
|
|
192
|
+
self.metrics: Dict[str, float] = {}
|
|
193
|
+
self.primary_metric: Optional[str] = None
|
|
194
|
+
|
|
195
|
+
# Progress state - initialize with challenge totals
|
|
196
|
+
self.tasks_completed: int = 0
|
|
197
|
+
self.total_tasks: int = self.challenge.total_tasks
|
|
198
|
+
self.questions_correct: int = 0
|
|
199
|
+
self.total_questions: int = self.challenge.total_questions
|
|
200
|
+
|
|
201
|
+
# Track completed tasks and answers
|
|
202
|
+
self._completed_task_ids: set = set()
|
|
203
|
+
self._answered_questions: Dict[str, int] = {} # question_id -> selected_index
|
|
204
|
+
|
|
205
|
+
def set_metric(self, name: str, value: float, primary: bool = False) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Set a metric value.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
name: Metric name (e.g., 'accuracy', 'fairness', 'robustness')
|
|
211
|
+
value: Metric value (should be between 0 and 1 typically)
|
|
212
|
+
primary: If True, sets this as the primary metric for scoring
|
|
213
|
+
"""
|
|
214
|
+
self.metrics[name] = value
|
|
215
|
+
|
|
216
|
+
if primary:
|
|
217
|
+
self.primary_metric = name
|
|
218
|
+
|
|
219
|
+
def set_progress(self, tasks_completed: int = 0, total_tasks: int = 0,
|
|
220
|
+
questions_correct: int = 0, total_questions: int = 0) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Set progress counters.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
tasks_completed: Number of tasks completed
|
|
226
|
+
total_tasks: Total number of tasks
|
|
227
|
+
questions_correct: Number of questions answered correctly
|
|
228
|
+
total_questions: Total number of questions
|
|
229
|
+
"""
|
|
230
|
+
self.tasks_completed = tasks_completed
|
|
231
|
+
self.total_tasks = total_tasks
|
|
232
|
+
self.questions_correct = questions_correct
|
|
233
|
+
self.total_questions = total_questions
|
|
234
|
+
|
|
235
|
+
def complete_task(self, task_id: str) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Mark a task as completed.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
task_id: The task identifier (e.g., 'A', 'B', 'C')
|
|
241
|
+
"""
|
|
242
|
+
if task_id not in self._completed_task_ids:
|
|
243
|
+
self._completed_task_ids.add(task_id)
|
|
244
|
+
self.tasks_completed = len(self._completed_task_ids)
|
|
245
|
+
|
|
246
|
+
def answer_question(self, task_id: str, question_id: str, selected_index: int) -> bool:
|
|
247
|
+
"""
|
|
248
|
+
Record an answer to a question.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
task_id: The task identifier
|
|
252
|
+
question_id: The question identifier
|
|
253
|
+
selected_index: The index of the selected answer
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
True if the answer is correct, False otherwise
|
|
257
|
+
"""
|
|
258
|
+
# Find the question
|
|
259
|
+
question = None
|
|
260
|
+
for task in self.challenge.tasks:
|
|
261
|
+
if task.id == task_id:
|
|
262
|
+
for q in task.questions:
|
|
263
|
+
if q.id == question_id:
|
|
264
|
+
question = q
|
|
265
|
+
break
|
|
266
|
+
break
|
|
267
|
+
|
|
268
|
+
if question is None:
|
|
269
|
+
raise ValueError(f"Question {question_id} not found in task {task_id}")
|
|
270
|
+
|
|
271
|
+
# Record the answer
|
|
272
|
+
self._answered_questions[question_id] = selected_index
|
|
273
|
+
|
|
274
|
+
# Check if correct and update counter
|
|
275
|
+
is_correct = (selected_index == question.correct_index)
|
|
276
|
+
|
|
277
|
+
# Recalculate questions_correct
|
|
278
|
+
self.questions_correct = sum(
|
|
279
|
+
1 for qid, idx in self._answered_questions.items()
|
|
280
|
+
if self._is_answer_correct(qid, idx)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return is_correct
|
|
284
|
+
|
|
285
|
+
def _is_answer_correct(self, question_id: str, selected_index: int) -> bool:
|
|
286
|
+
"""Check if an answer is correct"""
|
|
287
|
+
for task in self.challenge.tasks:
|
|
288
|
+
for q in task.questions:
|
|
289
|
+
if q.id == question_id:
|
|
290
|
+
return selected_index == q.correct_index
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
def get_progress_summary(self) -> Dict:
|
|
294
|
+
"""
|
|
295
|
+
Get a summary of current progress.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Dictionary with progress information including local score preview
|
|
299
|
+
"""
|
|
300
|
+
return {
|
|
301
|
+
'tasksCompleted': self.tasks_completed,
|
|
302
|
+
'totalTasks': self.total_tasks,
|
|
303
|
+
'questionsCorrect': self.questions_correct,
|
|
304
|
+
'totalQuestions': self.total_questions,
|
|
305
|
+
'metrics': self.metrics.copy(),
|
|
306
|
+
'primaryMetric': self.primary_metric,
|
|
307
|
+
'localScorePreview': self.get_local_score()
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
def get_local_score(self) -> float:
|
|
311
|
+
"""
|
|
312
|
+
Calculate moral compass score locally without syncing to server.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Moral compass score based on current state
|
|
316
|
+
"""
|
|
317
|
+
if not self.metrics:
|
|
318
|
+
return 0.0
|
|
319
|
+
|
|
320
|
+
# Determine primary metric
|
|
321
|
+
primary_metric = self.primary_metric
|
|
322
|
+
if primary_metric is None:
|
|
323
|
+
if 'accuracy' in self.metrics:
|
|
324
|
+
primary_metric = 'accuracy'
|
|
325
|
+
else:
|
|
326
|
+
primary_metric = sorted(self.metrics.keys())[0]
|
|
327
|
+
|
|
328
|
+
primary_value = self.metrics.get(primary_metric, 0.0)
|
|
329
|
+
|
|
330
|
+
# Calculate progress ratio
|
|
331
|
+
progress_denominator = self.total_tasks + self.total_questions
|
|
332
|
+
if progress_denominator == 0:
|
|
333
|
+
return 0.0
|
|
334
|
+
|
|
335
|
+
progress_ratio = (self.tasks_completed + self.questions_correct) / progress_denominator
|
|
336
|
+
|
|
337
|
+
return primary_value * progress_ratio
|
|
338
|
+
|
|
339
|
+
def sync(self) -> Dict:
|
|
340
|
+
"""
|
|
341
|
+
Sync current state to the Moral Compass API.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
API response dict with moralCompassScore and other fields
|
|
345
|
+
"""
|
|
346
|
+
if not self.metrics:
|
|
347
|
+
raise ValueError("No metrics set. Use set_metric() before syncing.")
|
|
348
|
+
|
|
349
|
+
return self.api_client.update_moral_compass(
|
|
350
|
+
table_id=self.table_id,
|
|
351
|
+
username=self.username,
|
|
352
|
+
metrics=self.metrics,
|
|
353
|
+
tasks_completed=self.tasks_completed,
|
|
354
|
+
total_tasks=self.total_tasks,
|
|
355
|
+
questions_correct=self.questions_correct,
|
|
356
|
+
total_questions=self.total_questions,
|
|
357
|
+
primary_metric=self.primary_metric
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def __repr__(self) -> str:
|
|
361
|
+
return (
|
|
362
|
+
f"ChallengeManager(table_id={self.table_id!r}, username={self.username!r}, "
|
|
363
|
+
f"metrics={self.metrics}, primary_metric={self.primary_metric!r}, "
|
|
364
|
+
f"local_score={self.get_local_score():.4f})"
|
|
365
|
+
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration module for moral_compass API client.
|
|
3
|
+
|
|
4
|
+
Provides API base URL discovery via:
|
|
5
|
+
1. Environment variable MORAL_COMPASS_API_BASE_URL or AIMODELSHARE_API_BASE_URL
|
|
6
|
+
2. Cached terraform outputs file (infra/terraform_outputs.json)
|
|
7
|
+
3. Terraform command execution (fallback)
|
|
8
|
+
|
|
9
|
+
Also provides AWS region discovery for region-aware table naming.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import subprocess
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("aimodelshare.moral_compass")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_aws_region() -> Optional[str]:
|
|
23
|
+
"""
|
|
24
|
+
Discover AWS region from multiple sources.
|
|
25
|
+
|
|
26
|
+
Resolution order:
|
|
27
|
+
1. AWS_REGION environment variable
|
|
28
|
+
2. AWS_DEFAULT_REGION environment variable
|
|
29
|
+
3. Cached terraform outputs file
|
|
30
|
+
4. None (caller should handle default)
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Optional[str]: AWS region name or None
|
|
34
|
+
"""
|
|
35
|
+
# Strategy 1: Check environment variables
|
|
36
|
+
region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
|
|
37
|
+
if region:
|
|
38
|
+
logger.debug(f"Using AWS region from environment: {region}")
|
|
39
|
+
return region
|
|
40
|
+
|
|
41
|
+
# Strategy 2: Try cached terraform outputs
|
|
42
|
+
cached_region = _get_region_from_cached_outputs()
|
|
43
|
+
if cached_region:
|
|
44
|
+
logger.debug(f"Using AWS region from cached terraform outputs: {cached_region}")
|
|
45
|
+
return cached_region
|
|
46
|
+
|
|
47
|
+
# No region found - return None and let caller decide default
|
|
48
|
+
logger.debug("AWS region not found, caller should use default")
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_api_base_url() -> str:
|
|
53
|
+
"""
|
|
54
|
+
Discover API base URL using multiple strategies in order:
|
|
55
|
+
1. Environment variables (MORAL_COMPASS_API_BASE_URL or AIMODELSHARE_API_BASE_URL)
|
|
56
|
+
2. Cached terraform outputs file
|
|
57
|
+
3. Terraform command execution
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
str: The API base URL
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
RuntimeError: If API base URL cannot be determined
|
|
64
|
+
"""
|
|
65
|
+
# Strategy 1: Check environment variables
|
|
66
|
+
env_url = os.getenv("MORAL_COMPASS_API_BASE_URL") or os.getenv("AIMODELSHARE_API_BASE_URL")
|
|
67
|
+
if env_url:
|
|
68
|
+
logger.debug(f"Using API base URL from environment: {env_url}")
|
|
69
|
+
return env_url.rstrip("/")
|
|
70
|
+
|
|
71
|
+
# Strategy 2: Try cached terraform outputs
|
|
72
|
+
cached_url = _get_url_from_cached_outputs()
|
|
73
|
+
if cached_url:
|
|
74
|
+
logger.debug(f"Using API base URL from cached terraform outputs: {cached_url}")
|
|
75
|
+
return cached_url
|
|
76
|
+
|
|
77
|
+
# Strategy 3: Try terraform command (last resort)
|
|
78
|
+
terraform_url = _get_url_from_terraform_command()
|
|
79
|
+
if terraform_url:
|
|
80
|
+
logger.debug(f"Using API base URL from terraform command: {terraform_url}")
|
|
81
|
+
return terraform_url
|
|
82
|
+
|
|
83
|
+
raise RuntimeError(
|
|
84
|
+
"Could not determine API base URL. Please set MORAL_COMPASS_API_BASE_URL "
|
|
85
|
+
"environment variable or ensure terraform outputs are accessible."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_url_from_cached_outputs() -> Optional[str]:
|
|
90
|
+
"""
|
|
91
|
+
Read API base URL from cached terraform outputs file.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Optional[str]: API base URL if found in cache, None otherwise
|
|
95
|
+
"""
|
|
96
|
+
# Look for terraform_outputs.json in infra directory
|
|
97
|
+
repo_root = Path(__file__).parent.parent.parent.parent
|
|
98
|
+
outputs_file = repo_root / "infra" / "terraform_outputs.json"
|
|
99
|
+
|
|
100
|
+
if not outputs_file.exists():
|
|
101
|
+
logger.debug(f"Cached terraform outputs not found at {outputs_file}")
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
with open(outputs_file, "r") as f:
|
|
106
|
+
outputs = json.load(f)
|
|
107
|
+
|
|
108
|
+
# Handle both formats: {"api_base_url": {"value": "..."}} or {"api_base_url": "..."}
|
|
109
|
+
api_base_url = outputs.get("api_base_url")
|
|
110
|
+
if isinstance(api_base_url, dict):
|
|
111
|
+
url = api_base_url.get("value")
|
|
112
|
+
else:
|
|
113
|
+
url = api_base_url
|
|
114
|
+
|
|
115
|
+
if url and url != "null":
|
|
116
|
+
return url.rstrip("/")
|
|
117
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
118
|
+
logger.warning(f"Error reading cached terraform outputs: {e}")
|
|
119
|
+
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _get_region_from_cached_outputs() -> Optional[str]:
|
|
124
|
+
"""
|
|
125
|
+
Read AWS region from cached terraform outputs file.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Optional[str]: AWS region if found in cache, None otherwise
|
|
129
|
+
"""
|
|
130
|
+
# Look for terraform_outputs.json in infra directory
|
|
131
|
+
repo_root = Path(__file__).parent.parent.parent.parent
|
|
132
|
+
outputs_file = repo_root / "infra" / "terraform_outputs.json"
|
|
133
|
+
|
|
134
|
+
if not outputs_file.exists():
|
|
135
|
+
logger.debug(f"Cached terraform outputs not found at {outputs_file}")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
with open(outputs_file, "r") as f:
|
|
140
|
+
outputs = json.load(f)
|
|
141
|
+
|
|
142
|
+
# Handle both formats: {"region": {"value": "..."}} or {"region": "..."}
|
|
143
|
+
region = outputs.get("region") or outputs.get("aws_region")
|
|
144
|
+
if isinstance(region, dict):
|
|
145
|
+
region_value = region.get("value")
|
|
146
|
+
else:
|
|
147
|
+
region_value = region
|
|
148
|
+
|
|
149
|
+
if region_value and region_value != "null":
|
|
150
|
+
return region_value
|
|
151
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
152
|
+
logger.warning(f"Error reading cached terraform outputs: {e}")
|
|
153
|
+
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _get_url_from_terraform_command() -> Optional[str]:
|
|
158
|
+
"""
|
|
159
|
+
Execute terraform command to get API base URL.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Optional[str]: API base URL if terraform command succeeds, None otherwise
|
|
163
|
+
"""
|
|
164
|
+
repo_root = Path(__file__).parent.parent.parent.parent
|
|
165
|
+
infra_dir = repo_root / "infra"
|
|
166
|
+
|
|
167
|
+
if not infra_dir.exists():
|
|
168
|
+
logger.debug(f"Infra directory not found at {infra_dir}")
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
result = subprocess.run(
|
|
173
|
+
["terraform", "output", "-raw", "api_base_url"],
|
|
174
|
+
cwd=infra_dir,
|
|
175
|
+
capture_output=True,
|
|
176
|
+
text=True,
|
|
177
|
+
timeout=10
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if result.returncode == 0:
|
|
181
|
+
url = result.stdout.strip()
|
|
182
|
+
if url and url != "null":
|
|
183
|
+
return url.rstrip("/")
|
|
184
|
+
except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError) as e:
|
|
185
|
+
logger.debug(f"Terraform command failed: {e}")
|
|
186
|
+
|
|
187
|
+
return None
|
aimodelshare/playground.py
CHANGED
|
@@ -1246,13 +1246,19 @@ class ModelPlayground:
|
|
|
1246
1246
|
with HiddenPrints():
|
|
1247
1247
|
competition = Competition(self.playground_url)
|
|
1248
1248
|
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1249
|
+
comp_result = competition.submit_model(model=model,
|
|
1250
|
+
prediction_submission=prediction_submission,
|
|
1251
|
+
preprocessor=preprocessor,
|
|
1252
|
+
reproducibility_env_filepath=reproducibility_env_filepath,
|
|
1253
|
+
custom_metadata=custom_metadata,
|
|
1254
|
+
input_dict=input_dict,
|
|
1255
|
+
print_output=False)
|
|
1256
|
+
|
|
1257
|
+
# Validate return structure before unpacking
|
|
1258
|
+
if not isinstance(comp_result, tuple) or len(comp_result) != 2:
|
|
1259
|
+
raise RuntimeError(f"Invalid return from competition.submit_model: expected (version, url) tuple, got {type(comp_result)}")
|
|
1260
|
+
|
|
1261
|
+
version_comp, model_page = comp_result
|
|
1256
1262
|
|
|
1257
1263
|
print(f"Your model has been submitted to competition as model version {version_comp}.")
|
|
1258
1264
|
|
|
@@ -1260,13 +1266,19 @@ class ModelPlayground:
|
|
|
1260
1266
|
with HiddenPrints():
|
|
1261
1267
|
experiment = Experiment(self.playground_url)
|
|
1262
1268
|
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1269
|
+
exp_result = experiment.submit_model(model=model,
|
|
1270
|
+
prediction_submission=prediction_submission,
|
|
1271
|
+
preprocessor=preprocessor,
|
|
1272
|
+
reproducibility_env_filepath=reproducibility_env_filepath,
|
|
1273
|
+
custom_metadata=custom_metadata,
|
|
1274
|
+
input_dict=input_dict,
|
|
1275
|
+
print_output=False)
|
|
1276
|
+
|
|
1277
|
+
# Validate return structure before unpacking
|
|
1278
|
+
if not isinstance(exp_result, tuple) or len(exp_result) != 2:
|
|
1279
|
+
raise RuntimeError(f"Invalid return from experiment.submit_model: expected (version, url) tuple, got {type(exp_result)}")
|
|
1280
|
+
|
|
1281
|
+
version_exp, model_page = exp_result
|
|
1270
1282
|
|
|
1271
1283
|
print(f"Your model has been submitted to experiment as model version {version_exp}.")
|
|
1272
1284
|
|