clarifai 11.1.7rc2__py3-none-any.whl → 11.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/base.py +16 -3
  3. clarifai/cli/model.py +0 -25
  4. clarifai/client/model.py +393 -157
  5. clarifai/runners/__init__.py +7 -2
  6. clarifai/runners/dockerfile_template/Dockerfile.template +0 -3
  7. clarifai/runners/models/model_builder.py +11 -38
  8. clarifai/runners/models/model_class.py +28 -262
  9. clarifai/runners/models/model_run_locally.py +80 -15
  10. clarifai/runners/models/model_runner.py +0 -2
  11. clarifai/runners/models/model_servicer.py +2 -11
  12. clarifai/runners/utils/data_handler.py +210 -271
  13. {clarifai-11.1.7rc2.dist-info → clarifai-11.2.0.dist-info}/METADATA +17 -5
  14. clarifai-11.2.0.dist-info/RECORD +101 -0
  15. {clarifai-11.1.7rc2.dist-info → clarifai-11.2.0.dist-info}/WHEEL +1 -1
  16. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  17. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  18. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  19. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  20. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  21. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  22. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  23. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  24. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  25. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  26. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  27. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  28. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  29. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  30. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  31. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  32. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  33. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  34. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  35. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  36. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  37. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  38. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  39. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  40. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  41. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  42. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  43. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  44. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  45. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  46. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  47. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  48. clarifai/client/cli/__init__.py +0 -0
  49. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  50. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  51. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  52. clarifai/client/cli/base_cli.py +0 -88
  53. clarifai/client/cli/model_cli.py +0 -29
  54. clarifai/client/model_client.py +0 -447
  55. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  56. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  57. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  58. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  59. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  60. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  61. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  62. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  63. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  64. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  65. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  66. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  67. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  68. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  69. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  70. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  71. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  72. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  73. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  74. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  75. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  76. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  77. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  78. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  79. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  80. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  81. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  82. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  83. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  84. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  85. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  86. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  87. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  88. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  90. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  91. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  92. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  93. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  94. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  95. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  96. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  97. clarifai/runners/models/model_class_refract.py +0 -80
  98. clarifai/runners/models/model_upload.py +0 -607
  99. clarifai/runners/models/temp.py +0 -25
  100. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  101. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  102. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  103. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  104. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  105. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  106. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  107. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  108. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  109. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  110. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  111. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  112. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  113. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  114. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  115. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  116. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  117. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  118. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  119. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  120. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  121. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  122. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  123. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  124. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  125. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  126. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  127. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  128. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  129. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  130. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  131. clarifai/runners/utils/data_handler_refract.py +0 -213
  132. clarifai/runners/utils/data_types.py +0 -427
  133. clarifai/runners/utils/logger.py +0 -0
  134. clarifai/runners/utils/method_signatures.py +0 -472
  135. clarifai/runners/utils/serializers.py +0 -222
  136. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  137. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  138. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  139. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  140. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  141. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  142. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  143. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  144. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  145. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  146. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  147. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  148. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  149. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  150. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  151. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  152. clarifai-11.1.7rc2.dist-info/RECORD +0 -237
  153. {clarifai-11.1.7rc2.dist-info → clarifai-11.2.0.dist-info}/entry_points.txt +0 -0
  154. {clarifai-11.1.7rc2.dist-info → clarifai-11.2.0.dist-info/licenses}/LICENSE +0 -0
  155. {clarifai-11.1.7rc2.dist-info → clarifai-11.2.0.dist-info}/top_level.txt +0 -0
