@huggingface/tasks 0.13.1-test → 0.13.1-test2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. package/package.json +4 -2
  2. package/src/dataset-libraries.ts +89 -0
  3. package/src/default-widget-inputs.ts +718 -0
  4. package/src/gguf.ts +40 -0
  5. package/src/hardware.ts +482 -0
  6. package/src/index.ts +59 -0
  7. package/src/library-to-tasks.ts +76 -0
  8. package/src/local-apps.ts +412 -0
  9. package/src/model-data.ts +149 -0
  10. package/src/model-libraries-downloads.ts +18 -0
  11. package/src/model-libraries-snippets.ts +1128 -0
  12. package/src/model-libraries.ts +820 -0
  13. package/src/pipelines.ts +698 -0
  14. package/src/snippets/common.ts +39 -0
  15. package/src/snippets/curl.spec.ts +94 -0
  16. package/src/snippets/curl.ts +120 -0
  17. package/src/snippets/index.ts +7 -0
  18. package/src/snippets/inputs.ts +167 -0
  19. package/src/snippets/js.spec.ts +148 -0
  20. package/src/snippets/js.ts +305 -0
  21. package/src/snippets/python.spec.ts +144 -0
  22. package/src/snippets/python.ts +321 -0
  23. package/src/snippets/types.ts +16 -0
  24. package/src/tasks/audio-classification/about.md +86 -0
  25. package/src/tasks/audio-classification/data.ts +81 -0
  26. package/src/tasks/audio-classification/inference.ts +52 -0
  27. package/src/tasks/audio-classification/spec/input.json +35 -0
  28. package/src/tasks/audio-classification/spec/output.json +11 -0
  29. package/src/tasks/audio-to-audio/about.md +56 -0
  30. package/src/tasks/audio-to-audio/data.ts +70 -0
  31. package/src/tasks/automatic-speech-recognition/about.md +90 -0
  32. package/src/tasks/automatic-speech-recognition/data.ts +82 -0
  33. package/src/tasks/automatic-speech-recognition/inference.ts +160 -0
  34. package/src/tasks/automatic-speech-recognition/spec/input.json +35 -0
  35. package/src/tasks/automatic-speech-recognition/spec/output.json +38 -0
  36. package/src/tasks/chat-completion/inference.ts +322 -0
  37. package/src/tasks/chat-completion/spec/input.json +350 -0
  38. package/src/tasks/chat-completion/spec/output.json +206 -0
  39. package/src/tasks/chat-completion/spec/stream_output.json +213 -0
  40. package/src/tasks/common-definitions.json +100 -0
  41. package/src/tasks/depth-estimation/about.md +45 -0
  42. package/src/tasks/depth-estimation/data.ts +70 -0
  43. package/src/tasks/depth-estimation/inference.ts +35 -0
  44. package/src/tasks/depth-estimation/spec/input.json +25 -0
  45. package/src/tasks/depth-estimation/spec/output.json +16 -0
  46. package/src/tasks/document-question-answering/about.md +53 -0
  47. package/src/tasks/document-question-answering/data.ts +85 -0
  48. package/src/tasks/document-question-answering/inference.ts +110 -0
  49. package/src/tasks/document-question-answering/spec/input.json +85 -0
  50. package/src/tasks/document-question-answering/spec/output.json +36 -0
  51. package/src/tasks/feature-extraction/about.md +72 -0
  52. package/src/tasks/feature-extraction/data.ts +57 -0
  53. package/src/tasks/feature-extraction/inference.ts +40 -0
  54. package/src/tasks/feature-extraction/spec/input.json +47 -0
  55. package/src/tasks/feature-extraction/spec/output.json +15 -0
  56. package/src/tasks/fill-mask/about.md +51 -0
  57. package/src/tasks/fill-mask/data.ts +79 -0
  58. package/src/tasks/fill-mask/inference.ts +62 -0
  59. package/src/tasks/fill-mask/spec/input.json +38 -0
  60. package/src/tasks/fill-mask/spec/output.json +29 -0
  61. package/src/tasks/image-classification/about.md +50 -0
  62. package/src/tasks/image-classification/data.ts +88 -0
  63. package/src/tasks/image-classification/inference.ts +52 -0
  64. package/src/tasks/image-classification/spec/input.json +35 -0
  65. package/src/tasks/image-classification/spec/output.json +11 -0
  66. package/src/tasks/image-feature-extraction/about.md +23 -0
  67. package/src/tasks/image-feature-extraction/data.ts +59 -0
  68. package/src/tasks/image-segmentation/about.md +63 -0
  69. package/src/tasks/image-segmentation/data.ts +99 -0
  70. package/src/tasks/image-segmentation/inference.ts +69 -0
  71. package/src/tasks/image-segmentation/spec/input.json +45 -0
  72. package/src/tasks/image-segmentation/spec/output.json +26 -0
  73. package/src/tasks/image-text-to-text/about.md +76 -0
  74. package/src/tasks/image-text-to-text/data.ts +102 -0
  75. package/src/tasks/image-to-3d/about.md +62 -0
  76. package/src/tasks/image-to-3d/data.ts +75 -0
  77. package/src/tasks/image-to-image/about.md +129 -0
  78. package/src/tasks/image-to-image/data.ts +101 -0
  79. package/src/tasks/image-to-image/inference.ts +68 -0
  80. package/src/tasks/image-to-image/spec/input.json +55 -0
  81. package/src/tasks/image-to-image/spec/output.json +12 -0
  82. package/src/tasks/image-to-text/about.md +61 -0
  83. package/src/tasks/image-to-text/data.ts +82 -0
  84. package/src/tasks/image-to-text/inference.ts +143 -0
  85. package/src/tasks/image-to-text/spec/input.json +34 -0
  86. package/src/tasks/image-to-text/spec/output.json +14 -0
  87. package/src/tasks/index.ts +312 -0
  88. package/src/tasks/keypoint-detection/about.md +57 -0
  89. package/src/tasks/keypoint-detection/data.ts +50 -0
  90. package/src/tasks/mask-generation/about.md +65 -0
  91. package/src/tasks/mask-generation/data.ts +55 -0
  92. package/src/tasks/object-detection/about.md +37 -0
  93. package/src/tasks/object-detection/data.ts +86 -0
  94. package/src/tasks/object-detection/inference.ts +75 -0
  95. package/src/tasks/object-detection/spec/input.json +31 -0
  96. package/src/tasks/object-detection/spec/output.json +50 -0
  97. package/src/tasks/placeholder/about.md +15 -0
  98. package/src/tasks/placeholder/data.ts +21 -0
  99. package/src/tasks/placeholder/spec/input.json +35 -0
  100. package/src/tasks/placeholder/spec/output.json +17 -0
  101. package/src/tasks/question-answering/about.md +56 -0
  102. package/src/tasks/question-answering/data.ts +75 -0
  103. package/src/tasks/question-answering/inference.ts +99 -0
  104. package/src/tasks/question-answering/spec/input.json +67 -0
  105. package/src/tasks/question-answering/spec/output.json +29 -0
  106. package/src/tasks/reinforcement-learning/about.md +167 -0
  107. package/src/tasks/reinforcement-learning/data.ts +75 -0
  108. package/src/tasks/sentence-similarity/about.md +97 -0
  109. package/src/tasks/sentence-similarity/data.ts +101 -0
  110. package/src/tasks/sentence-similarity/inference.ts +32 -0
  111. package/src/tasks/sentence-similarity/spec/input.json +40 -0
  112. package/src/tasks/sentence-similarity/spec/output.json +12 -0
  113. package/src/tasks/summarization/about.md +58 -0
  114. package/src/tasks/summarization/data.ts +76 -0
  115. package/src/tasks/summarization/inference.ts +57 -0
  116. package/src/tasks/summarization/spec/input.json +42 -0
  117. package/src/tasks/summarization/spec/output.json +14 -0
  118. package/src/tasks/table-question-answering/about.md +43 -0
  119. package/src/tasks/table-question-answering/data.ts +59 -0
  120. package/src/tasks/table-question-answering/inference.ts +61 -0
  121. package/src/tasks/table-question-answering/spec/input.json +44 -0
  122. package/src/tasks/table-question-answering/spec/output.json +40 -0
  123. package/src/tasks/tabular-classification/about.md +65 -0
  124. package/src/tasks/tabular-classification/data.ts +68 -0
  125. package/src/tasks/tabular-regression/about.md +87 -0
  126. package/src/tasks/tabular-regression/data.ts +57 -0
  127. package/src/tasks/text-classification/about.md +173 -0
  128. package/src/tasks/text-classification/data.ts +103 -0
  129. package/src/tasks/text-classification/inference.ts +51 -0
  130. package/src/tasks/text-classification/spec/input.json +35 -0
  131. package/src/tasks/text-classification/spec/output.json +11 -0
  132. package/src/tasks/text-generation/about.md +154 -0
  133. package/src/tasks/text-generation/data.ts +114 -0
  134. package/src/tasks/text-generation/inference.ts +200 -0
  135. package/src/tasks/text-generation/spec/input.json +219 -0
  136. package/src/tasks/text-generation/spec/output.json +179 -0
  137. package/src/tasks/text-generation/spec/stream_output.json +103 -0
  138. package/src/tasks/text-to-3d/about.md +62 -0
  139. package/src/tasks/text-to-3d/data.ts +56 -0
  140. package/src/tasks/text-to-audio/inference.ts +143 -0
  141. package/src/tasks/text-to-audio/spec/input.json +31 -0
  142. package/src/tasks/text-to-audio/spec/output.json +17 -0
  143. package/src/tasks/text-to-image/about.md +96 -0
  144. package/src/tasks/text-to-image/data.ts +100 -0
  145. package/src/tasks/text-to-image/inference.ts +75 -0
  146. package/src/tasks/text-to-image/spec/input.json +63 -0
  147. package/src/tasks/text-to-image/spec/output.json +13 -0
  148. package/src/tasks/text-to-speech/about.md +63 -0
  149. package/src/tasks/text-to-speech/data.ts +79 -0
  150. package/src/tasks/text-to-speech/inference.ts +145 -0
  151. package/src/tasks/text-to-speech/spec/input.json +31 -0
  152. package/src/tasks/text-to-speech/spec/output.json +7 -0
  153. package/src/tasks/text-to-video/about.md +41 -0
  154. package/src/tasks/text-to-video/data.ts +102 -0
  155. package/src/tasks/text2text-generation/inference.ts +55 -0
  156. package/src/tasks/text2text-generation/spec/input.json +55 -0
  157. package/src/tasks/text2text-generation/spec/output.json +14 -0
  158. package/src/tasks/token-classification/about.md +76 -0
  159. package/src/tasks/token-classification/data.ts +92 -0
  160. package/src/tasks/token-classification/inference.ts +85 -0
  161. package/src/tasks/token-classification/spec/input.json +65 -0
  162. package/src/tasks/token-classification/spec/output.json +37 -0
  163. package/src/tasks/translation/about.md +65 -0
  164. package/src/tasks/translation/data.ts +70 -0
  165. package/src/tasks/translation/inference.ts +67 -0
  166. package/src/tasks/translation/spec/input.json +50 -0
  167. package/src/tasks/translation/spec/output.json +14 -0
  168. package/src/tasks/unconditional-image-generation/about.md +50 -0
  169. package/src/tasks/unconditional-image-generation/data.ts +72 -0
  170. package/src/tasks/video-classification/about.md +37 -0
  171. package/src/tasks/video-classification/data.ts +84 -0
  172. package/src/tasks/video-classification/inference.ts +59 -0
  173. package/src/tasks/video-classification/spec/input.json +42 -0
  174. package/src/tasks/video-classification/spec/output.json +10 -0
  175. package/src/tasks/video-text-to-text/about.md +98 -0
  176. package/src/tasks/video-text-to-text/data.ts +66 -0
  177. package/src/tasks/visual-question-answering/about.md +48 -0
  178. package/src/tasks/visual-question-answering/data.ts +97 -0
  179. package/src/tasks/visual-question-answering/inference.ts +62 -0
  180. package/src/tasks/visual-question-answering/spec/input.json +41 -0
  181. package/src/tasks/visual-question-answering/spec/output.json +21 -0
  182. package/src/tasks/zero-shot-classification/about.md +40 -0
  183. package/src/tasks/zero-shot-classification/data.ts +70 -0
  184. package/src/tasks/zero-shot-classification/inference.ts +67 -0
  185. package/src/tasks/zero-shot-classification/spec/input.json +50 -0
  186. package/src/tasks/zero-shot-classification/spec/output.json +11 -0
  187. package/src/tasks/zero-shot-image-classification/about.md +75 -0
  188. package/src/tasks/zero-shot-image-classification/data.ts +84 -0
  189. package/src/tasks/zero-shot-image-classification/inference.ts +61 -0
  190. package/src/tasks/zero-shot-image-classification/spec/input.json +45 -0
  191. package/src/tasks/zero-shot-image-classification/spec/output.json +10 -0
  192. package/src/tasks/zero-shot-object-detection/about.md +45 -0
  193. package/src/tasks/zero-shot-object-detection/data.ts +67 -0
  194. package/src/tasks/zero-shot-object-detection/inference.ts +66 -0
  195. package/src/tasks/zero-shot-object-detection/spec/input.json +40 -0
  196. package/src/tasks/zero-shot-object-detection/spec/output.json +47 -0
  197. package/src/tokenizer-data.ts +32 -0
  198. package/src/widget-example.ts +125 -0
