davidkhala.ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- davidkhala/ai/agent/dify/api/__init__.py +2 -2
- davidkhala/ai/agent/dify/api/app.py +10 -6
- davidkhala/ai/agent/dify/api/knowledge/__init__.py +0 -0
- davidkhala/ai/agent/dify/api/knowledge/chunk.py +14 -0
- davidkhala/ai/agent/dify/api/knowledge/dataset.py +82 -0
- davidkhala/ai/agent/dify/api/knowledge/document.py +42 -0
- davidkhala/ai/agent/dify/api/knowledge/model.py +139 -0
- davidkhala/ai/agent/dify/{ops/console → console}/__init__.py +7 -1
- davidkhala/ai/agent/dify/console/knowledge/__init__.py +0 -0
- davidkhala/ai/agent/dify/console/knowledge/dataset.py +61 -0
- davidkhala/ai/agent/dify/console/knowledge/pipeline.py +127 -0
- davidkhala/ai/agent/dify/{ops/console → console}/plugin.py +20 -6
- davidkhala/ai/agent/dify/console/session.py +50 -0
- davidkhala/ai/agent/dify/db/orm.py +65 -0
- davidkhala/ai/agent/dify/model/__init__.py +7 -0
- davidkhala/ai/agent/dify/{model.py → model/knowledge.py} +1 -12
- davidkhala/ai/agent/dify/{ops/db/orm.py → model/workflow.py} +24 -62
- davidkhala/ai/agent/langgraph.py +1 -1
- davidkhala/ai/ali/dashscope.py +15 -14
- davidkhala/ai/anthropic/__init__.py +6 -0
- davidkhala/ai/api/__init__.py +6 -19
- davidkhala/ai/api/openrouter.py +14 -10
- davidkhala/ai/api/siliconflow.py +2 -4
- davidkhala/ai/atlas/__init__.py +24 -0
- davidkhala/ai/mistral/__init__.py +2 -20
- davidkhala/ai/mistral/agent.py +50 -0
- davidkhala/ai/mistral/ai.py +40 -0
- davidkhala/ai/mistral/file.py +38 -0
- davidkhala/ai/mistral/ocr.py +46 -0
- davidkhala/ai/model/__init__.py +11 -27
- davidkhala/ai/model/chat.py +60 -4
- davidkhala/ai/model/embed.py +8 -0
- davidkhala/ai/model/garden.py +9 -0
- davidkhala/ai/openai/__init__.py +9 -33
- davidkhala/ai/openai/azure.py +51 -0
- davidkhala/ai/openai/native.py +2 -3
- davidkhala/ai/openrouter/__init__.py +24 -13
- {davidkhala_ai-0.2.1.dist-info → davidkhala_ai-0.2.2.dist-info}/METADATA +8 -6
- davidkhala_ai-0.2.2.dist-info/RECORD +65 -0
- davidkhala/ai/agent/dify/api/knowledge.py +0 -191
- davidkhala/ai/agent/dify/ops/__init__.py +0 -1
- davidkhala/ai/agent/dify/ops/console/knowledge.py +0 -158
- davidkhala/ai/agent/dify/ops/console/session.py +0 -32
- davidkhala/ai/huggingface/BAAI.py +0 -10
- davidkhala/ai/huggingface/__init__.py +0 -21
- davidkhala/ai/huggingface/inference.py +0 -13
- davidkhala_ai-0.2.1.dist-info/RECORD +0 -53
- /davidkhala/ai/agent/dify/{ops/db → db}/__init__.py +0 -0
- /davidkhala/ai/agent/dify/{ops/db → db}/app.py +0 -0
- /davidkhala/ai/agent/dify/{ops/db → db}/knowledge.py +0 -0
- /davidkhala/ai/agent/dify/{ops/db → db}/sys.py +0 -0
- {davidkhala_ai-0.2.1.dist-info → davidkhala_ai-0.2.2.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from pydantic import BaseModel
|
|
1
|
+
from pydantic import BaseModel
|
|
2
2
|
|
|
3
3
|
from davidkhala.ai.agent.dify.const import IndexingStatus
|
|
4
4
|
|
|
@@ -18,14 +18,3 @@ class Dataset(BaseModel):
|
|
|
18
18
|
id: str
|
|
19
19
|
name: str
|
|
20
20
|
description: str
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class JsonData(BaseModel):
|
|
24
|
-
data: list
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class NodeOutput(BaseModel):
|
|
28
|
-
"""Schema for Output of a Dify node"""
|
|
29
|
-
text: str
|
|
30
|
-
files: list
|
|
31
|
-
json_: list[JsonData] = Field(alias="json") # avoid conflict with .json()
|
|
@@ -1,66 +1,34 @@
|
|
|
1
|
-
import json
|
|
2
1
|
from enum import Enum
|
|
3
|
-
from typing import Any,
|
|
2
|
+
from typing import Protocol, Literal, Any, Optional
|
|
4
3
|
|
|
5
|
-
from pydantic import BaseModel
|
|
6
|
-
from sqlalchemy import Column, String, Text, JSON, TIMESTAMP, func
|
|
7
|
-
from sqlalchemy.dialects.postgresql import UUID
|
|
8
|
-
from sqlalchemy.orm import declarative_base
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
9
5
|
|
|
10
|
-
Base = declarative_base()
|
|
11
6
|
|
|
12
7
|
|
|
13
|
-
class DifyBase(Base):
|
|
14
|
-
__abstract__ = True # keyword for SQLAlchemy
|
|
15
|
-
id = Column(UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4())
|
|
16
8
|
|
|
9
|
+
class NodeProtocol(Protocol):
|
|
10
|
+
id:str
|
|
11
|
+
datasource_type: str
|
|
17
12
|
|
|
18
|
-
class AppModelConfig(DifyBase):
|
|
19
|
-
__tablename__ = "app_model_configs"
|
|
20
|
-
__table_args__ = {"schema": "public"}
|
|
21
13
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
created_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
|
|
29
|
-
updated_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
|
|
30
|
-
|
|
31
|
-
opening_statement = Column(Text)
|
|
32
|
-
suggested_questions = Column(Text)
|
|
33
|
-
suggested_questions_after_answer = Column(Text)
|
|
34
|
-
more_like_this = Column(Text)
|
|
35
|
-
model = Column(Text)
|
|
36
|
-
user_input_form = Column(Text)
|
|
37
|
-
pre_prompt = Column(Text)
|
|
38
|
-
agent_mode = Column(Text)
|
|
39
|
-
speech_to_text = Column(Text)
|
|
40
|
-
sensitive_word_avoidance = Column(Text)
|
|
41
|
-
retriever_resource = Column(Text)
|
|
14
|
+
class Position(BaseModel):
|
|
15
|
+
x: float
|
|
16
|
+
y: float
|
|
17
|
+
class Viewport(Position):
|
|
18
|
+
zoom: float
|
|
42
19
|
|
|
43
|
-
|
|
44
|
-
|
|
20
|
+
class JsonData(BaseModel):
|
|
21
|
+
data: list
|
|
45
22
|
|
|
46
|
-
chat_prompt_config = Column(Text)
|
|
47
|
-
completion_prompt_config = Column(Text)
|
|
48
|
-
dataset_configs = Column(Text)
|
|
49
|
-
external_data_tools = Column(Text)
|
|
50
|
-
file_upload = Column(Text)
|
|
51
|
-
text_to_speech = Column(Text)
|
|
52
23
|
|
|
53
|
-
|
|
54
|
-
|
|
24
|
+
class NodeOutput(BaseModel):
|
|
25
|
+
"""Schema for Output of a Dify node"""
|
|
26
|
+
text: str
|
|
27
|
+
files: list
|
|
28
|
+
json_: list[JsonData] = Field(alias="json") # avoid conflict with .json()
|
|
55
29
|
|
|
56
|
-
def __repr__(self):
|
|
57
|
-
return f"<AppModelConfig(id={self.id}, app_id={self.app_id}, provider={self.provider}, model_id={self.model_id})>"
|
|
58
30
|
|
|
59
31
|
|
|
60
|
-
class Position(BaseModel):
|
|
61
|
-
x: float
|
|
62
|
-
y: float
|
|
63
|
-
|
|
64
32
|
|
|
65
33
|
class NodeData(BaseModel):
|
|
66
34
|
class Type(str, Enum):
|
|
@@ -104,7 +72,6 @@ class NodeData(BaseModel):
|
|
|
104
72
|
embedding_model: str | None = None
|
|
105
73
|
embedding_model_provider: str | None = None
|
|
106
74
|
|
|
107
|
-
|
|
108
75
|
class Node(BaseModel):
|
|
109
76
|
@property
|
|
110
77
|
def datasource_type(self): return self.data.provider_type
|
|
@@ -117,9 +84,13 @@ class Node(BaseModel):
|
|
|
117
84
|
positionAbsolute: Position | None = None
|
|
118
85
|
width: float | None = None
|
|
119
86
|
height: float | None = None
|
|
120
|
-
selected: bool
|
|
121
|
-
|
|
87
|
+
selected: bool | None = False
|
|
122
88
|
|
|
89
|
+
class EdgeData(BaseModel):
|
|
90
|
+
sourceType: str | None = None
|
|
91
|
+
targetType: str | None = None
|
|
92
|
+
isInIteration: bool | None = False
|
|
93
|
+
isInLoop: bool | None = False
|
|
123
94
|
class Edge(BaseModel):
|
|
124
95
|
id: str
|
|
125
96
|
type: str
|
|
@@ -127,16 +98,10 @@ class Edge(BaseModel):
|
|
|
127
98
|
target: str
|
|
128
99
|
sourceHandle: str | None = None
|
|
129
100
|
targetHandle: str | None = None
|
|
130
|
-
data:
|
|
101
|
+
data: EdgeData | None = None
|
|
131
102
|
zIndex: int | None = None
|
|
132
103
|
|
|
133
104
|
|
|
134
|
-
class Viewport(BaseModel):
|
|
135
|
-
x: float
|
|
136
|
-
y: float
|
|
137
|
-
zoom: float
|
|
138
|
-
|
|
139
|
-
|
|
140
105
|
class Graph(BaseModel):
|
|
141
106
|
nodes: list[Node]
|
|
142
107
|
edges: list[Edge]
|
|
@@ -146,6 +111,3 @@ class Graph(BaseModel):
|
|
|
146
111
|
def datasources(self):
|
|
147
112
|
return [node for node in self.nodes if node.data.type == NodeData.Type.SOURCE]
|
|
148
113
|
|
|
149
|
-
@staticmethod
|
|
150
|
-
def convert(*records: list[dict]):
|
|
151
|
-
return [{**record, "graph": Graph(**json.loads(record["graph"]))} for record in records]
|
davidkhala/ai/agent/langgraph.py
CHANGED
davidkhala/ai/ali/dashscope.py
CHANGED
|
@@ -4,7 +4,9 @@ from http import HTTPStatus
|
|
|
4
4
|
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
|
5
5
|
|
|
6
6
|
from dashscope import Generation, TextEmbedding
|
|
7
|
-
from davidkhala.ai.model import
|
|
7
|
+
from davidkhala.ai.model import ClientProtocol
|
|
8
|
+
from davidkhala.ai.model.embed import EmbeddingAware
|
|
9
|
+
from davidkhala.ai.model.chat import MessageDict, ChatAware
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class ModelEnum(str, Enum):
|
|
@@ -16,7 +18,7 @@ class ModelEnum(str, Enum):
|
|
|
16
18
|
EMBED = TextEmbedding.Models.text_embedding_v4
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
class API(
|
|
21
|
+
class API(ChatAware, EmbeddingAware, ClientProtocol):
|
|
20
22
|
"""
|
|
21
23
|
Unsupported to use international base_url "https://dashscope-intl.aliyuncs.com"
|
|
22
24
|
"""
|
|
@@ -24,27 +26,26 @@ class API(AbstractClient):
|
|
|
24
26
|
def __init__(self, api_key):
|
|
25
27
|
super().__init__()
|
|
26
28
|
self.api_key = api_key
|
|
27
|
-
self.model: ModelEnum|None = None
|
|
29
|
+
self.model: ModelEnum | None = None
|
|
30
|
+
|
|
28
31
|
def as_embeddings(self, model=ModelEnum.EMBED):
|
|
29
32
|
super().as_embeddings(model)
|
|
30
33
|
|
|
31
34
|
@staticmethod
|
|
32
|
-
def _on_response(response:DashScopeAPIResponse):
|
|
35
|
+
def _on_response(response: DashScopeAPIResponse):
|
|
33
36
|
if response.status_code == HTTPStatus.OK:
|
|
34
37
|
return response.output
|
|
35
38
|
else:
|
|
36
39
|
raise Exception(response)
|
|
37
40
|
|
|
38
|
-
|
|
39
41
|
def chat(self, user_prompt: str, **kwargs):
|
|
40
42
|
|
|
41
43
|
if not self.messages:
|
|
42
44
|
kwargs['prompt'] = user_prompt
|
|
43
45
|
else:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
]
|
|
46
|
+
cloned = list(self.messages)
|
|
47
|
+
cloned.append(MessageDict(role='user', content=user_prompt))
|
|
48
|
+
kwargs['messages'] = cloned
|
|
48
49
|
# prompt 和 messages 是互斥的参数:如果你使用了 messages,就不要再传 prompt
|
|
49
50
|
r = Generation.call(
|
|
50
51
|
self.model,
|
|
@@ -53,11 +54,11 @@ class API(AbstractClient):
|
|
|
53
54
|
)
|
|
54
55
|
return API._on_response(r)
|
|
55
56
|
|
|
56
|
-
def encode(self, *_input: str)-> list[list[float]]:
|
|
57
|
-
r= TextEmbedding.call(
|
|
58
|
-
self.model,list(_input),
|
|
59
|
-
api_key=
|
|
57
|
+
def encode(self, *_input: str) -> list[list[float]]:
|
|
58
|
+
r = TextEmbedding.call(
|
|
59
|
+
self.model, list(_input),
|
|
60
|
+
api_key=self.api_key,
|
|
60
61
|
)
|
|
61
62
|
r = API._on_response(r)
|
|
62
63
|
|
|
63
|
-
return [item['embedding'] for item in r['embeddings']]
|
|
64
|
+
return [item['embedding'] for item in r['embeddings']]
|
davidkhala/ai/api/__init__.py
CHANGED
|
@@ -1,35 +1,22 @@
|
|
|
1
1
|
import datetime
|
|
2
|
-
from abc import abstractmethod
|
|
3
2
|
|
|
4
3
|
from davidkhala.utils.http_request import Request
|
|
5
4
|
|
|
6
|
-
from davidkhala.ai.model import
|
|
5
|
+
from davidkhala.ai.model.chat import ChatAware
|
|
6
|
+
from davidkhala.ai.model.garden import GardenAlike
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class API(
|
|
9
|
+
class API(ChatAware, Request, GardenAlike):
|
|
10
10
|
def __init__(self, api_key: str, base_url: str):
|
|
11
|
-
|
|
12
|
-
Request.__init__(self,{
|
|
11
|
+
ChatAware.__init__(self)
|
|
12
|
+
Request.__init__(self, {
|
|
13
13
|
"bearer": api_key
|
|
14
14
|
})
|
|
15
15
|
self.base_url = base_url + '/v1'
|
|
16
16
|
|
|
17
|
-
@property
|
|
18
|
-
@abstractmethod
|
|
19
|
-
def free_models(self) -> list[str]:
|
|
20
|
-
...
|
|
21
|
-
|
|
22
17
|
def chat(self, *user_prompt: str, **kwargs):
|
|
23
|
-
messages = [
|
|
24
|
-
*self.messages,
|
|
25
|
-
*[{
|
|
26
|
-
"role": "user",
|
|
27
|
-
"content": _
|
|
28
|
-
} for _ in user_prompt],
|
|
29
|
-
]
|
|
30
|
-
|
|
31
18
|
json = {
|
|
32
|
-
"messages":
|
|
19
|
+
"messages": self.messages_from(*user_prompt),
|
|
33
20
|
**kwargs,
|
|
34
21
|
}
|
|
35
22
|
|
davidkhala/ai/api/openrouter.py
CHANGED
|
@@ -5,16 +5,20 @@ from davidkhala.utils.http_request import default_on_response
|
|
|
5
5
|
from requests import Response
|
|
6
6
|
|
|
7
7
|
from davidkhala.ai.api import API
|
|
8
|
+
from davidkhala.ai.model.chat import CompareChatAware
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
class OpenRouter(API):
|
|
11
|
+
class OpenRouter(API, CompareChatAware):
|
|
12
|
+
|
|
11
13
|
@property
|
|
12
14
|
def free_models(self) -> list[str]:
|
|
13
|
-
|
|
15
|
+
l = list(
|
|
14
16
|
map(lambda model: model['id'],
|
|
15
17
|
filter(lambda model: model['id'].endswith(':free'), self.list_models())
|
|
16
18
|
)
|
|
17
19
|
)
|
|
20
|
+
l.append('openrouter/free')
|
|
21
|
+
return l
|
|
18
22
|
|
|
19
23
|
@staticmethod
|
|
20
24
|
def on_response(response: requests.Response):
|
|
@@ -30,8 +34,7 @@ class OpenRouter(API):
|
|
|
30
34
|
derived_response.raise_for_status()
|
|
31
35
|
return r
|
|
32
36
|
|
|
33
|
-
def __init__(self, api_key: str,
|
|
34
|
-
|
|
37
|
+
def __init__(self, api_key: str, **kwargs):
|
|
35
38
|
super().__init__(api_key, 'https://openrouter.ai/api')
|
|
36
39
|
|
|
37
40
|
if 'leaderboard' in kwargs and type(kwargs['leaderboard']) is dict:
|
|
@@ -39,8 +42,6 @@ class OpenRouter(API):
|
|
|
39
42
|
'url'] # Site URL for rankings on openrouter.ai.
|
|
40
43
|
self.options["headers"]["X-Title"] = kwargs['leaderboard'][
|
|
41
44
|
'name'] # Site title for rankings on openrouter.ai.
|
|
42
|
-
self.models = models
|
|
43
|
-
|
|
44
45
|
self.on_response = OpenRouter.on_response
|
|
45
46
|
self.retry = True
|
|
46
47
|
|
|
@@ -54,14 +55,17 @@ class OpenRouter(API):
|
|
|
54
55
|
else:
|
|
55
56
|
raise
|
|
56
57
|
|
|
58
|
+
def as_chat(self, *models: str, sys_prompt: str = None):
|
|
59
|
+
CompareChatAware.as_chat(self, *models, sys_prompt=sys_prompt)
|
|
60
|
+
|
|
57
61
|
def chat(self, *user_prompt: str, **kwargs):
|
|
58
|
-
if self.
|
|
59
|
-
kwargs["models"] = self.
|
|
62
|
+
if self._models:
|
|
63
|
+
kwargs["models"] = self._models
|
|
60
64
|
else:
|
|
61
65
|
kwargs["model"] = self.model
|
|
62
66
|
|
|
63
67
|
r = super().chat(*user_prompt, **kwargs)
|
|
64
68
|
|
|
65
|
-
if self.
|
|
66
|
-
assert r['model'] in self.
|
|
69
|
+
if self._models:
|
|
70
|
+
assert r['model'] in self._models
|
|
67
71
|
return r
|
davidkhala/ai/api/siliconflow.py
CHANGED
|
@@ -34,11 +34,9 @@ class SiliconFlow(API):
|
|
|
34
34
|
|
|
35
35
|
def __init__(self, api_key: str):
|
|
36
36
|
super().__init__(api_key, 'https://api.siliconflow.cn')
|
|
37
|
-
self.options['timeout'] = 50
|
|
38
37
|
|
|
39
|
-
def chat(self, *user_prompt: str
|
|
40
|
-
|
|
41
|
-
return super().chat(*user_prompt, **kwargs)
|
|
38
|
+
def chat(self, *user_prompt: str):
|
|
39
|
+
return super().chat(*user_prompt, model=self.model, timeout=50)
|
|
42
40
|
|
|
43
41
|
def encode(self, *_input: str) -> list[list[float]]:
|
|
44
42
|
json = {
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import voyageai
|
|
2
|
+
|
|
3
|
+
from davidkhala.ai.model import SDKProtocol
|
|
4
|
+
from davidkhala.ai.model.embed import EmbeddingAware
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Client(EmbeddingAware, SDKProtocol):
|
|
8
|
+
def __init__(self, api_key):
|
|
9
|
+
self.client = voyageai.Client(
|
|
10
|
+
api_key=api_key, # Or use VOYAGE_API_KEY environment variable
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
def as_embeddings(self, model: str = 'voyage-4'):
|
|
14
|
+
"""
|
|
15
|
+
:param model: see in https://www.mongodb.com/docs/voyageai/models/#choosing-a-model
|
|
16
|
+
"""
|
|
17
|
+
super().as_embeddings(model)
|
|
18
|
+
|
|
19
|
+
def encode(self, *_input: str) -> list[list[float]]:
|
|
20
|
+
result = self.client.embed(
|
|
21
|
+
texts=list(_input),
|
|
22
|
+
model=self.model
|
|
23
|
+
)
|
|
24
|
+
return result.embeddings
|
|
@@ -1,17 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
from mistralai import Mistral
|
|
2
2
|
|
|
3
|
-
from davidkhala.ai.model import AbstractClient
|
|
4
|
-
from mistralai import Mistral, ChatCompletionResponse, ResponseFormat
|
|
5
|
-
from davidkhala.ai.model.chat import on_response
|
|
6
3
|
|
|
7
|
-
class Client
|
|
8
|
-
n = 1
|
|
4
|
+
class Client:
|
|
9
5
|
|
|
10
6
|
def __init__(self, api_key: str):
|
|
11
7
|
self.api_key = api_key
|
|
12
8
|
self.client = Mistral(api_key=api_key)
|
|
13
|
-
self.model = "mistral-large-latest"
|
|
14
|
-
self.messages = []
|
|
15
9
|
|
|
16
10
|
def __enter__(self):
|
|
17
11
|
self.client.__enter__()
|
|
@@ -19,15 +13,3 @@ class Client(AbstractClient):
|
|
|
19
13
|
|
|
20
14
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
21
15
|
return self.client.__exit__(exc_type, exc_val, exc_tb)
|
|
22
|
-
|
|
23
|
-
def chat(self, *user_prompt, **kwargs):
|
|
24
|
-
response: ChatCompletionResponse = self.client.chat.complete(
|
|
25
|
-
model=self.model,
|
|
26
|
-
messages=[
|
|
27
|
-
*self.messages,
|
|
28
|
-
*[{"content": m, "role": "user"} for m in user_prompt]
|
|
29
|
-
], stream=False, response_format=ResponseFormat(type='text'),
|
|
30
|
-
n=self.n,
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
return on_response(response, self.n)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Literal, Union
|
|
2
|
+
|
|
3
|
+
from mistralai import Agent, ToolExecutionEntry, FunctionCallEntry, MessageOutputEntry, AgentHandoffEntry
|
|
4
|
+
|
|
5
|
+
from davidkhala.ai.mistral import Client as MistralClient
|
|
6
|
+
from davidkhala.ai.model.chat import messages_from
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Agents(MistralClient):
|
|
10
|
+
def __init__(self, api_key):
|
|
11
|
+
super().__init__(api_key)
|
|
12
|
+
self.instructions: str | None = None
|
|
13
|
+
self.model = None
|
|
14
|
+
|
|
15
|
+
def as_chat(self, model="mistral-large-latest", sys_prompt: str = None):
|
|
16
|
+
self.model = model
|
|
17
|
+
if sys_prompt is not None:
|
|
18
|
+
self.instructions = sys_prompt
|
|
19
|
+
|
|
20
|
+
def create(self, name,
|
|
21
|
+
*,
|
|
22
|
+
web_search: Literal["web_search", "web_search_premium"] = None
|
|
23
|
+
) -> Agent:
|
|
24
|
+
"""
|
|
25
|
+
:param name:
|
|
26
|
+
:param web_search:
|
|
27
|
+
"web_search_premium": beyond search engine, add news provider as source
|
|
28
|
+
:return:
|
|
29
|
+
"""
|
|
30
|
+
tools = []
|
|
31
|
+
if web_search:
|
|
32
|
+
tools.append({"type": web_search})
|
|
33
|
+
agent = self.client.beta.agents.create(
|
|
34
|
+
model=self.model,
|
|
35
|
+
name=name,
|
|
36
|
+
tools=tools,
|
|
37
|
+
instructions=self.instructions
|
|
38
|
+
)
|
|
39
|
+
return agent
|
|
40
|
+
|
|
41
|
+
def chat(self, agent_id: str, *user_prompt: str) -> tuple[
|
|
42
|
+
list[Union[ToolExecutionEntry, FunctionCallEntry, MessageOutputEntry, AgentHandoffEntry]],
|
|
43
|
+
str
|
|
44
|
+
]:
|
|
45
|
+
response = self.client.beta.conversations.start(
|
|
46
|
+
agent_id=agent_id,
|
|
47
|
+
inputs=messages_from(*user_prompt)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return response.outputs, response.conversation_id
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# https://github.com/mistralai/client-python
|
|
2
|
+
|
|
3
|
+
from mistralai import ResponseFormat
|
|
4
|
+
|
|
5
|
+
from davidkhala.ai.mistral import Client as MistralClient
|
|
6
|
+
from davidkhala.ai.model.embed import EmbeddingAware
|
|
7
|
+
from davidkhala.ai.model.chat import on_response, ChatAware
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Client(ChatAware, EmbeddingAware, MistralClient):
|
|
11
|
+
def __init__(self, api_key: str):
|
|
12
|
+
ChatAware.__init__(self)
|
|
13
|
+
MistralClient.__init__(self, api_key)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def as_chat(self, model="mistral-large-latest", sys_prompt: str = None):
|
|
17
|
+
super().as_chat(model, sys_prompt)
|
|
18
|
+
|
|
19
|
+
def as_embeddings(self, model="mistral-embed"):
|
|
20
|
+
super().as_embeddings(model)
|
|
21
|
+
|
|
22
|
+
def chat(self, *user_prompt, **kwargs):
|
|
23
|
+
response = self.client.chat.complete(
|
|
24
|
+
model=self.model,
|
|
25
|
+
messages=self.messages_from(*user_prompt), stream=False, response_format=ResponseFormat(type='text'),
|
|
26
|
+
n=self.n,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return on_response(response, self.n)
|
|
30
|
+
|
|
31
|
+
def encode(self, *_input: str) -> list[list[float]]:
|
|
32
|
+
res = self.client.embeddings.create(
|
|
33
|
+
model=self.model,
|
|
34
|
+
inputs=_input,
|
|
35
|
+
)
|
|
36
|
+
return [d.embedding for d in res.data]
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def models(self) -> list[str]:
|
|
40
|
+
return [_.id for _ in self.client.models.list().data]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from mistralai import UploadFileOut, FileSchema, ListFilesOut
|
|
4
|
+
|
|
5
|
+
from davidkhala.ai.mistral import Client as MistralClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Client(MistralClient):
|
|
9
|
+
def upload(self, path: Path, file_name=None) -> str:
|
|
10
|
+
"""
|
|
11
|
+
specific schema is required
|
|
12
|
+
- for [Text & Vision Fine-tuning](https://docs.mistral.ai/capabilities/finetuning/text_vision_finetuning)
|
|
13
|
+
- for [Classifier Factory](https://docs.mistral.ai/capabilities/finetuning/classifier_factory)
|
|
14
|
+
:param path:
|
|
15
|
+
:param file_name:
|
|
16
|
+
:return:
|
|
17
|
+
"""
|
|
18
|
+
if not file_name:
|
|
19
|
+
file_name = path.name
|
|
20
|
+
assert file_name.endswith(".jsonl"), "Data must be stored in JSON Lines (.jsonl) files"
|
|
21
|
+
with open(path, "rb") as content:
|
|
22
|
+
res: UploadFileOut = self.client.files.upload(file={
|
|
23
|
+
"file_name": file_name,
|
|
24
|
+
"content": content
|
|
25
|
+
})
|
|
26
|
+
return res.id
|
|
27
|
+
|
|
28
|
+
def paginate_files(self, page=0, size=100) -> ListFilesOut:
|
|
29
|
+
return self.client.files.list(page=page, page_size=size)
|
|
30
|
+
|
|
31
|
+
def ls(self, page_size=100) -> list[FileSchema]:
|
|
32
|
+
has_next = True
|
|
33
|
+
result = []
|
|
34
|
+
while has_next:
|
|
35
|
+
page = self.paginate_files(size=page_size)
|
|
36
|
+
has_next = page.total == page_size
|
|
37
|
+
result.extend(page.data)
|
|
38
|
+
return result
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from davidkhala.ml.ocr.interface import FieldProperties as BaseFieldProperties
|
|
6
|
+
from mistralai import ImageURLChunk, ResponseFormat, JSONSchema
|
|
7
|
+
|
|
8
|
+
from davidkhala.ai.mistral import Client as MistralClient
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FieldProperties(BaseFieldProperties):
|
|
12
|
+
description: str = ""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Client(MistralClient):
|
|
16
|
+
def process(self, file: Path, schema: dict[str, FieldProperties] = None) -> list[dict]|dict:
|
|
17
|
+
"""
|
|
18
|
+
Allowed formats are JPEG, PNG, WEBP, GIF, MPO, HEIF, AVIF, BMP, TIFF
|
|
19
|
+
"""
|
|
20
|
+
with open(file, "rb") as f:
|
|
21
|
+
content = base64.b64encode(f.read()).decode('utf-8')
|
|
22
|
+
options = {}
|
|
23
|
+
if schema:
|
|
24
|
+
required = [k for k, _ in schema.items() if _.required]
|
|
25
|
+
properties = {k: {'type': v.type, 'description': v.description} for k, v in schema.items()}
|
|
26
|
+
options['document_annotation_format'] = ResponseFormat(
|
|
27
|
+
type='json_schema',
|
|
28
|
+
json_schema=JSONSchema(
|
|
29
|
+
name='-',
|
|
30
|
+
schema_definition={
|
|
31
|
+
"required": required,
|
|
32
|
+
"properties": properties
|
|
33
|
+
},
|
|
34
|
+
strict=True
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
ocr_response = self.client.ocr.process(
|
|
39
|
+
model="mistral-ocr-latest",
|
|
40
|
+
document=ImageURLChunk(image_url=f"data:image/jpeg;base64,{content}"),
|
|
41
|
+
include_image_base64=True,
|
|
42
|
+
**options,
|
|
43
|
+
)
|
|
44
|
+
if schema:
|
|
45
|
+
return json.loads(ocr_response.document_annotation)
|
|
46
|
+
return [{'markdown': page.markdown, 'images': page.images} for page in ocr_response.pages]
|
davidkhala/ai/model/__init__.py
CHANGED
|
@@ -1,44 +1,28 @@
|
|
|
1
|
-
from
|
|
2
|
-
from typing import Protocol, TypedDict
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class MessageDict(TypedDict):
|
|
6
|
-
content: str | list
|
|
7
|
-
role: str
|
|
1
|
+
from typing import Protocol, Any
|
|
8
2
|
|
|
9
3
|
|
|
10
4
|
class ClientProtocol(Protocol):
|
|
11
5
|
api_key: str
|
|
12
6
|
base_url: str
|
|
13
|
-
model: str | None
|
|
14
|
-
messages: list[MessageDict] | None
|
|
15
|
-
|
|
16
7
|
|
|
17
|
-
class AbstractClient(ABC, ClientProtocol):
|
|
18
8
|
|
|
9
|
+
class ModelAware:
|
|
19
10
|
def __init__(self):
|
|
20
|
-
self.model = None
|
|
21
|
-
self.messages = []
|
|
11
|
+
self.model: str | None = None
|
|
22
12
|
|
|
23
|
-
def as_chat(self, model: str, sys_prompt: str = None):
|
|
24
|
-
self.model = model
|
|
25
|
-
if sys_prompt is not None:
|
|
26
|
-
self.messages = [MessageDict(role='system', content=sys_prompt)]
|
|
27
13
|
|
|
28
|
-
|
|
29
|
-
|
|
14
|
+
class SDKProtocol(Protocol):
|
|
15
|
+
client: Any
|
|
30
16
|
|
|
31
|
-
def chat(self, *user_prompt, **kwargs):
|
|
32
|
-
...
|
|
33
17
|
|
|
34
|
-
|
|
35
|
-
|
|
18
|
+
class Connectable:
|
|
19
|
+
def connect(self) -> bool: ...
|
|
36
20
|
|
|
37
|
-
def
|
|
38
|
-
...
|
|
21
|
+
def close(self): ...
|
|
39
22
|
|
|
40
|
-
def
|
|
41
|
-
|
|
23
|
+
def __enter__(self):
|
|
24
|
+
assert self.connect()
|
|
25
|
+
return self
|
|
42
26
|
|
|
43
27
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
44
28
|
self.close()
|