File without changes
@@ -1,88 +0,0 @@
1
- import click
2
- import os
3
- import yaml
4
-
5
- @click.group()
6
- @click.pass_context
7
- def cli(ctx):
8
- """Clarifai CLI"""
9
- ctx.ensure_object(dict)
10
- config_path = 'config.yaml'
11
- if os.path.exists(config_path):
12
- ctx.obj = _from_yaml(config_path)
13
- print("Loaded config from file.")
14
- print(f"Config: {ctx.obj}")
15
- else:
16
- ctx.obj = {}
17
-
18
- def _from_yaml(filename: str):
19
- try:
20
- with open(filename, 'r') as f:
21
- return yaml.safe_load(f)
22
- except yaml.YAMLError as e:
23
- click.echo(f"Error reading YAML file: {e}", err=True)
24
- return {}
25
-
26
- def _dump_yaml(data, filename: str):
27
- try:
28
- with open(filename, 'w') as f:
29
- yaml.dump(data, f)
30
- except Exception as e:
31
- click.echo(f"Error writing YAML file: {e}", err=True)
32
-
33
- def _set_base_url(env):
34
- environments = {'prod': 'https://api.clarifai.com', 'staging': 'https://api-staging.clarifai.com', 'dev': 'https://api-dev.clarifai.com'}
35
- return environments.get(env, 'https://api.clarifai.com')
36
-
37
-
38
- @cli.command()
39
- @click.option('--config', type=click.Path(), required=False, help='Path to the config file')
40
- @click.option('-e', '--env', required=False, help='Environment', type=click.Choice(['prod', 'staging', 'dev']))
41
- @click.option('--user_id', required=False, help='User ID')
42
- @click.pass_context
43
- def login(ctx, config, env, user_id):
44
- """Login command to set PAT and other configurations."""
45
-
46
- if config and os.path.exists(config):
47
- ctx.obj = _from_yaml(config)
48
-
49
- if 'pat' in ctx.obj:
50
- os.environ["CLARIFAI_PAT"] = ctx.obj['pat']
51
- click.echo("Loaded PAT from config file.")
52
- elif 'CLARIFAI_PAT' in os.environ:
53
- ctx.obj['pat'] = os.environ["CLARIFAI_PAT"]
54
- click.echo("Loaded PAT from environment variable.")
55
- else:
56
- _pat = click.prompt("Get your PAT from https://clarifai.com/settings/security and pass it here", type=str)
57
- os.environ["CLARIFAI_PAT"] = _pat
58
- ctx.obj['pat'] = _pat
59
- click.echo("PAT saved successfully.")
60
-
61
- if user_id:
62
- ctx.obj['user_id'] = user_id
63
- os.environ["CLARIFAI_USER_ID"] = ctx.obj['user_id']
64
- elif 'user_id' in ctx.obj or 'CLARIFAI_USER_ID' in os.environ:
65
- ctx.obj['user_id'] = ctx.obj.get('user_id', os.environ["CLARIFAI_USER_ID"])
66
- os.environ["CLARIFAI_USER_ID"] = ctx.obj['user_id']
67
-
68
- if env:
69
- ctx.obj['env'] = env
70
- ctx.obj['base_url'] = _set_base_url(env)
71
- os.environ["CLARIFAI_API_BASE"] = ctx.obj['base_url']
72
- elif 'env' in ctx.obj:
73
- ctx.obj['env'] = ctx.obj.get('env', "prod")
74
- ctx.obj['base_url'] = _set_base_url(ctx.obj['env'])
75
- os.environ["CLARIFAI_API_BASE"] = ctx.obj['base_url']
76
- elif 'CLARIFAI_API_BASE' in os.environ:
77
- ctx.obj['base_url'] = os.environ["CLARIFAI_API_BASE"]
78
-
79
- _dump_yaml(ctx.obj, 'config.yaml')
80
-
81
- click.echo("Login successful.")
82
-
83
- # Import the model CLI commands to register them
84
- from clarifai.client.cli.model_cli import model # Ensure this is the correct import path
85
-
86
-
87
- if __name__ == '__main__':
88
- cli()
@@ -1,29 +0,0 @@
1
- import click
2
- from clarifai.client.cli.base_cli import cli
3
-
4
- @cli.group()
5
- def model():
6
- """Manage models: upload, test locally"""
7
- pass
8
-
9
- @model.command()
10
- @click.argument('model_path', type=click.Path(exists=True))
11
- @click.option('--download_checkpoints', is_flag=True, help='Flag to download checkpoints before uploading and including them in the tar file that is uploaded. Defaults to False, which will attempt to download them at docker build time.', )
12
- @click.option('--skip_dockerfile', is_flag =True, help='Flag to skip generating a dockerfile so that you can manually edit an already created dockerfile.', )
13
- def upload(model_path, download_checkpoints, skip_dockerfile):
14
- """Upload a model to Clarifai."""
15
- from clarifai.runners.models import model_upload
16
-
17
- model_upload.main(model_path, download_checkpoints, skip_dockerfile)
18
-
19
- @model.command()
20
- @click.argument('model_path', type=click.Path(exists=True))
21
- def test_locally(model_path):
22
- """Test model locally."""
23
- try:
24
- from clarifai.runners.models import run_test_locally
25
- run_test_locally.main(model_path)
26
- click.echo(f"Model tested locally from {model_path}.")
27
- except Exception as e:
28
- click.echo(f"Failed to test model locally: {e}", err=True)
29
-
@@ -1,447 +0,0 @@
1
- import time
2
- from typing import Any, Dict, Iterator, List
3
-
4
- from clarifai_grpc.grpc.api import resources_pb2, service_pb2
5
- from clarifai_grpc.grpc.api.status import status_code_pb2
6
-
7
- from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS
8
- from clarifai.errors import UserError
9
- from clarifai.runners.utils.method_signatures import (CompatibilitySerializer, deserialize,
10
- get_stream_from_signature, serialize,
11
- signatures_from_json)
12
- from clarifai.utils.logging import logger
13
- from clarifai.utils.misc import BackoffIterator, status_is_retryable
14
-
15
-
16
- class ModelClient:
17
- '''
18
- Client for calling model predict, generate, and stream methods.
19
- '''
20
-
21
- def __init__(self, stub, request_template: service_pb2.PostModelOutputsRequest = None):
22
- '''
23
- Initialize the model client.
24
-
25
- Args:
26
- stub: The gRPC stub for the model.
27
- request_template: The template for the request to send to the model, including
28
- common fields like model_id, model_version, cluster, etc.
29
- '''
30
- self.STUB = stub
31
- self.request_template = request_template or service_pb2.PostModelOutputsRequest()
32
- self._method_signatures = None
33
- self._defined = False
34
-
35
- def fetch(self):
36
- '''
37
- Fetch function signature definitions from the model and define the functions in the client
38
- '''
39
- if self._defined:
40
- return
41
- try:
42
- self._fetch_signatures()
43
- self._define_functions()
44
- finally:
45
- self._defined = True
46
-
47
- def __getattr__(self, name):
48
- if not self._defined:
49
- self.fetch()
50
- return self.__getattribute__(name)
51
-
52
- def _fetch_signatures(self):
53
- '''
54
- Fetch the method signatures from the model.
55
-
56
- Returns:
57
- Dict: The method signatures.
58
- '''
59
- #request = resources_pb2.GetModelSignaturesRequest()
60
- #response = self.stub.GetModelSignatures(request)
61
- #self._method_signatures = json.loads(response.signatures) # or define protos
62
- # TODO this could use a new endpoint to get the signatures
63
- # for local grpc models, we'll also have to add the endpoint to the model servicer
64
- # for now we'll just use the predict endpoint with a special method name
65
-
66
- request = service_pb2.PostModelOutputsRequest()
67
- request.CopyFrom(self.request_template)
68
- # request.model.model_version.output_info.params['_method_name'] = '_GET_SIGNATURES'
69
- inp = request.inputs.add() # empty input for this method
70
- inp.data.parts.add() # empty part for this input
71
- inp.data.metadata['_method_name'] = '_GET_SIGNATURES'
72
- start_time = time.time()
73
- backoff_iterator = BackoffIterator(10)
74
- while True:
75
- response = self.STUB.PostModelOutputs(request)
76
- if status_is_retryable(
77
- response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
78
- logger.info(f"Retrying model info fetch with response {response.status!r}")
79
- time.sleep(next(backoff_iterator))
80
- continue
81
- break
82
- if (response.status.code == status_code_pb2.INPUT_UNSUPPORTED_FORMAT or
83
- (response.status.code == status_code_pb2.SUCCESS and
84
- response.outputs[0].data.text.raw == '')):
85
- # return codes/values from older models that don't support _GET_SIGNATURES
86
- self._method_signatures = {}
87
- self._define_compatability_functions()
88
- return
89
- if response.status.code != status_code_pb2.SUCCESS:
90
- raise Exception(f"Model failed with response {response!r}")
91
- self._method_signatures = signatures_from_json(response.outputs[0].data.text.raw)
92
-
93
- def _define_functions(self):
94
- '''
95
- Define the functions based on the method signatures.
96
- '''
97
- for method_name, method_signature in self._method_signatures.items():
98
- # define the function in this client instance
99
- if resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_UNARY':
100
- call_func = self._predict
101
- elif resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_STREAMING':
102
- call_func = self._generate
103
- elif resources_pb2.RunnerMethodType.Name(
104
- method_signature.method_type) == 'STREAMING_STREAMING':
105
- call_func = self._stream
106
- else:
107
- raise ValueError(f"Unknown method type {method_signature.method_type}")
108
-
109
- # method argnames, in order, collapsing nested keys to corresponding user function args
110
- method_argnames = []
111
- for var in method_signature.input_fields:
112
- outer = var.name.split('.', 1)[0]
113
- if outer in method_argnames:
114
- continue
115
- method_argnames.append(outer)
116
-
117
- def bind_f(method_name, method_argnames, call_func):
118
-
119
- def f(*args, **kwargs):
120
- if len(args) > len(method_argnames):
121
- raise TypeError(
122
- f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
123
- )
124
- for name, arg in zip(method_argnames, args): # handle positional with zip shortest
125
- if name in kwargs:
126
- raise TypeError(f"Multiple values for argument {name}")
127
- kwargs[name] = arg
128
- return call_func(kwargs, method_name)
129
-
130
- return f
131
-
132
- # need to bind method_name to the value, not the mutating loop variable
133
- f = bind_f(method_name, method_argnames, call_func)
134
-
135
- # set names, annotations and docstrings
136
- f.__name__ = method_name
137
- f.__qualname__ = f'{self.__class__.__name__}.{method_name}'
138
- # input_annotations = json.loads(method_signature.annotations_json)
139
- # return_annotation = input_annotations.pop('return', None)
140
- # sig = inspect.signature(f).replace(
141
- # parameters=[
142
- # inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=v)
143
- # for k, v in input_annotations.items()
144
- # ],
145
- # return_annotation=return_annotation,
146
- # )
147
- # f.__signature__ = sig
148
- f.__doc__ = method_signature.description
149
- setattr(self, method_name, f)
150
-
151
- def _define_compatability_functions(self):
152
-
153
- serializer = CompatibilitySerializer()
154
-
155
- def predict(input: Any) -> Any:
156
- proto = resources_pb2.Input()
157
- serializer.serialize(proto.data, input)
158
- # always use text.raw for compat
159
- if proto.data.string_value:
160
- proto.data.text.raw = proto.data.string_value
161
- proto.data.string_value = ''
162
- response = self._predict_by_proto([proto])
163
- if response.status.code != status_code_pb2.SUCCESS:
164
- raise Exception(f"Model predict failed with response {response!r}")
165
- response_data = response.outputs[0].data
166
- if response_data.text.raw:
167
- response_data.string_value = response_data.text.raw
168
- response_data.text.raw = ''
169
- return serializer.deserialize(response_data)
170
-
171
- self.predict = predict
172
-
173
- def _predict(
174
- self,
175
- inputs, # TODO set up functions according to fetched signatures?
176
- method_name: str = 'predict',
177
- ) -> Any:
178
- input_signature = self._method_signatures[method_name].input_fields
179
- output_signature = self._method_signatures[method_name].output_fields
180
-
181
- batch_input = True
182
- if isinstance(inputs, dict):
183
- inputs = [inputs]
184
- batch_input = False
185
-
186
- proto_inputs = []
187
- for input in inputs:
188
- proto = resources_pb2.Input()
189
- serialize(input, input_signature, proto.data)
190
- proto_inputs.append(proto)
191
-
192
- response = self._predict_by_proto(proto_inputs, method_name)
193
- #print(response)
194
-
195
- outputs = []
196
- for output in response.outputs:
197
- outputs.append(deserialize(output.data, output_signature, is_output=True))
198
- if batch_input:
199
- return outputs
200
- return outputs[0]
201
-
202
- def _predict_by_proto(
203
- self,
204
- inputs: List[resources_pb2.Input],
205
- method_name: str = None,
206
- inference_params: Dict = None,
207
- output_config: Dict = None,
208
- ) -> service_pb2.MultiOutputResponse:
209
- """Predicts the model based on the given inputs.
210
-
211
- Args:
212
- inputs (List[resources_pb2.Input]): The inputs to predict.
213
- method_name (str): The remote method name to call.
214
- inference_params (Dict): Inference parameters to override.
215
- output_config (Dict): Output configuration to override.
216
-
217
- Returns:
218
- service_pb2.MultiOutputResponse: The prediction response(s).
219
- """
220
- if not isinstance(inputs, list):
221
- raise UserError('Invalid inputs, inputs must be a list of Input objects.')
222
- if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
223
- raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}.")
224
-
225
- request = service_pb2.PostModelOutputsRequest()
226
- request.CopyFrom(self.request_template)
227
-
228
- request.inputs.extend(inputs)
229
-
230
- if method_name:
231
- # TODO put in new proto field?
232
- for inp in request.inputs:
233
- inp.data.metadata['_method_name'] = method_name
234
- if inference_params:
235
- request.model.model_version.output_info.params.update(inference_params)
236
- if output_config:
237
- request.model.model_version.output_info.output_config.MergeFrom(
238
- resources_pb2.OutputConfig(**output_config))
239
-
240
- start_time = time.time()
241
- backoff_iterator = BackoffIterator(10)
242
- while True:
243
- response = self.STUB.PostModelOutputs(request)
244
- if status_is_retryable(
245
- response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
246
- logger.info(f"Model predict failed with response {response!r}")
247
- time.sleep(next(backoff_iterator))
248
- continue
249
-
250
- if response.status.code != status_code_pb2.SUCCESS:
251
- raise Exception(f"Model predict failed with response {response!r}")
252
- break
253
-
254
- return response
255
-
256
- def _generate(
257
- self,
258
- inputs, # TODO set up functions according to fetched signatures?
259
- method_name: str = 'generate',
260
- ) -> Any:
261
- input_signature = self._method_signatures[method_name].input_fields
262
- output_signature = self._method_signatures[method_name].output_fields
263
-
264
- batch_input = True
265
- if isinstance(inputs, dict):
266
- inputs = [inputs]
267
- batch_input = False
268
-
269
- proto_inputs = []
270
- for input in inputs:
271
- proto = resources_pb2.Input()
272
- serialize(input, input_signature, proto.data)
273
- proto_inputs.append(proto)
274
-
275
- response_stream = self._generate_by_proto(proto_inputs, method_name)
276
- #print(response)
277
-
278
- for response in response_stream:
279
- outputs = []
280
- for output in response.outputs:
281
- outputs.append(deserialize(output.data, output_signature, is_output=True))
282
- if batch_input:
283
- yield outputs
284
- yield outputs[0]
285
-
286
- def _generate_by_proto(
287
- self,
288
- inputs: List[resources_pb2.Input],
289
- method_name: str = None,
290
- inference_params: Dict = {},
291
- output_config: Dict = {},
292
- ):
293
- """Generate the stream output on model based on the given inputs.
294
-
295
- Args:
296
- inputs (list[Input]): The inputs to generate, must be less than 128.
297
- method_name (str): The remote method name to call.
298
- inference_params (dict): The inference params to override.
299
- output_config (dict): The output config to override.
300
- """
301
- if not isinstance(inputs, list):
302
- raise UserError('Invalid inputs, inputs must be a list of Input objects.')
303
- if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
304
- raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
305
- ) # TODO Use Chunker for inputs len > 128
306
-
307
- request = service_pb2.PostModelOutputsRequest()
308
- request.CopyFrom(self.request_template)
309
-
310
- request.inputs.extend(inputs)
311
-
312
- if method_name:
313
- # TODO put in new proto field?
314
- for inp in request.inputs:
315
- inp.data.metadata['_method_name'] = method_name
316
- if inference_params:
317
- request.model.model_version.output_info.params.update(inference_params)
318
- if output_config:
319
- request.model.model_version.output_info.output_config.MergeFromDict(output_config)
320
-
321
- start_time = time.time()
322
- backoff_iterator = BackoffIterator(10)
323
- started = False
324
- while not started:
325
- stream_response = self.STUB.GenerateModelOutputs(request)
326
- try:
327
- response = next(stream_response) # get the first response
328
- except StopIteration:
329
- raise Exception("Model Generate failed with no response")
330
- if status_is_retryable(response.status.code) and \
331
- time.time() - start_time < 60 * 10:
332
- logger.info("Model is still deploying, please wait...")
333
- time.sleep(next(backoff_iterator))
334
- continue
335
- if response.status.code != status_code_pb2.SUCCESS:
336
- raise Exception(f"Model Generate failed with response {response.status!r}")
337
- started = True
338
-
339
- yield response # yield the first response
340
-
341
- for response in stream_response:
342
- if response.status.code != status_code_pb2.SUCCESS:
343
- raise Exception(f"Model Generate failed with response {response.status!r}")
344
- yield response
345
-
346
- def _stream(
347
- self,
348
- inputs,
349
- method_name: str = 'stream',
350
- ) -> Any:
351
- input_signature = self._method_signatures[method_name].input_fields
352
- output_signature = self._method_signatures[method_name].output_fields
353
-
354
- if isinstance(inputs, list):
355
- assert len(inputs) == 1, 'streaming methods do not support batched calls'
356
- inputs = inputs[0]
357
- assert isinstance(inputs, dict)
358
- kwargs = inputs
359
-
360
- # find the streaming vars in the input signature, and the streaming input python param
361
- stream_sig = get_stream_from_signature(input_signature)
362
- if stream_sig is None:
363
- raise ValueError("Streaming method must have a Stream input")
364
- stream_argname = stream_sig.name
365
-
366
- # get the streaming input generator from the user-provided function arg values
367
- user_inputs_generator = kwargs.pop(stream_argname)
368
-
369
- def _input_proto_stream():
370
- # first item contains all the inputs and the first stream item
371
- proto = resources_pb2.Input()
372
- try:
373
- item = next(user_inputs_generator)
374
- except StopIteration:
375
- return # no items to stream
376
- kwargs[stream_argname] = item
377
- serialize(kwargs, input_signature, proto.data)
378
-
379
- yield proto
380
-
381
- # subsequent items are just the stream items
382
- for item in user_inputs_generator:
383
- proto = resources_pb2.Input()
384
- serialize({stream_argname: item}, [stream_sig], proto.data)
385
- yield proto
386
-
387
- response_stream = self._stream_by_proto(_input_proto_stream(), method_name)
388
- #print(response)
389
-
390
- for response in response_stream:
391
- assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
392
- yield deserialize(response.outputs[0].data, output_signature, is_output=True)
393
-
394
- def _req_iterator(self,
395
- input_iterator: Iterator[List[resources_pb2.Input]],
396
- method_name: str = None,
397
- inference_params: Dict = {},
398
- output_config: Dict = {}):
399
- request = service_pb2.PostModelOutputsRequest()
400
- request.CopyFrom(self.request_template)
401
- if inference_params:
402
- request.model.model_version.output_info.params.update(inference_params)
403
- if output_config:
404
- request.model.model_version.output_info.output_config.MergeFromDict(output_config)
405
- for inputs in input_iterator:
406
- req = service_pb2.PostModelOutputsRequest()
407
- req.CopyFrom(request)
408
- if isinstance(inputs, list):
409
- req.inputs.extend(inputs)
410
- else:
411
- req.inputs.append(inputs)
412
- # TODO: put into new proto field?
413
- for inp in req.inputs:
414
- inp.data.metadata['_method_name'] = method_name
415
- yield req
416
-
417
- def _stream_by_proto(self,
418
- inputs: Iterator[List[resources_pb2.Input]],
419
- method_name: str = None,
420
- inference_params: Dict = {},
421
- output_config: Dict = {}):
422
- """Generate the stream output on model based on the given stream of inputs.
423
- """
424
- # if not isinstance(inputs, Iterator[List[Input]]):
425
- # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
426
-
427
- request = self._req_iterator(inputs, method_name, inference_params, output_config)
428
-
429
- start_time = time.time()
430
- backoff_iterator = BackoffIterator(10)
431
- generation_started = False
432
- while True:
433
- if generation_started:
434
- break
435
- stream_response = self.STUB.StreamModelOutputs(request)
436
- for response in stream_response:
437
- if status_is_retryable(response.status.code) and \
438
- time.time() - start_time < 60 * 10:
439
- logger.info("Model is still deploying, please wait...")
440
- time.sleep(next(backoff_iterator))
441
- break
442
- if response.status.code != status_code_pb2.SUCCESS:
443
- raise Exception(f"Model Predict failed with response {response.status!r}")
444
- else:
445
- if not generation_started:
446
- generation_started = True
447
- yield response