clarifai 10.9.1__py3-none-any.whl → 10.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "10.9.1"
1
+ __version__ = "10.9.4"
File without changes
clarifai/cli/base.py ADDED
@@ -0,0 +1,74 @@
1
+ import os
2
+
3
+ import click
4
+
5
+ from ..utils.cli import dump_yaml, from_yaml, load_command_modules, set_base_url
6
+
7
+
8
+ @click.group()
9
+ @click.pass_context
10
+ def cli(ctx):
11
+ """Clarifai CLI"""
12
+ ctx.ensure_object(dict)
13
+ config_path = 'config.yaml'
14
+ if os.path.exists(config_path):
15
+ ctx.obj = from_yaml(config_path)
16
+ else:
17
+ ctx.obj = {}
18
+
19
+
20
+ @cli.command()
21
+ @click.option('--config', type=click.Path(), required=False, help='Path to the config file')
22
+ @click.option(
23
+ '-e',
24
+ '--env',
25
+ required=False,
26
+ help='Environment to use, choose from prod, staging and dev',
27
+ type=click.Choice(['prod', 'staging', 'dev']))
28
+ @click.option('--user_id', required=False, help='User ID')
29
+ @click.pass_context
30
+ def login(ctx, config, env, user_id):
31
+ """Login command to set PAT and other configurations."""
32
+
33
+ if config and os.path.exists(config):
34
+ ctx.obj = from_yaml(config)
35
+
36
+ if 'pat' in ctx.obj:
37
+ os.environ["CLARIFAI_PAT"] = ctx.obj['pat']
38
+ click.echo("Loaded PAT from config file.")
39
+ elif 'CLARIFAI_PAT' in os.environ:
40
+ ctx.obj['pat'] = os.environ["CLARIFAI_PAT"]
41
+ click.echo("Loaded PAT from environment variable.")
42
+ else:
43
+ _pat = click.prompt(
44
+ "Get your PAT from https://clarifai.com/settings/security and pass it here", type=str)
45
+ os.environ["CLARIFAI_PAT"] = _pat
46
+ ctx.obj['pat'] = _pat
47
+ click.echo("PAT saved successfully.")
48
+
49
+ if user_id:
50
+ ctx.obj['user_id'] = user_id
51
+ os.environ["CLARIFAI_USER_ID"] = ctx.obj['user_id']
52
+ elif 'user_id' in ctx.obj or 'CLARIFAI_USER_ID' in os.environ:
53
+ ctx.obj['user_id'] = ctx.obj.get('user_id', os.environ["CLARIFAI_USER_ID"])
54
+ os.environ["CLARIFAI_USER_ID"] = ctx.obj['user_id']
55
+
56
+ if env:
57
+ ctx.obj['env'] = env
58
+ ctx.obj['base_url'] = set_base_url(env)
59
+ os.environ["CLARIFAI_API_BASE"] = ctx.obj['base_url']
60
+ elif 'env' in ctx.obj:
61
+ ctx.obj['env'] = ctx.obj.get('env', "prod")
62
+ ctx.obj['base_url'] = set_base_url(ctx.obj['env'])
63
+ os.environ["CLARIFAI_API_BASE"] = ctx.obj['base_url']
64
+ elif 'CLARIFAI_API_BASE' in os.environ:
65
+ ctx.obj['base_url'] = os.environ["CLARIFAI_API_BASE"]
66
+
67
+ dump_yaml(ctx.obj, 'config.yaml')
68
+
69
+
70
+ # Import the CLI commands to register them
71
+ load_command_modules()
72
+
73
+ if __name__ == '__main__':
74
+ cli()
clarifai/cli/model.py ADDED
@@ -0,0 +1,65 @@
1
+ import click
2
+ from clarifai.cli.base import cli
3
+
4
+
5
+ @cli.group()
6
+ def model():
7
+ """Manage models: upload, test locally"""
8
+ pass
9
+
10
+
11
+ @model.command()
12
+ @click.option(
13
+ '--model_path',
14
+ type=click.Path(exists=True),
15
+ required=True,
16
+ help='Path to the model directory.')
17
+ @click.option(
18
+ '--download_checkpoints',
19
+ is_flag=True,
20
+ help=
21
+ 'Flag to download checkpoints before uploading and including them in the tar file that is uploaded. Defaults to False, which will attempt to download them at docker build time.',
22
+ )
23
+ @click.option(
24
+ '--skip_dockerfile',
25
+ is_flag=True,
26
+ help=
27
+ 'Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile.',
28
+ )
29
+ def upload(model_path, download_checkpoints, skip_dockerfile):
30
+ """Upload a model to Clarifai."""
31
+ from clarifai.runners.models import model_upload
32
+
33
+ model_upload.main(model_path, download_checkpoints, skip_dockerfile)
34
+
35
+
36
+ @model.command()
37
+ @click.option(
38
+ '--model_path',
39
+ type=click.Path(exists=True),
40
+ required=True,
41
+ help='Path to the model directory.')
42
+ def test_locally(model_path):
43
+ """Test model locally."""
44
+ try:
45
+ from clarifai.runners.models import model_run_locally
46
+ model_run_locally.main(model_path)
47
+ click.echo(f"Model tested locally from {model_path}.")
48
+ except Exception as e:
49
+ click.echo(f"Failed to test model locally: {e}", err=True)
50
+
51
+
52
+ @model.command()
53
+ @click.option(
54
+ '--model_path',
55
+ type=click.Path(exists=True),
56
+ required=True,
57
+ help='Path to the model directory.')
58
+ def run_locally(model_path):
59
+ """Run model locally and starts a GRPC server to serve the model."""
60
+ try:
61
+ from clarifai.runners.models import model_run_locally
62
+ model_run_locally.main(model_path, run_model_server=True)
63
+ click.echo(f"Model server started locally from {model_path}.")
64
+ except Exception as e:
65
+ click.echo(f"Failed to starts model server locally: {e}", err=True)
@@ -5,11 +5,16 @@ from urllib.parse import urlparse
5
5
 
