@aws/ml-container-creator 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,15 +8,29 @@ Subcommands:
8
8
  create-mpg - Create a Model Package Group (idempotent)
9
9
  register-model - Register a model as a versioned Model Package
10
10
  register-adapter - Register an adapter as a versioned Model Package linked to base model
11
+ register-dataset - Register a dataset with content-aware versioning
12
+ resolve-dataset - Resolve a dataset by name (with optional version)
11
13
 
12
14
  Uses sagemaker-core ModelPackageGroup and ModelPackage resource APIs (SDK v3).
13
15
  No boto3 sagemaker client per NFR-3.
14
16
 
15
17
  All output is JSON on stdout for bash consumption.
16
18
  Diagnostic messages go to stderr.
19
+
20
+ Dataset Versioning (F4 Research Spike Findings):
21
+ - DataSet.create() in sagemaker.ai_registry.dataset does NOT accept a `hub_content_version`
22
+ parameter directly. The API signature is: DataSet.create(name=, source=, customization_technique=).
23
+ - DataSet.get() does NOT accept a version filter — it retrieves by name only (latest).
24
+ - There is no `list_hub_content_versions` equivalent for DataSet objects.
25
+ - Conclusion: Native versioning is NOT supported via the DataSet API.
26
+ - Implementation approach: Use description field to embed hash (`[hash:<hex>] description`)
27
+ and local JSON registry with `versions[]` array for version tracking.
28
+ - Multipart S3 ETags (format: `hash-parts`) are not true content hashes but serve as
29
+ change-detection proxies. This is documented and acceptable per design.
17
30
  """
18
31
 
19
32
  import argparse
33
+ import hashlib
20
34
  import json
21
35
  import logging
22
36
  import os
@@ -133,7 +147,7 @@ def _build_adapter_metadata(args):
133
147
  """Build customer_metadata_properties dict for adapter registration.
134
148
 
135
149
  Includes all standard fields plus adapter-specific fields (AC-2.2):
