aimodelshare 0.1.55__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aimodelshare might be problematic. Click here for more details.

Files changed (35) hide show
  1. aimodelshare/__init__.py +94 -14
  2. aimodelshare/aimsonnx.py +263 -82
  3. aimodelshare/api.py +13 -12
  4. aimodelshare/auth.py +163 -0
  5. aimodelshare/base_image.py +1 -1
  6. aimodelshare/containerisation.py +1 -1
  7. aimodelshare/data_sharing/download_data.py +133 -83
  8. aimodelshare/generatemodelapi.py +7 -6
  9. aimodelshare/main/authorization.txt +275 -275
  10. aimodelshare/main/eval_lambda.txt +81 -13
  11. aimodelshare/model.py +492 -196
  12. aimodelshare/modeluser.py +22 -0
  13. aimodelshare/moral_compass/README.md +367 -0
  14. aimodelshare/moral_compass/__init__.py +58 -0
  15. aimodelshare/moral_compass/_version.py +3 -0
  16. aimodelshare/moral_compass/api_client.py +553 -0
  17. aimodelshare/moral_compass/challenge.py +365 -0
  18. aimodelshare/moral_compass/config.py +187 -0
  19. aimodelshare/playground.py +26 -14
  20. aimodelshare/preprocessormodules.py +60 -6
  21. aimodelshare/pyspark/authorization.txt +258 -258
  22. aimodelshare/pyspark/eval_lambda.txt +1 -1
  23. aimodelshare/reproducibility.py +20 -5
  24. aimodelshare/utils/__init__.py +78 -0
  25. aimodelshare/utils/optional_deps.py +38 -0
  26. aimodelshare-0.1.59.dist-info/METADATA +258 -0
  27. {aimodelshare-0.1.55.dist-info → aimodelshare-0.1.59.dist-info}/RECORD +30 -24
  28. aimodelshare-0.1.59.dist-info/licenses/LICENSE +5 -0
  29. {aimodelshare-0.1.55.dist-info → aimodelshare-0.1.59.dist-info}/top_level.txt +0 -1
  30. aimodelshare-0.1.55.dist-info/METADATA +0 -63
  31. aimodelshare-0.1.55.dist-info/licenses/LICENSE +0 -2
  32. tests/__init__.py +0 -0
  33. tests/test_aimsonnx.py +0 -135
  34. tests/test_playground.py +0 -721
  35. {aimodelshare-0.1.55.dist-info → aimodelshare-0.1.59.dist-info}/WHEEL +0 -0