6
6
  from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
7
7
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2_grpc
8
+
9
+ from clarifai import __version__
8
10
  from clarifai.utils.constants import CLARIFAI_PAT_ENV_VAR, CLARIFAI_SESSION_TOKEN_ENV_VAR
9
11
 
10
12
  DEFAULT_BASE = "https://api.clarifai.com"
11
13
  DEFAULT_UI = "https://clarifai.com"
12
14
 
15
+ REQUEST_ID_PREFIX_HEADER = "x-clarifai-request-id-prefix"
16
+ REQUEST_ID_PREFIX = f"sdk-python-{__version__}"
17
+
13
18
  # Map from base domain to True / False for whether the base has https or http.
14
19
  # This is filled in get_stub() if it's not in there already.
15
20
  base_https_cache = {}
@@ -136,7 +141,7 @@ class ClarifaiAuthHelper:
136
141
 
137
142
  # Then add in the query params.
138
143
  try:
139
- auth.add_streamlit_query_params(st.experimental_get_query_params())
144
+ auth.add_streamlit_query_params(dict(st.query_params))
140
145
  except Exception as e:
141
146
  st.error(e)
142
147
  st.stop()
@@ -269,9 +274,11 @@ Additionally, these optional params are supported:
269
274
  metadata: the metadata need to send with all grpc API calls in the API client.