136
- isAdapter, parentModelVersionArn, tuneTechnique, datasetS3Uri.
150
+ isAdapter, parentModelVersionArn, tuneTechnique, datasetS3Uri, datasetVersion.
137
151
  """
138
152
  props = {
139
153
  "deploymentConfig": args.deployment_config or "",
@@ -152,6 +166,11 @@ def _build_adapter_metadata(args):
152
166
  "datasetS3Uri": args.dataset_s3_uri or "",
153
167
  }
154
168
 
169
+ # Include dataset version lineage if available (AC-2.7)
170
+ dataset_version = getattr(args, "dataset_version", "") or ""
171
+ if dataset_version:
172
+ props["datasetVersion"] = dataset_version
173
+
155
174
  return _truncate_metadata(props)
156
175
 
157
176
 
@@ -281,7 +300,10 @@ def cmd_register_model(args):
281
300
  "ModelPackageDescription": description,
282
301
  "ModelApprovalStatus": "Approved",
283
302
  }
284
- if container_image:
303
+ # Only include InferenceSpecification if container image is a valid ECR URI.
304
+ # Non-ECR images (e.g., vllm/vllm-openai:v0.20.2 from DockerHub) cause
305
+ # ValidationException: "Provided image is not a valid ECR image."
306
+ if container_image and ".dkr.ecr." in container_image:
285
307
  create_params["InferenceSpecification"] = {
286
308
  "Containers": [{"Image": container_image}],
287
309
  "SupportedContentTypes": ["application/json"],
@@ -432,7 +454,10 @@ def cmd_register_adapter(args):
432
454
  "ModelPackageDescription": description,
433
455
  "ModelApprovalStatus": "Approved",
434
456
  }
435
- if container_image:
457
+ # Only include InferenceSpecification if container image is a valid ECR URI.
458
+ # Non-ECR images (e.g., vllm/vllm-openai:v0.20.2 from DockerHub) cause
459
+ # ValidationException: "Provided image is not a valid ECR image."
460
+ if container_image and ".dkr.ecr." in container_image:
436
461
  create_params["InferenceSpecification"] = {
437
462
  "Containers": [{"Image": container_image}],
438
463
  "SupportedContentTypes": ["application/json"],
@@ -471,6 +496,7 @@ def cmd_register_adapter(args):
471
496
  # TODO: Once an evaluator registry API is available, upgrade evaluators too.
472
497
 
473
498
  _REGISTRY_DIR = os.path.join(os.path.expanduser("~"), ".ml-container-creator")
499
+ _CONFIG_PATH = os.path.join(_REGISTRY_DIR, "config.json")
474
500
  _DATASETS_REGISTRY = os.path.join(_REGISTRY_DIR, "datasets.json")
475
501
  _EVALUATORS_REGISTRY = os.path.join(_REGISTRY_DIR, "evaluators.json")
476
502
 
@@ -487,6 +513,146 @@ def _check_ai_registry():
487
513
  return False
488
514
 
489
515
 
516
+ def _get_hub_name_from_profile(region=None):
517
+ """Read aiRegistryHubName from the bootstrap profile config.
518
+
519
+ Looks up ~/.ml-container-creator/config.json and finds the profile
520
+ matching the given region. If no region is provided or no matching
521
+ profile is found, returns the first profile with an aiRegistryHubName.
522
+
523
+ Args:
524
+ region: AWS region to match against profile keys (format: <region>-<accountId>)
525
+
526
+ Returns:
527
+ Hub name string (e.g., "mlcc-registry-123456789012") or None if not found.
528
+ """
529
+ try:
530
+ with open(_CONFIG_PATH) as f:
531
+ config = json.load(f)
532
+ except (FileNotFoundError, json.JSONDecodeError, IOError):
533
+ return None
534
+
535
+ profiles = config.get("profiles", {})
536
+ if not profiles:
537
+ return None
538
+
539
+ # Try to find a profile matching the region
540
+ if region:
541
+ for profile_key, profile_data in profiles.items():
542
+ if not isinstance(profile_data, dict):
543
+ continue
544
+ # Profile key format: <region>-<accountId>
545
+ if profile_key.startswith(region):
546
+ hub_name = profile_data.get("aiRegistryHubName")
547
+ if hub_name:
548
+ return hub_name
549
+
550
+ # Fallback: return the first profile that has an aiRegistryHubName
551
+ for profile_data in profiles.values():
552
+ if not isinstance(profile_data, dict):
553
+ continue
554
+ hub_name = profile_data.get("aiRegistryHubName")
555
+ if hub_name:
556
+ return hub_name
557
+
558
+ return None
559
+
560
+
561
+ def _register_to_hub(hub_name, name, s3_uri, technique, description, region):
562
+ """Register dataset to a specific hub by name.
563
+
564
+ Two-phase approach (AC-2.4):
565
+ Phase 1: Check if DataSet.create() accepts a hub_name/config option.
566
+ Phase 2: If no SDK option, use boto3 create_hub_content directly.
567
+
568
+ Must target the specific hub by name — never relies on SDK auto-discovery (AC-2.2).
569
+
570
+ Args:
571
+ hub_name: The hub name to target (e.g., "mlcc-registry-123456789012")
572
+ name: Dataset name
573
+ s3_uri: S3 URI of the dataset
574
+ technique: Tuning technique string (e.g., "sft")
575
+ description: Dataset description (may contain hash tag)
576
+ region: AWS region
577
+
578
+ Returns:
579
+ str: Hub content ARN if successful, None if failed (caller should fall back)
580
+ """
581
+ # ── Phase 1: Check if DataSet.create() accepts hub config ─────────────
582
+ # The sagemaker.ai_registry.dataset.DataSet.create() API signature is:
583
+ # DataSet.create(name=, source=, customization_technique=, description=)
584
+ # It does NOT accept a hub_name, hub_config, or similar parameter.
585
+ # There is no documented env var or session config to override the target hub.
586
+ # Conclusion: SDK DataSet.create() cannot target a specific hub by name.
587
+ # Proceed to Phase 2.
588
+
589
+ # ── Phase 2: Use boto3 create_hub_content directly ────────────────────
590
+ try:
591
+ import boto3
592
+
593
+ sm_client = boto3.client("sagemaker", region_name=region)
594
+
595
+ # Build the document schema for the dataset hub content
596
+ hub_content_document = json.dumps({
597
+ "Source": s3_uri,
598
+ "CustomizationTechnique": technique or "sft",
599
+ })
600
+
601
+ create_params = {
602
+ "HubName": hub_name,
603
+ "HubContentName": name,
604
+ "HubContentType": "Dataset",
605
+ "DocumentSchemaVersion": "1.0.0",
606
+ "HubContentDocument": hub_content_document,
607
+ }
608
+
609
+ if description:
610
+ create_params["HubContentDescription"] = description
611
+
612
+ response = sm_client.create_hub_content(**create_params)
613
+ hub_content_arn = response.get("HubContentArn", "")
614
+ print(f"Registered dataset '{name}' to hub '{hub_name}' (ARN: {hub_content_arn})", file=sys.stderr)
615
+ return hub_content_arn
616
+
617
+ except Exception as e:
618
+ error_msg = str(e).lower()
619
+
620
+ # Hub not found — clear actionable message (AC-2.5)
621
+ if ("resourcenotfound" in error_msg or "resource not found" in error_msg
622
+ or "does not exist" in error_msg or "hub" in error_msg and "not found" in error_msg):
623
+ _warn(
624
+ f"Hub '{hub_name}' not found. "
625
+ "Run `ml-container-creator bootstrap` to provision the AI Registry hub."
626
+ )
627
+ print(
628
+ " Falling back to local JSON registry.",
629
+ file=sys.stderr,
630
+ )
631
+ return None
632
+
633
+ # Already exists — idempotent, treat as success
634
+ if "already exists" in error_msg or "resourceinuse" in error_msg:
635
+ print(f"Dataset '{name}' already exists in hub '{hub_name}' (idempotent)", file=sys.stderr)
636
+ # Try to retrieve the ARN
637
+ try:
638
+ describe_resp = sm_client.describe_hub_content(
639
+ HubName=hub_name,
640
+ HubContentName=name,
641
+ HubContentType="Dataset",
642
+ )
643
+ return describe_resp.get("HubContentArn", "")
644
+ except Exception:
645
+ return ""
646
+
647
+ # Any other error — warn and fall back
648
+ _warn(
649
+ f"Failed to register dataset to hub '{hub_name}': {e}\n"
650
+ " If this persists, run `ml-container-creator bootstrap` to verify hub provisioning.\n"
651
+ " Falling back to local JSON registry."
652
+ )
653
+ return None
654
+
655
+
490
656
  def _ensure_registry_dir():
491
657
  """Create the registry directory if it doesn't exist."""
