uncertainty-engine-types 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- uncertainty_engine_types-0.0.2/PKG-INFO +126 -0
- uncertainty_engine_types-0.0.2/README.md +108 -0
- uncertainty_engine_types-0.0.2/pyproject.toml +30 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/__init__.py +42 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/execution_error.py +3 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/graph.py +17 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/handle.py +15 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/llm.py +111 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/message.py +9 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/model.py +39 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/node_info.py +30 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/sensor_designer.py +19 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/sql.py +84 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/tabular_data.py +13 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/token.py +6 -0
- uncertainty_engine_types-0.0.2/uncertainty_engine_types/vector_store.py +230 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: uncertainty-engine-types
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Common type definitions for the Uncertainty Engine
|
|
5
|
+
Author: Freddy Wordingham
|
|
6
|
+
Author-email: freddy@digilab.ai
|
|
7
|
+
Requires-Python: >=3.10,<4.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
14
|
+
Requires-Dist: pydantic (>=2.10.5,<3.0.0)
|
|
15
|
+
Requires-Dist: typeguard (>=2.13.3,<2.14.0)
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
<div style="color: black; background-color: #edf497;">
|
|
19
|
+
<img src="./assets/images/uncertainty-engine-logo.png">
|
|
20
|
+
</div>
|
|
21
|
+
|
|
22
|
+
# Types
|
|
23
|
+
|
|
24
|
+
Common types definitions for the Uncertainty Engine.
|
|
25
|
+
This library should be used by other packages to ensure consistency in the types used across the Uncertainty Engine.
|
|
26
|
+
|
|
27
|
+
## Overview
|
|
28
|
+
|
|
29
|
+
### Execution & Error Handling
|
|
30
|
+
|
|
31
|
+
- **ExecutionError**
|
|
32
|
+
Exception raised to indicate execution errors.
|
|
33
|
+
|
|
34
|
+
### Graph & Node Types
|
|
35
|
+
|
|
36
|
+
- **Graph**
|
|
37
|
+
Represents a collection of nodes and their connections.
|
|
38
|
+
- **NodeElement**
|
|
39
|
+
Defines a node with a type and associated inputs.
|
|
40
|
+
- **NodeId**
|
|
41
|
+
A unique identifier for nodes.
|
|
42
|
+
- **SourceHandle** & **TargetHandle**
|
|
43
|
+
Strings used to reference node connections.
|
|
44
|
+
|
|
45
|
+
### Node Handles
|
|
46
|
+
|
|
47
|
+
- **Handle**
|
|
48
|
+
Represents a node handle in the format `node.handle` and validates this structure.
|
|
49
|
+
|
|
50
|
+
### Language Learning Models (LLMs)
|
|
51
|
+
|
|
52
|
+
- **LLM**
|
|
53
|
+
Abstract base class for language learning models.
|
|
54
|
+
- **OpenAILLM**
|
|
55
|
+
LLM implementation using OpenAI.
|
|
56
|
+
- **OllamaLLM**
|
|
57
|
+
LLM implementation using Ollama.
|
|
58
|
+
- **LLMProvider**
|
|
59
|
+
Enum listing supported LLM providers.
|
|
60
|
+
- **LLMManager**
|
|
61
|
+
Manages connections to LLMs based on the chosen provider and configuration.
|
|
62
|
+
|
|
63
|
+
### Messaging
|
|
64
|
+
|
|
65
|
+
- **Message**
|
|
66
|
+
Represents a message with a role and content, used for interactions with LLMs.
|
|
67
|
+
|
|
68
|
+
### TwinLab Models
|
|
69
|
+
|
|
70
|
+
- **TwinLabModel**
|
|
71
|
+
Represents a model configuration including metadata.
|
|
72
|
+
- **save_model**
|
|
73
|
+
Function to persist a model configuration.
|
|
74
|
+
|
|
75
|
+
### Node Metadata
|
|
76
|
+
|
|
77
|
+
- **NodeInputInfo**
|
|
78
|
+
Describes the properties of a node's input.
|
|
79
|
+
- **NodeOutputInfo**
|
|
80
|
+
Describes the properties of a node's output.
|
|
81
|
+
- **NodeInfo**
|
|
82
|
+
Aggregates metadata for a node, including inputs and outputs.
|
|
83
|
+
|
|
84
|
+
### Sensor Design
|
|
85
|
+
|
|
86
|
+
- **SensorDesigner**
|
|
87
|
+
Defines sensor configuration and provides functionality to load sensor data.
|
|
88
|
+
- **save_sensor_designer**
|
|
89
|
+
Function to persist a sensor designer configuration.
|
|
90
|
+
|
|
91
|
+
### SQL Database Types
|
|
92
|
+
|
|
93
|
+
- **SQLDatabase**
|
|
94
|
+
Abstract base class for executing SQL queries.
|
|
95
|
+
- **PostgreSQL**
|
|
96
|
+
Implementation of SQLDatabase for PostgreSQL.
|
|
97
|
+
- **SQLKind**
|
|
98
|
+
Enum listing supported SQL database types.
|
|
99
|
+
- **SQLManager**
|
|
100
|
+
Manages connections and operations for SQL databases.
|
|
101
|
+
|
|
102
|
+
### Tabular Data
|
|
103
|
+
|
|
104
|
+
- **TabularData**
|
|
105
|
+
Represents CSV-based data and includes functionality to load it into a pandas DataFrame.
|
|
106
|
+
|
|
107
|
+
### Token Types
|
|
108
|
+
|
|
109
|
+
- **Token**
|
|
110
|
+
Enum representing token types, such as TRAINING and STANDARD.
|
|
111
|
+
|
|
112
|
+
### Vector Stores
|
|
113
|
+
|
|
114
|
+
- **VectorStoreConnection**
|
|
115
|
+
Abstract base class for vector store operations.
|
|
116
|
+
- **WeaviateVectorStoreConnection**
|
|
117
|
+
Implements a connection to a Weaviate vector store.
|
|
118
|
+
- **VectorStoreProvider**
|
|
119
|
+
Enum for supported vector store providers.
|
|
120
|
+
- **VectorStoreManager**
|
|
121
|
+
Manages connections to vector stores.
|
|
122
|
+
- **get_persistent_vector_store**
|
|
123
|
+
Function to establish a persistent connection to a Weaviate vector store.
|
|
124
|
+
- **get_embedding_function**
|
|
125
|
+
Retrieves an embedding function based on configuration, supporting both HuggingFace and OpenAI options.
|
|
126
|
+
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
<div style="color: black; background-color: #edf497;">
|
|
2
|
+
<img src="./assets/images/uncertainty-engine-logo.png">
|
|
3
|
+
</div>
|
|
4
|
+
|
|
5
|
+
# Types
|
|
6
|
+
|
|
7
|
+
Common types definitions for the Uncertainty Engine.
|
|
8
|
+
This library should be used by other packages to ensure consistency in the types used across the Uncertainty Engine.
|
|
9
|
+
|
|
10
|
+
## Overview
|
|
11
|
+
|
|
12
|
+
### Execution & Error Handling
|
|
13
|
+
|
|
14
|
+
- **ExecutionError**
|
|
15
|
+
Exception raised to indicate execution errors.
|
|
16
|
+
|
|
17
|
+
### Graph & Node Types
|
|
18
|
+
|
|
19
|
+
- **Graph**
|
|
20
|
+
Represents a collection of nodes and their connections.
|
|
21
|
+
- **NodeElement**
|
|
22
|
+
Defines a node with a type and associated inputs.
|
|
23
|
+
- **NodeId**
|
|
24
|
+
A unique identifier for nodes.
|
|
25
|
+
- **SourceHandle** & **TargetHandle**
|
|
26
|
+
Strings used to reference node connections.
|
|
27
|
+
|
|
28
|
+
### Node Handles
|
|
29
|
+
|
|
30
|
+
- **Handle**
|
|
31
|
+
Represents a node handle in the format `node.handle` and validates this structure.
|
|
32
|
+
|
|
33
|
+
### Language Learning Models (LLMs)
|
|
34
|
+
|
|
35
|
+
- **LLM**
|
|
36
|
+
Abstract base class for language learning models.
|
|
37
|
+
- **OpenAILLM**
|
|
38
|
+
LLM implementation using OpenAI.
|
|
39
|
+
- **OllamaLLM**
|
|
40
|
+
LLM implementation using Ollama.
|
|
41
|
+
- **LLMProvider**
|
|
42
|
+
Enum listing supported LLM providers.
|
|
43
|
+
- **LLMManager**
|
|
44
|
+
Manages connections to LLMs based on the chosen provider and configuration.
|
|
45
|
+
|
|
46
|
+
### Messaging
|
|
47
|
+
|
|
48
|
+
- **Message**
|
|
49
|
+
Represents a message with a role and content, used for interactions with LLMs.
|
|
50
|
+
|
|
51
|
+
### TwinLab Models
|
|
52
|
+
|
|
53
|
+
- **TwinLabModel**
|
|
54
|
+
Represents a model configuration including metadata.
|
|
55
|
+
- **save_model**
|
|
56
|
+
Function to persist a model configuration.
|
|
57
|
+
|
|
58
|
+
### Node Metadata
|
|
59
|
+
|
|
60
|
+
- **NodeInputInfo**
|
|
61
|
+
Describes the properties of a node's input.
|
|
62
|
+
- **NodeOutputInfo**
|
|
63
|
+
Describes the properties of a node's output.
|
|
64
|
+
- **NodeInfo**
|
|
65
|
+
Aggregates metadata for a node, including inputs and outputs.
|
|
66
|
+
|
|
67
|
+
### Sensor Design
|
|
68
|
+
|
|
69
|
+
- **SensorDesigner**
|
|
70
|
+
Defines sensor configuration and provides functionality to load sensor data.
|
|
71
|
+
- **save_sensor_designer**
|
|
72
|
+
Function to persist a sensor designer configuration.
|
|
73
|
+
|
|
74
|
+
### SQL Database Types
|
|
75
|
+
|
|
76
|
+
- **SQLDatabase**
|
|
77
|
+
Abstract base class for executing SQL queries.
|
|
78
|
+
- **PostgreSQL**
|
|
79
|
+
Implementation of SQLDatabase for PostgreSQL.
|
|
80
|
+
- **SQLKind**
|
|
81
|
+
Enum listing supported SQL database types.
|
|
82
|
+
- **SQLManager**
|
|
83
|
+
Manages connections and operations for SQL databases.
|
|
84
|
+
|
|
85
|
+
### Tabular Data
|
|
86
|
+
|
|
87
|
+
- **TabularData**
|
|
88
|
+
Represents CSV-based data and includes functionality to load it into a pandas DataFrame.
|
|
89
|
+
|
|
90
|
+
### Token Types
|
|
91
|
+
|
|
92
|
+
- **Token**
|
|
93
|
+
Enum representing token types, such as TRAINING and STANDARD.
|
|
94
|
+
|
|
95
|
+
### Vector Stores
|
|
96
|
+
|
|
97
|
+
- **VectorStoreConnection**
|
|
98
|
+
Abstract base class for vector store operations.
|
|
99
|
+
- **WeaviateVectorStoreConnection**
|
|
100
|
+
Implements a connection to a Weaviate vector store.
|
|
101
|
+
- **VectorStoreProvider**
|
|
102
|
+
Enum for supported vector store providers.
|
|
103
|
+
- **VectorStoreManager**
|
|
104
|
+
Manages connections to vector stores.
|
|
105
|
+
- **get_persistent_vector_store**
|
|
106
|
+
Function to establish a persistent connection to a Weaviate vector store.
|
|
107
|
+
- **get_embedding_function**
|
|
108
|
+
Retrieves an embedding function based on configuration, supporting both HuggingFace and OpenAI options.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "uncertainty-engine-types"
|
|
3
|
+
version = "0.0.2"
|
|
4
|
+
description = "Common type definitions for the Uncertainty Engine"
|
|
5
|
+
authors = [
|
|
6
|
+
{ name = "Freddy Wordingham", email = "freddy@digilab.ai" },
|
|
7
|
+
{ name = "Jasper Cantwell", email = "jasper@digilab.ai" },
|
|
8
|
+
{ name = "Jamie Donald-McCann", email = "jamie.donald-mccann@digilab.ai" },
|
|
9
|
+
]
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
|
|
12
|
+
[tool.poetry.dependencies]
|
|
13
|
+
pandas = "^2.2.3"
|
|
14
|
+
pydantic = "^2.10.5"
|
|
15
|
+
python = ">=3.10,<4.0"
|
|
16
|
+
typeguard = ">=2.13.3,<2.14.0"
|
|
17
|
+
|
|
18
|
+
[tool.poetry.group.dev.dependencies]
|
|
19
|
+
black = "^24.10.0"
|
|
20
|
+
flake8 = "^7.1.1"
|
|
21
|
+
mypy = "^1.13.0"
|
|
22
|
+
pylint = "^3.3.2"
|
|
23
|
+
pytest = "^8.3.4"
|
|
24
|
+
pytest-cov = "^6.0.0"
|
|
25
|
+
pytest-mock = "^3.14.0"
|
|
26
|
+
radon = "^6.0.1"
|
|
27
|
+
|
|
28
|
+
[build-system]
|
|
29
|
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
|
30
|
+
build-backend = "poetry.core.masonry.api"
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from .execution_error import ExecutionError
|
|
2
|
+
from .graph import Graph, NodeElement, NodeId, SourceHandle, TargetHandle
|
|
3
|
+
from .handle import Handle
|
|
4
|
+
from .llm import LLM, LLMManager, LLMProvider
|
|
5
|
+
from .message import Message
|
|
6
|
+
from .model import TwinLabModel, save_model
|
|
7
|
+
from .node_info import NodeInfo, NodeInputInfo, NodeOutputInfo
|
|
8
|
+
from .sensor_designer import SensorDesigner, save_sensor_designer
|
|
9
|
+
from .sql import SQLDatabase, SQLKind, SQLManager
|
|
10
|
+
from .tabular_data import TabularData
|
|
11
|
+
from .token import Token
|
|
12
|
+
from .vector_store import VectorStoreConnection, VectorStoreManager, VectorStoreProvider
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"ExecutionError",
|
|
16
|
+
"Graph",
|
|
17
|
+
"Handle",
|
|
18
|
+
"LLM",
|
|
19
|
+
"LLMManager",
|
|
20
|
+
"LLMProvider",
|
|
21
|
+
"Message",
|
|
22
|
+
"Message",
|
|
23
|
+
"NodeElement",
|
|
24
|
+
"NodeId",
|
|
25
|
+
"NodeInfo",
|
|
26
|
+
"NodeInputInfo",
|
|
27
|
+
"NodeOutputInfo",
|
|
28
|
+
"save_model",
|
|
29
|
+
"save_sensor_designer",
|
|
30
|
+
"SensorDesigner",
|
|
31
|
+
"SourceHandle",
|
|
32
|
+
"SQLDatabase",
|
|
33
|
+
"SQLKind",
|
|
34
|
+
"SQLManager",
|
|
35
|
+
"TabularData",
|
|
36
|
+
"TargetHandle",
|
|
37
|
+
"Token",
|
|
38
|
+
"TwinLabModel",
|
|
39
|
+
"VectorStoreConnection",
|
|
40
|
+
"VectorStoreManager",
|
|
41
|
+
"VectorStoreProvider",
|
|
42
|
+
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
from .handle import Handle
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
NodeId = str
|
|
7
|
+
TargetHandle = str
|
|
8
|
+
SourceHandle = str
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NodeElement(BaseModel):
|
|
12
|
+
type: str
|
|
13
|
+
inputs: dict[TargetHandle, Handle] = {}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Graph(BaseModel):
|
|
17
|
+
nodes: dict[NodeId, NodeElement]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Handle(BaseModel):
|
|
5
|
+
node_name: str
|
|
6
|
+
node_handle: str
|
|
7
|
+
|
|
8
|
+
def __init__(self, handle_str: str):
|
|
9
|
+
if handle_str.count(".") != 1:
|
|
10
|
+
raise ValueError(
|
|
11
|
+
"Handle string must contain exactly one dot ('.') separating node and handle"
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
node_name, node_handle = handle_str.split(".")
|
|
15
|
+
super().__init__(node_name=node_name, node_handle=node_handle)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
from typeguard import typechecked
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@typechecked
|
|
10
|
+
class LLM(ABC):
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def temperature(self) -> float:
|
|
14
|
+
return self._temperature
|
|
15
|
+
|
|
16
|
+
@temperature.setter
|
|
17
|
+
def temperature(self, value: float) -> None:
|
|
18
|
+
if not 0.0 <= value <= 1.0:
|
|
19
|
+
raise ValueError("Temperature must be between 0.0 and 1.0")
|
|
20
|
+
self._temperature = value
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def get_langchain_llm(self, query: str) -> "LLM":
|
|
24
|
+
"""
|
|
25
|
+
Return a BaseLLM object from Langchain according to the provider type.
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def run_query(self, query: str) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Call the LLM with a query.
|
|
32
|
+
"""
|
|
33
|
+
llm = self.get_langchain_llm()
|
|
34
|
+
return llm.invoke(query).content
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@typechecked
|
|
38
|
+
class OpenAILLM(LLM):
|
|
39
|
+
|
|
40
|
+
def __init__(self, api_key: str, model: str, temperature: float = 0.0):
|
|
41
|
+
self.api_key = api_key
|
|
42
|
+
self.model = model
|
|
43
|
+
self.temperature = temperature
|
|
44
|
+
|
|
45
|
+
def get_langchain_llm(self):
|
|
46
|
+
from langchain_openai import ChatOpenAI
|
|
47
|
+
|
|
48
|
+
return ChatOpenAI(
|
|
49
|
+
model=self.model,
|
|
50
|
+
temperature=self.temperature,
|
|
51
|
+
api_key=self.api_key,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@typechecked
|
|
56
|
+
class OllamaLLM(LLM):
|
|
57
|
+
def __init__(self, url: str, model: str, temperature: float = 0.0):
|
|
58
|
+
self.url = url
|
|
59
|
+
self.model = model
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
|
|
62
|
+
def get_langchain_llm(self):
|
|
63
|
+
from langchain_ollama import ChatOllama
|
|
64
|
+
|
|
65
|
+
return ChatOllama(
|
|
66
|
+
base_url=self.url,
|
|
67
|
+
model=self.model,
|
|
68
|
+
temperature=self.temperature,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@typechecked
|
|
73
|
+
class LLMProvider(Enum):
|
|
74
|
+
OPENAI = "openai"
|
|
75
|
+
OLLAMA = "ollama"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@typechecked
|
|
79
|
+
class LLMManager(BaseModel):
|
|
80
|
+
"""
|
|
81
|
+
Connection manager for Language Learning Models (LLMs).
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
url: str
|
|
85
|
+
provider: str
|
|
86
|
+
model: str
|
|
87
|
+
temperature: float = 0.0
|
|
88
|
+
api_key: Optional[str] = None
|
|
89
|
+
|
|
90
|
+
def connect(self) -> LLM:
|
|
91
|
+
"""
|
|
92
|
+
Connect to the LLM.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
match self.provider:
|
|
96
|
+
case LLMProvider.OPENAI.value:
|
|
97
|
+
if self.api_key is None:
|
|
98
|
+
raise ValueError("API key is required for OpenAI LLM")
|
|
99
|
+
return OpenAILLM(
|
|
100
|
+
api_key=self.api_key,
|
|
101
|
+
model=self.model,
|
|
102
|
+
temperature=self.temperature,
|
|
103
|
+
)
|
|
104
|
+
case LLMProvider.OLLAMA.value:
|
|
105
|
+
return OllamaLLM(
|
|
106
|
+
url=self.url,
|
|
107
|
+
model=self.model,
|
|
108
|
+
temperature=self.temperature,
|
|
109
|
+
)
|
|
110
|
+
case _:
|
|
111
|
+
raise ValueError(f"Unknown LLM provider: {self.provider}")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Any, Dict, Tuple
|
|
2
|
+
import json
|
|
3
|
+
import tempfile
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TwinLabModel(BaseModel):
|
|
9
|
+
model_type: str
|
|
10
|
+
config: dict
|
|
11
|
+
metadata: dict
|
|
12
|
+
|
|
13
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
14
|
+
|
|
15
|
+
def load_model(self) -> Tuple[Any, Dict]:
|
|
16
|
+
|
|
17
|
+
from twinlab_models.models import model_type_from_str # type: ignore
|
|
18
|
+
|
|
19
|
+
model_type = self.model_type
|
|
20
|
+
tl_model = model_type_from_str(model_type)
|
|
21
|
+
|
|
22
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f:
|
|
23
|
+
# Write the JSON to the temporary file
|
|
24
|
+
json.dump(self.model_dump(), f) # Changed dumps to dump
|
|
25
|
+
f.flush()
|
|
26
|
+
tl_model, meta_data = tl_model.load(f.name)
|
|
27
|
+
|
|
28
|
+
return tl_model, meta_data
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# TODO: Should this be a method of the type?
|
|
32
|
+
def save_model(model, meta_data: dict) -> TwinLabModel:
|
|
33
|
+
|
|
34
|
+
with tempfile.NamedTemporaryFile(mode="r", suffix=".json") as f:
|
|
35
|
+
model.save(f.name, meta_data)
|
|
36
|
+
f.seek(0)
|
|
37
|
+
config = json.load(f)
|
|
38
|
+
|
|
39
|
+
return TwinLabModel(**config)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NodeInputInfo(BaseModel):
|
|
7
|
+
label: str
|
|
8
|
+
type: str
|
|
9
|
+
description: str
|
|
10
|
+
required: bool = True
|
|
11
|
+
set_in_node: bool = True
|
|
12
|
+
default: Optional[Any] = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NodeOutputInfo(BaseModel):
|
|
16
|
+
label: str
|
|
17
|
+
type: str
|
|
18
|
+
description: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NodeInfo(BaseModel):
|
|
22
|
+
id: str
|
|
23
|
+
version: int
|
|
24
|
+
label: str
|
|
25
|
+
category: str
|
|
26
|
+
description: str
|
|
27
|
+
long_description: str
|
|
28
|
+
cost: int
|
|
29
|
+
inputs: dict[str, NodeInputInfo]
|
|
30
|
+
outputs: dict[str, NodeOutputInfo] = {}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SensorDesigner(BaseModel):
|
|
7
|
+
bed: dict
|
|
8
|
+
|
|
9
|
+
def load_sensor_designer(self):
|
|
10
|
+
|
|
11
|
+
from twinlab_bed.BED import BED
|
|
12
|
+
|
|
13
|
+
bed = BED.from_json(json.dumps(self.bed))
|
|
14
|
+
|
|
15
|
+
return bed
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def save_sensor_designer(bed) -> SensorDesigner:
|
|
19
|
+
return SensorDesigner(bed=json.loads(bed.to_json()))
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
from typeguard import typechecked
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SQLDatabase(ABC):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def execute(self, query: str) -> list[tuple[Any, ...]]:
|
|
12
|
+
"""
|
|
13
|
+
Execute the given query.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
query (str): The SQL query to execute.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
str: The result of the query.
|
|
20
|
+
"""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PostgreSQL(SQLDatabase):
|
|
25
|
+
@typechecked
|
|
26
|
+
def __init__(
|
|
27
|
+
self, host: str, username: str, password: str, port: int, database: str
|
|
28
|
+
):
|
|
29
|
+
import psycopg2
|
|
30
|
+
|
|
31
|
+
self.connection = psycopg2.connect(
|
|
32
|
+
host=host, user=username, password=password, port=port, database=database
|
|
33
|
+
)
|
|
34
|
+
self.cursor = self.connection.cursor()
|
|
35
|
+
|
|
36
|
+
@typechecked
|
|
37
|
+
def execute(self, query: str) -> list[tuple[Any, ...]]:
|
|
38
|
+
"""
|
|
39
|
+
Execute the given query.
|
|
40
|
+
|
|
41
|
+
Parameters:
|
|
42
|
+
query (str): The SQL query to execute.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
str: The result of the query.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
self.cursor.execute(query)
|
|
49
|
+
return self.cursor.fetchall()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SQLKind(Enum):
|
|
53
|
+
POSTGRES = "POSTGRES"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SQLManager(BaseModel):
|
|
57
|
+
"""
|
|
58
|
+
Connection manager for SQL databases.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
kind: SQLKind
|
|
62
|
+
host: str
|
|
63
|
+
username: str
|
|
64
|
+
password: str
|
|
65
|
+
port: int
|
|
66
|
+
database: str
|
|
67
|
+
|
|
68
|
+
@typechecked
|
|
69
|
+
def connect(self) -> SQLDatabase:
|
|
70
|
+
"""
|
|
71
|
+
Connect to the SQL database.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
match self.kind:
|
|
75
|
+
case SQLKind.POSTGRES:
|
|
76
|
+
return PostgreSQL(
|
|
77
|
+
host=self.host,
|
|
78
|
+
username=self.username,
|
|
79
|
+
password=self.password,
|
|
80
|
+
port=self.port,
|
|
81
|
+
database=self.database,
|
|
82
|
+
)
|
|
83
|
+
case _:
|
|
84
|
+
raise ValueError("Unsupported SQL kind")
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Optional, Literal, List, Dict
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
from typeguard import typechecked
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VectorStoreProvider(Enum):
|
|
10
|
+
WEAVIATE = "weaviate"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VectorStoreConnection(ABC):
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def ingest(self, texts, metadatas):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def retrieve(self, query, k):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def get_vector_store(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def close(self):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@typechecked
|
|
32
|
+
class WeaviateVectorStoreConnection(VectorStoreConnection):
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
host: str,
|
|
37
|
+
port: str,
|
|
38
|
+
collection: str,
|
|
39
|
+
embedding_type: str,
|
|
40
|
+
embedding_model: str,
|
|
41
|
+
embedding_api_key: str,
|
|
42
|
+
):
|
|
43
|
+
|
|
44
|
+
self.host = host
|
|
45
|
+
self.port = port
|
|
46
|
+
self.collection = collection
|
|
47
|
+
self.embedding_type = embedding_type
|
|
48
|
+
self.embedding_model = embedding_model
|
|
49
|
+
self.embedding_api_key = embedding_api_key
|
|
50
|
+
self.vector_store = get_persistent_vector_store(
|
|
51
|
+
host, port, collection, embedding_type, embedding_model, embedding_api_key
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def ingest(self, texts: List[str], metadatas: List[Dict]) -> List[str]:
|
|
55
|
+
"""
|
|
56
|
+
Ingest a list of texts into the vector store.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
texts (List[str]): The texts to ingest.
|
|
60
|
+
metadatas (List[Dict]): The metadata associated with each text.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List[str]: The IDs of the ingested texts.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
return self.vector_store.add_texts(texts=texts, metadatas=metadatas)
|
|
68
|
+
|
|
69
|
+
def retrieve(self, query: Optional[str], k: int) -> List[Dict]:
|
|
70
|
+
"""
|
|
71
|
+
Retrieve the k most relevant documents to a query.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
query (Optional[str]): The query to retrieve documents for.
|
|
75
|
+
k (int): The number of documents to retrieve.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
List[str]: The IDs of the retrieved documents.
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
docs = self.vector_store.max_marginal_relevance_search(query=query, k=k)
|
|
83
|
+
docs_list = [
|
|
84
|
+
{"content": doc.page_content, "metadata": doc.metadata} for doc in docs
|
|
85
|
+
]
|
|
86
|
+
return docs_list
|
|
87
|
+
|
|
88
|
+
def get_vector_store(self):
|
|
89
|
+
"""
|
|
90
|
+
Get the underlying vector store.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
LangcahinWeaviateVectorStore: The underlying vector store.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
return self.vector_store
|
|
98
|
+
|
|
99
|
+
def close(self):
|
|
100
|
+
self.vector_store._client.close()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class VectorStoreManager(BaseModel):
|
|
104
|
+
"""
|
|
105
|
+
Connection manager for a vector store.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
provider: str
|
|
109
|
+
host: str
|
|
110
|
+
port: str = "8080"
|
|
111
|
+
collection: str = "DefaultCollection"
|
|
112
|
+
embedding_type: str
|
|
113
|
+
embedding_model: str
|
|
114
|
+
embedding_api_key: str
|
|
115
|
+
|
|
116
|
+
@typechecked
|
|
117
|
+
def connect(self) -> WeaviateVectorStoreConnection:
|
|
118
|
+
"""
|
|
119
|
+
Connect to the vector store.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
VectorStoreConnection: The vector store connection.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
match self.provider:
|
|
126
|
+
case VectorStoreProvider.WEAVIATE.value:
|
|
127
|
+
return WeaviateVectorStoreConnection(
|
|
128
|
+
host=self.host,
|
|
129
|
+
port=self.port,
|
|
130
|
+
collection=self.collection,
|
|
131
|
+
embedding_type=self.embedding_type,
|
|
132
|
+
embedding_model=self.embedding_model,
|
|
133
|
+
embedding_api_key=self.embedding_api_key,
|
|
134
|
+
)
|
|
135
|
+
case _:
|
|
136
|
+
raise ValueError(f"Unknown vector store provider: {self.provider}")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@typechecked
|
|
140
|
+
def get_persistent_vector_store(
|
|
141
|
+
host: str,
|
|
142
|
+
port: str,
|
|
143
|
+
collection: str,
|
|
144
|
+
embedding_type: str,
|
|
145
|
+
embedding_model: str,
|
|
146
|
+
embedding_api_key: Optional[str] = None,
|
|
147
|
+
):
|
|
148
|
+
"""
|
|
149
|
+
Get a database client connected to a deployed Weaviate vector store
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
from langchain_weaviate import WeaviateVectorStore
|
|
153
|
+
from weaviate.connect import ConnectionParams
|
|
154
|
+
import weaviate
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
client = weaviate.WeaviateClient(
|
|
158
|
+
connection_params=ConnectionParams.from_params(
|
|
159
|
+
http_host=host,
|
|
160
|
+
http_port=port,
|
|
161
|
+
http_secure=False,
|
|
162
|
+
grpc_host=host,
|
|
163
|
+
grpc_port="50051",
|
|
164
|
+
grpc_secure=False,
|
|
165
|
+
),
|
|
166
|
+
skip_init_checks=True,
|
|
167
|
+
)
|
|
168
|
+
client.connect()
|
|
169
|
+
except Exception as e:
|
|
170
|
+
raise ValueError(f"Failed to connect to Weaviate: {e}")
|
|
171
|
+
|
|
172
|
+
embedding_function = get_embedding_function(
|
|
173
|
+
embedding_type, embedding_model, embedding_api_key
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# check if a collection exists and makes one if it doesn't
|
|
177
|
+
if not client.collections.exists(collection):
|
|
178
|
+
client.collections.create(collection)
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
vector_store = WeaviateVectorStore(
|
|
182
|
+
client=client,
|
|
183
|
+
index_name=collection,
|
|
184
|
+
text_key="text",
|
|
185
|
+
embedding=embedding_function,
|
|
186
|
+
)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
raise ValueError(f"Failed to initialize Weaviate vector store: {e}")
|
|
189
|
+
|
|
190
|
+
return vector_store
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@typechecked
|
|
194
|
+
def get_embedding_function(
|
|
195
|
+
embedding_type: Literal["huggingface", "openai"],
|
|
196
|
+
embedding_model: str,
|
|
197
|
+
embedding_api_key: Optional[str] = None,
|
|
198
|
+
):
|
|
199
|
+
"""
|
|
200
|
+
Get an embedding function based on the specified type and configuration.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
from langchain_huggingface import HuggingFaceEmbeddings
|
|
204
|
+
from langchain_openai import OpenAIEmbeddings
|
|
205
|
+
|
|
206
|
+
match embedding_type:
|
|
207
|
+
case "huggingface":
|
|
208
|
+
try:
|
|
209
|
+
return HuggingFaceEmbeddings(model_name=embedding_model)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Failed to initialize HuggingFace embedding model: {str(e)}"
|
|
213
|
+
) from e
|
|
214
|
+
|
|
215
|
+
case "openai":
|
|
216
|
+
if not embedding_api_key:
|
|
217
|
+
raise ValueError("OpenAI embeddings require an API key")
|
|
218
|
+
try:
|
|
219
|
+
return OpenAIEmbeddings(
|
|
220
|
+
model=embedding_model, api_key=embedding_api_key
|
|
221
|
+
)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Failed to initialize OpenAI embedding model: {str(e)}"
|
|
225
|
+
) from e
|
|
226
|
+
|
|
227
|
+
case _:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Embedding type must be one of ['huggingface', 'openai']. Got {embedding_type}"
|
|
230
|
+
)
|