270
275
  """
271
276
  if self._pat != "":
272
- return (("authorization", "Key %s" % self._pat),)
277
+ return (("authorization", "Key %s" % self._pat), (REQUEST_ID_PREFIX_HEADER,
278
+ REQUEST_ID_PREFIX))
273
279
  elif self._token != "":
274
- return (("x-clarifai-session-token", self._token),)
280
+ return (("x-clarifai-session-token", self._token), (REQUEST_ID_PREFIX_HEADER,
281
+ REQUEST_ID_PREFIX))
275
282
  else:
276
283
  raise Exception("'token' or 'pat' needed to be provided in the query params or env vars.")
277
284
 
@@ -628,6 +628,28 @@ class Dataset(Lister, BaseClient):
628
628
  if delete_version:
629
629
  self.delete_version(dataset_version_id)
630
630
 
631
+ def merge_dataset(self, merge_dataset_id: str) -> None:
632
+ """Merges the another dataset into current dataset.
633
+
634
+ Args:
635
+ merge_dataset_id (str): The dataset ID of the dataset to merge.
636
+
637
+ Example:
638
+ >>> from clarifai.client.dataset import Dataset
639
+ >>> dataset = Dataset(dataset_id='dataset_id', user_id='user_id', app_id='app_id')
640
+ >>> dataset.merge_dataset(merge_dataset_id='merge_dataset_id')
641
+ """
642
+ dataset_filter = resources_pb2.Filter(
643
+ input=resources_pb2.Input(dataset_ids=[merge_dataset_id]))
644
+ query = resources_pb2.Search(query=resources_pb2.Query(filters=[dataset_filter]))
645
+ request = service_pb2.PostDatasetInputsRequest(
646
+ user_app_id=self.user_app_id, dataset_id=self.id, search=query)
647
+
648
+ response = self._grpc_request(self.STUB.PostDatasetInputs, request)
649
+ if response.status.code != status_code_pb2.SUCCESS:
650
+ raise Exception(response.status)
651
+ self.logger.info("\nDataset Merged\n%s", response.status)
652
+
631
653
  def archive_zip(self, wait: bool = True) -> str:
632
654
  """Exports the dataset to a zip file URL."""
633
655
  request = service_pb2.PutDatasetVersionExportsRequest(
clarifai/client/input.py CHANGED
@@ -100,7 +100,7 @@ class Inputs(Lister, BaseClient):
100
100
  if not label_ids:
101
101
  concepts=[
102
102
  resources_pb2.Concept(
103
- id=f"id-{''.join(_label.split(' '))}", name=_label, value=1.)\
103
+ id=_label, name=_label, value=1.)\
104
104
  for _label in labels
105
105
  ]
106
106
  else:
@@ -516,7 +516,7 @@ class Inputs(Lister, BaseClient):
516
516
  right_col=bbox[2] #x_max
517
517
  )),
518
518
  data=resources_pb2.Data(concepts=[
519
- resources_pb2.Concept(id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
519
+ resources_pb2.Concept(id=label, name=label, value=1.)
520
520
  if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
521
521
  ]))
522
522
  ])
@@ -561,7 +561,7 @@ class Inputs(Lister, BaseClient):
561
561
  visibility="VISIBLE") for _point in polygons
562
562
  ])),
563
563
  data=resources_pb2.Data(concepts=[
564
- resources_pb2.Concept(id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
564
+ resources_pb2.Concept(id=label, name=label, value=1.)
565
565
  if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
566
566
  ]))
567
567
  ])
clarifai/client/model.py CHANGED
@@ -72,6 +72,7 @@ class Model(Lister, BaseClient):
72
72
  self.model_info = resources_pb2.Model(**self.kwargs)
73
73
  self.logger = logger
74
74
  self.training_params = {}
75
+ self.input_types = None
75
76
  BaseClient.__init__(
76
77
  self,
77
78
  user_id=self.user_id,
@@ -450,9 +451,55 @@ class Model(Lister, BaseClient):
450
451
 
451
452
  return response
452
453
 
454
+ def _check_predict_input_type(self, input_type: str) -> None:
455
+ """Checks if the input type is valid for the model.
456
+
457
+ Args:
458
+ input_type (str): The input type to check.
459
+ Returns:
460
+ None
461
+ """
462
+ if not input_type:
463
+ self.load_input_types()
464
+ if len(self.input_types) > 1:
465
+ raise UserError(
466
+ "Model has multiple input types. Please use model.predict() for this multi-modal model."
467
+ )
468
+ else:
469
+ self.input_types = [input_type]
470
+ if self.input_types[0] not in {'image', 'text', 'video', 'audio'}:
471
+ raise UserError(
472
+ f"Got input type {input_type} but expected one of image, text, video, audio.")
473
+
474
+ def load_input_types(self) -> None:
475
+ """Loads the input types for the model.
476
+
477
+ Returns:
478
+ None
479
+
480
+ Example:
481
+ >>> from clarifai.client.model import Model
482
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
483
+ or
484
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
485
+ >>> model.load_input_types()
486
+ """
487
+ if self.input_types:
488
+ return self.input_types
489
+ if self.model_info.model_type_id == "":
490
+ self.load_info()
491
+ request = service_pb2.GetModelTypeRequest(
492
+ user_app_id=self.user_app_id,
493
+ model_type_id=self.model_info.model_type_id,
494
+ )
495
+ response = self._grpc_request(self.STUB.GetModelType, request)
496
+ if response.status.code != status_code_pb2.SUCCESS:
497
+ raise Exception(response.status)
498
+ self.input_types = response.model_type.input_fields
499
+
453
500
  def predict_by_filepath(self,
454
501
  filepath: str,
455
- input_type: str,
502
+ input_type: str = None,
456
503
  compute_cluster_id: str = None,
457
504
  nodepool_id: str = None,
458
505
  deployment_id: str = None,
@@ -462,7 +509,7 @@ class Model(Lister, BaseClient):
462
509
 
463
510
  Args:
464
511
  filepath (str): The filepath to predict.
465
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
512
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
466
513
  compute_cluster_id (str): The compute cluster ID to use for the model.
467
514
  nodepool_id (str): The nodepool ID to use for the model.
468
515
  deployment_id (str): The deployment ID to use for the model.
@@ -477,8 +524,8 @@ class Model(Lister, BaseClient):
477
524
  >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
478
525
  or
479
526
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
480
- >>> model_prediction = model.predict_by_filepath('/path/to/image.jpg', 'image')
481
- >>> model_prediction = model.predict_by_filepath('/path/to/text.txt', 'text')
527
+ >>> model_prediction = model.predict_by_filepath('/path/to/image.jpg')
528
+ >>> model_prediction = model.predict_by_filepath('/path/to/text.txt')
482
529
  """
