clarifai 10.8.8__py3-none-any.whl → 10.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/client/deployment.py +12 -0
- clarifai/client/model.py +217 -21
- clarifai/client/nodepool.py +22 -0
- clarifai/datasets/upload/loaders/imagenet_classification.py +5 -1
- clarifai/runners/dockerfile_template/Dockerfile.cpu.template +1 -1
- clarifai/runners/dockerfile_template/Dockerfile.cuda.template +7 -57
- clarifai/runners/models/model_upload.py +92 -38
- clarifai/utils/evaluation/helpers.py +10 -4
- clarifai/utils/evaluation/main.py +2 -1
- {clarifai-10.8.8.dist-info → clarifai-10.9.0.dist-info}/METADATA +1 -1
- {clarifai-10.8.8.dist-info → clarifai-10.9.0.dist-info}/RECORD +16 -16
- {clarifai-10.8.8.dist-info → clarifai-10.9.0.dist-info}/LICENSE +0 -0
- {clarifai-10.8.8.dist-info → clarifai-10.9.0.dist-info}/WHEEL +0 -0
- {clarifai-10.8.8.dist-info → clarifai-10.9.0.dist-info}/entry_points.txt +0 -0
- {clarifai-10.8.8.dist-info → clarifai-10.9.0.dist-info}/top_level.txt +0 -0
    
        clarifai/__init__.py
    CHANGED
    
    | @@ -1 +1 @@ | |
| 1 | 
            -
            __version__ = "10. | 
| 1 | 
            +
            __version__ = "10.9.0"
         | 
    
        clarifai/client/deployment.py
    CHANGED
    
    | @@ -39,6 +39,18 @@ class Deployment(Lister, BaseClient): | |
| 39 39 | 
             
                    root_certificates_path=root_certificates_path)
         | 
| 40 40 | 
             
                Lister.__init__(self)
         | 
| 41 41 |  | 
| 42 | 
            +
              @staticmethod
         | 
| 43 | 
            +
              def get_runner_selector(user_id: str, deployment_id: str) -> resources_pb2.RunnerSelector:
         | 
| 44 | 
            +
                """Returns a RunnerSelector object for the given deployment_id.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                Args:
         | 
| 47 | 
            +
                    deployment_id (str): The deployment ID for the deployment.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                Returns:
         | 
| 50 | 
            +
                    resources_pb2.RunnerSelector: A RunnerSelector object for the given deployment_id.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                return resources_pb2.RunnerSelector(deployment_id=deployment_id, user_id=user_id)
         | 
| 53 | 
            +
             | 
| 42 54 | 
             
              def __getattr__(self, name):
         | 
| 43 55 | 
             
                return getattr(self.deployment_info, name)
         | 
| 44 56 |  | 
    
        clarifai/client/model.py
    CHANGED
    
    | @@ -7,7 +7,7 @@ import numpy as np | |
| 7 7 | 
             
            import requests
         | 
| 8 8 | 
             
            import yaml
         | 
| 9 9 | 
             
            from clarifai_grpc.grpc.api import resources_pb2, service_pb2
         | 
| 10 | 
            -
            from clarifai_grpc.grpc.api.resources_pb2 import Input
         | 
| 10 | 
            +
            from clarifai_grpc.grpc.api.resources_pb2 import Input, RunnerSelector
         | 
| 11 11 | 
             
            from clarifai_grpc.grpc.api.status import status_code_pb2
         | 
| 12 12 | 
             
            from google.protobuf.json_format import MessageToDict
         | 
| 13 13 | 
             
            from google.protobuf.struct_pb2 import Struct, Value
         | 
| @@ -16,8 +16,10 @@ from tqdm import tqdm | |
| 16 16 |  | 
| 17 17 | 
             
            from clarifai.client.base import BaseClient
         | 
| 18 18 | 
             
            from clarifai.client.dataset import Dataset
         | 
| 19 | 
            +
            from clarifai.client.deployment import Deployment
         | 
| 19 20 | 
             
            from clarifai.client.input import Inputs
         | 
| 20 21 | 
             
            from clarifai.client.lister import Lister
         | 
| 22 | 
            +
            from clarifai.client.nodepool import Nodepool
         | 
| 21 23 | 
             
            from clarifai.constants.model import (CHUNK_SIZE, MAX_CHUNK_SIZE, MAX_MODEL_PREDICT_INPUTS,
         | 
| 22 24 | 
             
                                                  MAX_RANGE_SIZE, MIN_CHUNK_SIZE, MIN_RANGE_SIZE,
         | 
| 23 25 | 
             
                                                  MODEL_EXPORT_TIMEOUT, RANGE_SIZE, TRAINABLE_MODEL_TYPES)
         | 
| @@ -404,11 +406,16 @@ class Model(Lister, BaseClient): | |
| 404 406 | 
             
                      model_id=self.id,
         | 
| 405 407 | 
             
                      **dict(self.kwargs, model_version=model_version_info))
         | 
| 406 408 |  | 
| 407 | 
            -
              def predict(self, | 
| 409 | 
            +
              def predict(self,
         | 
| 410 | 
            +
                          inputs: List[Input],
         | 
| 411 | 
            +
                          runner_selector: RunnerSelector = None,
         | 
| 412 | 
            +
                          inference_params: Dict = {},
         | 
| 413 | 
            +
                          output_config: Dict = {}):
         | 
| 408 414 | 
             
                """Predicts the model based on the given inputs.
         | 
| 409 415 |  | 
| 410 416 | 
             
                Args:
         | 
| 411 417 | 
             
                    inputs (list[Input]): The inputs to predict, must be less than 128.
         | 
| 418 | 
            +
                    runner_selector (RunnerSelector): The runner selector to use for the model.
         | 
| 412 419 | 
             
                """
         | 
| 413 420 | 
             
                if not isinstance(inputs, list):
         | 
| 414 421 | 
             
                  raise UserError('Invalid inputs, inputs must be a list of Input objects.')
         | 
| @@ -422,6 +429,7 @@ class Model(Lister, BaseClient): | |
| 422 429 | 
             
                    model_id=self.id,
         | 
| 423 430 | 
             
                    version_id=self.model_version.id,
         | 
| 424 431 | 
             
                    inputs=inputs,
         | 
| 432 | 
            +
                    runner_selector=runner_selector,
         | 
| 425 433 | 
             
                    model=self.model_info)
         | 
| 426 434 |  | 
| 427 435 | 
             
                start_time = time.time()
         | 
| @@ -445,6 +453,9 @@ class Model(Lister, BaseClient): | |
| 445 453 | 
             
              def predict_by_filepath(self,
         | 
| 446 454 | 
             
                                      filepath: str,
         | 
| 447 455 | 
             
                                      input_type: str,
         | 
| 456 | 
            +
                                      compute_cluster_id: str = None,
         | 
| 457 | 
            +
                                      nodepool_id: str = None,
         | 
| 458 | 
            +
                                      deployment_id: str = None,
         | 
| 448 459 | 
             
                                      inference_params: Dict = {},
         | 
| 449 460 | 
             
                                      output_config: Dict = {}):
         | 
| 450 461 | 
             
                """Predicts the model based on the given filepath.
         | 
| @@ -452,6 +463,9 @@ class Model(Lister, BaseClient): | |
| 452 463 | 
             
                Args:
         | 
| 453 464 | 
             
                    filepath (str): The filepath to predict.
         | 
| 454 465 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 466 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 467 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 468 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 455 469 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 456 470 | 
             
                    output_config (dict): The output config to override.
         | 
| 457 471 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -472,11 +486,15 @@ class Model(Lister, BaseClient): | |
| 472 486 | 
             
                with open(filepath, "rb") as f:
         | 
| 473 487 | 
             
                  file_bytes = f.read()
         | 
