pyoptinterface 0.3.0__cp312-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoptinterface/__init__.py +68 -0
- pyoptinterface/_src/__init__.py +0 -0
- pyoptinterface/_src/aml.py +83 -0
- pyoptinterface/_src/attributes.py +129 -0
- pyoptinterface/_src/codegen_c.py +301 -0
- pyoptinterface/_src/codegen_llvm.py +506 -0
- pyoptinterface/_src/constraint_bridge.py +49 -0
- pyoptinterface/_src/copt.py +524 -0
- pyoptinterface/_src/copt_model_ext.pyd +0 -0
- pyoptinterface/_src/copt_model_ext.pyi +195 -0
- pyoptinterface/_src/core_ext.pyd +0 -0
- pyoptinterface/_src/core_ext.pyi +450 -0
- pyoptinterface/_src/cpp_graph_iter.py +39 -0
- pyoptinterface/_src/cppad_interface_ext.pyd +0 -0
- pyoptinterface/_src/cppad_interface_ext.pyi +158 -0
- pyoptinterface/_src/dylib.py +13 -0
- pyoptinterface/_src/function_tracing.py +342 -0
- pyoptinterface/_src/gurobi.py +724 -0
- pyoptinterface/_src/gurobi_model_ext.pyd +0 -0
- pyoptinterface/_src/gurobi_model_ext.pyi +222 -0
- pyoptinterface/_src/highs.py +439 -0
- pyoptinterface/_src/highs_model_ext.pyd +0 -0
- pyoptinterface/_src/highs_model_ext.pyi +195 -0
- pyoptinterface/_src/ipopt.py +729 -0
- pyoptinterface/_src/ipopt_model_ext.pyd +0 -0
- pyoptinterface/_src/ipopt_model_ext.pyi +179 -0
- pyoptinterface/_src/jit_c.py +76 -0
- pyoptinterface/_src/jit_llvm.py +31 -0
- pyoptinterface/_src/mosek.py +500 -0
- pyoptinterface/_src/mosek_model_ext.pyd +0 -0
- pyoptinterface/_src/mosek_model_ext.pyi +178 -0
- pyoptinterface/_src/nleval_ext.pyd +0 -0
- pyoptinterface/_src/nleval_ext.pyi +67 -0
- pyoptinterface/_src/nlexpr_ext.pyd +0 -0
- pyoptinterface/_src/nlexpr_ext.pyi +154 -0
- pyoptinterface/_src/solver_common.py +107 -0
- pyoptinterface/_src/tcc_interface_ext.pyd +0 -0
- pyoptinterface/_src/tcc_interface_ext.pyi +24 -0
- pyoptinterface/_src/tupledict.py +132 -0
- pyoptinterface/copt.py +18 -0
- pyoptinterface/gurobi.py +15 -0
- pyoptinterface/highs.py +4 -0
- pyoptinterface/ipopt.py +8 -0
- pyoptinterface/mosek.py +16 -0
- pyoptinterface/nlfunc.py +17 -0
- pyoptinterface-0.3.0.dist-info/METADATA +152 -0
- pyoptinterface-0.3.0.dist-info/RECORD +49 -0
- pyoptinterface-0.3.0.dist-info/WHEEL +5 -0
- pyoptinterface-0.3.0.dist-info/licenses/LICENSE.md +383 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from pyoptinterface._src.core_ext import (
|
|
2
|
+
VariableIndex,
|
|
3
|
+
ConstraintIndex,
|
|
4
|
+
ExprBuilder,
|
|
5
|
+
VariableDomain,
|
|
6
|
+
ConstraintSense,
|
|
7
|
+
ConstraintType,
|
|
8
|
+
SOSType,
|
|
9
|
+
ObjectiveSense,
|
|
10
|
+
ScalarAffineFunction,
|
|
11
|
+
ScalarQuadraticFunction,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from pyoptinterface._src.attributes import (
|
|
15
|
+
VariableAttribute,
|
|
16
|
+
ModelAttribute,
|
|
17
|
+
TerminationStatusCode,
|
|
18
|
+
ResultStatusCode,
|
|
19
|
+
ConstraintAttribute,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from pyoptinterface._src.tupledict import (
|
|
23
|
+
tupledict,
|
|
24
|
+
make_tupledict,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from pyoptinterface._src.aml import make_nd_variable, quicksum, quicksum_
|
|
28
|
+
|
|
29
|
+
# Alias of ConstraintSense
|
|
30
|
+
Eq = ConstraintSense.Equal
|
|
31
|
+
"""Alias of `ConstraintSense.Equal` for equality constraints.
|
|
32
|
+
"""
|
|
33
|
+
Leq = ConstraintSense.LessEqual
|
|
34
|
+
"""Alias of `ConstraintSense.LessEqual` for less-than-or-equal-to constraints.
|
|
35
|
+
"""
|
|
36
|
+
Geq = ConstraintSense.GreaterEqual
|
|
37
|
+
"""Alias of `ConstraintSense.GreaterEqual` for greater-than-or-equal-to constraints.
|
|
38
|
+
"""
|
|
39
|
+
In = ConstraintSense.Within
|
|
40
|
+
"""Alias of `ConstraintSense.Within` for range constraints.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"VariableIndex",
|
|
45
|
+
"ConstraintIndex",
|
|
46
|
+
"ExprBuilder",
|
|
47
|
+
"VariableDomain",
|
|
48
|
+
"ConstraintSense",
|
|
49
|
+
"ConstraintType",
|
|
50
|
+
"SOSType",
|
|
51
|
+
"ObjectiveSense",
|
|
52
|
+
"ScalarAffineFunction",
|
|
53
|
+
"ScalarQuadraticFunction",
|
|
54
|
+
"VariableAttribute",
|
|
55
|
+
"ModelAttribute",
|
|
56
|
+
"TerminationStatusCode",
|
|
57
|
+
"ResultStatusCode",
|
|
58
|
+
"ConstraintAttribute",
|
|
59
|
+
"tupledict",
|
|
60
|
+
"make_tupledict",
|
|
61
|
+
"make_nd_variable",
|
|
62
|
+
"quicksum",
|
|
63
|
+
"quicksum_",
|
|
64
|
+
"Eq",
|
|
65
|
+
"Leq",
|
|
66
|
+
"Geq",
|
|
67
|
+
"In",
|
|
68
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from .core_ext import ExprBuilder
|
|
2
|
+
from .tupledict import make_tupledict
|
|
3
|
+
|
|
4
|
+
from collections.abc import Collection
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def make_nd_variable(
|
|
8
|
+
model, *coords: Collection, domain=None, lb=None, ub=None, name=None, start=None
|
|
9
|
+
):
|
|
10
|
+
kw_args = dict()
|
|
11
|
+
if domain is not None:
|
|
12
|
+
kw_args["domain"] = domain
|
|
13
|
+
if lb is not None:
|
|
14
|
+
kw_args["lb"] = lb
|
|
15
|
+
if ub is not None:
|
|
16
|
+
kw_args["ub"] = ub
|
|
17
|
+
if start is not None:
|
|
18
|
+
kw_args["start"] = start
|
|
19
|
+
|
|
20
|
+
def f(*args):
|
|
21
|
+
if name is not None:
|
|
22
|
+
suffix = str(args)
|
|
23
|
+
kw_args["name"] = f"{name}{suffix}"
|
|
24
|
+
return model.add_variable(**kw_args)
|
|
25
|
+
|
|
26
|
+
td = make_tupledict(*coords, rule=f)
|
|
27
|
+
return td
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# def make_nd_variable_batch(
|
|
31
|
+
# model,
|
|
32
|
+
# *coords: Collection,
|
|
33
|
+
# domain=None,
|
|
34
|
+
# lb=None,
|
|
35
|
+
# ub=None,
|
|
36
|
+
# name=None,
|
|
37
|
+
# ):
|
|
38
|
+
# assert model.supports_batch_add_variables()
|
|
39
|
+
#
|
|
40
|
+
# kw_args = dict()
|
|
41
|
+
# if domain is not None:
|
|
42
|
+
# kw_args["domain"] = domain
|
|
43
|
+
# if lb is not None:
|
|
44
|
+
# kw_args["lb"] = lb
|
|
45
|
+
# if ub is not None:
|
|
46
|
+
# kw_args["ub"] = ub
|
|
47
|
+
#
|
|
48
|
+
# N = math.prod(len(c) for c in coords)
|
|
49
|
+
#
|
|
50
|
+
# start_vi = model.add_variables(N, **kw_args)
|
|
51
|
+
# start_index = start_vi.index
|
|
52
|
+
#
|
|
53
|
+
# kvs = []
|
|
54
|
+
# assert len(coords) > 0
|
|
55
|
+
# for i, coord in enumerate(product(*coords)):
|
|
56
|
+
# coord = tuple(flatten_tuple(coord))
|
|
57
|
+
# value = VariableIndex(start_index + i)
|
|
58
|
+
# if len(coord) == 1:
|
|
59
|
+
# coord = coord[0]
|
|
60
|
+
# if value is not None:
|
|
61
|
+
# kvs.append((coord, value))
|
|
62
|
+
#
|
|
63
|
+
# suffix = str(coord)
|
|
64
|
+
# if name is not None:
|
|
65
|
+
# model.set_variable_name(value, f"{name}{suffix}")
|
|
66
|
+
# return tupledict(kvs)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def quicksum_(expr: ExprBuilder, terms, f=None):
|
|
70
|
+
if isinstance(terms, dict):
|
|
71
|
+
iter = terms.values()
|
|
72
|
+
else:
|
|
73
|
+
iter = terms
|
|
74
|
+
if f:
|
|
75
|
+
iter = map(f, iter)
|
|
76
|
+
for v in iter:
|
|
77
|
+
expr += v
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def quicksum(terms, f=None):
|
|
81
|
+
expr = ExprBuilder()
|
|
82
|
+
quicksum_(expr, terms, f)
|
|
83
|
+
return expr
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from enum import Enum, auto
|
|
2
|
+
from .core_ext import VariableDomain, ObjectiveSense
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class VariableAttribute(Enum):
|
|
6
|
+
Value = auto()
|
|
7
|
+
LowerBound = auto()
|
|
8
|
+
UpperBound = auto()
|
|
9
|
+
Domain = auto()
|
|
10
|
+
PrimalStart = auto()
|
|
11
|
+
Name = auto()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
var_attr_type_map = {
|
|
15
|
+
VariableAttribute.Value: float,
|
|
16
|
+
VariableAttribute.LowerBound: float,
|
|
17
|
+
VariableAttribute.UpperBound: float,
|
|
18
|
+
VariableAttribute.PrimalStart: float,
|
|
19
|
+
VariableAttribute.Domain: VariableDomain,
|
|
20
|
+
VariableAttribute.Name: str,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelAttribute(Enum):
|
|
25
|
+
# ModelLike API
|
|
26
|
+
# NumberOfConstraints = auto()
|
|
27
|
+
# NumberOfVariables = auto()
|
|
28
|
+
Name = auto()
|
|
29
|
+
ObjectiveSense = auto()
|
|
30
|
+
|
|
31
|
+
# AbstractOptimizer API
|
|
32
|
+
DualStatus = auto()
|
|
33
|
+
PrimalStatus = auto()
|
|
34
|
+
RawStatusString = auto()
|
|
35
|
+
TerminationStatus = auto()
|
|
36
|
+
BarrierIterations = auto()
|
|
37
|
+
DualObjectiveValue = auto()
|
|
38
|
+
NodeCount = auto()
|
|
39
|
+
NumberOfThreads = auto()
|
|
40
|
+
ObjectiveBound = auto()
|
|
41
|
+
ObjectiveValue = auto()
|
|
42
|
+
RelativeGap = auto()
|
|
43
|
+
Silent = auto()
|
|
44
|
+
SimplexIterations = auto()
|
|
45
|
+
SolverName = auto()
|
|
46
|
+
SolverVersion = auto()
|
|
47
|
+
SolveTimeSec = auto()
|
|
48
|
+
TimeLimitSec = auto()
|
|
49
|
+
# ObjectiveLimit = auto()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ResultStatusCode(Enum):
|
|
53
|
+
NO_SOLUTION = auto()
|
|
54
|
+
FEASIBLE_POINT = auto()
|
|
55
|
+
NEARLY_FEASIBLE_POINT = auto()
|
|
56
|
+
INFEASIBLE_POINT = auto()
|
|
57
|
+
INFEASIBILITY_CERTIFICATE = auto()
|
|
58
|
+
NEARLY_INFEASIBILITY_CERTIFICATE = auto()
|
|
59
|
+
REDUCTION_CERTIFICATE = auto()
|
|
60
|
+
NEARLY_REDUCTION_CERTIFICATE = auto()
|
|
61
|
+
UNKNOWN_RESULT_STATUS = auto()
|
|
62
|
+
OTHER_RESULT_STATUS = auto()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TerminationStatusCode(Enum):
|
|
66
|
+
OPTIMIZE_NOT_CALLED = auto()
|
|
67
|
+
OPTIMAL = auto()
|
|
68
|
+
INFEASIBLE = auto()
|
|
69
|
+
DUAL_INFEASIBLE = auto()
|
|
70
|
+
LOCALLY_SOLVED = auto()
|
|
71
|
+
LOCALLY_INFEASIBLE = auto()
|
|
72
|
+
INFEASIBLE_OR_UNBOUNDED = auto()
|
|
73
|
+
ALMOST_OPTIMAL = auto()
|
|
74
|
+
ALMOST_INFEASIBLE = auto()
|
|
75
|
+
ALMOST_DUAL_INFEASIBLE = auto()
|
|
76
|
+
ALMOST_LOCALLY_SOLVED = auto()
|
|
77
|
+
ITERATION_LIMIT = auto()
|
|
78
|
+
TIME_LIMIT = auto()
|
|
79
|
+
NODE_LIMIT = auto()
|
|
80
|
+
SOLUTION_LIMIT = auto()
|
|
81
|
+
MEMORY_LIMIT = auto()
|
|
82
|
+
OBJECTIVE_LIMIT = auto()
|
|
83
|
+
NORM_LIMIT = auto()
|
|
84
|
+
OTHER_LIMIT = auto()
|
|
85
|
+
SLOW_PROGRESS = auto()
|
|
86
|
+
NUMERICAL_ERROR = auto()
|
|
87
|
+
INVALID_MODEL = auto()
|
|
88
|
+
INVALID_OPTION = auto()
|
|
89
|
+
INTERRUPTED = auto()
|
|
90
|
+
OTHER_ERROR = auto()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
model_attr_type_map = {
|
|
94
|
+
ModelAttribute.Name: str,
|
|
95
|
+
ModelAttribute.ObjectiveSense: ObjectiveSense,
|
|
96
|
+
ModelAttribute.DualStatus: ResultStatusCode,
|
|
97
|
+
ModelAttribute.PrimalStatus: ResultStatusCode,
|
|
98
|
+
ModelAttribute.RawStatusString: str,
|
|
99
|
+
ModelAttribute.TerminationStatus: TerminationStatusCode,
|
|
100
|
+
ModelAttribute.BarrierIterations: int,
|
|
101
|
+
ModelAttribute.DualObjectiveValue: float,
|
|
102
|
+
ModelAttribute.NodeCount: int,
|
|
103
|
+
ModelAttribute.NumberOfThreads: int,
|
|
104
|
+
ModelAttribute.ObjectiveBound: float,
|
|
105
|
+
ModelAttribute.ObjectiveValue: float,
|
|
106
|
+
ModelAttribute.RelativeGap: float,
|
|
107
|
+
ModelAttribute.Silent: bool,
|
|
108
|
+
ModelAttribute.SimplexIterations: int,
|
|
109
|
+
ModelAttribute.SolverName: str,
|
|
110
|
+
ModelAttribute.SolverVersion: str,
|
|
111
|
+
ModelAttribute.SolveTimeSec: float,
|
|
112
|
+
ModelAttribute.TimeLimitSec: float,
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ConstraintAttribute(Enum):
|
|
117
|
+
Name = auto()
|
|
118
|
+
# PrimalStart = auto()
|
|
119
|
+
# DualStart = auto()
|
|
120
|
+
Primal = auto()
|
|
121
|
+
Dual = auto()
|
|
122
|
+
# BasisStatus = auto()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
constraint_attr_type_map = {
|
|
126
|
+
ConstraintAttribute.Name: str,
|
|
127
|
+
ConstraintAttribute.Primal: float,
|
|
128
|
+
ConstraintAttribute.Dual: float,
|
|
129
|
+
}
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from .cppad_interface_ext import (
|
|
2
|
+
graph_op,
|
|
3
|
+
)
|
|
4
|
+
from .cpp_graph_iter import cpp_graph_iterator
|
|
5
|
+
|
|
6
|
+
from typing import IO
|
|
7
|
+
|
|
8
|
+
op2name = {
|
|
9
|
+
graph_op.abs: "fabs",
|
|
10
|
+
graph_op.acos: "acos",
|
|
11
|
+
graph_op.asin: "asin",
|
|
12
|
+
graph_op.atan: "atan",
|
|
13
|
+
graph_op.cos: "cos",
|
|
14
|
+
graph_op.exp: "exp",
|
|
15
|
+
graph_op.log: "log",
|
|
16
|
+
graph_op.pow: "pow",
|
|
17
|
+
graph_op.sign: "sign",
|
|
18
|
+
graph_op.sin: "sin",
|
|
19
|
+
graph_op.sqrt: "sqrt",
|
|
20
|
+
graph_op.tan: "tan",
|
|
21
|
+
graph_op.add: "+",
|
|
22
|
+
graph_op.sub: "-",
|
|
23
|
+
graph_op.mul: "*",
|
|
24
|
+
graph_op.div: "/",
|
|
25
|
+
graph_op.azmul: "*",
|
|
26
|
+
graph_op.neg: "-",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
compare_ops = set([graph_op.cexp_eq, graph_op.cexp_le, graph_op.cexp_lt])
|
|
30
|
+
|
|
31
|
+
compare_ops_string = {
|
|
32
|
+
graph_op.cexp_eq: "==",
|
|
33
|
+
graph_op.cexp_le: "<=",
|
|
34
|
+
graph_op.cexp_lt: "<",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def generate_csrc_prelude(io: IO[str]):
|
|
39
|
+
io.write(
|
|
40
|
+
"""// includes
|
|
41
|
+
#include <stddef.h>
|
|
42
|
+
|
|
43
|
+
// typedefs
|
|
44
|
+
typedef double float_point_t;
|
|
45
|
+
|
|
46
|
+
// declare mathematical functions
|
|
47
|
+
#define UNARY(f) extern float_point_t f(float_point_t x)
|
|
48
|
+
#define BINARY(f) extern float_point_t f(float_point_t x, float_point_t y)
|
|
49
|
+
|
|
50
|
+
// unary functions
|
|
51
|
+
UNARY(fabs);
|
|
52
|
+
UNARY(acos);
|
|
53
|
+
UNARY(asin);
|
|
54
|
+
UNARY(atan);
|
|
55
|
+
UNARY(cos);
|
|
56
|
+
UNARY(exp);
|
|
57
|
+
UNARY(log);
|
|
58
|
+
UNARY(sin);
|
|
59
|
+
UNARY(sqrt);
|
|
60
|
+
UNARY(tan);
|
|
61
|
+
|
|
62
|
+
// binary functions
|
|
63
|
+
BINARY(pow);
|
|
64
|
+
|
|
65
|
+
// externals
|
|
66
|
+
// azmul
|
|
67
|
+
float_point_t azmul(float_point_t x, float_point_t y)
|
|
68
|
+
{
|
|
69
|
+
if( x == 0.0 ) return 0.0;
|
|
70
|
+
return x * y;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// sign
|
|
74
|
+
float_point_t sign(float_point_t x)
|
|
75
|
+
{
|
|
76
|
+
if( x > 0.0 ) return 1.0;
|
|
77
|
+
if( x == 0.0 ) return 0.0;
|
|
78
|
+
return -1.0;
|
|
79
|
+
}
|
|
80
|
+
"""
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def generate_csrc_from_graph(
|
|
85
|
+
io: IO[str],
|
|
86
|
+
graph_obj,
|
|
87
|
+
name: str,
|
|
88
|
+
np: int = 0,
|
|
89
|
+
hessian_lagrange: bool = False,
|
|
90
|
+
nw: int = 0,
|
|
91
|
+
indirect_x: bool = False,
|
|
92
|
+
indirect_p: bool = False,
|
|
93
|
+
indirect_w: bool = False,
|
|
94
|
+
indirect_y: bool = False,
|
|
95
|
+
add_y: bool = False,
|
|
96
|
+
):
|
|
97
|
+
n_dynamic_ind = graph_obj.n_dynamic_ind
|
|
98
|
+
n_variable_ind = graph_obj.n_variable_ind
|
|
99
|
+
n_constant = graph_obj.n_constant
|
|
100
|
+
n_dependent = graph_obj.n_dependent
|
|
101
|
+
|
|
102
|
+
# Simple case
|
|
103
|
+
# 0 -> dummy
|
|
104
|
+
# [1, 1 + np) -> *p
|
|
105
|
+
# [1 + np, 1 + n_dynamic_ind + n_variable_ind) -> *x
|
|
106
|
+
# [1 + n_dynamic_ind + n_variable_ind, 1 + n_dynamic_ind + n_variable_ind + n_constant) -> c[...]
|
|
107
|
+
# [1 + n_dynamic_ind + n_variable_ind + n_constant, ...) -> v[...]
|
|
108
|
+
|
|
109
|
+
# Hessian lagragian case
|
|
110
|
+
# 0 -> dummy
|
|
111
|
+
# [1, 1 + np) -> *p
|
|
112
|
+
# [1 + np, 1 + np + nw) -> *w
|
|
113
|
+
# [1 + np + nw, 1 + n_dynamic_ind + n_variable_ind) -> *x
|
|
114
|
+
# [1 + n_dynamic_ind + n_variable_ind, 1 + n_dynamic_ind + n_variable_ind + n_constant) -> c[...]
|
|
115
|
+
# [1 + n_dynamic_ind + n_variable_ind + n_constant, ...) -> v[...]
|
|
116
|
+
|
|
117
|
+
n_node = 0
|
|
118
|
+
for graph_iter in cpp_graph_iterator(graph_obj):
|
|
119
|
+
n_node += graph_iter.n_result
|
|
120
|
+
|
|
121
|
+
has_parameter = np > 0
|
|
122
|
+
|
|
123
|
+
function_args_signature = ["const float_point_t* x"]
|
|
124
|
+
if has_parameter:
|
|
125
|
+
function_args_signature.append("const float_point_t* p")
|
|
126
|
+
if hessian_lagrange:
|
|
127
|
+
function_args_signature.append("const float_point_t* w")
|
|
128
|
+
function_args_signature.append("float_point_t* y")
|
|
129
|
+
if indirect_x:
|
|
130
|
+
function_args_signature.append("const size_t* xi")
|
|
131
|
+
if has_parameter and indirect_p:
|
|
132
|
+
function_args_signature.append("const size_t* pi")
|
|
133
|
+
if hessian_lagrange and indirect_w:
|
|
134
|
+
function_args_signature.append("const size_t* wi")
|
|
135
|
+
if indirect_y:
|
|
136
|
+
function_args_signature.append("const size_t* yi")
|
|
137
|
+
|
|
138
|
+
function_args = ", ".join(function_args_signature)
|
|
139
|
+
|
|
140
|
+
function_prototype = f"""
|
|
141
|
+
void {name}(
|
|
142
|
+
{function_args}
|
|
143
|
+
)
|
|
144
|
+
"""
|
|
145
|
+
io.write(function_prototype)
|
|
146
|
+
|
|
147
|
+
if not hessian_lagrange:
|
|
148
|
+
nx = n_dynamic_ind + n_variable_ind - np
|
|
149
|
+
else:
|
|
150
|
+
nx = n_dynamic_ind + n_variable_ind - np - nw
|
|
151
|
+
ny = n_dependent
|
|
152
|
+
io.write(
|
|
153
|
+
f"""{{
|
|
154
|
+
// begin function body
|
|
155
|
+
|
|
156
|
+
// size checks
|
|
157
|
+
// const size_t nx = {nx};
|
|
158
|
+
// const size_t np = {np};
|
|
159
|
+
// const size_t ny = {ny};
|
|
160
|
+
"""
|
|
161
|
+
)
|
|
162
|
+
if hessian_lagrange:
|
|
163
|
+
io.write(f" // const size_t nw = {nw};\n")
|
|
164
|
+
|
|
165
|
+
io.write(
|
|
166
|
+
f"""
|
|
167
|
+
// declare variables
|
|
168
|
+
float_point_t v[{n_node}];
|
|
169
|
+
"""
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
nc = n_constant
|
|
173
|
+
cs = (graph_obj.constant_vec_get(i) for i in range(nc))
|
|
174
|
+
cs_str = ", ".join(f"{c}" for c in cs)
|
|
175
|
+
io.write(
|
|
176
|
+
f"""
|
|
177
|
+
// constants
|
|
178
|
+
// set c[i] for i = 0, ..., nc-1
|
|
179
|
+
// nc = {nc}
|
|
180
|
+
static const float_point_t c[{nc}] = {{
|
|
181
|
+
{cs_str}
|
|
182
|
+
}};
|
|
183
|
+
"""
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
n_result_node = n_node
|
|
187
|
+
io.write(
|
|
188
|
+
f"""
|
|
189
|
+
// result nodes
|
|
190
|
+
// set v[i] for i = 0, ..., n_result_node-1
|
|
191
|
+
// n_result_node = {n_result_node}
|
|
192
|
+
"""
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def get_node_name(node: int) -> str:
|
|
196
|
+
if node < 1:
|
|
197
|
+
raise ValueError(f"Invalid node: {node}")
|
|
198
|
+
if node < 1 + np:
|
|
199
|
+
index = node - 1
|
|
200
|
+
if indirect_p:
|
|
201
|
+
return f"p[pi[{index}]]"
|
|
202
|
+
else:
|
|
203
|
+
return f"p[{index}]"
|
|
204
|
+
elif node < 1 + n_dynamic_ind + n_variable_ind:
|
|
205
|
+
if hessian_lagrange:
|
|
206
|
+
if node < 1 + np + nw:
|
|
207
|
+
index = node - 1 - np
|
|
208
|
+
if indirect_w:
|
|
209
|
+
return f"w[wi[{index}]]"
|
|
210
|
+
else:
|
|
211
|
+
return f"w[{index}]"
|
|
212
|
+
else:
|
|
213
|
+
index = node - 1 - np - nw
|
|
214
|
+
if indirect_x:
|
|
215
|
+
return f"x[xi[{index}]]"
|
|
216
|
+
else:
|
|
217
|
+
return f"x[{index}]"
|
|
218
|
+
else:
|
|
219
|
+
index = node - 1 - np
|
|
220
|
+
if indirect_x:
|
|
221
|
+
return f"x[xi[{index}]]"
|
|
222
|
+
else:
|
|
223
|
+
return f"x[{index}]"
|
|
224
|
+
elif node < 1 + n_dynamic_ind + n_variable_ind + n_constant:
|
|
225
|
+
index = node - 1 - n_dynamic_ind - n_variable_ind
|
|
226
|
+
return f"c[{index}]"
|
|
227
|
+
else:
|
|
228
|
+
node = node - 1 - n_dynamic_ind - n_variable_ind - n_constant
|
|
229
|
+
assert node < n_node
|
|
230
|
+
return f"v[{node}]"
|
|
231
|
+
|
|
232
|
+
result_node = 0
|
|
233
|
+
|
|
234
|
+
infix_operators = set(["+", "-", "*", "/"])
|
|
235
|
+
|
|
236
|
+
for iter in cpp_graph_iterator(graph_obj):
|
|
237
|
+
op = iter.op
|
|
238
|
+
n_result = iter.n_result
|
|
239
|
+
args = iter.args
|
|
240
|
+
|
|
241
|
+
assert n_result == 1
|
|
242
|
+
|
|
243
|
+
n_arg = len(args)
|
|
244
|
+
|
|
245
|
+
op_name = op2name.get(op, None)
|
|
246
|
+
if op_name is not None:
|
|
247
|
+
if n_arg == 1:
|
|
248
|
+
arg1 = get_node_name(args[0])
|
|
249
|
+
io.write(f" v[{result_node}] = {op_name}({arg1});\n")
|
|
250
|
+
elif n_arg == 2:
|
|
251
|
+
arg1 = get_node_name(args[0])
|
|
252
|
+
arg2 = get_node_name(args[1])
|
|
253
|
+
if op_name in infix_operators:
|
|
254
|
+
io.write(f" v[{result_node}] = {arg1} {op_name} {arg2};\n")
|
|
255
|
+
else:
|
|
256
|
+
io.write(f" v[{result_node}] = {op_name}({arg1}, {arg2});\n")
|
|
257
|
+
else:
|
|
258
|
+
message = f"Unknown n_arg: {n_arg} for op_enum: {op}\n"
|
|
259
|
+
raise ValueError(message)
|
|
260
|
+
elif op in compare_ops:
|
|
261
|
+
cmp_op = compare_ops_string[op]
|
|
262
|
+
assert n_arg == 4
|
|
263
|
+
predicate = get_node_name(args[0])
|
|
264
|
+
target = get_node_name(args[1])
|
|
265
|
+
then_value = get_node_name(args[2])
|
|
266
|
+
else_value = get_node_name(args[3])
|
|
267
|
+
io.write(
|
|
268
|
+
f" v[{result_node}] = {predicate} {cmp_op} {target} ? {then_value} : {else_value};\n"
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
message = f"Unknown name for op_enum: {op}\n"
|
|
272
|
+
debug_context = f"name: {name}\nfull graph:\n{str(graph_obj)}"
|
|
273
|
+
raise ValueError(message + debug_context)
|
|
274
|
+
|
|
275
|
+
result_node += n_result
|
|
276
|
+
|
|
277
|
+
io.write(
|
|
278
|
+
"""
|
|
279
|
+
// dependent variables
|
|
280
|
+
// set y[i] for i = 0, ny-1
|
|
281
|
+
"""
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
op = "="
|
|
285
|
+
if add_y:
|
|
286
|
+
op = "+="
|
|
287
|
+
for i in range(ny):
|
|
288
|
+
node = graph_obj.dependent_vec_get(i)
|
|
289
|
+
node_name = get_node_name(node)
|
|
290
|
+
if indirect_y:
|
|
291
|
+
assignment = f" y[yi[{i}]] {op} {node_name};\n"
|
|
292
|
+
else:
|
|
293
|
+
assignment = f" y[{i}] {op} {node_name};\n"
|
|
294
|
+
io.write(assignment)
|
|
295
|
+
|
|
296
|
+
io.write(
|
|
297
|
+
"""
|
|
298
|
+
// end function body
|
|
299
|
+
}
|
|
300
|
+
"""
|
|
301
|
+
)
|