492
658
  os.makedirs(_REGISTRY_DIR, exist_ok=True)
@@ -511,16 +677,154 @@ def _save_registry(path, entries):
511
677
  json.dump(entries, f, indent=2)
512
678
 
513
679
 
680
+ # ── Dataset Versioning Helpers ─────────────────────────────────────────────────
681
+
682
+
683
+ def _parse_s3_uri(s3_uri):
684
+ """Parse an S3 URI into (bucket, key) tuple.
685
+
686
+ Args:
687
+ s3_uri: S3 URI in format s3://bucket/key or s3://bucket/prefix/
688
+
689
+ Returns:
690
+ Tuple of (bucket, key)
691
+ """
692
+ if not s3_uri.startswith("s3://"):
693
+ raise ValueError(f"Invalid S3 URI: {s3_uri}")
694
+ parts = s3_uri[5:].split("/", 1)
695
+ bucket = parts[0]
696
+ key = parts[1] if len(parts) > 1 else ""
697
+ return bucket, key
698
+
699
+
700
+ def _is_s3_prefix(key):
701
+ """Determine if an S3 key represents a prefix (directory) vs single file.
702
+
703
+ Heuristic: ends with '/' or has no file extension in the last path component.
704
+ """
705
+ if not key or key.endswith("/"):
706
+ return True
707
+ last_part = key.rstrip("/").split("/")[-1]
708
+ return "." not in last_part
709
+
710
+
711
+ def _compute_content_hash(s3_uri, region):
712
+ """Compute a content hash for a dataset at an S3 URI.
713
+
714
+ Single file: S3 ETag (truncated to 16 chars). For non-multipart uploads,
715
+ the ETag is the MD5 of the content. For multipart uploads, ETag is in
716
+ format `hash-parts` — not a true content hash but serves as a change-detection proxy.
717
+
718
+ Directory/prefix: Sort all object keys under prefix, concatenate
719
+ "key:etag" strings, then SHA256 the result. Truncated to 16 hex chars.
720
+
721
+ Args:
722
+ s3_uri: S3 URI (s3://bucket/key or s3://bucket/prefix/)
723
+ region: AWS region for the S3 client
724
+
725
+ Returns:
726
+ 16-character hex hash string
727
+ """
728
+ import boto3
729
+
730
+ s3 = boto3.client("s3", region_name=region)
731
+ bucket, key = _parse_s3_uri(s3_uri)
732
+
733
+ if _is_s3_prefix(key):
734
+ # Prefix/directory — list and hash all objects
735
+ paginator = s3.get_paginator("list_objects_v2")
736
+ etags = []
737
+ prefix = key if key.endswith("/") else key + "/"
738
+ for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
739
+ for obj in page.get("Contents", []):
740
+ etag = obj["ETag"].strip('"')
741
+ etags.append(f"{obj['Key']}:{etag}")
742
+ if not etags:
743
+ # Try without trailing slash (might be a single object path without extension)
744
+ head = s3.head_object(Bucket=bucket, Key=key)
745
+ return head["ETag"].strip('"')[:16]
746
+ etags.sort()
747
+ return hashlib.sha256("\n".join(etags).encode()).hexdigest()[:16]
748
+ else:
749
+ # Single file — use ETag directly
750
+ head = s3.head_object(Bucket=bucket, Key=key)
751
+ return head["ETag"].strip('"')[:16]
752
+
753
+
754
+ def _get_latest_version(name):
755
+ """Get the latest version info for a dataset from the local registry.
756
+
757
+ Checks local JSON registry for the most recent version of a named dataset.
758
+ Returns the latest version string and its content hash, or None if not found.
759
+
760
+ Args:
761
+ name: Dataset name to look up
762
+
763
+ Returns:
764
+ dict with keys: version (str), hash (str|None), ordinal (int)
765
+ or None if dataset not found
766
+ """
767
+ entries = _load_registry(_DATASETS_REGISTRY)
768
+
769
+ for entry in entries:
770
+ if entry.get("name") == name:
771
+ versions = entry.get("versions")
772
+ if versions and len(versions) > 0:
773
+ # Return the last version (most recent)
774
+ latest = versions[-1]
775
+ return {
776
+ "version": latest.get("version", "1.0.0"),
777
+ "hash": latest.get("hash"),
778
+ "ordinal": len(versions),
779
+ }
780
+ else:
781
+ # Legacy entry without versions — treat as v1.0.0 with hash=null (NFR-3)
782
+ return {
783
+ "version": "1.0.0",
784
+ "hash": None,
785
+ "ordinal": 1,
786
+ }
787
+
788
+ return None
789
+
790
+
791
+ def _increment_version(version_str):
792
+ """Increment a semver-like version string (minor bump).
793
+
794
+ 1.0.0 → 1.1.0, 1.1.0 → 1.2.0, etc.
795
+
796
+ Args:
797
+ version_str: Current version string (e.g., "1.0.0")
798
+
799
+ Returns:
800
+ New version string with minor incremented
801
+ """
802
+ parts = version_str.split(".")
803
+ if len(parts) != 3:
804
+ return "1.1.0"
805
+ major, minor, patch = int(parts[0]), int(parts[1]), int(parts[2])
806
+ return f"{major}.{minor + 1}.{patch}"
807
+
808
+
514
809
  # ── Subcommand: register-dataset ─────────────────────────────────────────────
515
810
 
516
811
 
