mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Any
|
|
3
|
+
from datasets import load_dataset as hf_load_dataset
|
|
4
|
+
from datasets import DatasetDict, ClassLabel, Sequence, Value
|
|
5
|
+
from datasets import Dataset as HFDataset
|
|
6
|
+
|
|
7
|
+
class DatasetArgs:
|
|
8
|
+
"""
|
|
9
|
+
Arguments for dataset loading
|
|
10
|
+
If a remapping of column names is needed, specify the field names here.
|
|
11
|
+
main text : text_field
|
|
12
|
+
label / classification : label_field
|
|
13
|
+
text pair (optional for contrastive learning, sentence similarity or just sequence classification with 2 inputs) : text_pair_field
|
|
14
|
+
negative example (optional for triplet loss) : negative_field
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, data: str, task_type: str,
|
|
17
|
+
text_field: Optional[str] = "text", label_field: Optional[str] = "label",
|
|
18
|
+
text_pair_field: Optional[str] = None, negative_field: Optional[str] = None, test: Optional[bool]=False
|
|
19
|
+
):
|
|
20
|
+
self.data = data
|
|
21
|
+
self.task_type = task_type
|
|
22
|
+
self.text_field = text_field
|
|
23
|
+
self.label_field = label_field
|
|
24
|
+
self.text_pair_field = text_pair_field
|
|
25
|
+
self.negative_field = negative_field
|
|
26
|
+
self.test = test # whether to create a test set if not present
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _standardize_column_names(dataset: HFDataset, args: DatasetArgs) -> HFDataset:
|
|
30
|
+
"""
|
|
31
|
+
Renames columns to standard 'text', 'label', 'text_pair', 'negative' expected by collators.
|
|
32
|
+
|
|
33
|
+
Common mappings for various tasks:
|
|
34
|
+
- similarity : Anchor / Sentence A -> 'text'
|
|
35
|
+
- similarity : The Positive / Reference / Sentence B -> 'text_pair'
|
|
36
|
+
- similarity : The Hard Negative / Sentence C -> 'negative' (optional)
|
|
37
|
+
- similarity : Similarity score for Regression -> 'label' (optional)
|
|
38
|
+
|
|
39
|
+
Manual mappings can be specified via args usiing text_field, label_field, text_pair_field, negative_field.
|
|
40
|
+
text_field : column name for the main text input
|
|
41
|
+
label_field : column name for the label / score
|
|
42
|
+
text_pair_field (optional): column name for the paired text input / sentence B (used for cross-encoders or bi-encoders)
|
|
43
|
+
negative_field (optional): column name for the negative example (used for triplet training)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
mapping = {}
|
|
47
|
+
# Manual field mappings
|
|
48
|
+
if args.text_field != "text" and args.text_field in dataset.column_names:
|
|
49
|
+
mapping[args.text_field] = "text"
|
|
50
|
+
|
|
51
|
+
if args.text_pair_field and args.text_pair_field != "text_pair" and args.text_pair_field in dataset.column_names:
|
|
52
|
+
mapping[args.text_pair_field] = "text_pair"
|
|
53
|
+
|
|
54
|
+
if args.label_field != "label" and args.label_field in dataset.column_names:
|
|
55
|
+
mapping[args.label_field] = "label"
|
|
56
|
+
|
|
57
|
+
if args.negative_field and args.negative_field != "negative" and args.negative_field in dataset.column_names:
|
|
58
|
+
mapping[args.negative_field] = "negative"
|
|
59
|
+
|
|
60
|
+
# Handle common alternative column names for text classification
|
|
61
|
+
if args.task_type == "sentence-similarity" or args.task_type == "sentence-transformers":
|
|
62
|
+
# handle Sequence classification : "sentence1" -> "text", "sentence2" -> "text_pair", "score" = "label"
|
|
63
|
+
if "sentence1" in dataset.column_names and "sentence2" in dataset.column_names and "score" in dataset.column_names:
|
|
64
|
+
mapping["sentence1"] = "text"
|
|
65
|
+
mapping["sentence2"] = "text_pair"
|
|
66
|
+
mapping["score"] = "label"
|
|
67
|
+
|
|
68
|
+
# Handle Anchor, Positives and Negatives for Triplet Training
|
|
69
|
+
if "anchor" in dataset.column_names and "positive" in dataset.column_names and "negative" in dataset.column_names:
|
|
70
|
+
mapping["anchor"] = "text"
|
|
71
|
+
mapping["positive"] = "text_pair"
|
|
72
|
+
mapping["negative"] = "negative"
|
|
73
|
+
|
|
74
|
+
if "pos" in dataset.column_names:
|
|
75
|
+
mapping["pos"] = "text_pair"
|
|
76
|
+
if "neg" in dataset.column_names:
|
|
77
|
+
mapping["neg"] = "negative"
|
|
78
|
+
|
|
79
|
+
# Handle Token Classification: usually "tokens" -> "text", "ner_tags" -> "labels"
|
|
80
|
+
if args.task_type == "token-classification":
|
|
81
|
+
if "tokens" in dataset.column_names and "text" not in mapping.values():
|
|
82
|
+
mapping["tokens"] = "text"
|
|
83
|
+
if "ner_tags" in dataset.column_names and "labels" not in mapping.values():
|
|
84
|
+
mapping["ner_tags"] = "labels"
|
|
85
|
+
|
|
86
|
+
if mapping:
|
|
87
|
+
dataset = dataset.rename_columns(mapping)
|
|
88
|
+
|
|
89
|
+
keep_columns = {"text", "text_pair", "label", "labels", "negative"}
|
|
90
|
+
existing_columns = set(dataset.column_names)
|
|
91
|
+
columns_to_select = list(keep_columns.intersection(existing_columns))
|
|
92
|
+
|
|
93
|
+
# Check if we have at least 'text'
|
|
94
|
+
if "text" not in columns_to_select:
|
|
95
|
+
print(f"Warning: Standard 'text' column not found in dataset columns: {dataset.column_names}")
|
|
96
|
+
|
|
97
|
+
dataset = dataset.select_columns(columns_to_select)
|
|
98
|
+
|
|
99
|
+
return dataset
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_label_mapping(dataset: HFDataset, args: DatasetArgs) -> Tuple[Optional[Dict[int, str]], Optional[Dict[str, int]]]:
|
|
103
|
+
"""
|
|
104
|
+
Derives id2label and label2id from a dataset.
|
|
105
|
+
Prioritizes dataset features (from config), falls back to scanning unique values in data.
|
|
106
|
+
"""
|
|
107
|
+
if args.task_type not in ["text-classification", "token-classification"]:
|
|
108
|
+
return None, None
|
|
109
|
+
|
|
110
|
+
# Determine the target column name based on task
|
|
111
|
+
target_col = "labels" if args.task_type == "token-classification" else "label"
|
|
112
|
+
if target_col not in dataset.column_names:
|
|
113
|
+
# Fallback: sometimes text-classification uses 'labels' or vice versa
|
|
114
|
+
if "label" in dataset.column_names: target_col = "label"
|
|
115
|
+
elif "labels" in dataset.column_names: target_col = "labels"
|
|
116
|
+
else: return None, None
|
|
117
|
+
|
|
118
|
+
labels = []
|
|
119
|
+
|
|
120
|
+
# Strategy 1: Check Features (Config/Hub Metadata) ---
|
|
121
|
+
feature = dataset.features[target_col]
|
|
122
|
+
|
|
123
|
+
# Case A: Standard ClassLabel (Text Classification)
|
|
124
|
+
if isinstance(feature, ClassLabel):
|
|
125
|
+
labels = feature.names
|
|
126
|
+
|
|
127
|
+
# Case B: Sequence of ClassLabels (Token Classification)
|
|
128
|
+
elif isinstance(feature, Sequence) and isinstance(feature.feature, ClassLabel):
|
|
129
|
+
labels = feature.feature.names
|
|
130
|
+
|
|
131
|
+
# Strategy 2: Scan Data (Raw JSONL/CSV) ---
|
|
132
|
+
if not labels:
|
|
133
|
+
if len(dataset) > 0:
|
|
134
|
+
if args.task_type == "token-classification":
|
|
135
|
+
# Flatten list of lists to find unique tags
|
|
136
|
+
unique_tags = set()
|
|
137
|
+
for row in dataset[target_col]:
|
|
138
|
+
unique_tags.update(row)
|
|
139
|
+
labels = sorted(list(unique_tags))
|
|
140
|
+
else:
|
|
141
|
+
# Standard text classification scan
|
|
142
|
+
labels = sorted(list(set(dataset[target_col])))
|
|
143
|
+
|
|
144
|
+
if not labels:
|
|
145
|
+
return None, None
|
|
146
|
+
|
|
147
|
+
# Construct mappings
|
|
148
|
+
id2label = {k: str(v) for k, v in enumerate(labels)}
|
|
149
|
+
label2id = {str(v): k for k, v in enumerate(labels)}
|
|
150
|
+
|
|
151
|
+
return id2label, label2id
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def load_dataset(args: DatasetArgs) -> Tuple[Optional[HFDataset], Optional[HFDataset], Optional[HFDataset], Dict[str, int], Dict[int, str]]:
|
|
155
|
+
if not hasattr(args, "task_type"):
|
|
156
|
+
raise ValueError("Must specify task_type in args")
|
|
157
|
+
|
|
158
|
+
supported_tasks = ["text-classification", "masked-lm", "token-classification", "sentence-transformers", "sentence-similarity"]
|
|
159
|
+
if args.task_type not in supported_tasks:
|
|
160
|
+
raise ValueError(f"Unsupported task type: {args.task_type}")
|
|
161
|
+
|
|
162
|
+
# Load from Hub or Local
|
|
163
|
+
data_path = Path(args.data)
|
|
164
|
+
if data_path.exists():
|
|
165
|
+
# Detect format based on extension if it's a file, or assume structure if folder
|
|
166
|
+
if data_path.is_file():
|
|
167
|
+
# Single file loading
|
|
168
|
+
ext = data_path.suffix[1:] # remove dot
|
|
169
|
+
ext = "json" if ext == "jsonl" else ext
|
|
170
|
+
raw_datasets = hf_load_dataset(ext, data_files=str(data_path))
|
|
171
|
+
# If it loaded as 'train' only, we split later
|
|
172
|
+
else:
|
|
173
|
+
# It's a directory. Check for specific files.
|
|
174
|
+
data_files = {}
|
|
175
|
+
for split in ["train", "validation", "test"]:
|
|
176
|
+
for ext in ["jsonl", "json", "parquet", "csv"]:
|
|
177
|
+
fname = f"{split}.{ext}"
|
|
178
|
+
if (data_path / fname).exists():
|
|
179
|
+
data_files[split] = str(data_path / fname)
|
|
180
|
+
|
|
181
|
+
if not data_files:
|
|
182
|
+
raise ValueError(f"No train/val/test files found in {data_path}")
|
|
183
|
+
|
|
184
|
+
# Determine loader type from first found file
|
|
185
|
+
first_file = list(data_files.values())[0]
|
|
186
|
+
ext = first_file.split(".")[-1]
|
|
187
|
+
ext = "json" if ext == "jsonl" else ext
|
|
188
|
+
raw_datasets = hf_load_dataset(ext, data_files=data_files)
|
|
189
|
+
|
|
190
|
+
else:
|
|
191
|
+
# Load from Hub
|
|
192
|
+
try:
|
|
193
|
+
raw_datasets = hf_load_dataset(args.data)
|
|
194
|
+
except Exception as e:
|
|
195
|
+
print(f"Failed to load as standard dataset: {e}. Trying simple load...")
|
|
196
|
+
raw_datasets = hf_load_dataset(args.data, split="train")
|
|
197
|
+
raw_datasets = DatasetDict({"train": raw_datasets})
|
|
198
|
+
|
|
199
|
+
if "train" not in raw_datasets:
|
|
200
|
+
raise ValueError("Training split not found in dataset")
|
|
201
|
+
|
|
202
|
+
# Handle Splits (Standard 70/15/15) or whatever the actual splits are
|
|
203
|
+
if "validation" not in raw_datasets and "test" not in raw_datasets:
|
|
204
|
+
if args.test:
|
|
205
|
+
t_t_split = raw_datasets["train"].train_test_split(test_size=0.15, seed=42)
|
|
206
|
+
raw_datasets["test"] = t_t_split["test"]
|
|
207
|
+
t_v_split = t_t_split["train"].train_test_split(test_size=0.176, seed=42)
|
|
208
|
+
raw_datasets["train"] = t_v_split["train"]
|
|
209
|
+
raw_datasets["validation"] = t_v_split["test"]
|
|
210
|
+
else : # create only validation split
|
|
211
|
+
t_v_split = raw_datasets["train"].train_test_split(test_size=0.176, seed=42)
|
|
212
|
+
raw_datasets["train"] = t_v_split["train"]
|
|
213
|
+
raw_datasets["validation"] = t_v_split["test"]
|
|
214
|
+
elif "validation" not in raw_datasets and "test" in raw_datasets:
|
|
215
|
+
if args.test:
|
|
216
|
+
t_v_split = raw_datasets["train"].train_test_split(test_size=0.176, seed=42)
|
|
217
|
+
raw_datasets["train"] = t_v_split["train"]
|
|
218
|
+
raw_datasets["validation"] = t_v_split["test"]
|
|
219
|
+
else : # use test split as validation split
|
|
220
|
+
raw_datasets["validation"] = raw_datasets["test"]
|
|
221
|
+
raw_datasets["test"] = None
|
|
222
|
+
elif "test" not in raw_datasets and args.test:
|
|
223
|
+
t_t_split = raw_datasets["train"].train_test_split(test_size=0.176, seed=42)
|
|
224
|
+
raw_datasets["train"] = t_t_split["train"]
|
|
225
|
+
raw_datasets["test"] = t_t_split["test"]
|
|
226
|
+
|
|
227
|
+
# Standardize Columns
|
|
228
|
+
for split in raw_datasets.keys():
|
|
229
|
+
if raw_datasets[split] is not None:
|
|
230
|
+
print(f"Standardizing columns for split '{split}' ({len(raw_datasets[split])} examples)...")
|
|
231
|
+
raw_datasets[split] = _standardize_column_names(raw_datasets[split], args)
|
|
232
|
+
|
|
233
|
+
# Get label mappings if applicable
|
|
234
|
+
id2label, label2id = None, None
|
|
235
|
+
if raw_datasets.get("train") is not None:
|
|
236
|
+
id2label, label2id = get_label_mapping(raw_datasets["train"], args)
|
|
237
|
+
|
|
238
|
+
if id2label:
|
|
239
|
+
print(f"Found {len(id2label)} labels. First 5: {list(id2label.values())[:5]}")
|
|
240
|
+
|
|
241
|
+
return (
|
|
242
|
+
raw_datasets.get("train"),
|
|
243
|
+
raw_datasets.get("validation"),
|
|
244
|
+
raw_datasets.get("test"),
|
|
245
|
+
id2label,
|
|
246
|
+
label2id
|
|
247
|
+
)
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
import importlib
|
|
3
|
+
|
|
4
|
+
# Pipeline to module mapping
|
|
5
|
+
_PIPELINE_TO_MODULE = {
|
|
6
|
+
"text-classification": "text_classification",
|
|
7
|
+
"sentence-similarity": "sentence_similarity",
|
|
8
|
+
"sentence-transformers": "sentence_similarity", # Same module, similar code
|
|
9
|
+
"embeddings": "embeddings",
|
|
10
|
+
"masked-lm": "masked_lm",
|
|
11
|
+
"zero-shot-classification": "zero_shot",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_inference_code(
|
|
16
|
+
pipeline: str,
|
|
17
|
+
model_path: str = "{{MODEL_PATH}}",
|
|
18
|
+
**kwargs,
|
|
19
|
+
) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Get inference example code for a model card.
|
|
22
|
+
|
|
23
|
+
This function returns clean, runnable Python code that can be directly
|
|
24
|
+
used in HuggingFace model cards. The code comes from the same source
|
|
25
|
+
as the test suite, ensuring consistency.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
pipeline: The pipeline type (e.g., "text-classification", "sentence-similarity")
|
|
29
|
+
model_path: The model path to use in the example. Use "{{MODEL_PATH}}" as a
|
|
30
|
+
placeholder if the actual path isn't known yet.
|
|
31
|
+
**kwargs: Additional arguments passed to the specific pipeline's get_example_code()
|
|
32
|
+
function. Common options:
|
|
33
|
+
- text: str - for masked-lm, zero-shot
|
|
34
|
+
- texts: List[str] - for text-classification, embeddings
|
|
35
|
+
- text_pairs: List[str] - for text-classification
|
|
36
|
+
- documents: List[str] - for text-classification
|
|
37
|
+
- queries: List[str] - for text-classification
|
|
38
|
+
- is_regression: bool - for text-classification
|
|
39
|
+
- use_late_interaction: bool - for sentence-similarity (ColBERT-style)
|
|
40
|
+
- label_candidates: List or Dict - for zero-shot
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Formatted Python code string ready for inclusion in a model card.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
if pipeline not in _PIPELINE_TO_MODULE:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Unknown pipeline: {pipeline}. "
|
|
49
|
+
f"Supported pipelines: {list(_PIPELINE_TO_MODULE.keys())}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
module_name = _PIPELINE_TO_MODULE[pipeline]
|
|
53
|
+
|
|
54
|
+
# Try importing from tests.inference_examples first (development)
|
|
55
|
+
# Fall back to relative import if tests not available
|
|
56
|
+
try:
|
|
57
|
+
module = importlib.import_module(f"tests.inference_examples.{module_name}")
|
|
58
|
+
except ImportError:
|
|
59
|
+
# If running from within the library, try relative path
|
|
60
|
+
try:
|
|
61
|
+
import tests.inference_examples
|
|
62
|
+
module = getattr(tests.inference_examples, module_name)
|
|
63
|
+
except (ImportError, AttributeError):
|
|
64
|
+
raise ImportError(
|
|
65
|
+
f"Could not import inference example module for {pipeline}. "
|
|
66
|
+
"Make sure the tests package is installed or accessible."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Call the module's get_example_code function
|
|
70
|
+
return module.get_example_code(model_path=model_path, **kwargs)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_available_pipelines() -> List[str]:
|
|
74
|
+
"""Get list of pipelines that have model card code templates."""
|
|
75
|
+
return list(_PIPELINE_TO_MODULE.keys())
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def generate_model_card_section(
|
|
79
|
+
pipeline: str,
|
|
80
|
+
model_path: str,
|
|
81
|
+
title: str = "Usage with mlx-raclate",
|
|
82
|
+
**kwargs,
|
|
83
|
+
) -> str:
|
|
84
|
+
"""
|
|
85
|
+
Generate a complete model card section with title and code block.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
pipeline: The pipeline type
|
|
89
|
+
model_path: The model path
|
|
90
|
+
title: Section title
|
|
91
|
+
**kwargs: Additional arguments for get_inference_code()
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Markdown-formatted section for a model card
|
|
95
|
+
"""
|
|
96
|
+
code = get_inference_code(pipeline=pipeline, model_path=model_path, **kwargs)
|
|
97
|
+
|
|
98
|
+
return f"""## {title}
|
|
99
|
+
|
|
100
|
+
This model can be used with [mlx-raclate](https://github.com/pappitti/mlx-raclate) for native inference on Apple Silicon.
|
|
101
|
+
|
|
102
|
+
```python
|
|
103
|
+
{code}
|
|
104
|
+
```
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_code_for_trained_model(
|
|
109
|
+
pipeline: str,
|
|
110
|
+
model_path: str,
|
|
111
|
+
base_model: str,
|
|
112
|
+
training_task: Optional[str] = None,
|
|
113
|
+
**kwargs,
|
|
114
|
+
) -> str:
|
|
115
|
+
"""
|
|
116
|
+
Generate model card content for a newly trained model.
|
|
117
|
+
|
|
118
|
+
This is intended to be called after training, to generate the
|
|
119
|
+
inference example code for the model card before uploading to HuggingFace.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
pipeline: Pipeline the model was trained for
|
|
123
|
+
model_path: Path where the model will be uploaded (e.g., "my-org/my-model")
|
|
124
|
+
base_model: The base model used for training
|
|
125
|
+
training_task: Optional description of the training task
|
|
126
|
+
**kwargs: Additional arguments for the code template
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Complete markdown section for the model card
|
|
130
|
+
"""
|
|
131
|
+
section = generate_model_card_section(
|
|
132
|
+
pipeline=pipeline,
|
|
133
|
+
model_path=model_path,
|
|
134
|
+
**kwargs,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Add metadata about training
|
|
138
|
+
metadata = f"""
|
|
139
|
+
### Model Details
|
|
140
|
+
|
|
141
|
+
- **Base Model**: [{base_model}](https://huggingface.co/{base_model})
|
|
142
|
+
- **Pipeline**: `{pipeline}`
|
|
143
|
+
- **Framework**: [mlx-raclate](https://github.com/pappitti/mlx-raclate) (MLX)
|
|
144
|
+
"""
|
|
145
|
+
if training_task:
|
|
146
|
+
metadata += f"- **Training Task**: {training_task}\n"
|
|
147
|
+
|
|
148
|
+
return section + metadata
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ============================================================================
|
|
152
|
+
# CLI INTERFACE
|
|
153
|
+
# ============================================================================
|
|
154
|
+
|
|
155
|
+
def main():
|
|
156
|
+
"""CLI for generating model card code snippets."""
|
|
157
|
+
import argparse
|
|
158
|
+
|
|
159
|
+
parser = argparse.ArgumentParser(
|
|
160
|
+
description="Generate inference code for model cards"
|
|
161
|
+
)
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"pipeline",
|
|
164
|
+
choices=get_available_pipelines(),
|
|
165
|
+
help="Pipeline type"
|
|
166
|
+
)
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"--model-path",
|
|
169
|
+
default="{{MODEL_PATH}}",
|
|
170
|
+
help="Model path/ID for the example"
|
|
171
|
+
)
|
|
172
|
+
parser.add_argument(
|
|
173
|
+
"--late-interaction",
|
|
174
|
+
action="store_true",
|
|
175
|
+
help="Use late interaction for sentence-similarity"
|
|
176
|
+
)
|
|
177
|
+
parser.add_argument(
|
|
178
|
+
"--full-section",
|
|
179
|
+
action="store_true",
|
|
180
|
+
help="Generate full markdown section instead of just code"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
args = parser.parse_args()
|
|
184
|
+
|
|
185
|
+
kwargs = {}
|
|
186
|
+
if args.late_interaction:
|
|
187
|
+
kwargs["use_late_interaction"] = True
|
|
188
|
+
|
|
189
|
+
if args.full_section:
|
|
190
|
+
output = generate_model_card_section(
|
|
191
|
+
pipeline=args.pipeline,
|
|
192
|
+
model_path=args.model_path,
|
|
193
|
+
**kwargs,
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
output = get_inference_code(
|
|
197
|
+
pipeline=args.pipeline,
|
|
198
|
+
model_path=args.model_path,
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
print(output)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
if __name__ == "__main__":
|
|
206
|
+
main()
|