spark-nlp 5.5.2__py2.py3-none-any.whl → 6.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of spark-nlp might be problematic. Click here for more details.

Files changed (34) hide show
  1. {spark_nlp-5.5.2.dist-info → spark_nlp-6.0.0.dist-info}/METADATA +20 -11
  2. {spark_nlp-5.5.2.dist-info → spark_nlp-6.0.0.dist-info}/RECORD +33 -18
  3. sparknlp/__init__.py +2 -2
  4. sparknlp/annotator/classifier_dl/__init__.py +4 -0
  5. sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py +161 -0
  6. sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +2 -2
  7. sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py +161 -0
  8. sparknlp/annotator/classifier_dl/roberta_for_multiple_choice.py +161 -0
  9. sparknlp/annotator/classifier_dl/xlm_roberta_for_multiple_choice.py +149 -0
  10. sparknlp/annotator/cleaners/__init__.py +15 -0
  11. sparknlp/annotator/cleaners/cleaner.py +202 -0
  12. sparknlp/annotator/cleaners/extractor.py +191 -0
  13. sparknlp/annotator/cv/__init__.py +6 -1
  14. sparknlp/annotator/cv/janus_for_multimodal.py +356 -0
  15. sparknlp/annotator/cv/llava_for_multimodal.py +328 -0
  16. sparknlp/annotator/cv/mllama_for_multimodal.py +340 -0
  17. sparknlp/annotator/cv/phi3_vision_for_multimodal.py +328 -0
  18. sparknlp/annotator/cv/qwen2vl_transformer.py +332 -0
  19. sparknlp/annotator/embeddings/auto_gguf_embeddings.py +10 -6
  20. sparknlp/annotator/embeddings/bge_embeddings.py +7 -3
  21. sparknlp/annotator/seq2seq/__init__.py +3 -0
  22. sparknlp/annotator/seq2seq/auto_gguf_model.py +8 -503
  23. sparknlp/annotator/seq2seq/auto_gguf_vision_model.py +333 -0
  24. sparknlp/annotator/seq2seq/cohere_transformer.py +357 -0
  25. sparknlp/annotator/seq2seq/llama3_transformer.py +4 -4
  26. sparknlp/annotator/seq2seq/olmo_transformer.py +326 -0
  27. sparknlp/base/image_assembler.py +58 -0
  28. sparknlp/common/properties.py +632 -96
  29. sparknlp/internal/__init__.py +100 -2
  30. sparknlp/reader/pdf_to_text.py +65 -0
  31. sparknlp/reader/sparknlp_reader.py +260 -60
  32. spark_nlp-5.5.2.dist-info/.uuid +0 -1
  33. {spark_nlp-5.5.2.dist-info → spark_nlp-6.0.0.dist-info}/WHEEL +0 -0
  34. {spark_nlp-5.5.2.dist-info → spark_nlp-6.0.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Contains classes for Annotator properties."""
15
+ from typing import List, Dict
15
16
 
16
17
  from pyspark.ml.param import Param, Params, TypeConverters
17
18
 
@@ -67,6 +68,33 @@ class HasCaseSensitiveProperties:
67
68
  return self.getOrDefault(self.caseSensitive)
68
69
 
69
70
 
71
+ class HasClsTokenProperties:
72
+ useCLSToken = Param(Params._dummy(),
73
+ "useCLSToken",
74
+ "Whether to use CLS token for pooling (true) or attention-based average pooling (false)",
75
+ typeConverter=TypeConverters.toBoolean)
76
+
77
+ def setUseCLSToken(self, value):
78
+ """Sets whether to ignore case in tokens for embeddings matching.
79
+
80
+ Parameters
81
+ ----------
82
+ value : bool
83
+ Whether to use CLS token for pooling (true) or attention-based average pooling (false)
84
+ """
85
+ return self._set(useCLSToken=value)
86
+
87
+ def getUseCLSToken(self):
88
+ """Gets whether to use CLS token for pooling (true) or attention-based average pooling (false)
89
+
90
+ Returns
91
+ -------
92
+ bool
93
+ Whether to use CLS token for pooling (true) or attention-based average pooling (false)
94
+ """
95
+ return self.getOrDefault(self.useCLSToken)
96
+
97
+
70
98
  class HasClassifierActivationProperties:
71
99
  activation = Param(Params._dummy(),
72
100
  "activation",
@@ -601,133 +629,641 @@ class HasGeneratorProperties:
601
629
  typeConverter=TypeConverters.toInt)
602
630
 
603
631
 
604
- def setTask(self, value):
605
- """Sets the transformer's task, e.g. ``summarize:``.
632
+ def setTask(self, value):
633
+ """Sets the transformer's task, e.g. ``summarize:``.
634
+
635
+ Parameters
636
+ ----------
637
+ value : str
638
+ The transformer's task
639
+ """
640
+ return self._set(task=value)
606
641
 
607
- Parameters
608
- ----------
609
- value : str
610
- The transformer's task
611
- """
612
- return self._set(task=value)
613
642
 
643
+ def setMinOutputLength(self, value):
644
+ """Sets minimum length of the sequence to be generated.
645
+
646
+ Parameters
647
+ ----------
648
+ value : int
649
+ Minimum length of the sequence to be generated
650
+ """
651
+ return self._set(minOutputLength=value)
652
+
653
+
654
+ def setMaxOutputLength(self, value):
655
+ """Sets maximum length of output text.
656
+
657
+ Parameters
658
+ ----------
659
+ value : int
660
+ Maximum length of output text
661
+ """
662
+ return self._set(maxOutputLength=value)
663
+
664
+
665
+ def setDoSample(self, value):
666
+ """Sets whether or not to use sampling, use greedy decoding otherwise.
667
+
668
+ Parameters
669
+ ----------
670
+ value : bool
671
+ Whether or not to use sampling; use greedy decoding otherwise
672
+ """
673
+ return self._set(doSample=value)
614
674
 
615
- def setMinOutputLength(self, value):
616
- """Sets minimum length of the sequence to be generated.
617
675
 
618
- Parameters
619
- ----------
620
- value : int
621
- Minimum length of the sequence to be generated
622
- """
623
- return self._set(minOutputLength=value)
676
+ def setTemperature(self, value):
677
+ """Sets the value used to module the next token probabilities.
624
678
 
679
+ Parameters
680
+ ----------
681
+ value : float
682
+ The value used to module the next token probabilities
683
+ """
684
+ return self._set(temperature=value)
625
685
 
626
- def setMaxOutputLength(self, value):
627
- """Sets maximum length of output text.
628
686
 
629
- Parameters
630
- ----------
631
- value : int
632
- Maximum length of output text
633
- """
634
- return self._set(maxOutputLength=value)
687
+ def setTopK(self, value):
688
+ """Sets the number of highest probability vocabulary tokens to keep for
689
+ top-k-filtering.
635
690
 
691
+ Parameters
692
+ ----------
693
+ value : int
694
+ Number of highest probability vocabulary tokens to keep
695
+ """
696
+ return self._set(topK=value)
636
697
 
637
- def setDoSample(self, value):
638
- """Sets whether or not to use sampling, use greedy decoding otherwise.
639
698
 
640
- Parameters
641
- ----------
642
- value : bool
643
- Whether or not to use sampling; use greedy decoding otherwise
644
- """
645
- return self._set(doSample=value)
699
+ def setTopP(self, value):
700
+ """Sets the top cumulative probability for vocabulary tokens.
646
701
 
702
+ If set to float < 1, only the most probable tokens with probabilities
703
+ that add up to ``topP`` or higher are kept for generation.
647
704
 
648
- def setTemperature(self, value):
649
- """Sets the value used to module the next token probabilities.
705
+ Parameters
706
+ ----------
707
+ value : float
708
+ Cumulative probability for vocabulary tokens
709
+ """
710
+ return self._set(topP=value)
650
711
 
651
- Parameters
652
- ----------
653
- value : float
654
- The value used to module the next token probabilities
655
- """
656
- return self._set(temperature=value)
657
712
 
713
+ def setRepetitionPenalty(self, value):
714
+ """Sets the parameter for repetition penalty. 1.0 means no penalty.
658
715
 
659
- def setTopK(self, value):
660
- """Sets the number of highest probability vocabulary tokens to keep for
661
- top-k-filtering.
716
+ Parameters
717
+ ----------
718
+ value : float
719
+ The repetition penalty
662
720
 
663
- Parameters
664
- ----------
665
- value : int
666
- Number of highest probability vocabulary tokens to keep
667
- """
668
- return self._set(topK=value)
721
+ References
722
+ ----------
723
+ See `Ctrl: A Conditional Transformer Language Model For Controllable
724
+ Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
725
+ """
726
+ return self._set(repetitionPenalty=value)
669
727
 
670
728
 
671
- def setTopP(self, value):
672
- """Sets the top cumulative probability for vocabulary tokens.
729
+ def setNoRepeatNgramSize(self, value):
730
+ """Sets size of n-grams that can only occur once.
673
731
 
674
- If set to float < 1, only the most probable tokens with probabilities
675
- that add up to ``topP`` or higher are kept for generation.
732
+ If set to int > 0, all ngrams of that size can only occur once.
676
733
 
677
- Parameters
678
- ----------
679
- value : float
680
- Cumulative probability for vocabulary tokens
681
- """
682
- return self._set(topP=value)
734
+ Parameters
735
+ ----------
736
+ value : int
737
+ N-gram size can only occur once
738
+ """
739
+ return self._set(noRepeatNgramSize=value)
683
740
 
684
741
 
685
- def setRepetitionPenalty(self, value):
686
- """Sets the parameter for repetition penalty. 1.0 means no penalty.
742
+ def setBeamSize(self, value):
743
+ """Sets the number of beam size for beam search.
687
744
 
688
- Parameters
689
- ----------
690
- value : float
691
- The repetition penalty
745
+ Parameters
746
+ ----------
747
+ value : int
748
+ Number of beam size for beam search
749
+ """
750
+ return self._set(beamSize=value)
692
751
 
693
- References
694
- ----------
695
- See `Ctrl: A Conditional Transformer Language Model For Controllable
696
- Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
697
- """
698
- return self._set(repetitionPenalty=value)
699
752
 
753
+ def setNReturnSequences(self, value):
754
+ """Sets the number of sequences to return from the beam search.
700
755
 
701
- def setNoRepeatNgramSize(self, value):
702
- """Sets size of n-grams that can only occur once.
756
+ Parameters
757
+ ----------
758
+ value : int
759
+ Number of sequences to return
760
+ """
761
+ return self._set(nReturnSequences=value)
703
762
 
704
- If set to int > 0, all ngrams of that size can only occur once.
705
763
 
706
- Parameters
707
- ----------
708
- value : int
709
- N-gram size can only occur once
710
- """
711
- return self._set(noRepeatNgramSize=value)
764
+ class HasLlamaCppProperties:
765
+ # -------- MODEl PARAMETERS --------
766
+ nThreads = Param(Params._dummy(), "nThreads", "Set the number of threads to use during generation",
767
+ typeConverter=TypeConverters.toInt)
768
+ nThreadsDraft = Param(Params._dummy(), "nThreadsDraft", "Set the number of threads to use during draft generation",
769
+ typeConverter=TypeConverters.toInt)
770
+ nThreadsBatch = Param(Params._dummy(), "nThreadsBatch",
771
+ "Set the number of threads to use during batch and prompt processing",
772
+ typeConverter=TypeConverters.toInt)
773
+ nThreadsBatchDraft = Param(Params._dummy(), "nThreadsBatchDraft",
774
+ "Set the number of threads to use during batch and prompt processing",
775
+ typeConverter=TypeConverters.toInt)
776
+ nCtx = Param(Params._dummy(), "nCtx", "Set the size of the prompt context", typeConverter=TypeConverters.toInt)
777
+ nBatch = Param(Params._dummy(), "nBatch",
778
+ "Set the logical batch size for prompt processing (must be >=32 to use BLAS)",
779
+ typeConverter=TypeConverters.toInt)
780
+ nUbatch = Param(Params._dummy(), "nUbatch",
781
+ "Set the physical batch size for prompt processing (must be >=32 to use BLAS)",
782
+ typeConverter=TypeConverters.toInt)
783
+ nDraft = Param(Params._dummy(), "nDraft", "Set the number of tokens to draft for speculative decoding",
784
+ typeConverter=TypeConverters.toInt)
785
+ nChunks = Param(Params._dummy(), "nChunks", "Set the maximal number of chunks to process",
786
+ typeConverter=TypeConverters.toInt)
787
+ nSequences = Param(Params._dummy(), "nSequences", "Set the number of sequences to decode",
788
+ typeConverter=TypeConverters.toInt)
789
+ pSplit = Param(Params._dummy(), "pSplit", "Set the speculative decoding split probability",
790
+ typeConverter=TypeConverters.toFloat)
791
+ nGpuLayers = Param(Params._dummy(), "nGpuLayers", "Set the number of layers to store in VRAM (-1 - use default)",
792
+ typeConverter=TypeConverters.toInt)
793
+ nGpuLayersDraft = Param(Params._dummy(), "nGpuLayersDraft",
794
+ "Set the number of layers to store in VRAM for the draft model (-1 - use default)",
795
+ typeConverter=TypeConverters.toInt)
796
+ # Set how to split the model across GPUs
797
+ #
798
+ # - NONE: No GPU split
799
+ # - LAYER: Split the model across GPUs by layer
800
+ # - ROW: Split the model across GPUs by rows
801
+ gpuSplitMode = Param(Params._dummy(), "gpuSplitMode", "Set how to split the model across GPUs",
802
+ typeConverter=TypeConverters.toString)
803
+ mainGpu = Param(Params._dummy(), "mainGpu", "Set the main GPU that is used for scratch and small tensors.",
804
+ typeConverter=TypeConverters.toInt)
805
+ tensorSplit = Param(Params._dummy(), "tensorSplit", "Set how split tensors should be distributed across GPUs",
806
+ typeConverter=TypeConverters.toListFloat)
807
+ grpAttnN = Param(Params._dummy(), "grpAttnN", "Set the group-attention factor", typeConverter=TypeConverters.toInt)
808
+ grpAttnW = Param(Params._dummy(), "grpAttnW", "Set the group-attention width", typeConverter=TypeConverters.toInt)
809
+ ropeFreqBase = Param(Params._dummy(), "ropeFreqBase", "Set the RoPE base frequency, used by NTK-aware scaling",
810
+ typeConverter=TypeConverters.toFloat)
811
+ ropeFreqScale = Param(Params._dummy(), "ropeFreqScale",
812
+ "Set the RoPE frequency scaling factor, expands context by a factor of 1/N",
813
+ typeConverter=TypeConverters.toFloat)
814
+ yarnExtFactor = Param(Params._dummy(), "yarnExtFactor", "Set the YaRN extrapolation mix factor",
815
+ typeConverter=TypeConverters.toFloat)
816
+ yarnAttnFactor = Param(Params._dummy(), "yarnAttnFactor", "Set the YaRN scale sqrt(t) or attention magnitude",
817
+ typeConverter=TypeConverters.toFloat)
818
+ yarnBetaFast = Param(Params._dummy(), "yarnBetaFast", "Set the YaRN low correction dim or beta",
819
+ typeConverter=TypeConverters.toFloat)
820
+ yarnBetaSlow = Param(Params._dummy(), "yarnBetaSlow", "Set the YaRN high correction dim or alpha",
821
+ typeConverter=TypeConverters.toFloat)
822
+ yarnOrigCtx = Param(Params._dummy(), "yarnOrigCtx", "Set the YaRN original context size of model",
823
+ typeConverter=TypeConverters.toInt)
824
+ defragmentationThreshold = Param(Params._dummy(), "defragmentationThreshold",
825
+ "Set the KV cache defragmentation threshold", typeConverter=TypeConverters.toFloat)
826
+ # Set optimization strategies that help on some NUMA systems (if available)
827
+ #
828
+ # Available Strategies:
829
+ #
830
+ # - DISABLED: No NUMA optimizations
831
+ # - DISTRIBUTE: Spread execution evenly over all
832
+ # - ISOLATE: Only spawn threads on CPUs on the node that execution started on
833
+ # - NUMA_CTL: Use the CPU map provided by numactl
834
+ # - MIRROR: Mirrors the model across NUMA nodes
835
+ numaStrategy = Param(Params._dummy(), "numaStrategy",
836
+ "Set optimization strategies that help on some NUMA systems (if available)",
837
+ typeConverter=TypeConverters.toString)
838
+ # Set the RoPE frequency scaling method, defaults to linear unless specified by the model.
839
+ #
840
+ # - UNSPECIFIED: Don't use any scaling
841
+ # - LINEAR: Linear scaling
842
+ # - YARN: YaRN RoPE scaling
843
+ ropeScalingType = Param(Params._dummy(), "ropeScalingType",
844
+ "Set the RoPE frequency scaling method, defaults to linear unless specified by the model",
845
+ typeConverter=TypeConverters.toString)
846
+ # Set the pooling type for embeddings, use model default if unspecified
847
+ #
848
+ # - 0 NONE: Don't use any pooling
849
+ # - 1 MEAN: Mean Pooling
850
+ # - 2 CLS: CLS Pooling
851
+ poolingType = Param(Params._dummy(), "poolingType",
852
+ "Set the pooling type for embeddings, use model default if unspecified",
853
+ typeConverter=TypeConverters.toString)
854
+ modelDraft = Param(Params._dummy(), "modelDraft", "Set the draft model for speculative decoding",
855
+ typeConverter=TypeConverters.toString)
856
+ modelAlias = Param(Params._dummy(), "modelAlias", "Set a model alias", typeConverter=TypeConverters.toString)
857
+ lookupCacheStaticFilePath = Param(Params._dummy(), "lookupCacheStaticFilePath",
858
+ "Set path to static lookup cache to use for lookup decoding (not updated by generation)",
859
+ typeConverter=TypeConverters.toString)
860
+ lookupCacheDynamicFilePath = Param(Params._dummy(), "lookupCacheDynamicFilePath",
861
+ "Set path to dynamic lookup cache to use for lookup decoding (updated by generation)",
862
+ typeConverter=TypeConverters.toString)
863
+ # loraAdapters = new StructFeature[Map[String, Float]](this, "loraAdapters")
864
+ embedding = Param(Params._dummy(), "embedding", "Whether to load model with embedding support",
865
+ typeConverter=TypeConverters.toBoolean)
866
+ flashAttention = Param(Params._dummy(), "flashAttention", "Whether to enable Flash Attention",
867
+ typeConverter=TypeConverters.toBoolean)
868
+ inputPrefixBos = Param(Params._dummy(), "inputPrefixBos",
869
+ "Whether to add prefix BOS to user inputs, preceding the `--in-prefix` string",
870
+ typeConverter=TypeConverters.toBoolean)
871
+ useMmap = Param(Params._dummy(), "useMmap",
872
+ "Whether to use memory-map model (faster load but may increase pageouts if not using mlock)",
873
+ typeConverter=TypeConverters.toBoolean)
874
+ useMlock = Param(Params._dummy(), "useMlock",
875
+ "Whether to force the system to keep model in RAM rather than swapping or compressing",
876
+ typeConverter=TypeConverters.toBoolean)
877
+ noKvOffload = Param(Params._dummy(), "noKvOffload", "Whether to disable KV offload",
878
+ typeConverter=TypeConverters.toBoolean)
879
+ systemPrompt = Param(Params._dummy(), "systemPrompt", "Set a system prompt to use",
880
+ typeConverter=TypeConverters.toString)
881
+ chatTemplate = Param(Params._dummy(), "chatTemplate", "The chat template to use",
882
+ typeConverter=TypeConverters.toString)
883
+
884
+ # -------- INFERENCE PARAMETERS --------
885
+ inputPrefix = Param(Params._dummy(), "inputPrefix", "Set the prompt to start generation with",
886
+ typeConverter=TypeConverters.toString)
887
+ inputSuffix = Param(Params._dummy(), "inputSuffix", "Set a suffix for infilling",
888
+ typeConverter=TypeConverters.toString)
889
+ cachePrompt = Param(Params._dummy(), "cachePrompt", "Whether to remember the prompt to avoid reprocessing it",
890
+ typeConverter=TypeConverters.toBoolean)
891
+ nPredict = Param(Params._dummy(), "nPredict", "Set the number of tokens to predict",
892
+ typeConverter=TypeConverters.toInt)
893
+ topK = Param(Params._dummy(), "topK", "Set top-k sampling", typeConverter=TypeConverters.toInt)
894
+ topP = Param(Params._dummy(), "topP", "Set top-p sampling", typeConverter=TypeConverters.toFloat)
895
+ minP = Param(Params._dummy(), "minP", "Set min-p sampling", typeConverter=TypeConverters.toFloat)
896
+ tfsZ = Param(Params._dummy(), "tfsZ", "Set tail free sampling, parameter z", typeConverter=TypeConverters.toFloat)
897
+ typicalP = Param(Params._dummy(), "typicalP", "Set locally typical sampling, parameter p",
898
+ typeConverter=TypeConverters.toFloat)
899
+ temperature = Param(Params._dummy(), "temperature", "Set the temperature", typeConverter=TypeConverters.toFloat)
900
+ dynamicTemperatureRange = Param(Params._dummy(), "dynatempRange", "Set the dynamic temperature range",
901
+ typeConverter=TypeConverters.toFloat)
902
+ dynamicTemperatureExponent = Param(Params._dummy(), "dynatempExponent", "Set the dynamic temperature exponent",
903
+ typeConverter=TypeConverters.toFloat)
904
+ repeatLastN = Param(Params._dummy(), "repeatLastN", "Set the last n tokens to consider for penalties",
905
+ typeConverter=TypeConverters.toInt)
906
+ repeatPenalty = Param(Params._dummy(), "repeatPenalty", "Set the penalty of repeated sequences of tokens",
907
+ typeConverter=TypeConverters.toFloat)
908
+ frequencyPenalty = Param(Params._dummy(), "frequencyPenalty", "Set the repetition alpha frequency penalty",
909
+ typeConverter=TypeConverters.toFloat)
910
+ presencePenalty = Param(Params._dummy(), "presencePenalty", "Set the repetition alpha presence penalty",
911
+ typeConverter=TypeConverters.toFloat)
912
+ miroStat = Param(Params._dummy(), "miroStat", "Set MiroStat sampling strategies.",
913
+ typeConverter=TypeConverters.toString)
914
+ miroStatTau = Param(Params._dummy(), "mirostatTau", "Set the MiroStat target entropy, parameter tau",
915
+ typeConverter=TypeConverters.toFloat)
916
+ miroStatEta = Param(Params._dummy(), "mirostatEta", "Set the MiroStat learning rate, parameter eta",
917
+ typeConverter=TypeConverters.toFloat)
918
+ penalizeNl = Param(Params._dummy(), "penalizeNl", "Whether to penalize newline tokens",
919
+ typeConverter=TypeConverters.toBoolean)
920
+ nKeep = Param(Params._dummy(), "nKeep", "Set the number of tokens to keep from the initial prompt",
921
+ typeConverter=TypeConverters.toInt)
922
+ seed = Param(Params._dummy(), "seed", "Set the RNG seed", typeConverter=TypeConverters.toInt)
923
+ nProbs = Param(Params._dummy(), "nProbs", "Set the amount top tokens probabilities to output if greater than 0.",
924
+ typeConverter=TypeConverters.toInt)
925
+ minKeep = Param(Params._dummy(), "minKeep",
926
+ "Set the amount of tokens the samplers should return at least (0 = disabled)",
927
+ typeConverter=TypeConverters.toInt)
928
+ grammar = Param(Params._dummy(), "grammar", "Set BNF-like grammar to constrain generations",
929
+ typeConverter=TypeConverters.toString)
930
+ penaltyPrompt = Param(Params._dummy(), "penaltyPrompt",
931
+ "Override which part of the prompt is penalized for repetition.",
932
+ typeConverter=TypeConverters.toString)
933
+ ignoreEos = Param(Params._dummy(), "ignoreEos",
934
+ "Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)",
935
+ typeConverter=TypeConverters.toBoolean)
936
+ disableTokenIds = Param(Params._dummy(), "disableTokenIds", "Set the token ids to disable in the completion",
937
+ typeConverter=TypeConverters.toListInt)
938
+ stopStrings = Param(Params._dummy(), "stopStrings", "Set strings upon seeing which token generation is stopped",
939
+ typeConverter=TypeConverters.toListString)
940
+ samplers = Param(Params._dummy(), "samplers", "Set which samplers to use for token generation in the given order",
941
+ typeConverter=TypeConverters.toListString)
942
+ useChatTemplate = Param(Params._dummy(), "useChatTemplate",
943
+ "Set whether or not generate should apply a chat template",
944
+ typeConverter=TypeConverters.toBoolean)
945
+
946
+ # -------- MODEL SETTERS --------
947
+ def setNThreads(self, nThreads: int):
948
+ """Set the number of threads to use during generation"""
949
+ return self._set(nThreads=nThreads)
950
+
951
+ def setNThreadsDraft(self, nThreadsDraft: int):
952
+ """Set the number of threads to use during draft generation"""
953
+ return self._set(nThreadsDraft=nThreadsDraft)
954
+
955
+ def setNThreadsBatch(self, nThreadsBatch: int):
956
+ """Set the number of threads to use during batch and prompt processing"""
957
+ return self._set(nThreadsBatch=nThreadsBatch)
958
+
959
+ def setNThreadsBatchDraft(self, nThreadsBatchDraft: int):
960
+ """Set the number of threads to use during batch and prompt processing"""
961
+ return self._set(nThreadsBatchDraft=nThreadsBatchDraft)
962
+
963
+ def setNCtx(self, nCtx: int):
964
+ """Set the size of the prompt context"""
965
+ return self._set(nCtx=nCtx)
966
+
967
+ def setNBatch(self, nBatch: int):
968
+ """Set the logical batch size for prompt processing (must be >=32 to use BLAS)"""
969
+ return self._set(nBatch=nBatch)
970
+
971
+ def setNUbatch(self, nUbatch: int):
972
+ """Set the physical batch size for prompt processing (must be >=32 to use BLAS)"""
973
+ return self._set(nUbatch=nUbatch)
974
+
975
+ def setNDraft(self, nDraft: int):
976
+ """Set the number of tokens to draft for speculative decoding"""
977
+ return self._set(nDraft=nDraft)
978
+
979
+ def setNChunks(self, nChunks: int):
980
+ """Set the maximal number of chunks to process"""
981
+ return self._set(nChunks=nChunks)
982
+
983
+ def setNSequences(self, nSequences: int):
984
+ """Set the number of sequences to decode"""
985
+ return self._set(nSequences=nSequences)
986
+
987
+ def setPSplit(self, pSplit: float):
988
+ """Set the speculative decoding split probability"""
989
+ return self._set(pSplit=pSplit)
990
+
991
+ def setNGpuLayers(self, nGpuLayers: int):
992
+ """Set the number of layers to store in VRAM (-1 - use default)"""
993
+ return self._set(nGpuLayers=nGpuLayers)
994
+
995
+ def setNGpuLayersDraft(self, nGpuLayersDraft: int):
996
+ """Set the number of layers to store in VRAM for the draft model (-1 - use default)"""
997
+ return self._set(nGpuLayersDraft=nGpuLayersDraft)
998
+
999
+ def setGpuSplitMode(self, gpuSplitMode: str):
1000
+ """Set how to split the model across GPUs"""
1001
+ return self._set(gpuSplitMode=gpuSplitMode)
1002
+
1003
+ def setMainGpu(self, mainGpu: int):
1004
+ """Set the main GPU that is used for scratch and small tensors."""
1005
+ return self._set(mainGpu=mainGpu)
1006
+
1007
+ def setTensorSplit(self, tensorSplit: List[float]):
1008
+ """Set how split tensors should be distributed across GPUs"""
1009
+ return self._set(tensorSplit=tensorSplit)
1010
+
1011
+ def setGrpAttnN(self, grpAttnN: int):
1012
+ """Set the group-attention factor"""
1013
+ return self._set(grpAttnN=grpAttnN)
1014
+
1015
+ def setGrpAttnW(self, grpAttnW: int):
1016
+ """Set the group-attention width"""
1017
+ return self._set(grpAttnW=grpAttnW)
1018
+
1019
+ def setRopeFreqBase(self, ropeFreqBase: float):
1020
+ """Set the RoPE base frequency, used by NTK-aware scaling"""
1021
+ return self._set(ropeFreqBase=ropeFreqBase)
1022
+
1023
+ def setRopeFreqScale(self, ropeFreqScale: float):
1024
+ """Set the RoPE frequency scaling factor, expands context by a factor of 1/N"""
1025
+ return self._set(ropeFreqScale=ropeFreqScale)
1026
+
1027
+ def setYarnExtFactor(self, yarnExtFactor: float):
1028
+ """Set the YaRN extrapolation mix factor"""
1029
+ return self._set(yarnExtFactor=yarnExtFactor)
1030
+
1031
+ def setYarnAttnFactor(self, yarnAttnFactor: float):
1032
+ """Set the YaRN scale sqrt(t) or attention magnitude"""
1033
+ return self._set(yarnAttnFactor=yarnAttnFactor)
1034
+
1035
+ def setYarnBetaFast(self, yarnBetaFast: float):
1036
+ """Set the YaRN low correction dim or beta"""
1037
+ return self._set(yarnBetaFast=yarnBetaFast)
1038
+
1039
+ def setYarnBetaSlow(self, yarnBetaSlow: float):
1040
+ """Set the YaRN high correction dim or alpha"""
1041
+ return self._set(yarnBetaSlow=yarnBetaSlow)
1042
+
1043
+ def setYarnOrigCtx(self, yarnOrigCtx: int):
1044
+ """Set the YaRN original context size of model"""
1045
+ return self._set(yarnOrigCtx=yarnOrigCtx)
1046
+
1047
+ def setDefragmentationThreshold(self, defragmentationThreshold: float):
1048
+ """Set the KV cache defragmentation threshold"""
1049
+ return self._set(defragmentationThreshold=defragmentationThreshold)
712
1050
 
1051
+ def setNumaStrategy(self, numaStrategy: str):
1052
+ """Set optimization strategies that help on some NUMA systems (if available)"""
1053
+ numaUpper = numaStrategy.upper()
1054
+ numaStrategies = ["DISABLED", "DISTRIBUTE", "ISOLATE", "NUMA_CTL", "MIRROR"]
1055
+ if numaUpper not in numaStrategies:
1056
+ raise ValueError(
1057
+ f"Invalid NUMA strategy: {numaUpper}. "
1058
+ + f"Valid values are: {numaStrategies}"
1059
+ )
1060
+ return self._set(numaStrategy=numaStrategy)
1061
+
1062
+ def setRopeScalingType(self, ropeScalingType: str):
1063
+ """Set the RoPE frequency scaling method, defaults to linear unless specified by the model"""
1064
+ return self._set(ropeScalingType=ropeScalingType)
1065
+
1066
+ def setPoolingType(self, poolingType: str):
1067
+ """Set the pooling type for embeddings, use model default if unspecified"""
1068
+ poolingTypeUpper = poolingType.upper()
1069
+ poolingTypes = ["NONE", "MEAN", "CLS", "LAST"]
1070
+ if poolingTypeUpper not in poolingTypes:
1071
+ raise ValueError(
1072
+ f"Invalid pooling type: {poolingType}. "
1073
+ + f"Valid values are: {poolingTypes}"
1074
+ )
1075
+ return self._set(poolingType=poolingType)
1076
+
1077
+ def setModelDraft(self, modelDraft: str):
1078
+ """Set the draft model for speculative decoding"""
1079
+ return self._set(modelDraft=modelDraft)
1080
+
1081
+ def setModelAlias(self, modelAlias: str):
1082
+ """Set a model alias"""
1083
+ return self._set(modelAlias=modelAlias)
1084
+
1085
+ def setLookupCacheStaticFilePath(self, lookupCacheStaticFilePath: str):
1086
+ """Set path to static lookup cache to use for lookup decoding (not updated by generation)"""
1087
+ return self._set(lookupCacheStaticFilePath=lookupCacheStaticFilePath)
1088
+
1089
+ def setLookupCacheDynamicFilePath(self, lookupCacheDynamicFilePath: str):
1090
+ """Set path to dynamic lookup cache to use for lookup decoding (updated by generation)"""
1091
+ return self._set(lookupCacheDynamicFilePath=lookupCacheDynamicFilePath)
1092
+
1093
+ def setEmbedding(self, embedding: bool):
1094
+ """Whether to load model with embedding support"""
1095
+ return self._set(embedding=embedding)
1096
+
1097
+ def setFlashAttention(self, flashAttention: bool):
1098
+ """Whether to enable Flash Attention"""
1099
+ return self._set(flashAttention=flashAttention)
1100
+
1101
+ def setInputPrefixBos(self, inputPrefixBos: bool):
1102
+ """Whether to add prefix BOS to user inputs, preceding the `--in-prefix` bool"""
1103
+ return self._set(inputPrefixBos=inputPrefixBos)
1104
+
1105
+ def setUseMmap(self, useMmap: bool):
1106
+ """Whether to use memory-map model (faster load but may increase pageouts if not using mlock)"""
1107
+ return self._set(useMmap=useMmap)
1108
+
1109
+ def setUseMlock(self, useMlock: bool):
1110
+ """Whether to force the system to keep model in RAM rather than swapping or compressing"""
1111
+ return self._set(useMlock=useMlock)
1112
+
1113
+ def setNoKvOffload(self, noKvOffload: bool):
1114
+ """Whether to disable KV offload"""
1115
+ return self._set(noKvOffload=noKvOffload)
1116
+
1117
+ def setSystemPrompt(self, systemPrompt: bool):
1118
+ """Set a system prompt to use"""
1119
+ return self._set(systemPrompt=systemPrompt)
1120
+
1121
+ def setChatTemplate(self, chatTemplate: str):
1122
+ """The chat template to use"""
1123
+ return self._set(chatTemplate=chatTemplate)
1124
+
1125
+ # -------- INFERENCE SETTERS --------
1126
+ def setInputPrefix(self, inputPrefix: str):
1127
+ """Set the prompt to start generation with"""
1128
+ return self._set(inputPrefix=inputPrefix)
713
1129
 
714
- def setBeamSize(self, value):
715
- """Sets the number of beam size for beam search.
1130
+ def setInputSuffix(self, inputSuffix: str):
1131
+ """Set a suffix for infilling"""
1132
+ return self._set(inputSuffix=inputSuffix)
716
1133
 
717
- Parameters
718
- ----------
719
- value : int
720
- Number of beam size for beam search
721
- """
722
- return self._set(beamSize=value)
1134
+ def setCachePrompt(self, cachePrompt: bool):
1135
+ """Whether to remember the prompt to avoid reprocessing it"""
1136
+ return self._set(cachePrompt=cachePrompt)
723
1137
 
1138
+ def setNPredict(self, nPredict: int):
1139
+ """Set the number of tokens to predict"""
1140
+ return self._set(nPredict=nPredict)
724
1141
 
725
- def setNReturnSequences(self, value):
726
- """Sets the number of sequences to return from the beam search.
1142
+ def setTopK(self, topK: int):
1143
+ """Set top-k sampling"""
1144
+ return self._set(topK=topK)
727
1145
 
728
- Parameters
729
- ----------
730
- value : int
731
- Number of sequences to return
732
- """
733
- return self._set(nReturnSequences=value)
1146
+ def setTopP(self, topP: float):
1147
+ """Set top-p sampling"""
1148
+ return self._set(topP=topP)
1149
+
1150
+ def setMinP(self, minP: float):
1151
+ """Set min-p sampling"""
1152
+ return self._set(minP=minP)
1153
+
1154
+ def setTfsZ(self, tfsZ: float):
1155
+ """Set tail free sampling, parameter z"""
1156
+ return self._set(tfsZ=tfsZ)
1157
+
1158
+ def setTypicalP(self, typicalP: float):
1159
+ """Set locally typical sampling, parameter p"""
1160
+ return self._set(typicalP=typicalP)
1161
+
1162
+ def setTemperature(self, temperature: float):
1163
+ """Set the temperature"""
1164
+ return self._set(temperature=temperature)
1165
+
1166
+ def setDynamicTemperatureRange(self, dynamicTemperatureRange: float):
1167
+ """Set the dynamic temperature range"""
1168
+ return self._set(dynamicTemperatureRange=dynamicTemperatureRange)
1169
+
1170
+ def setDynamicTemperatureExponent(self, dynamicTemperatureExponent: float):
1171
+ """Set the dynamic temperature exponent"""
1172
+ return self._set(dynamicTemperatureExponent=dynamicTemperatureExponent)
1173
+
1174
+ def setRepeatLastN(self, repeatLastN: int):
1175
+ """Set the last n tokens to consider for penalties"""
1176
+ return self._set(repeatLastN=repeatLastN)
1177
+
1178
+ def setRepeatPenalty(self, repeatPenalty: float):
1179
+ """Set the penalty of repeated sequences of tokens"""
1180
+ return self._set(repeatPenalty=repeatPenalty)
1181
+
1182
+ def setFrequencyPenalty(self, frequencyPenalty: float):
1183
+ """Set the repetition alpha frequency penalty"""
1184
+ return self._set(frequencyPenalty=frequencyPenalty)
1185
+
1186
+ def setPresencePenalty(self, presencePenalty: float):
1187
+ """Set the repetition alpha presence penalty"""
1188
+ return self._set(presencePenalty=presencePenalty)
1189
+
1190
+ def setMiroStat(self, miroStat: str):
1191
+ """Set MiroStat sampling strategies."""
1192
+ return self._set(miroStat=miroStat)
1193
+
1194
+ def setMiroStatTau(self, miroStatTau: float):
1195
+ """Set the MiroStat target entropy, parameter tau"""
1196
+ return self._set(miroStatTau=miroStatTau)
1197
+
1198
+ def setMiroStatEta(self, miroStatEta: float):
1199
+ """Set the MiroStat learning rate, parameter eta"""
1200
+ return self._set(miroStatEta=miroStatEta)
1201
+
1202
+ def setPenalizeNl(self, penalizeNl: bool):
1203
+ """Whether to penalize newline tokens"""
1204
+ return self._set(penalizeNl=penalizeNl)
1205
+
1206
+ def setNKeep(self, nKeep: int):
1207
+ """Set the number of tokens to keep from the initial prompt"""
1208
+ return self._set(nKeep=nKeep)
1209
+
1210
+ def setSeed(self, seed: int):
1211
+ """Set the RNG seed"""
1212
+ return self._set(seed=seed)
1213
+
1214
+ def setNProbs(self, nProbs: int):
1215
+ """Set the amount top tokens probabilities to output if greater than 0."""
1216
+ return self._set(nProbs=nProbs)
1217
+
1218
+ def setMinKeep(self, minKeep: int):
1219
+ """Set the amount of tokens the samplers should return at least (0 = disabled)"""
1220
+ return self._set(minKeep=minKeep)
1221
+
1222
+ def setGrammar(self, grammar: bool):
1223
+ """Set BNF-like grammar to constrain generations"""
1224
+ return self._set(grammar=grammar)
1225
+
1226
+ def setPenaltyPrompt(self, penaltyPrompt: str):
1227
+ """Override which part of the prompt is penalized for repetition."""
1228
+ return self._set(penaltyPrompt=penaltyPrompt)
1229
+
1230
+ def setIgnoreEos(self, ignoreEos: bool):
1231
+ """Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)"""
1232
+ return self._set(ignoreEos=ignoreEos)
1233
+
1234
+ def setDisableTokenIds(self, disableTokenIds: List[int]):
1235
+ """Set the token ids to disable in the completion"""
1236
+ return self._set(disableTokenIds=disableTokenIds)
1237
+
1238
+ def setStopStrings(self, stopStrings: List[str]):
1239
+ """Set strings upon seeing which token generation is stopped"""
1240
+ return self._set(stopStrings=stopStrings)
1241
+
1242
+ def setSamplers(self, samplers: List[str]):
1243
+ """Set which samplers to use for token generation in the given order"""
1244
+ return self._set(samplers=samplers)
1245
+
1246
+ def setUseChatTemplate(self, useChatTemplate: bool):
1247
+ """Set whether generate should apply a chat template"""
1248
+ return self._set(useChatTemplate=useChatTemplate)
1249
+
1250
+ def setNParallel(self, nParallel: int):
1251
+ """Sets the number of parallel processes for decoding. This is an alias for `setBatchSize`."""
1252
+ return self.setBatchSize(nParallel)
1253
+
1254
+ # -------- JAVA SETTERS --------
1255
+ def setTokenIdBias(self, tokenIdBias: Dict[int, float]):
1256
+ """Set token id bias"""
1257
+ return self._call_java("setTokenIdBias", tokenIdBias)
1258
+
1259
+ def setTokenBias(self, tokenBias: Dict[str, float]):
1260
+ """Set token id bias"""
1261
+ return self._call_java("setTokenBias", tokenBias)
1262
+
1263
+ def setLoraAdapters(self, loraAdapters: Dict[str, float]):
1264
+ """Set token id bias"""
1265
+ return self._call_java("setLoraAdapters", loraAdapters)
1266
+
1267
+ def getMetadata(self):
1268
+ """Gets the metadata of the model"""
1269
+ return self._call_java("getMetadata")