@@ -0,0 +1,553 @@
1
+ """
2
+ API client for moral_compass REST API.
3
+
4
+ Provides a production-ready client with:
5
+ - Dataclasses for API responses
6
+ - Automatic retries for network and 5xx errors
7
+ - Pagination helpers
8
+ - Structured exceptions
9
+ - Authentication support via JWT tokens
10
+ """
11
+
12
+ import json
13
+ import logging
14
+ import time
15
+ import os
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Dict, Any, Iterator, List
18
+ from urllib.parse import urlencode
19
+
20
+ import requests
21
+ from requests.adapters import HTTPAdapter
22
+ from urllib3.util.retry import Retry
23
+
24
+ from .config import get_api_base_url
25
+
26
+ logger = logging.getLogger("aimodelshare.moral_compass")
27
+
28
+
29
+ # ============================================================================
30
+ # Exceptions
31
+ # ============================================================================
32
+
33
+ class ApiClientError(Exception):
34
+ """Base exception for API client errors"""
35
+ pass
36
+
37
+
38
+ class NotFoundError(ApiClientError):
39
+ """Raised when a resource is not found (404)"""
40
+ pass
41
+
42
+
43
+ class ServerError(ApiClientError):
44
+ """Raised when server returns 5xx error"""
45
+ pass
46
+
47
+
48
+ # ============================================================================
49
+ # Dataclasses
50
+ # ============================================================================
51
+
52
+ @dataclass
53
+ class MoralcompassTableMeta:
54
+ """Metadata for a moral compass table"""
55
+ table_id: str
56
+ display_name: str
57
+ created_at: Optional[str] = None
58
+ is_archived: bool = False
59
+ user_count: int = 0
60
+
61
+
62
+ @dataclass
63
+ class MoralcompassUserStats:
64
+ """Statistics for a user in a table"""
65
+ username: str
66
+ submission_count: int = 0
67
+ total_count: int = 0
68
+ last_updated: Optional[str] = None
69
+
70
+
71
+ # ============================================================================
72
+ # API Client
73
+ # ============================================================================
74
+
75
+ class MoralcompassApiClient:
76
+ """
77
+ Production-ready client for moral_compass REST API.
78
+
79
+ Features:
80
+ - Automatic API base URL discovery
81
+ - Network retries with exponential backoff
82
+ - Pagination helpers
83
+ - Structured exceptions
84
+ - Automatic authentication token attachment
85
+ """
86
+
87
+ def __init__(self, api_base_url: Optional[str] = None, timeout: int = 30, auth_token: Optional[str] = None):
88
+ """
89
+ Initialize the API client.
90
+
91
+ Args:
92
+ api_base_url: Optional explicit API base URL. If None, will auto-discover.
93
+ timeout: Request timeout in seconds (default: 30)
94
+ auth_token: Optional JWT authentication token. If None, will try to get from environment.
95
+ """
96
+ self.api_base_url = (api_base_url or get_api_base_url()).rstrip("/")
97
+ self.timeout = timeout
98
+ self.auth_token = auth_token or self._get_auth_token_from_env()
99
+ self.session = self._create_session()
100
+ logger.info(f"MoralcompassApiClient initialized with base URL: {self.api_base_url}")
101
+
102
+ def _get_auth_token_from_env(self) -> Optional[str]:
103
+ """
104
+ Get authentication token from environment variables.
105
+
106
+ Tries JWT_AUTHORIZATION_TOKEN first, then falls back to AWS_TOKEN.
107
+
108
+ Returns:
109
+ Optional[str]: Token or None if not found
110
+ """
111
+ try:
112
+ from ..auth import get_primary_token
113
+ return get_primary_token()
114
+ except ImportError:
115
+ # Fallback to direct environment variable access if auth module not available
116
+ return os.getenv('JWT_AUTHORIZATION_TOKEN') or os.getenv('AWS_TOKEN')
117
+
118
+ def _create_session(self) -> requests.Session:
119
+ """
120
+ Create a requests session with retry configuration.
121
+
122
+ Returns:
123
+ Configured requests.Session with retry adapter
124
+ """
125
+ session = requests.Session()
126
+
127
+ # Configure retries for network errors and 5xx server errors
128
+ retry_strategy = Retry(
129
+ total=3,
130
+ backoff_factor=1, # 1s, 2s, 4s
131
+ status_forcelist=[500, 502, 503, 504],
132
+ allowed_methods=["HEAD", "GET", "PUT", "PATCH", "POST", "DELETE", "OPTIONS"]
133
+ )
134
+
135
+ adapter = HTTPAdapter(max_retries=retry_strategy)
136
+ session.mount("http://", adapter)
137
+ session.mount("https://", adapter)
138
+
139
+ return session
140
+
141
+ def _request(self, method: str, path: str, **kwargs) -> requests.Response:
142
+ """
143
+ Make an HTTP request with error handling and automatic auth header attachment.
144
+
145
+ Args:
146
+ method: HTTP method
147
+ path: API path (without base URL)
148
+ **kwargs: Additional arguments to pass to requests
149
+
150
+ Returns:
151
+ requests.Response object
152
+
153
+ Raises:
154
+ NotFoundError: If resource not found (404)
155
+ ServerError: If server error (5xx)
156
+ ApiClientError: For other errors
157
+ """
158
+ url = f"{self.api_base_url}/{path.lstrip('/')}"
159
+
160
+ # Add Authorization header if token is available
161
+ if self.auth_token:
162
+ headers = kwargs.get('headers', {})
163
+ headers['Authorization'] = f'Bearer {self.auth_token}'
164
+ kwargs['headers'] = headers
165
+
166
+ try:
167
+ response = self.session.request(
168
+ method,
169
+ url,
170
+ timeout=kwargs.pop("timeout", self.timeout),
171
+ **kwargs
172
+ )
173
+
174
+ # Handle specific error codes
175
+ if response.status_code == 404:
176
+ raise NotFoundError(f"Resource not found: {path} | body={response.text}")
177
+ elif 500 <= response.status_code < 600:
178
+ raise ServerError(f"Server error {response.status_code}: {response.text}")
179
+
180
+ response.raise_for_status()
181
+ return response
182
+
183
+ except requests.exceptions.Timeout as e:
184
+ raise ApiClientError(f"Request timeout: {e}")
185
+ except requests.exceptions.ConnectionError as e:
186
+ raise ApiClientError(f"Connection error: {e}")
187
+ except requests.exceptions.RequestException as e:
188
+ if not isinstance(e, (NotFoundError, ServerError)):
189
+ raise ApiClientError(f"Request failed: {e}")
190
+ raise
191
+
192
+ # ========================================================================
193
+ # Health endpoint
194
+ # ========================================================================
195
+
196
+ def health(self) -> Dict[str, Any]:
197
+ """
198
+ Check API health status.
199
+
200
+ Returns:
201
+ Dict containing health status information
202
+ """
203
+ response = self._request("GET", "/health")
204
+ return response.json()
205
+
206
+ # ========================================================================
207
+ # Table endpoints
208
+ # ========================================================================
209
+
210
+ def create_table(self, table_id: str, display_name: Optional[str] = None,
211
+ playground_url: Optional[str] = None) -> Dict[str, Any]:
212
+ """
213
+ Create a new table.
214
+
215
+ Args:
216
+ table_id: Unique identifier for the table
217
+ display_name: Optional display name (defaults to table_id)
218
+ playground_url: Optional playground URL for ownership and naming validation
219
+
220
+ Returns:
221
+ Dict containing creation response
222
+ """
223
+ payload = {"tableId": table_id}
224
+ if display_name:
225
+ payload["displayName"] = display_name
226
+ if playground_url:
227
+ payload["playgroundUrl"] = playground_url
228
+
229
+ response = self._request("POST", "/tables", json=payload)
230
+ return response.json()
231
+
232
+ def create_table_for_playground(self, playground_url: str, suffix: str = '-mc',
233
+ display_name: Optional[str] = None, region: Optional[str] = None) -> Dict[str, Any]:
234
+ """
235
+ Convenience method to create a moral compass table for a playground.
236
+
237
+ Automatically derives the table ID from the playground URL and suffix.
238
+ Supports region-aware table naming.
239
+
240
+ Args:
241
+ playground_url: URL of the playground
242
+ suffix: Suffix for the table ID (default: '-mc')
243
+ display_name: Optional display name
244
+ region: Optional AWS region for region-aware naming (e.g., 'us-east-1').
245
+ If provided, table ID will be <playgroundId>-<region><suffix>
246
+
247
+ Returns:
248
+ Dict containing creation response
249
+
250
+ Raises:
251
+ ValueError: If playground ID cannot be extracted from URL
252
+
253
+ Examples:
254
+ # Non-region-aware
255
+ create_table_for_playground('https://example.com/playground/my-pg')
256
+ # Creates table: my-pg-mc
257
+
258
+ # Region-aware
259
+ create_table_for_playground('https://example.com/playground/my-pg', region='us-east-1')
260
+ # Creates table: my-pg-us-east-1-mc
261
+ """
262
+ from urllib.parse import urlparse
263
+
264
+ # Extract playground ID from URL
265
+ parsed = urlparse(playground_url)
266
+ path_parts = [p for p in parsed.path.split('/') if p]
267
+
268
+ playground_id = None
269
+ for i, part in enumerate(path_parts):
270
+ if part.lower() in ['playground', 'playgrounds']:
271
+ if i + 1 < len(path_parts):
272
+ playground_id = path_parts[i + 1]
273
+ break
274
+
275
+ if not playground_id and path_parts:
276
+ # Fallback: use last path component
277
+ playground_id = path_parts[-1]
278
+
279
+ if not playground_id:
280
+ raise ValueError(f"Could not extract playground ID from URL: {playground_url}")
281
+
282
+ # Build table ID with optional region
283
+ if region:
284
+ table_id = f"{playground_id}-{region}{suffix}"
285
+ else:
286
+ table_id = f"{playground_id}{suffix}"
287
+
288
+ if not display_name:
289
+ region_suffix = f" ({region})" if region else ""
290
+ display_name = f"Moral Compass - {playground_id}{region_suffix}"
291
+
292
+ return self.create_table(table_id=table_id, display_name=display_name,
293
+ playground_url=playground_url)
294
+
295
+ def list_tables(self, limit: int = 50, last_key: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
296
+ """
297
+ List tables with pagination.
298
+
299
+ Args:
300
+ limit: Maximum number of tables to return (default: 50)
301
+ last_key: Pagination key from previous response
302
+
303
+ Returns:
304
+ Dict containing 'tables' list and optional 'lastKey' for pagination
305
+ """
306
+ params = {"limit": limit}
307
+ if last_key:
308
+ params["lastKey"] = json.dumps(last_key)
309
+
310
+ response = self._request("GET", f"/tables?{urlencode(params)}")
311
+ return response.json()
312
+
313
+ def iter_tables(self, limit: int = 50) -> Iterator[MoralcompassTableMeta]:
314
+ """
315
+ Iterate over all tables with automatic pagination.
316
+
317
+ Args:
318
+ limit: Page size (default: 50)
319
+
320
+ Yields:
321
+ MoralcompassTableMeta objects
322
+ """
323
+ last_key = None
324
+
325
+ while True:
326
+ response = self.list_tables(limit=limit, last_key=last_key)
327
+ tables = response.get("tables", [])
328
+
329
+ for table_data in tables:
330
+ yield MoralcompassTableMeta(
331
+ table_id=table_data["tableId"],
332
+ display_name=table_data.get("displayName", table_data["tableId"]),
333
+ created_at=table_data.get("createdAt"),
334
+ is_archived=table_data.get("isArchived", False),
335
+ user_count=table_data.get("userCount", 0)
336
+ )
337
+
338
+ last_key = response.get("lastKey")
339
+ if not last_key:
340
+ break
341
+
342
+ def get_table(self, table_id: str) -> MoralcompassTableMeta:
343
+ """
344
+ Get a specific table by ID.
345
+
346
+ Args:
347
+ table_id: The table identifier
348
+
349
+ Returns:
350
+ MoralcompassTableMeta object
351
+
352
+ Raises:
353
+ NotFoundError: If table not found
354
+ """
355
+ response = self._request("GET", f"/tables/{table_id}")
356
+ data = response.json()
357
+
358
+ return MoralcompassTableMeta(
359
+ table_id=data["tableId"],
360
+ display_name=data.get("displayName", data["tableId"]),
361
+ created_at=data.get("createdAt"),
362
+ is_archived=data.get("isArchived", False),
363
+ user_count=data.get("userCount", 0)
364
+ )
365
+
366
+ def patch_table(self, table_id: str, display_name: Optional[str] = None,
367
+ is_archived: Optional[bool] = None) -> Dict[str, Any]:
368
+ """
369
+ Update table metadata.
370
+
371
+ Args:
372
+ table_id: The table identifier
373
+ display_name: Optional new display name
374
+ is_archived: Optional archive status
375
+
376
+ Returns:
377
+ Dict containing update response
378
+ """
379
+ payload = {}
380
+ if display_name is not None:
381
+ payload["displayName"] = display_name
382
+ if is_archived is not None:
383
+ payload["isArchived"] = is_archived
384
+
385
+ response = self._request("PATCH", f"/tables/{table_id}", json=payload)
386
+ return response.json()
387
+
388
+ def delete_table(self, table_id: str) -> Dict[str, Any]:
389
+ """
390
+ Delete a table and all associated data.
391
+
392
+ Requires owner or admin authorization when AUTH_ENABLED=true.
393
+ Only works when ALLOW_TABLE_DELETE=true on server.
394
+
395
+ Args:
396
+ table_id: The table identifier
397
+
398
+ Returns:
399
+ Dict containing deletion confirmation
400
+
401
+ Raises:
402
+ NotFoundError: If table not found
403
+ ApiClientError: If deletion not allowed or authorization fails
404
+ """
405
+ response = self._request("DELETE", f"/tables/{table_id}")
406
+ return response.json()
407
+
408
+ # ========================================================================
409
+ # User endpoints
410
+ # ========================================================================
411
+
412
+ def list_users(self, table_id: str, limit: int = 50,
413
+ last_key: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
414
+ """
415
+ List users in a table with pagination.
416
+
417
+ Args:
418
+ table_id: The table identifier
419
+ limit: Maximum number of users to return (default: 50)
420
+ last_key: Pagination key from previous response
421
+
422
+ Returns:
423
+ Dict containing 'users' list and optional 'lastKey' for pagination
424
+ """
425
+ params = {"limit": limit}
426
+ if last_key:
427
+ params["lastKey"] = json.dumps(last_key)
428
+
429
+ response = self._request("GET", f"/tables/{table_id}/users?{urlencode(params)}")
430
+ return response.json()
431
+
432
+ def iter_users(self, table_id: str, limit: int = 50) -> Iterator[MoralcompassUserStats]:
433
+ """
434
+ Iterate over all users in a table with automatic pagination.
435
+
436
+ Args:
437
+ table_id: The table identifier
438
+ limit: Page size (default: 50)
439
+
440
+ Yields:
441
+ MoralcompassUserStats objects
442
+ """
443
+ last_key = None
444
+
445
+ while True:
446
+ response = self.list_users(table_id, limit=limit, last_key=last_key)
447
+ users = response.get("users", [])
448
+
449
+ for user_data in users:
450
+ yield MoralcompassUserStats(
451
+ username=user_data["username"],
452
+ submission_count=user_data.get("submissionCount", 0),
453
+ total_count=user_data.get("totalCount", 0),
454
+ last_updated=user_data.get("lastUpdated")
455
+ )
456
+
457
+ last_key = response.get("lastKey")
458
+ if not last_key:
459
+ break
460
+
461
+ def get_user(self, table_id: str, username: str) -> MoralcompassUserStats:
462
+ """
463
+ Get a specific user's stats in a table.
464
+
465
+ Args:
466
+ table_id: The table identifier
467
+ username: The username
468
+
469
+ Returns:
470
+ MoralcompassUserStats object
471
+
472
+ Raises:
473
+ NotFoundError: If user or table not found
474
+ """
475
+ response = self._request("GET", f"/tables/{table_id}/users/{username}")
476
+ data = response.json()
477
+
478
+ return MoralcompassUserStats(
479
+ username=data["username"],
480
+ submission_count=data.get("submissionCount", 0),
481
+ total_count=data.get("totalCount", 0),
482
+ last_updated=data.get("lastUpdated")
483
+ )
484
+
485
+ def put_user(self, table_id: str, username: str,
486
+ submission_count: int, total_count: int) -> Dict[str, Any]:
487
+ """
488
+ Create or update a user's stats in a table.
489
+
490
+ Args:
491
+ table_id: The table identifier
492
+ username: The username
493
+ submission_count: Number of submissions
494
+ total_count: Total count
495
+
496
+ Returns:
497
+ Dict containing update response
498
+ """
499
+ payload = {
500
+ "submissionCount": submission_count,
501
+ "totalCount": total_count
502
+ }
503
+
504
+ response = self._request("PUT", f"/tables/{table_id}/users/{username}", json=payload)
505
+ return response.json()
506
+
507
+ def update_moral_compass(self, table_id: str, username: str,
508
+ metrics: Dict[str, float],
509
+ tasks_completed: int = 0,
510
+ total_tasks: int = 0,
511
+ questions_correct: int = 0,
512
+ total_questions: int = 0,
513
+ primary_metric: Optional[str] = None) -> Dict[str, Any]:
514
+ """
515
+ Update a user's moral compass score with dynamic metrics.
516
+
517
+ Args:
518
+ table_id: The table identifier
519
+ username: The username
520
+ metrics: Dictionary of metric_name -> numeric_value
521
+ tasks_completed: Number of tasks completed (default: 0)
522
+ total_tasks: Total number of tasks (default: 0)
523
+ questions_correct: Number of questions answered correctly (default: 0)
524
+ total_questions: Total number of questions (default: 0)
525
+ primary_metric: Optional primary metric name (defaults to 'accuracy' or first sorted key)
526
+
527
+ Returns:
528
+ Dict containing moralCompassScore and other fields
529
+ """
530
+ payload = {
531
+ "metrics": metrics,
532
+ "tasksCompleted": tasks_completed,
533
+ "totalTasks": total_tasks,
534
+ "questionsCorrect": questions_correct,
535
+ "totalQuestions": total_questions
536
+ }
537
+
538
+ if primary_metric is not None:
539
+ payload["primaryMetric"] = primary_metric
540
+
541
+ # Try hyphenated path first
542
+ try:
543
+ response = self._request("PUT", f"/tables/{table_id}/users/{username}/moral-compass", json=payload)
544
+ return response.json()
545
+ except NotFoundError as e:
546
+ # If route not found, retry with legacy path (no hyphen)
547
+ if "route not found" in str(e).lower():
548
+ logger.warning(f"Hyphenated path failed with 404, retrying with legacy path: {e}")
549
+ response = self._request("PUT", f"/tables/{table_id}/users/{username}/moralcompass", json=payload)
550
+ return response.json()
551
+ else:
552
+ # Resource-level 404 (e.g., table or user not found), don't retry
553
+ raise