483
530
  if not os.path.isfile(filepath):
484
531
  raise UserError('Invalid filepath.')
@@ -491,7 +538,7 @@ class Model(Lister, BaseClient):
491
538
 
492
539
  def predict_by_bytes(self,
493
540
  input_bytes: bytes,
494
- input_type: str,
541
+ input_type: str = None,
495
542
  compute_cluster_id: str = None,
496
543
  nodepool_id: str = None,
497
544
  deployment_id: str = None,
@@ -501,7 +548,7 @@ class Model(Lister, BaseClient):
501
548
 
502
549
  Args:
503
550
  input_bytes (bytes): File Bytes to predict on.
504
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
551
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
505
552
  compute_cluster_id (str): The compute cluster ID to use for the model.
506
553
  nodepool_id (str): The nodepool ID to use for the model.
507
554
  deployment_id (str): The deployment ID to use for the model.
@@ -515,22 +562,17 @@ class Model(Lister, BaseClient):
515
562
  >>> from clarifai.client.model import Model
516
563
  >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
517
564
  >>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI',
518
- input_type='text',
519
565
  inference_params=dict(temperature=str(0.7), max_tokens=30)))
520
566
  """
521
- if input_type not in {'image', 'text', 'video', 'audio'}:
522
- raise UserError(
523
- f"Got input type {input_type} but expected one of image, text, video, audio.")
524
- if not isinstance(input_bytes, bytes):
525
- raise UserError('Invalid bytes.')
567
+ self._check_predict_input_type(input_type)
526
568
 
527
- if input_type == "image":
569
+ if self.input_types[0] == "image":
528
570
  input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
529
- elif input_type == "text":
571
+ elif self.input_types[0] == "text":
530
572
  input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
531
- elif input_type == "video":
573
+ elif self.input_types[0] == "video":
532
574
  input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
533
- elif input_type == "audio":
575
+ elif self.input_types[0] == "audio":
534
576
  input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
535
577
 
536
578
  if deployment_id and (compute_cluster_id or nodepool_id):
@@ -553,7 +595,7 @@ class Model(Lister, BaseClient):
553
595
 
554
596
  def predict_by_url(self,
555
597
  url: str,
556
- input_type: str,
598
+ input_type: str = None,
557
599
  compute_cluster_id: str = None,
558
600
  nodepool_id: str = None,
559
601
  deployment_id: str = None,
@@ -563,7 +605,7 @@ class Model(Lister, BaseClient):
563
605
 
564
606
  Args:
565
607
  url (str): The URL to predict.
566
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
608
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio'.
567
609
  compute_cluster_id (str): The compute cluster ID to use for the model.
568
610
  nodepool_id (str): The nodepool ID to use for the model.
569
611
  deployment_id (str): The deployment ID to use for the model.
@@ -578,19 +620,17 @@ class Model(Lister, BaseClient):
578
620
  >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
579
621
  or
580
622
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
581
- >>> model_prediction = model.predict_by_url('url', 'image')
623
+ >>> model_prediction = model.predict_by_url('url')
582
624
  """
