aimodelshare 0.1.55__py3-none-any.whl → 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aimodelshare might be problematic. Click here for more details.

Files changed (36) hide show
  1. aimodelshare/__init__.py +94 -14
  2. aimodelshare/aimsonnx.py +263 -82
  3. aimodelshare/api.py +13 -12
  4. aimodelshare/auth.py +163 -0
  5. aimodelshare/aws.py +4 -4
  6. aimodelshare/base_image.py +1 -1
  7. aimodelshare/containerisation.py +1 -1
  8. aimodelshare/data_sharing/download_data.py +133 -83
  9. aimodelshare/generatemodelapi.py +7 -6
  10. aimodelshare/main/authorization.txt +275 -275
  11. aimodelshare/main/eval_lambda.txt +81 -13
  12. aimodelshare/model.py +493 -197
  13. aimodelshare/modeluser.py +89 -1
  14. aimodelshare/moral_compass/README.md +408 -0
  15. aimodelshare/moral_compass/__init__.py +58 -0
  16. aimodelshare/moral_compass/_version.py +3 -0
  17. aimodelshare/moral_compass/api_client.py +601 -0
  18. aimodelshare/moral_compass/challenge.py +365 -0
  19. aimodelshare/moral_compass/config.py +187 -0
  20. aimodelshare/playground.py +26 -14
  21. aimodelshare/preprocessormodules.py +60 -6
  22. aimodelshare/pyspark/authorization.txt +258 -258
  23. aimodelshare/pyspark/eval_lambda.txt +1 -1
  24. aimodelshare/reproducibility.py +20 -5
  25. aimodelshare/utils/__init__.py +78 -0
  26. aimodelshare/utils/optional_deps.py +38 -0
  27. aimodelshare-0.1.60.dist-info/METADATA +258 -0
  28. {aimodelshare-0.1.55.dist-info → aimodelshare-0.1.60.dist-info}/RECORD +31 -25
  29. aimodelshare-0.1.60.dist-info/licenses/LICENSE +5 -0
  30. {aimodelshare-0.1.55.dist-info → aimodelshare-0.1.60.dist-info}/top_level.txt +0 -1
  31. aimodelshare-0.1.55.dist-info/METADATA +0 -63
  32. aimodelshare-0.1.55.dist-info/licenses/LICENSE +0 -2
  33. tests/__init__.py +0 -0
  34. tests/test_aimsonnx.py +0 -135
  35. tests/test_playground.py +0 -721
  36. {aimodelshare-0.1.55.dist-info → aimodelshare-0.1.60.dist-info}/WHEEL +0 -0
