clarifai 11.4.1__py3-none-any.whl → 11.4.3rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/__pycache__/__init__.cpython-312.pyc +0 -0
  3. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  4. clarifai/__pycache__/errors.cpython-312.pyc +0 -0
  5. clarifai/__pycache__/errors.cpython-39.pyc +0 -0
  6. clarifai/__pycache__/versions.cpython-312.pyc +0 -0
  7. clarifai/__pycache__/versions.cpython-39.pyc +0 -0
  8. clarifai/cli/__pycache__/__init__.cpython-312.pyc +0 -0
  9. clarifai/cli/__pycache__/base.cpython-312.pyc +0 -0
  10. clarifai/cli/__pycache__/compute_cluster.cpython-312.pyc +0 -0
  11. clarifai/cli/__pycache__/deployment.cpython-312.pyc +0 -0
  12. clarifai/cli/__pycache__/model.cpython-312.pyc +0 -0
  13. clarifai/cli/__pycache__/nodepool.cpython-312.pyc +0 -0
  14. clarifai/cli/base.py +8 -0
  15. clarifai/cli/model.py +6 -6
  16. clarifai/client/__pycache__/__init__.cpython-312.pyc +0 -0
  17. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  18. clarifai/client/__pycache__/app.cpython-312.pyc +0 -0
  19. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  20. clarifai/client/__pycache__/base.cpython-312.pyc +0 -0
  21. clarifai/client/__pycache__/base.cpython-39.pyc +0 -0
  22. clarifai/client/__pycache__/compute_cluster.cpython-312.pyc +0 -0
  23. clarifai/client/__pycache__/dataset.cpython-312.pyc +0 -0
  24. clarifai/client/__pycache__/deployment.cpython-312.pyc +0 -0
  25. clarifai/client/__pycache__/input.cpython-312.pyc +0 -0
  26. clarifai/client/__pycache__/lister.cpython-312.pyc +0 -0
  27. clarifai/client/__pycache__/model.cpython-312.pyc +0 -0
  28. clarifai/client/__pycache__/model_client.cpython-312.pyc +0 -0
  29. clarifai/client/__pycache__/module.cpython-312.pyc +0 -0
  30. clarifai/client/__pycache__/nodepool.cpython-312.pyc +0 -0
  31. clarifai/client/__pycache__/runner.cpython-312.pyc +0 -0
  32. clarifai/client/__pycache__/search.cpython-312.pyc +0 -0
  33. clarifai/client/__pycache__/user.cpython-312.pyc +0 -0
  34. clarifai/client/__pycache__/workflow.cpython-312.pyc +0 -0
  35. clarifai/client/auth/__pycache__/__init__.cpython-312.pyc +0 -0
  36. clarifai/client/auth/__pycache__/__init__.cpython-39.pyc +0 -0
  37. clarifai/client/auth/__pycache__/helper.cpython-312.pyc +0 -0
  38. clarifai/client/auth/__pycache__/helper.cpython-39.pyc +0 -0
  39. clarifai/client/auth/__pycache__/register.cpython-312.pyc +0 -0
  40. clarifai/client/auth/__pycache__/register.cpython-39.pyc +0 -0
  41. clarifai/client/auth/__pycache__/stub.cpython-312.pyc +0 -0
  42. clarifai/client/auth/__pycache__/stub.cpython-39.pyc +0 -0
  43. clarifai/client/dataset.py +6 -0
  44. clarifai/constants/__pycache__/base.cpython-312.pyc +0 -0
  45. clarifai/constants/__pycache__/base.cpython-39.pyc +0 -0
  46. clarifai/constants/__pycache__/dataset.cpython-312.pyc +0 -0
  47. clarifai/constants/__pycache__/input.cpython-312.pyc +0 -0
  48. clarifai/constants/__pycache__/model.cpython-312.pyc +0 -0
  49. clarifai/constants/__pycache__/rag.cpython-312.pyc +0 -0
  50. clarifai/constants/__pycache__/search.cpython-312.pyc +0 -0
  51. clarifai/constants/__pycache__/workflow.cpython-312.pyc +0 -0
  52. clarifai/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
  53. clarifai/datasets/export/__pycache__/__init__.cpython-312.pyc +0 -0
  54. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-312.pyc +0 -0
  55. clarifai/datasets/upload/__pycache__/__init__.cpython-312.pyc +0 -0
  56. clarifai/datasets/upload/__pycache__/base.cpython-312.pyc +0 -0
  57. clarifai/datasets/upload/__pycache__/features.cpython-312.pyc +0 -0
  58. clarifai/datasets/upload/__pycache__/image.cpython-312.pyc +0 -0
  59. clarifai/datasets/upload/__pycache__/multimodal.cpython-312.pyc +0 -0
  60. clarifai/datasets/upload/__pycache__/text.cpython-312.pyc +0 -0
  61. clarifai/datasets/upload/__pycache__/utils.cpython-312.pyc +0 -0
  62. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-312.pyc +0 -0
  63. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-312.pyc +0 -0
  64. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-312.pyc +0 -0
  65. clarifai/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  66. clarifai/modules/__pycache__/css.cpython-312.pyc +0 -0
  67. clarifai/rag/__pycache__/__init__.cpython-312.pyc +0 -0
  68. clarifai/rag/__pycache__/rag.cpython-312.pyc +0 -0
  69. clarifai/rag/__pycache__/utils.cpython-312.pyc +0 -0
  70. clarifai/runners/__pycache__/__init__.cpython-312.pyc +0 -0
  71. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  72. clarifai/runners/__pycache__/server.cpython-312.pyc +0 -0
  73. clarifai/runners/models/__pycache__/__init__.cpython-312.pyc +0 -0
  74. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  75. clarifai/runners/models/__pycache__/base_typed_model.cpython-312.pyc +0 -0
  76. clarifai/runners/models/__pycache__/mcp_class.cpython-312.pyc +0 -0
  77. clarifai/runners/models/__pycache__/model_builder.cpython-312.pyc +0 -0
  78. clarifai/runners/models/__pycache__/model_builder.cpython-39.pyc +0 -0
  79. clarifai/runners/models/__pycache__/model_class.cpython-312.pyc +0 -0
  80. clarifai/runners/models/__pycache__/model_run_locally.cpython-312.pyc +0 -0
  81. clarifai/runners/models/__pycache__/model_runner.cpython-312.pyc +0 -0
  82. clarifai/runners/models/__pycache__/model_servicer.cpython-312.pyc +0 -0
  83. clarifai/runners/models/__pycache__/test_model_builder.cpython-312-pytest-8.3.5.pyc +0 -0
  84. clarifai/runners/models/base_typed_model.py +238 -0
  85. clarifai/runners/models/example_mcp_server.py +44 -0
  86. clarifai/runners/models/mcp_class.py +143 -0
  87. clarifai/runners/models/mcp_class.py~ +149 -0
  88. clarifai/runners/models/model_builder.py +167 -38
  89. clarifai/runners/models/model_class.py +5 -22
  90. clarifai/runners/models/model_run_locally.py +0 -4
  91. clarifai/runners/models/test_model_builder.py +89 -0
  92. clarifai/runners/models/visual_classifier_class.py +75 -0
  93. clarifai/runners/models/visual_detector_class.py +79 -0
  94. clarifai/runners/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  95. clarifai/runners/utils/__pycache__/code_script.cpython-312.pyc +0 -0
  96. clarifai/runners/utils/__pycache__/const.cpython-312.pyc +0 -0
  97. clarifai/runners/utils/__pycache__/data_handler.cpython-312.pyc +0 -0
  98. clarifai/runners/utils/__pycache__/data_types.cpython-312.pyc +0 -0
  99. clarifai/runners/utils/__pycache__/data_utils.cpython-312.pyc +0 -0
  100. clarifai/runners/utils/__pycache__/loader.cpython-312.pyc +0 -0
  101. clarifai/runners/utils/__pycache__/method_signatures.cpython-312.pyc +0 -0
  102. clarifai/runners/utils/__pycache__/serializers.cpython-312.pyc +0 -0
  103. clarifai/runners/utils/__pycache__/url_fetcher.cpython-312.pyc +0 -0
  104. clarifai/runners/utils/code_script.py +41 -44
  105. clarifai/runners/utils/const.py +15 -0
  106. clarifai/runners/utils/data_handler.py +231 -0
  107. clarifai/runners/utils/data_types/__pycache__/__init__.cpython-312.pyc +0 -0
  108. clarifai/runners/utils/data_types/__pycache__/data_types.cpython-312.pyc +0 -0
  109. clarifai/runners/utils/data_utils.py +33 -5
  110. clarifai/runners/utils/loader.py +23 -2
  111. clarifai/runners/utils/method_signatures.py +4 -4
  112. clarifai/schema/__pycache__/search.cpython-312.pyc +0 -0
  113. clarifai/urls/__pycache__/helper.cpython-312.pyc +0 -0
  114. clarifai/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  115. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  116. clarifai/utils/__pycache__/cli.cpython-312.pyc +0 -0
  117. clarifai/utils/__pycache__/config.cpython-312.pyc +0 -0
  118. clarifai/utils/__pycache__/constants.cpython-312.pyc +0 -0
  119. clarifai/utils/__pycache__/constants.cpython-39.pyc +0 -0
  120. clarifai/utils/__pycache__/logging.cpython-312.pyc +0 -0
  121. clarifai/utils/__pycache__/logging.cpython-39.pyc +0 -0
  122. clarifai/utils/__pycache__/misc.cpython-312.pyc +0 -0
  123. clarifai/utils/__pycache__/misc.cpython-39.pyc +0 -0
  124. clarifai/utils/__pycache__/model_train.cpython-312.pyc +0 -0
  125. clarifai/utils/__pycache__/protobuf.cpython-312.pyc +0 -0
  126. clarifai/utils/config.py +19 -0
  127. clarifai/utils/config.py~ +145 -0
  128. clarifai/utils/evaluation/__pycache__/__init__.cpython-312.pyc +0 -0
  129. clarifai/utils/evaluation/__pycache__/helpers.cpython-312.pyc +0 -0
  130. clarifai/utils/evaluation/__pycache__/main.cpython-312.pyc +0 -0
  131. clarifai/utils/logging.py +22 -5
  132. clarifai/workflows/__pycache__/__init__.cpython-312.pyc +0 -0
  133. clarifai/workflows/__pycache__/export.cpython-312.pyc +0 -0
  134. clarifai/workflows/__pycache__/utils.cpython-312.pyc +0 -0
  135. clarifai/workflows/__pycache__/validate.cpython-312.pyc +0 -0
  136. {clarifai-11.4.1.dist-info → clarifai-11.4.3rc1.dist-info}/METADATA +2 -14
  137. clarifai-11.4.3rc1.dist-info/RECORD +230 -0
  138. {clarifai-11.4.1.dist-info → clarifai-11.4.3rc1.dist-info}/WHEEL +1 -1
  139. clarifai-11.4.1.dist-info/RECORD +0 -109
  140. {clarifai-11.4.1.dist-info/licenses → clarifai-11.4.3rc1.dist-info}/LICENSE +0 -0
  141. {clarifai-11.4.1.dist-info → clarifai-11.4.3rc1.dist-info}/entry_points.txt +0 -0
  142. {clarifai-11.4.1.dist-info → clarifai-11.4.3rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ """Base class for creating Model Context Protocol (MCP) servers."""
2
+
3
+ import asyncio
4
+ import json
5
+ from typing import Any
6
+
7
+ from fastmcp import Client, FastMCP # use fastmcp v2 not the built in mcp
8
+ from mcp import types
9
+ from mcp.shared.exceptions import McpError
10
+
11
+ from clarifai.runners.models.model_class import ModelClass
12
+
13
+ # class MCPServerProvider(abc.ABC):
14
+ # """
15
+ # Base class for creating Model Context Protocol (MCP) servers.
16
+
17
+ # This class provides a base implementation of the MCP server, including
18
+ # methods for handling requests and responses, as well as error handling and
19
+ # logging.
20
+
21
+ # Attributes:
22
+ # _server: The FastMCP server instance.
23
+ # _tools: List of tools available in the server.
24
+ # _resources: List of resources available in the server.
25
+ # _prompts: List of prompts available in the server.
26
+
27
+ # Methods:
28
+ # get_server(): Returns the FastMCP server instance.
29
+ # mcp_transport(msg): Handles incoming messages and sends them to the FastMCP server.
30
+ # """
31
+
32
+ # @abc.abstractmethod
33
+ # def get_server(self) -> FastMCP:
34
+ # """Required method for each subclass to implement to return the FastMCP server to use."""
35
+ # if self._server is None:
36
+ # raise ValueError("Server not initialized")
37
+ # return self._server
38
+
39
+
40
+ class MCPModelClass(ModelClass):
41
+ """Base class for wrapping FastMCP servers as a model running in Clarfai. This handles
42
+ all the transport between the API and the MCP server here. Simply subclass this and implement
43
+ the get_server() method to return the FastMCP server instance. The server is then used to
44
+ handle all the requests and responses.
45
+ """
46
+
47
+ def load_model(self):
48
+ # in memory transport provided in fastmcp v2 so we can easily use the client functions.
49
+ self.client = Client(self.get_server())
50
+
51
+ def get_server(self) -> FastMCP:
52
+ """Required method for each subclass to implement to return the FastMCP server to use."""
53
+ raise NotImplementedError("Subclasses must implement get_server() method")
54
+
55
+ @ModelClass.method
56
+ def mcp_transport(self, msg: str) -> str:
57
+ """The single model method to get the jsonrpc message and send it to the FastMCP server then
58
+ return it's response.
59
+
60
+ """
61
+
62
+ async def send_notification(client_message: types.ClientNotification) -> None:
63
+ async with self.client:
64
+ # Strip the jsonrpc field since send_notification will also pass it in for some reason.
65
+ client_message = types.ClientNotification.model_validate(
66
+ client_message.model_dump(
67
+ by_alias=True, mode="json", exclude_none=True, exclude={"jsonrpc"}
68
+ )
69
+ )
70
+ try:
71
+ return await self.client.session.send_notification(client_message)
72
+ except McpError as e:
73
+ return types.JSONRPCError(jsonrpc="2.0", error=e.error)
74
+
75
+ async def send_request(client_message: types.ClientRequest, id: str) -> Any:
76
+ async with self.client:
77
+ # Strip the jsonrpc and id fields as send_request sets them again too.
78
+ client_message = types.ClientRequest.model_validate(
79
+ client_message.model_dump(
80
+ by_alias=True, mode="json", exclude_none=True, exclude={"jsonrpc", "id"}
81
+ )
82
+ )
83
+
84
+ result_type = None
85
+ print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
86
+ print(types.PingRequest)
87
+ if isinstance(client_message.root, types.PingRequest):
88
+ result_type = types.EmptyResult
89
+ elif isinstance(client_message.root, types.InitializeRequest):
90
+ return await self.client.session.initialize()
91
+ elif isinstance(client_message.root, types.SetLevelRequest):
92
+ result_type = types.EmptyResult
93
+ elif isinstance(client_message.root, types.ListResourcesRequest):
94
+ result_type = types.ListResourcesResult
95
+ elif isinstance(client_message.root, types.ListResourceTemplatesRequest):
96
+ result_type = types.ListResourceTemplatesResult
97
+ elif isinstance(client_message.root, types.ReadResourceRequest):
98
+ result_type = types.ReadResourceResult
99
+ elif isinstance(client_message.root, types.SubscribeRequest):
100
+ result_type = types.EmptyResult
101
+ elif isinstance(client_message.root, types.UnsubscribeRequest):
102
+ result_type = types.EmptyResult
103
+ elif isinstance(client_message.root, types.ListPromptsRequest):
104
+ result_type = types.ListPromptsResult
105
+ elif isinstance(client_message.root, types.GetPromptRequest):
106
+ result_type = types.GetPromptResult
107
+ elif isinstance(client_message.root, types.CompleteRequest):
108
+ result_type = types.CompleteResult
109
+ elif isinstance(client_message.root, types.ListToolsRequest):
110
+ result_type = types.ListToolsResult
111
+ elif isinstance(client_message.root, types.CallToolRequest):
112
+ result_type = types.CallToolResult
113
+ else:
114
+ # this is a special case where we need to return the list of tools.
115
+ raise NotImplementedError(f"Method {client_message.method} not implemented")
116
+ # Call the mcp server using send_request() or send_notification() depending on the method.
117
+ try:
118
+ return await self.client.session.send_request(client_message, result_type)
119
+ except McpError as e:
120
+ return types.JSONRPCError(jsonrpc="2.0", id=id, error=e.error)
121
+
122
+ # The message coming here is the generic request. We look at it's .method
123
+ # to determine which client function to call and to further subparse the params.
124
+ # Note(zeiler): unfortunately the pydantic types in mcp/types.py are not consistent.
125
+ # The JSONRPCRequest are supposed to have an id but the InitializeRequest
126
+ # does not have it.
127
+ d = json.loads(msg)
128
+
129
+ # If we have an id it's a JSONRPCRequest
130
+ if not d.get('method', None).startswith("notifications/"):
131
+ # rpc_message2 = types.JSONRPCRequest.model_validate(rpc_message)
132
+ # underlying: types.JSONRPCRequest = rpc_message.root
133
+ client_message = types.ClientRequest.model_validate(d)
134
+ response = asyncio.run(
135
+ send_request(client_message, id=d.get('id', ""))
136
+ ) # underlying.id))
137
+ else: # JSONRPCRequest
138
+ client_message = types.ClientNotification.model_validate(d)
139
+ response = asyncio.run(send_notification(client_message))
140
+ if response is None:
141
+ return ""
142
+ # return as a serialized json string
143
+ return response.model_dump_json(by_alias=True, exclude_none=True)
@@ -0,0 +1,149 @@
1
+ """Base class for creating Model Context Protocol (MCP) servers."""
2
+
3
+ import asyncio
4
+ import json
5
+ from typing import Any
6
+
7
+ from fastmcp import Client, FastMCP # use fastmcp v2 not the built in mcp
8
+ from mcp import types
9
+ from mcp.shared.exceptions import McpError
10
+
11
+ from clarifai.runners.models.model_class import ModelClass
12
+
13
+ # class MCPServerProvider(abc.ABC):
14
+ # """
15
+ # Base class for creating Model Context Protocol (MCP) servers.
16
+
17
+ # This class provides a base implementation of the MCP server, including
18
+ # methods for handling requests and responses, as well as error handling and
19
+ # logging.
20
+
21
+ # Attributes:
22
+ # _server: The FastMCP server instance.
23
+ # _tools: List of tools available in the server.
24
+ # _resources: List of resources available in the server.
25
+ # _prompts: List of prompts available in the server.
26
+
27
+ # Methods:
28
+ # get_server(): Returns the FastMCP server instance.
29
+ # mcp_transport(msg): Handles incoming messages and sends them to the FastMCP server.
30
+ # """
31
+
32
+ # @abc.abstractmethod
33
+ # def get_server(self) -> FastMCP:
34
+ # """Required method for each subclass to implement to return the FastMCP server to use."""
35
+ # if self._server is None:
36
+ # raise ValueError("Server not initialized")
37
+ # return self._server
38
+
39
+
40
+ class MCPModelClass(ModelClass, MCPServerProvider):
41
+ """Base class for wrapping FastMCP servers as a model running in Clarfai. This handles
42
+ all the transport between the API and the MCP server here. Simply subclass this and implement
43
+ the get_server() method to return the FastMCP server instance. The server is then used to
44
+ handle all the requests and responses.
45
+ """
46
+
47
+ def load_model(self):
48
+ # in memory transport provided in fastmcp v2 so we can easily use the client functions.
49
+ self.client = Client(self.get_server())
50
+
51
+ def get_server(self) -> FastMCP:
52
+ """Required method for each subclass to implement to return the FastMCP server to use."""
53
+ if self._server is None:
54
+ raise ValueError("Server not initialized")
55
+ return self._server
56
+
57
+ @ModelClass.method
58
+ def mcp_transport(self, msg: str) -> str:
59
+ """The single model method to get the jsonrpc message and send it to the FastMCP server then
60
+ return it's response.
61
+
62
+ Arguments:
63
+ msg: The incoming message to be handled in serialized JSONRPC format from an MCP client.
64
+ Returns:
65
+ str: The response to the incoming message in serialized JSONRPC format
66
+ """
67
+
68
+ async def send_notification(client_message: types.ClientNotification) -> None:
69
+ async with self.client:
70
+ # Strip the jsonrpc field since send_notification will also pass it in for some reason.
71
+ client_message = types.ClientNotification.model_validate(
72
+ client_message.model_dump(
73
+ by_alias=True, mode="json", exclude_none=True, exclude={"jsonrpc"}
74
+ )
75
+ )
76
+ try:
77
+ return await self.client.session.send_notification(client_message)
78
+ except McpError as e:
79
+ return types.JSONRPCError(jsonrpc="2.0", error=e.error)
80
+
81
+ async def send_request(client_message: types.ClientRequest, id: str) -> Any:
82
+ async with self.client:
83
+ # Strip the jsonrpc and id fields as send_request sets them again too.
84
+ client_message = types.ClientRequest.model_validate(
85
+ client_message.model_dump(
86
+ by_alias=True, mode="json", exclude_none=True, exclude={"jsonrpc", "id"}
87
+ )
88
+ )
89
+
90
+ result_type = None
91
+ print("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
92
+ print(types.PingRequest)
93
+ if isinstance(client_message.root, types.PingRequest):
94
+ result_type = types.EmptyResult
95
+ elif isinstance(client_message.root, types.InitializeRequest):
96
+ return await self.client.session.initialize()
97
+ elif isinstance(client_message.root, types.SetLevelRequest):
98
+ result_type = types.EmptyResult
99
+ elif isinstance(client_message.root, types.ListResourcesRequest):
100
+ result_type = types.ListResourcesResult
101
+ elif isinstance(client_message.root, types.ListResourceTemplatesRequest):
102
+ result_type = types.ListResourceTemplatesResult
103
+ elif isinstance(client_message.root, types.ReadResourceRequest):
104
+ result_type = types.ReadResourceResult
105
+ elif isinstance(client_message.root, types.SubscribeRequest):
106
+ result_type = types.EmptyResult
107
+ elif isinstance(client_message.root, types.UnsubscribeRequest):
108
+ result_type = types.EmptyResult
109
+ elif isinstance(client_message.root, types.ListPromptsRequest):
110
+ result_type = types.ListPromptsResult
111
+ elif isinstance(client_message.root, types.GetPromptRequest):
112
+ result_type = types.GetPromptResult
113
+ elif isinstance(client_message.root, types.CompleteRequest):
114
+ result_type = types.CompleteResult
115
+ elif isinstance(client_message.root, types.ListToolsRequest):
116
+ result_type = types.ListToolsResult
117
+ elif isinstance(client_message.root, types.CallToolRequest):
118
+ result_type = types.CallToolResult
119
+ else:
120
+ # this is a special case where we need to return the list of tools.
121
+ raise NotImplementedError(f"Method {client_message.method} not implemented")
122
+ # Call the mcp server using send_request() or send_notification() depending on the method.
123
+ try:
124
+ return await self.client.session.send_request(client_message, result_type)
125
+ except McpError as e:
126
+ return types.JSONRPCError(jsonrpc="2.0", id=id, error=e.error)
127
+
128
+ # The message coming here is the generic request. We look at it's .method
129
+ # to determine which client function to call and to further subparse the params.
130
+ # Note(zeiler): unfortunately the pydantic types in mcp/types.py are not consistent.
131
+ # The JSONRPCRequest are supposed to have an id but the InitializeRequest
132
+ # does not have it.
133
+ d = json.loads(msg)
134
+
135
+ # If we have an id it's a JSONRPCRequest
136
+ if not d.get('method', None).startswith("notifications/"):
137
+ # rpc_message2 = types.JSONRPCRequest.model_validate(rpc_message)
138
+ # underlying: types.JSONRPCRequest = rpc_message.root
139
+ client_message = types.ClientRequest.model_validate(d)
140
+ response = asyncio.run(
141
+ send_request(client_message, id=d.get('id', ""))
142
+ ) # underlying.id))
143
+ else: # JSONRPCRequest
144
+ client_message = types.ClientNotification.model_validate(d)
145
+ response = asyncio.run(send_notification(client_message))
146
+ if response is None:
147
+ return ""
148
+ # return as a serialized json string
149
+ return response.model_dump_json(by_alias=True, exclude_none=True)
@@ -14,15 +14,17 @@ import yaml
14
14
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
15
15
  from clarifai_grpc.grpc.api.status import status_code_pb2
16
16
  from google.protobuf import json_format
17
- from rich import print
18
- from rich.markup import escape
19
17
 
20
18
  from clarifai.client.base import BaseClient
21
19
  from clarifai.runners.models.model_class import ModelClass
22
20
  from clarifai.runners.utils.const import (
21
+ AMD_PYTHON_BASE_IMAGE,
22
+ AMD_VLLM_BASE_IMAGE,
23
23
  AVAILABLE_PYTHON_IMAGES,
24
24
  AVAILABLE_TORCH_IMAGES,
25
25
  CONCEPTS_REQUIRED_MODEL_TYPE,
26
+ DEFAULT_AMD_GPU_VERSION,
27
+ DEFAULT_AMD_TORCH_VERSION,
26
28
  DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
27
29
  DEFAULT_PYTHON_VERSION,
28
30
  DEFAULT_RUNTIME_DOWNLOAD_PATH,
@@ -43,13 +45,6 @@ dependencies = [
43
45
  ]
44
46
 
45
47
 
46
- def _clear_line(n: int = 1) -> None:
47
- LINE_UP = '\033[1A' # Move cursor up one line
48
- LINE_CLEAR = '\x1b[2K' # Clear the entire line
49
- for _ in range(n):
50
- print(LINE_UP, end=LINE_CLEAR, flush=True)
51
-
52
-
53
48
  def is_related(object_class, main_class):
54
49
  # Check if the object_class is a subclass of main_class
55
50
  if issubclass(object_class, main_class):
@@ -361,13 +356,23 @@ class ModelBuilder:
361
356
  if self.config.get("checkpoints"):
362
357
  loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
363
358
 
364
- if loader_type == "huggingface" and hf_token:
365
- is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
366
- if not is_valid_token:
359
+ if loader_type == "huggingface":
360
+ is_valid_token = hf_token and HuggingFaceLoader.validate_hftoken(hf_token)
361
+ if not is_valid_token and hf_token:
362
+ logger.info(
363
+ "Continuing without Hugging Face token for validating config in model builder."
364
+ )
365
+
366
+ has_repo_access = HuggingFaceLoader.validate_hf_repo_access(
367
+ repo_id=self.config.get("checkpoints", {}).get("repo_id"),
368
+ token=hf_token if is_valid_token else None,
369
+ )
370
+
371
+ if not has_repo_access:
367
372
  logger.error(
368
- "Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints."
373
+ f"Invalid Hugging Face repo access for repo {self.config.get('checkpoints').get('repo_id')}. Please check your repo and try again."
369
374
  )
370
- logger.info("Continuing without Hugging Face token")
375
+ sys.exit("Token does not have access to HuggingFace repo , exiting.")
371
376
 
372
377
  num_threads = self.config.get("num_threads")
373
378
  if num_threads or num_threads == 0:
@@ -405,11 +410,11 @@ class ModelBuilder:
405
410
  signatures = {method.name: method.signature for method in method_info.values()}
406
411
  return signatures_to_yaml(signatures)
407
412
 
408
- def get_method_signatures(self):
413
+ def get_method_signatures(self, mocking=True):
409
414
  """
410
415
  Returns the method signatures for the model class.
411
416
  """
412
- model_class = self.load_model_class(mocking=True)
417
+ model_class = self.load_model_class(mocking=mocking)
413
418
  method_info = model_class._get_method_info()
414
419
  signatures = [method.signature for method in method_info.values()]
415
420
  return signatures
@@ -532,6 +537,30 @@ class ModelBuilder:
532
537
  dependencies_version[dependency] = version if version else None
533
538
  return dependencies_version
534
539
 
540
+ def _is_amd(self):
541
+ """
542
+ Check if the model is AMD or not.
543
+ """
544
+ is_amd_gpu = False
545
+ is_nvidia_gpu = False
546
+ if "inference_compute_info" in self.config:
547
+ inference_compute_info = self.config.get('inference_compute_info')
548
+ if 'accelerator_type' in inference_compute_info:
549
+ for accelerator in inference_compute_info['accelerator_type']:
550
+ if 'amd' in accelerator.lower():
551
+ is_amd_gpu = True
552
+ elif 'nvidia' in accelerator.lower():
553
+ is_nvidia_gpu = True
554
+ if is_amd_gpu and is_nvidia_gpu:
555
+ raise Exception(
556
+ "Both AMD and NVIDIA GPUs are specified in the config file, please use only one type of GPU."
557
+ )
558
+ if is_amd_gpu:
559
+ logger.info("Using AMD base image to build the Docker image and upload the model")
560
+ elif is_nvidia_gpu:
561
+ logger.info("Using NVIDIA base image to build the Docker image and upload the model")
562
+ return is_amd_gpu
563
+
535
564
  def create_dockerfile(self):
536
565
  dockerfile_template = os.path.join(
537
566
  os.path.dirname(os.path.dirname(__file__)),
@@ -562,30 +591,85 @@ class ModelBuilder:
562
591
  )
563
592
  python_version = DEFAULT_PYTHON_VERSION
564
593
 
565
- # This is always the final image used for runtime.
566
- final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
567
- downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
568
-
569
594
  # Parse the requirements.txt file to determine the base image
570
595
  dependencies = self._parse_requirements()
571
- if 'torch' in dependencies and dependencies['torch']:
572
- torch_version = dependencies['torch']
573
-
574
- # Sort in reverse so that newer cuda versions come first and are preferred.
575
- for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
576
- if torch_version in image and f'py{python_version}' in image:
577
- # like cu124, rocm6.3, etc.
578
- gpu_version = image.split('-')[-1]
579
- final_image = TORCH_BASE_IMAGE.format(
580
- torch_version=torch_version,
581
- python_version=python_version,
582
- gpu_version=gpu_version,
596
+
597
+ is_amd_gpu = self._is_amd()
598
+ if is_amd_gpu:
599
+ final_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version)
600
+ downloader_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version)
601
+ if 'vllm' in dependencies:
602
+ if python_version != DEFAULT_PYTHON_VERSION:
603
+ raise Exception(
604
+ f"vLLM is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
583
605
  )
606
+ torch_version = dependencies.get('torch', None)
607
+ if 'torch' in dependencies:
608
+ if python_version != DEFAULT_PYTHON_VERSION:
609
+ raise Exception(
610
+ f"torch is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
611
+ )
612
+ if not torch_version:
613
+ logger.info(
614
+ f"torch version not found in requirements.txt, using the default version {DEFAULT_AMD_TORCH_VERSION}"
615
+ )
616
+ torch_version = DEFAULT_AMD_TORCH_VERSION
617
+ if torch_version not in [DEFAULT_AMD_TORCH_VERSION]:
618
+ raise Exception(
619
+ f"torch version {torch_version} not supported, please use one of the following versions: {DEFAULT_AMD_TORCH_VERSION} in your requirements.txt"
620
+ )
621
+ python_version = DEFAULT_PYTHON_VERSION
622
+ gpu_version = DEFAULT_AMD_GPU_VERSION
623
+ final_image = AMD_VLLM_BASE_IMAGE.format(
624
+ torch_version=torch_version,
625
+ python_version=python_version,
626
+ gpu_version=gpu_version,
627
+ )
628
+ logger.info("Using vLLM base image to build the Docker image")
629
+ elif 'torch' in dependencies:
630
+ torch_version = dependencies['torch']
631
+ if python_version != DEFAULT_PYTHON_VERSION:
632
+ raise Exception(
633
+ f"torch is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
634
+ )
635
+ if not torch_version:
584
636
  logger.info(
585
- f"Using Torch version {torch_version} base image to build the Docker image"
637
+ f"torch version not found in requirements.txt, using the default version {DEFAULT_AMD_TORCH_VERSION}"
586
638
  )
587
- break
588
-
639
+ torch_version = DEFAULT_AMD_TORCH_VERSION
640
+ if torch_version not in [DEFAULT_AMD_TORCH_VERSION]:
641
+ raise Exception(
642
+ f"torch version {torch_version} not supported, please use one of the following versions: {DEFAULT_AMD_TORCH_VERSION} in your requirements.txt"
643
+ )
644
+ python_version = DEFAULT_PYTHON_VERSION
645
+ gpu_version = DEFAULT_AMD_GPU_VERSION
646
+ final_image = TORCH_BASE_IMAGE.format(
647
+ torch_version=torch_version,
648
+ python_version=python_version,
649
+ gpu_version=gpu_version,
650
+ )
651
+ logger.info(
652
+ f"Using Torch version {torch_version} base image to build the Docker image"
653
+ )
654
+ else:
655
+ final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
656
+ downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
657
+ if 'torch' in dependencies and dependencies['torch']:
658
+ torch_version = dependencies['torch']
659
+ # Sort in reverse so that newer cuda versions come first and are preferred.
660
+ for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
661
+ if torch_version in image and f'py{python_version}' in image:
662
+ # like cu124, rocm6.3, etc.
663
+ gpu_version = image.split('-')[-1]
664
+ final_image = TORCH_BASE_IMAGE.format(
665
+ torch_version=torch_version,
666
+ python_version=python_version,
667
+ gpu_version=gpu_version,
668
+ )
669
+ logger.info(
670
+ f"Using Torch version {torch_version} base image to build the Docker image"
671
+ )
672
+ break
589
673
  if 'clarifai' not in dependencies:
590
674
  raise Exception(
591
675
  f"clarifai not found in requirements.txt, please add clarifai to the requirements.txt file with a fixed version. Current version is clarifai=={CLIENT_VERSION}"
@@ -835,7 +919,6 @@ class ModelBuilder:
835
919
  percent_completed = response.status.percent_completed
836
920
  details = response.status.details
837
921
 
838
- _clear_line()
839
922
  print(
840
923
  f"Status: {response.status.description}, Progress: {percent_completed}% - {details} ",
841
924
  f"request_id: {response.status.req_id}",
@@ -849,7 +932,48 @@ class ModelBuilder:
849
932
  logger.info(f"Created Model Version ID: {self.model_version_id}")
850
933
  logger.info(f"Full url to that version is: {self.model_url}")
851
934
  try:
852
- self.monitor_model_build()
935
+ is_uploaded = self.monitor_model_build()
936
+ if is_uploaded:
937
+ # Provide an mcp client config
938
+ if model_type_id == "mcp":
939
+ snippet = (
940
+ """
941
+ import asyncio
942
+ import os
943
+ from fastmcp import Client
944
+ from fastmcp.client.transports import StreamableHttpTransport
945
+
946
+ transport = StreamableHttpTransport(url="%s/mcp",
947
+ headers={"Authorization": "Bearer " + os.environ["CLARIFAI_PAT"]})
948
+
949
+ async def main():
950
+ async with Client(transport) as client:
951
+ tools = await client.list_tools()
952
+ print(f"Available tools: {tools}")
953
+ result = await client.call_tool(tools[0].name, {"a": 5, "b": 3})
954
+ print(f"Result: {result[0].text}")
955
+
956
+ if __name__ == "__main__":
957
+ asyncio.run(main())
958
+ """
959
+ % self.model_url
960
+ )
961
+ else: # python code to run the model.
962
+ from clarifai.runners.utils import code_script
963
+
964
+ method_signatures = self.get_method_signatures()
965
+ snippet = code_script.generate_client_script(
966
+ method_signatures,
967
+ user_id=self.client.user_app_id.user_id,
968
+ app_id=self.client.user_app_id.app_id,
969
+ model_id=self.model_proto.id,
970
+ )
971
+ logger.info("""\n
972
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
973
+ # Here is a code snippet to use this model:
974
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
975
+ """)
976
+ logger.info(snippet)
853
977
  finally:
854
978
  if os.path.exists(self.tar_file):
855
979
  logger.debug(f"Cleaning up upload file: {self.tar_file}")
@@ -933,7 +1057,12 @@ class ModelBuilder:
933
1057
  for log_entry in logs.log_entries:
934
1058
  if log_entry.url not in seen_logs:
935
1059
  seen_logs.add(log_entry.url)
936
- logger.info(f"{escape(log_entry.message.strip())}")
1060
+ log_entry_msg = re.sub(
1061
+ r"(\\*)(\[[a-z#/@][^[]*?])",
1062
+ lambda m: f"{m.group(1)}{m.group(1)}\\{m.group(2)}",
1063
+ log_entry.message.strip(),
1064
+ )
1065
+ logger.info(log_entry_msg)
937
1066
  if status_code == status_code_pb2.MODEL_BUILDING:
938
1067
  print(
939
1068
  f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True
@@ -9,7 +9,6 @@ from typing import Any, Dict, Iterator, List
9
9
 
10
10
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
11
11
  from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
12
- from google.protobuf import json_format
13
12
 
14
13
  from clarifai.runners.utils import data_types
15
14
  from clarifai.runners.utils.data_utils import DataConverter
@@ -100,7 +99,6 @@ class ModelClass(ABC):
100
99
  try:
101
100
  # TODO add method name field to proto
102
101
  method_name = 'predict'
103
- inference_params = get_inference_params(request)
104
102
  if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
105
103
  method_name = request.inputs[0].data.metadata['_method_name']
106
104
  if (
@@ -124,7 +122,7 @@ class ModelClass(ABC):
124
122
  input.data.CopyFrom(new_data)
125
123
  # convert inputs to python types
126
124
  inputs = self._convert_input_protos_to_python(
127
- request.inputs, inference_params, signature.input_fields, python_param_types
125
+ request.inputs, signature.input_fields, python_param_types
128
126
  )
129
127
  if len(inputs) == 1:
130
128
  inputs = inputs[0]
@@ -163,7 +161,6 @@ class ModelClass(ABC):
163
161
  ) -> Iterator[service_pb2.MultiOutputResponse]:
164
162
  try:
165
163
  method_name = 'generate'
166
- inference_params = get_inference_params(request)
167
164
  if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
168
165
  method_name = request.inputs[0].data.metadata['_method_name']
169
166
  method = getattr(self, method_name)
@@ -180,7 +177,7 @@ class ModelClass(ABC):
180
177
  )
181
178
  input.data.CopyFrom(new_data)
182
179
  inputs = self._convert_input_protos_to_python(
183
- request.inputs, inference_params, signature.input_fields, python_param_types
180
+ request.inputs, signature.input_fields, python_param_types
184
181
  )
185
182
  if len(inputs) == 1:
186
183
  inputs = inputs[0]
@@ -226,7 +223,6 @@ class ModelClass(ABC):
226
223
  assert len(request.inputs) == 1, "Streaming requires exactly one input"
227
224
 
228
225
  method_name = 'stream'
229
- inference_params = get_inference_params(request)
230
226
  if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
231
227
  method_name = request.inputs[0].data.metadata['_method_name']
232
228
  method = getattr(self, method_name)
@@ -251,7 +247,7 @@ class ModelClass(ABC):
251
247
  input.data.CopyFrom(new_data)
252
248
  # convert all inputs for the first request, including the first stream value
253
249
  inputs = self._convert_input_protos_to_python(
254
- request.inputs, inference_params, signature.input_fields, python_param_types
250
+ request.inputs, signature.input_fields, python_param_types
255
251
  )
256
252
  kwargs = inputs[0]
257
253
 
@@ -264,7 +260,7 @@ class ModelClass(ABC):
264
260
  # subsequent streaming items contain only the streaming input
265
261
  for request in request_iterator:
266
262
  item = self._convert_input_protos_to_python(
267
- request.inputs, inference_params, [stream_sig], python_param_types
263
+ request.inputs, [stream_sig], python_param_types
268
264
  )
269
265
  item = item[0][stream_argname]
270
266
  yield item
@@ -297,13 +293,12 @@ class ModelClass(ABC):
297
293
  def _convert_input_protos_to_python(
298
294
  self,
299
295
  inputs: List[resources_pb2.Input],
300
- inference_params: dict,
301
296
  variables_signature: List[resources_pb2.ModelTypeField],
302
297
  python_param_types,
303
298
  ) -> List[Dict[str, Any]]:
304
299
  result = []
305
300
  for input in inputs:
306
- kwargs = deserialize(input.data, variables_signature, inference_params)
301
+ kwargs = deserialize(input.data, variables_signature)
307
302
  # dynamic cast to annotated types
308
303
  for k, v in kwargs.items():
309
304
  if k not in python_param_types:
@@ -374,18 +369,6 @@ class ModelClass(ABC):
374
369
  return method_info
375
370
 
376
371
 
377
- # Helper function to get the inference params
378
- def get_inference_params(request) -> dict:
379
- """Get the inference params from the request."""
380
- inference_params = {}
381
- if request.model.model_version.id != "":
382
- output_info = request.model.model_version.output_info
383
- output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)
384
- if "params" in output_info:
385
- inference_params = output_info["params"]
386
- return inference_params
387
-
388
-
389
372
  class _MethodInfo:
390
373
  def __init__(self, method):
391
374
  self.name = method.__name__