langchain-timbr 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,187 @@
1
+ from typing import Optional, Union, Dict, Any
2
+ from langchain.chains.base import Chain
3
+ from langchain.llms.base import LLM
4
+
5
+ from ..utils.general import parse_list, to_integer, to_boolean
6
+ from ..utils.timbr_llm_utils import generate_sql
7
+ from ..utils.timbr_utils import validate_sql
8
+
9
+
10
+ class ValidateTimbrSqlChain(Chain):
11
+ """
12
+ LangChain chain for validating SQL queries against Timbr knowledge graph schemas.
13
+
14
+ This chain validates SQL queries to ensure they are syntactically correct and
15
+ compatible with the target Timbr ontology/knowledge graph structure. It uses an LLM
16
+ for validation and connects to Timbr via URL and token.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ llm: LLM,
22
+ url: str,
23
+ token: str,
24
+ ontology: str,
25
+ schema: Optional[str] = 'dtimbr',
26
+ concept: Optional[str] = None,
27
+ retries: Optional[int] = 3,
28
+ concepts_list: Optional[Union[list[str], str]] = None,
29
+ views_list: Optional[Union[list[str], str]] = None,
30
+ include_logic_concepts: Optional[bool] = False,
31
+ include_tags: Optional[Union[list[str], str]] = None,
32
+ exclude_properties: Optional[Union[list[str], str]] = ['entity_id', 'entity_type', 'entity_label'],
33
+ max_limit: Optional[int] = 500,
34
+ note: Optional[str] = '',
35
+ db_is_case_sensitive: Optional[bool] = False,
36
+ graph_depth: Optional[int] = 1,
37
+ verify_ssl: Optional[bool] = True,
38
+ is_jwt: Optional[bool] = False,
39
+ jwt_tenant_id: Optional[str] = None,
40
+ conn_params: Optional[dict] = None,
41
+ debug: Optional[bool] = False,
42
+ **kwargs,
43
+ ):
44
+ """
45
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
46
+ :param url: Timbr server url
47
+ :param token: Timbr password or token value
48
+ :param ontology: The name of the ontology/knowledge graph
49
+ :param schema: The name of the schema to query
50
+ :param concept: The name of the concept to query
51
+ :param retries: The maximum number of retries to attempt
52
+ :param concepts_list: Optional specific concept options to query
53
+ :param views_list: Optional specific view options to query
54
+ :param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
55
+ :param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
56
+ :param exclude_properties: Optional specific properties to exclude from the query (entity_id, entity_type & entity_label by default).
57
+ :param max_limit: Maximum number of rows to query
58
+ :param note: Optional additional note to extend our llm prompt
59
+ :param db_is_case_sensitive: Whether the database is case sensitive (default is False).
60
+ :param graph_depth: Maximum number of relationship hops to traverse from the source concept during schema exploration (default is 1).
61
+ :param verify_ssl: Whether to verify SSL certificates (default is True).
62
+ :param is_jwt: Whether to use JWT authentication (default is False).
63
+ :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
64
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
65
+ :param kwargs: Additional arguments to pass to the base
66
+
67
+ ## Example
68
+ ```
69
+ validate_timbr_sql_chain = ValidateTimbrSqlChain(
70
+ url=<url>,
71
+ token=<token>,
72
+ llm=<llm or timbr_llm_wrapper instance>,
73
+ ontology=<ontology_name>,
74
+ schema=<schema_name>,
75
+ concept=<concept_name>,
76
+ retries=<retries_number>,
77
+ concepts_list=<concepts>,
78
+ views_list=<views>,
79
+ include_tags=<tags>,
80
+ note=<note>,
81
+ )
82
+
83
+ return validate_timbr_sql_chain.invoke({ "prompt": question, "sql": <latest_query_to_validate> }).get("sql", [])
84
+ ```
85
+ """
86
+ super().__init__(**kwargs)
87
+ self._llm = llm
88
+ self._url = url
89
+ self._token = token
90
+ self._ontology = ontology
91
+ self._schema = schema
92
+ self._concept = concept
93
+ self._retries = retries
94
+ self._concepts_list = parse_list(concepts_list)
95
+ self._views_list = parse_list(views_list)
96
+ self._include_logic_concepts = to_boolean(include_logic_concepts)
97
+ self._include_tags = parse_list(include_tags)
98
+ self._exclude_properties = parse_list(exclude_properties)
99
+ self._max_limit = to_integer(max_limit)
100
+ self._note = note
101
+ self._db_is_case_sensitive = to_boolean(db_is_case_sensitive)
102
+ self._graph_depth = to_integer(graph_depth)
103
+ self._verify_ssl = to_boolean(verify_ssl)
104
+ self._is_jwt = to_boolean(is_jwt)
105
+ self._jwt_tenant_id = jwt_tenant_id
106
+ self._debug = to_boolean(debug)
107
+ self._conn_params = conn_params or {}
108
+
109
+
110
+ @property
111
+ def usage_metadata_key(self) -> str:
112
+ return "validate_sql_usage_metadata"
113
+
114
+
115
+ @property
116
+ def input_keys(self) -> list:
117
+ return ["prompt", "sql"]
118
+
119
+
120
+ @property
121
+ def output_keys(self) -> list:
122
+ return [
123
+ "sql",
124
+ "schema",
125
+ "concept",
126
+ "is_sql_valid",
127
+ "error",
128
+ self.usage_metadata_key,
129
+ ]
130
+
131
+
132
+ def _get_conn_params(self) -> dict:
133
+ return {
134
+ "url": self._url,
135
+ "token": self._token,
136
+ "ontology": self._ontology,
137
+ "verify_ssl": self._verify_ssl,
138
+ "is_jwt": self._is_jwt,
139
+ "jwt_tenant_id": self._jwt_tenant_id,
140
+ **self._conn_params,
141
+ }
142
+
143
+
144
+ def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, Any]:
145
+ usage_metadata = {}
146
+ sql = inputs["sql"]
147
+ prompt = inputs["prompt"]
148
+ schema = self._schema
149
+ concept = self._concept
150
+
151
+ is_sql_valid, error = validate_sql(sql, self._get_conn_params())
152
+ if not is_sql_valid:
153
+ prompt_extension = self._note + '\n' if self._note else ""
154
+ generate_res = generate_sql(
155
+ question=prompt,
156
+ llm=self._llm,
157
+ conn_params=self._get_conn_params(),
158
+ schema=schema,
159
+ concept=concept,
160
+ concepts_list=self._concepts_list,
161
+ views_list=self._views_list,
162
+ include_logic_concepts=self._include_logic_concepts,
163
+ include_tags=self._include_tags,
164
+ exclude_properties=self._exclude_properties,
165
+ should_validate_sql=True,
166
+ retries=self._retries,
167
+ max_limit=self._max_limit,
168
+ note=f"{prompt_extension}The original SQL query (`{sql}`) was invalid with this error from query {error}. Please take this in consideration while generating the query.",
169
+ db_is_case_sensitive=self._db_is_case_sensitive,
170
+ graph_depth=self._graph_depth,
171
+ debug=self._debug,
172
+ )
173
+ sql = generate_res.get("sql", "")
174
+ schema = generate_res.get("schema", self._schema)
175
+ concept = generate_res.get("concept", self._concept)
176
+ usage_metadata.update(generate_res.get("usage_metadata", {}))
177
+ is_sql_valid = generate_res.get("is_sql_valid")
178
+ error = generate_res.get("error")
179
+
180
+ return {
181
+ "sql": sql,
182
+ "schema": schema,
183
+ "concept": concept,
184
+ "is_sql_valid": is_sql_valid,
185
+ "error": error,
186
+ self.usage_metadata_key: usage_metadata,
187
+ }
@@ -0,0 +1,13 @@
1
+ from .identify_concept_node import IdentifyConceptNode
2
+ from .generate_timbr_sql_node import GenerateTimbrSqlNode
3
+ from .validate_timbr_query_node import ValidateSemanticSqlNode
4
+ from .execute_timbr_query_node import ExecuteSemanticQueryNode
5
+ from .generate_response_node import GenerateResponseNode
6
+
7
+ __all__ = [
8
+ "IdentifyConceptNode",
9
+ "GenerateTimbrSqlNode",
10
+ "ValidateSemanticSqlNode",
11
+ "ExecuteSemanticQueryNode",
12
+ "GenerateResponseNode",
13
+ ]
@@ -0,0 +1,108 @@
1
+ from typing import Optional, Union
2
+ from langchain.llms.base import LLM
3
+ from langgraph.graph import StateGraph
4
+
5
+ from ..langchain.execute_timbr_query_chain import ExecuteTimbrQueryChain
6
+
7
+
8
+ class ExecuteSemanticQueryNode:
9
+ """
10
+ Node that wraps ExecuteTimbrQueryChain functionality.
11
+ Expects the SQL (and optionally the concept) in the payload.
12
+ Returns the query rows along with the SQL and concept.
13
+ """
14
+ def __init__(
15
+ self,
16
+ llm: LLM,
17
+ url: Optional[str] = None,
18
+ token: Optional[str] = None,
19
+ ontology: Optional[str] = None,
20
+ schema: Optional[str] = None,
21
+ concept: Optional[str] = None,
22
+ concepts_list: Optional[Union[list[str], str]] = None,
23
+ views_list: Optional[Union[list[str], str]] = None,
24
+ include_logic_concepts: Optional[bool] = False,
25
+ include_tags: Optional[Union[list[str], str]] = None,
26
+ exclude_properties: Optional[Union[list[str], str]] = ['entity_id', 'entity_type', 'entity_label'],
27
+ should_validate_sql: Optional[bool] = False,
28
+ retries: Optional[int] = 3,
29
+ max_limit: Optional[int] = 500,
30
+ retry_if_no_results: Optional[bool] = False,
31
+ no_results_max_retries: Optional[int] = 2,
32
+ note: Optional[str] = '',
33
+ db_is_case_sensitive: Optional[bool] = False,
34
+ graph_depth: Optional[int] = 1,
35
+ verify_ssl: Optional[bool] = True,
36
+ is_jwt: Optional[bool] = False,
37
+ jwt_tenant_id: Optional[str] = None,
38
+ conn_params: Optional[dict] = None,
39
+ debug: Optional[bool] = False,
40
+ **kwargs,
41
+ ):
42
+ """
43
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
44
+ :param url: Timbr server url
45
+ :param token: Timbr password or token value
46
+ :param ontology: The name of the ontology/knowledge graph
47
+ :param schema: The name of the schema to query
48
+ :param concept: The name of the concept to query
49
+ :param concepts_list: Optional specific concept options to query
50
+ :param views_list: Optional specific view options to query
51
+ :param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
52
+ :param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
53
+ :param exclude_properties: Optional specific properties to exclude from the query (entity_id, entity_type & entity_label by default).
54
+ :param should_validate_sql: Whether to validate the SQL before executing it
55
+ :param retries: Number of retry attempts if the generated SQL is invalid
56
+ :param max_limit: Maximum number of rows to return
57
+ :retry_if_no_results: Whether to infer the result value from the SQL query. If the query won't return any rows, it will try to re-generate the SQL query then re-run it.
58
+ :param no_results_max_retries: Number of retry attempts to infer the result value from the SQL query
59
+ :param note: Optional additional note to extend our llm prompt
60
+ :param db_is_case_sensitive: Whether the database is case sensitive (default is False).
61
+ :param graph_depth: Maximum number of relationship hops to traverse from the source concept during schema exploration (default is 1).
62
+ :param verify_ssl: Whether to verify SSL certificates (default is True).
63
+ :param is_jwt: Whether to use JWT authentication (default is False).
64
+ :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
65
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
66
+ :return: A list of rows from the Timbr query
67
+ """
68
+ self.chain = ExecuteTimbrQueryChain(
69
+ llm=llm,
70
+ url=url,
71
+ token=token,
72
+ ontology=ontology,
73
+ schema=schema,
74
+ concept=concept,
75
+ concepts_list=concepts_list,
76
+ views_list=views_list,
77
+ include_logic_concepts=include_logic_concepts,
78
+ include_tags=include_tags,
79
+ exclude_properties=exclude_properties,
80
+ should_validate_sql=should_validate_sql,
81
+ retries=retries,
82
+ max_limit=max_limit,
83
+ retry_if_no_results=retry_if_no_results,
84
+ no_results_max_retries=no_results_max_retries,
85
+ note=note,
86
+ db_is_case_sensitive=db_is_case_sensitive,
87
+ graph_depth=graph_depth,
88
+ verify_ssl=verify_ssl,
89
+ is_jwt=is_jwt,
90
+ jwt_tenant_id=jwt_tenant_id,
91
+ conn_params=conn_params,
92
+ debug=debug,
93
+ **kwargs,
94
+ )
95
+
96
+
97
+ def run(self, state: StateGraph) -> dict:
98
+ try:
99
+ prompt = state.messages[-1].content if state.messages[-1] else None
100
+ except:
101
+ prompt = state.get('prompt', None)
102
+
103
+ return self.chain.invoke({ "prompt": prompt })
104
+
105
+
106
+ def __call__(self, state: dict) -> dict:
107
+ return self.run(state)
108
+
@@ -0,0 +1,59 @@
1
+ from typing import Optional, Union
2
+ from langchain.llms.base import LLM
3
+ from langgraph.graph import StateGraph
4
+
5
+ from ..langchain import GenerateAnswerChain
6
+
7
+
8
+ class GenerateResponseNode:
9
+ """
10
+ Node that wraps GenerateAnswerChain functionality, which generates an answer based on a given prompt and rows of data.
11
+ It uses the LLM to build a human-readable answer.
12
+
13
+ This node connects to a Timbr server via the provided URL and token to generate contextual answers from query results using an LLM.
14
+ """
15
+ def __init__(
16
+ self,
17
+ llm: LLM,
18
+ url: str,
19
+ token: str,
20
+ verify_ssl: Optional[bool] = True,
21
+ is_jwt: Optional[bool] = False,
22
+ jwt_tenant_id: Optional[str] = None,
23
+ conn_params: Optional[dict] = None,
24
+ debug: Optional[bool] = False,
25
+ **kwargs,
26
+ ):
27
+ """
28
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
29
+ :param url: Timbr server url
30
+ :param token: Timbr password or token value
31
+ :param verify_ssl: Whether to verify SSL certificates (default is True).
32
+ :param is_jwt: Whether to use JWT authentication (default is False).
33
+ :param jwt_tenant_id: JWT tenant ID for multi-tenant environments (required when is_jwt=True).
34
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
35
+ """
36
+ self.chain = GenerateAnswerChain(
37
+ llm=llm,
38
+ url=url,
39
+ token=token,
40
+ verify_ssl=verify_ssl,
41
+ is_jwt=is_jwt,
42
+ jwt_tenant_id=jwt_tenant_id,
43
+ conn_params=conn_params,
44
+ debug=debug,
45
+ **kwargs,
46
+ )
47
+
48
+
49
+ def run(self, state: dict) -> dict:
50
+ sql = state.get("sql", "")
51
+ rows = state.get("rows", "")
52
+ prompt = state.get("prompt", "")
53
+
54
+ return self.chain.invoke({ "prompt": prompt, "rows": rows, "sql": sql })
55
+
56
+
57
+ def __call__(self, state: dict) -> dict:
58
+ return self.run(state)
59
+
@@ -0,0 +1,98 @@
1
+ from typing import Optional, Union
2
+ from langchain.llms.base import LLM
3
+ from langgraph.graph import StateGraph
4
+
5
+ from ..langchain.generate_timbr_sql_chain import GenerateTimbrSqlChain
6
+
7
+ class GenerateTimbrSqlNode:
8
+ """
9
+ Node that wraps GenerateTimbrSqlChain functionality.
10
+ Expects an input payload with a "prompt" key.
11
+ """
12
+ def __init__(
13
+ self,
14
+ llm: LLM,
15
+ url: Optional[str] = None,
16
+ token: Optional[str] = None,
17
+ ontology: Optional[str] = None,
18
+ schema: Optional[str] = None,
19
+ concept: Optional[str] = None,
20
+ concepts_list: Optional[Union[list[str], str]] = None,
21
+ views_list: Optional[Union[list[str], str]] = None,
22
+ include_logic_concepts: Optional[bool] = False,
23
+ include_tags: Optional[Union[list[str], str]] = None,
24
+ exclude_properties: Optional[Union[list[str], str]] = ['entity_id', 'entity_type', 'entity_label'],
25
+ should_validate_sql: Optional[bool] = False,
26
+ retries: Optional[int] = 3,
27
+ max_limit: Optional[int] = 500,
28
+ note: Optional[str] = '',
29
+ db_is_case_sensitive: Optional[bool] = False,
30
+ graph_depth: Optional[int] = 1,
31
+ verify_ssl: Optional[bool] = True,
32
+ is_jwt: Optional[bool] = False,
33
+ jwt_tenant_id: Optional[str] = None,
34
+ conn_params: Optional[dict] = None,
35
+ debug: Optional[bool] = False,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
40
+ :param url: Timbr server url
41
+ :param token: Timbr password or token value
42
+ :param ontology: The name of the ontology/knowledge graph
43
+ :param schema: The name of the schema to query
44
+ :param concept: The name of the concept to query
45
+ :param concepts_list: Optional specific concept options to query
46
+ :param views_list: Optional specific view options to query
47
+ :param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
48
+ :param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
49
+ :param exclude_properties: Optional specific properties to exclude from the query (entity_id, entity_type & entity_label by default).
50
+ :param should_validate_sql: Whether to validate the SQL before executing it
51
+ :param retries: Number of retry attempts if the generated SQL is invalid
52
+ :param max_limit: Maximum number of rows to query
53
+ :param note: Optional additional note to extend our llm prompt
54
+ :param db_is_case_sensitive: Whether the database is case sensitive (default is False).
55
+ :param graph_depth: Maximum number of relationship hops to traverse from the source concept during schema exploration (default is 1).
56
+ :param verify_ssl: Whether to verify SSL certificates (default is True).
57
+ :param is_jwt: Whether to use JWT authentication (default: False)
58
+ :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
59
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
60
+ """
61
+ self.chain = GenerateTimbrSqlChain(
62
+ llm=llm,
63
+ url=url,
64
+ token=token,
65
+ ontology=ontology,
66
+ schema=schema,
67
+ concept=concept,
68
+ concepts_list=concepts_list,
69
+ views_list=views_list,
70
+ include_logic_concepts=include_logic_concepts,
71
+ include_tags=include_tags,
72
+ exclude_properties=exclude_properties,
73
+ should_validate_sql=should_validate_sql,
74
+ retries=retries,
75
+ max_limit=max_limit,
76
+ note=note,
77
+ db_is_case_sensitive=db_is_case_sensitive,
78
+ graph_depth=graph_depth,
79
+ verify_ssl=verify_ssl,
80
+ is_jwt=is_jwt,
81
+ jwt_tenant_id=jwt_tenant_id,
82
+ conn_params=conn_params,
83
+ debug=debug,
84
+ **kwargs,
85
+ )
86
+
87
+
88
+ def run(self, state: StateGraph) -> dict:
89
+ try:
90
+ prompt = state.messages[-1].content if (state.messages and state.messages[-1]) else None
91
+ except:
92
+ prompt = state.get('prompt', None)
93
+
94
+ return self.chain.invoke({ "prompt": prompt })
95
+
96
+
97
+ def __call__(self, state: dict) -> dict:
98
+ return self.run(state)
@@ -0,0 +1,78 @@
1
+ from typing import Optional, Union
2
+ from langchain.llms.base import LLM
3
+ from langgraph.graph import StateGraph
4
+
5
+ from ..langchain.identify_concept_chain import IdentifyTimbrConceptChain
6
+
7
+
8
+ class IdentifyConceptNode:
9
+ def __init__(
10
+ self,
11
+ llm: LLM,
12
+ url: Optional[str] = None,
13
+ token: Optional[str] = None,
14
+ ontology: Optional[str] = None,
15
+ concepts_list: Optional[Union[list[str], str]] = None,
16
+ views_list: Optional[Union[list[str], str]] = None,
17
+ include_logic_concepts: Optional[bool] = False,
18
+ include_tags: Optional[Union[list[str], str]] = None,
19
+ should_validate: Optional[bool] = False,
20
+ retries: Optional[int] = 3,
21
+ note: Optional[str] = None,
22
+ verify_ssl: Optional[bool] = True,
23
+ is_jwt: Optional[bool] = False,
24
+ jwt_tenant_id: Optional[str] = None,
25
+ conn_params: Optional[dict] = None,
26
+ debug: Optional[bool] = False,
27
+ **kwargs,
28
+ ):
29
+ """
30
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
31
+ :param url: Timbr server url
32
+ :param token: Timbr password or token value
33
+ :param ontology: The name of the ontology/knowledge graph
34
+ :param concepts_list: Optional specific concept options to query
35
+ :param views_list: Optional specific view options to query
36
+ :param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
37
+ :param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
38
+ :param should_validate: Whether to validate the identified concept before returning it
39
+ :param retries: Number of retry attempts if the identified concept is invalid
40
+ :param note: Optional additional note to extend our llm prompt
41
+ :param verify_ssl: Whether to verify SSL certificates
42
+ :param is_jwt: Whether to use JWT authentication (default: False)
43
+ :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
44
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
45
+ """
46
+ self.chain = IdentifyTimbrConceptChain(
47
+ llm=llm,
48
+ url=url,
49
+ token=token,
50
+ ontology=ontology,
51
+ concepts_list=concepts_list,
52
+ views_list=views_list,
53
+ include_logic_concepts=include_logic_concepts,
54
+ include_tags=include_tags,
55
+ should_validate=should_validate,
56
+ retries=retries,
57
+ note=note,
58
+ verify_ssl=verify_ssl,
59
+ is_jwt=is_jwt,
60
+ jwt_tenant_id=jwt_tenant_id,
61
+ conn_params=conn_params,
62
+ debug=debug,
63
+ **kwargs,
64
+ )
65
+
66
+
67
+ def run(self, state: StateGraph) -> dict:
68
+ try:
69
+ prompt = state.messages[-1].content if state.messages[-1] else None
70
+ except:
71
+ prompt = state.get('prompt', None)
72
+
73
+ return self.chain.invoke({ "prompt": prompt })
74
+
75
+
76
+ def __call__(self, state: dict) -> dict:
77
+ return self.run(state)
78
+
@@ -0,0 +1,100 @@
1
+ from typing import Optional, Union
2
+ from langchain.llms.base import LLM
3
+ from langgraph.graph import StateGraph
4
+
5
+ from ..langchain.validate_timbr_sql_chain import ValidateTimbrSqlChain
6
+
7
+
8
+ class ValidateSemanticSqlNode:
9
+ """
10
+ Node that wraps ValidateTimbrSqlChain functionality.
11
+ Expects an input payload with a "sql" or "prompt" key.
12
+ Produces output with keys: "sql" and "is_sql_valid".
13
+ """
14
+ def __init__(
15
+ self,
16
+ llm: LLM,
17
+ url: str,
18
+ token: str,
19
+ ontology: str,
20
+ schema: Optional[str] = None,
21
+ concept: Optional[str] = None,
22
+ retries: Optional[int] = 3,
23
+ concepts_list: Optional[Union[list[str], str]] = None,
24
+ views_list: Optional[Union[list[str], str]] = None,
25
+ include_logic_concepts: Optional[bool] = False,
26
+ include_tags: Optional[Union[list[str], str]] = None,
27
+ exclude_properties: Optional[Union[list[str], str]] = ['entity_id', 'entity_type', 'entity_label'],
28
+ max_limit: Optional[int] = 500,
29
+ note: Optional[str] = None,
30
+ db_is_case_sensitive: Optional[bool] = False,
31
+ graph_depth: Optional[int] = 1,
32
+ verify_ssl: Optional[bool] = True,
33
+ is_jwt: Optional[bool] = False,
34
+ jwt_tenant_id: Optional[str] = None,
35
+ conn_params: Optional[dict] = None,
36
+ debug: Optional[bool] = False,
37
+ **kwargs,
38
+ ):
39
+ """
40
+ :param llm: An LLM instance or a function that takes a prompt string and returns the LLM’s response
41
+ :param url: Timbr server url
42
+ :param token: Timbr password or token value
43
+ :param ontology: The name of the ontology/knowledge graph
44
+ :param schema: The name of the schema to query
45
+ :param concept: The name of the concept to query
46
+ :param retries: The maximum number of retries to attempt
47
+ :param concepts_list: Optional specific concept options to query
48
+ :param views_list: Optional specific view options to query
49
+ :param include_logic_concepts: Optional boolean to include logic concepts (concepts without unique properties which only inherits from an upper level concept with filter logic) in the query.
50
+ :param include_tags: Optional specific concepts & properties tag options to use in the query (Disabled by default. Use '*' to enable all tags or a string represents a list of tags divided by commas (e.g. 'tag1,tag2')
51
+ :param exclude_properties: Optional specific properties to exclude from the query (entity_id, entity_type & entity_label by default).
52
+ :param max_limit: Maximum number of rows to query
53
+ :param note: Optional additional note to extend our llm prompt
54
+ :param db_is_case_sensitive: Whether the database is case sensitive (default is False).
55
+ :param graph_depth: Maximum number of relationship hops to traverse from the source concept during schema exploration (default is 1).
56
+ :param verify_ssl: Whether to verify SSL certificates (default is True).
57
+ :param is_jwt: Whether to use JWT authentication (default: False)
58
+ :param jwt_tenant_id: Tenant ID for JWT authentication when using multi-tenant setup
59
+ :param conn_params: Extra Timbr connection parameters sent with every request (e.g., 'x-api-impersonate-user').
60
+ """
61
+ self.chain = ValidateTimbrSqlChain(
62
+ llm=llm,
63
+ url=url,
64
+ token=token,
65
+ ontology=ontology,
66
+ schema=schema,
67
+ concept=concept,
68
+ retries=retries,
69
+ concepts_list=concepts_list,
70
+ views_list=views_list,
71
+ include_logic_concepts=include_logic_concepts,
72
+ include_tags=include_tags,
73
+ exclude_properties=exclude_properties,
74
+ max_limit=max_limit,
75
+ note=note,
76
+ db_is_case_sensitive=db_is_case_sensitive,
77
+ graph_depth=graph_depth,
78
+ verify_ssl=verify_ssl,
79
+ is_jwt=is_jwt,
80
+ jwt_tenant_id=jwt_tenant_id,
81
+ conn_params=conn_params,
82
+ debug=debug,
83
+ **kwargs,
84
+ )
85
+
86
+
87
+ def run(self, state: StateGraph) -> dict:
88
+ try:
89
+ sql = state.sql
90
+ prompt = state.prompt
91
+ except:
92
+ sql = state.get('sql', None)
93
+ prompt = state.get('prompt', None)
94
+
95
+ return self.chain.invoke({ "sql": sql, "prompt": prompt })
96
+
97
+
98
+ def __call__(self, payload: dict) -> dict:
99
+ return self.run(payload)
100
+