@@ -0,0 +1,601 @@
1
+ """
2
+ API client for moral_compass REST API.
3
+
4
+ Provides a production-ready client with:
5
+ - Dataclasses for API responses
6
+ - Automatic retries for network and 5xx errors
7
+ - Pagination helpers
8
+ - Structured exceptions
9
+ - Authentication support via JWT tokens
10
+ """
11
+
12
+ import json
13
+ import logging
14
+ import time
15
+ import os
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Dict, Any, Iterator, List
18
+ from urllib.parse import urlencode
19
+
20
+ import requests
21
+ from requests.adapters import HTTPAdapter
22
+ from urllib3.util.retry import Retry
23
+
24
+ from .config import get_api_base_url
25
+
26
+ logger = logging.getLogger("aimodelshare.moral_compass")
27
+
28
+
29
+ # ============================================================================
30
+ # Exceptions
31
+ # ============================================================================
32
+
33
+ class ApiClientError(Exception):
34
+ """Base exception for API client errors"""
35
+ pass
36
+
37
+
38
+ class NotFoundError(ApiClientError):
39
+ """Raised when a resource is not found (404)"""
40
+ pass
41
+
42
+
43
+ class ServerError(ApiClientError):
44
+ """Raised when server returns 5xx error"""
45
+ pass
46
+
47
+
48
+ # ============================================================================
49
+ # Dataclasses
50
+ # ============================================================================
51
+
52
+ @dataclass
53
+ class MoralcompassTableMeta:
54
+ """Metadata for a moral compass table"""
55
+ table_id: str
56
+ display_name: str
57
+ created_at: Optional[str] = None
58
+ is_archived: bool = False
59
+ user_count: int = 0
60
+
61
+
62
+ @dataclass
63
+ class MoralcompassUserStats:
64
+ """Statistics for a user in a table"""
65
+ username: str
66
+ submission_count: int = 0
67
+ total_count: int = 0
68
+ last_updated: Optional[str] = None
69
+
70
+
71
+ # ============================================================================
72
+ # API Client
73
+ # ============================================================================
74
+
75
+ class MoralcompassApiClient:
76
+ """
77
+ Production-ready client for moral_compass REST API.
78
+
79
+ Features:
80
+ - Automatic API base URL discovery
81
+ - Network retries with exponential backoff
82
+ - Pagination helpers
83
+ - Structured exceptions
84
+ - Automatic authentication token attachment
85
+ """
86
+
87
+ def __init__(self, api_base_url: Optional[str] = None, timeout: int = 30, auth_token: Optional[str] = None):
88
+ """
89
+ Initialize the API client.
90
+
91
+ Args:
92
+ api_base_url: Optional explicit API base URL. If None, will auto-discover.
93
+ timeout: Request timeout in seconds (default: 30)
94
+ auth_token: Optional JWT authentication token. If None, will try to get from environment.
95
+ """
96
+ self.api_base_url = (api_base_url or get_api_base_url()).rstrip("/")
97
+ self.timeout = timeout
98
+ self.auth_token = auth_token or self._get_auth_token_from_env()
99
+
100
+ # Auto-generate JWT if no token found but credentials available
101
+ if not self.auth_token:
102
+ self._auto_generate_jwt_if_possible()
103
+
104
+ self.session = self._create_session()
105
+ logger.info(f"MoralcompassApiClient initialized with base URL: {self.api_base_url}")
106
+
107
+ def _get_auth_token_from_env(self) -> Optional[str]:
108
+ """
109
+ Get authentication token from environment variables.
110
+
111
+ Tries JWT_AUTHORIZATION_TOKEN first, then falls back to AWS_TOKEN.
112
+
113
+ Returns:
114
+ Optional[str]: Token or None if not found
115
+ """
116
+ try:
117
+ from ..auth import get_primary_token
118
+ return get_primary_token()
119
+ except ImportError:
120
+ # Fallback to direct environment variable access if auth module not available
121
+ return os.getenv('JWT_AUTHORIZATION_TOKEN') or os.getenv('AWS_TOKEN')
122
+
123
+ def _auto_generate_jwt_if_possible(self) -> None:
124
+ """
125
+ Attempt to auto-generate a JWT token if credentials are available.
126
+
127
+ Checks for username/password environment variables and uses them to generate
128
+ a JWT token via aimodelshare.modeluser.get_jwt_token if possible.
129
+
130
+ Sets self.auth_token and exports JWT_AUTHORIZATION_TOKEN if successful.
131
+ """
132
+ # Check for username/password environment variables
133
+ username = os.getenv('AIMODELSHARE_USERNAME') or os.getenv('username')
134
+ password = os.getenv('AIMODELSHARE_PASSWORD') or os.getenv('password')
135
+
136
+ if not (username and password):
137
+ logger.debug("Auto JWT generation skipped: No username/password credentials found in environment")
138
+ return
139
+
140
+ try:
141
+ from aimodelshare.modeluser import get_jwt_token
142
+
143
+ # Generate JWT token
144
+ logger.debug(f"Attempting to auto-generate JWT token for user: {username[:3]}***")
145
+ get_jwt_token(username, password)
146
+
147
+ # get_jwt_token sets JWT_AUTHORIZATION_TOKEN in environment, retrieve it
148
+ token = os.getenv('JWT_AUTHORIZATION_TOKEN')
149
+ if token:
150
+ self.auth_token = token
151
+ logger.info(f"Auto-generated JWT token for moral_compass client. Token: {token[:10]}...")
152
+ else:
153
+ logger.debug("JWT token generation completed but JWT_AUTHORIZATION_TOKEN not found in environment")
154
+
155
+ except Exception as e:
156
+ logger.debug(f"Auto JWT generation failed: {e}")
157
+ # Continue without token - let the actual API calls handle authorization errors
158
+
159
+ def _create_session(self) -> requests.Session:
160
+ """
161
+ Create a requests session with retry configuration.
162
+
163
+ Returns:
164
+ Configured requests.Session with retry adapter
165
+ """
166
+ session = requests.Session()
167
+
168
+ # Configure retries for network errors and 5xx server errors
169
+ retry_strategy = Retry(
170
+ total=3,
171
+ backoff_factor=1, # 1s, 2s, 4s
172
+ status_forcelist=[500, 502, 503, 504],
173
+ allowed_methods=["HEAD", "GET", "PUT", "PATCH", "POST", "DELETE", "OPTIONS"]
174
+ )
175
+
176
+ adapter = HTTPAdapter(max_retries=retry_strategy)
177
+ session.mount("http://", adapter)
178
+ session.mount("https://", adapter)
179
+
180
+ return session
181
+
182
+ def _request(self, method: str, path: str, **kwargs) -> requests.Response:
183
+ """
184
+ Make an HTTP request with error handling and automatic auth header attachment.
185
+
186
+ Args:
187
+ method: HTTP method
188
+ path: API path (without base URL)
189
+ **kwargs: Additional arguments to pass to requests
190
+
191
+ Returns:
192
+ requests.Response object
193
+
194
+ Raises:
195
+ NotFoundError: If resource not found (404)
196
+ ServerError: If server error (5xx)
197
+ ApiClientError: For other errors
198
+ """
199
+ url = f"{self.api_base_url}/{path.lstrip('/')}"
200
+
201
+ # Add Authorization header if token is available
202
+ if self.auth_token:
203
+ headers = kwargs.get('headers', {})
204
+ headers['Authorization'] = f'Bearer {self.auth_token}'
205
+ kwargs['headers'] = headers
206
+
207
+ try:
208
+ response = self.session.request(
209
+ method,
210
+ url,
211
+ timeout=kwargs.pop("timeout", self.timeout),
212
+ **kwargs
213
+ )
214
+
215
+ # Handle specific error codes
216
+ if response.status_code == 401:
217
+ auth_msg = "Authentication failed (401 Unauthorized)"
218
+ if not self.auth_token:
219
+ auth_msg += ". No authentication token provided. Set JWT_AUTHORIZATION_TOKEN environment variable or set AIMODELSHARE_USERNAME/AIMODELSHARE_PASSWORD for automatic JWT generation."
220
+ else:
221
+ auth_msg += f". Token present but invalid or expired. Token: {self.auth_token[:10]}..."
222
+ raise ApiClientError(f"{auth_msg} | URL: {response.url} | Response: {response.text}")
223
+ elif response.status_code == 404:
224
+ raise NotFoundError(f"Resource not found: {path} | body={response.text}")
225
+ elif 500 <= response.status_code < 600:
226
+ raise ServerError(f"Server error {response.status_code}: {response.text}")
227
+
228
+ response.raise_for_status()
229
+ return response
230
+
231
+ except requests.exceptions.Timeout as e:
232
+ raise ApiClientError(f"Request timeout: {e}")
233
+ except requests.exceptions.ConnectionError as e:
234
+ raise ApiClientError(f"Connection error: {e}")
235
+ except requests.exceptions.RequestException as e:
236
+ if not isinstance(e, (NotFoundError, ServerError)):
237
+ raise ApiClientError(f"Request failed: {e}")
238
+ raise
239
+
240
+ # ========================================================================
241
+ # Health endpoint
242
+ # ========================================================================
243
+
244
+ def health(self) -> Dict[str, Any]:
245
+ """
246
+ Check API health status.
247
+
248
+ Returns:
249
+ Dict containing health status information
250
+ """
251
+ response = self._request("GET", "/health")
252
+ return response.json()
253
+
254
+ # ========================================================================
255
+ # Table endpoints
256
+ # ========================================================================
257
+
258
+ def create_table(self, table_id: str, display_name: Optional[str] = None,
259
+ playground_url: Optional[str] = None) -> Dict[str, Any]:
260
+ """
261
+ Create a new table.
262
+
263
+ Args:
264
+ table_id: Unique identifier for the table
265
+ display_name: Optional display name (defaults to table_id)
266
+ playground_url: Optional playground URL for ownership and naming validation
267
+
268
+ Returns:
269
+ Dict containing creation response
270
+ """
271
+ payload = {"tableId": table_id}
272
+ if display_name:
273
+ payload["displayName"] = display_name
274
+ if playground_url:
275
+ payload["playgroundUrl"] = playground_url
276
+
277
+ response = self._request("POST", "/tables", json=payload)
278
+ return response.json()
279
+
280
+ def create_table_for_playground(self, playground_url: str, suffix: str = '-mc',
281
+ display_name: Optional[str] = None, region: Optional[str] = None) -> Dict[str, Any]:
282
+ """
283
+ Convenience method to create a moral compass table for a playground.
284
+
285
+ Automatically derives the table ID from the playground URL and suffix.
286
+ Supports region-aware table naming.
287
+
288
+ Args:
289
+ playground_url: URL of the playground
290
+ suffix: Suffix for the table ID (default: '-mc')
291
+ display_name: Optional display name
292
+ region: Optional AWS region for region-aware naming (e.g., 'us-east-1').
293
+ If provided, table ID will be <playgroundId>-<region><suffix>
294
+
295
+ Returns:
296
+ Dict containing creation response
297
+
298
+ Raises:
299
+ ValueError: If playground ID cannot be extracted from URL
300
+
301
+ Examples:
302
+ # Non-region-aware
303
+ create_table_for_playground('https://example.com/playground/my-pg')
304
+ # Creates table: my-pg-mc
305
+
306
+ # Region-aware
307
+ create_table_for_playground('https://example.com/playground/my-pg', region='us-east-1')
308
+ # Creates table: my-pg-us-east-1-mc
309
+ """
310
+ from urllib.parse import urlparse
311
+
312
+ # Extract playground ID from URL
313
+ parsed = urlparse(playground_url)
314
+ path_parts = [p for p in parsed.path.split('/') if p]
315
+
316
+ playground_id = None
317
+ for i, part in enumerate(path_parts):
318
+ if part.lower() in ['playground', 'playgrounds']:
319
+ if i + 1 < len(path_parts):
320
+ playground_id = path_parts[i + 1]
321
+ break
322
+
323
+ if not playground_id and path_parts:
324
+ # Fallback: use last path component
325
+ playground_id = path_parts[-1]
326
+
327
+ if not playground_id:
328
+ raise ValueError(f"Could not extract playground ID from URL: {playground_url}")
329
+
330
+ # Build table ID with optional region
331
+ if region:
332
+ table_id = f"{playground_id}-{region}{suffix}"
333
+ else:
334
+ table_id = f"{playground_id}{suffix}"
335
+
336
+ if not display_name:
337
+ region_suffix = f" ({region})" if region else ""
338
+ display_name = f"Moral Compass - {playground_id}{region_suffix}"
339
+
340
+ return self.create_table(table_id=table_id, display_name=display_name,
341
+ playground_url=playground_url)
342
+
343
+ def list_tables(self, limit: int = 50, last_key: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
344
+ """
345
+ List tables with pagination.
346
+
347
+ Args:
348
+ limit: Maximum number of tables to return (default: 50)
349
+ last_key: Pagination key from previous response
350
+
351
+ Returns:
352
+ Dict containing 'tables' list and optional 'lastKey' for pagination
353
+ """
354
+ params = {"limit": limit}
355
+ if last_key:
356
+ params["lastKey"] = json.dumps(last_key)
357
+
358
+ response = self._request("GET", f"/tables?{urlencode(params)}")
359
+ return response.json()
360
+
361
+ def iter_tables(self, limit: int = 50) -> Iterator[MoralcompassTableMeta]:
362
+ """
363
+ Iterate over all tables with automatic pagination.
364
+
365
+ Args:
366
+ limit: Page size (default: 50)
367
+
368
+ Yields:
369
+ MoralcompassTableMeta objects
370
+ """
371
+ last_key = None
372
+
373
+ while True:
374
+ response = self.list_tables(limit=limit, last_key=last_key)
375
+ tables = response.get("tables", [])
376
+
377
+ for table_data in tables:
378
+ yield MoralcompassTableMeta(
379
+ table_id=table_data["tableId"],
380
+ display_name=table_data.get("displayName", table_data["tableId"]),
381
+ created_at=table_data.get("createdAt"),
382
+ is_archived=table_data.get("isArchived", False),
383
+ user_count=table_data.get("userCount", 0)
384
+ )
385
+
386
+ last_key = response.get("lastKey")
387
+ if not last_key:
388
+ break
389
+
390
+ def get_table(self, table_id: str) -> MoralcompassTableMeta:
391
+ """
392
+ Get a specific table by ID.
393
+
394
+ Args:
395
+ table_id: The table identifier
396
+
397
+ Returns:
398
+ MoralcompassTableMeta object
399
+
400
+ Raises:
401
+ NotFoundError: If table not found
402
+ """
403
+ response = self._request("GET", f"/tables/{table_id}")
404
+ data = response.json()
405
+
406
+ return MoralcompassTableMeta(
407
+ table_id=data["tableId"],
408
+ display_name=data.get("displayName", data["tableId"]),
409
+ created_at=data.get("createdAt"),
410
+ is_archived=data.get("isArchived", False),
411
+ user_count=data.get("userCount", 0)
412
+ )
413
+
414
+ def patch_table(self, table_id: str, display_name: Optional[str] = None,
415
+ is_archived: Optional[bool] = None) -> Dict[str, Any]:
416
+ """
417
+ Update table metadata.
418
+
419
+ Args:
420
+ table_id: The table identifier
421
+ display_name: Optional new display name
422
+ is_archived: Optional archive status
423
+
424
+ Returns:
425
+ Dict containing update response
426
+ """
427
+ payload = {}
428
+ if display_name is not None:
429
+ payload["displayName"] = display_name
430
+ if is_archived is not None:
431
+ payload["isArchived"] = is_archived
432
+
433
+ response = self._request("PATCH", f"/tables/{table_id}", json=payload)
434
+ return response.json()
435
+
436
+ def delete_table(self, table_id: str) -> Dict[str, Any]:
437
+ """
438
+ Delete a table and all associated data.
439
+
440
+ Requires owner or admin authorization when AUTH_ENABLED=true.
441
+ Only works when ALLOW_TABLE_DELETE=true on server.
442
+
443
+ Args:
444
+ table_id: The table identifier
445
+
446
+ Returns:
447
+ Dict containing deletion confirmation
448
+
449
+ Raises:
450
+ NotFoundError: If table not found
451
+ ApiClientError: If deletion not allowed or authorization fails
452
+ """
453
+ response = self._request("DELETE", f"/tables/{table_id}")
454
+ return response.json()
455
+
456
+ # ========================================================================
457
+ # User endpoints
458
+ # ========================================================================
459
+
460
+ def list_users(self, table_id: str, limit: int = 50,
461
+ last_key: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
462
+ """
463
+ List users in a table with pagination.
464
+
465
+ Args:
466
+ table_id: The table identifier
467
+ limit: Maximum number of users to return (default: 50)
468
+ last_key: Pagination key from previous response
469
+
470
+ Returns:
471
+ Dict containing 'users' list and optional 'lastKey' for pagination
472
+ """
473
+ params = {"limit": limit}
474
+ if last_key:
475
+ params["lastKey"] = json.dumps(last_key)
476
+
477
+ response = self._request("GET", f"/tables/{table_id}/users?{urlencode(params)}")
478
+ return response.json()
479
+
480
+ def iter_users(self, table_id: str, limit: int = 50) -> Iterator[MoralcompassUserStats]:
481
+ """
482
+ Iterate over all users in a table with automatic pagination.
483
+
484
+ Args:
485
+ table_id: The table identifier
486
+ limit: Page size (default: 50)
487
+
488
+ Yields:
489
+ MoralcompassUserStats objects
490
+ """
491
+ last_key = None
492
+
493
+ while True:
494
+ response = self.list_users(table_id, limit=limit, last_key=last_key)
495
+ users = response.get("users", [])
496
+
497
+ for user_data in users:
498
+ yield MoralcompassUserStats(
499
+ username=user_data["username"],
500
+ submission_count=user_data.get("submissionCount", 0),
501
+ total_count=user_data.get("totalCount", 0),
502
+ last_updated=user_data.get("lastUpdated")
503
+ )
504
+
505
+ last_key = response.get("lastKey")
506
+ if not last_key:
507
+ break
508
+
509
+ def get_user(self, table_id: str, username: str) -> MoralcompassUserStats:
510
+ """
511
+ Get a specific user's stats in a table.
512
+
513
+ Args:
514
+ table_id: The table identifier
515
+ username: The username
516
+
517
+ Returns:
518
+ MoralcompassUserStats object
519
+
520
+ Raises:
521
+ NotFoundError: If user or table not found
522
+ """
523
+ response = self._request("GET", f"/tables/{table_id}/users/{username}")
524
+ data = response.json()
525
+
526
+ return MoralcompassUserStats(
527
+ username=data["username"],
528
+ submission_count=data.get("submissionCount", 0),
529
+ total_count=data.get("totalCount", 0),
530
+ last_updated=data.get("lastUpdated")
531
+ )
532
+
533
+ def put_user(self, table_id: str, username: str,
534
+ submission_count: int, total_count: int) -> Dict[str, Any]:
535
+ """
536
+ Create or update a user's stats in a table.
537
+
538
+ Args:
539
+ table_id: The table identifier
540
+ username: The username
541
+ submission_count: Number of submissions
542
+ total_count: Total count
543
+
544
+ Returns:
545
+ Dict containing update response
546
+ """
547
+ payload = {
548
+ "submissionCount": submission_count,
549
+ "totalCount": total_count
550
+ }
551
+
552
+ response = self._request("PUT", f"/tables/{table_id}/users/{username}", json=payload)
553
+ return response.json()
554
+
555
+ def update_moral_compass(self, table_id: str, username: str,
556
+ metrics: Dict[str, float],
557
+ tasks_completed: int = 0,
558
+ total_tasks: int = 0,
559
+ questions_correct: int = 0,
560
+ total_questions: int = 0,
561
+ primary_metric: Optional[str] = None) -> Dict[str, Any]:
562
+ """
563
+ Update a user's moral compass score with dynamic metrics.
564
+
565
+ Args:
566
+ table_id: The table identifier
567
+ username: The username
568
+ metrics: Dictionary of metric_name -> numeric_value
569
+ tasks_completed: Number of tasks completed (default: 0)
570
+ total_tasks: Total number of tasks (default: 0)
571
+ questions_correct: Number of questions answered correctly (default: 0)
572
+ total_questions: Total number of questions (default: 0)
573
+ primary_metric: Optional primary metric name (defaults to 'accuracy' or first sorted key)
574
+
575
+ Returns:
576
+ Dict containing moralCompassScore and other fields
577
+ """
578
+ payload = {
579
+ "metrics": metrics,
580
+ "tasksCompleted": tasks_completed,
581
+ "totalTasks": total_tasks,
582
+ "questionsCorrect": questions_correct,
583
+ "totalQuestions": total_questions
584
+ }
585
+
586
+ if primary_metric is not None:
587
+ payload["primaryMetric"] = primary_metric
588
+
589
+ # Try hyphenated path first
590
+ try:
591
+ response = self._request("PUT", f"/tables/{table_id}/users/{username}/moral-compass", json=payload)
592
+ return response.json()
593
+ except NotFoundError as e:
594
+ # If route not found, retry with legacy path (no hyphen)
595
+ if "route not found" in str(e).lower():
596
+ logger.warning(f"Hyphenated path failed with 404, retrying with legacy path: {e}")
597
+ response = self._request("PUT", f"/tables/{table_id}/users/{username}/moralcompass", json=payload)
598
+ return response.json()
599
+ else:
600
+ # Resource-level 404 (e.g., table or user not found), don't retry
601
+ raise