sdg-hub 0.1.0a2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +4 -0
- sdg_hub/_version.py +21 -0
- sdg_hub/blocks/__init__.py +6 -0
- sdg_hub/blocks/block.py +54 -0
- sdg_hub/blocks/filterblock.py +76 -0
- sdg_hub/blocks/iterblock.py +31 -0
- sdg_hub/blocks/llmblock.py +430 -0
- sdg_hub/blocks/rmblocks.py +194 -0
- sdg_hub/blocks/utilblocks.py +140 -0
- sdg_hub/configs/__init__.py +0 -0
- sdg_hub/configs/annotations/__init__.py +0 -0
- sdg_hub/configs/annotations/cot_reflection.yaml +34 -0
- sdg_hub/configs/annotations/detailed_description.yaml +10 -0
- sdg_hub/configs/annotations/detailed_description_icl.yaml +32 -0
- sdg_hub/configs/annotations/simple.yaml +10 -0
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +45 -0
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +35 -0
- sdg_hub/configs/knowledge/data_recipe/__init__.py +0 -0
- sdg_hub/configs/knowledge/data_recipe/default_recipe.yaml +3 -0
- sdg_hub/configs/knowledge/detailed_summary.yaml +17 -0
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +68 -0
- sdg_hub/configs/knowledge/evaluate_question.yaml +38 -0
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +85 -0
- sdg_hub/configs/knowledge/extractive_summary.yaml +17 -0
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +39 -0
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +56 -0
- sdg_hub/configs/knowledge/mcq_generation.yaml +83 -0
- sdg_hub/configs/knowledge/router.yaml +12 -0
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +34 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +40 -0
- sdg_hub/configs/skills/_A_.yaml +97 -0
- sdg_hub/configs/skills/_B_.yaml +36 -0
- sdg_hub/configs/skills/_C_.yaml +71 -0
- sdg_hub/configs/skills/_D_.yaml +85 -0
- sdg_hub/configs/skills/_E_.yaml +30 -0
- sdg_hub/configs/skills/_F_.yaml +45 -0
- sdg_hub/configs/skills/_G_.yaml +56 -0
- sdg_hub/configs/skills/_H_.yaml +80 -0
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +48 -0
- sdg_hub/configs/skills/annotation.yaml +36 -0
- sdg_hub/configs/skills/contexts.yaml +21 -0
- sdg_hub/configs/skills/critic.yaml +60 -0
- sdg_hub/configs/skills/data_recipe/__init__.py +0 -0
- sdg_hub/configs/skills/data_recipe/default_recipe.yaml +6 -0
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +44 -0
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +46 -0
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +54 -0
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +51 -0
- sdg_hub/configs/skills/freeform_questions.yaml +29 -0
- sdg_hub/configs/skills/freeform_responses.yaml +45 -0
- sdg_hub/configs/skills/grounded_questions.yaml +38 -0
- sdg_hub/configs/skills/grounded_responses.yaml +59 -0
- sdg_hub/configs/skills/judge.yaml +53 -0
- sdg_hub/configs/skills/planner.yaml +67 -0
- sdg_hub/configs/skills/respond.yaml +8 -0
- sdg_hub/configs/skills/revised_responder.yaml +78 -0
- sdg_hub/configs/skills/router.yaml +12 -0
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +27 -0
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +31 -0
- sdg_hub/flow.py +127 -0
- sdg_hub/flows/annotation/emotion/detailed_description.yaml +19 -0
- sdg_hub/flows/annotation/emotion/detailed_description_icl.yaml +19 -0
- sdg_hub/flows/annotation/emotion/simple.yaml +19 -0
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +13 -0
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +12 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +89 -0
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +136 -0
- sdg_hub/flows/generation/skills/agentic_improve_skill.yaml +108 -0
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +12 -0
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +80 -0
- sdg_hub/flows/generation/skills/synth_skills.yaml +59 -0
- sdg_hub/logger_config.py +20 -0
- sdg_hub/pipeline.py +66 -0
- sdg_hub/prompts.py +17 -0
- sdg_hub/py.typed +0 -0
- sdg_hub/registry.py +122 -0
- sdg_hub/sdg.py +164 -0
- sdg_hub/utils/__init__.py +5 -0
- sdg_hub/utils/chunking.py +73 -0
- sdg_hub/utils/datamixing.py +123 -0
- sdg_hub/utils/datautils.py +14 -0
- sdg_hub/utils/docprocessor.py +357 -0
- sdg_hub/utils/json.py +48 -0
- sdg_hub/utils/models.py +31 -0
- sdg_hub/utils/parse_and_convert.py +392 -0
- sdg_hub/utils/taxonomy.py +489 -0
- sdg_hub-0.1.0a2.dev0.dist-info/METADATA +154 -0
- sdg_hub-0.1.0a2.dev0.dist-info/RECORD +94 -0
- sdg_hub-0.1.0a2.dev0.dist-info/WHEEL +5 -0
- sdg_hub-0.1.0a2.dev0.dist-info/licenses/LICENSE +201 -0
- sdg_hub-0.1.0a2.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# Standard
|
2
|
+
from pathlib import Path
|
3
|
+
import json
|
4
|
+
|
5
|
+
# Third Party
|
6
|
+
from datasets import Dataset
|
7
|
+
from tabulate import tabulate
|
8
|
+
from transformers import AutoTokenizer
|
9
|
+
import yaml
|
10
|
+
|
11
|
+
# First Party
|
12
|
+
from sdg_hub.logger_config import setup_logger
|
13
|
+
|
14
|
+
# Local
|
15
|
+
from .datautils import safe_concatenate_datasets
|
16
|
+
from .chunking import chunk_document
|
17
|
+
|
18
|
+
logger = setup_logger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def fuse_texts(text_list, short_length_threshold=100):
|
22
|
+
fused_texts = []
|
23
|
+
previous_long_text = ""
|
24
|
+
|
25
|
+
for text in text_list:
|
26
|
+
word_count = len(text.split())
|
27
|
+
|
28
|
+
if word_count <= short_length_threshold and previous_long_text:
|
29
|
+
# Append the short text to the last long text
|
30
|
+
fused_texts[-1] += "\n\n" + text
|
31
|
+
else:
|
32
|
+
# This is a long text, so add it to the list and remember it
|
33
|
+
fused_texts.append(text)
|
34
|
+
previous_long_text = text
|
35
|
+
|
36
|
+
return fused_texts
|
37
|
+
|
38
|
+
|
39
|
+
def handle_footnote(book_element):
|
40
|
+
pass
|
41
|
+
|
42
|
+
|
43
|
+
def create_tokenizer():
|
44
|
+
return AutoTokenizer.from_pretrained("instructlab/granite-7b-lab")
|
45
|
+
|
46
|
+
|
47
|
+
def get_token_count(text, tokenizer):
|
48
|
+
return len(tokenizer.tokenize(text))
|
49
|
+
|
50
|
+
|
51
|
+
def add_heading_formatting(text):
|
52
|
+
text = text.split(".")
|
53
|
+
# TODO: Change this from hardcoded to something that makes sense
|
54
|
+
if len(text) > 1 and len(text[0].split(" ")) < 3:
|
55
|
+
text = f"**{text[0]}**" + ".".join(text[1:])
|
56
|
+
else:
|
57
|
+
text = ".".join(text)
|
58
|
+
return text
|
59
|
+
|
60
|
+
|
61
|
+
def generate_table_from_parsed_rep(item):
|
62
|
+
"""
|
63
|
+
Generate the table from the parsed representation and return
|
64
|
+
"""
|
65
|
+
caption = ""
|
66
|
+
if "text" in item:
|
67
|
+
# print("caption: ", item["text"])
|
68
|
+
caption = item["text"]
|
69
|
+
|
70
|
+
data = item["data"]
|
71
|
+
|
72
|
+
if len(data) <= 1 or len(data[0]) <= 1:
|
73
|
+
return ""
|
74
|
+
|
75
|
+
table = []
|
76
|
+
for i, row in enumerate(data):
|
77
|
+
trow = []
|
78
|
+
for j, cell in enumerate(row):
|
79
|
+
trow.append(cell["text"])
|
80
|
+
table.append(trow)
|
81
|
+
|
82
|
+
table_text = tabulate(table, tablefmt="github")
|
83
|
+
if caption:
|
84
|
+
table_text += f"\nCaption: {caption}\n"
|
85
|
+
return table_text
|
86
|
+
|
87
|
+
|
88
|
+
def get_table(json_book, table_ref):
|
89
|
+
parts = table_ref.split("/")
|
90
|
+
table_text = generate_table_from_parsed_rep(json_book[parts[1]][int(parts[2])])
|
91
|
+
return table_text
|
92
|
+
|
93
|
+
|
94
|
+
def get_table_page_number(json_book, idx):
|
95
|
+
# Get previous page number
|
96
|
+
prev_page_num, next_page_num = None, None
|
97
|
+
for book_element in json_book["main-text"][idx - 1 :: -1]:
|
98
|
+
if "prov" in book_element:
|
99
|
+
prev_page_num = book_element["prov"][0]["page"]
|
100
|
+
break
|
101
|
+
for book_element in json_book["main-text"][idx:]:
|
102
|
+
if "prov" in book_element:
|
103
|
+
next_page_num = book_element["prov"][0]["page"]
|
104
|
+
break
|
105
|
+
if prev_page_num is not None and next_page_num is not None:
|
106
|
+
if prev_page_num == next_page_num:
|
107
|
+
return prev_page_num
|
108
|
+
else:
|
109
|
+
return next_page_num
|
110
|
+
elif prev_page_num is not None:
|
111
|
+
return prev_page_num
|
112
|
+
elif next_page_num is not None:
|
113
|
+
return next_page_num
|
114
|
+
|
115
|
+
|
116
|
+
def build_chunks_from_docling_json(
|
117
|
+
json_book,
|
118
|
+
max_token_per_chunk,
|
119
|
+
tokenizer,
|
120
|
+
keep_same_page_thing_together=False,
|
121
|
+
chunking_criteria=None,
|
122
|
+
):
|
123
|
+
current_buffer = []
|
124
|
+
document_chunks = []
|
125
|
+
prev_page_number = None
|
126
|
+
book_title = None
|
127
|
+
|
128
|
+
for idx, book_element in enumerate(json_book["main-text"]):
|
129
|
+
if book_element["type"] in [
|
130
|
+
"page-footer",
|
131
|
+
"picture",
|
132
|
+
"reference",
|
133
|
+
"meta-data",
|
134
|
+
"figure",
|
135
|
+
"page-header",
|
136
|
+
]:
|
137
|
+
continue
|
138
|
+
elif book_element["type"] == "footnote":
|
139
|
+
handle_footnote(book_element)
|
140
|
+
current_book_page_number = book_element["prov"][0]["page"]
|
141
|
+
elif book_element["type"] in [
|
142
|
+
"subtitle-level-1",
|
143
|
+
"paragraph",
|
144
|
+
"table",
|
145
|
+
"title",
|
146
|
+
"equation",
|
147
|
+
]: # 'page-header',
|
148
|
+
if book_element["type"] == "table":
|
149
|
+
current_book_page_number = get_table_page_number(json_book, idx)
|
150
|
+
else:
|
151
|
+
current_book_page_number = book_element["prov"][0]["page"]
|
152
|
+
book_text = book_element["text"]
|
153
|
+
|
154
|
+
if book_element["type"] == "subtitle-level-1":
|
155
|
+
if book_title is None:
|
156
|
+
book_title = book_text
|
157
|
+
book_text = f"# Title: **{book_text}**"
|
158
|
+
else:
|
159
|
+
book_text = f"## **{book_text}**"
|
160
|
+
|
161
|
+
if book_element["type"] == "title":
|
162
|
+
book_text = f"# **{book_text}**"
|
163
|
+
if book_element["type"] == "page-header":
|
164
|
+
book_text = f"Page Header: **{book_text}**\n\n"
|
165
|
+
|
166
|
+
if chunking_criteria is not None:
|
167
|
+
# custom break function that can be used to chunk document
|
168
|
+
if chunking_criteria(book_text):
|
169
|
+
document_chunks.append("\n\n".join(current_buffer))
|
170
|
+
current_buffer = []
|
171
|
+
elif (
|
172
|
+
prev_page_number is not None
|
173
|
+
and prev_page_number != current_book_page_number
|
174
|
+
) and keep_same_page_thing_together:
|
175
|
+
document_chunks.append("\n\n".join(current_buffer))
|
176
|
+
current_buffer = []
|
177
|
+
else:
|
178
|
+
if (
|
179
|
+
get_token_count("\n\n".join(current_buffer), tokenizer)
|
180
|
+
>= max_token_per_chunk
|
181
|
+
and len(current_buffer) > 1
|
182
|
+
):
|
183
|
+
# chunk_text = '\n\n'.join(current_buffer[:-1])
|
184
|
+
# print(f"Current chunk size {get_token_count(chunk_text, tokenizer)} and max is {max_token_per_chunk}")
|
185
|
+
document_chunks.append("\n\n".join(current_buffer[:-1]))
|
186
|
+
|
187
|
+
if (
|
188
|
+
get_token_count(current_buffer[-1], tokenizer)
|
189
|
+
>= max_token_per_chunk
|
190
|
+
):
|
191
|
+
# print(f"This is too big document to be left in the current buffer { get_token_count(current_buffer[-1], tokenizer)}")
|
192
|
+
document_chunks.append(current_buffer[-1])
|
193
|
+
current_buffer = []
|
194
|
+
else:
|
195
|
+
current_buffer = current_buffer[-1:]
|
196
|
+
|
197
|
+
if book_element["type"] == "paragraph":
|
198
|
+
book_text = add_heading_formatting(book_text)
|
199
|
+
elif book_element["type"] == "table":
|
200
|
+
book_text = get_table(json_book, book_element["$ref"])
|
201
|
+
if "## References" in book_text or "## Acknowledgements" in book_text:
|
202
|
+
# For reasearch papers we ignore everything after this sections
|
203
|
+
break
|
204
|
+
current_buffer.append(book_text)
|
205
|
+
|
206
|
+
try:
|
207
|
+
prev_page_number = current_book_page_number
|
208
|
+
except:
|
209
|
+
logger.error(book_element)
|
210
|
+
if "\n\n".join(current_buffer) not in document_chunks:
|
211
|
+
document_chunks.append("\n\n".join(current_buffer))
|
212
|
+
return document_chunks
|
213
|
+
|
214
|
+
|
215
|
+
class DocProcessor:
|
216
|
+
def __init__(
|
217
|
+
self,
|
218
|
+
parsed_doc_dir: Path,
|
219
|
+
tokenizer: str = "instructlab/granite-7b-lab",
|
220
|
+
user_config_path: Path = None,
|
221
|
+
):
|
222
|
+
self.parsed_doc_dir = self._path_validator(parsed_doc_dir)
|
223
|
+
self.user_config = self._load_user_config(
|
224
|
+
self._path_validator(user_config_path)
|
225
|
+
)
|
226
|
+
self.docling_jsons = list(self.parsed_doc_dir.glob("*.json"))
|
227
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
228
|
+
|
229
|
+
def _path_validator(self, path) -> Path:
|
230
|
+
"""
|
231
|
+
Validate the path and return a Path object.
|
232
|
+
Args:
|
233
|
+
path (str): Path to be validated.
|
234
|
+
Returns:
|
235
|
+
Path`: Path object.
|
236
|
+
"""
|
237
|
+
if isinstance(path, str):
|
238
|
+
path = Path(path)
|
239
|
+
if not path.exists():
|
240
|
+
raise FileNotFoundError(f"{path} does not exist.")
|
241
|
+
return path
|
242
|
+
|
243
|
+
def _load_user_config(self, user_config_path: Path) -> dict:
|
244
|
+
"""
|
245
|
+
Load the user config file.
|
246
|
+
Args:
|
247
|
+
user_config_path (Path): Path to the user config file.
|
248
|
+
Returns:
|
249
|
+
dict: User config dictionary.
|
250
|
+
"""
|
251
|
+
# load user config as yaml
|
252
|
+
with open(user_config_path, "r", encoding="utf-8") as f:
|
253
|
+
return yaml.safe_load(f)
|
254
|
+
|
255
|
+
def _process_parsed_docling_json(self, json_fp: Path) -> Dataset:
|
256
|
+
"""
|
257
|
+
Process the parsed docling json file and return a dataset.
|
258
|
+
Args:
|
259
|
+
json_fp (str): Path to the parsed docling json file.
|
260
|
+
Returns:
|
261
|
+
Dataset: Dataset object.
|
262
|
+
"""
|
263
|
+
logger.info(f"Processing parsed docling json file: {json_fp}")
|
264
|
+
with open(json_fp, "r", encoding="utf-8") as f:
|
265
|
+
data = json.load(f)
|
266
|
+
|
267
|
+
file_name = json_fp.name.split(".")[0]
|
268
|
+
chunks = build_chunks_from_docling_json(
|
269
|
+
data,
|
270
|
+
max_token_per_chunk=500,
|
271
|
+
tokenizer=self.tokenizer,
|
272
|
+
)
|
273
|
+
chunks = fuse_texts(chunks, 200)
|
274
|
+
return Dataset.from_dict(
|
275
|
+
{
|
276
|
+
"document": chunks,
|
277
|
+
"document_outline": [self.user_config["document_outline"]]
|
278
|
+
* len(chunks),
|
279
|
+
"document_title": [file_name] * len(chunks),
|
280
|
+
"domain": [self.user_config["domain"]] * len(chunks),
|
281
|
+
}
|
282
|
+
)
|
283
|
+
|
284
|
+
def _add_icls(self, chunked_document: Dataset) -> Dataset:
|
285
|
+
"""
|
286
|
+
Add the ICLS label to the dataset.
|
287
|
+
Args:
|
288
|
+
dataset (Dataset): Dataset object.
|
289
|
+
Returns:
|
290
|
+
Dataset: Dataset object with ICLS label.
|
291
|
+
"""
|
292
|
+
icl = self.user_config["seed_examples"]
|
293
|
+
chunked_document_all_icl = []
|
294
|
+
for icl_ in icl:
|
295
|
+
chunked_document_all_icl.append(
|
296
|
+
chunked_document.map(
|
297
|
+
lambda x: {
|
298
|
+
"icl_document": icl_["context"],
|
299
|
+
"icl_query_1": icl_["questions_and_answers"][0]["question"],
|
300
|
+
"icl_response_1": icl_["questions_and_answers"][0]["answer"],
|
301
|
+
"icl_query_2": icl_["questions_and_answers"][1]["question"],
|
302
|
+
"icl_response_2": icl_["questions_and_answers"][1]["answer"],
|
303
|
+
"icl_query_3": icl_["questions_and_answers"][2]["question"],
|
304
|
+
"icl_response_3": icl_["questions_and_answers"][2]["answer"],
|
305
|
+
}
|
306
|
+
)
|
307
|
+
)
|
308
|
+
chunked_document_all_icl = safe_concatenate_datasets(chunked_document_all_icl)
|
309
|
+
chunked_document_all_icl = chunked_document_all_icl.map(
|
310
|
+
lambda x: {
|
311
|
+
"chunks": chunk_document(
|
312
|
+
[x["document"]], server_ctx_size=4096, chunk_word_count=1024
|
313
|
+
)
|
314
|
+
if get_token_count(x["document"], self.tokenizer) > 1024
|
315
|
+
else [x["document"]]
|
316
|
+
}
|
317
|
+
)
|
318
|
+
df = chunked_document_all_icl.to_pandas()
|
319
|
+
df_exploded = df.explode("chunks").reset_index(drop=True)
|
320
|
+
new_ds = Dataset.from_pandas(df_exploded)
|
321
|
+
new_ds = new_ds.remove_columns("document").rename_columns(
|
322
|
+
{"chunks": "document"}
|
323
|
+
)
|
324
|
+
|
325
|
+
# Only keep document greater than 100 tokens
|
326
|
+
new_ds = new_ds.filter(
|
327
|
+
lambda x: get_token_count(x["document"], self.tokenizer) > 100
|
328
|
+
)
|
329
|
+
return new_ds
|
330
|
+
|
331
|
+
def get_processed_dataset(self) -> Dataset:
|
332
|
+
"""
|
333
|
+
Process all the parsed docling json files and return a dataset.
|
334
|
+
Returns:
|
335
|
+
Dataset: Dataset object.
|
336
|
+
"""
|
337
|
+
datasets = []
|
338
|
+
for json_fp in self.docling_jsons:
|
339
|
+
chunk_ds = self._process_parsed_docling_json(json_fp)
|
340
|
+
chunk_ds_with_icls = self._add_icls(chunk_ds)
|
341
|
+
datasets.append(chunk_ds_with_icls)
|
342
|
+
return safe_concatenate_datasets(datasets)
|
343
|
+
|
344
|
+
def get_processed_markdown_dataset(self, list_md_files: list[Path]) -> Dataset:
|
345
|
+
chunks_mds = []
|
346
|
+
for md_file in list_md_files:
|
347
|
+
with open(md_file, "r", encoding="utf-8") as f:
|
348
|
+
text = f.read()
|
349
|
+
chunks_mds.append({
|
350
|
+
"document": text,
|
351
|
+
"document_outline": self.user_config["document_outline"],
|
352
|
+
"document_title": md_file,
|
353
|
+
"domain": self.user_config["domain"],
|
354
|
+
})
|
355
|
+
chunk_ds = Dataset.from_list(chunks_mds)
|
356
|
+
chunk_ds_with_icls = self._add_icls(chunk_ds)
|
357
|
+
return chunk_ds_with_icls
|
sdg_hub/utils/json.py
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# Standard
|
4
|
+
import io
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
|
8
|
+
|
9
|
+
def _make_w_io_base(f, mode: str):
|
10
|
+
# pylint: disable=consider-using-with
|
11
|
+
if not isinstance(f, io.IOBase):
|
12
|
+
f_dirname = os.path.dirname(f)
|
13
|
+
if f_dirname != "":
|
14
|
+
os.makedirs(f_dirname, exist_ok=True)
|
15
|
+
f = open(f, mode=mode, encoding="utf-8")
|
16
|
+
return f
|
17
|
+
|
18
|
+
|
19
|
+
def _make_r_io_base(f, mode: str):
|
20
|
+
# pylint: disable=consider-using-with
|
21
|
+
if not isinstance(f, io.IOBase):
|
22
|
+
f = open(f, mode=mode, encoding="utf-8")
|
23
|
+
return f
|
24
|
+
|
25
|
+
|
26
|
+
def jdump(obj, f, mode="w", indent=4, default=str):
|
27
|
+
"""Dump a str or dictionary to a file in json format.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
obj: An object to be written.
|
31
|
+
f: A string path to the location on disk.
|
32
|
+
mode: Mode for opening the file.
|
33
|
+
indent: Indent for storing json dictionaries.
|
34
|
+
default: A function to handle non-serializable entries; defaults to `str`.
|
35
|
+
"""
|
36
|
+
with _make_w_io_base(f, mode) as f_:
|
37
|
+
if isinstance(obj, (dict, list)):
|
38
|
+
json.dump(obj, f_, indent=indent, default=default)
|
39
|
+
elif isinstance(obj, str):
|
40
|
+
f_.write(obj)
|
41
|
+
else:
|
42
|
+
raise ValueError(f"Unexpected type: {type(obj)}")
|
43
|
+
|
44
|
+
|
45
|
+
def jload(f, mode="r"):
|
46
|
+
"""Load a .json file into a dictionary."""
|
47
|
+
with _make_r_io_base(f, mode) as f_:
|
48
|
+
return json.load(f_)
|
sdg_hub/utils/models.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# Standard
|
4
|
+
import os
|
5
|
+
import re
|
6
|
+
|
7
|
+
# First Party
|
8
|
+
from sdg_hub.utils import GenerateException
|
9
|
+
|
10
|
+
# When otherwise unknown, ilab uses this as the default family
|
11
|
+
DEFAULT_MODEL_FAMILY = "merlinite"
|
12
|
+
|
13
|
+
# Model families understood by ilab
|
14
|
+
MODEL_FAMILIES = set(("merlinite", "mixtral"))
|
15
|
+
|
16
|
+
# Map model names to their family
|
17
|
+
MODEL_FAMILY_MAPPINGS = {
|
18
|
+
"granite": "merlinite",
|
19
|
+
}
|
20
|
+
|
21
|
+
|
22
|
+
def get_model_family(forced, model_path):
|
23
|
+
forced = MODEL_FAMILY_MAPPINGS.get(forced, forced)
|
24
|
+
if forced and forced.lower() not in MODEL_FAMILIES:
|
25
|
+
raise GenerateException("Unknown model family: %s" % forced)
|
26
|
+
|
27
|
+
# Try to guess the model family based on the model's filename
|
28
|
+
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
|
29
|
+
guess = MODEL_FAMILY_MAPPINGS.get(guess, guess)
|
30
|
+
|
31
|
+
return guess if guess in MODEL_FAMILIES else DEFAULT_MODEL_FAMILY
|