517
812
  def cmd_register_dataset(args):
518
- """Register a dataset into SageMaker AI Registry (preferred) or local registry (fallback).
813
+ """Register a dataset with content-aware versioning.
814
+
815
+ Version logic (AC-1.1 through AC-1.8):
816
+ 1. Compute content hash of the S3 URI
817
+ 2. Look up latest version for this name
818
+ 3. If no existing entry → create version 1.0.0
819
+ 4. If hash matches latest → skip (print "Dataset unchanged (v{N})")
820
+ 5. If hash differs → create new version (minor increment)
821
+ 6. --force flag bypasses hash comparison (always creates new version)
519
822
 
520
823
  Uses sagemaker.ai_registry.dataset.DataSet API (SDK v3) when available.
521
- Falls back to local JSON registry if the API is not installed (Backlog #023).
824
+ Falls back to local JSON registry if the API is not installed.
522
825
 
523
- Returns JSON: {"name": str, "s3_uri": str, "format": str, "technique": str, "arn": str|null, "registered": bool}
826
+ Returns JSON: {"name": str, "s3_uri": str, "format": str, "technique": str,
827
+ "version": str, "hash": str|null, "arn": str|null, "registered": bool, "skipped": bool}
524
828
  """
525
829
  name = args.name
526
830
  s3_uri = args.s3_uri
@@ -529,6 +833,7 @@ def cmd_register_dataset(args):
529
833
  row_count = args.row_count
530
834
  column_schema = args.column_schema
531
835
  project_name = args.project_name or ""
836
+ force = getattr(args, "force", False)
532
837
 
533
838
  # Set region before any sagemaker import (creates boto3 clients at import time)
534
839
  region = getattr(args, 'region', None) or os.environ.get('AWS_DEFAULT_REGION') or os.environ.get('AWS_REGION')
@@ -548,98 +853,187 @@ def cmd_register_dataset(args):
548
853
  except json.JSONDecodeError:
549
854
  _error_exit("--column-schema must be valid JSON", code="INVALID_ARGUMENT")
550
855
 
551
- # Try SageMaker AI Registry API first (Backlog #023)
552
- if _check_ai_registry():
856
+ # ── Step 1: Compute content hash (AC-1.5) ─────────────────────────────────
857
+ content_hash = None
858
+ if region:
553
859
  try:
554
- from sagemaker.ai_registry.dataset import DataSet
555
- from sagemaker.ai_registry.dataset import CustomizationTechnique
556
-
557
- # Map technique string to enum
558
- technique_enum = None
559
- technique_map = {t.name.lower(): t for t in CustomizationTechnique}
560
- if technique.lower() in technique_map:
561
- technique_enum = technique_map[technique.lower()]
562
-
563
- print(f"Registering dataset '{name}' via SageMaker AI Registry...", file=sys.stderr)
564
- dataset = DataSet.create(
565
- name=name,
566
- source=s3_uri,
567
- customization_technique=technique_enum,
568
- )
569
- dataset_arn = dataset.arn
570
-
571
- # Also write to local registry for offline fallback
572
- _write_dataset_to_local_registry(
573
- name=name, s3_uri=s3_uri, data_format=data_format,
574
- technique=technique, row_count=row_count,
575
- column_schema=column_schema, project_name=project_name,
576
- arn=dataset_arn,
577
- )
860
+ content_hash = _compute_content_hash(s3_uri, region)
861
+ print(f"Content hash: {content_hash}", file=sys.stderr)
862
+ except Exception as e:
863
+ _warn(f"Could not compute content hash: {e}. Proceeding without hash.")
864
+ else:
865
+ _warn("No region specified skipping content hash computation.")
866
+
867
+ # ── Step 2: Get latest version (AC-1.2) ───────────────────────────────────
868
+ latest = _get_latest_version(name)
578
869
 
579
- print(f"Registered dataset '{name}' {s3_uri} (ARN: {dataset_arn})", file=sys.stderr)
870
+ # ── Step 3: Version decision (AC-1.3, AC-1.4, AC-1.7) ────────────────────
871
+ if latest is None:
872
+ # First registration — version 1.0.0 (AC-1.1)
873
+ new_version = "1.0.0"
874
+ ordinal = 1
875
+ print(f"First registration of '{name}' → v1 ({new_version})", file=sys.stderr)
876
+ else:
877
+ latest_hash = latest["hash"]
878
+ latest_version = latest["version"]
879
+ ordinal = latest["ordinal"]
880
+
881
+ if not force and content_hash is not None and latest_hash is not None and content_hash == latest_hash:
882
+ # Hash matches — skip (AC-1.3)
883
+ print(f"Dataset unchanged (v{ordinal})", file=sys.stderr)
580
884
  _output({
581
885
  "name": name,
582
886
  "s3_uri": s3_uri,
583
887
  "format": data_format,
584
888
  "technique": technique,
585
- "arn": dataset_arn,
586
- "registered": True,
889
+ "version": latest_version,
890
+ "hash": latest_hash,
891
+ "arn": None,
892
+ "registered": False,
893
+ "skipped": True,
587
894
  })
588
- except Exception as e:
589
- _warn(f"AI Registry registration failed: {e}. Falling back to local registry.")
590
- # Fall through to local registry below
895
+
896
+ # Hash differs or force create new version (AC-1.4, AC-1.7)
897
+ new_version = _increment_version(latest_version)
898
+ ordinal = ordinal + 1
899
+ if force:
900
+ print(f"Force re-registration of '{name}' → v{ordinal} ({new_version})", file=sys.stderr)
901
+ else:
902
+ print(f"Dataset changed — new version v{ordinal} ({new_version})", file=sys.stderr)
903
+
904
+ # ── Step 4: Register via AI Registry (preferred) ──────────────────────────
905
+ description = f"[hash:{content_hash}]" if content_hash else ""
906
+ dataset_arn = None
907
+
908
+ # ── Step 4a: Try hub-targeted registration (AC-2.1, AC-2.2) ───────────
909
+ hub_name = _get_hub_name_from_profile(region)
910
+
911
+ if hub_name:
912
+ # Hub name available in profile — target it explicitly (never auto-discover)
913
+ print(f"Targeting hub '{hub_name}' for dataset registration...", file=sys.stderr)
914
+ hub_arn = _register_to_hub(hub_name, name, s3_uri, technique, description, region)
915
+ if hub_arn is not None:
916
+ dataset_arn = hub_arn
917
+ else:
918
+ # Hub registration failed — fall back to local JSON only (AC-2.5)
919
+ print("Continuing with local JSON registry only.", file=sys.stderr)
591
920
  else:
921
+ # No hub name in profile (legacy/pre-bootstrap) — local JSON only (AC-2.3)
592
922
  _warn(
593
- "sagemaker.ai_registry.dataset.DataSet not available (older SDK). "
594
- "Using local registry fallback."
923
+ "No AI Registry hub configured in profile. "
924
+ "Using local JSON registry only.\n"
925
+ " To enable hub registration, run `ml-container-creator bootstrap`."
595
926
  )
596
927
 
597
- # Fallback: local JSON registry
598
- _write_dataset_to_local_registry(
928
+ # ── Step 5: Write to local registry with versioning (AC-1.8) ──────────────
929
+ _write_dataset_version_to_local_registry(
599
930
  name=name, s3_uri=s3_uri, data_format=data_format,
600
931
  technique=technique, row_count=row_count,
601
932
  column_schema=column_schema, project_name=project_name,
602
- arn=None,
933
+ arn=dataset_arn, version=new_version, content_hash=content_hash,
603
934
  )
604
935
 
605
- print(f"Registered dataset '{name}' → {s3_uri} (local registry)", file=sys.stderr)
936
+ print(f"Registered dataset '{name}' v{ordinal} ({new_version}) → {s3_uri}", file=sys.stderr)
606
937
  _output({
607
938
  "name": name,
608
939
  "s3_uri": s3_uri,
609
940
  "format": data_format,
610
941
  "technique": technique,
611
- "arn": None,
942
+ "version": new_version,
943
+ "hash": content_hash,
944
+ "arn": dataset_arn,
612
945
  "registered": True,
946
+ "skipped": False,
613
947
  })
614
948
 
615
949
 
616
- def _write_dataset_to_local_registry(*, name, s3_uri, data_format, technique,
617
- row_count, column_schema, project_name, arn):
618
- """Write a dataset entry to the local JSON registry (for offline fallback)."""
950
+ def _write_dataset_version_to_local_registry(*, name, s3_uri, data_format, technique,
951
+ row_count, column_schema, project_name,
952
+ arn, version, content_hash):
953
+ """Write a versioned dataset entry to the local JSON registry.
954
+
955
+ Schema (AC-1.8, backward compatible):
956
+ - Each dataset has a `versions[]` array
957
+ - Existing entries without `versions` are treated as v1.0.0 with hash=null (NFR-3)
958
+ - New versions are appended to the array
959
+
960
+ Args:
961
+ name: Dataset name
962
+ s3_uri: S3 URI of the dataset
963
+ data_format: Format (jsonl/parquet/csv)
964
+ technique: Tuning technique
965
+ row_count: Number of rows (optional)
966
+ column_schema: Column schema JSON string (optional)
967
+ project_name: Project name for context
968
+ arn: AI Registry ARN (if registered there)
969
+ version: Version string (e.g., "1.0.0")
970
+ content_hash: Content hash string (16-char hex) or None
971
+ """
619
972
  import datetime
620
973
 
621
974
  entries = _load_registry(_DATASETS_REGISTRY)
622
975
 
623
- entry = {
624
- "name": name,
976
+ now = datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z")
977
+
978
+ version_entry = {
979
+ "version": version,
625
980
  "s3_uri": s3_uri,
626
- "format": data_format,
981
+ "hash": content_hash,
627
982
  "technique": technique,
628
- "row_count": row_count,
629
- "column_schema": column_schema,
630
- "project_name": project_name,
631
- "arn": arn,
632
- "registered_at": datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z"),
983
+ "rows": row_count,
984
+ "registered_at": now,
633
985
  }
634
986
 
635
- # Upsert: replace existing entry with same name, or append
636
- updated = False
987
+ # Find existing entry for this name
988
+ found = False
637
989
  for i, existing in enumerate(entries):
638
990
  if existing.get("name") == name:
639
- entries[i] = entry
640
- updated = True
991
+ found = True
992
+ # Migrate legacy entry (no versions array) to new schema
993
+ if "versions" not in existing:
994
+ legacy_version = {
995
+ "version": "1.0.0",
996
+ "s3_uri": existing.get("s3_uri", ""),
997
+ "hash": None,
998
+ "technique": existing.get("technique", ""),
999
+ "rows": existing.get("row_count"),
1000
+ "registered_at": existing.get("registered_at", now),
1001
+ }
1002
+ existing["versions"] = [legacy_version]
1003
+
1004
+ # Append new version
1005
+ existing["versions"].append(version_entry)
1006
+
1007
+ # Update top-level fields to reflect latest
1008
+ existing["s3_uri"] = s3_uri
1009
+ existing["format"] = data_format
1010
+ existing["technique"] = technique
1011
+ existing["row_count"] = row_count
1012
+ existing["column_schema"] = column_schema
1013
+ existing["project_name"] = project_name
1014
+ existing["arn"] = arn
1015
+ existing["registered_at"] = now
1016
+ existing["latest_version"] = version
1017
+ existing["content_hash"] = content_hash
1018
+ entries[i] = existing
641
1019
  break
642
- if not updated:
1020
+
1021
+ if not found:
1022
+ # New dataset entry
1023
+ entry = {
1024
+ "name": name,
1025
+ "s3_uri": s3_uri,
1026
+ "format": data_format,
1027
+ "technique": technique,
1028
+ "row_count": row_count,
1029
+ "column_schema": column_schema,
1030
+ "project_name": project_name,
1031
+ "arn": arn,
1032
+ "registered_at": now,
1033
+ "latest_version": version,
1034
+ "content_hash": content_hash,
1035
+ "versions": [version_entry],
1036
+ }
643
1037
  entries.append(entry)
644
1038
 
645
1039
  _save_registry(_DATASETS_REGISTRY, entries)
@@ -649,16 +1043,88 @@ def _write_dataset_to_local_registry(*, name, s3_uri, data_format, technique,
649
1043
 
650
1044
 
651
1045
  def cmd_list_datasets(args):
652
- """List all registered datasets from the local registry.
1046
+ """List all registered datasets grouped by name with version summary (AC-3.1).
1047
+
1048
+ Enhanced output includes version_count and latest_version per dataset entry.
1049
+ Groups by name and shows: NAME, TECHNIQUE, VERSIONS (count), LATEST, ROWS, S3_URI.
653
1050
 
654
- Returns JSON: {"datasets": [...]}
1051
+ Returns JSON: {"datasets": [{..., "version_count": int, "latest_version": str}, ...]}
655
1052
  """
656
1053
  entries = _load_registry(_DATASETS_REGISTRY)
1054
+
657
1055
  # Filter by technique if provided
658
1056
  technique = getattr(args, 'technique', None)
659
1057
  if technique:
660
1058
  entries = [e for e in entries if e.get('technique') == technique]
661
- _output({"datasets": entries})
1059
+
1060
+ # Enhance each entry with version_count and latest_version (AC-3.1)
1061
+ enhanced = []
1062
+ for entry in entries:
1063
+ item = dict(entry)
1064
+ versions = entry.get("versions", [])
1065
+ if versions:
1066
+ item["version_count"] = len(versions)
1067
+ item["latest_version"] = versions[-1].get("version", "1.0.0")
1068
+ else:
1069
+ # Legacy entry without versions array — treat as v1.0.0 (NFR-3)
1070
+ item["version_count"] = 1
1071
+ item["latest_version"] = item.get("latest_version", "1.0.0")
1072
+ enhanced.append(item)
1073
+
1074
+ _output({"datasets": enhanced})
1075
+
1076
+
1077
+ # ── Subcommand: list-dataset-versions ─────────────────────────────────────────
1078
+
1079
+
1080
+ def cmd_list_dataset_versions(args):
1081
+ """List all versions for a specific dataset by name (AC-3.3).
1082
+
1083
+ Returns all versions with: VERSION, HASH, DATE, ROWS, S3_URI.
1084
+
1085
+ Args (via argparse):
1086
+ --name: Dataset name (required)
1087
+
1088
+ Returns JSON: {"name": str, "versions": [{"version": str, "hash": str|null,
1089
+ "date": str, "rows": int|null, "s3_uri": str}, ...]}
1090
+ or error if dataset not found.
1091
+ """
1092
+ name = args.name
1093
+ if not name:
1094
+ _error_exit("--name is required", code="MISSING_ARGUMENT")
1095
+
1096
+ entries = _load_registry(_DATASETS_REGISTRY)
1097
+
1098
+ for entry in entries:
1099
+ if entry.get("name") == name:
1100
+ versions = entry.get("versions", [])
1101
+ if not versions:
1102
+ # Legacy entry without versions array — present as single v1.0.0 (NFR-3)
1103
+ versions = [{
1104
+ "version": "1.0.0",
1105
+ "hash": None,
1106
+ "registered_at": entry.get("registered_at", ""),
1107
+ "rows": entry.get("row_count"),
1108
+ "s3_uri": entry.get("s3_uri", ""),
1109
+ }]
1110
+
1111
+ # Normalize output format
1112
+ result_versions = []
1113
+ for v in versions:
1114
+ result_versions.append({
1115
+ "version": v.get("version", "1.0.0"),
1116
+ "hash": v.get("hash"),
1117
+ "date": v.get("registered_at", ""),
1118
+ "rows": v.get("rows"),
1119
+ "s3_uri": v.get("s3_uri", ""),
1120
+ })
1121
+
1122
+ _output({
1123
+ "name": name,
1124
+ "versions": result_versions,
1125
+ })
1126
+
1127
+ _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
662
1128
 
663
1129
 
664
1130
  # ── Subcommand: register-evaluator ───────────────────────────────────────────
@@ -941,18 +1407,30 @@ def cmd_get_version(args):
941
1407
 
942
1408
 
943
1409
  def cmd_resolve_dataset(args):
944
- """Resolve a registered dataset by name.
1410
+ """Resolve a registered dataset by name (with optional version pinning).
945
1411
 
946
1412
  Uses SageMaker AI Registry DataSet.get() when available, falls back to
947
1413
  local JSON registry. Includes ARN in output when available (Backlog #023).
948
1414
 
949
- Returns JSON: {"name": str, "s3_uri": str, "arn": str|null, "format": str, "technique": str, ...}
1415
+ Version resolution (AC-2.1, AC-2.4):
1416
+ - --version N: resolve the Nth version (ordinal, 1-based) for this name
1417
+ - No --version: resolve latest (existing behavior)
1418
+ - If requested version doesn't exist: print available versions and exit 1 (AC-2.5)
1419
+
1420
+ Returns JSON: {"name": str, "s3_uri": str, "arn": str|null, "format": str, "technique": str, "version": str|null, "ordinal": int|null}
950
1421
  or error if not found.
951
1422
  """
952
1423
  name = args.name
1424
+ version_ordinal = getattr(args, "version", None)
1425
+
953
1426
  if not name:
954
1427
  _error_exit("--name is required", code="MISSING_ARGUMENT")
955
1428
 
1429
+ # If version is specified, use version-aware resolution
1430
+ if version_ordinal is not None:
1431
+ return _resolve_dataset_version(name, version_ordinal)
1432
+
1433
+ # No version — resolve latest (existing behavior)
956
1434
  # Try SageMaker AI Registry API first
957
1435
  if _check_ai_registry():
958
1436
  try:
@@ -966,6 +1444,8 @@ def cmd_resolve_dataset(args):
966
1444
  "arn": dataset.arn if hasattr(dataset, 'arn') else None,
967
1445
  "format": "jsonl", # AI Registry may not store format
968
1446
  "technique": getattr(dataset, 'customization_technique', '').lower() if hasattr(dataset, 'customization_technique') else "",
1447
+ "version": None,
1448
+ "ordinal": None,
969
1449
  })
970
1450
  except Exception as e:
971
1451
  # AI Registry lookup failed — fall through to local registry
@@ -979,11 +1459,92 @@ def cmd_resolve_dataset(args):
979
1459
  output = dict(entry)
980
1460
  if "arn" not in output:
981
1461
  output["arn"] = None
1462
+ # Include latest version info if available
1463
+ versions = entry.get("versions")
1464
+ if versions and len(versions) > 0:
1465
+ latest = versions[-1]
1466
+ output["s3_uri"] = latest.get("s3_uri", output.get("s3_uri", ""))
1467
+ output["version"] = latest.get("version")
1468
+ output["ordinal"] = len(versions)
1469
+ else:
1470
+ output["version"] = None
1471
+ output["ordinal"] = None
982
1472
  _output(output)
1473
+ return
983
1474
 
984
1475
  _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
985
1476
 
986
1477
 
1478
+ def _resolve_dataset_version(name, version_ordinal):
1479
+ """Resolve a specific version (by ordinal) of a named dataset.
1480
+
1481
+ Ordinal is 1-based: @v1 = first registered version, @v2 = second, etc.
1482
+ Internally, versions may be semver strings (1.0.0, 1.1.0, 1.2.0).
1483
+
1484
+ If the version doesn't exist, prints available versions and exits 1 (AC-2.5).
1485
+
1486
+ Args:
1487
+ name: Dataset name
1488
+ version_ordinal: 1-based version ordinal (e.g., 2 for the 2nd version)
1489
+ """
1490
+ # Load local registry
1491
+ entries = _load_registry(_DATASETS_REGISTRY)
1492
+
1493
+ for entry in entries:
1494
+ if entry.get("name") == name:
1495
+ versions = entry.get("versions", [])
1496
+
1497
+ if not versions:
1498
+ # Legacy entry without versions array — treat as v1
1499
+ if version_ordinal == 1:
1500
+ output = dict(entry)
1501
+ output["version"] = "1.0.0"
1502
+ output["ordinal"] = 1
1503
+ if "arn" not in output:
1504
+ output["arn"] = None
1505
+ _output(output)
1506
+ else:
1507
+ print(f"Error: Version v{version_ordinal} not found for dataset '{name}'", file=sys.stderr)
1508
+ print(f"Available versions: v1 (1.0.0)", file=sys.stderr)
1509
+ print(json.dumps({
1510
+ "error": f"Version v{version_ordinal} not found for dataset '{name}'",
1511
+ "code": "VERSION_NOT_FOUND",
1512
+ "available_versions": [{"ordinal": 1, "version": "1.0.0"}],
1513
+ }))
1514
+ sys.exit(1)
1515
+
1516
+ # Check if requested ordinal is valid (1-based index)
1517
+ if version_ordinal < 1 or version_ordinal > len(versions):
1518
+ print(f"Error: Version v{version_ordinal} not found for dataset '{name}'", file=sys.stderr)
1519
+ available = []
1520
+ for i, v in enumerate(versions, 1):
1521
+ ver_str = v.get("version", f"{i}.0.0")
1522
+ available.append({"ordinal": i, "version": ver_str})
1523
+ print(f" v{i} ({ver_str})", file=sys.stderr)
1524
+ print(json.dumps({
1525
+ "error": f"Version v{version_ordinal} not found for dataset '{name}'",
1526
+ "code": "VERSION_NOT_FOUND",
1527
+ "available_versions": available,
1528
+ }))
1529
+ sys.exit(1)
1530
+
1531
+ # Resolve the specific version (0-based index from 1-based ordinal)
1532
+ target_version = versions[version_ordinal - 1]
1533
+ _output({
1534
+ "name": name,
1535
+ "s3_uri": target_version.get("s3_uri", entry.get("s3_uri", "")),
1536
+ "arn": entry.get("arn"),
1537
+ "format": target_version.get("format", entry.get("format", "jsonl")),
1538
+ "technique": target_version.get("technique", entry.get("technique", "")),
1539
+ "version": target_version.get("version", "1.0.0"),
1540
+ "ordinal": version_ordinal,
1541
+ "hash": target_version.get("hash"),
1542
+ })
1543
+
1544
+ # Dataset name not found at all
1545
+ _error_exit(f"Dataset not found: {name}", code="DATASET_NOT_FOUND")
1546
+
1547
+
987
1548
  # ── Subcommand: resolve-evaluator ────────────────────────────────────────────
988
1549
 
989
1550
 
@@ -1052,6 +1613,7 @@ def main():
1052
1613
  adapter_parser.add_argument("--parent-version-arn", required=True, help="Base model version ARN in the same MPG")
1053
1614
  adapter_parser.add_argument("--tune-technique", default="", help="Tune technique (sft/dpo/rlvr)")
1054
1615
  adapter_parser.add_argument("--dataset-s3-uri", default="", help="Training dataset S3 URI")
1616
+ adapter_parser.add_argument("--dataset-version", default="", help="Dataset version ordinal (lineage: trained on dataset X version N)")
1055
1617
  adapter_parser.add_argument("--deployment-config", default="", help="Deployment config (e.g., gpu-vllm)")
1056
1618
  adapter_parser.add_argument("--container-image", default="", help="Container image URI")
1057
1619
  adapter_parser.add_argument("--model-data-url", default="", help="Model/adapter data S3 URI")
@@ -1068,7 +1630,7 @@ def main():
1068
1630
  # ── register-dataset ─────────────────────────────────────────────────
1069
1631
  dataset_parser = subparsers.add_parser(
1070
1632
  "register-dataset",
1071
- help="Register a dataset into the local registry (AI Registry fallback)",
1633
+ help="Register a dataset with content-aware versioning",
1072
1634
  )
1073
1635
  dataset_parser.add_argument("--name", required=True, help="Dataset name (unique identifier)")
1074
1636
  dataset_parser.add_argument("--s3-uri", required=True, help="S3 URI of the dataset")
@@ -1080,6 +1642,9 @@ def main():
1080
1642
  dataset_parser.add_argument("--column-schema", default=None,
1081
1643
  help="Column schema as JSON string")
1082
1644
  dataset_parser.add_argument("--project-name", default=None, help="Project name for context")
1645
+ dataset_parser.add_argument("--region", default=None, help="AWS region (for S3 hash computation)")
1646
+ dataset_parser.add_argument("--force", action="store_true", default=False,
1647
+ help="Force new version even if content hash matches (AC-1.7)")
1083
1648
 
1084
1649
  # ── list-datasets ─────────────────────────────────────────────────────────
1085
1650
  list_datasets_parser = subparsers.add_parser(
@@ -1089,6 +1654,13 @@ def main():
1089
1654
  list_datasets_parser.add_argument("--technique", default=None, choices=["sft", "dpo", "rlaif", "rlvr"],
1090
1655
  help="Filter by tuning technique")
1091
1656
 
1657
+ # ── list-dataset-versions ─────────────────────────────────────────────
1658
+ list_dataset_versions_parser = subparsers.add_parser(
1659
+ "list-dataset-versions",
1660
+ help="List all versions for a specific dataset by name (AC-3.3)",
1661
+ )
1662
+ list_dataset_versions_parser.add_argument("--name", required=True, help="Dataset name to list versions for")
1663
+
1092
1664
  # ── register-evaluator ────────────────────────────────────────────────
1093
1665
  evaluator_parser = subparsers.add_parser(
1094
1666
  "register-evaluator",
@@ -1134,6 +1706,8 @@ def main():
1134
1706
  help="Resolve a registered dataset by name",
1135
1707
  )
1136
1708
  resolve_dataset_parser.add_argument("--name", required=True, help="Dataset name to resolve")
1709
+ resolve_dataset_parser.add_argument("--version", type=int, default=None,
1710
+ help="Version ordinal to resolve (e.g., 2 for the 2nd version). Default: latest.")
1137
1711
 
1138
1712
  # ── resolve-evaluator ─────────────────────────────────────────────────
1139
1713
  resolve_evaluator_parser = subparsers.add_parser(
@@ -1165,6 +1739,8 @@ def main():
1165
1739
  cmd_register_dataset(args)
1166
1740
  elif args.command == "list-datasets":
1167
1741
  cmd_list_datasets(args)
1742
+ elif args.command == "list-dataset-versions":
1743
+ cmd_list_dataset_versions(args)
1168
1744
  elif args.command == "register-evaluator":
1169
1745
  cmd_register_evaluator(args)
1170
1746
  elif args.command == "list-adapters":