glre 0.33.0 → 0.34.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "glre",
3
- "version": "0.33.0",
3
+ "version": "0.34.0",
4
4
  "author": "tseijp",
5
5
  "license": "MIT",
6
6
  "private": false,
package/src/node/code.ts CHANGED
@@ -15,9 +15,9 @@ import {
15
15
  parseUniformHead,
16
16
  } from './parse'
17
17
  import { getBluiltin, getOperator, formatConversions, safeEventCall, getEventFun, initNodeContext } from './utils'
18
- import type { NodeContext, NodeProxy, X } from './types'
18
+ import type { Constants, NodeContext, X } from './types'
19
19
 
20
- export const code = (target: X, c?: NodeContext | null): string => {
20
+ export const code = <T extends Constants>(target: X<T>, c?: NodeContext | null): string => {
21
21
  if (!c) c = {}
22
22
  initNodeContext(c)
23
23
  if (is.arr(target)) return parseArray(target, c)
@@ -25,6 +25,8 @@ export const code = (target: X, c?: NodeContext | null): string => {
25
25
  if (is.num(target)) {
26
26
  const ret = `${target}`
27
27
  if (ret.includes('.')) return ret
28
+ // Check if this number should be an integer based on the inferred type
29
+ // For now, keep the original behavior to maintain compatibility
28
30
  return ret + '.0'
29
31
  }
30
32
  if (is.bol(target)) return target ? 'true' : 'false'
@@ -39,8 +41,8 @@ export const code = (target: X, c?: NodeContext | null): string => {
39
41
  if (type === 'member') return `${code(y, c)}.${code(x, c)}`
40
42
  if (type === 'ternary')
41
43
  return c.isWebGL
42
- ? `(${code(x, c)} ? ${code(y, c)} : ${code(z, c)})`
43
- : `select(${code(z, c)}, ${code(y, c)}, ${code(x, c)})`
44
+ ? `(${code(z, c)} ? ${code(x, c)} : ${code(y, c)})`
45
+ : `select(${code(x, c)}, ${code(y, c)}, ${code(z, c)})`
44
46
  if (type === 'conversion') return `${formatConversions(x, c)}(${parseArray(children.slice(1), c)})`
45
47
  if (type === 'operator') {
46
48
  if (x === 'not' || x === 'bitNot') return `!${code(y, c)}`
@@ -59,8 +61,8 @@ export const code = (target: X, c?: NodeContext | null): string => {
59
61
  if (type === 'return') return `return ${code(x, c)};`
60
62
  if (type === 'loop')
61
63
  return c.isWebGL
62
- ? `for (int i = 0; i < ${x}; i += 1) {\n${code(y, c)}\n}`
63
- : `for (var i: i32 = 0; i < ${x}; i++) {\n${code(y, c)}\n}`
64
+ ? `for (int i = 0; i < ${code(x, c)}; i += 1) {\n${code(y, c)}\n}`
65
+ : `for (var i: i32 = 0; i < ${code(x, c)}; i++) {\n${code(y, c)}\n}`
64
66
  if (type === 'if') return parseIf(c, x, y, children)
65
67
  if (type === 'switch') return parseSwitch(c, x, children)
66
68
  if (type === 'declare') return parseDeclare(c, x, y)
@@ -70,7 +72,7 @@ export const code = (target: X, c?: NodeContext | null): string => {
70
72
  }
71
73
  if (type === 'struct') {
72
74
  if (!c.code?.headers.has(id)) c.code?.headers.set(id, parseStructHead(c, id, fields))
73
- return parseStruct(c, id, (x as NodeProxy).props.id, fields, initialValues)
75
+ return parseStruct(c, id, x.props.id, fields, initialValues)
74
76
  }
75
77
  /**
76
78
  * headers
package/src/node/const.ts CHANGED
@@ -74,11 +74,19 @@ export const OPERATORS = {
74
74
 
75
75
  export const OPERATOR_KEYS = Object.keys(OPERATORS) as (keyof typeof OPERATORS)[]
76
76
 
77
- export const SCALAR_RETURN_FUNCTIONS = ['dot', 'distance', 'length', 'lengthSq', 'determinant', 'luminance'] as const
78
-
79
- export const BOOL_RETURN_FUNCTIONS = ['all', 'any'] as const
80
-
81
- export const PRESERVE_TYPE_FUNCTIONS = [
77
+ // All shader functions (type inference now handled by inferFrom)
78
+ export const FUNCTIONS = [
79
+ // Float return functions
80
+ 'dot',
81
+ 'distance',
82
+ 'length',
83
+ 'lengthSq',
84
+ 'determinant',
85
+ 'luminance',
86
+ // Bool return functions
87
+ 'all',
88
+ 'any',
89
+ // Component-wise functions (preserve input type)
82
90
  'abs',
83
91
  'sign',
84
92
  'floor',
@@ -92,6 +100,12 @@ export const PRESERVE_TYPE_FUNCTIONS = [
92
100
  'asin',
93
101
  'acos',
94
102
  'atan',
103
+ 'sinh',
104
+ 'cosh',
105
+ 'tanh',
106
+ 'asinh',
107
+ 'acosh',
108
+ 'atanh',
95
109
  'exp',
96
110
  'exp2',
97
111
  'log',
@@ -106,43 +120,38 @@ export const PRESERVE_TYPE_FUNCTIONS = [
106
120
  'dFdx',
107
121
  'dFdy',
108
122
  'fwidth',
109
- ] as const
110
-
111
- export const VEC3_RETURN_FUNCTIONS = ['cross'] as const
112
-
113
- export const FIRST_ARG_TYPE_FUNCTIONS = ['reflect', 'refract'] as const
114
-
115
- export const HIGHEST_TYPE_FUNCTIONS = ['min', 'max', 'mix', 'clamp', 'step', 'smoothstep'] as const
116
-
117
- export const VEC4_RETURN_FUNCTIONS = ['texture', 'textureLod', 'textureSize', 'cubeTexture'] as const
118
-
119
- export const ADDITIONAL_FUNCTIONS = [
120
- 'atan2',
121
123
  'degrees',
124
+ 'radians',
125
+ // Vector functions
126
+ 'cross',
127
+ 'reflect',
128
+ 'refract',
129
+ // Multi-argument functions
130
+ 'min',
131
+ 'max',
132
+ 'mix',
133
+ 'clamp',
134
+ 'step',
135
+ 'smoothstep',
136
+ 'pow',
137
+ 'atan2',
138
+ // Texture functions
139
+ 'texture',
140
+ 'textureLod',
141
+ 'textureSize',
142
+ 'cubeTexture',
143
+ // Utility functions
122
144
  'faceforward',
123
145
  'bitcast',
124
146
  'cbrt',
125
147
  'difference',
126
148
  'equals',
127
- 'pow',
128
149
  'pow2',
129
150
  'pow3',
130
151
  'pow4',
131
- 'radians',
132
152
  'transformDirection',
133
153
  ] as const
134
154
 
135
- export const FUNCTIONS = [
136
- ...SCALAR_RETURN_FUNCTIONS,
137
- ...BOOL_RETURN_FUNCTIONS,
138
- ...PRESERVE_TYPE_FUNCTIONS,
139
- ...VEC3_RETURN_FUNCTIONS,
140
- ...FIRST_ARG_TYPE_FUNCTIONS,
141
- ...HIGHEST_TYPE_FUNCTIONS,
142
- ...VEC4_RETURN_FUNCTIONS,
143
- ...ADDITIONAL_FUNCTIONS,
144
- ] as const
145
-
146
155
  export const COMPONENT_COUNT_TO_TYPE = {
147
156
  1: 'float',
148
157
  2: 'vec2',
@@ -152,6 +161,24 @@ export const COMPONENT_COUNT_TO_TYPE = {
152
161
  16: 'mat4',
153
162
  } as const
154
163
 
164
+ // Function return type mapping for method chaining
165
+ export const FUNCTION_RETURN_TYPES = {
166
+ // Always return vec4
167
+ texture: 'vec4',
168
+ cubeTexture: 'vec4',
169
+ textureSize: 'vec4',
170
+ // Always return float
171
+ length: 'float',
172
+ lengthSq: 'float',
173
+ distance: 'float',
174
+ dot: 'float',
175
+ // Always return bool
176
+ all: 'bool',
177
+ any: 'bool',
178
+ // Always return vec3
179
+ cross: 'vec3',
180
+ } as const
181
+
155
182
  export const BUILTIN_TYPES = {
156
183
  // WGSL builtin variables
157
184
  position: 'vec4',
package/src/node/index.ts CHANGED
@@ -2,7 +2,7 @@ import { is } from '../utils/helpers'
2
2
  import { code } from './code'
3
3
  import { builtin, conversion as c, function_ as f, uniform as u } from './node'
4
4
  import { hex2rgb, sortHeadersByDependencies } from './utils'
5
- import type { NodeContext, X } from './types'
5
+ import type { Constants as C, NodeContext, X, Vec2, Float } from './types'
6
6
  export * from './code'
7
7
  export * from './node'
8
8
  export * from './scope'
@@ -84,26 +84,26 @@ export const fragment = (x: X, c: NodeContext) => {
84
84
  }
85
85
 
86
86
  // Builtin Variables
87
- export const position = builtin('position')
88
- export const vertexIndex = builtin('vertex_index')
89
- export const instanceIndex = builtin('instance_index')
90
- export const frontFacing = builtin('front_facing')
91
- export const fragDepth = builtin('frag_depth')
92
- export const sampleIndex = builtin('sample_index')
93
- export const sampleMask = builtin('sample_mask')
94
- export const pointCoord = builtin('point_coord')
87
+ export const position = builtin<'vec4'>('position')
88
+ export const vertexIndex = builtin<'uint'>('vertex_index')
89
+ export const instanceIndex = builtin<'uint'>('instance_index')
90
+ export const frontFacing = builtin<'bool'>('front_facing')
91
+ export const fragDepth = builtin<'float'>('frag_depth')
92
+ export const sampleIndex = builtin<'uint'>('sample_index')
93
+ export const sampleMask = builtin<'uint'>('sample_mask')
94
+ export const pointCoord = builtin<'vec2'>('point_coord')
95
95
 
96
96
  // TSL Compatible Builtin Variables
97
- export const normalLocal = builtin('normalLocal')
98
- export const normalWorld = builtin('normalWorld')
99
- export const normalView = builtin('normalView')
100
- export const positionLocal = builtin('position')
101
- export const positionWorld = builtin('positionWorld')
102
- export const positionView = builtin('positionView')
103
- export const screenCoordinate = builtin('screenCoordinate')
104
- export const screenUV = builtin('screenUV')
105
-
106
- // Type constructors
97
+ export const positionLocal = builtin<'vec3'>('position')
98
+ export const positionWorld = builtin<'vec3'>('positionWorld')
99
+ export const positionView = builtin<'vec3'>('positionView')
100
+ export const normalLocal = builtin<'vec3'>('normalLocal')
101
+ export const normalWorld = builtin<'vec3'>('normalWorld')
102
+ export const normalView = builtin<'vec3'>('normalView')
103
+ export const screenCoordinate = builtin<'vec2'>('screenCoordinate')
104
+ export const screenUV = builtin<'vec2'>('screenUV')
105
+
106
+ // Type constructors with proper type inference
107
107
  export const float = (x?: X) => c('float', x)
108
108
  export const int = (x?: X) => c('int', x)
109
109
  export const uint = (x?: X) => c('uint', x)
@@ -130,80 +130,92 @@ export const color = (r?: X, g?: X, b?: X) => {
130
130
  return vec3(r, g, b)
131
131
  }
132
132
 
133
- // Default uniforms
134
- export const iResolution = u(vec2(), 'iResolution')
135
- export const iMouse = u(vec2(), 'iMouse')
136
- export const iTime = u(float(), 'iTime')
137
- export const uv = () => position.xy.div(iResolution)
138
-
139
- // Texture Functions
140
- export const texture = (x: X, y: X, z?: X) => f('texture', x, y, z)
141
- export const cubeTexture = (x: X, y: X, z?: X) => f('cubeTexture', x, y, z)
142
- export const textureSize = (x: X, y?: X) => f('textureSize', x, y)
143
-
144
- // Math Functions
145
- export const abs = (x: X) => f('abs', x)
146
- export const acos = (x: X) => f('acos', x)
147
- export const all = (x: X) => f('all', x)
148
- export const any = (x: X) => f('any', x)
149
- export const asin = (x: X) => f('asin', x)
150
- export const atan = (y: X, x?: X) => (x !== undefined ? f('atan', y, x) : f('atan', y))
151
- export const atan2 = (y: X, x: X) => f('atan', y, x)
152
- export const bitcast = (x: X, y: X) => f('bitcast', x, y)
153
- export const cbrt = (x: X) => f('cbrt', x)
154
- export const ceil = (x: X) => f('ceil', x)
155
- export const clamp = (x: X, min: X, max: X) => f('clamp', x, min, max)
156
- export const cos = (x: X) => f('cos', x)
157
- export const cross = (x: X, y: X) => f('cross', x, y)
158
- export const dFdx = (p: X) => f('dFdx', p)
159
- export const dFdy = (p: X) => f('dFdy', p)
160
- export const degrees = (radians: X) => f('degrees', radians)
161
- export const difference = (x: X, y: X) => f('difference', x, y)
162
- export const distance = (x: X, y: X) => f('distance', x, y)
163
- export const dot = (x: X, y: X) => f('dot', x, y)
164
- export const equals = (x: X, y: X) => f('equals', x, y)
165
- export const exp = (x: X) => f('exp', x)
166
- export const exp2 = (x: X) => f('exp2', x)
167
- export const faceforward = (N: X, I: X, Nref: X) => f('faceforward', N, I, Nref)
168
- export const floor = (x: X) => f('floor', x)
169
- export const fract = (x: X) => f('fract', x)
170
- export const fwidth = (x: X) => f('fwidth', x)
171
- export const inverseSqrt = (x: X) => f('inverseSqrt', x)
172
- export const length = (x: X) => f('length', x)
173
- export const lengthSq = (x: X) => f('lengthSq', x)
174
- export const log = (x: X) => f('log', x)
175
- export const log2 = (x: X) => f('log2', x)
176
- export const max = (x: X, y: X) => f('max', x, y)
177
- export const min = (x: X, y: X) => f('min', x, y)
178
- export const mix = (x: X, y: X, a: X) => f('mix', x, y, a)
179
- export const negate = (x: X) => f('negate', x)
180
- export const normalize = (x: X) => f('normalize', x)
181
- export const oneMinus = (x: X) => f('oneMinus', x)
182
- export const pow = (x: X, y: X) => f('pow', x, y)
183
- export const pow2 = (x: X) => f('pow2', x)
184
- export const pow3 = (x: X) => f('pow3', x)
185
- export const pow4 = (x: X) => f('pow4', x)
186
- export const radians = (degrees: X) => f('radians', degrees)
187
- export const reciprocal = (x: X) => f('reciprocal', x)
188
- export const reflect = (I: X, N: X) => f('reflect', I, N)
189
- export const refract = (I: X, N: X, eta: X) => f('refract', I, N, eta)
190
- export const round = (x: X) => f('round', x)
191
- export const saturate = (x: X) => f('saturate', x)
192
- export const sign = (x: X) => f('sign', x)
193
- export const sin = (x: X) => f('sin', x)
194
- export const smoothstep = (e0: X, e1: X, x: X) => f('smoothstep', e0, e1, x)
195
- export const sqrt = (x: X) => f('sqrt', x)
196
- export const step = (edge: X, x: X) => f('step', edge, x)
197
- export const tan = (x: X) => f('tan', x)
198
- export const transformDirection = (dir: X, matrix: X) => f('transformDirection', dir, matrix)
199
- export const trunc = (x: X) => f('trunc', x)
200
-
201
- // // Struct functions
202
- // export const struct = (fields: Record<string, X>) => {
203
- // const id = getId()
204
- // const structNode = node('struct', { id, fields })
205
- // // Create constructor function
206
- // const constructor = () => node('struct', { id: getId(), type: id })
207
- // Object.assign(constructor, structNode)
208
- // return constructor as any
209
- // }
133
+ // Default uniforms with proper typing
134
+ export const iResolution: Vec2 = u(vec2(), 'iResolution')
135
+ export const iMouse: Vec2 = u(vec2(), 'iMouse')
136
+ export const iTime: Float = u(float(), 'iTime')
137
+ export const uv = position.xy.div(iResolution)
138
+
139
+ // Texture Functions - always return vec4
140
+ export const texture = (x: X, y: X, z?: X) => f<'vec4'>('texture', x, y, z)
141
+ export const cubeTexture = (x: X, y: X, z?: X) => f<'vec4'>('cubeTexture', x, y, z)
142
+ export const textureSize = (x: X, y?: X) => f<'vec4'>('textureSize', x, y)
143
+
144
+ // Functions that always return float regardless of input
145
+ export const length = (x: X) => f<'float'>('length', x)
146
+ export const lengthSq = (x: X) => f<'float'>('lengthSq', x)
147
+ export const distance = (x: X, y: X) => f<'float'>('distance', x, y)
148
+ export const dot = (x: X, y: X) => f<'float'>('dot', x, y)
149
+
150
+ // Functions that always return bool
151
+ export const all = <T extends C>(x: X<T>) => f<'bool'>('all', x)
152
+ export const any = <T extends C>(x: X<T>) => f<'bool'>('any', x)
153
+
154
+ // Functions that always return vec3 (cross product only works with vec3)
155
+ export const cross = (x: X<'vec3'>, y: X<'vec3'>) => f<'vec3'>('cross', x, y)
156
+
157
+ // Component-wise functions - preserve input type (T -> T)
158
+ export const abs = <T extends C>(x: X<T>) => f<T>('abs', x)
159
+ export const sign = <T extends C>(x: X<T>) => f<T>('sign', x)
160
+ export const floor = <T extends C>(x: X<T>) => f<T>('floor', x)
161
+ export const ceil = <T extends C>(x: X<T>) => f<T>('ceil', x)
162
+ export const round = <T extends C>(x: X<T>) => f<T>('round', x)
163
+ export const fract = <T extends C>(x: X<T>) => f<T>('fract', x)
164
+ export const trunc = <T extends C>(x: X<T>) => f<T>('trunc', x)
165
+ export const sin = <T extends C>(x: X<T>) => f<T>('sin', x)
166
+ export const cos = <T extends C>(x: X<T>) => f<T>('cos', x)
167
+ export const tan = <T extends C>(x: X<T>) => f<T>('tan', x)
168
+ export const asin = <T extends C>(x: X<T>) => f<T>('asin', x)
169
+ export const acos = <T extends C>(x: X<T>) => f<T>('acos', x)
170
+ export const atan = <T extends C>(x: X<T>) => f<T>('atan', x)
171
+ export const sinh = <T extends C>(x: X<T>) => f<T>('sinh', x)
172
+ export const cosh = <T extends C>(x: X<T>) => f<T>('cosh', x)
173
+ export const tanh = <T extends C>(x: X<T>) => f<T>('tanh', x)
174
+ export const asinh = <T extends C>(x: X<T>) => f<T>('asinh', x)
175
+ export const acosh = <T extends C>(x: X<T>) => f<T>('acosh', x)
176
+ export const atanh = <T extends C>(x: X<T>) => f<T>('atanh', x)
177
+ export const exp = <T extends C>(x: X<T>) => f<T>('exp', x)
178
+ export const exp2 = <T extends C>(x: X<T>) => f<T>('exp2', x)
179
+ export const log = <T extends C>(x: X<T>) => f<T>('log', x)
180
+ export const log2 = <T extends C>(x: X<T>) => f<T>('log2', x)
181
+ export const sqrt = <T extends C>(x: X<T>) => f<T>('sqrt', x)
182
+ export const inverseSqrt = <T extends C>(x: X<T>) => f<T>('inverseSqrt', x)
183
+ export const normalize = <T extends C>(x: X<T>) => f<T>('normalize', x)
184
+ export const oneMinus = <T extends C>(x: X<T>) => f<T>('oneMinus', x)
185
+ export const saturate = <T extends C>(x: X<T>) => f<T>('saturate', x)
186
+ export const negate = <T extends C>(x: X<T>) => f<T>('negate', x)
187
+ export const reciprocal = <T extends C>(x: X<T>) => f<T>('reciprocal', x)
188
+ export const dFdx = <T extends C>(x: X<T>) => f<T>('dFdx', x)
189
+ export const dFdy = <T extends C>(x: X<T>) => f<T>('dFdy', x)
190
+ export const fwidth = <T extends C>(x: X<T>) => f<T>('fwidth', x)
191
+ export const degrees = <T extends C>(x: X<T>) => f<T>('degrees', x)
192
+ export const radians = <T extends C>(x: X<T>) => f<T>('radians', x)
193
+
194
+ // Functions where first argument determines return type
195
+ export const reflect = <T extends C>(I: X<T>, N: X) => f<T>('reflect', I, N)
196
+ export const refract = <T extends C>(I: X<T>, N: X, eta: X) => f<T>('refract', I, N, eta)
197
+
198
+ // Functions with highest priority type among arguments (using first arg for simplicity)
199
+ export const min = <T extends C>(x: X<T>, y: X) => f<T>('min', x, y)
200
+ export const max = <T extends C>(x: X<T>, y: X) => f<T>('max', x, y)
201
+ export const mix = <T extends C>(x: X<T>, y: X, a: X) => f<T>('mix', x, y, a)
202
+ export const clamp = <T extends C>(x: X<T>, minVal: X, maxVal: X) => f<T>('clamp', x, minVal, maxVal)
203
+ export const step = <T extends C>(edge: X, x: X<T>) => f<T>('step', edge, x)
204
+ export const smoothstep = <T extends C>(e0: X, e1: X, x: X<T>) => f<T>('smoothstep', e0, e1, x)
205
+
206
+ // Two-argument functions with highest priority type
207
+ export const atan2 = <T extends C>(y: X<T>, x: X) => f<T>('atan2', y, x)
208
+ export const pow = <T extends C>(x: X<T>, y: X) => f<T>('pow', x, y)
209
+
210
+ // Component-wise power functions
211
+ export const pow2 = <T extends C>(x: X<T>) => f<T>('pow2', x)
212
+ export const pow3 = <T extends C>(x: X<T>) => f<T>('pow3', x)
213
+ export const pow4 = <T extends C>(x: X<T>) => f<T>('pow4', x)
214
+
215
+ // Utility functions
216
+ export const bitcast = <T extends C>(x: X<T>, y: X) => f<T>('bitcast', x, y)
217
+ export const cbrt = <T extends C>(x: X<T>) => f<T>('cbrt', x)
218
+ export const difference = <T extends C>(x: X<T>, y: X) => f<T>('difference', x, y)
219
+ export const equals = (x: X, y: X) => f<'bool'>('equals', x, y)
220
+ export const faceforward = <T extends C>(N: X<T>, I: X, Nref: X) => f<T>('faceforward', N, I, Nref)
221
+ export const transformDirection = <T extends C>(dir: X<T>, matrix: X) => f<T>('transformDirection', dir, matrix)
package/src/node/infer.ts CHANGED
@@ -1,108 +1,91 @@
1
1
  import { is } from '../utils/helpers'
2
2
  import {
3
- BOOL_RETURN_FUNCTIONS,
4
3
  BUILTIN_TYPES,
5
4
  COMPARISON_OPERATORS,
6
5
  COMPONENT_COUNT_TO_TYPE,
7
- CONSTANTS,
8
- FIRST_ARG_TYPE_FUNCTIONS,
9
- HIGHEST_TYPE_FUNCTIONS,
6
+ FUNCTION_RETURN_TYPES,
10
7
  LOGICAL_OPERATORS,
11
- PRESERVE_TYPE_FUNCTIONS,
12
- SCALAR_RETURN_FUNCTIONS,
13
- VEC3_RETURN_FUNCTIONS,
14
- VEC4_RETURN_FUNCTIONS,
15
8
  } from './const'
16
9
  import { isConstants, isNodeProxy, isSwizzle } from './utils'
17
- import type { Constants, NodeContext, NodeProxy, X } from './types'
10
+ import type { Constants as C, NodeContext, NodeProxy, X } from './types'
18
11
 
19
- const getHighestPriorityType = (args: X[], c: NodeContext) => {
20
- return args.reduce((highest, current) => {
21
- const currentType = infer(current, c)
22
- const highestPriority = CONSTANTS.indexOf(highest as any)
23
- const currentPriority = CONSTANTS.indexOf(currentType as any)
24
- return currentPriority > highestPriority ? currentType : highest
25
- }, 'float') as Constants
12
+ const inferBuiltin = <T extends C>(id: string | undefined) => {
13
+ return BUILTIN_TYPES[id as keyof typeof BUILTIN_TYPES] as T
26
14
  }
27
15
 
28
- const inferBuiltin = (id: string | undefined): Constants => {
29
- return BUILTIN_TYPES[id as keyof typeof BUILTIN_TYPES]!
16
+ // Unified logic with types.ts InferOperator type
17
+ const inferOperator = <T extends C>(L: T, R: T, op: string): T => {
18
+ if (COMPARISON_OPERATORS.includes(op as any) || LOGICAL_OPERATORS.includes(op as any)) return 'bool' as T
19
+ if (L === R) return L
20
+ // broadcast
21
+ if (L === 'float' || L === 'int') return R
22
+ if (R === 'float' || R === 'int') return L
23
+ // mat * vec → vec
24
+ if (L === 'mat4' && R === 'vec4') return R
25
+ if (L === 'mat3' && R === 'vec3') return R
26
+ if (L === 'mat2' && R === 'vec2') return R
27
+ // vec * mat → vec
28
+ if (L === 'vec4' && R === 'mat4') return L
29
+ if (L === 'vec3' && R === 'mat3') return L
30
+ if (L === 'vec2' && R === 'mat2') return L
31
+ return L
30
32
  }
31
33
 
32
- const inferFunction = (funcName: string, args: X[], c: NodeContext): Constants => {
33
- const firstArgType = args.length > 0 ? infer(args[0], c) : 'float'
34
- if (FIRST_ARG_TYPE_FUNCTIONS.includes(funcName as any)) return firstArgType
35
- if (SCALAR_RETURN_FUNCTIONS.includes(funcName as any)) return 'float'
36
- if (BOOL_RETURN_FUNCTIONS.includes(funcName as any)) return 'bool'
37
- if (PRESERVE_TYPE_FUNCTIONS.includes(funcName as any)) return firstArgType
38
- if (VEC3_RETURN_FUNCTIONS.includes(funcName as any)) return 'vec3'
39
- if (VEC4_RETURN_FUNCTIONS.includes(funcName as any)) return 'vec4'
40
- if (HIGHEST_TYPE_FUNCTIONS.includes(funcName as any)) return getHighestPriorityType(args, c)
41
- return firstArgType
34
+ export const inferPrimitiveType = <T extends C>(x: X) => {
35
+ if (is.bol(x)) return 'bool' as T
36
+ if (is.str(x)) return 'texture' as T
37
+ if (is.num(x)) return 'float' as T // @TODO FIX: Number.isInteger(x) ? 'int' : 'float'
38
+ if (is.arr(x)) return COMPONENT_COUNT_TO_TYPE[x.length as keyof typeof COMPONENT_COUNT_TO_TYPE] as T
39
+ return 'float' as T
42
40
  }
43
41
 
44
- const inferOperator = (leftType: string, rightType: string, op: string): Constants => {
45
- if (COMPARISON_OPERATORS.includes(op as any)) return 'bool'
46
- if (LOGICAL_OPERATORS.includes(op as any)) return 'bool'
47
- if (leftType === rightType) return leftType as Constants
48
- if (leftType.includes('vec') && !rightType.includes('vec')) return leftType as Constants
49
- if (rightType.includes('vec') && !leftType.includes('vec')) return rightType as Constants
50
- const leftPriority = CONSTANTS.indexOf(leftType as any)
51
- const rightPriority = CONSTANTS.indexOf(rightType as any)
52
- return (leftPriority >= rightPriority ? leftType : rightType) as Constants
42
+ const inferFromCount = <T extends C>(count: number) => {
43
+ return COMPONENT_COUNT_TO_TYPE[count as keyof typeof COMPONENT_COUNT_TO_TYPE] as T
53
44
  }
54
45
 
55
- export const inferPrimitiveType = (x: any): Constants => {
56
- if (is.bol(x)) return 'bool'
57
- if (is.str(x)) return 'texture'
58
- if (is.num(x)) return Number.isInteger(x) ? 'int' : 'float'
59
- if (is.arr(x)) return COMPONENT_COUNT_TO_TYPE[x.length as keyof typeof COMPONENT_COUNT_TO_TYPE] || 'float'
60
- return 'float'
61
- }
62
-
63
- const inferFromCount = (count: number): Constants => {
64
- return COMPONENT_COUNT_TO_TYPE[count as keyof typeof COMPONENT_COUNT_TO_TYPE]!
65
- }
66
-
67
- const inferFromArray = (arr: X[], c: NodeContext): Constants => {
68
- if (arr.length === 0) return 'void'
46
+ const inferFromArray = <T extends C>(arr: X<T>[], c: NodeContext) => {
47
+ if (arr.length === 0) return 'void' as T
69
48
  const [x] = arr
70
- if (is.str(x)) return x as Constants // for struct
49
+ if (is.str(x)) return x as T // for struct
71
50
  const ret = infer(x, c)
72
51
  for (const x of arr.slice(1))
73
52
  if (ret !== infer(x, c)) throw new Error(`glre node system error: defined scope return mismatch`)
74
53
  return ret
75
54
  }
76
55
 
77
- export const inferImpl = (target: NodeProxy, c: NodeContext): Constants => {
56
+ export const inferFunction = <T extends C>(x: X) => {
57
+ return FUNCTION_RETURN_TYPES[x as keyof typeof FUNCTION_RETURN_TYPES] as T
58
+ }
59
+
60
+ export const inferImpl = <T extends C>(target: NodeProxy<T>, c: NodeContext): T => {
78
61
  const { type, props } = target
79
- const { id, children = [], layout, inferFrom } = props
62
+ const { id, children = [], inferFrom, layout } = props
80
63
  const [x, y, z] = children
81
- if (type === 'conversion') return x as Constants
82
- if (type === 'operator') return inferOperator(infer(y, c), infer(z, c), x as string)
83
- if (type === 'function') return inferFunction(x as string, children.slice(1), c)
64
+ if (type === 'conversion') return x
65
+ if (type === 'operator') return inferOperator(infer(y, c), infer(z, c), x)
84
66
  if (type === 'ternary') return inferOperator(infer(y, c), infer(z, c), 'add')
85
67
  if (type === 'builtin') return inferBuiltin(id)
86
- if (type === 'define' && isConstants(layout?.type)) return layout?.type
68
+ if (type === 'function') return inferFunction(x) || infer(y, c)
69
+ if (type === 'define' && isConstants(layout?.type)) return layout?.type as T
87
70
  if (type === 'attribute' && is.arr(x) && c.gl?.count) return inferFromCount(x.length / c.gl.count)
88
71
  if (type === 'member') {
89
72
  if (isSwizzle(x)) return inferFromCount(x.length)
90
73
  if (isNodeProxy(y) && is.str(x)) {
91
- const field = y.props.fields?.[x] // for variable node of struct member
74
+ const field = (y as any).props.fields?.[x] // for variable node of struct member
92
75
  if (field) return infer(field, c)
93
76
  }
94
- return 'float' // fallback @TODO FIX
77
+ return 'float' as T // fallback @TODO FIX
95
78
  }
96
79
  if (inferFrom) return inferFromArray(inferFrom, c)
97
- return infer(x, c)
80
+ return infer(x, c) // for uniform
98
81
  }
99
82
 
100
- export const infer = (target: X, c?: NodeContext | null): Constants => {
83
+ export const infer = <T extends C>(target: X<T>, c?: NodeContext | null): T => {
101
84
  if (!c) c = {}
102
85
  if (!isNodeProxy(target)) return inferPrimitiveType(target)
103
86
  if (is.arr(target)) return inferFromCount(target.length)
104
- if (!c.infers) c.infers = new WeakMap<NodeProxy, Constants>()
105
- if (c.infers.has(target)) return c.infers.get(target)!
87
+ if (!c.infers) c.infers = new WeakMap<NodeProxy<T>, C>()
88
+ if (c.infers.has(target)) return c.infers.get(target) as T
106
89
  const ret = inferImpl(target, c)
107
90
  c.infers.set(target, ret)
108
91
  return ret