| 474 488 |  | 
| 475 | 
            -
                return self.predict_by_bytes(file_bytes, input_type,  | 
| 489 | 
            +
                return self.predict_by_bytes(file_bytes, input_type, compute_cluster_id, nodepool_id,
         | 
| 490 | 
            +
                                             deployment_id, inference_params, output_config)
         | 
| 476 491 |  | 
| 477 492 | 
             
              def predict_by_bytes(self,
         | 
| 478 493 | 
             
                                   input_bytes: bytes,
         | 
| 479 494 | 
             
                                   input_type: str,
         | 
| 495 | 
            +
                                   compute_cluster_id: str = None,
         | 
| 496 | 
            +
                                   nodepool_id: str = None,
         | 
| 497 | 
            +
                                   deployment_id: str = None,
         | 
| 480 498 | 
             
                                   inference_params: Dict = {},
         | 
| 481 499 | 
             
                                   output_config: Dict = {}):
         | 
| 482 500 | 
             
                """Predicts the model based on the given bytes.
         | 
| @@ -484,6 +502,9 @@ class Model(Lister, BaseClient): | |
| 484 502 | 
             
                Args:
         | 
| 485 503 | 
             
                    input_bytes (bytes): File Bytes to predict on.
         | 
| 486 504 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 505 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 506 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 507 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 487 508 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 488 509 | 
             
                    output_config (dict): The output config to override.
         | 
| 489 510 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -512,12 +533,30 @@ class Model(Lister, BaseClient): | |
| 512 533 | 
             
                elif input_type == "audio":
         | 
| 513 534 | 
             
                  input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
         | 
| 514 535 |  | 
| 536 | 
            +
                if deployment_id and (compute_cluster_id or nodepool_id):
         | 
| 537 | 
            +
                  raise UserError(
         | 
| 538 | 
            +
                      "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                runner_selector = None
         | 
| 541 | 
            +
                if deployment_id:
         | 
| 542 | 
            +
                  runner_selector = Deployment.get_runner_selector(
         | 
| 543 | 
            +
                      user_id=self.user_id, deployment_id=deployment_id)
         | 
| 544 | 
            +
                elif compute_cluster_id and nodepool_id:
         | 
| 545 | 
            +
                  runner_selector = Nodepool.get_runner_selector(
         | 
| 546 | 
            +
                      user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
         | 
| 547 | 
            +
             | 
| 515 548 | 
             
                return self.predict(
         | 
| 516 | 
            -
                    inputs=[input_proto], | 
| 549 | 
            +
                    inputs=[input_proto],
         | 
| 550 | 
            +
                    runner_selector=runner_selector,
         | 
| 551 | 
            +
                    inference_params=inference_params,
         | 
| 552 | 
            +
                    output_config=output_config)
         | 
| 517 553 |  | 
| 518 554 | 
             
              def predict_by_url(self,
         | 
| 519 555 | 
             
                                 url: str,
         | 
| 520 556 | 
             
                                 input_type: str,
         | 
| 557 | 
            +
                                 compute_cluster_id: str = None,
         | 
| 558 | 
            +
                                 nodepool_id: str = None,
         | 
| 559 | 
            +
                                 deployment_id: str = None,
         | 
| 521 560 | 
             
                                 inference_params: Dict = {},
         | 
| 522 561 | 
             
                                 output_config: Dict = {}):
         | 
| 523 562 | 
             
                """Predicts the model based on the given URL.
         | 
| @@ -525,6 +564,9 @@ class Model(Lister, BaseClient): | |
| 525 564 | 
             
                Args:
         | 
| 526 565 | 
             
                    url (str): The URL to predict.
         | 
| 527 566 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 567 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 568 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 569 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 528 570 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 529 571 | 
             
                    output_config (dict): The output config to override.
         | 
| 530 572 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -551,14 +593,43 @@ class Model(Lister, BaseClient): | |
| 551 593 | 
             
                elif input_type == "audio":
         | 
| 552 594 | 
             
                  input_proto = Inputs.get_input_from_url("", audio_url=url)
         | 
| 553 595 |  | 
| 554 | 
            -
                 | 
| 555 | 
            -
             | 
| 596 | 
            +
                if deployment_id and (compute_cluster_id or nodepool_id):
         | 
| 597 | 
            +
                  raise UserError(
         | 
| 598 | 
            +
                      "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                runner_selector = None
         | 
| 601 | 
            +
                if deployment_id:
         | 
| 602 | 
            +
                  runner_selector = Deployment.get_runner_selector(
         | 
| 603 | 
            +
                      user_id=self.user_id, deployment_id=deployment_id)
         | 
| 604 | 
            +
                elif compute_cluster_id and nodepool_id:
         | 
| 605 | 
            +
                  runner_selector = Nodepool.get_runner_selector(
         | 
| 606 | 
            +
                      user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
         | 
| 556 607 |  | 
| 557 | 
            -
             | 
| 608 | 
            +
                return self.predict(
         | 
| 609 | 
            +
                    inputs=[input_proto],
         | 
| 610 | 
            +
                    runner_selector=runner_selector,
         | 
| 611 | 
            +
                    inference_params=inference_params,
         | 
| 612 | 
            +
                    output_config=output_config)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
              def generate(self,
         | 
| 615 | 
            +
                           inputs: List[Input],
         | 
| 616 | 
            +
                           runner_selector: RunnerSelector = None,
         | 
| 617 | 
            +
                           inference_params: Dict = {},
         | 
| 618 | 
            +
                           output_config: Dict = {}):
         | 
| 558 619 | 
             
                """Generate the stream output on model based on the given inputs.
         | 
| 559 620 |  | 
| 560 621 | 
             
                Args:
         | 
| 561 | 
            -
                    inputs (list[Input]): The inputs to  | 
| 622 | 
            +
                    inputs (list[Input]): The inputs to generate, must be less than 128.
         | 
| 623 | 
            +
                    runner_selector (RunnerSelector): The runner selector to use for the model.
         | 
| 624 | 
            +
                    inference_params (dict): The inference params to override.
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                Example:
         | 
| 627 | 
            +
                    >>> from clarifai.client.model import Model
         | 
| 628 | 
            +
                    >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
         | 
| 629 | 
            +
                                or
         | 
| 630 | 
            +
                    >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
         | 
| 631 | 
            +
                    >>> stream_response = model.generate(inputs=[input1, input2], runner_selector=runner_selector)
         | 
| 632 | 
            +
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 562 633 | 
             
                """
         | 
| 563 634 | 
             
                if not isinstance(inputs, list):
         | 
| 564 635 | 
             
                  raise UserError('Invalid inputs, inputs must be a list of Input objects.')
         | 
| @@ -572,6 +643,7 @@ class Model(Lister, BaseClient): | |
| 572 643 | 
             
                    model_id=self.id,
         | 
| 573 644 | 
             
                    version_id=self.model_version.id,
         | 
| 574 645 | 
             
                    inputs=inputs,
         | 
| 646 | 
            +
                    runner_selector=runner_selector,
         | 
| 575 647 | 
             
                    model=self.model_info)
         | 
| 576 648 |  | 
| 577 649 | 
             
                start_time = time.time()
         | 
| @@ -597,6 +669,9 @@ class Model(Lister, BaseClient): | |
| 597 669 | 
             
              def generate_by_filepath(self,
         | 
| 598 670 | 
             
                                       filepath: str,
         | 
| 599 671 | 
             
                                       input_type: str,
         | 
| 672 | 
            +
                                       compute_cluster_id: str = None,
         | 
| 673 | 
            +
                                       nodepool_id: str = None,
         | 
| 674 | 
            +
                                       deployment_id: str = None,
         | 
| 600 675 | 
             
                                       inference_params: Dict = {},
         | 
| 601 676 | 
             
                                       output_config: Dict = {}):
         | 
| 602 677 | 
             
                """Generate the stream output on model based on the given filepath.
         | 
| @@ -604,6 +679,9 @@ class Model(Lister, BaseClient): | |
| 604 679 | 
             
                Args:
         | 
| 605 680 | 
             
                    filepath (str): The filepath to predict.
         | 
| 606 681 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 682 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 683 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 684 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 607 685 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 608 686 | 
             
                    output_config (dict): The output config to override.
         | 
| 609 687 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -615,8 +693,7 @@ class Model(Lister, BaseClient): | |
| 615 693 | 
             
                    >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
         | 
| 616 694 | 
             
                                or
         | 
| 617 695 | 
             
                    >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
         | 
| 618 | 
            -
                    >>> stream_response = model.generate_by_filepath('/path/to/image.jpg', 'image')
         | 
| 619 | 
            -
                    >>> stream_response = model.generate_by_filepath('/path/to/text.txt', 'text')
         | 
| 696 | 
            +
                    >>> stream_response = model.generate_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
         | 
| 620 697 | 
             
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 621 698 | 
             
                """
         | 
| 622 699 | 
             
                if not os.path.isfile(filepath):
         | 
| @@ -625,11 +702,21 @@ class Model(Lister, BaseClient): | |
| 625 702 | 
             
                with open(filepath, "rb") as f:
         | 
| 626 703 | 
             
                  file_bytes = f.read()
         | 
| 627 704 |  | 
| 628 | 
            -
                return self.generate_by_bytes( | 
| 705 | 
            +
                return self.generate_by_bytes(
         | 
| 706 | 
            +
                    filepath=file_bytes,
         | 
| 707 | 
            +
                    input_type=input_type,
         | 
| 708 | 
            +
                    compute_cluster_id=compute_cluster_id,
         | 
| 709 | 
            +
                    nodepool_id=nodepool_id,
         | 
| 710 | 
            +
                    deployment_id=deployment_id,
         | 
| 711 | 
            +
                    inference_params=inference_params,
         | 
| 712 | 
            +
                    output_config=output_config)
         | 
| 629 713 |  | 
| 630 714 | 
             
              def generate_by_bytes(self,
         | 
| 631 715 | 
             
                                    input_bytes: bytes,
         | 
| 632 716 | 
             
                                    input_type: str,
         | 
| 717 | 
            +
                                    compute_cluster_id: str = None,
         | 
| 718 | 
            +
                                    nodepool_id: str = None,
         | 
| 719 | 
            +
                                    deployment_id: str = None,
         | 
| 633 720 | 
             
                                    inference_params: Dict = {},
         | 
| 634 721 | 
             
                                    output_config: Dict = {}):
         | 
| 635 722 | 
             
                """Generate the stream output on model based on the given bytes.
         | 
| @@ -637,6 +724,9 @@ class Model(Lister, BaseClient): | |
| 637 724 | 
             
                Args:
         | 
| 638 725 | 
             
                    input_bytes (bytes): File Bytes to predict on.
         | 
| 639 726 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 727 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 728 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 729 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 640 730 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 641 731 | 
             
                    output_config (dict): The output config to override.
         | 
| 642 732 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -648,6 +738,7 @@ class Model(Lister, BaseClient): | |
| 648 738 | 
             
                    >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
         | 
| 649 739 | 
             
                    >>> stream_response = model.generate_by_bytes(b'Write a tweet on future of AI',
         | 
| 650 740 | 
             
                                                                  input_type='text',
         | 
| 741 | 
            +
                                                                  deployment_id='deployment_id',
         | 
| 651 742 | 
             
                                                                  inference_params=dict(temperature=str(0.7), max_tokens=30)))
         | 
| 652 743 | 
             
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 653 744 | 
             
                """
         | 
| @@ -666,12 +757,30 @@ class Model(Lister, BaseClient): | |
| 666 757 | 
             
                elif input_type == "audio":
         | 
| 667 758 | 
             
                  input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
         | 
| 668 759 |  | 
| 760 | 
            +
                if deployment_id and (compute_cluster_id or nodepool_id):
         | 
| 761 | 
            +
                  raise UserError(
         | 
| 762 | 
            +
                      "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                runner_selector = None
         | 
| 765 | 
            +
                if deployment_id:
         | 
| 766 | 
            +
                  runner_selector = Deployment.get_runner_selector(
         | 
| 767 | 
            +
                      user_id=self.user_id, deployment_id=deployment_id)
         | 
| 768 | 
            +
                elif compute_cluster_id and nodepool_id:
         | 
| 769 | 
            +
                  runner_selector = Nodepool.get_runner_selector(
         | 
| 770 | 
            +
                      user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
         | 
| 771 | 
            +
             | 
| 669 772 | 
             
                return self.generate(
         | 
| 670 | 
            -
                    inputs=[input_proto], | 
| 773 | 
            +
                    inputs=[input_proto],
         | 
| 774 | 
            +
                    runner_selector=runner_selector,
         | 
| 775 | 
            +
                    inference_params=inference_params,
         | 
| 776 | 
            +
                    output_config=output_config)
         | 
| 671 777 |  | 
| 672 778 | 
             
              def generate_by_url(self,
         | 
| 673 779 | 
             
                                  url: str,
         | 
| 674 780 | 
             
                                  input_type: str,
         | 
| 781 | 
            +
                                  compute_cluster_id: str = None,
         | 
| 782 | 
            +
                                  nodepool_id: str = None,
         | 
| 783 | 
            +
                                  deployment_id: str = None,
         | 
| 675 784 | 
             
                                  inference_params: Dict = {},
         | 
| 676 785 | 
             
                                  output_config: Dict = {}):
         | 
| 677 786 | 
             
                """Generate the stream output on model based on the given URL.
         | 
| @@ -679,6 +788,9 @@ class Model(Lister, BaseClient): | |
| 679 788 | 
             
                Args:
         | 
| 680 789 | 
             
                    url (str): The URL to predict.
         | 
| 681 790 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 791 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 792 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 793 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 682 794 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 683 795 | 
             
                    output_config (dict): The output config to override.
         | 
| 684 796 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -690,7 +802,7 @@ class Model(Lister, BaseClient): | |
| 690 802 | 
             
                    >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
         | 
| 691 803 | 
             
                                or
         | 
| 692 804 | 
             
                    >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
         | 
| 693 | 
            -
                    >>> stream_response = model.generate_by_url('url', 'image')
         | 
| 805 | 
            +
                    >>> stream_response = model.generate_by_url('url', 'image', deployment_id='deployment_id')
         | 
| 694 806 | 
             
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 695 807 | 
             
                """
         | 
| 696 808 | 
             
                if input_type not in {'image', 'text', 'video', 'audio'}:
         | 
| @@ -706,32 +818,58 @@ class Model(Lister, BaseClient): | |
| 706 818 | 
             
                elif input_type == "audio":
         | 
| 707 819 | 
             
                  input_proto = Inputs.get_input_from_url("", audio_url=url)
         | 
| 708 820 |  | 
| 821 | 
            +
                if deployment_id and (compute_cluster_id or nodepool_id):
         | 
| 822 | 
            +
                  raise UserError(
         | 
| 823 | 
            +
                      "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                runner_selector = None
         | 
| 826 | 
            +
                if deployment_id:
         | 
| 827 | 
            +
                  runner_selector = Deployment.get_runner_selector(
         | 
| 828 | 
            +
                      user_id=self.user_id, deployment_id=deployment_id)
         | 
| 829 | 
            +
                elif compute_cluster_id and nodepool_id:
         | 
| 830 | 
            +
                  runner_selector = Nodepool.get_runner_selector(
         | 
| 831 | 
            +
                      user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
         | 
| 832 | 
            +
             | 
| 709 833 | 
             
                return self.generate(
         | 
| 710 | 
            -
                    inputs=[input_proto], | 
| 834 | 
            +
                    inputs=[input_proto],
         | 
| 835 | 
            +
                    runner_selector=runner_selector,
         | 
| 836 | 
            +
                    inference_params=inference_params,
         | 
| 837 | 
            +
                    output_config=output_config)
         | 
| 711 838 |  | 
| 712 | 
            -
              def _req_iterator(self, input_iterator: Iterator[List[Input]]):
         | 
| 839 | 
            +
              def _req_iterator(self, input_iterator: Iterator[List[Input]], runner_selector: RunnerSelector):
         | 
| 713 840 | 
             
                for inputs in input_iterator:
         | 
| 714 841 | 
             
                  yield service_pb2.PostModelOutputsRequest(
         | 
| 715 842 | 
             
                      user_app_id=self.user_app_id,
         | 
| 716 843 | 
             
                      model_id=self.id,
         | 
| 717 844 | 
             
                      version_id=self.model_version.id,
         | 
| 718 845 | 
             
                      inputs=inputs,
         | 
| 846 | 
            +
                      runner_selector=runner_selector,
         | 
| 719 847 | 
             
                      model=self.model_info)
         | 
| 720 848 |  | 
| 721 849 | 
             
              def stream(self,
         | 
| 722 850 | 
             
                         inputs: Iterator[List[Input]],
         | 
| 851 | 
            +
                         runner_selector: RunnerSelector = None,
         | 
| 723 852 | 
             
                         inference_params: Dict = {},
         | 
| 724 853 | 
             
                         output_config: Dict = {}):
         | 
| 725 854 | 
             
                """Generate the stream output on model based on the given stream of inputs.
         | 
| 726 855 |  | 
| 727 856 | 
             
                Args:
         | 
| 728 857 | 
             
                    inputs (Iterator[list[Input]]): stream of inputs to predict, must be less than 128.
         | 
| 858 | 
            +
                    runner_selector (RunnerSelector): The runner selector to use for the model.
         | 
| 859 | 
            +
             | 
| 860 | 
            +
                Example:
         | 
| 861 | 
            +
                    >>> from clarifai.client.model import Model
         | 
| 862 | 
            +
                    >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
         | 
| 863 | 
            +
                                or
         | 
| 864 | 
            +
                    >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
         | 
| 865 | 
            +
                    >>> stream_response = model.stream(inputs=inputs, runner_selector=runner_selector)
         | 
| 866 | 
            +
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 729 867 | 
             
                """
         | 
| 730 868 | 
             
                # if not isinstance(inputs, Iterator[List[Input]]):
         | 
| 731 869 | 
             
                #   raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
         | 
| 732 870 |  | 
| 733 871 | 
             
                self._override_model_version(inference_params, output_config)
         | 
| 734 | 
            -
                request = self._req_iterator(inputs)
         | 
| 872 | 
            +
                request = self._req_iterator(inputs, runner_selector)
         | 
| 735 873 |  | 
| 736 874 | 
             
                start_time = time.time()
         | 
| 737 875 | 
             
                backoff_iterator = BackoffIterator(10)
         | 
| @@ -756,6 +894,9 @@ class Model(Lister, BaseClient): | |
| 756 894 | 
             
              def stream_by_filepath(self,
         | 
| 757 895 | 
             
                                     filepath: str,
         | 
| 758 896 | 
             
                                     input_type: str,
         | 
| 897 | 
            +
                                     compute_cluster_id: str = None,
         | 
| 898 | 
            +
                                     nodepool_id: str = None,
         | 
| 899 | 
            +
                                     deployment_id: str = None,
         | 
| 759 900 | 
             
                                     inference_params: Dict = {},
         | 
| 760 901 | 
             
                                     output_config: Dict = {}):
         | 
| 761 902 | 
             
                """Stream the model output based on the given filepath.
         | 
| @@ -763,6 +904,9 @@ class Model(Lister, BaseClient): | |
| 763 904 | 
             
                Args:
         | 
| 764 905 | 
             
                    filepath (str): The filepath to predict.
         | 
| 765 906 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 907 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 908 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 909 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 766 910 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 767 911 | 
             
                    output_config (dict): The output config to override.
         | 
| 768 912 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -772,7 +916,7 @@ class Model(Lister, BaseClient): | |
| 772 916 | 
             
                Example:
         | 
| 773 917 | 
             
                    >>> from clarifai.client.model import Model
         | 
| 774 918 | 
             
                    >>> model = Model("url")
         | 
| 775 | 
            -
                    >>> stream_response = model.stream_by_filepath('/path/to/image.jpg', 'image')
         | 
| 919 | 
            +
                    >>> stream_response = model.stream_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
         | 
| 776 920 | 
             
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 777 921 | 
             
                """
         | 
| 778 922 | 
             
                if not os.path.isfile(filepath):
         | 
| @@ -781,11 +925,21 @@ class Model(Lister, BaseClient): | |
| 781 925 | 
             
                with open(filepath, "rb") as f:
         | 
| 782 926 | 
             
                  file_bytes = f.read()
         | 
| 783 927 |  | 
| 784 | 
            -
                return self.stream_by_bytes( | 
| 928 | 
            +
                return self.stream_by_bytes(
         | 
| 929 | 
            +
                    input_bytes_iterator=iter([file_bytes]),
         | 
| 930 | 
            +
                    input_type=input_type,
         | 
| 931 | 
            +
                    compute_cluster_id=compute_cluster_id,
         | 
| 932 | 
            +
                    nodepool_id=nodepool_id,
         | 
| 933 | 
            +
                    deployment_id=deployment_id,
         | 
| 934 | 
            +
                    inference_params=inference_params,
         | 
| 935 | 
            +
                    output_config=output_config)
         | 
| 785 936 |  | 
| 786 937 | 
             
              def stream_by_bytes(self,
         | 
| 787 938 | 
             
                                  input_bytes_iterator: Iterator[bytes],
         | 
| 788 939 | 
             
                                  input_type: str,
         | 
| 940 | 
            +
                                  compute_cluster_id: str = None,
         | 
| 941 | 
            +
                                  nodepool_id: str = None,
         | 
| 942 | 
            +
                                  deployment_id: str = None,
         | 
| 789 943 | 
             
                                  inference_params: Dict = {},
         | 
| 790 944 | 
             
                                  output_config: Dict = {}):
         | 
| 791 945 | 
             
                """Stream the model output based on the given bytes.
         | 
| @@ -793,6 +947,9 @@ class Model(Lister, BaseClient): | |
| 793 947 | 
             
                Args:
         | 
| 794 948 | 
             
                    input_bytes_iterator (Iterator[bytes]): Iterator of file bytes to predict on.
         | 
| 795 949 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 950 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 951 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 952 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 796 953 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 797 954 | 
             
                    output_config (dict): The output config to override.
         | 
| 798 955 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -804,6 +961,7 @@ class Model(Lister, BaseClient): | |
| 804 961 | 
             
                    >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
         | 
| 805 962 | 
             
                    >>> stream_response = model.stream_by_bytes(iter([b'Write a tweet on future of AI']),
         | 
| 806 963 | 
             
                                                                input_type='text',
         | 
| 964 | 
            +
                                                                deployment_id='deployment_id',
         | 
| 807 965 | 
             
                                                                inference_params=dict(temperature=str(0.7), max_tokens=30)))
         | 
| 808 966 | 
             
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 809 967 | 
             
                """
         | 
| @@ -822,11 +980,30 @@ class Model(Lister, BaseClient): | |
| 822 980 | 
             
                    elif input_type == "audio":
         | 
| 823 981 | 
             
                      yield [Inputs.get_input_from_bytes("", audio_bytes=input_bytes)]
         | 
| 824 982 |  | 
| 825 | 
            -
                 | 
| 983 | 
            +
                if deployment_id and (compute_cluster_id or nodepool_id):
         | 
| 984 | 
            +
                  raise UserError(
         | 
| 985 | 
            +
                      "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                runner_selector = None
         | 
| 988 | 
            +
                if deployment_id:
         | 
| 989 | 
            +
                  runner_selector = Deployment.get_runner_selector(
         | 
| 990 | 
            +
                      user_id=self.user_id, deployment_id=deployment_id)
         | 
| 991 | 
            +
                elif compute_cluster_id and nodepool_id:
         | 
| 992 | 
            +
                  runner_selector = Nodepool.get_runner_selector(
         | 
| 993 | 
            +
                      user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
         | 
| 994 | 
            +
             | 
| 995 | 
            +
                return self.stream(
         | 
| 996 | 
            +
                    inputs=input_generator(),
         | 
| 997 | 
            +
                    runner_selector=runner_selector,
         | 
| 998 | 
            +
                    inference_params=inference_params,
         | 
| 999 | 
            +
                    output_config=output_config)
         | 
| 826 1000 |  | 
| 827 1001 | 
             
              def stream_by_url(self,
         | 
| 828 1002 | 
             
                                url_iterator: Iterator[str],
         | 
| 829 1003 | 
             
                                input_type: str,
         | 
| 1004 | 
            +
                                compute_cluster_id: str = None,
         | 
| 1005 | 
            +
                                nodepool_id: str = None,
         | 
| 1006 | 
            +
                                deployment_id: str = None,
         | 
| 830 1007 | 
             
                                inference_params: Dict = {},
         | 
| 831 1008 | 
             
                                output_config: Dict = {}):
         | 
| 832 1009 | 
             
                """Stream the model output based on the given URL.
         | 
| @@ -834,6 +1011,9 @@ class Model(Lister, BaseClient): | |
| 834 1011 | 
             
                Args:
         | 
| 835 1012 | 
             
                    url_iterator (Iterator[str]): Iterator of URLs to predict.
         | 
| 836 1013 | 
             
                    input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
         | 
| 1014 | 
            +
                    compute_cluster_id (str): The compute cluster ID to use for the model.
         | 
| 1015 | 
            +
                    nodepool_id (str): The nodepool ID to use for the model.
         | 
| 1016 | 
            +
                    deployment_id (str): The deployment ID to use for the model.
         | 
| 837 1017 | 
             
                    inference_params (dict): The inference params to override.
         | 
| 838 1018 | 
             
                    output_config (dict): The output config to override.
         | 
| 839 1019 | 
             
                      min_value (float): The minimum value of the prediction confidence to filter.
         | 
| @@ -843,7 +1023,7 @@ class Model(Lister, BaseClient): | |
| 843 1023 | 
             
                Example:
         | 
| 844 1024 | 
             
                    >>> from clarifai.client.model import Model
         | 
| 845 1025 | 
             
                    >>> model = Model("url")
         | 
| 846 | 
            -
                    >>> stream_response = model.stream_by_url(iter(['url']), 'image')
         | 
| 1026 | 
            +
                    >>> stream_response = model.stream_by_url(iter(['url']), 'image', deployment_id='deployment_id')
         | 
| 847 1027 | 
             
                    >>> list_stream_response = [response for response in stream_response]
         | 
| 848 1028 | 
             
                """
         | 
| 849 1029 | 
             
                if input_type not in {'image', 'text', 'video', 'audio'}:
         | 
| @@ -861,7 +1041,23 @@ class Model(Lister, BaseClient): | |
| 861 1041 | 
             
                    elif input_type == "audio":
         | 
| 862 1042 | 
             
                      yield [Inputs.get_input_from_url("", audio_url=url)]
         | 
| 863 1043 |  | 
| 864 | 
            -
                 | 
| 1044 | 
            +
                if deployment_id and (compute_cluster_id or nodepool_id):
         | 
| 1045 | 
            +
                  raise UserError(
         | 
| 1046 | 
            +
                      "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
         | 
| 1047 | 
            +
             | 
| 1048 | 
            +
                runner_selector = None
         | 
| 1049 | 
            +
                if deployment_id:
         | 
| 1050 | 
            +
                  runner_selector = Deployment.get_runner_selector(
         | 
| 1051 | 
            +
                      user_id=self.user_id, deployment_id=deployment_id)
         | 
| 1052 | 
            +
                elif compute_cluster_id and nodepool_id:
         | 
| 1053 | 
            +
                  runner_selector = Nodepool.get_runner_selector(
         | 
| 1054 | 
            +
                      user_id=self.user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
         | 
| 1055 | 
            +
             | 
| 1056 | 
            +
                return self.stream(
         | 
| 1057 | 
            +
                    inputs=input_generator(),
         | 
| 1058 | 
            +
                    runner_selector=runner_selector,
         | 
| 1059 | 
            +
                    inference_params=inference_params,
         | 
| 1060 | 
            +
                    output_config=output_config)
         | 
| 865 1061 |  | 
| 866 1062 | 
             
              def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
         | 
| 867 1063 | 
             
                """Overrides the model version.
         | 
    
        clarifai/client/nodepool.py
    CHANGED
    
    | @@ -114,6 +114,28 @@ class Nodepool(Lister, BaseClient): | |
| 114 114 | 
             
                  deployment["visibility"] = resources_pb2.Visibility(**deployment["visibility"])
         | 
| 115 115 | 
             
                return deployment
         | 
| 116 116 |  | 
| 117 | 
            +
              @staticmethod
         | 
| 118 | 
            +
              def get_runner_selector(user_id: str, compute_cluster_id: str,
         | 
| 119 | 
            +
                                      nodepool_id: str) -> resources_pb2.RunnerSelector:
         | 
| 120 | 
            +
                """Returns a RunnerSelector object for the specified compute cluster and nodepool.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                Args:
         | 
| 123 | 
            +
                    user_id (str): The user ID of the user.
         | 
| 124 | 
            +
                    compute_cluster_id (str): The compute cluster ID for the compute cluster.
         | 
| 125 | 
            +
                    nodepool_id (str): The nodepool ID for the nodepool.
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                Returns:
         | 
| 128 | 
            +
                    resources_pb2.RunnerSelector: A RunnerSelector object for the specified compute cluster and nodepool.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                Example:
         | 
| 131 | 
            +
                    >>> from clarifai.client.nodepool import Nodepool
         | 
| 132 | 
            +
                    >>> nodepool = Nodepool(nodepool_id="nodepool_id", user_id="user_id")
         | 
| 133 | 
            +
                    >>> runner_selector = Nodepool.get_runner_selector(user_id="user_id", compute_cluster_id="compute_cluster_id", nodepool_id="nodepool_id")
         | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                compute_cluster = resources_pb2.ComputeCluster(id=compute_cluster_id, user_id=user_id)
         | 
| 136 | 
            +
                nodepool = resources_pb2.Nodepool(id=nodepool_id, compute_cluster=compute_cluster)
         | 
| 137 | 
            +
                return resources_pb2.RunnerSelector(nodepool=nodepool)
         | 
| 138 | 
            +
             | 
| 117 139 | 
             
              def create_deployment(self, deployment_id: str, config_filepath: str) -> Deployment:
         | 
| 118 140 | 
             
                """Creates a deployment for the nodepool.
         | 
| 119 141 |  | 
| @@ -24,6 +24,10 @@ class ImageNetDataLoader(ClarifaiDataLoader): | |
| 24 24 |  | 
| 25 25 | 
             
                self.load_data()
         | 
| 26 26 |  | 
| 27 | 
            +
              @property
         | 
| 28 | 
            +
              def task(self):
         | 
| 29 | 
            +
                return "visual_classification"
         | 
| 30 | 
            +
             | 
| 27 31 | 
             
              def load_data(self):
         | 
| 28 32 | 
             
                #Creating label map
         | 
| 29 33 | 
             
                with open(os.path.join(self.data_dir, "LOC_synset_mapping.txt")) as _file:
         | 
| @@ -54,5 +58,5 @@ class ImageNetDataLoader(ClarifaiDataLoader): | |
| 54 58 | 
             
              def __getitem__(self, idx):
         | 
| 55 59 | 
             
                return VisualClassificationFeatures(
         | 
| 56 60 | 
             
                    image_path=self.image_paths[idx],
         | 
| 57 | 
            -
                     | 
| 61 | 
            +
                    labels=self.concepts[idx],
         | 
| 58 62 | 
             
                    id=self.image_paths[idx].split('.')[0].split('/')[-1])
         | 
| @@ -6,37 +6,18 @@ | |
| 6 6 | 
             
            # * Export environment variables to use the virtualenv by default
         | 
| 7 7 | 
             
            # * Create a non-root user with minimal privileges and use it
         | 
| 8 8 | 
             
            ARG TARGET_PLATFORM=linux/amd64
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            FROM --platform=$TARGET_PLATFORM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS build
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            ARG DRIVER_VERSION=535
         | 
| 13 | 
            -
            ARG PYTHON_VERSION=${PYTHON_VERSION}
         | 
| 9 | 
            +
            FROM --platform=$TARGET_PLATFORM public.ecr.aws/docker/library/python:${PYTHON_VERSION}-slim-bookworm as build
         | 
| 14 10 |  | 
| 15 11 | 
             
            ENV DEBIAN_FRONTEND=noninteractive
         | 
| 16 12 | 
             
            RUN apt-get update && \
         | 
| 17 13 | 
             
                apt-get install --no-install-suggests --no-install-recommends --yes \
         | 
| 18 | 
            -
                software-properties-common | 
| 19 | 
            -
                gpg-agent && \
         | 
| 20 | 
            -
                add-apt-repository ppa:graphics-drivers/ppa && \
         | 
| 21 | 
            -
                add-apt-repository ppa:deadsnakes/ppa && \
         | 
| 22 | 
            -
                apt-get update && \
         | 
| 23 | 
            -
                apt-get install --no-install-suggests --no-install-recommends --yes \
         | 
| 24 | 
            -
                python${PYTHON_VERSION} \
         | 
| 25 | 
            -
                python${PYTHON_VERSION}-venv \
         | 
| 26 | 
            -
                python${PYTHON_VERSION}-dev \
         | 
| 14 | 
            +
                software-properties-common  \
         | 
| 27 15 | 
             
                gcc \
         | 
| 28 | 
            -
                libpython3-dev \
         | 
| 29 | 
            -
                # drivers and nvidia-smi
         | 
| 30 | 
            -
                nvidia-utils-${DRIVER_VERSION} \
         | 
| 31 | 
            -
                nvidia-driver-${DRIVER_VERSION} \
         | 
| 32 | 
            -
                libcap2-bin && \
         | 
| 16 | 
            +
                libpython3-dev && \
         | 
| 33 17 | 
             
                python${PYTHON_VERSION} -m venv /venv && \
         | 
| 34 18 | 
             
                /venv/bin/pip install --disable-pip-version-check --upgrade pip setuptools wheel && \
         | 
| 35 | 
            -
                # Create a non-root user with minimal privileges and set file permissions
         | 
| 36 | 
            -
                ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \
         | 
| 37 19 | 
             
                apt-get clean && rm -rf /var/lib/apt/lists/*
         | 
| 38 20 |  | 
| 39 | 
            -
             | 
| 40 21 | 
             
            # Set environment variables to use virtualenv by default
         | 
| 41 22 | 
             
            ENV VIRTUAL_ENV=/venv
         | 
| 42 23 | 
             
            ENV PATH="$VIRTUAL_ENV/bin:$PATH"
         | 
| @@ -57,54 +38,24 @@ RUN python -m pip install clarifai | |
| 57 38 | 
             
            # Finally copy everything we built into a distroless image for runtime.
         | 
| 58 39 | 
             
            ######################>#######
         | 
| 59 40 | 
             
            ARG TARGET_PLATFORM=linux/amd64
         | 
| 60 | 
            -
             | 
| 61 | 
            -
            FROM --platform=$TARGET_PLATFORM gcr.io/distroless/python3-debian12:debug
         | 
| 41 | 
            +
            FROM --platform=$TARGET_PLATFORM gcr.io/distroless/python3-debian12:latest
         | 
| 42 | 
            +
            # FROM --platform=$TARGET_PLATFORM gcr.io/distroless/python3-debian12:debug
         | 
| 62 43 | 
             
            ARG PYTHON_VERSION=${PYTHON_VERSION}
         | 
| 63 44 | 
             
            # needed to call pip directly
         | 
| 64 45 | 
             
            COPY --from=build /bin/sh /bin/sh
         | 
| 65 46 |  | 
| 66 | 
            -
            # Copy driver libraries based on architecture
         | 
| 67 | 
            -
            # Set FOLDER_NAME based on TARGET_PLATFORM
         | 
| 68 | 
            -
            ENV TARGET_PLATFORM=${TARGET_PLATFORM}
         | 
| 69 | 
            -
            ENV FOLDER_NAME="x86_64-linux-gnu"
         | 
| 70 | 
            -
            # RUN if [ "${TARGET_PLATFORM}" = "linux/arm64" ]; then \
         | 
| 71 | 
            -
            #         export FOLDER_NAME="aarch64-linux-gnu"; \
         | 
| 72 | 
            -
            #     fi && \
         | 
| 73 | 
            -
            #     echo "FOLDER_NAME=${FOLDER_NAME}"
         | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 47 | 
             
            # virtual env
         | 
| 77 48 | 
             
            COPY --from=build /venv /venv
         | 
| 78 49 |  | 
| 79 | 
            -
            # cuda
         | 
| 80 | 
            -
            COPY --from=build --chmod=755 /usr/local/cuda /usr/local/cuda
         | 
| 81 | 
            -
             | 
| 82 50 | 
             
            # We have to overwrite the python3 binary that the distroless image uses
         | 
| 83 | 
            -
            COPY --from=build /usr/bin/python${PYTHON_VERSION} /usr/bin/python3
         | 
| 51 | 
            +
            COPY --from=build /usr/local/bin/python${PYTHON_VERSION} /usr/bin/python3
         | 
| 84 52 | 
             
            # And also copy in all the lib files for it.
         | 
| 85 | 
            -
            COPY --from=build /usr/lib/ | 
| 86 | 
            -
            # Note that distroless comes with a fixed python version, so we may need to overwrite that specific
         | 
| 87 | 
            -
            # version.
         | 
| 88 | 
            -
             | 
| 89 | 
            -
            # for debugging
         | 
| 90 | 
            -
            COPY --from=build /usr/bin/nvidia-smi /usr/bin/nvidia-smi
         | 
| 91 | 
            -
            # Copy driver libraries based on architecture
         | 
| 92 | 
            -
            COPY --from=build /usr/lib/${FOLDER_NAME}/libcuda.so* /usr/lib/${FOLDER_NAME}/
         | 
| 93 | 
            -
            COPY --from=build /usr/lib/${FOLDER_NAME}/libnvidia-ml.so* /usr/lib/${FOLDER_NAME}/
         | 
| 94 | 
            -
             | 
| 95 | 
            -
            # Set environment variables for CUDA
         | 
| 96 | 
            -
            ENV PATH=/usr/local/cuda/bin:${PATH}
         | 
| 97 | 
            -
            ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/lib/${FOLDER_NAME}:$LD_LIBRARY_PATH
         | 
| 98 | 
            -
            ENV CUDA_HOME=/usr/local/cuda
         | 
| 53 | 
            +
            COPY --from=build /usr/local/lib/ /usr/lib/
         | 
| 99 54 |  | 
| 100 55 | 
             
            # Set environment variables to use virtualenv by default
         | 
| 101 56 | 
             
            ENV VIRTUAL_ENV=/venv
         | 
| 102 | 
            -
            # ENV PATH=${VIRTUAL_ENV}/bin:${PATH}
         | 
| 103 57 | 
             
            ENV PYTHONPATH=${PYTHONPATH}:${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages
         | 
| 104 58 |  | 
| 105 | 
            -
            # ENTRYPOINT ["${VIRTUAL_ENV}/bin/python"]
         | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 59 | 
             
            # These will be set by the templaing system.
         | 
| 109 60 | 
             
            ENV CLARIFAI_PAT=${CLARIFAI_PAT}
         | 
| 110 61 | 
             
            ENV CLARIFAI_USER_ID=${CLARIFAI_USER_ID}
         | 
| @@ -123,7 +74,6 @@ COPY . /app/model_dir/${name} | |
| 123 74 | 
             
            # Add the model directory to the python path.
         | 
| 124 75 | 
             
            ENV PYTHONPATH=${PYTHONPATH}:/app/model_dir/${name}
         | 
| 125 76 |  | 
| 126 | 
            -
             | 
| 127 77 | 
             
            # Finally run the clarifai entrypoint to start the runner loop and local dev server.
         | 
| 128 78 | 
             
            # Note(zeiler): we may want to make this a clarifai CLI call.
         | 
| 129 79 | 
             
            CMD ["-m", "clarifai.runners.server", "--model_path", "/app/model_dir/${name}"]
         | 
| @@ -11,6 +11,7 @@ from rich import print | |
| 11 11 |  | 
| 12 12 | 
             
            from clarifai.client import BaseClient
         | 
| 13 13 | 
             
            from clarifai.runners.utils.loader import HuggingFaceLoarder
         | 
| 14 | 
            +
            from clarifai.urls.helper import ClarifaiUrlHelper
         | 
| 14 15 | 
             
            from clarifai.utils.logging import logger
         | 
| 15 16 |  | 
| 16 17 |  | 
| @@ -33,6 +34,7 @@ class ModelUploader: | |
| 33 34 | 
             
                self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
         | 
| 34 35 | 
             
                self.model_proto = self._get_model_proto()
         | 
| 35 36 | 
             
                self.model_id = self.model_proto.id
         | 
| 37 | 
            +
                self.model_version_id = None
         | 
| 36 38 | 
             
                self.inference_compute_info = self._get_inference_compute_info()
         | 
| 37 39 | 
             
                self.is_v3 = True  # Do model build for v3
         | 
| 38 40 |  | 
| @@ -64,15 +66,27 @@ class ModelUploader: | |
| 64 66 | 
             
                  model = self.config.get('model')
         | 
| 65 67 | 
             
                  assert "user_id" in model, "user_id not found in the config file"
         | 
| 66 68 | 
             
                  assert "app_id" in model, "app_id not found in the config file"
         | 
| 69 | 
            +
                  # The owner of the model and the app.
         | 
| 67 70 | 
             
                  user_id = model.get('user_id')
         | 
| 68 71 | 
             
                  app_id = model.get('app_id')
         | 
| 69 72 |  | 
| 70 73 | 
             
                  base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
         | 
| 71 74 |  | 
| 72 75 | 
             
                  self._client = BaseClient(user_id=user_id, app_id=app_id, base=base)
         | 
| 73 | 
            -
             | 
| 76 | 
            +
             | 
| 74 77 | 
             
                return self._client
         | 
| 75 78 |  | 
| 79 | 
            +
              @property
         | 
| 80 | 
            +
              def model_url(self):
         | 
| 81 | 
            +
                url_helper = ClarifaiUrlHelper(self._client.auth_helper)
         | 
| 82 | 
            +
                if self.model_version_id is not None:
         | 
| 83 | 
            +
                  return url_helper.clarifai_url(self.client.user_app_id.user_id,
         | 
| 84 | 
            +
                                                 self.client.user_app_id.app_id, "models", self.model_id)
         | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                  return url_helper.clarifai_url(self.client.user_app_id.user_id,
         | 
| 87 | 
            +
                                                 self.client.user_app_id.app_id, "models", self.model_id,
         | 
| 88 | 
            +
                                                 self.model_version_id)
         | 
| 89 | 
            +
             | 
| 76 90 | 
             
              def _get_model_proto(self):
         | 
| 77 91 | 
             
                assert "model" in self.config, "model info not found in the config file"
         | 
| 78 92 | 
             
                model = self.config.get('model')
         | 
| @@ -83,9 +97,6 @@ class ModelUploader: | |
| 83 97 | 
             
                assert "app_id" in model, "app_id not found in the config file"
         | 
| 84 98 |  | 
| 85 99 | 
             
                model_proto = json_format.ParseDict(model, resources_pb2.Model())
         | 
| 86 | 
            -
                assert model_proto.id == model_proto.id.lower(), "Model ID must be lowercase"
         | 
| 87 | 
            -
                assert model_proto.user_id == model_proto.user_id.lower(), "User ID must be lowercase"
         | 
| 88 | 
            -
                assert model_proto.app_id == model_proto.app_id.lower(), "App ID must be lowercase"
         | 
| 89 100 |  | 
| 90 101 | 
             
                return model_proto
         | 
| 91 102 |  | 
| @@ -95,15 +106,20 @@ class ModelUploader: | |
| 95 106 | 
             
                inference_compute_info = self.config.get('inference_compute_info')
         | 
| 96 107 | 
             
                return json_format.ParseDict(inference_compute_info, resources_pb2.ComputeInfo())
         | 
| 97 108 |  | 
| 98 | 
            -
              def  | 
| 109 | 
            +
              def check_model_exists(self):
         | 
| 99 110 | 
             
                resp = self.client.STUB.GetModel(
         | 
| 100 111 | 
             
                    service_pb2.GetModelRequest(
         | 
| 101 112 | 
             
                        user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
         | 
| 102 113 | 
             
                if resp.status.code == status_code_pb2.SUCCESS:
         | 
| 114 | 
            +
                  return True
         | 
| 115 | 
            +
                return False
         | 
| 116 | 
            +
             | 
| 117 | 
            +
              def maybe_create_model(self):
         | 
| 118 | 
            +
                if self.check_model_exists():
         | 
| 103 119 | 
             
                  logger.info(
         | 
| 104 120 | 
             
                      f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
         | 
| 105 121 | 
             
                      f"will create a new version for it.")
         | 
| 106 | 
            -
                  return | 
| 122 | 
            +
                  return
         | 
| 107 123 |  | 
| 108 124 | 
             
                request = service_pb2.PostModelsRequest(
         | 
| 109 125 | 
             
                    user_app_id=self.client.user_app_id,
         | 
| @@ -151,6 +167,18 @@ class ModelUploader: | |
| 151 167 | 
             
                with open(os.path.join(self.folder, 'Dockerfile'), 'w') as dockerfile:
         | 
| 152 168 | 
             
                  dockerfile.write(dockerfile_content)
         | 
| 153 169 |  | 
| 170 | 
            +
              @property
         | 
| 171 | 
            +
              def checkpoint_path(self):
         | 
| 172 | 
            +
                return os.path.join(self.folder, self.checkpoint_suffix)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
              @property
         | 
| 175 | 
            +
              def checkpoint_suffix(self):
         | 
| 176 | 
            +
                return '1/checkpoints'
         | 
| 177 | 
            +
             | 
| 178 | 
            +
              @property
         | 
| 179 | 
            +
              def tar_file(self):
         | 
| 180 | 
            +
                return f"{self.folder}.tar.gz"
         | 
| 181 | 
            +
             | 
| 154 182 | 
             
              def download_checkpoints(self):
         | 
| 155 183 | 
             
                if not self.config.get("checkpoints"):
         | 
| 156 184 | 
             
                  logger.info("No checkpoints specified in the config file")
         | 
| @@ -173,8 +201,7 @@ class ModelUploader: | |
| 173 201 | 
             
                    assert hf_token != 'hf_token', "The default 'hf_token' is not valid. Please provide a valid token or leave that field out of config.yaml if not needed."
         | 
| 174 202 | 
             
                  loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
         | 
| 175 203 |  | 
| 176 | 
            -
                   | 
| 177 | 
            -
                  success = loader.download_checkpoints(checkpoint_path)
         | 
| 204 | 
            +
                  success = loader.download_checkpoints(self.checkpoint_path)
         | 
| 178 205 |  | 
| 179 206 | 
             
                  if not success:
         | 
| 180 207 | 
             
                    logger.error(f"Failed to download checkpoints for model {repo_id}")
         | 
| @@ -207,7 +234,7 @@ class ModelUploader: | |
| 207 234 |  | 
| 208 235 | 
             
              def _get_model_version_proto(self):
         | 
| 209 236 |  | 
| 210 | 
            -
                 | 
| 237 | 
            +
                model_version_proto = resources_pb2.ModelVersion(
         | 
| 211 238 | 
             
                    pretrained_model_config=resources_pb2.PretrainedModelConfig(),
         | 
| 212 239 | 
             
                    inference_compute_info=self.inference_compute_info,
         | 
| 213 240 | 
             
                )
         | 
| @@ -216,31 +243,40 @@ class ModelUploader: | |
| 216 243 | 
             
                if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE:
         | 
| 217 244 |  | 
| 218 245 | 
             
                  loader = HuggingFaceLoarder()
         | 
| 219 | 
            -
                   | 
| 220 | 
            -
                  labels = loader.fetch_labels(checkpoint_path)
         | 
| 246 | 
            +
                  labels = loader.fetch_labels(self.checkpoint_path)
         | 
| 221 247 | 
             
                  # sort the concepts by id and then update the config file
         | 
| 222 248 | 
             
                  labels = sorted(labels.items(), key=lambda x: int(x[0]))
         | 
| 223 249 |  | 
| 224 250 | 
             
                  config_file = os.path.join(self.folder, 'config.yaml')
         | 
| 225 251 | 
             
                  self.hf_labels_to_config(labels, config_file)
         | 
| 226 252 |  | 
| 227 | 
            -
                   | 
| 228 | 
            -
             | 
| 253 | 
            +
                  model_version_proto.output_info.data.concepts.extend(
         | 
| 254 | 
            +
                      self._concepts_protos_from_concepts(labels))
         | 
| 255 | 
            +
                return model_version_proto
         | 
| 229 256 |  | 
| 230 | 
            -
              def upload_model_version(self):
         | 
| 257 | 
            +
              def upload_model_version(self, download_checkpoints):
         | 
| 231 258 | 
             
                file_path = f"{self.folder}.tar.gz"
         | 
| 232 259 | 
             
                logger.info(f"Will tar it into file: {file_path}")
         | 
| 233 260 |  | 
| 261 | 
            +
                if download_checkpoints:
         | 
| 262 | 
            +
                  tar_cmd = f"tar --exclude=*~ -czvf {self.tar_file} -C {self.folder} ."
         | 
| 263 | 
            +
                else:  # we don't want to send the checkpoints up even if they are in the folder.
         | 
| 264 | 
            +
                  logger.info(f"Skipping {self.checkpoint_path} in the tar file that is uploaded.")
         | 
| 265 | 
            +
                  tar_cmd = f"tar --exclude={self.checkpoint_suffix} --exclude=*~ -czvf {self.tar_file} -C {self.folder} ."
         | 
| 234 266 | 
             
                # Tar the folder
         | 
| 235 | 
            -
                 | 
| 267 | 
            +
                logger.debug(tar_cmd)
         | 
| 268 | 
            +
                os.system(tar_cmd)
         | 
| 236 269 | 
             
                logger.info("Tarring complete, about to start upload.")
         | 
| 237 270 |  | 
| 238 | 
            -
                 | 
| 271 | 
            +
                model_version_proto = self._get_model_version_proto()
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                file_size = os.path.getsize(self.tar_file)
         | 
| 274 | 
            +
                logger.info(f"Size of the tar is: {file_size} bytes")
         | 
| 239 275 |  | 
| 240 | 
            -
                 | 
| 276 | 
            +
                self.maybe_create_model()
         | 
| 241 277 |  | 
| 242 278 | 
             
                for response in self.client.STUB.PostModelVersionsUpload(
         | 
| 243 | 
            -
                    self.model_version_stream_upload_iterator( | 
| 279 | 
            +
                    self.model_version_stream_upload_iterator(model_version_proto, file_path),):
         | 
| 244 280 | 
             
                  percent_completed = 0
         | 
| 245 281 | 
             
                  if response.status.code == status_code_pb2.UPLOAD_IN_PROGRESS:
         | 
| 246 282 | 
             
                    percent_completed = response.status.percent_completed
         | 
| @@ -257,13 +293,18 @@ class ModelUploader: | |
| 257 293 | 
             
                if response.status.code != status_code_pb2.MODEL_BUILDING:
         | 
| 258 294 | 
             
                  logger.error(f"Failed to upload model version: {response}")
         | 
| 259 295 | 
             
                  return
         | 
| 260 | 
            -
                model_version_id = response.model_version_id
         | 
| 261 | 
            -
                logger.info(f"Created Model Version ID: {model_version_id}")
         | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
               | 
| 266 | 
            -
             | 
| 296 | 
            +
                self.model_version_id = response.model_version_id
         | 
| 297 | 
            +
                logger.info(f"Created Model Version ID: {self.model_version_id}")
         | 
| 298 | 
            +
                logger.info(f"Full url to that version is: {self.model_url}")
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                success = self.monitor_model_build()
         | 
| 301 | 
            +
                if success:  # cleanup the tar_file if it exists
         | 
| 302 | 
            +
                  if os.path.exists(self.tar_file):
         | 
| 303 | 
            +
                    logger.info(f"Cleaning up upload file: {self.tar_file}")
         | 
| 304 | 
            +
                    os.remove(self.tar_file)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
              def model_version_stream_upload_iterator(self, model_version_proto, file_path):
         | 
| 307 | 
            +
                yield self.init_upload_model_version(model_version_proto, file_path)
         | 
| 267 308 | 
             
                with open(file_path, "rb") as f:
         | 
| 268 309 | 
             
                  file_size = os.path.getsize(file_path)
         | 
| 269 310 | 
             
                  chunk_size = int(127 * 1024 * 1024)  # 127MB chunk size
         | 
| @@ -293,51 +334,58 @@ class ModelUploader: | |
| 293 334 | 
             
                if read_so_far == file_size:
         | 
| 294 335 | 
             
                  logger.info("\nUpload complete!, waiting for model build...")
         | 
| 295 336 |  | 
| 296 | 
            -
              def init_upload_model_version(self,  | 
| 337 | 
            +
              def init_upload_model_version(self, model_version_proto, file_path):
         | 
| 297 338 | 
             
                file_size = os.path.getsize(file_path)
         | 
| 298 | 
            -
                logger.info(f"Uploading model version  | 
| 339 | 
            +
                logger.info(f"Uploading model version of model {self.model_proto.id}")
         | 
| 299 340 | 
             
                logger.info(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
         | 
| 300 341 | 
             
                return service_pb2.PostModelVersionsUploadRequest(
         | 
| 301 342 | 
             
                    upload_config=service_pb2.PostModelVersionsUploadConfig(
         | 
| 302 343 | 
             
                        user_app_id=self.client.user_app_id,
         | 
| 303 344 | 
             
                        model_id=self.model_proto.id,
         | 
| 304 | 
            -
                        model_version= | 
| 345 | 
            +
                        model_version=model_version_proto,
         | 
| 305 346 | 
             
                        total_size=file_size,
         | 
| 306 347 | 
             
                        is_v3=self.is_v3,
         | 
| 307 348 | 
             
                    ))
         | 
| 308 349 |  | 
| 309 | 
            -
              def monitor_model_build(self | 
| 350 | 
            +
              def monitor_model_build(self):
         | 
| 310 351 | 
             
                st = time.time()
         | 
| 311 352 | 
             
                while True:
         | 
| 312 353 | 
             
                  resp = self.client.STUB.GetModelVersion(
         | 
| 313 354 | 
             
                      service_pb2.GetModelVersionRequest(
         | 
| 314 355 | 
             
                          user_app_id=self.client.user_app_id,
         | 
| 315 356 | 
             
                          model_id=self.model_proto.id,
         | 
| 316 | 
            -
                          version_id=model_version_id,
         | 
| 357 | 
            +
                          version_id=self.model_version_id,
         | 
| 317 358 | 
             
                      ))
         | 
| 318 359 | 
             
                  status_code = resp.model_version.status.code
         | 
| 319 360 | 
             
                  if status_code == status_code_pb2.MODEL_BUILDING:
         | 
| 320 361 | 
             
                    print(f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
         | 
| 321 362 | 
             
                    time.sleep(1)
         | 
| 322 363 | 
             
                  elif status_code == status_code_pb2.MODEL_TRAINED:
         | 
| 323 | 
            -
                    logger.info("\nModel build complete!")
         | 
| 324 | 
            -
                    logger.info(
         | 
| 325 | 
            -
             | 
| 326 | 
            -
                    )
         | 
| 327 | 
            -
                    break
         | 
| 364 | 
            +
                    logger.info(f"\nModel build complete! (elapsed {time.time() - st:.1f}s)")
         | 
| 365 | 
            +
                    logger.info(f"Check out the model at {self.model_url}")
         | 
| 366 | 
            +
                    return True
         | 
| 328 367 | 
             
                  else:
         | 
| 329 368 | 
             
                    logger.info(
         | 
| 330 369 | 
             
                        f"\nModel build failed with status: {resp.model_version.status} and response {resp}")
         | 
| 331 | 
            -
                     | 
| 370 | 
            +
                    return False
         | 
| 332 371 |  | 
| 333 372 |  | 
| 334 373 | 
             
            def main(folder, download_checkpoints):
         | 
| 335 374 | 
             
              uploader = ModelUploader(folder)
         | 
| 336 375 | 
             
              if download_checkpoints:
         | 
| 337 376 | 
             
                uploader.download_checkpoints()
         | 
| 338 | 
            -
               | 
| 377 | 
            +
              if not args.skip_dockerfile:
         | 
| 378 | 
            +
                uploader.create_dockerfile()
         | 
| 379 | 
            +
              exists = uploader.check_model_exists()
         | 
| 380 | 
            +
              if exists:
         | 
| 381 | 
            +
                logger.info(
         | 
| 382 | 
            +
                    f"Model already exists at {uploader.model_url}, this upload will create a new version for it."
         | 
| 383 | 
            +
                )
         | 
| 384 | 
            +
              else:
         | 
| 385 | 
            +
                logger.info(f"New model will be created at {uploader.model_url} with it's first version.")
         | 
| 386 | 
            +
             | 
| 339 387 | 
             
              input("Press Enter to continue...")
         | 
| 340 | 
            -
              uploader.upload_model_version()
         | 
| 388 | 
            +
              uploader.upload_model_version(download_checkpoints)
         | 
| 341 389 |  | 
| 342 390 |  | 
| 343 391 | 
             
            if __name__ == "__main__":
         | 
| @@ -351,6 +399,12 @@ if __name__ == "__main__": | |
| 351 399 | 
             
                  help=
         | 
| 352 400 | 
             
                  'Flag to download checkpoints before uploading and including them in the tar file that is uploaded. Defaults to False, which will attempt to download them at docker build time.',
         | 
| 353 401 | 
             
              )
         | 
| 402 | 
            +
              parser.add_argument(
         | 
| 403 | 
            +
                  '--skip_dockerfile',
         | 
| 404 | 
            +
                  action='store_true',
         | 
| 405 | 
            +
                  help=
         | 
| 406 | 
            +
                  'Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile.',
         | 
| 407 | 
            +
              )
         | 
| 354 408 | 
             
              args = parser.parse_args()
         | 
| 355 409 |  | 
| 356 410 | 
             
              main(args.model_path, args.download_checkpoints)
         | 
| @@ -58,13 +58,16 @@ class _BaseEvalResultHandler: | |
| 58 58 | 
             
              model: Model
         | 
| 59 59 | 
             
              eval_data: List[resources_pb2.EvalMetrics] = field(default_factory=list)
         | 
| 60 60 |  | 
| 61 | 
            -
              def evaluate_and_wait(self, dataset: Dataset):
         | 
| 61 | 
            +
              def evaluate_and_wait(self, dataset: Dataset, eval_info: dict = None):
         | 
| 62 62 | 
             
                from tqdm import tqdm
         | 
| 63 63 | 
             
                dataset_id = dataset.id
         | 
| 64 64 | 
             
                dataset_app_id = dataset.app_id
         | 
| 65 65 | 
             
                dataset_user_id = dataset.user_id
         | 
| 66 66 | 
             
                _ = self.model.evaluate(
         | 
| 67 | 
            -
                    dataset_id=dataset_id, | 
| 67 | 
            +
                    dataset_id=dataset_id,
         | 
| 68 | 
            +
                    dataset_app_id=dataset_app_id,
         | 
| 69 | 
            +
                    dataset_user_id=dataset_user_id,
         | 
| 70 | 
            +
                    eval_info=eval_info)
         | 
| 68 71 | 
             
                latest_eval = self.model.list_evaluations()[0]
         | 
| 69 72 | 
             
                excepted = 10
         | 
| 70 73 | 
             
                desc = f"Please wait for the evaluation process between model {self.get_model_name()} and dataset {dataset_user_id}/{dataset_app_id}/{dataset_id} to complete."
         | 
| @@ -83,7 +86,10 @@ class _BaseEvalResultHandler: | |
| 83 86 | 
             
                      f"Model has failed to evaluate \n {latest_eval.status}.\nPlease check your dataset inputs!"
         | 
| 84 87 | 
             
                  )
         | 
| 85 88 |  | 
| 86 | 
            -
              def find_eval_id(self, | 
| 89 | 
            +
              def find_eval_id(self,
         | 
| 90 | 
            +
                               datasets: List[Dataset] = [],
         | 
| 91 | 
            +
                               attempt_evaluate: bool = False,
         | 
| 92 | 
            +
                               eval_info: dict = None):
         | 
| 87 93 | 
             
                list_eval_outputs = self.model.list_evaluations()
         | 
| 88 94 | 
             
                self.eval_data = []
         | 
| 89 95 | 
             
                for dataset in datasets:
         | 
| @@ -117,7 +123,7 @@ class _BaseEvalResultHandler: | |
| 117 123 | 
             
                  # if not evaluated, but user wants to proceed it
         | 
| 118 124 | 
             
                  if not _is_found:
         | 
| 119 125 | 
             
                    if attempt_evaluate:
         | 
| 120 | 
            -
                      self.eval_data.append(self.evaluate_and_wait(dataset))
         | 
| 126 | 
            +
                      self.eval_data.append(self.evaluate_and_wait(dataset, eval_info=eval_info))
         | 
| 121 127 | 
             
                    # otherwise raise error
         | 
| 122 128 | 
             
                    else:
         | 
| 123 129 | 
             
                      raise Exception(
         | 
| @@ -53,6 +53,7 @@ class EvalResultCompare: | |
| 53 53 | 
             
                           models: Union[List[Model], List[str]],
         | 
| 54 54 | 
             
                           datasets: Union[Dataset, List[Dataset], str, List[str]],
         | 
| 55 55 | 
             
                           attempt_evaluate: bool = False,
         | 
| 56 | 
            +
                           eval_info: dict = None,
         | 
| 56 57 | 
             
                           auth_kwargs: dict = {}):
         | 
| 57 58 | 
             
                assert isinstance(models, list), ValueError("Expected list")
         | 
| 58 59 |  | 
| @@ -97,7 +98,7 @@ class EvalResultCompare: | |
| 97 98 | 
             
                    assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
         | 
| 98 99 | 
             
                  m = make_handler_by_type(model_type)(model=model)
         | 
| 99 100 | 
             
                  logger.info(f"* {m.get_model_name(pretify=True)}")
         | 
| 100 | 
            -
                  m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
         | 
| 101 | 
            +
                  m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate, eval_info=eval_info)
         | 
| 101 102 | 
             
                  self._eval_handlers.append(m)
         | 
| 102 103 |  | 
| 103 104 | 
             
              @property
         | 
| @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            clarifai/__init__.py,sha256= | 
| 1 | 
            +
            clarifai/__init__.py,sha256=Nt0sCLO5SfzhmX3Z2TYY4ZxvmbALrHAzZ1S3yAjDwgQ,23
         | 
| 2 2 | 
             
            clarifai/cli.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 3 3 | 
             
            clarifai/errors.py,sha256=RwzTajwds51wLD0MVlMC5kcpBnzRpreDLlazPSBZxrg,2605
         | 
| 4 4 | 
             
            clarifai/versions.py,sha256=jctnczzfGk_S3EnVqb2FjRKfSREkNmvNEwAAa_VoKiQ,222
         | 
| @@ -7,12 +7,12 @@ clarifai/client/app.py,sha256=6pckYme1urV2YJjLIYfeZ-vH0Z5YSQa51jzIMcEfwug,38342 | |
| 7 7 | 
             
            clarifai/client/base.py,sha256=hSHOqkXbSKyaRDeylMMnkhUHCAHhEqno4KI0CXGziBA,7536
         | 
| 8 8 | 
             
            clarifai/client/compute_cluster.py,sha256=lntZDLVDhS71Yj7mZrgq5uhnAuNPUnj48i3zMSuoUpk,8693
         | 
| 9 9 | 
             
            clarifai/client/dataset.py,sha256=oqp6ryg7IyxCZcItzownadYJKK0s1DtghHwITN71_6E,30160
         | 
| 10 | 
            -
            clarifai/client/deployment.py,sha256= | 
| 10 | 
            +
            clarifai/client/deployment.py,sha256=4gfvUvQY9adFS98B0vP9C5fR9OnDRV2JbUIdAkMymT8,2551
         | 
| 11 11 | 
             
            clarifai/client/input.py,sha256=cEVRytrMF1gCgwHLbXlSbPSEQN8uHpUAoKcCdyHO1pc,44406
         | 
| 12 12 | 
             
            clarifai/client/lister.py,sha256=03KGMvs5RVyYqxLsSrWhNc34I8kiF1Ph0NeyEwu7nMU,2082
         | 
| 13 | 
            -
            clarifai/client/model.py,sha256= | 
| 13 | 
            +
            clarifai/client/model.py,sha256=jNTyCxrME4vbU3Qw9nMSFwg4Ud8AuJrsmVYOMLBtbLI,83846
         | 
| 14 14 | 
             
            clarifai/client/module.py,sha256=FTkm8s9m-EaTKN7g9MnLhGJ9eETUfKG7aWZ3o1RshYs,4204
         | 
| 15 | 
            -
            clarifai/client/nodepool.py,sha256= | 
| 15 | 
            +
            clarifai/client/nodepool.py,sha256=DK8oqswjjrP6TqCqbw7Ge51Z7PxK3XmWZGLeUM3fd_A,10142
         | 
| 16 16 | 
             
            clarifai/client/search.py,sha256=GaPWN6JmTQGZaCHr6U1yv0zqR6wKFl7i9IVLg2ul1CI,14254
         | 
| 17 17 | 
             
            clarifai/client/user.py,sha256=0tcOk8_Yd1_ANj9E6sy9mz6s01V3qkmJS7pZVn_zUYo,17637
         | 
| 18 18 | 
             
            clarifai/client/workflow.py,sha256=Wm4Fry6lGx8T43sBUqRI7v7sAmuvq_4Jft3vSW8UUJU,10516
         | 
| @@ -41,7 +41,7 @@ clarifai/datasets/upload/loaders/README.md,sha256=aNRutSCTzLp2ruIZx74ZkN5AxpzwKO | |
| 41 41 | 
             
            clarifai/datasets/upload/loaders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 42 42 | 
             
            clarifai/datasets/upload/loaders/coco_captions.py,sha256=YfuNXplbdoH8N9ph7RyN9MfJTtOcJBG4ie1ow6-mELA,1516
         | 
| 43 43 | 
             
            clarifai/datasets/upload/loaders/coco_detection.py,sha256=_I_yThw435KS9SH7zheBbJDK3zFgjTImBsES__ijjMk,2831
         | 
| 44 | 
            -
            clarifai/datasets/upload/loaders/imagenet_classification.py,sha256= | 
| 44 | 
            +
            clarifai/datasets/upload/loaders/imagenet_classification.py,sha256=i7W5F6FTB3LwLmhPgjZHmbCbS3l4LmjsuBFKtjxl1pU,1962
         | 
| 45 45 | 
             
            clarifai/datasets/upload/loaders/xview_detection.py,sha256=hk8cZdYZimm4KOaZvBjYcC6ikURZMn51xmn7pXZT3HE,6052
         | 
| 46 46 | 
             
            clarifai/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 47 47 | 
             
            clarifai/models/api.py,sha256=d3FQQlG0mNDLrfEvchqaVcq4Tgb_TqryNnJtwp3c7sE,10961
         | 
| @@ -55,14 +55,14 @@ clarifai/rag/rag.py,sha256=L10TcV9E0PF1aJ2Nn1z1x6WVoUoGxbKt20lQXg8ksqo,12594 | |
| 55 55 | 
             
            clarifai/rag/utils.py,sha256=yr1jAcbpws4vFGBqlAwPPE7v1DRba48g8gixLFw8OhQ,4070
         | 
| 56 56 | 
             
            clarifai/runners/__init__.py,sha256=3vr4RVvN1IRy2SxJpyycAAvrUBbH-mXR7pqUmu4w36A,412
         | 
| 57 57 | 
             
            clarifai/runners/server.py,sha256=CVLrv2DjzCvKVXcJ4SWvcFWUZq0bdlBmyEpfVlfgT2A,4902
         | 
| 58 | 
            -
            clarifai/runners/dockerfile_template/Dockerfile.cpu.template,sha256= | 
| 59 | 
            -
            clarifai/runners/dockerfile_template/Dockerfile.cuda.template,sha256= | 
| 58 | 
            +
            clarifai/runners/dockerfile_template/Dockerfile.cpu.template,sha256=B35jcpqWBP3ALa2WRtbtBg8uvDyqP_PWZnJtIeAnjT0,1222
         | 
| 59 | 
            +
            clarifai/runners/dockerfile_template/Dockerfile.cuda.template,sha256=TMqTZBN1exMYzjLotn17DO4Je0rg9pBapIuwdohwht8,3228
         | 
| 60 60 | 
             
            clarifai/runners/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 61 61 | 
             
            clarifai/runners/models/base_typed_model.py,sha256=OnAk08Lo2Y1fGiBc6JJ6UvJ8P435cTsikTNYDkStDpI,7790
         | 
| 62 62 | 
             
            clarifai/runners/models/model_class.py,sha256=9JSPAr4U4K7xI0kSl-q0mHB06zknm2OR-8XIgBCto94,1611
         | 
| 63 63 | 
             
            clarifai/runners/models/model_runner.py,sha256=3vzoastQxkGRDK8T9aojDsLNBb9A3IiKm6YmbFrE9S0,6241
         | 
| 64 64 | 
             
            clarifai/runners/models/model_servicer.py,sha256=L5AuqKDZrsKOnv_Fz1Ld4-nzqehltLTsYAS7NIclm1g,2880
         | 
| 65 | 
            -
            clarifai/runners/models/model_upload.py,sha256= | 
| 65 | 
            +
            clarifai/runners/models/model_upload.py,sha256=or1yUlBLOFM9gD3Jjg6Vc9zhpK9uqnRrp4B1bV5VCKM,15985
         | 
| 66 66 | 
             
            clarifai/runners/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 67 67 | 
             
            clarifai/runners/utils/data_handler.py,sha256=sxy9zlAgI6ETuxCQhUgEXAn2GCsaW1GxpK6GTaMne0g,6966
         | 
| 68 68 | 
             
            clarifai/runners/utils/data_utils.py,sha256=R1iQ82TuQ9JwxCJk8yEB1Lyb0BYVhVbWJI9YDi1zGOs,318
         | 
| @@ -76,16 +76,16 @@ clarifai/utils/logging.py,sha256=rhutBRQJLtkNRz8IErNCgbIpvtl2fQ3D2otYcGqd3-Q,115 | |
| 76 76 | 
             
            clarifai/utils/misc.py,sha256=ptjt1NtteDT0EhrPoyQ7mgWtvoAQ-XNncQaZvNHb0KI,2253
         | 
| 77 77 | 
             
            clarifai/utils/model_train.py,sha256=Mndqy5GNu7kjQHjDyNVyamL0hQFLGSHcWhOuPyOvr1w,8005
         | 
| 78 78 | 
             
            clarifai/utils/evaluation/__init__.py,sha256=PYkurUrXrGevByj7RFb6CoU1iC7fllyQSfnnlo9WnY8,69
         | 
| 79 | 
            -
            clarifai/utils/evaluation/helpers.py,sha256= | 
| 80 | 
            -
            clarifai/utils/evaluation/main.py,sha256= | 
| 79 | 
            +
            clarifai/utils/evaluation/helpers.py,sha256=aZeHLI7oSmU5YDWQp5GdkYW5qbHx37nV9xwunKTAwWM,18549
         | 
| 80 | 
            +
            clarifai/utils/evaluation/main.py,sha256=sQAuMk0lPclXCYvy_rS7rYteo2xh9Ju13VNvbyGt_VM,15779
         | 
| 81 81 | 
             
            clarifai/utils/evaluation/testset_annotation_parser.py,sha256=iZfLw6oR1qgJ3MHMbOZXcGBLu7btSDn0VqdiAzpIm4g,5002
         | 
| 82 82 | 
             
            clarifai/workflows/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 83 83 | 
             
            clarifai/workflows/export.py,sha256=vICRhIreqDSShxLKjHNM2JwzKsf1B4fdXB0ciMcA70k,1945
         | 
| 84 84 | 
             
            clarifai/workflows/utils.py,sha256=nGeB_yjVgUO9kOeKTg4OBBaBz-AwXI3m-huSVj-9W18,1924
         | 
| 85 85 | 
             
            clarifai/workflows/validate.py,sha256=yJq03MaJqi5AK3alKGJJBR89xmmjAQ31sVufJUiOqY8,2556
         | 
| 86 | 
            -
            clarifai-10. | 
| 87 | 
            -
            clarifai-10. | 
| 88 | 
            -
            clarifai-10. | 
| 89 | 
            -
            clarifai-10. | 
| 90 | 
            -
            clarifai-10. | 
| 91 | 
            -
            clarifai-10. | 
| 86 | 
            +
            clarifai-10.9.0.dist-info/LICENSE,sha256=mUqF_d12-qE2n41g7C5_sq-BMLOcj6CNN-jevr15YHU,555
         | 
| 87 | 
            +
            clarifai-10.9.0.dist-info/METADATA,sha256=oP8QYgz6MkknQDIQ9ky3lWJS_NwvSVp4nK4o8VIcH20,19479
         | 
| 88 | 
            +
            clarifai-10.9.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
         | 
| 89 | 
            +
            clarifai-10.9.0.dist-info/entry_points.txt,sha256=qZOr_MIPG0dBBE1zringDJS_wXNGTAA_SQ-zcbmDHOw,82
         | 
| 90 | 
            +
            clarifai-10.9.0.dist-info/top_level.txt,sha256=wUMdCQGjkxaynZ6nZ9FAnvBUCgp5RJUVFSy2j-KYo0s,9
         | 
| 91 | 
            +
            clarifai-10.9.0.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |