pixeltable 0.4.17__py3-none-any.whl → 0.4.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -7,8 +7,10 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
7
7
  UDFs).
8
8
  """
9
9
 
10
- from typing import Any, Callable, Optional, TypeVar
10
+ from typing import Any, Callable, Literal, Optional, TypeVar
11
11
 
12
+ import av
13
+ import numpy as np
12
14
  import PIL.Image
13
15
 
14
16
  import pixeltable as pxt
@@ -18,6 +20,9 @@ from pixeltable import env
18
20
  from pixeltable.func import Batch
19
21
  from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
20
22
  from pixeltable.utils.code import local_public_names
23
+ from pixeltable.utils.local_store import TempStore
24
+
25
+ T = TypeVar('T')
21
26
 
22
27
 
23
28
  @pxt.udf(batch_size=32)
@@ -454,7 +459,1031 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
454
459
  return {'image': {'width': image.width, 'height': image.height}, 'annotations': annotations}
455
460
 
456
461
 
457
- T = TypeVar('T')
462
+ @pxt.udf
463
+ def text_generation(text: str, *, model_id: str, model_kwargs: Optional[dict[str, Any]] = None) -> str:
464
+ """
465
+ Generates text using a pretrained language model. `model_id` should be a reference to a pretrained
466
+ [text generation model](https://huggingface.co/models?pipeline_tag=text-generation).
467
+
468
+ __Requirements:__
469
+
470
+ - `pip install torch transformers`
471
+
472
+ Args:
473
+ text: The input text to continue/complete.
474
+ model_id: The pretrained model to use for text generation.
475
+ model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`,
476
+ `temperature`, etc. See the
477
+ [Hugging Face text_generation documentation](https://huggingface.co/docs/inference-providers/en/tasks/text-generation)
478
+ for details.
479
+
480
+ Returns:
481
+ The generated text completion.
482
+
483
+ Examples:
484
+ Add a computed column that generates text completions using the `Qwen/Qwen3-0.6B` model:
485
+
486
+ >>> tbl.add_computed_column(completion=text_generation(
487
+ ... tbl.prompt,
488
+ ... model_id='Qwen/Qwen3-0.6B',
489
+ ... model_kwargs={'temperature': 0.5, 'max_length': 150}
490
+ ... ))
491
+ """
492
+ env.Env.get().require_package('transformers')
493
+ device = resolve_torch_device('auto')
494
+ import torch
495
+ from transformers import AutoModelForCausalLM, AutoTokenizer
496
+
497
+ if model_kwargs is None:
498
+ model_kwargs = {}
499
+
500
+ model = _lookup_model(model_id, AutoModelForCausalLM.from_pretrained, device=device)
501
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
502
+
503
+ if tokenizer.pad_token is None:
504
+ tokenizer.pad_token = tokenizer.eos_token
505
+
506
+ with torch.no_grad():
507
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
508
+ outputs = model.generate(**inputs.to(device), pad_token_id=tokenizer.eos_token_id, **model_kwargs)
509
+
510
+ input_length = len(inputs['input_ids'][0])
511
+ generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
512
+ return generated_text
513
+
514
+
515
+ @pxt.udf(batch_size=16)
516
+ def text_classification(text: Batch[str], *, model_id: str, top_k: int = 5) -> Batch[list[dict[str, Any]]]:
517
+ """
518
+ Classifies text using a pretrained classification model. `model_id` should be a reference to a pretrained
519
+ [text classification model](https://huggingface.co/models?pipeline_tag=text-classification)
520
+ such as BERT, RoBERTa, or DistilBERT.
521
+
522
+ __Requirements:__
523
+
524
+ - `pip install torch transformers`
525
+
526
+ Args:
527
+ text: The text to classify.
528
+ model_id: The pretrained model to use for classification.
529
+ top_k: The number of top predictions to return.
530
+
531
+ Returns:
532
+ A dictionary containing classification results with scores, labels, and label text.
533
+
534
+ Examples:
535
+ Add a computed column for sentiment analysis:
536
+
537
+ >>> tbl.add_computed_column(sentiment=text_classification(
538
+ ... tbl.review_text,
539
+ ... model_id='cardiffnlp/twitter-roberta-base-sentiment-latest'
540
+ ... ))
541
+ """
542
+ env.Env.get().require_package('transformers')
543
+ device = resolve_torch_device('auto')
544
+ import torch
545
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
546
+
547
+ model = _lookup_model(model_id, AutoModelForSequenceClassification.from_pretrained, device=device)
548
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
549
+
550
+ with torch.no_grad():
551
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
552
+ outputs = model(**inputs.to(device))
553
+ logits = outputs.logits
554
+
555
+ probs = torch.softmax(logits, dim=-1)
556
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
557
+
558
+ results = []
559
+ for i in range(len(text)):
560
+ # Return as list of individual classification items for HuggingFace compatibility
561
+ classification_items = []
562
+ for k in range(top_k_probs.shape[1]):
563
+ classification_items.append(
564
+ {
565
+ 'label': top_k_indices[i, k].item(),
566
+ 'label_text': model.config.id2label[top_k_indices[i, k].item()],
567
+ 'score': top_k_probs[i, k].item(),
568
+ }
569
+ )
570
+ results.append(classification_items)
571
+
572
+ return results
573
+
574
+
575
+ @pxt.udf(batch_size=4)
576
+ def image_captioning(
577
+ image: Batch[PIL.Image.Image], *, model_id: str, model_kwargs: Optional[dict[str, Any]] = None
578
+ ) -> Batch[str]:
579
+ """
580
+ Generates captions for images using a pretrained image captioning model. `model_id` should be a reference to a
581
+ pretrained [image-to-text model](https://huggingface.co/models?pipeline_tag=image-to-text) such as BLIP,
582
+ Git, or LLaVA.
583
+
584
+ __Requirements:__
585
+
586
+ - `pip install torch transformers`
587
+
588
+ Args:
589
+ image: The image to caption.
590
+ model_id: The pretrained model to use for captioning.
591
+ model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
592
+
593
+ Returns:
594
+ The generated caption text.
595
+
596
+ Examples:
597
+ Add a computed column `caption` to an existing table `tbl` that generates captions using the
598
+ `Salesforce/blip-image-captioning-base` model:
599
+
600
+ >>> tbl.add_computed_column(caption=image_captioning(
601
+ ... tbl.image,
602
+ ... model_id='Salesforce/blip-image-captioning-base',
603
+ ... model_kwargs={'max_length': 30}
604
+ ... ))
605
+ """
606
+ env.Env.get().require_package('transformers')
607
+ device = resolve_torch_device('auto')
608
+ import torch
609
+ from transformers import AutoModelForVision2Seq, AutoProcessor
610
+
611
+ if model_kwargs is None:
612
+ model_kwargs = {}
613
+
614
+ model = _lookup_model(model_id, AutoModelForVision2Seq.from_pretrained, device=device)
615
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
616
+ normalized_images = [normalize_image_mode(img) for img in image]
617
+
618
+ with torch.no_grad():
619
+ inputs = processor(images=normalized_images, return_tensors='pt')
620
+ outputs = model.generate(**inputs.to(device), **model_kwargs)
621
+
622
+ captions = processor.batch_decode(outputs, skip_special_tokens=True)
623
+ return captions
624
+
625
+
626
+ @pxt.udf(batch_size=8)
627
+ def summarization(text: Batch[str], *, model_id: str, model_kwargs: Optional[dict[str, Any]] = None) -> Batch[str]:
628
+ """
629
+ Summarizes text using a pretrained summarization model. `model_id` should be a reference to a pretrained
630
+ [summarization model](https://huggingface.co/models?pipeline_tag=summarization) such as BART, T5, or Pegasus.
631
+
632
+ __Requirements:__
633
+
634
+ - `pip install torch transformers`
635
+
636
+ Args:
637
+ text: The text to summarize.
638
+ model_id: The pretrained model to use for summarization.
639
+ model_kwargs: Additional keyword arguments to pass to the model's `generate` method, such as `max_length`.
640
+
641
+ Returns:
642
+ The generated summary text.
643
+
644
+ Examples:
645
+ Add a computed column that summarizes documents:
646
+
647
+ >>> tbl.add_computed_column(summary=text_summarization(
648
+ ... tbl.document_text,
649
+ ... model_id='facebook/bart-large-cnn',
650
+ ... max_length=100
651
+ ... ))
652
+ """
653
+ env.Env.get().require_package('transformers')
654
+ device = resolve_torch_device('auto')
655
+ import torch
656
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
657
+
658
+ if model_kwargs is None:
659
+ model_kwargs = {}
660
+
661
+ model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
662
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
663
+
664
+ with torch.no_grad():
665
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
666
+ outputs = model.generate(**inputs.to(device), **model_kwargs)
667
+
668
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
669
+
670
+
671
+ @pxt.udf
672
+ def token_classification(
673
+ text: str, *, model_id: str, aggregation_strategy: Literal['simple', 'first', 'average', 'max'] = 'simple'
674
+ ) -> list[dict[str, Any]]:
675
+ """
676
+ Extracts named entities from text using a pretrained named entity recognition (NER) model.
677
+ `model_id` should be a reference to a pretrained
678
+ [token classification model](https://huggingface.co/models?pipeline_tag=token-classification) for NER.
679
+
680
+ __Requirements:__
681
+
682
+ - `pip install torch transformers`
683
+
684
+ Args:
685
+ text: The text to analyze for named entities.
686
+ model_id: The pretrained model to use.
687
+ aggregation_strategy: Method used to aggregate tokens.
688
+
689
+ Returns:
690
+ A list of dictionaries containing entity information (text, label, confidence, start, end).
691
+
692
+ Examples:
693
+ Add a computed column that extracts named entities:
694
+
695
+ >>> tbl.add_computed_column(entities=token_classification(
696
+ ... tbl.text,
697
+ ... model_id='dbmdz/bert-large-cased-finetuned-conll03-english'
698
+ ... ))
699
+ """
700
+ env.Env.get().require_package('transformers')
701
+ device = resolve_torch_device('auto')
702
+ import torch
703
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
704
+
705
+ # Follow direct model loading pattern like other best practice functions
706
+ model = _lookup_model(model_id, AutoModelForTokenClassification.from_pretrained, device=device)
707
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
708
+
709
+ # Validate aggregation strategy
710
+ valid_strategies = {'simple', 'first', 'average', 'max'}
711
+ if aggregation_strategy not in valid_strategies:
712
+ raise excs.Error(
713
+ f'Invalid aggregation_strategy {aggregation_strategy!r}. Must be one of: {", ".join(valid_strategies)}'
714
+ )
715
+
716
+ with torch.no_grad():
717
+ # Tokenize with special tokens and return offsets for entity extraction
718
+ inputs = tokenizer(
719
+ text,
720
+ return_tensors='pt',
721
+ truncation=True,
722
+ max_length=512,
723
+ return_offsets_mapping=True,
724
+ add_special_tokens=True,
725
+ )
726
+
727
+ # Get model predictions
728
+ outputs = model(**{k: v.to(device) for k, v in inputs.items() if k != 'offset_mapping'})
729
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
730
+
731
+ # Get the predicted labels and confidence scores
732
+ predicted_token_classes = predictions.argmax(dim=-1).squeeze().tolist()
733
+ confidence_scores = predictions.max(dim=-1).values.squeeze().tolist()
734
+
735
+ # Handle single token case
736
+ if not isinstance(predicted_token_classes, list):
737
+ predicted_token_classes = [predicted_token_classes]
738
+ confidence_scores = [confidence_scores]
739
+
740
+ # Extract entities from predictions
741
+ entities = []
742
+ offset_mapping = inputs['offset_mapping'][0].tolist()
743
+
744
+ current_entity = None
745
+
746
+ for token_class, confidence, (start_offset, end_offset) in zip(
747
+ predicted_token_classes, confidence_scores, offset_mapping
748
+ ):
749
+ # Skip special tokens (offset is (0, 0))
750
+ if start_offset == 0 and end_offset == 0:
751
+ continue
752
+
753
+ label = model.config.id2label[token_class]
754
+
755
+ # Skip 'O' (outside) labels
756
+ if label == 'O':
757
+ if current_entity:
758
+ entities.append(current_entity)
759
+ current_entity = None
760
+ continue
761
+
762
+ # Parse BIO/BILOU tags
763
+ if label.startswith('B-') or (label.startswith('I-') and current_entity is None):
764
+ # Begin new entity
765
+ if current_entity:
766
+ entities.append(current_entity)
767
+
768
+ entity_type = label[2:] if label.startswith(('B-', 'I-')) else label
769
+ current_entity = {
770
+ 'word': text[start_offset:end_offset],
771
+ 'entity_group': entity_type,
772
+ 'score': float(confidence),
773
+ 'start': start_offset,
774
+ 'end': end_offset,
775
+ }
776
+
777
+ elif label.startswith('I-') and current_entity:
778
+ # Continue current entity
779
+ entity_type = label[2:]
780
+ if current_entity['entity_group'] == entity_type:
781
+ # Extend the current entity
782
+ current_entity['word'] = text[current_entity['start'] : end_offset]
783
+ current_entity['end'] = end_offset
784
+
785
+ # Update confidence based on aggregation strategy
786
+ if aggregation_strategy == 'average':
787
+ # Simple average (could be improved with token count weighting)
788
+ current_entity['score'] = (current_entity['score'] + float(confidence)) / 2
789
+ elif aggregation_strategy == 'max':
790
+ current_entity['score'] = max(current_entity['score'], float(confidence))
791
+ elif aggregation_strategy == 'first':
792
+ pass # Keep first confidence
793
+ # 'simple' uses the same logic as 'first'
794
+ else:
795
+ # Different entity type, start new entity
796
+ entities.append(current_entity)
797
+ current_entity = {
798
+ 'word': text[start_offset:end_offset],
799
+ 'entity_group': entity_type,
800
+ 'score': float(confidence),
801
+ 'start': start_offset,
802
+ 'end': end_offset,
803
+ }
804
+
805
+ # Don't forget the last entity
806
+ if current_entity:
807
+ entities.append(current_entity)
808
+
809
+ return entities
810
+
811
+
812
+ @pxt.udf
813
+ def question_answering(context: str, question: str, *, model_id: str) -> dict[str, Any]:
814
+ """
815
+ Answers questions based on provided context using a pretrained QA model. `model_id` should be a reference to a
816
+ pretrained [question answering model](https://huggingface.co/models?pipeline_tag=question-answering) such as
817
+ BERT or RoBERTa.
818
+
819
+ __Requirements:__
820
+
821
+ - `pip install torch transformers`
822
+
823
+ Args:
824
+ context: The context text containing the answer.
825
+ question: The question to answer.
826
+ model_id: The pretrained QA model to use.
827
+
828
+ Returns:
829
+ A dictionary containing the answer, confidence score, and start/end positions.
830
+
831
+ Examples:
832
+ Add a computed column that answers questions based on document context:
833
+
834
+ >>> tbl.add_computed_column(answer=question_answering(
835
+ ... tbl.document_text,
836
+ ... tbl.question,
837
+ ... model_id='deepset/roberta-base-squad2'
838
+ ... ))
839
+ """
840
+ env.Env.get().require_package('transformers')
841
+ device = resolve_torch_device('auto')
842
+ import torch
843
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
844
+
845
+ model = _lookup_model(model_id, AutoModelForQuestionAnswering.from_pretrained, device=device)
846
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
847
+
848
+ with torch.no_grad():
849
+ # Tokenize the question and context
850
+ inputs = tokenizer.encode_plus(
851
+ question, context, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512
852
+ )
853
+
854
+ # Get model predictions
855
+ outputs = model(**inputs.to(device))
856
+ start_scores = outputs.start_logits
857
+ end_scores = outputs.end_logits
858
+
859
+ # Find the tokens with the highest start and end scores
860
+ start_idx = torch.argmax(start_scores)
861
+ end_idx = torch.argmax(end_scores)
862
+
863
+ # Ensure end_idx >= start_idx
864
+ end_idx = torch.max(end_idx, start_idx)
865
+
866
+ # Convert token positions to string
867
+ input_ids = inputs['input_ids'][0]
868
+
869
+ # Extract answer tokens
870
+ answer_tokens = input_ids[start_idx : end_idx + 1]
871
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
872
+
873
+ # Calculate confidence score
874
+ start_probs = torch.softmax(start_scores, dim=1)
875
+ end_probs = torch.softmax(end_scores, dim=1)
876
+ confidence = float(start_probs[0][start_idx] * end_probs[0][end_idx])
877
+
878
+ return {'answer': answer.strip(), 'score': confidence, 'start': int(start_idx), 'end': int(end_idx)}
879
+
880
+
881
+ @pxt.udf(batch_size=8)
882
+ def translation(
883
+ text: Batch[str], *, model_id: str, src_lang: Optional[str] = None, target_lang: Optional[str] = None
884
+ ) -> Batch[str]:
885
+ """
886
+ Translates text using a pretrained translation model. `model_id` should be a reference to a pretrained
887
+ [translation model](https://huggingface.co/models?pipeline_tag=translation) such as MarianMT or T5.
888
+
889
+ __Requirements:__
890
+
891
+ - `pip install torch transformers sentencepiece`
892
+
893
+ Args:
894
+ text: The text to translate.
895
+ model_id: The pretrained translation model to use.
896
+ src_lang: Source language code (optional, can be inferred from model).
897
+ target_lang: Target language code (optional, can be inferred from model).
898
+
899
+ Returns:
900
+ The translated text.
901
+
902
+ Examples:
903
+ Add a computed column that translates text:
904
+
905
+ >>> tbl.add_computed_column(french_text=translation(
906
+ ... tbl.english_text,
907
+ ... model_id='Helsinki-NLP/opus-mt-en-fr',
908
+ ... src_lang='en',
909
+ ... target_lang='fr'
910
+ ... ))
911
+ """
912
+ env.Env.get().require_package('transformers')
913
+ device = resolve_torch_device('auto')
914
+ import torch
915
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
916
+
917
+ model = _lookup_model(model_id, AutoModelForSeq2SeqLM.from_pretrained, device=device)
918
+ tokenizer = _lookup_processor(model_id, AutoTokenizer.from_pretrained)
919
+ lang_code_to_id: dict | None = getattr(tokenizer, 'lang_code_to_id', {})
920
+
921
+ # Language validation - following speech2text_for_conditional_generation pattern
922
+ if src_lang is not None and src_lang not in lang_code_to_id:
923
+ raise excs.Error(
924
+ f'Source language code {src_lang!r} is not supported by the model {model_id!r}. '
925
+ f'Supported languages are: {list(lang_code_to_id.keys())}'
926
+ )
927
+
928
+ if target_lang is not None and target_lang not in lang_code_to_id:
929
+ raise excs.Error(
930
+ f'Target language code {target_lang!r} is not supported by the model {model_id!r}. '
931
+ f'Supported languages are: {list(lang_code_to_id.keys())}'
932
+ )
933
+
934
+ with torch.no_grad():
935
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
936
+
937
+ # Set forced_bos_token_id for target language if supported
938
+ generate_kwargs = {'max_length': 512, 'num_beams': 4, 'early_stopping': True}
939
+
940
+ if target_lang is not None:
941
+ generate_kwargs['forced_bos_token_id'] = lang_code_to_id[target_lang]
942
+
943
+ outputs = model.generate(**inputs.to(device), **generate_kwargs)
944
+
945
+ # Decode all outputs at once
946
+ translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
947
+ return translations
948
+
949
+
950
+ @pxt.udf
951
+ def text_to_image(
952
+ prompt: str,
953
+ *,
954
+ model_id: str,
955
+ height: int = 512,
956
+ width: int = 512,
957
+ seed: Optional[int] = None,
958
+ model_kwargs: Optional[dict[str, Any]] = None,
959
+ ) -> PIL.Image.Image:
960
+ """
961
+ Generates images from text prompts using a pretrained text-to-image model. `model_id` should be a reference to a
962
+ pretrained [text-to-image model](https://huggingface.co/models?pipeline_tag=text-to-image) such as
963
+ Stable Diffusion or FLUX.
964
+
965
+ __Requirements:__
966
+
967
+ - `pip install torch transformers diffusers accelerate`
968
+
969
+ Args:
970
+ prompt: The text prompt describing the desired image.
971
+ model_id: The pretrained text-to-image model to use.
972
+ height: Height of the generated image in pixels.
973
+ width: Width of the generated image in pixels.
974
+ seed: Optional random seed for reproducibility.
975
+ model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
976
+ `guidance_scale`, or `negative_prompt`.
977
+
978
+ Returns:
979
+ The generated Image.
980
+
981
+ Examples:
982
+ Add a computed column that generates images from text prompts:
983
+
984
+ >>> tbl.add_computed_column(generated_image=text_to_image(
985
+ ... tbl.prompt,
986
+ ... model_id='stable-diffusion-v1.5/stable-diffusion-v1-5',
987
+ ... height=512,
988
+ ... width=512,
989
+ ... model_kwargs={'num_inference_steps': 25},
990
+ ... ))
991
+ """
992
+ env.Env.get().require_package('transformers')
993
+ env.Env.get().require_package('diffusers')
994
+ env.Env.get().require_package('accelerate')
995
+ device = resolve_torch_device('auto', allow_mps=False)
996
+ import torch
997
+ from diffusers import AutoPipelineForText2Image
998
+
999
+ if model_kwargs is None:
1000
+ model_kwargs = {}
1001
+
1002
+ # Parameter validation - following best practices pattern
1003
+ if height <= 0 or width <= 0:
1004
+ raise excs.Error(f'Height ({height}) and width ({width}) must be positive integers')
1005
+
1006
+ if height % 8 != 0 or width % 8 != 0:
1007
+ raise excs.Error(f'Height ({height}) and width ({width}) must be divisible by 8 for most diffusion models')
1008
+
1009
+ pipeline = _lookup_model(
1010
+ model_id,
1011
+ lambda x: AutoPipelineForText2Image.from_pretrained(
1012
+ x,
1013
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1014
+ device_map='auto' if device == 'cuda' else None,
1015
+ safety_checker=None, # Disable safety checker for performance
1016
+ requires_safety_checker=False,
1017
+ ),
1018
+ device=device,
1019
+ )
1020
+
1021
+ try:
1022
+ if device == 'cuda' and hasattr(pipeline, 'enable_model_cpu_offload'):
1023
+ pipeline.enable_model_cpu_offload()
1024
+ if hasattr(pipeline, 'enable_memory_efficient_attention'):
1025
+ pipeline.enable_memory_efficient_attention()
1026
+ except Exception:
1027
+ pass # Ignore optimization failures
1028
+
1029
+ generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
1030
+
1031
+ with torch.no_grad():
1032
+ result = pipeline(prompt, height=height, width=width, generator=generator, **model_kwargs)
1033
+ return result.images[0]
1034
+
1035
+
1036
+ @pxt.udf
1037
+ def text_to_speech(
1038
+ text: str, *, model_id: str, speaker_id: Optional[int] = None, vocoder: Optional[str] = None
1039
+ ) -> pxt.Audio:
1040
+ """
1041
+ Converts text to speech using a pretrained TTS model. `model_id` should be a reference to a
1042
+ pretrained [text-to-speech model](https://huggingface.co/models?pipeline_tag=text-to-speech).
1043
+
1044
+ __Requirements:__
1045
+
1046
+ - `pip install torch transformers datasets soundfile`
1047
+
1048
+ Args:
1049
+ text: The text to convert to speech.
1050
+ model_id: The pretrained TTS model to use.
1051
+ speaker_id: Speaker ID for multi-speaker models.
1052
+ vocoder: Optional vocoder model for higher quality audio.
1053
+
1054
+ Returns:
1055
+ The generated audio file.
1056
+
1057
+ Examples:
1058
+ Add a computed column that converts text to speech:
1059
+
1060
+ >>> tbl.add_computed_column(audio=text_to_speech(
1061
+ ... tbl.text_content,
1062
+ ... model_id='microsoft/speecht5_tts',
1063
+ ... speaker_id=0
1064
+ ... ))
1065
+ """
1066
+ env.Env.get().require_package('transformers')
1067
+ env.Env.get().require_package('datasets')
1068
+ env.Env.get().require_package('soundfile')
1069
+ device = resolve_torch_device('auto')
1070
+ import soundfile as sf # type: ignore[import-untyped]
1071
+ import torch
1072
+ from datasets import load_dataset # type: ignore[import-untyped]
1073
+ from transformers import (
1074
+ AutoModelForTextToWaveform,
1075
+ AutoProcessor,
1076
+ BarkModel,
1077
+ SpeechT5ForTextToSpeech,
1078
+ SpeechT5HifiGan,
1079
+ SpeechT5Processor,
1080
+ )
1081
+
1082
+ # Model loading with error handling - following best practices pattern
1083
+ if 'speecht5' in model_id.lower():
1084
+ model = _lookup_model(model_id, SpeechT5ForTextToSpeech.from_pretrained, device=device)
1085
+ processor = _lookup_processor(model_id, SpeechT5Processor.from_pretrained)
1086
+ vocoder_model_id = vocoder or 'microsoft/speecht5_hifigan'
1087
+ vocoder_model = _lookup_model(vocoder_model_id, SpeechT5HifiGan.from_pretrained, device=device)
1088
+
1089
+ elif 'bark' in model_id.lower():
1090
+ model = _lookup_model(model_id, BarkModel.from_pretrained, device=device)
1091
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1092
+ vocoder_model = None
1093
+
1094
+ else:
1095
+ model = _lookup_model(model_id, AutoModelForTextToWaveform.from_pretrained, device=device)
1096
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1097
+ vocoder_model = None
1098
+
1099
+ # Load speaker embeddings once for SpeechT5 (following speech2text pattern)
1100
+ speaker_embeddings = None
1101
+ if 'speecht5' in model_id.lower():
1102
+ embeddings_dataset = load_dataset('Matthijs/cmu-arctic-xvectors', split='validation')
1103
+ speaker_embeddings = torch.tensor(embeddings_dataset[speaker_id or 7306]['xvector']).unsqueeze(0).to(device)
1104
+
1105
+ with torch.no_grad():
1106
+ # Generate speech based on model type
1107
+ if 'speecht5' in model_id.lower():
1108
+ inputs = processor(text=text, return_tensors='pt').to(device)
1109
+ speech = model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=vocoder_model)
1110
+ audio_np = speech.cpu().numpy()
1111
+ sample_rate = 16000
1112
+
1113
+ elif 'bark' in model_id.lower():
1114
+ inputs = processor(text, return_tensors='pt').to(device)
1115
+ audio_array = model.generate(**inputs)
1116
+ audio_np = audio_array.cpu().numpy().squeeze()
1117
+ sample_rate = getattr(model.generation_config, 'sample_rate', 24000)
1118
+
1119
+ else:
1120
+ # Generic approach for other TTS models
1121
+ inputs = processor(text, return_tensors='pt').to(device)
1122
+ audio_output = model(**inputs)
1123
+ audio_np = audio_output.waveform.cpu().numpy().squeeze()
1124
+ sample_rate = getattr(model.config, 'sample_rate', 22050)
1125
+
1126
+ # Normalize audio - following consistent pattern
1127
+ if audio_np.dtype != np.float32:
1128
+ audio_np = audio_np.astype(np.float32)
1129
+
1130
+ if np.max(np.abs(audio_np)) > 0:
1131
+ audio_np = audio_np / np.max(np.abs(audio_np)) * 0.9
1132
+
1133
+ # Create output file
1134
+ output_filename = str(TempStore.create_path(extension='.wav'))
1135
+ sf.write(output_filename, audio_np, sample_rate, format='WAV', subtype='PCM_16')
1136
+ return output_filename
1137
+
1138
+
1139
+ @pxt.udf
1140
+ def image_to_image(
1141
+ image: PIL.Image.Image,
1142
+ prompt: str,
1143
+ *,
1144
+ model_id: str,
1145
+ seed: Optional[int] = None,
1146
+ model_kwargs: Optional[dict[str, Any]] = None,
1147
+ ) -> PIL.Image.Image:
1148
+ """
1149
+ Transforms input images based on text prompts using a pretrained image-to-image model.
1150
+ `model_id` should be a reference to a pretrained
1151
+ [image-to-image model](https://huggingface.co/models?pipeline_tag=image-to-image).
1152
+
1153
+ __Requirements:__
1154
+
1155
+ - `pip install torch transformers diffusers accelerate`
1156
+
1157
+ Args:
1158
+ image: The input image to transform.
1159
+ prompt: The text prompt describing the desired transformation.
1160
+ model_id: The pretrained image-to-image model to use.
1161
+ seed: Random seed for reproducibility.
1162
+ model_kwargs: Additional keyword arguments to pass to the model, such as `strength`,
1163
+ `guidance_scale`, or `num_inference_steps`.
1164
+
1165
+ Returns:
1166
+ The transformed image.
1167
+
1168
+ Examples:
1169
+ Add a computed column that transforms images based on prompts:
1170
+
1171
+ >>> tbl.add_computed_column(transformed=image_to_image(
1172
+ ... tbl.source_image,
1173
+ ... tbl.transformation_prompt,
1174
+ ... model_id='runwayml/stable-diffusion-v1-5'
1175
+ ... ))
1176
+ """
1177
+ env.Env.get().require_package('transformers')
1178
+ env.Env.get().require_package('diffusers')
1179
+ env.Env.get().require_package('accelerate')
1180
+ device = resolve_torch_device('auto')
1181
+ import torch
1182
+ from diffusers import StableDiffusionImg2ImgPipeline
1183
+
1184
+ if model_kwargs is None:
1185
+ model_kwargs = {}
1186
+
1187
+ pipe = _lookup_model(
1188
+ model_id,
1189
+ lambda x: StableDiffusionImg2ImgPipeline.from_pretrained(
1190
+ x,
1191
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1192
+ safety_checker=None,
1193
+ requires_safety_checker=False,
1194
+ ),
1195
+ device=device,
1196
+ )
1197
+
1198
+ try:
1199
+ if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
1200
+ pipe.enable_model_cpu_offload()
1201
+ if hasattr(pipe, 'enable_memory_efficient_attention'):
1202
+ pipe.enable_memory_efficient_attention()
1203
+ except Exception:
1204
+ pass # Ignore optimization failures
1205
+
1206
+ generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
1207
+
1208
+ processed_image = image.convert('RGB')
1209
+
1210
+ with torch.no_grad():
1211
+ result = pipe(prompt=prompt, image=processed_image, generator=generator, **model_kwargs)
1212
+ return result.images[0]
1213
+
1214
+
1215
+ @pxt.udf
1216
+ def automatic_speech_recognition(
1217
+ audio: pxt.Audio,
1218
+ *,
1219
+ model_id: str,
1220
+ language: Optional[str] = None,
1221
+ chunk_length_s: Optional[int] = None,
1222
+ return_timestamps: bool = False,
1223
+ ) -> str:
1224
+ """
1225
+ Transcribes speech to text using a pretrained ASR model. `model_id` should be a reference to a
1226
+ pretrained [automatic-speech-recognition model](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition).
1227
+
1228
+ This is a **generic function** that works with many ASR model families. For production use with
1229
+ specific models, consider specialized functions like `whisper.transcribe()` or
1230
+ `speech2text_for_conditional_generation()`.
1231
+
1232
+ __Requirements:__
1233
+
1234
+ - `pip install torch transformers torchaudio`
1235
+
1236
+ __Recommended Models:__
1237
+
1238
+ - **OpenAI Whisper**: `openai/whisper-tiny.en`, `openai/whisper-small`, `openai/whisper-base`
1239
+ - **Facebook Wav2Vec2**: `facebook/wav2vec2-base-960h`, `facebook/wav2vec2-large-960h-lv60-self`
1240
+ - **Microsoft SpeechT5**: `microsoft/speecht5_asr`
1241
+ - **Meta MMS (Multilingual)**: `facebook/mms-1b-all`
1242
+
1243
+ Args:
1244
+ audio: The audio file(s) to transcribe.
1245
+ model_id: The pretrained ASR model to use.
1246
+ language: Language code for multilingual models (e.g., 'en', 'es', 'fr').
1247
+ chunk_length_s: Maximum length of audio chunks in seconds for long audio processing.
1248
+ return_timestamps: Whether to return word-level timestamps (model dependent).
1249
+
1250
+ Returns:
1251
+ The transcribed text.
1252
+
1253
+ Examples:
1254
+ Add a computed column that transcribes audio files:
1255
+
1256
+ >>> tbl.add_computed_column(transcription=automatic_speech_recognition(
1257
+ ... tbl.audio_file,
1258
+ ... model_id='openai/whisper-tiny.en' # Recommended
1259
+ ... ))
1260
+
1261
+ Transcribe with language specification:
1262
+
1263
+ >>> tbl.add_computed_column(transcription=automatic_speech_recognition(
1264
+ ... tbl.audio_file,
1265
+ ... model_id='facebook/mms-1b-all',
1266
+ ... language='en'
1267
+ ... ))
1268
+ """
1269
+ env.Env.get().require_package('transformers')
1270
+ env.Env.get().require_package('torchaudio')
1271
+ device = resolve_torch_device('auto', allow_mps=False) # Following speech2text pattern
1272
+ import torch
1273
+ import torchaudio
1274
+
1275
+ # Try to load model and processor using direct model loading - following speech2text pattern
1276
+ # Handle different ASR model types
1277
+ if 'whisper' in model_id.lower():
1278
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
1279
+
1280
+ model = _lookup_model(model_id, WhisperForConditionalGeneration.from_pretrained, device=device)
1281
+ processor = _lookup_processor(model_id, WhisperProcessor.from_pretrained)
1282
+
1283
+ # Language validation for Whisper - following speech2text pattern
1284
+ if language is not None and hasattr(processor.tokenizer, 'get_decoder_prompt_ids'):
1285
+ try:
1286
+ # Test if language is supported
1287
+ _ = processor.tokenizer.get_decoder_prompt_ids(language=language)
1288
+ except Exception:
1289
+ raise excs.Error(
1290
+ f"Language code '{language}' is not supported by Whisper model '{model_id}'. "
1291
+ f"Try common codes like 'en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh'."
1292
+ ) from None
1293
+
1294
+ elif 'wav2vec2' in model_id.lower():
1295
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
1296
+
1297
+ model = _lookup_model(model_id, Wav2Vec2ForCTC.from_pretrained, device=device)
1298
+ processor = _lookup_processor(model_id, Wav2Vec2Processor.from_pretrained)
1299
+
1300
+ elif 'speech_to_text' in model_id.lower() or 's2t' in model_id.lower():
1301
+ # Use the existing speech2text function for these models
1302
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
1303
+
1304
+ model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
1305
+ processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
1306
+
1307
+ else:
1308
+ # Generic fallback using Auto classes
1309
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
1310
+
1311
+ try:
1312
+ model = _lookup_model(model_id, AutoModelForSpeechSeq2Seq.from_pretrained, device=device)
1313
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1314
+ except Exception:
1315
+ # Fallback to CTC models
1316
+ from transformers import AutoModelForCTC
1317
+
1318
+ model = _lookup_model(model_id, AutoModelForCTC.from_pretrained, device=device)
1319
+ processor = _lookup_processor(model_id, AutoProcessor.from_pretrained)
1320
+
1321
+ # Get model's expected sampling rate - following speech2text pattern
1322
+ model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
1323
+
1324
+ # Load and preprocess audio - following speech2text pattern
1325
+ waveform, sampling_rate = torchaudio.load(audio)
1326
+
1327
+ # Resample if necessary
1328
+ if sampling_rate != model_sampling_rate:
1329
+ waveform = torchaudio.transforms.Resample(sampling_rate, model_sampling_rate)(waveform)
1330
+
1331
+ # Convert to mono if stereo
1332
+ if waveform.dim() == 2:
1333
+ waveform = torch.mean(waveform, dim=0)
1334
+ assert waveform.dim() == 1
1335
+
1336
+ with torch.no_grad():
1337
+ # Process audio with the model
1338
+ inputs = processor(waveform, sampling_rate=model_sampling_rate, return_tensors='pt')
1339
+
1340
+ # Handle different model types for generation
1341
+ if 'whisper' in model_id.lower():
1342
+ # Whisper-specific generation
1343
+ generate_kwargs = {}
1344
+ if language is not None:
1345
+ generate_kwargs['language'] = language
1346
+ if return_timestamps:
1347
+ generate_kwargs['return_timestamps'] = 'word' if return_timestamps else None
1348
+
1349
+ generated_ids = model.generate(**inputs.to(device), **generate_kwargs)
1350
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1351
+
1352
+ elif hasattr(model, 'generate'):
1353
+ # Seq2Seq models (Speech2Text, etc.)
1354
+ generated_ids = model.generate(**inputs.to(device))
1355
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1356
+
1357
+ else:
1358
+ # CTC models (Wav2Vec2, etc.)
1359
+ logits = model(**inputs.to(device)).logits
1360
+ predicted_ids = torch.argmax(logits, dim=-1)
1361
+ transcription = processor.batch_decode(predicted_ids)[0]
1362
+
1363
+ return transcription.strip()
1364
+
1365
+
1366
+ @pxt.udf
1367
+ def image_to_video(
1368
+ image: PIL.Image.Image,
1369
+ *,
1370
+ model_id: str,
1371
+ num_frames: int = 25,
1372
+ fps: int = 6,
1373
+ seed: Optional[int] = None,
1374
+ model_kwargs: Optional[dict[str, Any]] = None,
1375
+ ) -> pxt.Video:
1376
+ """
1377
+ Generates videos from input images using a pretrained image-to-video model.
1378
+ `model_id` should be a reference to a pretrained
1379
+ [image-to-video model](https://huggingface.co/models?pipeline_tag=image-to-video).
1380
+
1381
+ __Requirements:__
1382
+
1383
+ - `pip install torch transformers diffusers accelerate`
1384
+
1385
+ Args:
1386
+ image: The input image to animate into a video.
1387
+ model_id: The pretrained image-to-video model to use.
1388
+ num_frames: Number of video frames to generate.
1389
+ fps: Frames per second for the output video.
1390
+ seed: Random seed for reproducibility.
1391
+ model_kwargs: Additional keyword arguments to pass to the model, such as `num_inference_steps`,
1392
+ `motion_bucket_id`, or `guidance_scale`.
1393
+
1394
+ Returns:
1395
+ The generated video file.
1396
+
1397
+ Examples:
1398
+ Add a computed column that creates videos from images:
1399
+
1400
+ >>> tbl.add_computed_column(video=image_to_video(
1401
+ ... tbl.input_image,
1402
+ ... model_id='stabilityai/stable-video-diffusion-img2vid-xt',
1403
+ ... num_frames=25,
1404
+ ... fps=7
1405
+ ... ))
1406
+ """
1407
+ env.Env.get().require_package('transformers')
1408
+ env.Env.get().require_package('diffusers')
1409
+ env.Env.get().require_package('accelerate')
1410
+ device = resolve_torch_device('auto', allow_mps=False)
1411
+ import numpy as np
1412
+ import torch
1413
+ from diffusers import StableVideoDiffusionPipeline
1414
+
1415
+ if model_kwargs is None:
1416
+ model_kwargs = {}
1417
+
1418
+ # Parameter validation - following best practices pattern
1419
+ if num_frames < 1:
1420
+ raise excs.Error(f'num_frames must be at least 1, got {num_frames}')
1421
+
1422
+ if num_frames > 25:
1423
+ raise excs.Error(f'num_frames cannot exceed 25 for most video diffusion models, got {num_frames}')
1424
+
1425
+ if fps < 1:
1426
+ raise excs.Error(f'fps must be at least 1, got {fps}')
1427
+
1428
+ if fps > 60:
1429
+ raise excs.Error(f'fps should not exceed 60 for reasonable video generation, got {fps}')
1430
+
1431
+ pipe = _lookup_model(
1432
+ model_id,
1433
+ lambda x: StableVideoDiffusionPipeline.from_pretrained(
1434
+ x,
1435
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1436
+ variant='fp16' if device == 'cuda' else None,
1437
+ ),
1438
+ device=device,
1439
+ )
1440
+
1441
+ try:
1442
+ if device == 'cuda' and hasattr(pipe, 'enable_model_cpu_offload'):
1443
+ pipe.enable_model_cpu_offload()
1444
+ if hasattr(pipe, 'enable_memory_efficient_attention'):
1445
+ pipe.enable_memory_efficient_attention()
1446
+ except Exception:
1447
+ pass # Ignore optimization failures
1448
+
1449
+ generator = None if seed is None else torch.Generator(device=device).manual_seed(seed)
1450
+
1451
+ # Ensure image is in RGB mode and proper size
1452
+ processed_image = image.convert('RGB')
1453
+ target_width, target_height = 512, 320
1454
+ processed_image = processed_image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)
1455
+
1456
+ # Generate video frames with proper error handling
1457
+ with torch.no_grad():
1458
+ result = pipe(image=processed_image, num_frames=num_frames, generator=generator, **model_kwargs)
1459
+ frames = result.frames[0]
1460
+
1461
+ # Create output video file
1462
+ output_path = str(TempStore.create_path(extension='.mp4'))
1463
+
1464
+ with av.open(output_path, mode='w') as container:
1465
+ stream = container.add_stream('h264', rate=fps)
1466
+ stream.width = target_width
1467
+ stream.height = target_height
1468
+ stream.pix_fmt = 'yuv420p'
1469
+
1470
+ # Set codec options for better compatibility
1471
+ stream.codec_context.options = {'crf': '23', 'preset': 'medium'}
1472
+
1473
+ for frame_pil in frames:
1474
+ # Convert PIL to numpy array
1475
+ frame_array = np.array(frame_pil)
1476
+ # Create av VideoFrame
1477
+ av_frame = av.VideoFrame.from_ndarray(frame_array, format='rgb24')
1478
+ # Encode and mux
1479
+ for packet in stream.encode(av_frame):
1480
+ container.mux(packet)
1481
+
1482
+ # Flush encoder
1483
+ for packet in stream.encode():
1484
+ container.mux(packet)
1485
+
1486
+ return output_path
458
1487
 
459
1488
 
460
1489
  def _lookup_model(