583
- if input_type not in {'image', 'text', 'video', 'audio'}:
584
- raise UserError(
585
- f"Got input type {input_type} but expected one of image, text, video, audio.")
625
+ self._check_predict_input_type(input_type)
586
626
 
587
- if input_type == "image":
627
+ if self.input_types[0] == "image":
588
628
  input_proto = Inputs.get_input_from_url("", image_url=url)
589
- elif input_type == "text":
629
+ elif self.input_types[0] == "text":
590
630
  input_proto = Inputs.get_input_from_url("", text_url=url)
591
- elif input_type == "video":
631
+ elif self.input_types[0] == "video":
592
632
  input_proto = Inputs.get_input_from_url("", video_url=url)
593
- elif input_type == "audio":
633
+ elif self.input_types[0] == "audio":
594
634
  input_proto = Inputs.get_input_from_url("", audio_url=url)
595
635
 
596
636
  if deployment_id and (compute_cluster_id or nodepool_id):
@@ -668,7 +708,7 @@ class Model(Lister, BaseClient):
668
708
 
669
709
  def generate_by_filepath(self,
670
710
  filepath: str,
671
- input_type: str,
711
+ input_type: str = None,
672
712
  compute_cluster_id: str = None,
673
713
  nodepool_id: str = None,
674
714
  deployment_id: str = None,
@@ -678,7 +718,7 @@ class Model(Lister, BaseClient):
678
718
 
679
719
  Args:
680
720
  filepath (str): The filepath to predict.
681
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
721
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
682
722
  compute_cluster_id (str): The compute cluster ID to use for the model.
683
723
  nodepool_id (str): The nodepool ID to use for the model.
684
724
  deployment_id (str): The deployment ID to use for the model.
@@ -713,7 +753,7 @@ class Model(Lister, BaseClient):
713
753
 
714
754
  def generate_by_bytes(self,
715
755
  input_bytes: bytes,
716
- input_type: str,
756
+ input_type: str = None,
717
757
  compute_cluster_id: str = None,
718
758
  nodepool_id: str = None,
719
759
  deployment_id: str = None,
@@ -723,7 +763,7 @@ class Model(Lister, BaseClient):
723
763
 
724
764
  Args:
725
765
  input_bytes (bytes): File Bytes to predict on.
726
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
766
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
727
767
  compute_cluster_id (str): The compute cluster ID to use for the model.
728
768
  nodepool_id (str): The nodepool ID to use for the model.
729
769
  deployment_id (str): The deployment ID to use for the model.
@@ -737,24 +777,19 @@ class Model(Lister, BaseClient):
737
777
  >>> from clarifai.client.model import Model
738
778
  >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
739
779
  >>> stream_response = model.generate_by_bytes(b'Write a tweet on future of AI',
740
- input_type='text',
741
780
  deployment_id='deployment_id',
742
781
  inference_params=dict(temperature=str(0.7), max_tokens=30)))
743
782
  >>> list_stream_response = [response for response in stream_response]
744
783
  """
745
- if input_type not in {'image', 'text', 'video', 'audio'}:
746
- raise UserError(
747
- f"Got input type {input_type} but expected one of image, text, video, audio.")
748
- if not isinstance(input_bytes, bytes):
749
- raise UserError('Invalid bytes.')
784
+ self._check_predict_input_type(input_type)
750
785
 
751
- if input_type == "image":
786
+ if self.input_types[0] == "image":
752
787
  input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
753
- elif input_type == "text":
788
+ elif self.input_types[0] == "text":
754
789
  input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
755
- elif input_type == "video":
790
+ elif self.input_types[0] == "video":
756
791
  input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
757
- elif input_type == "audio":
792
+ elif self.input_types[0] == "audio":
758
793
  input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
759
794
 
760
795
  if deployment_id and (compute_cluster_id or nodepool_id):
@@ -777,7 +812,7 @@ class Model(Lister, BaseClient):
777
812
 
778
813
  def generate_by_url(self,
779
814
  url: str,
780
- input_type: str,
815
+ input_type: str = None,
781
816
  compute_cluster_id: str = None,
782
817
  nodepool_id: str = None,
783
818
  deployment_id: str = None,
@@ -787,7 +822,7 @@ class Model(Lister, BaseClient):
787
822
 
788
823
  Args:
789
824
  url (str): The URL to predict.
790
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
825
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
791
826
  compute_cluster_id (str): The compute cluster ID to use for the model.
792
827
  nodepool_id (str): The nodepool ID to use for the model.
793
828
  deployment_id (str): The deployment ID to use for the model.
@@ -802,20 +837,18 @@ class Model(Lister, BaseClient):
802
837
  >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
803
838
  or
804
839
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
805
- >>> stream_response = model.generate_by_url('url', 'image', deployment_id='deployment_id')
840
+ >>> stream_response = model.generate_by_url('url', deployment_id='deployment_id')
806
841
  >>> list_stream_response = [response for response in stream_response]
807
842
  """
808
- if input_type not in {'image', 'text', 'video', 'audio'}:
809
- raise UserError(
810
- f"Got input type {input_type} but expected one of image, text, video, audio.")
843
+ self._check_predict_input_type(input_type)
811
844
 
812
- if input_type == "image":
845
+ if self.input_types[0] == "image":
813
846
  input_proto = Inputs.get_input_from_url("", image_url=url)
814
- elif input_type == "text":
847
+ elif self.input_types[0] == "text":
815
848
  input_proto = Inputs.get_input_from_url("", text_url=url)
816
- elif input_type == "video":
849
+ elif self.input_types[0] == "video":
817
850
  input_proto = Inputs.get_input_from_url("", video_url=url)
818
- elif input_type == "audio":
851
+ elif self.input_types[0] == "audio":
819
852
  input_proto = Inputs.get_input_from_url("", audio_url=url)
820
853
 
821
854
  if deployment_id and (compute_cluster_id or nodepool_id):
@@ -893,7 +926,7 @@ class Model(Lister, BaseClient):
893
926
 
894
927
  def stream_by_filepath(self,
895
928
  filepath: str,
896
- input_type: str,
929
+ input_type: str = None,
897
930
  compute_cluster_id: str = None,
898
931
  nodepool_id: str = None,
899
932
  deployment_id: str = None,
@@ -903,7 +936,7 @@ class Model(Lister, BaseClient):
903
936
 
904
937
  Args:
905
938
  filepath (str): The filepath to predict.
906
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
939
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
907
940
  compute_cluster_id (str): The compute cluster ID to use for the model.
908
941
  nodepool_id (str): The nodepool ID to use for the model.
909
942
  deployment_id (str): The deployment ID to use for the model.
@@ -916,7 +949,7 @@ class Model(Lister, BaseClient):
916
949
  Example:
917
950
  >>> from clarifai.client.model import Model
918
951
  >>> model = Model("url")
919
- >>> stream_response = model.stream_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
952
+ >>> stream_response = model.stream_by_filepath('/path/to/image.jpg', deployment_id='deployment_id')
920
953
  >>> list_stream_response = [response for response in stream_response]
921
954
  """
922
955
  if not os.path.isfile(filepath):
@@ -936,7 +969,7 @@ class Model(Lister, BaseClient):
936
969
 
937
970
  def stream_by_bytes(self,
938
971
  input_bytes_iterator: Iterator[bytes],
939
- input_type: str,
972
+ input_type: str = None,
940
973
  compute_cluster_id: str = None,
941
974
  nodepool_id: str = None,
942
975
  deployment_id: str = None,
@@ -946,7 +979,7 @@ class Model(Lister, BaseClient):
946
979
 
947
980
  Args:
948
981
  input_bytes_iterator (Iterator[bytes]): Iterator of file bytes to predict on.
949
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
982
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
950
983
  compute_cluster_id (str): The compute cluster ID to use for the model.
951
984
  nodepool_id (str): The nodepool ID to use for the model.
952
985
  deployment_id (str): The deployment ID to use for the model.
@@ -960,24 +993,21 @@ class Model(Lister, BaseClient):
960
993
  >>> from clarifai.client.model import Model
961
994
  >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
962
995
  >>> stream_response = model.stream_by_bytes(iter([b'Write a tweet on future of AI']),
963
- input_type='text',
964
996
  deployment_id='deployment_id',
965
997
  inference_params=dict(temperature=str(0.7), max_tokens=30)))
966
998
  >>> list_stream_response = [response for response in stream_response]
967
999
  """
968
- if input_type not in {'image', 'text', 'video', 'audio'}:
969
- raise UserError(
970
- f"Got input type {input_type} but expected one of image, text, video, audio.")
1000
+ self._check_predict_input_type(input_type)
971
1001
 
972
1002
  def input_generator():
973
1003
  for input_bytes in input_bytes_iterator:
974
- if input_type == "image":
1004
+ if self.input_types[0] == "image":
975
1005
  yield [Inputs.get_input_from_bytes("", image_bytes=input_bytes)]
976
- elif input_type == "text":
1006
+ elif self.input_types[0] == "text":
977
1007
  yield [Inputs.get_input_from_bytes("", text_bytes=input_bytes)]
978
- elif input_type == "video":
1008
+ elif self.input_types[0] == "video":
979
1009
  yield [Inputs.get_input_from_bytes("", video_bytes=input_bytes)]
980
- elif input_type == "audio":
1010
+ elif self.input_types[0] == "audio":
981
1011
  yield [Inputs.get_input_from_bytes("", audio_bytes=input_bytes)]
982
1012
 
983
1013
  if deployment_id and (compute_cluster_id or nodepool_id):
@@ -1000,7 +1030,7 @@ class Model(Lister, BaseClient):
1000
1030
 
1001
1031
  def stream_by_url(self,
1002
1032
  url_iterator: Iterator[str],
1003
- input_type: str,
1033
+ input_type: str = None,
1004
1034
  compute_cluster_id: str = None,
1005
1035
  nodepool_id: str = None,
1006
1036
  deployment_id: str = None,
@@ -1010,7 +1040,7 @@ class Model(Lister, BaseClient):
1010
1040
 
1011
1041
  Args:
1012
1042
  url_iterator (Iterator[str]): Iterator of URLs to predict.
1013
- input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
1043
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
1014
1044
  compute_cluster_id (str): The compute cluster ID to use for the model.
1015
1045
  nodepool_id (str): The nodepool ID to use for the model.
1016
1046
  deployment_id (str): The deployment ID to use for the model.
@@ -1023,22 +1053,20 @@ class Model(Lister, BaseClient):
1023
1053
  Example:
1024
1054
  >>> from clarifai.client.model import Model
1025
1055
  >>> model = Model("url")
1026
- >>> stream_response = model.stream_by_url(iter(['url']), 'image', deployment_id='deployment_id')
1056
+ >>> stream_response = model.stream_by_url(iter(['url']), deployment_id='deployment_id')
1027
1057
  >>> list_stream_response = [response for response in stream_response]
1028
1058
  """
1029
- if input_type not in {'image', 'text', 'video', 'audio'}:
1030
- raise UserError(
1031
- f"Got input type {input_type} but expected one of image, text, video, audio.")
1059
+ self._check_predict_input_type(input_type)
1032
1060
 
1033
1061
  def input_generator():
1034
1062
  for url in url_iterator:
1035
- if input_type == "image":
1063
+ if self.input_types[0] == "image":
1036
1064
  yield [Inputs.get_input_from_url("", image_url=url)]
1037
- elif input_type == "text":
1065
+ elif self.input_types[0] == "text":
1038
1066
  yield [Inputs.get_input_from_url("", text_url=url)]
1039
- elif input_type == "video":
1067
+ elif self.input_types[0] == "video":
1040
1068
  yield [Inputs.get_input_from_url("", video_url=url)]
1041
- elif input_type == "audio":
1069
+ elif self.input_types[0] == "audio":
1042
1070
  yield [Inputs.get_input_from_url("", audio_url=url)]
1043
1071
 
1044
1072
  if deployment_id and (compute_cluster_id or nodepool_id):