mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,247 @@
1
+ from pathlib import Path
2
+ from typing import Dict, List, Optional, Tuple, Any
3
+ from datasets import load_dataset as hf_load_dataset
4
+ from datasets import DatasetDict, ClassLabel, Sequence, Value
5
+ from datasets import Dataset as HFDataset
6
+
7
+ class DatasetArgs:
8
+ """
9
+ Arguments for dataset loading
10
+ If a remapping of column names is needed, specify the field names here.
11
+ main text : text_field
12
+ label / classification : label_field
13
+ text pair (optional for contrastive learning, sentence similarity or just sequence classification with 2 inputs) : text_pair_field
14
+ negative example (optional for triplet loss) : negative_field
15
+ """
16
+ def __init__(self, data: str, task_type: str,
17
+ text_field: Optional[str] = "text", label_field: Optional[str] = "label",
18
+ text_pair_field: Optional[str] = None, negative_field: Optional[str] = None, test: Optional[bool]=False
19
+ ):
20
+ self.data = data
21
+ self.task_type = task_type
22
+ self.text_field = text_field
23
+ self.label_field = label_field
24
+ self.text_pair_field = text_pair_field
25
+ self.negative_field = negative_field
26
+ self.test = test # whether to create a test set if not present
27
+
28
+
29
+ def _standardize_column_names(dataset: HFDataset, args: DatasetArgs) -> HFDataset:
30
+ """
31
+ Renames columns to standard 'text', 'label', 'text_pair', 'negative' expected by collators.
32
+
33
+ Common mappings for various tasks:
34
+ - similarity : Anchor / Sentence A -> 'text'
35
+ - similarity : The Positive / Reference / Sentence B -> 'text_pair'
36
+ - similarity : The Hard Negative / Sentence C -> 'negative' (optional)
37
+ - similarity : Similarity score for Regression -> 'label' (optional)
38
+
39
+ Manual mappings can be specified via args usiing text_field, label_field, text_pair_field, negative_field.
40
+ text_field : column name for the main text input
41
+ label_field : column name for the label / score
42
+ text_pair_field (optional): column name for the paired text input / sentence B (used for cross-encoders or bi-encoders)
43
+ negative_field (optional): column name for the negative example (used for triplet training)
44
+ """
45
+
46
+ mapping = {}
47
+ # Manual field mappings
48
+ if args.text_field != "text" and args.text_field in dataset.column_names:
49
+ mapping[args.text_field] = "text"
50
+
51
+ if args.text_pair_field and args.text_pair_field != "text_pair" and args.text_pair_field in dataset.column_names:
52
+ mapping[args.text_pair_field] = "text_pair"
53
+
54
+ if args.label_field != "label" and args.label_field in dataset.column_names:
55
+ mapping[args.label_field] = "label"
56
+
57
+ if args.negative_field and args.negative_field != "negative" and args.negative_field in dataset.column_names:
58
+ mapping[args.negative_field] = "negative"
59
+
60
+ # Handle common alternative column names for text classification
61
+ if args.task_type == "sentence-similarity" or args.task_type == "sentence-transformers":
62
+ # handle Sequence classification : "sentence1" -> "text", "sentence2" -> "text_pair", "score" = "label"
63
+ if "sentence1" in dataset.column_names and "sentence2" in dataset.column_names and "score" in dataset.column_names:
64
+ mapping["sentence1"] = "text"
65
+ mapping["sentence2"] = "text_pair"
66
+ mapping["score"] = "label"
67
+
68
+ # Handle Anchor, Positives and Negatives for Triplet Training
69
+ if "anchor" in dataset.column_names and "positive" in dataset.column_names and "negative" in dataset.column_names:
70
+ mapping["anchor"] = "text"
71
+ mapping["positive"] = "text_pair"
72
+ mapping["negative"] = "negative"
73
+
74
+ if "pos" in dataset.column_names:
75
+ mapping["pos"] = "text_pair"
76
+ if "neg" in dataset.column_names:
77
+ mapping["neg"] = "negative"
78
+
79
+ # Handle Token Classification: usually "tokens" -> "text", "ner_tags" -> "labels"
80
+ if args.task_type == "token-classification":
81
+ if "tokens" in dataset.column_names and "text" not in mapping.values():
82
+ mapping["tokens"] = "text"
83
+ if "ner_tags" in dataset.column_names and "labels" not in mapping.values():
84
+ mapping["ner_tags"] = "labels"
85
+
86
+ if mapping:
87
+ dataset = dataset.rename_columns(mapping)
88
+
89
+ keep_columns = {"text", "text_pair", "label", "labels", "negative"}
90
+ existing_columns = set(dataset.column_names)
91
+ columns_to_select = list(keep_columns.intersection(existing_columns))
92
+
93
+ # Check if we have at least 'text'
94
+ if "text" not in columns_to_select:
95
+ print(f"Warning: Standard 'text' column not found in dataset columns: {dataset.column_names}")
96
+
97
+ dataset = dataset.select_columns(columns_to_select)
98
+
99
+ return dataset
100
+
101
+
102
+ def get_label_mapping(dataset: HFDataset, args: DatasetArgs) -> Tuple[Optional[Dict[int, str]], Optional[Dict[str, int]]]:
103
+ """
104
+ Derives id2label and label2id from a dataset.
105
+ Prioritizes dataset features (from config), falls back to scanning unique values in data.
106
+ """
107
+ if args.task_type not in ["text-classification", "token-classification"]:
108
+ return None, None
109
+
110
+ # Determine the target column name based on task
111
+ target_col = "labels" if args.task_type == "token-classification" else "label"
112
+ if target_col not in dataset.column_names:
113
+ # Fallback: sometimes text-classification uses 'labels' or vice versa
114
+ if "label" in dataset.column_names: target_col = "label"
115
+ elif "labels" in dataset.column_names: target_col = "labels"
116
+ else: return None, None
117
+
118
+ labels = []
119
+
120
+ # Strategy 1: Check Features (Config/Hub Metadata) ---
121
+ feature = dataset.features[target_col]
122
+
123
+ # Case A: Standard ClassLabel (Text Classification)
124
+ if isinstance(feature, ClassLabel):
125
+ labels = feature.names
126
+
127
+ # Case B: Sequence of ClassLabels (Token Classification)
128
+ elif isinstance(feature, Sequence) and isinstance(feature.feature, ClassLabel):
129
+ labels = feature.feature.names
130
+
131
+ # Strategy 2: Scan Data (Raw JSONL/CSV) ---
132
+ if not labels:
133
+ if len(dataset) > 0:
134
+ if args.task_type == "token-classification":
135
+ # Flatten list of lists to find unique tags
136
+ unique_tags = set()
137
+ for row in dataset[target_col]:
138
+ unique_tags.update(row)
139
+ labels = sorted(list(unique_tags))
140
+ else:
141
+ # Standard text classification scan
142
+ labels = sorted(list(set(dataset[target_col])))
143
+
144
+ if not labels:
145
+ return None, None
146
+
147
+ # Construct mappings
148
+ id2label = {k: str(v) for k, v in enumerate(labels)}
149
+ label2id = {str(v): k for k, v in enumerate(labels)}
150
+
151
+ return id2label, label2id
152
+
153
+
154
+ def load_dataset(args: DatasetArgs) -> Tuple[Optional[HFDataset], Optional[HFDataset], Optional[HFDataset], Dict[str, int], Dict[int, str]]:
155
+ if not hasattr(args, "task_type"):
156
+ raise ValueError("Must specify task_type in args")
157
+
158
+ supported_tasks = ["text-classification", "masked-lm", "token-classification", "sentence-transformers", "sentence-similarity"]
159
+ if args.task_type not in supported_tasks:
160
+ raise ValueError(f"Unsupported task type: {args.task_type}")
161
+
162
+ # Load from Hub or Local
163
+ data_path = Path(args.data)
164
+ if data_path.exists():
165
+ # Detect format based on extension if it's a file, or assume structure if folder
166
+ if data_path.is_file():
167
+ # Single file loading
168
+ ext = data_path.suffix[1:] # remove dot
169
+ ext = "json" if ext == "jsonl" else ext
170
+ raw_datasets = hf_load_dataset(ext, data_files=str(data_path))
171
+ # If it loaded as 'train' only, we split later
172
+ else:
173
+ # It's a directory. Check for specific files.
174
+ data_files = {}
175
+ for split in ["train", "validation", "test"]:
176
+ for ext in ["jsonl", "json", "parquet", "csv"]:
177
+ fname = f"{split}.{ext}"
178
+ if (data_path / fname).exists():
179
+ data_files[split] = str(data_path / fname)
180
+
181
+ if not data_files:
182
+ raise ValueError(f"No train/val/test files found in {data_path}")
183
+
184
+ # Determine loader type from first found file
185
+ first_file = list(data_files.values())[0]
186
+ ext = first_file.split(".")[-1]
187
+ ext = "json" if ext == "jsonl" else ext
188
+ raw_datasets = hf_load_dataset(ext, data_files=data_files)
189
+
190
+ else:
191
+ # Load from Hub
192
+ try:
193
+ raw_datasets = hf_load_dataset(args.data)
194
+ except Exception as e:
195
+ print(f"Failed to load as standard dataset: {e}. Trying simple load...")
196
+ raw_datasets = hf_load_dataset(args.data, split="train")
197
+ raw_datasets = DatasetDict({"train": raw_datasets})
198
+
199
+ if "train" not in raw_datasets:
200
+ raise ValueError("Training split not found in dataset")
201
+
202
+ # Handle Splits (Standard 70/15/15) or whatever the actual splits are
203
+ if "validation" not in raw_datasets and "test" not in raw_datasets:
204
+ if args.test:
205
+ t_t_split = raw_datasets["train"].train_test_split(test_size=0.15, seed=42)
206
+ raw_datasets["test"] = t_t_split["test"]
207
+ t_v_split = t_t_split["train"].train_test_split(test_size=0.176, seed=42)
208
+ raw_datasets["train"] = t_v_split["train"]
209
+ raw_datasets["validation"] = t_v_split["test"]
210
+ else : # create only validation split
211
+ t_v_split = raw_datasets["train"].train_test_split(test_size=0.176, seed=42)
212
+ raw_datasets["train"] = t_v_split["train"]
213
+ raw_datasets["validation"] = t_v_split["test"]
214
+ elif "validation" not in raw_datasets and "test" in raw_datasets:
215
+ if args.test:
216
+ t_v_split = raw_datasets["train"].train_test_split(test_size=0.176, seed=42)
217
+ raw_datasets["train"] = t_v_split["train"]
218
+ raw_datasets["validation"] = t_v_split["test"]
219
+ else : # use test split as validation split
220
+ raw_datasets["validation"] = raw_datasets["test"]
221
+ raw_datasets["test"] = None
222
+ elif "test" not in raw_datasets and args.test:
223
+ t_t_split = raw_datasets["train"].train_test_split(test_size=0.176, seed=42)
224
+ raw_datasets["train"] = t_t_split["train"]
225
+ raw_datasets["test"] = t_t_split["test"]
226
+
227
+ # Standardize Columns
228
+ for split in raw_datasets.keys():
229
+ if raw_datasets[split] is not None:
230
+ print(f"Standardizing columns for split '{split}' ({len(raw_datasets[split])} examples)...")
231
+ raw_datasets[split] = _standardize_column_names(raw_datasets[split], args)
232
+
233
+ # Get label mappings if applicable
234
+ id2label, label2id = None, None
235
+ if raw_datasets.get("train") is not None:
236
+ id2label, label2id = get_label_mapping(raw_datasets["train"], args)
237
+
238
+ if id2label:
239
+ print(f"Found {len(id2label)} labels. First 5: {list(id2label.values())[:5]}")
240
+
241
+ return (
242
+ raw_datasets.get("train"),
243
+ raw_datasets.get("validation"),
244
+ raw_datasets.get("test"),
245
+ id2label,
246
+ label2id
247
+ )
@@ -0,0 +1,206 @@
1
+ from typing import List, Optional
2
+ import importlib
3
+
4
+ # Pipeline to module mapping
5
+ _PIPELINE_TO_MODULE = {
6
+ "text-classification": "text_classification",
7
+ "sentence-similarity": "sentence_similarity",
8
+ "sentence-transformers": "sentence_similarity", # Same module, similar code
9
+ "embeddings": "embeddings",
10
+ "masked-lm": "masked_lm",
11
+ "zero-shot-classification": "zero_shot",
12
+ }
13
+
14
+
15
+ def get_inference_code(
16
+ pipeline: str,
17
+ model_path: str = "{{MODEL_PATH}}",
18
+ **kwargs,
19
+ ) -> str:
20
+ """
21
+ Get inference example code for a model card.
22
+
23
+ This function returns clean, runnable Python code that can be directly
24
+ used in HuggingFace model cards. The code comes from the same source
25
+ as the test suite, ensuring consistency.
26
+
27
+ Args:
28
+ pipeline: The pipeline type (e.g., "text-classification", "sentence-similarity")
29
+ model_path: The model path to use in the example. Use "{{MODEL_PATH}}" as a
30
+ placeholder if the actual path isn't known yet.
31
+ **kwargs: Additional arguments passed to the specific pipeline's get_example_code()
32
+ function. Common options:
33
+ - text: str - for masked-lm, zero-shot
34
+ - texts: List[str] - for text-classification, embeddings
35
+ - text_pairs: List[str] - for text-classification
36
+ - documents: List[str] - for text-classification
37
+ - queries: List[str] - for text-classification
38
+ - is_regression: bool - for text-classification
39
+ - use_late_interaction: bool - for sentence-similarity (ColBERT-style)
40
+ - label_candidates: List or Dict - for zero-shot
41
+
42
+ Returns:
43
+ Formatted Python code string ready for inclusion in a model card.
44
+
45
+ """
46
+ if pipeline not in _PIPELINE_TO_MODULE:
47
+ raise ValueError(
48
+ f"Unknown pipeline: {pipeline}. "
49
+ f"Supported pipelines: {list(_PIPELINE_TO_MODULE.keys())}"
50
+ )
51
+
52
+ module_name = _PIPELINE_TO_MODULE[pipeline]
53
+
54
+ # Try importing from tests.inference_examples first (development)
55
+ # Fall back to relative import if tests not available
56
+ try:
57
+ module = importlib.import_module(f"tests.inference_examples.{module_name}")
58
+ except ImportError:
59
+ # If running from within the library, try relative path
60
+ try:
61
+ import tests.inference_examples
62
+ module = getattr(tests.inference_examples, module_name)
63
+ except (ImportError, AttributeError):
64
+ raise ImportError(
65
+ f"Could not import inference example module for {pipeline}. "
66
+ "Make sure the tests package is installed or accessible."
67
+ )
68
+
69
+ # Call the module's get_example_code function
70
+ return module.get_example_code(model_path=model_path, **kwargs)
71
+
72
+
73
+ def get_available_pipelines() -> List[str]:
74
+ """Get list of pipelines that have model card code templates."""
75
+ return list(_PIPELINE_TO_MODULE.keys())
76
+
77
+
78
+ def generate_model_card_section(
79
+ pipeline: str,
80
+ model_path: str,
81
+ title: str = "Usage with mlx-raclate",
82
+ **kwargs,
83
+ ) -> str:
84
+ """
85
+ Generate a complete model card section with title and code block.
86
+
87
+ Args:
88
+ pipeline: The pipeline type
89
+ model_path: The model path
90
+ title: Section title
91
+ **kwargs: Additional arguments for get_inference_code()
92
+
93
+ Returns:
94
+ Markdown-formatted section for a model card
95
+ """
96
+ code = get_inference_code(pipeline=pipeline, model_path=model_path, **kwargs)
97
+
98
+ return f"""## {title}
99
+
100
+ This model can be used with [mlx-raclate](https://github.com/pappitti/mlx-raclate) for native inference on Apple Silicon.
101
+
102
+ ```python
103
+ {code}
104
+ ```
105
+ """
106
+
107
+
108
+ def get_code_for_trained_model(
109
+ pipeline: str,
110
+ model_path: str,
111
+ base_model: str,
112
+ training_task: Optional[str] = None,
113
+ **kwargs,
114
+ ) -> str:
115
+ """
116
+ Generate model card content for a newly trained model.
117
+
118
+ This is intended to be called after training, to generate the
119
+ inference example code for the model card before uploading to HuggingFace.
120
+
121
+ Args:
122
+ pipeline: Pipeline the model was trained for
123
+ model_path: Path where the model will be uploaded (e.g., "my-org/my-model")
124
+ base_model: The base model used for training
125
+ training_task: Optional description of the training task
126
+ **kwargs: Additional arguments for the code template
127
+
128
+ Returns:
129
+ Complete markdown section for the model card
130
+ """
131
+ section = generate_model_card_section(
132
+ pipeline=pipeline,
133
+ model_path=model_path,
134
+ **kwargs,
135
+ )
136
+
137
+ # Add metadata about training
138
+ metadata = f"""
139
+ ### Model Details
140
+
141
+ - **Base Model**: [{base_model}](https://huggingface.co/{base_model})
142
+ - **Pipeline**: `{pipeline}`
143
+ - **Framework**: [mlx-raclate](https://github.com/pappitti/mlx-raclate) (MLX)
144
+ """
145
+ if training_task:
146
+ metadata += f"- **Training Task**: {training_task}\n"
147
+
148
+ return section + metadata
149
+
150
+
151
+ # ============================================================================
152
+ # CLI INTERFACE
153
+ # ============================================================================
154
+
155
+ def main():
156
+ """CLI for generating model card code snippets."""
157
+ import argparse
158
+
159
+ parser = argparse.ArgumentParser(
160
+ description="Generate inference code for model cards"
161
+ )
162
+ parser.add_argument(
163
+ "pipeline",
164
+ choices=get_available_pipelines(),
165
+ help="Pipeline type"
166
+ )
167
+ parser.add_argument(
168
+ "--model-path",
169
+ default="{{MODEL_PATH}}",
170
+ help="Model path/ID for the example"
171
+ )
172
+ parser.add_argument(
173
+ "--late-interaction",
174
+ action="store_true",
175
+ help="Use late interaction for sentence-similarity"
176
+ )
177
+ parser.add_argument(
178
+ "--full-section",
179
+ action="store_true",
180
+ help="Generate full markdown section instead of just code"
181
+ )
182
+
183
+ args = parser.parse_args()
184
+
185
+ kwargs = {}
186
+ if args.late_interaction:
187
+ kwargs["use_late_interaction"] = True
188
+
189
+ if args.full_section:
190
+ output = generate_model_card_section(
191
+ pipeline=args.pipeline,
192
+ model_path=args.model_path,
193
+ **kwargs,
194
+ )
195
+ else:
196
+ output = get_inference_code(
197
+ pipeline=args.pipeline,
198
+ model_path=args.model_path,
199
+ **kwargs,
200
+ )
201
+
202
+ print(output)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()