@@ -0,0 +1,1128 @@
1
+ import type { ModelData } from "./model-data.js";
2
+ import type { WidgetExampleTextInput, WidgetExampleSentenceSimilarityInput } from "./widget-example.js";
3
+ import { LIBRARY_TASK_MAPPING } from "./library-to-tasks.js";
4
+
5
+ const TAG_CUSTOM_CODE = "custom_code";
6
+
7
+ function nameWithoutNamespace(modelId: string): string {
8
+ const splitted = modelId.split("/");
9
+ return splitted.length === 1 ? splitted[0] : splitted[1];
10
+ }
11
+
12
+ const escapeStringForJson = (str: string): string => JSON.stringify(str).slice(1, -1); // slice is needed to remove surrounding quotes added by JSON.stringify
13
+
14
+ //#region snippets
15
+
16
+ export const adapters = (model: ModelData): string[] => [
17
+ `from adapters import AutoAdapterModel
18
+
19
+ model = AutoAdapterModel.from_pretrained("${model.config?.adapter_transformers?.model_name}")
20
+ model.load_adapter("${model.id}", set_active=True)`,
21
+ ];
22
+
23
+ const allennlpUnknown = (model: ModelData) => [
24
+ `import allennlp_models
25
+ from allennlp.predictors.predictor import Predictor
26
+
27
+ predictor = Predictor.from_path("hf://${model.id}")`,
28
+ ];
29
+
30
+ const allennlpQuestionAnswering = (model: ModelData) => [
31
+ `import allennlp_models
32
+ from allennlp.predictors.predictor import Predictor
33
+
34
+ predictor = Predictor.from_path("hf://${model.id}")
35
+ predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"}
36
+ predictions = predictor.predict_json(predictor_input)`,
37
+ ];
38
+
39
+ export const allennlp = (model: ModelData): string[] => {
40
+ if (model.tags.includes("question-answering")) {
41
+ return allennlpQuestionAnswering(model);
42
+ }
43
+ return allennlpUnknown(model);
44
+ };
45
+
46
+ export const asteroid = (model: ModelData): string[] => [
47
+ `from asteroid.models import BaseModel
48
+
49
+ model = BaseModel.from_pretrained("${model.id}")`,
50
+ ];
51
+
52
+ export const audioseal = (model: ModelData): string[] => {
53
+ const watermarkSnippet = `# Watermark Generator
54
+ from audioseal import AudioSeal
55
+
56
+ model = AudioSeal.load_generator("${model.id}")
57
+ # pass a tensor (tensor_wav) of shape (batch, channels, samples) and a sample rate
58
+ wav, sr = tensor_wav, 16000
59
+
60
+ watermark = model.get_watermark(wav, sr)
61
+ watermarked_audio = wav + watermark`;
62
+
63
+ const detectorSnippet = `# Watermark Detector
64
+ from audioseal import AudioSeal
65
+
66
+ detector = AudioSeal.load_detector("${model.id}")
67
+
68
+ result, message = detector.detect_watermark(watermarked_audio, sr)`;
69
+ return [watermarkSnippet, detectorSnippet];
70
+ };
71
+
72
+ function get_base_diffusers_model(model: ModelData): string {
73
+ return model.cardData?.base_model?.toString() ?? "fill-in-base-model";
74
+ }
75
+
76
+ function get_prompt_from_diffusers_model(model: ModelData): string | undefined {
77
+ const prompt = (model.widgetData?.[0] as WidgetExampleTextInput | undefined)?.text ?? model.cardData?.instance_prompt;
78
+ if (prompt) {
79
+ return escapeStringForJson(prompt);
80
+ }
81
+ }
82
+
83
+ export const bertopic = (model: ModelData): string[] => [
84
+ `from bertopic import BERTopic
85
+
86
+ model = BERTopic.load("${model.id}")`,
87
+ ];
88
+
89
+ export const bm25s = (model: ModelData): string[] => [
90
+ `from bm25s.hf import BM25HF
91
+
92
+ retriever = BM25HF.load_from_hub("${model.id}")`,
93
+ ];
94
+
95
+ export const depth_anything_v2 = (model: ModelData): string[] => {
96
+ let encoder: string;
97
+ let features: string;
98
+ let out_channels: string;
99
+
100
+ encoder = "<ENCODER>";
101
+ features = "<NUMBER_OF_FEATURES>";
102
+ out_channels = "<OUT_CHANNELS>";
103
+
104
+ if (model.id === "depth-anything/Depth-Anything-V2-Small") {
105
+ encoder = "vits";
106
+ features = "64";
107
+ out_channels = "[48, 96, 192, 384]";
108
+ } else if (model.id === "depth-anything/Depth-Anything-V2-Base") {
109
+ encoder = "vitb";
110
+ features = "128";
111
+ out_channels = "[96, 192, 384, 768]";
112
+ } else if (model.id === "depth-anything/Depth-Anything-V2-Large") {
113
+ encoder = "vitl";
114
+ features = "256";
115
+ out_channels = "[256, 512, 1024, 1024";
116
+ }
117
+
118
+ return [
119
+ `
120
+ # Install from https://github.com/DepthAnything/Depth-Anything-V2
121
+
122
+ # Load the model and infer depth from an image
123
+ import cv2
124
+ import torch
125
+
126
+ from depth_anything_v2.dpt import DepthAnythingV2
127
+
128
+ # instantiate the model
129
+ model = DepthAnythingV2(encoder="${encoder}", features=${features}, out_channels=${out_channels})
130
+
131
+ # load the weights
132
+ filepath = hf_hub_download(repo_id="${model.id}", filename="depth_anything_v2_${encoder}.pth", repo_type="model")
133
+ state_dict = torch.load(filepath, map_location="cpu")
134
+ model.load_state_dict(state_dict).eval()
135
+
136
+ raw_img = cv2.imread("your/image/path")
137
+ depth = model.infer_image(raw_img) # HxW raw depth map in numpy
138
+ `,
139
+ ];
140
+ };
141
+
142
+ export const depth_pro = (model: ModelData): string[] => {
143
+ const installSnippet = `# Download checkpoint
144
+ pip install huggingface-hub
145
+ huggingface-cli download --local-dir checkpoints ${model.id}`;
146
+
147
+ const inferenceSnippet = `import depth_pro
148
+
149
+ # Load model and preprocessing transform
150
+ model, transform = depth_pro.create_model_and_transforms()
151
+ model.eval()
152
+
153
+ # Load and preprocess an image.
154
+ image, _, f_px = depth_pro.load_rgb("example.png")
155
+ image = transform(image)
156
+
157
+ # Run inference.
158
+ prediction = model.infer(image, f_px=f_px)
159
+
160
+ # Results: 1. Depth in meters
161
+ depth = prediction["depth"]
162
+ # Results: 2. Focal length in pixels
163
+ focallength_px = prediction["focallength_px"]`;
164
+
165
+ return [installSnippet, inferenceSnippet];
166
+ };
167
+
168
+ const diffusersDefaultPrompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k";
169
+
170
+ const diffusers_default = (model: ModelData) => [
171
+ `from diffusers import DiffusionPipeline
172
+
173
+ pipe = DiffusionPipeline.from_pretrained("${model.id}")
174
+
175
+ prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}"
176
+ image = pipe(prompt).images[0]`,
177
+ ];
178
+
179
+ const diffusers_controlnet = (model: ModelData) => [
180
+ `from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
181
+
182
+ controlnet = ControlNetModel.from_pretrained("${model.id}")
183
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
184
+ "${get_base_diffusers_model(model)}", controlnet=controlnet
185
+ )`,
186
+ ];
187
+
188
+ const diffusers_lora = (model: ModelData) => [
189
+ `from diffusers import DiffusionPipeline
190
+
191
+ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
192
+ pipe.load_lora_weights("${model.id}")
193
+
194
+ prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}"
195
+ image = pipe(prompt).images[0]`,
196
+ ];
197
+
198
+ const diffusers_textual_inversion = (model: ModelData) => [
199
+ `from diffusers import DiffusionPipeline
200
+
201
+ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
202
+ pipe.load_textual_inversion("${model.id}")`,
203
+ ];
204
+
205
+ export const diffusers = (model: ModelData): string[] => {
206
+ if (model.tags.includes("controlnet")) {
207
+ return diffusers_controlnet(model);
208
+ } else if (model.tags.includes("lora")) {
209
+ return diffusers_lora(model);
210
+ } else if (model.tags.includes("textual_inversion")) {
211
+ return diffusers_textual_inversion(model);
212
+ } else {
213
+ return diffusers_default(model);
214
+ }
215
+ };
216
+
217
+ export const diffusionkit = (model: ModelData): string[] => {
218
+ const sd3Snippet = `# Pipeline for Stable Diffusion 3
219
+ from diffusionkit.mlx import DiffusionPipeline
220
+
221
+ pipeline = DiffusionPipeline(
222
+ shift=3.0,
223
+ use_t5=False,
224
+ model_version=${model.id},
225
+ low_memory_mode=True,
226
+ a16=True,
227
+ w16=True,
228
+ )`;
229
+
230
+ const fluxSnippet = `# Pipeline for Flux
231
+ from diffusionkit.mlx import FluxPipeline
232
+
233
+ pipeline = FluxPipeline(
234
+ shift=1.0,
235
+ model_version=${model.id},
236
+ low_memory_mode=True,
237
+ a16=True,
238
+ w16=True,
239
+ )`;
240
+
241
+ const generateSnippet = `# Image Generation
242
+ HEIGHT = 512
243
+ WIDTH = 512
244
+ NUM_STEPS = ${model.tags.includes("flux") ? 4 : 50}
245
+ CFG_WEIGHT = ${model.tags.includes("flux") ? 0 : 5}
246
+
247
+ image, _ = pipeline.generate_image(
248
+ "a photo of a cat",
249
+ cfg_weight=CFG_WEIGHT,
250
+ num_steps=NUM_STEPS,
251
+ latent_size=(HEIGHT // 8, WIDTH // 8),
252
+ )`;
253
+
254
+ const pipelineSnippet = model.tags.includes("flux") ? fluxSnippet : sd3Snippet;
255
+
256
+ return [pipelineSnippet, generateSnippet];
257
+ };
258
+
259
+ export const cartesia_pytorch = (model: ModelData): string[] => [
260
+ `# pip install --no-binary :all: cartesia-pytorch
261
+ from cartesia_pytorch import ReneLMHeadModel
262
+ from transformers import AutoTokenizer
263
+
264
+ model = ReneLMHeadModel.from_pretrained("${model.id}")
265
+ tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
266
+
267
+ in_message = ["Rene Descartes was"]
268
+ inputs = tokenizer(in_message, return_tensors="pt")
269
+
270
+ outputs = model.generate(inputs.input_ids, max_length=50, top_k=100, top_p=0.99)
271
+ out_message = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
272
+
273
+ print(out_message)
274
+ )`,
275
+ ];
276
+
277
+ export const cartesia_mlx = (model: ModelData): string[] => [
278
+ `import mlx.core as mx
279
+ import cartesia_mlx as cmx
280
+
281
+ model = cmx.from_pretrained("${model.id}")
282
+ model.set_dtype(mx.float32)
283
+
284
+ prompt = "Rene Descartes was"
285
+
286
+ for text in model.generate(
287
+ prompt,
288
+ max_tokens=500,
289
+ eval_every_n=5,
290
+ verbose=True,
291
+ top_p=0.99,
292
+ temperature=0.85,
293
+ ):
294
+ print(text, end="", flush=True)
295
+ `,
296
+ ];
297
+
298
+ export const edsnlp = (model: ModelData): string[] => {
299
+ const packageName = nameWithoutNamespace(model.id).replaceAll("-", "_");
300
+ return [
301
+ `# Load it from the Hub directly
302
+ import edsnlp
303
+ nlp = edsnlp.load("${model.id}")
304
+ `,
305
+ `# Or install it as a package
306
+ !pip install git+https://huggingface.co/${model.id}
307
+
308
+ # and import it as a module
309
+ import ${packageName}
310
+
311
+ nlp = ${packageName}.load() # or edsnlp.load("${packageName}")
312
+ `,
313
+ ];
314
+ };
315
+
316
+ export const espnetTTS = (model: ModelData): string[] => [
317
+ `from espnet2.bin.tts_inference import Text2Speech
318
+
319
+ model = Text2Speech.from_pretrained("${model.id}")
320
+
321
+ speech, *_ = model("text to generate speech from")`,
322
+ ];
323
+
324
+ export const espnetASR = (model: ModelData): string[] => [
325
+ `from espnet2.bin.asr_inference import Speech2Text
326
+
327
+ model = Speech2Text.from_pretrained(
328
+ "${model.id}"
329
+ )
330
+
331
+ speech, rate = soundfile.read("speech.wav")
332
+ text, *_ = model(speech)[0]`,
333
+ ];
334
+
335
+ const espnetUnknown = () => [`unknown model type (must be text-to-speech or automatic-speech-recognition)`];
336
+
337
+ export const espnet = (model: ModelData): string[] => {
338
+ if (model.tags.includes("text-to-speech")) {
339
+ return espnetTTS(model);
340
+ } else if (model.tags.includes("automatic-speech-recognition")) {
341
+ return espnetASR(model);
342
+ }
343
+ return espnetUnknown();
344
+ };
345
+
346
+ export const fairseq = (model: ModelData): string[] => [
347
+ `from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
348
+
349
+ models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
350
+ "${model.id}"
351
+ )`,
352
+ ];
353
+
354
+ export const flair = (model: ModelData): string[] => [
355
+ `from flair.models import SequenceTagger
356
+
357
+ tagger = SequenceTagger.load("${model.id}")`,
358
+ ];
359
+
360
+ export const gliner = (model: ModelData): string[] => [
361
+ `from gliner import GLiNER
362
+
363
+ model = GLiNER.from_pretrained("${model.id}")`,
364
+ ];
365
+
366
+ export const htrflow = (model: ModelData): string[] => [
367
+ `# CLI usage
368
+ # see docs: https://ai-riksarkivet.github.io/htrflow/latest/getting_started/quick_start.html
369
+ htrflow pipeline <path/to/pipeline.yaml> <path/to/image>`,
370
+ `# Python usage
371
+ from htrflow.pipeline.pipeline import Pipeline
372
+ from htrflow.pipeline.steps import Task
373
+ from htrflow.models.framework.model import ModelClass
374
+
375
+ pipeline = Pipeline(
376
+ [
377
+ Task(
378
+ ModelClass, {"model": "${model.id}"}, {}
379
+ ),
380
+ ])`,
381
+ ];
382
+
383
+ export const keras = (model: ModelData): string[] => [
384
+ `# Available backend options are: "jax", "torch", "tensorflow".
385
+ import os
386
+ os.environ["KERAS_BACKEND"] = "jax"
387
+
388
+ import keras
389
+
390
+ model = keras.saving.load_model("hf://${model.id}")
391
+ `,
392
+ ];
393
+
394
+ export const keras_nlp = (model: ModelData): string[] => [
395
+ `# Available backend options are: "jax", "torch", "tensorflow".
396
+ import os
397
+ os.environ["KERAS_BACKEND"] = "jax"
398
+
399
+ import keras_nlp
400
+
401
+ tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${model.id}")
402
+ backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
403
+ `,
404
+ ];
405
+
406
+ export const keras_hub = (model: ModelData): string[] => [
407
+ `# Available backend options are: "jax", "torch", "tensorflow".
408
+ import os
409
+ os.environ["KERAS_BACKEND"] = "jax"
410
+
411
+ import keras_hub
412
+
413
+ # Load a task-specific model (*replace CausalLM with your task*)
414
+ model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
415
+
416
+ # Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
417
+ # full list here: https://keras.io/api/keras_hub/models/#api-documentation
418
+ `,
419
+ ];
420
+
421
+ export const llama_cpp_python = (model: ModelData): string[] => [
422
+ `from llama_cpp import Llama
423
+
424
+ llm = Llama.from_pretrained(
425
+ repo_id="${model.id}",
426
+ filename="{{GGUF_FILE}}",
427
+ )
428
+
429
+ llm.create_chat_completion(
430
+ messages = [
431
+ {
432
+ "role": "user",
433
+ "content": "What is the capital of France?"
434
+ }
435
+ ]
436
+ )`,
437
+ ];
438
+
439
+ export const tf_keras = (model: ModelData): string[] => [
440
+ `# Note: 'keras<3.x' or 'tf_keras' must be installed (legacy)
441
+ # See https://github.com/keras-team/tf-keras for more details.
442
+ from huggingface_hub import from_pretrained_keras
443
+
444
+ model = from_pretrained_keras("${model.id}")
445
+ `,
446
+ ];
447
+
448
+ export const mamba_ssm = (model: ModelData): string[] => [
449
+ `from mamba_ssm import MambaLMHeadModel
450
+
451
+ model = MambaLMHeadModel.from_pretrained("${model.id}")`,
452
+ ];
453
+
454
+ export const mars5_tts = (model: ModelData): string[] => [
455
+ `# Install from https://github.com/Camb-ai/MARS5-TTS
456
+
457
+ from inference import Mars5TTS
458
+ mars5 = Mars5TTS.from_pretrained("${model.id}")`,
459
+ ];
460
+
461
+ export const mesh_anything = (): string[] => [
462
+ `# Install from https://github.com/buaacyw/MeshAnything.git
463
+
464
+ from MeshAnything.models.meshanything import MeshAnything
465
+
466
+ # refer to https://github.com/buaacyw/MeshAnything/blob/main/main.py#L91 on how to define args
467
+ # and https://github.com/buaacyw/MeshAnything/blob/main/app.py regarding usage
468
+ model = MeshAnything(args)`,
469
+ ];
470
+
471
+ export const open_clip = (model: ModelData): string[] => [
472
+ `import open_clip
473
+
474
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:${model.id}')
475
+ tokenizer = open_clip.get_tokenizer('hf-hub:${model.id}')`,
476
+ ];
477
+
478
+ export const paddlenlp = (model: ModelData): string[] => {
479
+ if (model.config?.architectures?.[0]) {
480
+ const architecture = model.config.architectures[0];
481
+ return [
482
+ [
483
+ `from paddlenlp.transformers import AutoTokenizer, ${architecture}`,
484
+ "",
485
+ `tokenizer = AutoTokenizer.from_pretrained("${model.id}", from_hf_hub=True)`,
486
+ `model = ${architecture}.from_pretrained("${model.id}", from_hf_hub=True)`,
487
+ ].join("\n"),
488
+ ];
489
+ } else {
490
+ return [
491
+ [
492
+ `# ⚠️ Type of model unknown`,
493
+ `from paddlenlp.transformers import AutoTokenizer, AutoModel`,
494
+ "",
495
+ `tokenizer = AutoTokenizer.from_pretrained("${model.id}", from_hf_hub=True)`,
496
+ `model = AutoModel.from_pretrained("${model.id}", from_hf_hub=True)`,
497
+ ].join("\n"),
498
+ ];
499
+ }
500
+ };
501
+
502
+ export const pyannote_audio_pipeline = (model: ModelData): string[] => [
503
+ `from pyannote.audio import Pipeline
504
+
505
+ pipeline = Pipeline.from_pretrained("${model.id}")
506
+
507
+ # inference on the whole file
508
+ pipeline("file.wav")
509
+
510
+ # inference on an excerpt
511
+ from pyannote.core import Segment
512
+ excerpt = Segment(start=2.0, end=5.0)
513
+
514
+ from pyannote.audio import Audio
515
+ waveform, sample_rate = Audio().crop("file.wav", excerpt)
516
+ pipeline({"waveform": waveform, "sample_rate": sample_rate})`,
517
+ ];
518
+
519
+ const pyannote_audio_model = (model: ModelData): string[] => [
520
+ `from pyannote.audio import Model, Inference
521
+
522
+ model = Model.from_pretrained("${model.id}")
523
+ inference = Inference(model)
524
+
525
+ # inference on the whole file
526
+ inference("file.wav")
527
+
528
+ # inference on an excerpt
529
+ from pyannote.core import Segment
530
+ excerpt = Segment(start=2.0, end=5.0)
531
+ inference.crop("file.wav", excerpt)`,
532
+ ];
533
+
534
+ export const pyannote_audio = (model: ModelData): string[] => {
535
+ if (model.tags.includes("pyannote-audio-pipeline")) {
536
+ return pyannote_audio_pipeline(model);
537
+ }
538
+ return pyannote_audio_model(model);
539
+ };
540
+
541
+ export const relik = (model: ModelData): string[] => [
542
+ `from relik import Relik
543
+
544
+ relik = Relik.from_pretrained("${model.id}")`,
545
+ ];
546
+
547
+ const tensorflowttsTextToMel = (model: ModelData): string[] => [
548
+ `from tensorflow_tts.inference import AutoProcessor, TFAutoModel
549
+
550
+ processor = AutoProcessor.from_pretrained("${model.id}")
551
+ model = TFAutoModel.from_pretrained("${model.id}")
552
+ `,
553
+ ];
554
+
555
+ const tensorflowttsMelToWav = (model: ModelData): string[] => [
556
+ `from tensorflow_tts.inference import TFAutoModel
557
+
558
+ model = TFAutoModel.from_pretrained("${model.id}")
559
+ audios = model.inference(mels)
560
+ `,
561
+ ];
562
+
563
+ const tensorflowttsUnknown = (model: ModelData): string[] => [
564
+ `from tensorflow_tts.inference import TFAutoModel
565
+
566
+ model = TFAutoModel.from_pretrained("${model.id}")
567
+ `,
568
+ ];
569
+
570
+ export const tensorflowtts = (model: ModelData): string[] => {
571
+ if (model.tags.includes("text-to-mel")) {
572
+ return tensorflowttsTextToMel(model);
573
+ } else if (model.tags.includes("mel-to-wav")) {
574
+ return tensorflowttsMelToWav(model);
575
+ }
576
+ return tensorflowttsUnknown(model);
577
+ };
578
+
579
+ export const timm = (model: ModelData): string[] => [
580
+ `import timm
581
+
582
+ model = timm.create_model("hf_hub:${model.id}", pretrained=True)`,
583
+ ];
584
+
585
+ export const saelens = (/* model: ModelData */): string[] => [
586
+ `# pip install sae-lens
587
+ from sae_lens import SAE
588
+
589
+ sae, cfg_dict, sparsity = SAE.from_pretrained(
590
+ release = "RELEASE_ID", # e.g., "gpt2-small-res-jb". See other options in https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
591
+ sae_id = "SAE_ID", # e.g., "blocks.8.hook_resid_pre". Won't always be a hook point
592
+ )`,
593
+ ];
594
+
595
+ export const seed_story = (): string[] => [
596
+ `# seed_story_cfg_path refers to 'https://github.com/TencentARC/SEED-Story/blob/master/configs/clm_models/agent_7b_sft.yaml'
597
+ # llm_cfg_path refers to 'https://github.com/TencentARC/SEED-Story/blob/master/configs/clm_models/llama2chat7b_lora.yaml'
598
+ from omegaconf import OmegaConf
599
+ import hydra
600
+
601
+ # load Llama2
602
+ llm_cfg = OmegaConf.load(llm_cfg_path)
603
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype="fp16")
604
+
605
+ # initialize seed_story
606
+ seed_story_cfg = OmegaConf.load(seed_story_cfg_path)
607
+ seed_story = hydra.utils.instantiate(seed_story_cfg, llm=llm) `,
608
+ ];
609
+
610
+ const skopsPickle = (model: ModelData, modelFile: string) => {
611
+ return [
612
+ `import joblib
613
+ from skops.hub_utils import download
614
+ download("${model.id}", "path_to_folder")
615
+ model = joblib.load(
616
+ "${modelFile}"
617
+ )
618
+ # only load pickle files from sources you trust
619
+ # read more about it here https://skops.readthedocs.io/en/stable/persistence.html`,
620
+ ];
621
+ };
622
+
623
+ const skopsFormat = (model: ModelData, modelFile: string) => {
624
+ return [
625
+ `from skops.hub_utils import download
626
+ from skops.io import load
627
+ download("${model.id}", "path_to_folder")
628
+ # make sure model file is in skops format
629
+ # if model is a pickle file, make sure it's from a source you trust
630
+ model = load("path_to_folder/${modelFile}")`,
631
+ ];
632
+ };
633
+
634
+ const skopsJobLib = (model: ModelData) => {
635
+ return [
636
+ `from huggingface_hub import hf_hub_download
637
+ import joblib
638
+ model = joblib.load(
639
+ hf_hub_download("${model.id}", "sklearn_model.joblib")
640
+ )
641
+ # only load pickle files from sources you trust
642
+ # read more about it here https://skops.readthedocs.io/en/stable/persistence.html`,
643
+ ];
644
+ };
645
+
646
+ export const sklearn = (model: ModelData): string[] => {
647
+ if (model.tags.includes("skops")) {
648
+ const skopsmodelFile = model.config?.sklearn?.model?.file;
649
+ const skopssaveFormat = model.config?.sklearn?.model_format;
650
+ if (!skopsmodelFile) {
651
+ return [`# ⚠️ Model filename not specified in config.json`];
652
+ }
653
+ if (skopssaveFormat === "pickle") {
654
+ return skopsPickle(model, skopsmodelFile);
655
+ } else {
656
+ return skopsFormat(model, skopsmodelFile);
657
+ }
658
+ } else {
659
+ return skopsJobLib(model);
660
+ }
661
+ };
662
+
663
+ export const stable_audio_tools = (model: ModelData): string[] => [
664
+ `import torch
665
+ import torchaudio
666
+ from einops import rearrange
667
+ from stable_audio_tools import get_pretrained_model
668
+ from stable_audio_tools.inference.generation import generate_diffusion_cond
669
+
670
+ device = "cuda" if torch.cuda.is_available() else "cpu"
671
+
672
+ # Download model
673
+ model, model_config = get_pretrained_model("${model.id}")
674
+ sample_rate = model_config["sample_rate"]
675
+ sample_size = model_config["sample_size"]
676
+
677
+ model = model.to(device)
678
+
679
+ # Set up text and timing conditioning
680
+ conditioning = [{
681
+ "prompt": "128 BPM tech house drum loop",
682
+ }]
683
+
684
+ # Generate stereo audio
685
+ output = generate_diffusion_cond(
686
+ model,
687
+ conditioning=conditioning,
688
+ sample_size=sample_size,
689
+ device=device
690
+ )
691
+
692
+ # Rearrange audio batch to a single sequence
693
+ output = rearrange(output, "b d n -> d (b n)")
694
+
695
+ # Peak normalize, clip, convert to int16, and save to file
696
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
697
+ torchaudio.save("output.wav", output, sample_rate)`,
698
+ ];
699
+
700
+ export const fastai = (model: ModelData): string[] => [
701
+ `from huggingface_hub import from_pretrained_fastai
702
+
703
+ learn = from_pretrained_fastai("${model.id}")`,
704
+ ];
705
+
706
+ export const sam2 = (model: ModelData): string[] => {
707
+ const image_predictor = `# Use SAM2 with images
708
+ import torch
709
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
710
+
711
+ predictor = SAM2ImagePredictor.from_pretrained(${model.id})
712
+
713
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
714
+ predictor.set_image(<your_image>)
715
+ masks, _, _ = predictor.predict(<input_prompts>)`;
716
+
717
+ const video_predictor = `# Use SAM2 with videos
718
+ import torch
719
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
720
+
721
+ predictor = SAM2VideoPredictor.from_pretrained(${model.id})
722
+
723
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
724
+ state = predictor.init_state(<your_video>)
725
+
726
+ # add new prompts and instantly get the output on the same frame
727
+ frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
728
+
729
+ # propagate the prompts to get masklets throughout the video
730
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
731
+ ...`;
732
+ return [image_predictor, video_predictor];
733
+ };
734
+
735
+ export const sampleFactory = (model: ModelData): string[] => [
736
+ `python -m sample_factory.huggingface.load_from_hub -r ${model.id} -d ./train_dir`,
737
+ ];
738
+
739
+ function get_widget_examples_from_st_model(model: ModelData): string[] | undefined {
740
+ const widgetExample = model.widgetData?.[0] as WidgetExampleSentenceSimilarityInput | undefined;
741
+ if (widgetExample) {
742
+ return [widgetExample.source_sentence, ...widgetExample.sentences];
743
+ }
744
+ }
745
+
746
+ export const sentenceTransformers = (model: ModelData): string[] => {
747
+ const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : "";
748
+ const exampleSentences = get_widget_examples_from_st_model(model) ?? [
749
+ "The weather is lovely today.",
750
+ "It's so sunny outside!",
751
+ "He drove to the stadium.",
752
+ ];
753
+
754
+ return [
755
+ `from sentence_transformers import SentenceTransformer
756
+
757
+ model = SentenceTransformer("${model.id}"${remote_code_snippet})
758
+
759
+ sentences = ${JSON.stringify(exampleSentences, null, 4)}
760
+ embeddings = model.encode(sentences)
761
+
762
+ similarities = model.similarity(embeddings, embeddings)
763
+ print(similarities.shape)
764
+ # [${exampleSentences.length}, ${exampleSentences.length}]`,
765
+ ];
766
+ };
767
+
768
+ export const setfit = (model: ModelData): string[] => [
769
+ `from setfit import SetFitModel
770
+
771
+ model = SetFitModel.from_pretrained("${model.id}")`,
772
+ ];
773
+
774
+ export const spacy = (model: ModelData): string[] => [
775
+ `!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl
776
+
777
+ # Using spacy.load().
778
+ import spacy
779
+ nlp = spacy.load("${nameWithoutNamespace(model.id)}")
780
+
781
+ # Importing as module.
782
+ import ${nameWithoutNamespace(model.id)}
783
+ nlp = ${nameWithoutNamespace(model.id)}.load()`,
784
+ ];
785
+
786
+ export const span_marker = (model: ModelData): string[] => [
787
+ `from span_marker import SpanMarkerModel
788
+
789
+ model = SpanMarkerModel.from_pretrained("${model.id}")`,
790
+ ];
791
+
792
+ export const stanza = (model: ModelData): string[] => [
793
+ `import stanza
794
+
795
+ stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}")
796
+ nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`,
797
+ ];
798
+
799
+ const speechBrainMethod = (speechbrainInterface: string) => {
800
+ switch (speechbrainInterface) {
801
+ case "EncoderClassifier":
802
+ return "classify_file";
803
+ case "EncoderDecoderASR":
804
+ case "EncoderASR":
805
+ return "transcribe_file";
806
+ case "SpectralMaskEnhancement":
807
+ return "enhance_file";
808
+ case "SepformerSeparation":
809
+ return "separate_file";
810
+ default:
811
+ return undefined;
812
+ }
813
+ };
814
+
815
+ export const speechbrain = (model: ModelData): string[] => {
816
+ const speechbrainInterface = model.config?.speechbrain?.speechbrain_interface;
817
+ if (speechbrainInterface === undefined) {
818
+ return [`# interface not specified in config.json`];
819
+ }
820
+
821
+ const speechbrainMethod = speechBrainMethod(speechbrainInterface);
822
+ if (speechbrainMethod === undefined) {
823
+ return [`# interface in config.json invalid`];
824
+ }
825
+
826
+ return [
827
+ `from speechbrain.pretrained import ${speechbrainInterface}
828
+ model = ${speechbrainInterface}.from_hparams(
829
+ "${model.id}"
830
+ )
831
+ model.${speechbrainMethod}("file.wav")`,
832
+ ];
833
+ };
834
+
835
+ export const transformers = (model: ModelData): string[] => {
836
+ const info = model.transformersInfo;
837
+ if (!info) {
838
+ return [`# ⚠️ Type of model unknown`];
839
+ }
840
+ const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : "";
841
+
842
+ let autoSnippet: string;
843
+ if (info.processor) {
844
+ const varName =
845
+ info.processor === "AutoTokenizer"
846
+ ? "tokenizer"
847
+ : info.processor === "AutoFeatureExtractor"
848
+ ? "extractor"
849
+ : "processor";
850
+ autoSnippet = [
851
+ "# Load model directly",
852
+ `from transformers import ${info.processor}, ${info.auto_model}`,
853
+ "",
854
+ `${varName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
855
+ `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
856
+ ].join("\n");
857
+ } else {
858
+ autoSnippet = [
859
+ "# Load model directly",
860
+ `from transformers import ${info.auto_model}`,
861
+ `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
862
+ ].join("\n");
863
+ }
864
+
865
+ if (model.pipeline_tag && LIBRARY_TASK_MAPPING.transformers?.includes(model.pipeline_tag)) {
866
+ const pipelineSnippet = ["# Use a pipeline as a high-level helper", "from transformers import pipeline", ""];
867
+
868
+ if (model.tags.includes("conversational") && model.config?.tokenizer_config?.chat_template) {
869
+ pipelineSnippet.push("messages = [", ' {"role": "user", "content": "Who are you?"},', "]");
870
+ }
871
+ pipelineSnippet.push(`pipe = pipeline("${model.pipeline_tag}", model="${model.id}"` + remote_code_snippet + ")");
872
+ if (model.tags.includes("conversational") && model.config?.tokenizer_config?.chat_template) {
873
+ pipelineSnippet.push("pipe(messages)");
874
+ }
875
+
876
+ return [pipelineSnippet.join("\n"), autoSnippet];
877
+ }
878
+ return [autoSnippet];
879
+ };
880
+
881
+ export const transformersJS = (model: ModelData): string[] => {
882
+ if (!model.pipeline_tag) {
883
+ return [`// ⚠️ Unknown pipeline tag`];
884
+ }
885
+
886
+ const libName = "@huggingface/transformers";
887
+
888
+ return [
889
+ `// npm i ${libName}
890
+ import { pipeline } from '${libName}';
891
+
892
+ // Allocate pipeline
893
+ const pipe = await pipeline('${model.pipeline_tag}', '${model.id}');`,
894
+ ];
895
+ };
896
+
897
+ const peftTask = (peftTaskType?: string) => {
898
+ switch (peftTaskType) {
899
+ case "CAUSAL_LM":
900
+ return "CausalLM";
901
+ case "SEQ_2_SEQ_LM":
902
+ return "Seq2SeqLM";
903
+ case "TOKEN_CLS":
904
+ return "TokenClassification";
905
+ case "SEQ_CLS":
906
+ return "SequenceClassification";
907
+ default:
908
+ return undefined;
909
+ }
910
+ };
911
+
912
+ export const peft = (model: ModelData): string[] => {
913
+ const { base_model_name_or_path: peftBaseModel, task_type: peftTaskType } = model.config?.peft ?? {};
914
+ const pefttask = peftTask(peftTaskType);
915
+ if (!pefttask) {
916
+ return [`Task type is invalid.`];
917
+ }
918
+ if (!peftBaseModel) {
919
+ return [`Base model is not found.`];
920
+ }
921
+
922
+ return [
923
+ `from peft import PeftModel, PeftConfig
924
+ from transformers import AutoModelFor${pefttask}
925
+
926
+ config = PeftConfig.from_pretrained("${model.id}")
927
+ base_model = AutoModelFor${pefttask}.from_pretrained("${peftBaseModel}")
928
+ model = PeftModel.from_pretrained(base_model, "${model.id}")`,
929
+ ];
930
+ };
931
+
932
+ export const fasttext = (model: ModelData): string[] => [
933
+ `from huggingface_hub import hf_hub_download
934
+ import fasttext
935
+
936
+ model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`,
937
+ ];
938
+
939
+ export const stableBaselines3 = (model: ModelData): string[] => [
940
+ `from huggingface_sb3 import load_from_hub
941
+ checkpoint = load_from_hub(
942
+ repo_id="${model.id}",
943
+ filename="{MODEL FILENAME}.zip",
944
+ )`,
945
+ ];
946
+
947
+ const nemoDomainResolver = (domain: string, model: ModelData): string[] | undefined => {
948
+ switch (domain) {
949
+ case "ASR":
950
+ return [
951
+ `import nemo.collections.asr as nemo_asr
952
+ asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}")
953
+
954
+ transcriptions = asr_model.transcribe(["file.wav"])`,
955
+ ];
956
+ default:
957
+ return undefined;
958
+ }
959
+ };
960
+
961
+ export const mlAgents = (model: ModelData): string[] => [
962
+ `mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./download: string[]s"`,
963
+ ];
964
+
965
+ export const sentis = (/* model: ModelData */): string[] => [
966
+ `string modelName = "[Your model name here].sentis";
967
+ Model model = ModelLoader.Load(Application.streamingAssetsPath + "/" + modelName);
968
+ IWorker engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, model);
969
+ // Please see provided C# file for more details
970
+ `,
971
+ ];
972
+
973
+ export const vfimamba = (model: ModelData): string[] => [
974
+ `from Trainer_finetune import Model
975
+
976
+ model = Model.from_pretrained("${model.id}")`,
977
+ ];
978
+
979
+ export const voicecraft = (model: ModelData): string[] => [
980
+ `from voicecraft import VoiceCraft
981
+
982
+ model = VoiceCraft.from_pretrained("${model.id}")`,
983
+ ];
984
+
985
+ export const chattts = (): string[] => [
986
+ `import ChatTTS
987
+ import torchaudio
988
+
989
+ chat = ChatTTS.Chat()
990
+ chat.load_models(compile=False) # Set to True for better performance
991
+
992
+ texts = ["PUT YOUR TEXT HERE",]
993
+
994
+ wavs = chat.infer(texts, )
995
+
996
+ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)`,
997
+ ];
998
+
999
+ export const yolov10 = (model: ModelData): string[] => [
1000
+ `from ultralytics import YOLOv10
1001
+
1002
+ model = YOLOv10.from_pretrained("${model.id}")
1003
+ source = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1004
+ model.predict(source=source, save=True)
1005
+ `,
1006
+ ];
1007
+
1008
+ export const birefnet = (model: ModelData): string[] => [
1009
+ `# Option 1: use with transformers
1010
+
1011
+ from transformers import AutoModelForImageSegmentation
1012
+ birefnet = AutoModelForImageSegmentation.from_pretrained("${model.id}", trust_remote_code=True)
1013
+ `,
1014
+ `# Option 2: use with BiRefNet
1015
+
1016
+ # Install from https://github.com/ZhengPeng7/BiRefNet
1017
+
1018
+ from models.birefnet import BiRefNet
1019
+ model = BiRefNet.from_pretrained("${model.id}")`,
1020
+ ];
1021
+
1022
+ export const mlx = (model: ModelData): string[] => [
1023
+ `pip install huggingface_hub hf_transfer
1024
+
1025
+ export HF_HUB_ENABLE_HF_TRANS: string[]FER=1
1026
+ huggingface-cli download --local-dir ${nameWithoutNamespace(model.id)} ${model.id}`,
1027
+ ];
1028
+
1029
+ export const mlxim = (model: ModelData): string[] => [
1030
+ `from mlxim.model import create_model
1031
+
1032
+ model = create_model(${model.id})`,
1033
+ ];
1034
+
1035
+ export const model2vec = (model: ModelData): string[] => [
1036
+ `from model2vec import StaticModel
1037
+
1038
+ model = StaticModel.from_pretrained("${model.id}")`,
1039
+ ];
1040
+
1041
+ export const nemo = (model: ModelData): string[] => {
1042
+ let command: string[] | undefined = undefined;
1043
+ // Resolve the tag to a nemo domain/sub-domain
1044
+ if (model.tags.includes("automatic-speech-recognition")) {
1045
+ command = nemoDomainResolver("ASR", model);
1046
+ }
1047
+
1048
+ return command ?? [`# tag did not correspond to a valid NeMo domain.`];
1049
+ };
1050
+
1051
+ export const pxia = (model: ModelData): string[] => [
1052
+ `from pxia import AutoModel
1053
+
1054
+ model = AutoModel.from_pretrained("${model.id}")`,
1055
+ ];
1056
+
1057
+ export const pythae = (model: ModelData): string[] => [
1058
+ `from pythae.models import AutoModel
1059
+
1060
+ model = AutoModel.load_from_hf_hub("${model.id}")`,
1061
+ ];
1062
+
1063
+ const musicgen = (model: ModelData): string[] => [
1064
+ `from audiocraft.models import MusicGen
1065
+
1066
+ model = MusicGen.get_pretrained("${model.id}")
1067
+
1068
+ descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
1069
+ wav = model.generate(descriptions) # generates 3 samples.`,
1070
+ ];
1071
+
1072
+ const magnet = (model: ModelData): string[] => [
1073
+ `from audiocraft.models import MAGNeT
1074
+
1075
+ model = MAGNeT.get_pretrained("${model.id}")
1076
+
1077
+ descriptions = ['disco beat', 'energetic EDM', 'funky groove']
1078
+ wav = model.generate(descriptions) # generates 3 samples.`,
1079
+ ];
1080
+
1081
+ const audiogen = (model: ModelData): string[] => [
1082
+ `from audiocraft.models import AudioGen
1083
+
1084
+ model = AudioGen.get_pretrained("${model.id}")
1085
+ model.set_generation_params(duration=5) # generate 5 seconds.
1086
+ descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor']
1087
+ wav = model.generate(descriptions) # generates 3 samples.`,
1088
+ ];
1089
+
1090
+ export const audiocraft = (model: ModelData): string[] => {
1091
+ if (model.tags.includes("musicgen")) {
1092
+ return musicgen(model);
1093
+ } else if (model.tags.includes("audiogen")) {
1094
+ return audiogen(model);
1095
+ } else if (model.tags.includes("magnet")) {
1096
+ return magnet(model);
1097
+ } else {
1098
+ return [`# Type of model unknown.`];
1099
+ }
1100
+ };
1101
+
1102
+ export const whisperkit = (): string[] => [
1103
+ `# Install CLI with Homebrew on macOS device
1104
+ brew install whisperkit-cli
1105
+
1106
+ # View all available inference options
1107
+ whisperkit-cli transcribe --help
1108
+
1109
+ # Download and run inference using whisper base model
1110
+ whisperkit-cli transcribe --audio-path /path/to/audio.mp3
1111
+
1112
+ # Or use your preferred model variant
1113
+ whisperkit-cli transcribe --model "large-v3" --model-prefix "distil" --audio-path /path/to/audio.mp3 --verbose`,
1114
+ ];
1115
+
1116
+ export const threedtopia_xl = (model: ModelData): string[] => [
1117
+ `from threedtopia_xl.models import threedtopia_xl
1118
+
1119
+ model = threedtopia_xl.from_pretrained("${model.id}")
1120
+ model.generate(cond="path/to/image.png")`,
1121
+ ];
1122
+
1123
+ export const hezar = (model: ModelData): string[] => [
1124
+ `from hezar import Model
1125
+
1126
+ model = Model.load("${model.id}")`,
1127
+ ];
1128
+ //#endregion