pixeltable 0.4.17__py3-none-any.whl → 0.4.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/catalog.py +26 -19
- pixeltable/catalog/table.py +33 -14
- pixeltable/catalog/table_version.py +16 -12
- pixeltable/dataframe.py +1 -1
- pixeltable/env.py +4 -0
- pixeltable/exec/exec_context.py +15 -2
- pixeltable/exec/sql_node.py +3 -2
- pixeltable/functions/huggingface.py +1031 -2
- pixeltable/functions/video.py +34 -7
- pixeltable/globals.py +23 -4
- pixeltable/iterators/document.py +88 -57
- pixeltable/iterators/video.py +58 -24
- pixeltable/plan.py +2 -6
- pixeltable/store.py +24 -3
- pixeltable/utils/av.py +66 -38
- {pixeltable-0.4.17.dist-info → pixeltable-0.4.18.dist-info}/METADATA +4 -4
- {pixeltable-0.4.17.dist-info → pixeltable-0.4.18.dist-info}/RECORD +20 -20
- {pixeltable-0.4.17.dist-info → pixeltable-0.4.18.dist-info}/WHEEL +0 -0
- {pixeltable-0.4.17.dist-info → pixeltable-0.4.18.dist-info}/entry_points.txt +0 -0
- {pixeltable-0.4.17.dist-info → pixeltable-0.4.18.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,8 +7,10 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
|
|
|
7
7
|
UDFs).
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from typing import Any, Callable, Optional, TypeVar
|
|
10
|
+
from typing import Any, Callable, Literal, Optional, TypeVar
|
|
11
11
|
|
|
12
|
+
import av
|
|
13
|
+
import numpy as np
|
|
12
14
|
import PIL.Image
|
|
13
15
|
|
|
14
16
|
import pixeltable as pxt
|
|
@@ -18,6 +20,9 @@ from pixeltable import env
|
|
|
18
20
|
from pixeltable.func import Batch
|
|
19
21
|
from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
|
|
20
22
|
from pixeltable.utils.code import local_public_names
|
|
23
|
+
from pixeltable.utils.local_store import TempStore
|
|
24
|
+
|
|
25
|
+
T = TypeVar('T')
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
@pxt.udf(batch_size=32)
|
|
@@ -454,7 +459,1031 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
|
|
|
454
459
|
return {'image': {'width': image.width, 'height': image.height}, 'annotations': annotations}
|
|
455
460
|
|
|
456
461
|
|
|
457
|
-
|
|
462
|
+
@pxt.udf
|
|
463
|
+
def text_generation(text: str, *, model_id: str, model_kwargs: Optional[dict[str, Any]] = None) -> str:
|
|
464
|
+
"""
|
|
465
|
+
Generates text using a pretrained language model. `model_id` should be a reference to a pretrained
|
|
466
|
+
[text generation model](https://huggingface.co/models?pipeline_tag=text-generation).
|
|
467
|
+
|
|
468
|
+
__Requirements:__
|
|
469
|
+
|
|
470
|
+
- `pip install torch transformers`
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
text: The input text to continue/complete.
|
|
474
|
+
model_id: The pretrained model to use for text generation.
|
|
475
|
+
model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`,
|
|
476
|
+
`temperature`, etc. See the
|
|
477
|
+
[Hugging Face text_generation documentation](https://huggingface.co/docs/inference-providers/en/tasks/text-generation)
|
|
478
|
+
for details.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
The generated text completion.
|
|
482
|
+
|
|
483
|
+
Examples:
|
|
484
|
+
Add a computed column that generates text completions using the `Qwen/Qwen3-0.6B` model:
|
|
485
|
+
|
|
486
|
+
>>> tbl.add_computed_column(completion=text_generation(
|
|
487
|
+
... tbl.prompt,
|
|
488
|
+
... model_id='Qwen/Qwen3-0.6B',
|
|
489
|
+
... model_kwargs={'temperature': 0.5, 'max_length': 150}
|
|
490
|
+
... ))
|
|
491
|
+
"""
|
|
492
|
+
env.Env.get().require_package('transformers')
|
|
493
|
+
device = resolve_torch_device('auto')
|
|
494
|
+
import torch
|
|
495
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
496
|
+
|
|
497
|
+
if model_kwargs is None:
|
|
498
|
+
model_kwargs = {}
|
|
499
|
+
|
|
500
|
+
model = _lookup_model(model_id, AutoModelForCausalLM.from_pretrained, device=device)
|
|
501
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
502
|
+
|
|
503
|
+
if tokenizer.pad_token is None:
|
|
504
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
505
|
+
|
|
506
|
+
with torch.no_grad():
|
|
507
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
|
|
508
|
+
outputs = model.generate(**inputs.to(device), pad_token_id=tokenizer.eos_token_id, **model_kwargs)
|
|
509
|
+
|
|
510
|
+
input_length = len(inputs['input_ids'][0])
|
|
511
|
+
generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
|
512
|
+
return generated_text
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
@pxt.udf(batch_size=16)
|
|
516
|
+
def text_classification(text: Batch[str], *, model_id: str, top_k: int = 5) -> Batch[list[dict[str, Any]]]:
|
|
517
|
+
"""
|
|
518
|
+
Classifies text using a pretrained classification model. `model_id` should be a reference to a pretrained
|
|
519
|
+
[text classification model](https://huggingface.co/models?pipeline_tag=text-classification)
|
|
520
|
+
such as BERT, RoBERTa, or DistilBERT.
|
|
521
|
+
|
|
522
|
+
__Requirements:__
|
|
523
|
+
|
|
524
|
+
- `pip install torch transformers`
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
text: The text to classify.
|
|
528
|
+
model_id: The pretrained model to use for classification.
|
|
529
|
+
top_k: The number of top predictions to return.
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
A dictionary containing classification results with scores, labels, and label text.
|
|
533
|
+
|
|
534
|
+
Examples:
|
|
535
|
+
Add a computed column for sentiment analysis:
|
|
536
|
+
|
|
537
|
+
>>> tbl.add_computed_column(sentiment=text_classification(
|
|
538
|
+
... tbl.review_text,
|
|
539
|
+
... model_id='cardiffnlp/twitter-roberta-base-sentiment-latest'
|
|
540
|
+
... ))
|
|
541
|
+
"""
|
|
542
|
+
env.Env.get().require_package('transformers')
|
|
543
|
+
device = resolve_torch_device('auto')
|
|
544
|
+
import torch
|
|
545
|
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
546
|
+
|
|
547
|
+
model = _lookup_model(model_id, AutoModelForSequenceClassification.from_pretrained, device=device)
|
|
548
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
549
|
+
|
|
550
|
+
with torch.no_grad():
|
|
551
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
552
|
+
outputs = model(**inputs.to(device))
|
|
553
|
+
logits = outputs.logits
|
|
554
|
+
|
|
555
|
+
probs = torch.softmax(logits, dim=-1)
|
|
556
|
+
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
|
557
|
+
|
|
558
|
+
results = []
|
|
559
|
+
for i in range(len(text)):
|
|
560
|
+
# Return as list of individual classification items for HuggingFace compatibility
|
|
561
|
+
classification_items = []
|
|
562
|
+
for k in range(top_k_probs.shape[1]):
|
|
563
|
+
classification_items.append(
|
|
564
|
+
{
|
|
565
|
+
'label': top_k_indices[i, k].item(),
|
|
566
|
+
'label_text': model.config.id2label[top_k_indices[i, k].item()],
|
|
567
|
+
'score': top_k_probs[i, k].item(),
|
|
568
|
+
}
|
|
569
|
+
)
|
|
570
|
+
results.append(classification_items)
|
|
571
|
+
|
|
572
|
+
return results
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
@pxt.udf(batch_size=4)
|
|
576
|
+
def image_captioning(
|
|
577
|
+
image: Batch[PIL.Image.Image], *, model_id: str, model_kwargs: Optional[dict[str, Any]] = None
|
|
578
|
+
) -> Batch[str]:
|
|
579
|
+
"""
|
|
580
|
+
Generates captions for images using a pretrained image captioning model. `model_id` should be a reference to a
|
|
581
|
+
pretrained [image-to-text model](https://huggingface.co/models?pipeline_tag=image-to-text) such as BLIP,
|
|
582
|
+
Git, or LLaVA.
|
|
583
|
+
|
|
584
|
+
__Requirements:__
|
|
585
|
+
|
|
586
|
+
- `pip install torch transformers`
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
image: The image to caption.
|
|
590
|
+
model_id: The pretrained model to use for captioning.
|
|
591
|
+
model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
The generated caption text.
|
|
595
|
+
|
|
596
|
+
Examples:
|
|
597
|
+
Add a computed column `caption` to an existing table `tbl` that generates captions using the
|
|
598
|
+
`Salesforce/blip-image-captioning-base` model:
|
|
599
|
+
|
|
600
|
+
>>> tbl.add_computed_column(caption=image_captioning(
|
|
601
|
+
... tbl.image,
|
|
602
|
+
... model_id='Salesforce/blip-image-captioning-base',
|
|
603
|
+
... model_kwargs={'max_length': 30}
|
|
604
|
+
... ))
|
|
605
|
+
"""
|
|
606
|
+
env.Env.get().require_package('transformers')
|
|
607
|
+
device = resolve_torch_device('auto')
|
|
608
|
+
import torch
|
|
609
|
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
610
|
+
|
|
611
|
+
if model_kwargs is None:
|
|
612
|
+
model_kwargs = {}
|
|
613
|
+
|
|
614
|
+
model = _lookup_model(model_id, AutoModelForVision2Seq.from_pretrained, device=device)
|
|
615
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
616
|
+
normalized_images = [normalize_image_mode(img) for img in image]
|
|
617
|
+
|
|
618
|
+
with torch.no_grad():
|
|
619
|
+
inputs = processor(images=normalized_images, return_tensors='pt')
|
|
620
|
+
outputs = model.generate(**inputs.to(device), **model_kwargs)
|
|
621
|
+
|
|
622
|
+
captions = processor.batch_decode(outputs, skip_special_tokens=True)
|
|
623
|
+
return captions
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
@pxt.udf(batch_size=8)
|
|
627
|
+
def summarization(text: Batch[str], *, model_id: str, model_kwargs: Optional[dict[str, Any]] = None) -> Batch[str]:
|
|
628
|
+
"""
|
|
629
|
+
Summarizes text using a pretrained summarization model. `model_id` should be a reference to a pretrained
|
|
630
|
+
[summarization model](https://huggingface.co/models?pipeline_tag=summarization) such as BART, T5, or Pegasus.
|
|
631
|
+
|
|
632
|
+
__Requirements:__
|
|
633
|
+
|
|
634
|
+
- `pip install torch transformers`
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
text: The text to summarize.
|
|
638
|
+
model_id: The pretrained model to use for summarization.
|
|
639
|
+
model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
The generated summary text.
|
|
643
|
+
|
|
644
|
+
Examples:
|
|
645
|
+
Add a computed column that summarizes documents:
|
|
646
|
+
|
|
647
|
+
>>> tbl.add_computed_column(summary=text_summarization(
|
|
648
|
+
... tbl.document_text,
|
|
649
|
+
... model_id='facebook/bart-large-cnn',
|
|
650
|
+
... max_length=100
|
|
651
|
+
... ))
|
|
652
|
+
"""
|
|
653
|
+
env.Env.get().require_package('transformers')
|
|
654
|
+
device = resolve_torch_device('auto')
|
|
655
|
+
import torch
|
|
656
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
657
|
+
|
|
658
|
+
if model_kwargs is None:
|
|
659
|
+
model_kwargs = {}
|
|
660
|
+
|
|
661
|
+
model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
|
|
662
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
663
|
+
|
|
664
|
+
with torch.no_grad():
|
|
665
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
|
|
666
|
+
outputs = model.generate(**inputs.to(device), **model_kwargs)
|
|
667
|
+
|
|
668
|
+
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@pxt.udf
|
|
672
|
+
def token_classification(
|
|
673
|
+
text: str, *, model_id: str, aggregation_strategy: Literal['simple', 'first', 'average', 'max'] = 'simple'
|
|
674
|
+
) -> list[dict[str, Any]]:
|
|
675
|
+
"""
|
|
676
|
+
Extracts named entities from text using a pretrained named entity recognition (NER) model.
|
|
677
|
+
`model_id` should be a reference to a pretrained
|
|
678
|
+
[token classification model](https://huggingface.co/models?pipeline_tag=token-classification) for NER.
|
|
679
|
+
|
|
680
|
+
__Requirements:__
|
|
681
|
+
|
|
682
|
+
- `pip install torch transformers`
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
text: The text to analyze for named entities.
|
|
686
|
+
model_id: The pretrained model to use.
|
|
687
|
+
aggregation_strategy: Method used to aggregate tokens.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
A list of dictionaries containing entity information (text, label, confidence, start, end).
|
|
691
|
+
|
|
692
|
+
Examples:
|
|
693
|
+
Add a computed column that extracts named entities:
|
|
694
|
+
|
|
695
|
+
>>> tbl.add_computed_column(entities=token_classification(
|
|
696
|
+
... tbl.text,
|
|
697
|
+
... model_id='dbmdz/bert-large-cased-finetuned-conll03-english'
|
|
698
|
+
... ))
|
|
699
|
+
"""
|
|
700
|
+
env.Env.get().require_package('transformers')
|
|
701
|
+
device = resolve_torch_device('auto')
|
|
702
|
+
import torch
|
|
703
|
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
|
704
|
+
|
|
705
|
+
# Follow direct model loading pattern like other best practice functions
|
|
706
|
+
model = _lookup_model(model_id, AutoModelForTokenClassification.from_pretrained, device=device)
|
|
707
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
708
|
+
|
|
709
|
+
# Validate aggregation strategy
|
|
710
|
+
valid_strategies = {'simple', 'first', 'average', 'max'}
|
|
711
|
+
if aggregation_strategy not in valid_strategies:
|
|
712
|
+
raise excs.Error(
|
|
713
|
+
f'Invalid aggregation_strategy {aggregation_strategy!r}. Must be one of: {", ".join(valid_strategies)}'
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
with torch.no_grad():
|
|
717
|
+
# Tokenize with special tokens and return offsets for entity extraction
|
|
718
|
+
inputs = tokenizer(
|
|
719
|
+
text,
|
|
720
|
+
return_tensors='pt',
|
|
721
|
+
truncation=True,
|
|
722
|
+
max_length=512,
|
|
723
|
+
return_offsets_mapping=True,
|
|
724
|
+
add_special_tokens=True,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Get model predictions
|
|
728
|
+
outputs = model(**{k: v.to(device) for k, v in inputs.items() if k != 'offset_mapping'})
|
|
729
|
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
|
730
|
+
|
|
731
|
+
# Get the predicted labels and confidence scores
|
|
732
|
+
predicted_token_classes = predictions.argmax(dim=-1).squeeze().tolist()
|
|
733
|
+
confidence_scores = predictions.max(dim=-1).values.squeeze().tolist()
|
|
734
|
+
|
|
735
|
+
# Handle single token case
|
|
736
|
+
if not isinstance(predicted_token_classes, list):
|
|
737
|
+
predicted_token_classes = [predicted_token_classes]
|
|
738
|
+
confidence_scores = [confidence_scores]
|
|
739
|
+
|
|
740
|
+
# Extract entities from predictions
|
|
741
|
+
entities = []
|
|
742
|
+
offset_mapping = inputs['offset_mapping'][0].tolist()
|
|
743
|
+
|
|
744
|
+
current_entity = None
|
|
745
|
+
|
|
746
|
+
for token_class, confidence, (start_offset, end_offset) in zip(
|
|
747
|
+
predicted_token_classes, confidence_scores, offset_mapping
|
|
748
|
+
):
|
|
749
|
+
# Skip special tokens (offset is (0, 0))
|
|
750
|
+
if start_offset == 0 and end_offset == 0:
|
|
751
|
+
continue
|
|
752
|
+
|
|
753
|
+
label = model.config.id2label[token_class]
|
|
754
|
+
|
|
755
|
+
# Skip 'O' (outside) labels
|
|
756
|
+
if label == 'O':
|
|
757
|
+
if current_entity:
|
|
758
|
+
entities.append(current_entity)
|
|
759
|
+
current_entity = None
|
|
760
|
+
continue
|
|
761
|
+
|
|
762
|
+
# Parse BIO/BILOU tags
|
|
763
|
+
if label.startswith('B-') or (label.startswith('I-') and current_entity is None):
|
|
764
|
+
# Begin new entity
|
|
765
|
+
if current_entity:
|
|
766
|
+
entities.append(current_entity)
|
|
767
|
+
|
|
768
|
+
entity_type = label[2:] if label.startswith(('B-', 'I-')) else label
|
|
769
|
+
current_entity = {
|
|
770
|
+
'word': text[start_offset:end_offset],
|
|
771
|
+
'entity_group': entity_type,
|
|
772
|
+
'score': float(confidence),
|
|
773
|
+
'start': start_offset,
|
|
774
|
+
'end': end_offset,
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
elif label.startswith('I-') and current_entity:
|
|
778
|
+
# Continue current entity
|
|
779
|
+
entity_type = label[2:]
|
|
780
|
+
if current_entity['entity_group'] == entity_type:
|
|
781
|
+
# Extend the current entity
|
|
782
|
+
current_entity['word'] = text[current_entity['start'] : end_offset]
|
|
783
|
+
current_entity['end'] = end_offset
|
|
784
|
+
|
|
785
|
+
# Update confidence based on aggregation strategy
|
|
786
|
+
if aggregation_strategy == 'average':
|
|
787
|
+
# Simple average (could be improved with token count weighting)
|
|
788
|
+
current_entity['score'] = (current_entity['score'] + float(confidence)) / 2
|
|
789
|
+
elif aggregation_strategy == 'max':
|
|
790
|
+
current_entity['score'] = max(current_entity['score'], float(confidence))
|
|
791
|
+
elif aggregation_strategy == 'first':
|
|
792
|
+
pass # Keep first confidence
|
|
793
|
+
# 'simple' uses the same logic as 'first'
|
|
794
|
+
else:
|
|
795
|
+
# Different entity type, start new entity
|
|
796
|
+
entities.append(current_entity)
|
|
797
|
+
current_entity = {
|
|
798
|
+
'word': text[start_offset:end_offset],
|
|
799
|
+
'entity_group': entity_type,
|
|
800
|
+
'score': float(confidence),
|
|
801
|
+
'start': start_offset,
|
|
802
|
+
'end': end_offset,
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
# Don't forget the last entity
|
|
806
|
+
if current_entity:
|
|
807
|
+
entities.append(current_entity)
|
|
808
|
+
|
|
809
|
+
return entities
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
@pxt.udf
|
|
813
|
+
def question_answering(context: str, question: str, *, model_id: str) -> dict[str, Any]:
|
|
814
|
+
"""
|
|
815
|
+
Answers questions based on provided context using a pretrained QA model. `model_id` should be a reference to a
|
|
816
|
+
pretrained [question answering model](https://huggingface.co/models?pipeline_tag=question-answering) such as
|
|
817
|
+
BERT or RoBERTa.
|
|
818
|
+
|
|
819
|
+
__Requirements:__
|
|
820
|
+
|
|
821
|
+
- `pip install torch transformers`
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
context: The context text containing the answer.
|
|
825
|
+
question: The question to answer.
|
|
826
|
+
model_id: The pretrained QA model to use.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
A dictionary containing the answer, confidence score, and start/end positions.
|
|
830
|
+
|
|
831
|
+
Examples:
|
|
832
|
+
Add a computed column that answers questions based on document context:
|
|
833
|
+
|
|
834
|
+
>>> tbl.add_computed_column(answer=question_answering(
|
|
835
|
+
... tbl.document_text,
|
|
836
|
+
... tbl.question,
|
|
837
|
+
... model_id='deepset/roberta-base-squad2'
|
|
838
|
+
... ))
|
|
839
|
+
"""
|
|
840
|
+
env.Env.get().require_package('transformers')
|
|
841
|
+
device = resolve_torch_device('auto')
|
|
842
|
+
import torch
|
|
843
|
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
|
|
844
|
+
|
|
845
|
+
model = _lookup_model(model_id, AutoModelForQuestionAnswering.from_pretrained, device=device)
|
|
846
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
847
|
+
|
|
848
|
+
with torch.no_grad():
|
|
849
|
+
# Tokenize the question and context
|
|
850
|
+
inputs = tokenizer.encode_plus(
|
|
851
|
+
question, context, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
# Get model predictions
|
|
855
|
+
outputs = model(**inputs.to(device))
|
|
856
|
+
start_scores = outputs.start_logits
|
|
857
|
+
end_scores = outputs.end_logits
|
|
858
|
+
|
|
859
|
+
# Find the tokens with the highest start and end scores
|
|
860
|
+
start_idx = torch.argmax(start_scores)
|
|
861
|
+
end_idx = torch.argmax(end_scores)
|
|
862
|
+
|
|
863
|
+
# Ensure end_idx >= start_idx
|
|
864
|
+
end_idx = torch.max(end_idx, start_idx)
|
|
865
|
+
|
|
866
|
+
# Convert token positions to string
|
|
867
|
+
input_ids = inputs['input_ids'][0]
|
|
868
|
+
|
|
869
|
+
# Extract answer tokens
|
|
870
|
+
answer_tokens = input_ids[start_idx : end_idx + 1]
|
|
871
|
+
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
|
872
|
+
|
|
873
|
+
# Calculate confidence score
|
|
874
|
+
start_probs = torch.softmax(start_scores, dim=1)
|
|
875
|
+
end_probs = torch.softmax(end_scores, dim=1)
|
|
876
|
+
confidence = float(start_probs[0][start_idx] * end_probs[0][end_idx])
|
|
877
|
+
|
|
878
|
+
return {'answer': answer.strip(), 'score': confidence, 'start': int(start_idx), 'end': int(end_idx)}
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
@pxt.udf(batch_size=8)
|
|
882
|
+
def translation(
|
|
883
|
+
text: Batch[str], *, model_id: str, src_lang: Optional[str] = None, target_lang: Optional[str] = None
|
|
884
|
+
) -> Batch[str]:
|
|
885
|
+
"""
|
|
886
|
+
Translates text using a pretrained translation model. `model_id` should be a reference to a pretrained
|
|
887
|
+
[translation model](https://huggingface.co/models?pipeline_tag=translation) such as MarianMT or T5.
|
|
888
|
+
|
|
889
|
+
__Requirements:__
|
|
890
|
+
|
|
891
|
+
- `pip install torch transformers sentencepiece`
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
text: The text to translate.
|
|
895
|
+
model_id: The pretrained translation model to use.
|
|
896
|
+
src_lang: Source language code (optional, can be inferred from model).
|
|
897
|
+
target_lang: Target language code (optional, can be inferred from model).
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
The translated text.
|
|
901
|
+
|
|
902
|
+
Examples:
|
|
903
|
+
Add a computed column that translates text:
|
|
904
|
+
|
|
905
|
+
>>> tbl.add_computed_column(french_text=translation(
|
|
906
|
+
... tbl.english_text,
|
|
907
|
+
... model_id='Helsinki-NLP/opus-mt-en-fr',
|
|
908
|
+
... src_lang='en',
|
|
909
|
+
... target_lang='fr'
|
|
910
|
+
... ))
|
|
911
|
+
"""
|
|
912
|
+
env.Env.get().require_package('transformers')
|
|
913
|
+
device = resolve_torch_device('auto')
|
|
914
|
+
import torch
|
|
915
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
916
|
+
|
|
917
|
+
model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
|
|
918
|
+
tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
|
|
919
|
+
lang_code_to_id: dict | None = getattr(tokenizer, 'lang_code_to_id', {})
|
|
920
|
+
|
|
921
|
+
# Language validation - following speech2text_for_conditional_generation pattern
|
|
922
|
+
if src_lang is not None and src_lang not in lang_code_to_id:
|
|
923
|
+
raise excs.Error(
|
|
924
|
+
f'Source language code {src_lang!r} is not supported by the model {model_id!r}. '
|
|
925
|
+
f'Supported languages are: {list(lang_code_to_id.keys())}'
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
if target_lang is not None and target_lang not in lang_code_to_id:
|
|
929
|
+
raise excs.Error(
|
|
930
|
+
f'Target language code {target_lang!r} is not supported by the model {model_id!r}. '
|
|
931
|
+
f'Supported languages are: {list(lang_code_to_id.keys())}'
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
with torch.no_grad():
|
|
935
|
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
936
|
+
|
|
937
|
+
# Set forced_bos_token_id for target language if supported
|
|
938
|
+
generate_kwargs = {'max_length': 512, 'num_beams': 4, 'early_stopping': True}
|
|
939
|
+
|
|
940
|
+
if target_lang is not None:
|
|
941
|
+
generate_kwargs['forced_bos_token_id'] = lang_code_to_id[target_lang]
|
|
942
|
+
|
|
943
|
+
outputs = model.generate(**inputs.to(device), **generate_kwargs)
|
|
944
|
+
|
|
945
|
+
# Decode all outputs at once
|
|
946
|
+
translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
947
|
+
return translations
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
@pxt.udf
|
|
951
|
+
def text_to_image(
|
|
952
|
+
prompt: str,
|
|
953
|
+
*,
|
|
954
|
+
model_id: str,
|
|
955
|
+
height: int = 512,
|
|
956
|
+
width: int = 512,
|
|
957
|
+
seed: Optional[int] = None,
|
|
958
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
959
|
+
) -> PIL.Image.Image:
|
|
960
|
+
"""
|
|
961
|
+
Generates images from text prompts using a pretrained text-to-image model. `model_id` should be a reference to a
|
|
962
|
+
pretrained [text-to-image model](https://huggingface.co/models?pipeline_tag=text-to-image) such as
|
|
963
|
+
Stable Diffusion or FLUX.
|
|
964
|
+
|
|
965
|
+
__Requirements:__
|
|
966
|
+
|
|
967
|
+
- `pip install torch transformers diffusers accelerate`
|
|
968
|
+
|
|
969
|
+
Args:
|
|
970
|
+
prompt: The text prompt describing the desired image.
|
|
971
|
+
model_id: The pretrained text-to-image model to use.
|
|
972
|
+
height: Height of the generated image in pixels.
|
|
973
|
+
width: Width of the generated image in pixels.
|
|
974
|
+
seed: Optional random seed for reproducibility.
|
|
975
|
+
model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
|
|
976
|
+
`guidance_scale`, or `negative_prompt`.
|
|
977
|
+
|
|
978
|
+
Returns:
|
|
979
|
+
The generated Image.
|
|
980
|
+
|
|
981
|
+
Examples:
|
|
982
|
+
Add a computed column that generates images from text prompts:
|
|
983
|
+
|
|
984
|
+
>>> tbl.add_computed_column(generated_image=text_to_image(
|
|
985
|
+
... tbl.prompt,
|
|
986
|
+
... model_id='stable-diffusion-v1.5/stable-diffusion-v1-5',
|
|
987
|
+
... height=512,
|
|
988
|
+
... width=512,
|
|
989
|
+
... model_kwargs={'num_inference_steps': 25},
|
|
990
|
+
... ))
|
|
991
|
+
"""
|
|
992
|
+
env.Env.get().require_package('transformers')
|
|
993
|
+
env.Env.get().require_package('diffusers')
|
|
994
|
+
env.Env.get().require_package('accelerate')
|
|
995
|
+
device = resolve_torch_device('auto', allow_mps=False)
|
|
996
|
+
import torch
|
|
997
|
+
from diffusers import AutoPipelineForText2Image
|
|
998
|
+
|
|
999
|
+
if model_kwargs is None:
|
|
1000
|
+
model_kwargs = {}
|
|
1001
|
+
|
|
1002
|
+
# Parameter validation - following best practices pattern
|
|
1003
|
+
if height <= 0 or width <= 0:
|
|
1004
|
+
raise excs.Error(f'Height ({height}) and width ({width}) must be positive integers')
|
|
1005
|
+
|
|
1006
|
+
if height % 8 != 0 or width % 8 != 0:
|
|
1007
|
+
raise excs.Error(f'Height ({height}) and width ({width}) must be divisible by 8 for most diffusion models')
|
|
1008
|
+
|
|
1009
|
+
pipeline = _lookup_model(
|
|
1010
|
+
model_id,
|
|
1011
|
+
lambda x: AutoPipelineForText2Image.from_pretrained(
|
|
1012
|
+
x,
|
|
1013
|
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
|
1014
|
+
device_map='auto' if device == 'cuda' else None,
|
|
1015
|
+
safety_checker=None, # Disable safety checker for performance
|
|
1016
|
+
requires_safety_checker=False,
|
|
1017
|
+
),
|
|
1018
|
+
device=device,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
try:
|
|
1022
|
+
if device == 'cuda' and hasattr(pipeline, 'enable_model_cpu_offload'):
|
|
1023
|
+
pipeline.enable_model_cpu_offload()
|
|
1024
|
+
if hasattr(pipeline, 'enable_memory_efficient_attention'):
|
|
1025
|
+
pipeline.enable_memory_efficient_attention()
|
|
1026
|
+
except Exception:
|
|
1027
|
+
pass # Ignore optimization failures
|
|
1028
|
+
|
|
1029
|
+
generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
|
|
1030
|
+
|
|
1031
|
+
with torch.no_grad():
|
|
1032
|
+
result = pipeline(prompt, height=height, width=width, generator=generator, **model_kwargs)
|
|
1033
|
+
return result.images[0]
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
@pxt.udf
|
|
1037
|
+
def text_to_speech(
|
|
1038
|
+
text: str, *, model_id: str, speaker_id: Optional[int] = None, vocoder: Optional[str] = None
|
|
1039
|
+
) -> pxt.Audio:
|
|
1040
|
+
"""
|
|
1041
|
+
Converts text to speech using a pretrained TTS model. `model_id` should be a reference to a
|
|
1042
|
+
pretrained [text-to-speech model](https://huggingface.co/models?pipeline_tag=text-to-speech).
|
|
1043
|
+
|
|
1044
|
+
__Requirements:__
|
|
1045
|
+
|
|
1046
|
+
- `pip install torch transformers datasets soundfile`
|
|
1047
|
+
|
|
1048
|
+
Args:
|
|
1049
|
+
text: The text to convert to speech.
|
|
1050
|
+
model_id: The pretrained TTS model to use.
|
|
1051
|
+
speaker_id: Speaker ID for multi-speaker models.
|
|
1052
|
+
vocoder: Optional vocoder model for higher quality audio.
|
|
1053
|
+
|
|
1054
|
+
Returns:
|
|
1055
|
+
The generated audio file.
|
|
1056
|
+
|
|
1057
|
+
Examples:
|
|
1058
|
+
Add a computed column that converts text to speech:
|
|
1059
|
+
|
|
1060
|
+
>>> tbl.add_computed_column(audio=text_to_speech(
|
|
1061
|
+
... tbl.text_content,
|
|
1062
|
+
... model_id='microsoft/speecht5_tts',
|
|
1063
|
+
... speaker_id=0
|
|
1064
|
+
... ))
|
|
1065
|
+
"""
|
|
1066
|
+
env.Env.get().require_package('transformers')
|
|
1067
|
+
env.Env.get().require_package('datasets')
|
|
1068
|
+
env.Env.get().require_package('soundfile')
|
|
1069
|
+
device = resolve_torch_device('auto')
|
|
1070
|
+
import soundfile as sf # type: ignore[import-untyped]
|
|
1071
|
+
import torch
|
|
1072
|
+
from datasets import load_dataset # type: ignore[import-untyped]
|
|
1073
|
+
from transformers import (
|
|
1074
|
+
AutoModelForTextToWaveform,
|
|
1075
|
+
AutoProcessor,
|
|
1076
|
+
BarkModel,
|
|
1077
|
+
SpeechT5ForTextToSpeech,
|
|
1078
|
+
SpeechT5HifiGan,
|
|
1079
|
+
SpeechT5Processor,
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
# Model loading with error handling - following best practices pattern
|
|
1083
|
+
if 'speecht5' in model_id.lower():
|
|
1084
|
+
model = _lookup_model(model_id, SpeechT5ForTextToSpeech.from_pretrained, device=device)
|
|
1085
|
+
processor = _lookup_processor(model_id, SpeechT5Processor.from_pretrained)
|
|
1086
|
+
vocoder_model_id = vocoder or 'microsoft/speecht5_hifigan'
|
|
1087
|
+
vocoder_model = _lookup_model(vocoder_model_id, SpeechT5HifiGan.from_pretrained, device=device)
|
|
1088
|
+
|
|
1089
|
+
elif 'bark' in model_id.lower():
|
|
1090
|
+
model = _lookup_model(model_id, BarkModel.from_pretrained, device=device)
|
|
1091
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1092
|
+
vocoder_model = None
|
|
1093
|
+
|
|
1094
|
+
else:
|
|
1095
|
+
model = _lookup_model(model_id, AutoModelForTextToWaveform.from_pretrained, device=device)
|
|
1096
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1097
|
+
vocoder_model = None
|
|
1098
|
+
|
|
1099
|
+
# Load speaker embeddings once for SpeechT5 (following speech2text pattern)
|
|
1100
|
+
speaker_embeddings = None
|
|
1101
|
+
if 'speecht5' in model_id.lower():
|
|
1102
|
+
embeddings_dataset = load_dataset('Matthijs/cmu-arctic-xvectors', split='validation')
|
|
1103
|
+
speaker_embeddings = torch.tensor(embeddings_dataset[speaker_id or 7306]['xvector']).unsqueeze(0).to(device)
|
|
1104
|
+
|
|
1105
|
+
with torch.no_grad():
|
|
1106
|
+
# Generate speech based on model type
|
|
1107
|
+
if 'speecht5' in model_id.lower():
|
|
1108
|
+
inputs = processor(text=text, return_tensors='pt').to(device)
|
|
1109
|
+
speech = model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=vocoder_model)
|
|
1110
|
+
audio_np = speech.cpu().numpy()
|
|
1111
|
+
sample_rate = 16000
|
|
1112
|
+
|
|
1113
|
+
elif 'bark' in model_id.lower():
|
|
1114
|
+
inputs = processor(text, return_tensors='pt').to(device)
|
|
1115
|
+
audio_array = model.generate(**inputs)
|
|
1116
|
+
audio_np = audio_array.cpu().numpy().squeeze()
|
|
1117
|
+
sample_rate = getattr(model.generation_config, 'sample_rate', 24000)
|
|
1118
|
+
|
|
1119
|
+
else:
|
|
1120
|
+
# Generic approach for other TTS models
|
|
1121
|
+
inputs = processor(text, return_tensors='pt').to(device)
|
|
1122
|
+
audio_output = model(**inputs)
|
|
1123
|
+
audio_np = audio_output.waveform.cpu().numpy().squeeze()
|
|
1124
|
+
sample_rate = getattr(model.config, 'sample_rate', 22050)
|
|
1125
|
+
|
|
1126
|
+
# Normalize audio - following consistent pattern
|
|
1127
|
+
if audio_np.dtype != np.float32:
|
|
1128
|
+
audio_np = audio_np.astype(np.float32)
|
|
1129
|
+
|
|
1130
|
+
if np.max(np.abs(audio_np)) > 0:
|
|
1131
|
+
audio_np = audio_np / np.max(np.abs(audio_np)) * 0.9
|
|
1132
|
+
|
|
1133
|
+
# Create output file
|
|
1134
|
+
output_filename = str(TempStore.create_path(extension='.wav'))
|
|
1135
|
+
sf.write(output_filename, audio_np, sample_rate, format='WAV', subtype='PCM_16')
|
|
1136
|
+
return output_filename
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
@pxt.udf
|
|
1140
|
+
def image_to_image(
|
|
1141
|
+
image: PIL.Image.Image,
|
|
1142
|
+
prompt: str,
|
|
1143
|
+
*,
|
|
1144
|
+
model_id: str,
|
|
1145
|
+
seed: Optional[int] = None,
|
|
1146
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
1147
|
+
) -> PIL.Image.Image:
|
|
1148
|
+
"""
|
|
1149
|
+
Transforms input images based on text prompts using a pretrained image-to-image model.
|
|
1150
|
+
`model_id` should be a reference to a pretrained
|
|
1151
|
+
[image-to-image model](https://huggingface.co/models?pipeline_tag=image-to-image).
|
|
1152
|
+
|
|
1153
|
+
__Requirements:__
|
|
1154
|
+
|
|
1155
|
+
- `pip install torch transformers diffusers accelerate`
|
|
1156
|
+
|
|
1157
|
+
Args:
|
|
1158
|
+
image: The input image to transform.
|
|
1159
|
+
prompt: The text prompt describing the desired transformation.
|
|
1160
|
+
model_id: The pretrained image-to-image model to use.
|
|
1161
|
+
seed: Random seed for reproducibility.
|
|
1162
|
+
model_kwargs: Additional keyword arguments to pass to the model, such as `strength`,
|
|
1163
|
+
`guidance_scale`, or `num_inference_steps`.
|
|
1164
|
+
|
|
1165
|
+
Returns:
|
|
1166
|
+
The transformed image.
|
|
1167
|
+
|
|
1168
|
+
Examples:
|
|
1169
|
+
Add a computed column that transforms images based on prompts:
|
|
1170
|
+
|
|
1171
|
+
>>> tbl.add_computed_column(transformed=image_to_image(
|
|
1172
|
+
... tbl.source_image,
|
|
1173
|
+
... tbl.transformation_prompt,
|
|
1174
|
+
... model_id='runwayml/stable-diffusion-v1-5'
|
|
1175
|
+
... ))
|
|
1176
|
+
"""
|
|
1177
|
+
env.Env.get().require_package('transformers')
|
|
1178
|
+
env.Env.get().require_package('diffusers')
|
|
1179
|
+
env.Env.get().require_package('accelerate')
|
|
1180
|
+
device = resolve_torch_device('auto')
|
|
1181
|
+
import torch
|
|
1182
|
+
from diffusers import StableDiffusionImg2ImgPipeline
|
|
1183
|
+
|
|
1184
|
+
if model_kwargs is None:
|
|
1185
|
+
model_kwargs = {}
|
|
1186
|
+
|
|
1187
|
+
pipe = _lookup_model(
|
|
1188
|
+
model_id,
|
|
1189
|
+
lambda x: StableDiffusionImg2ImgPipeline.from_pretrained(
|
|
1190
|
+
x,
|
|
1191
|
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
|
1192
|
+
safety_checker=None,
|
|
1193
|
+
requires_safety_checker=False,
|
|
1194
|
+
),
|
|
1195
|
+
device=device,
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
try:
|
|
1199
|
+
if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
|
|
1200
|
+
pipe.enable_model_cpu_offload()
|
|
1201
|
+
if hasattr(pipe, 'enable_memory_efficient_attention'):
|
|
1202
|
+
pipe.enable_memory_efficient_attention()
|
|
1203
|
+
except Exception:
|
|
1204
|
+
pass # Ignore optimization failures
|
|
1205
|
+
|
|
1206
|
+
generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
|
|
1207
|
+
|
|
1208
|
+
processed_image = image.convert('RGB')
|
|
1209
|
+
|
|
1210
|
+
with torch.no_grad():
|
|
1211
|
+
result = pipe(prompt=prompt, image=processed_image, generator=generator, **model_kwargs)
|
|
1212
|
+
return result.images[0]
|
|
1213
|
+
|
|
1214
|
+
|
|
1215
|
+
@pxt.udf
|
|
1216
|
+
def automatic_speech_recognition(
|
|
1217
|
+
audio: pxt.Audio,
|
|
1218
|
+
*,
|
|
1219
|
+
model_id: str,
|
|
1220
|
+
language: Optional[str] = None,
|
|
1221
|
+
chunk_length_s: Optional[int] = None,
|
|
1222
|
+
return_timestamps: bool = False,
|
|
1223
|
+
) -> str:
|
|
1224
|
+
"""
|
|
1225
|
+
Transcribes speech to text using a pretrained ASR model. `model_id` should be a reference to a
|
|
1226
|
+
pretrained [automatic-speech-recognition model](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition).
|
|
1227
|
+
|
|
1228
|
+
This is a **generic function** that works with many ASR model families. For production use with
|
|
1229
|
+
specific models, consider specialized functions like `whisper.transcribe()` or
|
|
1230
|
+
`speech2text_for_conditional_generation()`.
|
|
1231
|
+
|
|
1232
|
+
__Requirements:__
|
|
1233
|
+
|
|
1234
|
+
- `pip install torch transformers torchaudio`
|
|
1235
|
+
|
|
1236
|
+
__Recommended Models:__
|
|
1237
|
+
|
|
1238
|
+
- **OpenAI Whisper**: `openai/whisper-tiny.en`, `openai/whisper-small`, `openai/whisper-base`
|
|
1239
|
+
- **Facebook Wav2Vec2**: `facebook/wav2vec2-base-960h`, `facebook/wav2vec2-large-960h-lv60-self`
|
|
1240
|
+
- **Microsoft SpeechT5**: `microsoft/speecht5_asr`
|
|
1241
|
+
- **Meta MMS (Multilingual)**: `facebook/mms-1b-all`
|
|
1242
|
+
|
|
1243
|
+
Args:
|
|
1244
|
+
audio: The audio file(s) to transcribe.
|
|
1245
|
+
model_id: The pretrained ASR model to use.
|
|
1246
|
+
language: Language code for multilingual models (e.g., 'en', 'es', 'fr').
|
|
1247
|
+
chunk_length_s: Maximum length of audio chunks in seconds for long audio processing.
|
|
1248
|
+
return_timestamps: Whether to return word-level timestamps (model dependent).
|
|
1249
|
+
|
|
1250
|
+
Returns:
|
|
1251
|
+
The transcribed text.
|
|
1252
|
+
|
|
1253
|
+
Examples:
|
|
1254
|
+
Add a computed column that transcribes audio files:
|
|
1255
|
+
|
|
1256
|
+
>>> tbl.add_computed_column(transcription=automatic_speech_recognition(
|
|
1257
|
+
... tbl.audio_file,
|
|
1258
|
+
... model_id='openai/whisper-tiny.en' # Recommended
|
|
1259
|
+
... ))
|
|
1260
|
+
|
|
1261
|
+
Transcribe with language specification:
|
|
1262
|
+
|
|
1263
|
+
>>> tbl.add_computed_column(transcription=automatic_speech_recognition(
|
|
1264
|
+
... tbl.audio_file,
|
|
1265
|
+
... model_id='facebook/mms-1b-all',
|
|
1266
|
+
... language='en'
|
|
1267
|
+
... ))
|
|
1268
|
+
"""
|
|
1269
|
+
env.Env.get().require_package('transformers')
|
|
1270
|
+
env.Env.get().require_package('torchaudio')
|
|
1271
|
+
device = resolve_torch_device('auto', allow_mps=False) # Following speech2text pattern
|
|
1272
|
+
import torch
|
|
1273
|
+
import torchaudio
|
|
1274
|
+
|
|
1275
|
+
# Try to load model and processor using direct model loading - following speech2text pattern
|
|
1276
|
+
# Handle different ASR model types
|
|
1277
|
+
if 'whisper' in model_id.lower():
|
|
1278
|
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
|
1279
|
+
|
|
1280
|
+
model = _lookup_model(model_id, WhisperForConditionalGeneration.from_pretrained, device=device)
|
|
1281
|
+
processor = _lookup_processor(model_id, WhisperProcessor.from_pretrained)
|
|
1282
|
+
|
|
1283
|
+
# Language validation for Whisper - following speech2text pattern
|
|
1284
|
+
if language is not None and hasattr(processor.tokenizer, 'get_decoder_prompt_ids'):
|
|
1285
|
+
try:
|
|
1286
|
+
# Test if language is supported
|
|
1287
|
+
_ = processor.tokenizer.get_decoder_prompt_ids(language=language)
|
|
1288
|
+
except Exception:
|
|
1289
|
+
raise excs.Error(
|
|
1290
|
+
f"Language code '{language}' is not supported by Whisper model '{model_id}'. "
|
|
1291
|
+
f"Try common codes like 'en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh'."
|
|
1292
|
+
) from None
|
|
1293
|
+
|
|
1294
|
+
elif 'wav2vec2' in model_id.lower():
|
|
1295
|
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
1296
|
+
|
|
1297
|
+
model = _lookup_model(model_id, Wav2Vec2ForCTC.from_pretrained, device=device)
|
|
1298
|
+
processor = _lookup_processor(model_id, Wav2Vec2Processor.from_pretrained)
|
|
1299
|
+
|
|
1300
|
+
elif 'speech_to_text' in model_id.lower() or 's2t' in model_id.lower():
|
|
1301
|
+
# Use the existing speech2text function for these models
|
|
1302
|
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
|
1303
|
+
|
|
1304
|
+
model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
|
|
1305
|
+
processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
|
|
1306
|
+
|
|
1307
|
+
else:
|
|
1308
|
+
# Generic fallback using Auto classes
|
|
1309
|
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
1310
|
+
|
|
1311
|
+
try:
|
|
1312
|
+
model = _lookup_model(model_id, AutoModelForSpeechSeq2Seq.from_pretrained, device=device)
|
|
1313
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1314
|
+
except Exception:
|
|
1315
|
+
# Fallback to CTC models
|
|
1316
|
+
from transformers import AutoModelForCTC
|
|
1317
|
+
|
|
1318
|
+
model = _lookup_model(model_id, AutoModelForCTC.from_pretrained, device=device)
|
|
1319
|
+
processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
|
|
1320
|
+
|
|
1321
|
+
# Get model's expected sampling rate - following speech2text pattern
|
|
1322
|
+
model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
|
|
1323
|
+
|
|
1324
|
+
# Load and preprocess audio - following speech2text pattern
|
|
1325
|
+
waveform, sampling_rate = torchaudio.load(audio)
|
|
1326
|
+
|
|
1327
|
+
# Resample if necessary
|
|
1328
|
+
if sampling_rate != model_sampling_rate:
|
|
1329
|
+
waveform = torchaudio.transforms.Resample(sampling_rate, model_sampling_rate)(waveform)
|
|
1330
|
+
|
|
1331
|
+
# Convert to mono if stereo
|
|
1332
|
+
if waveform.dim() == 2:
|
|
1333
|
+
waveform = torch.mean(waveform, dim=0)
|
|
1334
|
+
assert waveform.dim() == 1
|
|
1335
|
+
|
|
1336
|
+
with torch.no_grad():
|
|
1337
|
+
# Process audio with the model
|
|
1338
|
+
inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
|
|
1339
|
+
|
|
1340
|
+
# Handle different model types for generation
|
|
1341
|
+
if 'whisper' in model_id.lower():
|
|
1342
|
+
# Whisper-specific generation
|
|
1343
|
+
generate_kwargs = {}
|
|
1344
|
+
if language is not None:
|
|
1345
|
+
generate_kwargs['language'] = language
|
|
1346
|
+
if return_timestamps:
|
|
1347
|
+
generate_kwargs['return_timestamps'] = 'word' if return_timestamps else None
|
|
1348
|
+
|
|
1349
|
+
generated_ids = model.generate(**inputs.to(device), **generate_kwargs)
|
|
1350
|
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
1351
|
+
|
|
1352
|
+
elif hasattr(model, 'generate'):
|
|
1353
|
+
# Seq2Seq models (Speech2Text, etc.)
|
|
1354
|
+
generated_ids = model.generate(**inputs.to(device))
|
|
1355
|
+
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
1356
|
+
|
|
1357
|
+
else:
|
|
1358
|
+
# CTC models (Wav2Vec2, etc.)
|
|
1359
|
+
logits = model(**inputs.to(device)).logits
|
|
1360
|
+
predicted_ids = torch.argmax(logits, dim=-1)
|
|
1361
|
+
transcription = processor.batch_decode(predicted_ids)[0]
|
|
1362
|
+
|
|
1363
|
+
return transcription.strip()
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
@pxt.udf
|
|
1367
|
+
def image_to_video(
|
|
1368
|
+
image: PIL.Image.Image,
|
|
1369
|
+
*,
|
|
1370
|
+
model_id: str,
|
|
1371
|
+
num_frames: int = 25,
|
|
1372
|
+
fps: int = 6,
|
|
1373
|
+
seed: Optional[int] = None,
|
|
1374
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
1375
|
+
) -> pxt.Video:
|
|
1376
|
+
"""
|
|
1377
|
+
Generates videos from input images using a pretrained image-to-video model.
|
|
1378
|
+
`model_id` should be a reference to a pretrained
|
|
1379
|
+
[image-to-video model](https://huggingface.co/models?pipeline_tag=image-to-video).
|
|
1380
|
+
|
|
1381
|
+
__Requirements:__
|
|
1382
|
+
|
|
1383
|
+
- `pip install torch transformers diffusers accelerate`
|
|
1384
|
+
|
|
1385
|
+
Args:
|
|
1386
|
+
image: The input image to animate into a video.
|
|
1387
|
+
model_id: The pretrained image-to-video model to use.
|
|
1388
|
+
num_frames: Number of video frames to generate.
|
|
1389
|
+
fps: Frames per second for the output video.
|
|
1390
|
+
seed: Random seed for reproducibility.
|
|
1391
|
+
model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
|
|
1392
|
+
`motion_bucket_id`, or `guidance_scale`.
|
|
1393
|
+
|
|
1394
|
+
Returns:
|
|
1395
|
+
The generated video file.
|
|
1396
|
+
|
|
1397
|
+
Examples:
|
|
1398
|
+
Add a computed column that creates videos from images:
|
|
1399
|
+
|
|
1400
|
+
>>> tbl.add_computed_column(video=image_to_video(
|
|
1401
|
+
... tbl.input_image,
|
|
1402
|
+
... model_id='stabilityai/stable-video-diffusion-img2vid-xt',
|
|
1403
|
+
... num_frames=25,
|
|
1404
|
+
... fps=7
|
|
1405
|
+
... ))
|
|
1406
|
+
"""
|
|
1407
|
+
env.Env.get().require_package('transformers')
|
|
1408
|
+
env.Env.get().require_package('diffusers')
|
|
1409
|
+
env.Env.get().require_package('accelerate')
|
|
1410
|
+
device = resolve_torch_device('auto', allow_mps=False)
|
|
1411
|
+
import numpy as np
|
|
1412
|
+
import torch
|
|
1413
|
+
from diffusers import StableVideoDiffusionPipeline
|
|
1414
|
+
|
|
1415
|
+
if model_kwargs is None:
|
|
1416
|
+
model_kwargs = {}
|
|
1417
|
+
|
|
1418
|
+
# Parameter validation - following best practices pattern
|
|
1419
|
+
if num_frames < 1:
|
|
1420
|
+
raise excs.Error(f'num_frames must be at least 1, got {num_frames}')
|
|
1421
|
+
|
|
1422
|
+
if num_frames > 25:
|
|
1423
|
+
raise excs.Error(f'num_frames cannot exceed 25 for most video diffusion models, got {num_frames}')
|
|
1424
|
+
|
|
1425
|
+
if fps < 1:
|
|
1426
|
+
raise excs.Error(f'fps must be at least 1, got {fps}')
|
|
1427
|
+
|
|
1428
|
+
if fps > 60:
|
|
1429
|
+
raise excs.Error(f'fps should not exceed 60 for reasonable video generation, got {fps}')
|
|
1430
|
+
|
|
1431
|
+
pipe = _lookup_model(
|
|
1432
|
+
model_id,
|
|
1433
|
+
lambda x: StableVideoDiffusionPipeline.from_pretrained(
|
|
1434
|
+
x,
|
|
1435
|
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
|
1436
|
+
variant='fp16' if device == 'cuda' else None,
|
|
1437
|
+
),
|
|
1438
|
+
device=device,
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
try:
|
|
1442
|
+
if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
|
|
1443
|
+
pipe.enable_model_cpu_offload()
|
|
1444
|
+
if hasattr(pipe, 'enable_memory_efficient_attention'):
|
|
1445
|
+
pipe.enable_memory_efficient_attention()
|
|
1446
|
+
except Exception:
|
|
1447
|
+
pass # Ignore optimization failures
|
|
1448
|
+
|
|
1449
|
+
generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
|
|
1450
|
+
|
|
1451
|
+
# Ensure image is in RGB mode and proper size
|
|
1452
|
+
processed_image = image.convert('RGB')
|
|
1453
|
+
target_width, target_height = 512, 320
|
|
1454
|
+
processed_image = processed_image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)
|
|
1455
|
+
|
|
1456
|
+
# Generate video frames with proper error handling
|
|
1457
|
+
with torch.no_grad():
|
|
1458
|
+
result = pipe(image=processed_image, num_frames=num_frames, generator=generator, **model_kwargs)
|
|
1459
|
+
frames = result.frames[0]
|
|
1460
|
+
|
|
1461
|
+
# Create output video file
|
|
1462
|
+
output_path = str(TempStore.create_path(extension='.mp4'))
|
|
1463
|
+
|
|
1464
|
+
with av.open(output_path, mode='w') as container:
|
|
1465
|
+
stream = container.add_stream('h264', rate=fps)
|
|
1466
|
+
stream.width = target_width
|
|
1467
|
+
stream.height = target_height
|
|
1468
|
+
stream.pix_fmt = 'yuv420p'
|
|
1469
|
+
|
|
1470
|
+
# Set codec options for better compatibility
|
|
1471
|
+
stream.codec_context.options = {'crf': '23', 'preset': 'medium'}
|
|
1472
|
+
|
|
1473
|
+
for frame_pil in frames:
|
|
1474
|
+
# Convert PIL to numpy array
|
|
1475
|
+
frame_array = np.array(frame_pil)
|
|
1476
|
+
# Create av VideoFrame
|
|
1477
|
+
av_frame = av.VideoFrame.from_ndarray(frame_array, format='rgb24')
|
|
1478
|
+
# Encode and mux
|
|
1479
|
+
for packet in stream.encode(av_frame):
|
|
1480
|
+
container.mux(packet)
|
|
1481
|
+
|
|
1482
|
+
# Flush encoder
|
|
1483
|
+
for packet in stream.encode():
|
|
1484
|
+
container.mux(packet)
|
|
1485
|
+
|
|
1486
|
+
return output_path
|
|
458
1487
|
|
|
459
1488
|
|
|
460
1489
|
def _lookup_model(
|