pytrilogy 0.0.2.58__py3-none-any.whl → 0.0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {pytrilogy-0.0.2.58.dist-info → pytrilogy-0.0.3.0.dist-info}/METADATA +9 -2
  2. pytrilogy-0.0.3.0.dist-info/RECORD +99 -0
  3. {pytrilogy-0.0.2.58.dist-info → pytrilogy-0.0.3.0.dist-info}/WHEEL +1 -1
  4. trilogy/__init__.py +2 -2
  5. trilogy/core/enums.py +1 -7
  6. trilogy/core/env_processor.py +17 -5
  7. trilogy/core/environment_helpers.py +11 -25
  8. trilogy/core/exceptions.py +4 -0
  9. trilogy/core/functions.py +695 -261
  10. trilogy/core/graph_models.py +10 -10
  11. trilogy/core/internal.py +11 -2
  12. trilogy/core/models/__init__.py +0 -0
  13. trilogy/core/models/author.py +2110 -0
  14. trilogy/core/models/build.py +1845 -0
  15. trilogy/core/models/build_environment.py +151 -0
  16. trilogy/core/models/core.py +370 -0
  17. trilogy/core/models/datasource.py +297 -0
  18. trilogy/core/models/environment.py +696 -0
  19. trilogy/core/models/execute.py +931 -0
  20. trilogy/core/optimization.py +14 -16
  21. trilogy/core/optimizations/base_optimization.py +1 -1
  22. trilogy/core/optimizations/inline_constant.py +6 -6
  23. trilogy/core/optimizations/inline_datasource.py +17 -11
  24. trilogy/core/optimizations/predicate_pushdown.py +17 -16
  25. trilogy/core/processing/concept_strategies_v3.py +180 -145
  26. trilogy/core/processing/graph_utils.py +1 -1
  27. trilogy/core/processing/node_generators/basic_node.py +19 -18
  28. trilogy/core/processing/node_generators/common.py +50 -44
  29. trilogy/core/processing/node_generators/filter_node.py +26 -13
  30. trilogy/core/processing/node_generators/group_node.py +26 -21
  31. trilogy/core/processing/node_generators/group_to_node.py +11 -8
  32. trilogy/core/processing/node_generators/multiselect_node.py +60 -43
  33. trilogy/core/processing/node_generators/node_merge_node.py +76 -38
  34. trilogy/core/processing/node_generators/rowset_node.py +57 -36
  35. trilogy/core/processing/node_generators/select_helpers/datasource_injection.py +27 -34
  36. trilogy/core/processing/node_generators/select_merge_node.py +161 -64
  37. trilogy/core/processing/node_generators/select_node.py +13 -13
  38. trilogy/core/processing/node_generators/union_node.py +12 -11
  39. trilogy/core/processing/node_generators/unnest_node.py +9 -7
  40. trilogy/core/processing/node_generators/window_node.py +19 -16
  41. trilogy/core/processing/nodes/__init__.py +21 -18
  42. trilogy/core/processing/nodes/base_node.py +82 -66
  43. trilogy/core/processing/nodes/filter_node.py +19 -13
  44. trilogy/core/processing/nodes/group_node.py +50 -35
  45. trilogy/core/processing/nodes/merge_node.py +45 -36
  46. trilogy/core/processing/nodes/select_node_v2.py +53 -39
  47. trilogy/core/processing/nodes/union_node.py +5 -7
  48. trilogy/core/processing/nodes/unnest_node.py +7 -11
  49. trilogy/core/processing/nodes/window_node.py +9 -4
  50. trilogy/core/processing/utility.py +103 -75
  51. trilogy/core/query_processor.py +65 -47
  52. trilogy/core/statements/__init__.py +0 -0
  53. trilogy/core/statements/author.py +413 -0
  54. trilogy/core/statements/build.py +0 -0
  55. trilogy/core/statements/common.py +30 -0
  56. trilogy/core/statements/execute.py +42 -0
  57. trilogy/dialect/base.py +146 -106
  58. trilogy/dialect/common.py +9 -10
  59. trilogy/dialect/duckdb.py +1 -1
  60. trilogy/dialect/enums.py +4 -2
  61. trilogy/dialect/presto.py +1 -1
  62. trilogy/dialect/sql_server.py +1 -1
  63. trilogy/executor.py +44 -32
  64. trilogy/hooks/base_hook.py +6 -4
  65. trilogy/hooks/query_debugger.py +110 -93
  66. trilogy/parser.py +1 -1
  67. trilogy/parsing/common.py +303 -64
  68. trilogy/parsing/parse_engine.py +263 -617
  69. trilogy/parsing/render.py +50 -26
  70. trilogy/scripts/trilogy.py +2 -1
  71. pytrilogy-0.0.2.58.dist-info/RECORD +0 -87
  72. trilogy/core/models.py +0 -4960
  73. {pytrilogy-0.0.2.58.dist-info → pytrilogy-0.0.3.0.dist-info}/LICENSE.md +0 -0
  74. {pytrilogy-0.0.2.58.dist-info → pytrilogy-0.0.3.0.dist-info}/entry_points.txt +0 -0
  75. {pytrilogy-0.0.2.58.dist-info → pytrilogy-0.0.3.0.dist-info}/top_level.txt +0 -0
trilogy/core/models.py DELETED
@@ -1,4960 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import difflib
4
- import hashlib
5
- import os
6
- from abc import ABC
7
- from collections import UserDict, UserList, defaultdict
8
- from datetime import date, datetime
9
- from enum import Enum
10
- from functools import cached_property
11
- from pathlib import Path
12
- from typing import (
13
- Annotated,
14
- Any,
15
- Callable,
16
- Dict,
17
- Generic,
18
- ItemsView,
19
- List,
20
- Never,
21
- Optional,
22
- Self,
23
- Sequence,
24
- Set,
25
- Tuple,
26
- Type,
27
- TypeVar,
28
- Union,
29
- ValuesView,
30
- get_args,
31
- )
32
-
33
- from lark.tree import Meta
34
- from pydantic import (
35
- BaseModel,
36
- ConfigDict,
37
- Field,
38
- ValidationInfo,
39
- computed_field,
40
- field_validator,
41
- )
42
- from pydantic.functional_validators import PlainValidator
43
- from pydantic_core import core_schema
44
-
45
- from trilogy.constants import (
46
- CONFIG,
47
- DEFAULT_NAMESPACE,
48
- ENV_CACHE_NAME,
49
- MagicConstants,
50
- logger,
51
- )
52
- from trilogy.core.constants import (
53
- ALL_ROWS_CONCEPT,
54
- CONSTANT_DATASET,
55
- INTERNAL_NAMESPACE,
56
- PERSISTED_CONCEPT_PREFIX,
57
- )
58
- from trilogy.core.enums import (
59
- BooleanOperator,
60
- ComparisonOperator,
61
- ConceptSource,
62
- DatePart,
63
- FunctionClass,
64
- FunctionType,
65
- Granularity,
66
- InfiniteFunctionArgs,
67
- IOType,
68
- JoinType,
69
- Modifier,
70
- Ordering,
71
- Purpose,
72
- PurposeLineage,
73
- SelectFiltering,
74
- ShowCategory,
75
- SourceType,
76
- WindowOrder,
77
- WindowType,
78
- )
79
- from trilogy.core.exceptions import (
80
- InvalidSyntaxException,
81
- UndefinedConceptException,
82
- )
83
- from trilogy.utility import unique
84
-
85
- LOGGER_PREFIX = "[MODELS]"
86
-
87
- KT = TypeVar("KT")
88
- VT = TypeVar("VT")
89
- LT = TypeVar("LT")
90
-
91
-
92
- def is_compatible_datatype(left, right):
93
- # for unknown types, we can't make any assumptions
94
- if right == DataType.UNKNOWN or left == DataType.UNKNOWN:
95
- return True
96
- if left == right:
97
- return True
98
- if {left, right} == {DataType.NUMERIC, DataType.FLOAT}:
99
- return True
100
- if {left, right} == {DataType.NUMERIC, DataType.INTEGER}:
101
- return True
102
- if {left, right} == {DataType.FLOAT, DataType.INTEGER}:
103
- return True
104
- return False
105
-
106
-
107
- def get_version():
108
- from trilogy import __version__
109
-
110
- return __version__
111
-
112
-
113
- def address_with_namespace(address: str, namespace: str) -> str:
114
- if address.split(".", 1)[0] == DEFAULT_NAMESPACE:
115
- return f"{namespace}.{address.split('.',1)[1]}"
116
- return f"{namespace}.{address}"
117
-
118
-
119
- def get_concept_arguments(expr) -> List["Concept"]:
120
- output = []
121
- if isinstance(expr, Concept):
122
- output += [expr]
123
-
124
- elif isinstance(
125
- expr,
126
- (
127
- Comparison,
128
- Conditional,
129
- Function,
130
- Parenthetical,
131
- AggregateWrapper,
132
- CaseWhen,
133
- CaseElse,
134
- ),
135
- ):
136
- output += expr.concept_arguments
137
- return output
138
-
139
-
140
- ALL_TYPES = Union[
141
- "DataType", "MapType", "ListType", "NumericType", "StructType", "Concept"
142
- ]
143
-
144
- NAMESPACED_TYPES = Union[
145
- "WindowItem",
146
- "FilterItem",
147
- "Conditional",
148
- "Comparison",
149
- "Concept",
150
- "CaseWhen",
151
- "CaseElse",
152
- "Function",
153
- "AggregateWrapper",
154
- "Parenthetical",
155
- ]
156
-
157
-
158
- class Namespaced(ABC):
159
- def with_namespace(self, namespace: str):
160
- raise NotImplementedError
161
-
162
-
163
- class Mergeable(ABC):
164
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
165
- raise NotImplementedError
166
-
167
- def hydrate_missing(self, concepts: EnvironmentConceptDict):
168
- return self
169
-
170
-
171
- class ConceptArgs(ABC):
172
- @property
173
- def concept_arguments(self) -> List["Concept"]:
174
- raise NotImplementedError
175
-
176
- @property
177
- def existence_arguments(self) -> list[tuple["Concept", ...]]:
178
- return []
179
-
180
- @property
181
- def row_arguments(self) -> List["Concept"]:
182
- return self.concept_arguments
183
-
184
-
185
- class SelectContext(ABC):
186
- def with_select_context(
187
- self,
188
- local_concepts: dict[str, Concept],
189
- grain: Grain,
190
- environment: Environment,
191
- ) -> Any:
192
- raise NotImplementedError
193
-
194
-
195
- class ConstantInlineable(ABC):
196
- def inline_concept(self, concept: Concept):
197
- raise NotImplementedError
198
-
199
-
200
- class HasUUID(ABC):
201
- @property
202
- def uuid(self) -> str:
203
- return hashlib.md5(str(self).encode()).hexdigest()
204
-
205
-
206
- class SelectTypeMixin(BaseModel):
207
- where_clause: Union["WhereClause", None] = Field(default=None)
208
- having_clause: Union["HavingClause", None] = Field(default=None)
209
-
210
- @property
211
- def output_components(self) -> List[Concept]:
212
- raise NotImplementedError
213
-
214
- @property
215
- def implicit_where_clause_selections(self) -> List[Concept]:
216
- if not self.where_clause:
217
- return []
218
- filter = set(
219
- [
220
- str(x.address)
221
- for x in self.where_clause.row_arguments
222
- if not x.derivation == PurposeLineage.CONSTANT
223
- ]
224
- )
225
- query_output = set([str(z.address) for z in self.output_components])
226
- delta = filter.difference(query_output)
227
- if delta:
228
- return [
229
- x for x in self.where_clause.row_arguments if str(x.address) in delta
230
- ]
231
- return []
232
-
233
- @property
234
- def where_clause_category(self) -> SelectFiltering:
235
- if not self.where_clause:
236
- return SelectFiltering.NONE
237
- elif self.implicit_where_clause_selections:
238
- return SelectFiltering.IMPLICIT
239
- return SelectFiltering.EXPLICIT
240
-
241
-
242
- class DataType(Enum):
243
- # PRIMITIVES
244
- STRING = "string"
245
- BOOL = "bool"
246
- MAP = "map"
247
- LIST = "list"
248
- NUMBER = "number"
249
- FLOAT = "float"
250
- NUMERIC = "numeric"
251
- INTEGER = "int"
252
- BIGINT = "bigint"
253
- DATE = "date"
254
- DATETIME = "datetime"
255
- TIMESTAMP = "timestamp"
256
- ARRAY = "array"
257
- DATE_PART = "date_part"
258
- STRUCT = "struct"
259
- NULL = "null"
260
-
261
- # GRANULAR
262
- UNIX_SECONDS = "unix_seconds"
263
-
264
- # PARSING
265
- UNKNOWN = "unknown"
266
-
267
- @property
268
- def data_type(self):
269
- return self
270
-
271
-
272
- class NumericType(BaseModel):
273
- precision: int = 20
274
- scale: int = 5
275
-
276
- @property
277
- def data_type(self):
278
- return DataType.NUMERIC
279
-
280
- @property
281
- def value(self):
282
- return self.data_type.value
283
-
284
-
285
- class ListType(BaseModel):
286
- model_config = ConfigDict(frozen=True)
287
- type: ALL_TYPES
288
-
289
- def __str__(self) -> str:
290
- return f"ListType<{self.type}>"
291
-
292
- @property
293
- def data_type(self):
294
- return DataType.LIST
295
-
296
- @property
297
- def value(self):
298
- return self.data_type.value
299
-
300
- @property
301
- def value_data_type(
302
- self,
303
- ) -> DataType | StructType | MapType | ListType | NumericType:
304
- if isinstance(self.type, Concept):
305
- return self.type.datatype
306
- return self.type
307
-
308
-
309
- class MapType(BaseModel):
310
- key_type: DataType
311
- value_type: ALL_TYPES
312
-
313
- @property
314
- def data_type(self):
315
- return DataType.MAP
316
-
317
- @property
318
- def value(self):
319
- return self.data_type.value
320
-
321
- @property
322
- def value_data_type(
323
- self,
324
- ) -> DataType | StructType | MapType | ListType | NumericType:
325
- if isinstance(self.value_type, Concept):
326
- return self.value_type.datatype
327
- return self.value_type
328
-
329
- @property
330
- def key_data_type(
331
- self,
332
- ) -> DataType | StructType | MapType | ListType | NumericType:
333
- if isinstance(self.key_type, Concept):
334
- return self.key_type.datatype
335
- return self.key_type
336
-
337
-
338
- class StructType(BaseModel):
339
- fields: List[ALL_TYPES]
340
- fields_map: Dict[str, Concept | int | float | str]
341
-
342
- @property
343
- def data_type(self):
344
- return DataType.STRUCT
345
-
346
- @property
347
- def value(self):
348
- return self.data_type.value
349
-
350
-
351
- class ListWrapper(Generic[VT], UserList):
352
- """Used to distinguish parsed list objects from other lists"""
353
-
354
- def __init__(self, *args, type: DataType, **kwargs):
355
- super().__init__(*args, **kwargs)
356
- self.type = type
357
-
358
- @classmethod
359
- def __get_pydantic_core_schema__(
360
- cls, source_type: Any, handler: Callable[[Any], core_schema.CoreSchema]
361
- ) -> core_schema.CoreSchema:
362
- args = get_args(source_type)
363
- if args:
364
- schema = handler(List[args]) # type: ignore
365
- else:
366
- schema = handler(List)
367
- return core_schema.no_info_after_validator_function(cls.validate, schema)
368
-
369
- @classmethod
370
- def validate(cls, v):
371
- return cls(v, type=arg_to_datatype(v[0]))
372
-
373
-
374
- class MapWrapper(Generic[KT, VT], UserDict):
375
- """Used to distinguish parsed map objects from other dicts"""
376
-
377
- def __init__(self, *args, key_type: DataType, value_type: DataType, **kwargs):
378
- super().__init__(*args, **kwargs)
379
- self.key_type = key_type
380
- self.value_type = value_type
381
-
382
- @classmethod
383
- def __get_pydantic_core_schema__(
384
- cls, source_type: Any, handler: Callable[[Any], core_schema.CoreSchema]
385
- ) -> core_schema.CoreSchema:
386
- args = get_args(source_type)
387
- if args:
388
- schema = handler(Dict[args]) # type: ignore
389
- else:
390
- schema = handler(Dict)
391
- return core_schema.no_info_after_validator_function(cls.validate, schema)
392
-
393
- @classmethod
394
- def validate(cls, v):
395
- return cls(
396
- v,
397
- key_type=arg_to_datatype(list(v.keys()).pop()),
398
- value_type=arg_to_datatype(list(v.values()).pop()),
399
- )
400
-
401
-
402
- class Metadata(BaseModel):
403
- """Metadata container object.
404
- TODO: support arbitrary tags"""
405
-
406
- description: Optional[str] = None
407
- line_number: Optional[int] = None
408
- concept_source: ConceptSource = ConceptSource.MANUAL
409
-
410
-
411
- class Concept(Mergeable, Namespaced, SelectContext, BaseModel):
412
- name: str
413
- datatype: DataType | ListType | StructType | MapType | NumericType
414
- purpose: Purpose
415
- metadata: Metadata = Field(
416
- default_factory=lambda: Metadata(description=None, line_number=None),
417
- validate_default=True,
418
- )
419
- lineage: Optional[
420
- Union[
421
- Function,
422
- WindowItem,
423
- FilterItem,
424
- AggregateWrapper,
425
- RowsetItem,
426
- MultiSelectStatement,
427
- ]
428
- ] = None
429
- namespace: Optional[str] = Field(default=DEFAULT_NAMESPACE, validate_default=True)
430
- keys: Optional[set[str]] = None
431
- grain: "Grain" = Field(default=None, validate_default=True) # type: ignore
432
- modifiers: List[Modifier] = Field(default_factory=list) # type: ignore
433
- pseudonyms: set[str] = Field(default_factory=set)
434
- _address_cache: str | None = None
435
-
436
- def __init__(self, **data):
437
- super().__init__(**data)
438
-
439
- def duplicate(self) -> Concept:
440
- return self.model_copy(deep=True)
441
-
442
- def __hash__(self):
443
- return hash(
444
- f"{self.name}+{self.datatype}+ {self.purpose} + {str(self.lineage)} + {self.namespace} + {str(self.grain)} + {str(self.keys)}"
445
- )
446
-
447
- def __repr__(self):
448
- base = f"{self.address}@{self.grain}"
449
- return base
450
-
451
- @property
452
- def is_aggregate(self):
453
- if (
454
- self.lineage
455
- and isinstance(self.lineage, Function)
456
- and self.lineage.operator in FunctionClass.AGGREGATE_FUNCTIONS.value
457
- ):
458
- return True
459
- if (
460
- self.lineage
461
- and isinstance(self.lineage, AggregateWrapper)
462
- and self.lineage.function.operator
463
- in FunctionClass.AGGREGATE_FUNCTIONS.value
464
- ):
465
- return True
466
- return False
467
-
468
- def with_merge(self, source: Self, target: Self, modifiers: List[Modifier]) -> Self:
469
- if self.address == source.address:
470
- new = target.with_grain(self.grain.with_merge(source, target, modifiers))
471
- new.pseudonyms.add(self.address)
472
- return new
473
- return self.__class__(
474
- name=self.name,
475
- datatype=self.datatype,
476
- purpose=self.purpose,
477
- metadata=self.metadata,
478
- lineage=(
479
- self.lineage.with_merge(source, target, modifiers)
480
- if self.lineage
481
- else None
482
- ),
483
- grain=self.grain.with_merge(source, target, modifiers),
484
- namespace=self.namespace,
485
- keys=(
486
- set(x if x != source.address else target.address for x in self.keys)
487
- if self.keys
488
- else None
489
- ),
490
- modifiers=self.modifiers,
491
- pseudonyms=self.pseudonyms,
492
- )
493
-
494
- @field_validator("namespace", mode="plain")
495
- @classmethod
496
- def namespace_validation(cls, v):
497
- return v or DEFAULT_NAMESPACE
498
-
499
- @field_validator("metadata", mode="before")
500
- @classmethod
501
- def metadata_validation(cls, v):
502
- v = v or Metadata()
503
- return v
504
-
505
- @field_validator("purpose", mode="after")
506
- @classmethod
507
- def purpose_validation(cls, v):
508
- if v == Purpose.AUTO:
509
- raise ValueError("Cannot set purpose to AUTO")
510
- return v
511
-
512
- @field_validator("grain", mode="before")
513
- @classmethod
514
- def parse_grain(cls, v, info: ValidationInfo) -> Grain:
515
- # this is silly - rethink how we do grains
516
- values = info.data
517
- if not v and values.get("purpose", None) == Purpose.KEY:
518
- v = Grain(
519
- components={
520
- f'{values.get("namespace", DEFAULT_NAMESPACE)}.{values["name"]}'
521
- }
522
- )
523
- elif (
524
- "lineage" in values
525
- and isinstance(values["lineage"], AggregateWrapper)
526
- and values["lineage"].by
527
- ):
528
- v = Grain(components={c.address for c in values["lineage"].by})
529
- elif not v:
530
- v = Grain(components=set())
531
- elif isinstance(v, Grain):
532
- pass
533
- elif isinstance(v, Concept):
534
- v = Grain(components={v.address})
535
- elif isinstance(v, dict):
536
- v = Grain.model_validate(v)
537
- else:
538
- raise SyntaxError(f"Invalid grain {v} for concept {values['name']}")
539
- return v
540
-
541
- def __eq__(self, other: object):
542
- if isinstance(other, str):
543
- if self.address == other:
544
- return True
545
- if not isinstance(other, Concept):
546
- return False
547
- return (
548
- self.name == other.name
549
- and self.datatype == other.datatype
550
- and self.purpose == other.purpose
551
- and self.namespace == other.namespace
552
- and self.grain == other.grain
553
- # and self.keys == other.keys
554
- )
555
-
556
- def __str__(self):
557
- grain = str(self.grain) if self.grain else "Grain<>"
558
- return f"{self.namespace}.{self.name}@{grain}"
559
-
560
- @cached_property
561
- def address(self) -> str:
562
- return f"{self.namespace}.{self.name}"
563
-
564
- def set_name(self, name: str):
565
- self.name = name
566
- try:
567
- del self.address
568
- except AttributeError:
569
- pass
570
-
571
- @property
572
- def output(self) -> "Concept":
573
- return self
574
-
575
- @property
576
- def safe_address(self) -> str:
577
- if self.namespace == DEFAULT_NAMESPACE:
578
- return self.name.replace(".", "_")
579
- elif self.namespace:
580
- return f"{self.namespace.replace('.','_')}_{self.name.replace('.','_')}"
581
- return self.name.replace(".", "_")
582
-
583
- def with_namespace(self, namespace: str) -> Self:
584
- if namespace == self.namespace:
585
- return self
586
- return self.__class__(
587
- name=self.name,
588
- datatype=self.datatype,
589
- purpose=self.purpose,
590
- metadata=self.metadata,
591
- lineage=self.lineage.with_namespace(namespace) if self.lineage else None,
592
- grain=(
593
- self.grain.with_namespace(namespace)
594
- if self.grain
595
- else Grain(components=set())
596
- ),
597
- namespace=(
598
- namespace + "." + self.namespace
599
- if self.namespace
600
- and self.namespace != DEFAULT_NAMESPACE
601
- and self.namespace != namespace
602
- else namespace
603
- ),
604
- keys=(
605
- set([address_with_namespace(x, namespace) for x in self.keys])
606
- if self.keys
607
- else None
608
- ),
609
- modifiers=self.modifiers,
610
- pseudonyms={address_with_namespace(v, namespace) for v in self.pseudonyms},
611
- )
612
-
613
- def with_select_context(
614
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
615
- ) -> Concept:
616
- new_lineage = self.lineage
617
- if isinstance(self.lineage, SelectContext):
618
- new_lineage = self.lineage.with_select_context(
619
- local_concepts=local_concepts, grain=grain, environment=environment
620
- )
621
- final_grain = self.grain or grain
622
- keys = self.keys if self.keys else None
623
- if self.is_aggregate and isinstance(new_lineage, Function) and grain.components:
624
- grain_components = [environment.concepts[c] for c in grain.components]
625
- new_lineage = AggregateWrapper(function=new_lineage, by=grain_components)
626
- final_grain = grain
627
- keys = set(grain.components)
628
- elif (
629
- self.is_aggregate and not keys and isinstance(new_lineage, AggregateWrapper)
630
- ):
631
- keys = set([x.address for x in new_lineage.by])
632
-
633
- return self.__class__(
634
- name=self.name,
635
- datatype=self.datatype,
636
- purpose=self.purpose,
637
- metadata=self.metadata,
638
- lineage=new_lineage,
639
- grain=final_grain,
640
- namespace=self.namespace,
641
- keys=keys,
642
- modifiers=self.modifiers,
643
- # a select needs to always defer to the environment for pseudonyms
644
- # TODO: evaluate if this should be cached
645
- pseudonyms=(environment.concepts.get(self.address) or self).pseudonyms,
646
- )
647
-
648
- def with_grain(self, grain: Optional["Grain"] = None) -> Self:
649
- return self.__class__(
650
- name=self.name,
651
- datatype=self.datatype,
652
- purpose=self.purpose,
653
- metadata=self.metadata,
654
- lineage=self.lineage,
655
- grain=grain if grain else Grain(components=set()),
656
- namespace=self.namespace,
657
- keys=self.keys,
658
- modifiers=self.modifiers,
659
- pseudonyms=self.pseudonyms,
660
- )
661
-
662
- @property
663
- def _with_default_grain(self) -> Self:
664
- if self.purpose == Purpose.KEY:
665
- # we need to make this abstract
666
- grain = Grain(components={self.address})
667
- elif self.purpose == Purpose.PROPERTY:
668
- components = []
669
- if self.keys:
670
- components = [*self.keys]
671
- if self.lineage:
672
- for item in self.lineage.arguments:
673
- if isinstance(item, Concept):
674
- components += [x.address for x in item.sources]
675
- # TODO: set synonyms
676
- grain = Grain(
677
- components=set([x for x in components]),
678
- ) # synonym_set=generate_concept_synonyms(components))
679
- elif self.purpose == Purpose.METRIC:
680
- grain = Grain()
681
- elif self.purpose == Purpose.CONSTANT:
682
- if self.derivation != PurposeLineage.CONSTANT:
683
- grain = Grain(components={self.address})
684
- else:
685
- grain = self.grain
686
- else:
687
- grain = self.grain # type: ignore
688
- return self.__class__(
689
- name=self.name,
690
- datatype=self.datatype,
691
- purpose=self.purpose,
692
- metadata=self.metadata,
693
- lineage=self.lineage,
694
- grain=grain,
695
- keys=self.keys,
696
- namespace=self.namespace,
697
- modifiers=self.modifiers,
698
- pseudonyms=self.pseudonyms,
699
- )
700
-
701
- def with_default_grain(self) -> "Concept":
702
- return self._with_default_grain
703
-
704
- @property
705
- def sources(self) -> List["Concept"]:
706
- if self.lineage:
707
- output: List[Concept] = []
708
-
709
- def get_sources(
710
- expr: Union[
711
- Function,
712
- WindowItem,
713
- FilterItem,
714
- AggregateWrapper,
715
- RowsetItem,
716
- MultiSelectStatement,
717
- ],
718
- output: List[Concept],
719
- ):
720
- for item in expr.arguments:
721
- if isinstance(item, Concept):
722
- if item.address == self.address:
723
- raise SyntaxError(
724
- f"Concept {self.address} references itself"
725
- )
726
- output.append(item)
727
- output += item.sources
728
- elif isinstance(item, Function):
729
- get_sources(item, output)
730
-
731
- get_sources(self.lineage, output)
732
- return output
733
- return []
734
-
735
- @property
736
- def concept_arguments(self) -> List[Concept]:
737
- return self.lineage.concept_arguments if self.lineage else []
738
-
739
- @property
740
- def input(self):
741
- return [self] + self.sources
742
-
743
- @property
744
- def derivation(self) -> PurposeLineage:
745
- if self.lineage and isinstance(self.lineage, WindowItem):
746
- return PurposeLineage.WINDOW
747
- elif self.lineage and isinstance(self.lineage, FilterItem):
748
- return PurposeLineage.FILTER
749
- elif self.lineage and isinstance(self.lineage, AggregateWrapper):
750
- return PurposeLineage.AGGREGATE
751
- elif self.lineage and isinstance(self.lineage, RowsetItem):
752
- return PurposeLineage.ROWSET
753
- elif self.lineage and isinstance(self.lineage, MultiSelectStatement):
754
- return PurposeLineage.MULTISELECT
755
- elif (
756
- self.lineage
757
- and isinstance(self.lineage, Function)
758
- and self.lineage.operator in FunctionClass.AGGREGATE_FUNCTIONS.value
759
- ):
760
- return PurposeLineage.AGGREGATE
761
- elif (
762
- self.lineage
763
- and isinstance(self.lineage, Function)
764
- and self.lineage.operator == FunctionType.UNNEST
765
- ):
766
- return PurposeLineage.UNNEST
767
- elif (
768
- self.lineage
769
- and isinstance(self.lineage, Function)
770
- and self.lineage.operator == FunctionType.UNION
771
- ):
772
- return PurposeLineage.UNION
773
- elif (
774
- self.lineage
775
- and isinstance(self.lineage, Function)
776
- and self.lineage.operator in FunctionClass.SINGLE_ROW.value
777
- ):
778
- return PurposeLineage.CONSTANT
779
-
780
- elif self.lineage and isinstance(self.lineage, Function):
781
- if not self.lineage.concept_arguments:
782
- return PurposeLineage.CONSTANT
783
- elif all(
784
- [
785
- x.derivation == PurposeLineage.CONSTANT
786
- for x in self.lineage.concept_arguments
787
- ]
788
- ):
789
- return PurposeLineage.CONSTANT
790
- return PurposeLineage.BASIC
791
- elif self.purpose == Purpose.CONSTANT:
792
- return PurposeLineage.CONSTANT
793
- return PurposeLineage.ROOT
794
-
795
- @property
796
- def granularity(self) -> Granularity:
797
- """ "used to determine if concepts need to be included in grain
798
- calculations"""
799
- if self.derivation == PurposeLineage.CONSTANT:
800
- # constants are a single row
801
- return Granularity.SINGLE_ROW
802
- elif self.derivation == PurposeLineage.AGGREGATE:
803
- # if it's an aggregate grouped over all rows
804
- # there is only one row left and it's fine to cross_join
805
- if all([x.endswith(ALL_ROWS_CONCEPT) for x in self.grain.components]):
806
- return Granularity.SINGLE_ROW
807
- elif self.namespace == INTERNAL_NAMESPACE and self.name == ALL_ROWS_CONCEPT:
808
- return Granularity.SINGLE_ROW
809
- elif (
810
- self.lineage
811
- and isinstance(self.lineage, Function)
812
- and self.lineage.operator in (FunctionType.UNNEST, FunctionType.UNION)
813
- ):
814
- return Granularity.MULTI_ROW
815
- elif self.lineage and all(
816
- [
817
- x.granularity == Granularity.SINGLE_ROW
818
- for x in self.lineage.concept_arguments
819
- ]
820
- ):
821
-
822
- return Granularity.SINGLE_ROW
823
- return Granularity.MULTI_ROW
824
-
825
- def with_filter(
826
- self,
827
- condition: "Conditional | Comparison | Parenthetical",
828
- environment: Environment | None = None,
829
- ) -> "Concept":
830
- from trilogy.utility import string_to_hash
831
-
832
- if self.lineage and isinstance(self.lineage, FilterItem):
833
- if self.lineage.where.conditional == condition:
834
- return self
835
- hash = string_to_hash(self.name + str(condition))
836
- new = Concept(
837
- name=f"{self.name}_filter_{hash}",
838
- datatype=self.datatype,
839
- purpose=self.purpose,
840
- metadata=self.metadata,
841
- lineage=FilterItem(content=self, where=WhereClause(conditional=condition)),
842
- keys=(self.keys if self.purpose == Purpose.PROPERTY else None),
843
- grain=self.grain if self.grain else Grain(components=set()),
844
- namespace=self.namespace,
845
- modifiers=self.modifiers,
846
- pseudonyms=self.pseudonyms,
847
- )
848
- if environment:
849
- environment.add_concept(new)
850
- return new
851
-
852
-
853
- class ConceptRef(BaseModel):
854
- address: str
855
- line_no: int | None = None
856
-
857
- def hydrate(self, environment: Environment) -> Concept:
858
- return environment.concepts.__getitem__(self.address, self.line_no)
859
-
860
-
861
- class Grain(Namespaced, BaseModel):
862
- components: set[str] = Field(default_factory=set)
863
- where_clause: Optional["WhereClause"] = None
864
-
865
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
866
- new_components = set()
867
- for c in self.components:
868
- if c == source.address:
869
- new_components.add(target.address)
870
- else:
871
- new_components.add(c)
872
- return Grain(components=new_components)
873
-
874
- @classmethod
875
- def from_concepts(
876
- cls,
877
- concepts: List[Concept],
878
- environment: Environment | None = None,
879
- where_clause: WhereClause | None = None,
880
- ) -> "Grain":
881
- from trilogy.parsing.common import concepts_to_grain_concepts
882
-
883
- return Grain(
884
- components={
885
- c.address
886
- for c in concepts_to_grain_concepts(concepts, environment=environment)
887
- },
888
- where_clause=where_clause,
889
- )
890
-
891
- def with_namespace(self, namespace: str) -> "Grain":
892
- return Grain(
893
- components={address_with_namespace(c, namespace) for c in self.components},
894
- where_clause=(
895
- self.where_clause.with_namespace(namespace)
896
- if self.where_clause
897
- else None
898
- ),
899
- )
900
-
901
- @field_validator("components", mode="before")
902
- def component_validator(cls, v, info: ValidationInfo):
903
- output = set()
904
- if isinstance(v, list):
905
- for vc in v:
906
- if isinstance(vc, Concept):
907
- output.add(vc.address)
908
- elif isinstance(vc, ConceptRef):
909
- output.add(vc.address)
910
- else:
911
- output.add(vc)
912
- else:
913
- output = v
914
- if not isinstance(output, set):
915
- raise ValueError(f"Invalid grain component {output}, is not set")
916
- if not all(isinstance(x, str) for x in output):
917
- raise ValueError(f"Invalid component {output}")
918
- return output
919
-
920
- def __add__(self, other: "Grain") -> "Grain":
921
- where = self.where_clause
922
- if other.where_clause:
923
- if not self.where_clause:
924
- where = other.where_clause
925
- elif not other.where_clause == self.where_clause:
926
- where = WhereClause(
927
- conditional=Conditional(
928
- left=self.where_clause.conditional,
929
- right=other.where_clause.conditional,
930
- operator=BooleanOperator.AND,
931
- )
932
- )
933
- # raise NotImplementedError(
934
- # f"Cannot merge grains with where clauses, self {self.where_clause} other {other.where_clause}"
935
- # )
936
- return Grain(
937
- components=self.components.union(other.components), where_clause=where
938
- )
939
-
940
- def __sub__(self, other: "Grain") -> "Grain":
941
- return Grain(
942
- components=self.components.difference(other.components),
943
- where_clause=self.where_clause,
944
- )
945
-
946
- @property
947
- def abstract(self):
948
- return not self.components or all(
949
- [c.endswith(ALL_ROWS_CONCEPT) for c in self.components]
950
- )
951
-
952
- def __eq__(self, other: object):
953
- if isinstance(other, list):
954
- if not all([isinstance(c, Concept) for c in other]):
955
- return False
956
- return self.components == set([c.address for c in other])
957
- if not isinstance(other, Grain):
958
- return False
959
- if self.components == other.components:
960
- return True
961
- return False
962
-
963
- def issubset(self, other: "Grain"):
964
- return self.components.issubset(other.components)
965
-
966
- def union(self, other: "Grain"):
967
- addresses = self.components.union(other.components)
968
- return Grain(components=addresses, where_clause=self.where_clause)
969
-
970
- def isdisjoint(self, other: "Grain"):
971
- return self.components.isdisjoint(other.components)
972
-
973
- def intersection(self, other: "Grain") -> "Grain":
974
- intersection = self.components.intersection(other.components)
975
- return Grain(components=intersection)
976
-
977
- def __str__(self):
978
- if self.abstract:
979
- base = "Grain<Abstract>"
980
- else:
981
- base = "Grain<" + ",".join([c for c in sorted(list(self.components))]) + ">"
982
- if self.where_clause:
983
- base += f"|{str(self.where_clause)}"
984
- return base
985
-
986
- def __radd__(self, other) -> "Grain":
987
- if other == 0:
988
- return self
989
- else:
990
- return self.__add__(other)
991
-
992
-
993
- class EnvironmentConceptDict(dict):
994
- def __init__(self, *args, **kwargs) -> None:
995
- super().__init__(self, *args, **kwargs)
996
- self.undefined: dict[str, UndefinedConcept] = {}
997
- self.fail_on_missing: bool = True
998
- self.populate_default_concepts()
999
-
1000
- def duplicate(self) -> "EnvironmentConceptDict":
1001
- new = EnvironmentConceptDict()
1002
- new.update({k: v.duplicate() for k, v in self.items()})
1003
- new.undefined = self.undefined
1004
- new.fail_on_missing = self.fail_on_missing
1005
- return new
1006
-
1007
- def populate_default_concepts(self):
1008
- from trilogy.core.internal import DEFAULT_CONCEPTS
1009
-
1010
- for concept in DEFAULT_CONCEPTS.values():
1011
- self[concept.address] = concept
1012
-
1013
- def values(self) -> ValuesView[Concept]: # type: ignore
1014
- return super().values()
1015
-
1016
- def get(self, key: str, default: Concept | None = None) -> Concept | None: # type: ignore
1017
- try:
1018
- return self.__getitem__(key)
1019
- except UndefinedConceptException:
1020
- return default
1021
-
1022
- def raise_undefined(
1023
- self, key: str, line_no: int | None = None, file: Path | str | None = None
1024
- ) -> Never:
1025
-
1026
- matches = self._find_similar_concepts(key)
1027
- message = f"Undefined concept: {key}."
1028
- if matches:
1029
- message += f" Suggestions: {matches}"
1030
-
1031
- if line_no:
1032
- if file:
1033
- raise UndefinedConceptException(
1034
- f"{file}: {line_no}: " + message, matches
1035
- )
1036
- raise UndefinedConceptException(f"line: {line_no}: " + message, matches)
1037
- raise UndefinedConceptException(message, matches)
1038
-
1039
- def __getitem__(
1040
- self, key: str, line_no: int | None = None, file: Path | None = None
1041
- ) -> Concept | UndefinedConcept:
1042
- try:
1043
- return super(EnvironmentConceptDict, self).__getitem__(key)
1044
- except KeyError:
1045
- if "." in key and key.split(".", 1)[0] == DEFAULT_NAMESPACE:
1046
- return self.__getitem__(key.split(".", 1)[1], line_no)
1047
- if DEFAULT_NAMESPACE + "." + key in self:
1048
- return self.__getitem__(DEFAULT_NAMESPACE + "." + key, line_no)
1049
- if not self.fail_on_missing:
1050
- if "." in key:
1051
- ns, rest = key.rsplit(".", 1)
1052
- else:
1053
- ns = DEFAULT_NAMESPACE
1054
- rest = key
1055
- if key in self.undefined:
1056
- return self.undefined[key]
1057
- undefined = UndefinedConcept(
1058
- name=rest,
1059
- line_no=line_no,
1060
- datatype=DataType.UNKNOWN,
1061
- purpose=Purpose.UNKNOWN,
1062
- namespace=ns,
1063
- )
1064
- self.undefined[key] = undefined
1065
- return undefined
1066
- self.raise_undefined(key, line_no, file)
1067
-
1068
- def _find_similar_concepts(self, concept_name: str):
1069
- def strip_local(input: str):
1070
- if input.startswith(f"{DEFAULT_NAMESPACE}."):
1071
- return input[len(DEFAULT_NAMESPACE) + 1 :]
1072
- return input
1073
-
1074
- matches = difflib.get_close_matches(
1075
- strip_local(concept_name), [strip_local(x) for x in self.keys()]
1076
- )
1077
- return matches
1078
-
1079
- def items(self) -> ItemsView[str, Concept]: # type: ignore
1080
- return super().items()
1081
-
1082
-
1083
- class RawColumnExpr(BaseModel):
1084
- text: str
1085
-
1086
-
1087
- class ColumnAssignment(BaseModel):
1088
- alias: str | RawColumnExpr | Function
1089
- concept: Concept
1090
- modifiers: List[Modifier] = Field(default_factory=list)
1091
-
1092
- @property
1093
- def is_complete(self) -> bool:
1094
- return Modifier.PARTIAL not in self.modifiers
1095
-
1096
- @property
1097
- def is_nullable(self) -> bool:
1098
- return Modifier.NULLABLE in self.modifiers
1099
-
1100
- def with_namespace(self, namespace: str) -> "ColumnAssignment":
1101
- return ColumnAssignment(
1102
- alias=(
1103
- self.alias.with_namespace(namespace)
1104
- if isinstance(self.alias, Function)
1105
- else self.alias
1106
- ),
1107
- concept=self.concept.with_namespace(namespace),
1108
- modifiers=self.modifiers,
1109
- )
1110
-
1111
- def with_merge(
1112
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1113
- ) -> "ColumnAssignment":
1114
- return ColumnAssignment(
1115
- alias=self.alias,
1116
- concept=self.concept.with_merge(source, target, modifiers),
1117
- modifiers=(
1118
- modifiers if self.concept.address == source.address else self.modifiers
1119
- ),
1120
- )
1121
-
1122
-
1123
- class Statement(BaseModel):
1124
- pass
1125
-
1126
-
1127
- class LooseConceptList(BaseModel):
1128
- concepts: List[Concept]
1129
-
1130
- @cached_property
1131
- def addresses(self) -> set[str]:
1132
- return {s.address for s in self.concepts}
1133
-
1134
- @classmethod
1135
- def validate(cls, v):
1136
- return cls(v)
1137
-
1138
- def __str__(self) -> str:
1139
- return f"lcl{str(self.addresses)}"
1140
-
1141
- def __iter__(self):
1142
- return iter(self.concepts)
1143
-
1144
- def __eq__(self, other):
1145
- if not isinstance(other, LooseConceptList):
1146
- return False
1147
- return self.addresses == other.addresses
1148
-
1149
- def issubset(self, other):
1150
- if not isinstance(other, LooseConceptList):
1151
- return False
1152
- return self.addresses.issubset(other.addresses)
1153
-
1154
- def __contains__(self, other):
1155
- if isinstance(other, str):
1156
- return other in self.addresses
1157
- if not isinstance(other, Concept):
1158
- return False
1159
- return other.address in self.addresses
1160
-
1161
- def difference(self, other):
1162
- if not isinstance(other, LooseConceptList):
1163
- return False
1164
- return self.addresses.difference(other.addresses)
1165
-
1166
- def isdisjoint(self, other):
1167
- if not isinstance(other, LooseConceptList):
1168
- return False
1169
- return self.addresses.isdisjoint(other.addresses)
1170
-
1171
-
1172
- class Function(Mergeable, Namespaced, SelectContext, BaseModel):
1173
- operator: FunctionType
1174
- arg_count: int = Field(default=1)
1175
- output_datatype: DataType | ListType | StructType | MapType | NumericType
1176
- output_purpose: Purpose
1177
- valid_inputs: Optional[
1178
- Union[
1179
- Set[DataType | ListType | StructType | MapType | NumericType],
1180
- List[Set[DataType | ListType | StructType | MapType | NumericType]],
1181
- ]
1182
- ] = None
1183
- arguments: Sequence[
1184
- Union[
1185
- Concept,
1186
- "AggregateWrapper",
1187
- "Function",
1188
- int,
1189
- float,
1190
- str,
1191
- date,
1192
- datetime,
1193
- MapWrapper[Any, Any],
1194
- DataType,
1195
- ListType,
1196
- MapType,
1197
- NumericType,
1198
- DatePart,
1199
- "Parenthetical",
1200
- CaseWhen,
1201
- "CaseElse",
1202
- list,
1203
- ListWrapper[Any],
1204
- WindowItem,
1205
- ]
1206
- ]
1207
-
1208
- def __repr__(self):
1209
- return f'{self.operator.value}({",".join([str(a) for a in self.arguments])})'
1210
-
1211
- def __str__(self):
1212
- return self.__repr__()
1213
-
1214
- @property
1215
- def datatype(self):
1216
- return self.output_datatype
1217
-
1218
- def with_select_context(
1219
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
1220
- ) -> Function:
1221
- base = Function(
1222
- operator=self.operator,
1223
- arguments=[
1224
- (
1225
- c.with_select_context(local_concepts, grain, environment)
1226
- if isinstance(
1227
- c,
1228
- SelectContext,
1229
- )
1230
- else c
1231
- )
1232
- for c in self.arguments
1233
- ],
1234
- output_datatype=self.output_datatype,
1235
- output_purpose=self.output_purpose,
1236
- valid_inputs=self.valid_inputs,
1237
- arg_count=self.arg_count,
1238
- )
1239
- return base
1240
-
1241
- @field_validator("arguments")
1242
- @classmethod
1243
- def parse_arguments(cls, v, info: ValidationInfo):
1244
- from trilogy.parsing.exceptions import ParseError
1245
-
1246
- values = info.data
1247
- arg_count = len(v)
1248
- target_arg_count = values["arg_count"]
1249
- operator_name = values["operator"].name
1250
- # surface right error
1251
- if "valid_inputs" not in values:
1252
- return v
1253
- valid_inputs = values["valid_inputs"]
1254
- if not arg_count <= target_arg_count:
1255
- if target_arg_count != InfiniteFunctionArgs:
1256
- raise ParseError(
1257
- f"Incorrect argument count to {operator_name} function, expects"
1258
- f" {target_arg_count}, got {arg_count}"
1259
- )
1260
- # if all arguments can be any of the set type
1261
- # turn this into an array for validation
1262
- if isinstance(valid_inputs, set):
1263
- valid_inputs = [valid_inputs for _ in v]
1264
- elif not valid_inputs:
1265
- return v
1266
- for idx, arg in enumerate(v):
1267
- if (
1268
- isinstance(arg, Concept)
1269
- and arg.datatype.data_type not in valid_inputs[idx]
1270
- ):
1271
- if arg.datatype != DataType.UNKNOWN:
1272
- raise TypeError(
1273
- f"Invalid input datatype {arg.datatype.data_type} passed into position {idx}"
1274
- f" for {operator_name} from concept {arg.name}, valid is {valid_inputs[idx]}"
1275
- )
1276
- if (
1277
- isinstance(arg, Function)
1278
- and arg.output_datatype not in valid_inputs[idx]
1279
- ):
1280
- if arg.output_datatype != DataType.UNKNOWN:
1281
- raise TypeError(
1282
- f"Invalid input datatype {arg.output_datatype} passed into"
1283
- f" {operator_name} from function {arg.operator.name}"
1284
- )
1285
- # check constants
1286
- comparisons: List[Tuple[Type, DataType]] = [
1287
- (str, DataType.STRING),
1288
- (int, DataType.INTEGER),
1289
- (float, DataType.FLOAT),
1290
- (bool, DataType.BOOL),
1291
- (DatePart, DataType.DATE_PART),
1292
- ]
1293
- for ptype, dtype in comparisons:
1294
- if isinstance(arg, ptype) and dtype in valid_inputs[idx]:
1295
- # attempt to exit early to avoid checking all types
1296
- break
1297
- elif isinstance(arg, ptype):
1298
- raise TypeError(
1299
- f"Invalid {dtype} constant passed into {operator_name} {arg}, expecting one of {valid_inputs[idx]}"
1300
- )
1301
- return v
1302
-
1303
- def with_namespace(self, namespace: str) -> "Function":
1304
- return Function(
1305
- operator=self.operator,
1306
- arguments=[
1307
- (
1308
- c.with_namespace(namespace)
1309
- if isinstance(
1310
- c,
1311
- Namespaced,
1312
- )
1313
- else c
1314
- )
1315
- for c in self.arguments
1316
- ],
1317
- output_datatype=self.output_datatype,
1318
- output_purpose=self.output_purpose,
1319
- valid_inputs=self.valid_inputs,
1320
- arg_count=self.arg_count,
1321
- )
1322
-
1323
- def with_merge(
1324
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1325
- ) -> "Function":
1326
- return Function(
1327
- operator=self.operator,
1328
- arguments=[
1329
- (
1330
- c.with_merge(source, target, modifiers)
1331
- if isinstance(
1332
- c,
1333
- Mergeable,
1334
- )
1335
- else c
1336
- )
1337
- for c in self.arguments
1338
- ],
1339
- output_datatype=self.output_datatype,
1340
- output_purpose=self.output_purpose,
1341
- valid_inputs=self.valid_inputs,
1342
- arg_count=self.arg_count,
1343
- )
1344
-
1345
- @property
1346
- def concept_arguments(self) -> List[Concept]:
1347
- base = []
1348
- for arg in self.arguments:
1349
- base += get_concept_arguments(arg)
1350
- return base
1351
-
1352
- @property
1353
- def output_grain(self):
1354
- # aggregates have an abstract grain
1355
- base_grain = Grain(components=[])
1356
- if self.operator in FunctionClass.AGGREGATE_FUNCTIONS.value:
1357
- return base_grain
1358
- # scalars have implicit grain of all arguments
1359
- for input in self.concept_arguments:
1360
- base_grain += input.grain
1361
- return base_grain
1362
-
1363
-
1364
- class ConceptTransform(Namespaced, BaseModel):
1365
- function: Function | FilterItem | WindowItem | AggregateWrapper
1366
- output: Concept
1367
- modifiers: List[Modifier] = Field(default_factory=list)
1368
-
1369
- @property
1370
- def input(self) -> List[Concept]:
1371
- return [v for v in self.function.arguments if isinstance(v, Concept)]
1372
-
1373
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
1374
- return ConceptTransform(
1375
- function=self.function.with_merge(source, target, modifiers),
1376
- output=self.output.with_merge(source, target, modifiers),
1377
- modifiers=self.modifiers + modifiers,
1378
- )
1379
-
1380
- def with_namespace(self, namespace: str) -> "ConceptTransform":
1381
- return ConceptTransform(
1382
- function=self.function.with_namespace(namespace),
1383
- output=self.output.with_namespace(namespace),
1384
- modifiers=self.modifiers,
1385
- )
1386
-
1387
-
1388
- class Window(BaseModel):
1389
- count: int
1390
- window_order: WindowOrder
1391
-
1392
- def __str__(self):
1393
- return f"Window<{self.window_order}>"
1394
-
1395
-
1396
- class WindowItemOver(BaseModel):
1397
- contents: List[Concept]
1398
-
1399
-
1400
- class WindowItemOrder(BaseModel):
1401
- contents: List["OrderItem"]
1402
-
1403
-
1404
- class WindowItem(Mergeable, Namespaced, SelectContext, BaseModel):
1405
- type: WindowType
1406
- content: Concept
1407
- order_by: List["OrderItem"]
1408
- over: List["Concept"] = Field(default_factory=list)
1409
- index: Optional[int] = None
1410
-
1411
- def __repr__(self) -> str:
1412
- return f"{self.type}({self.content} {self.index}, {self.over}, {self.order_by})"
1413
-
1414
- def with_merge(
1415
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1416
- ) -> "WindowItem":
1417
- return WindowItem(
1418
- type=self.type,
1419
- content=self.content.with_merge(source, target, modifiers),
1420
- over=[x.with_merge(source, target, modifiers) for x in self.over],
1421
- order_by=[x.with_merge(source, target, modifiers) for x in self.order_by],
1422
- index=self.index,
1423
- )
1424
-
1425
- def with_namespace(self, namespace: str) -> "WindowItem":
1426
- return WindowItem(
1427
- type=self.type,
1428
- content=self.content.with_namespace(namespace),
1429
- over=[x.with_namespace(namespace) for x in self.over],
1430
- order_by=[x.with_namespace(namespace) for x in self.order_by],
1431
- index=self.index,
1432
- )
1433
-
1434
- def with_select_context(
1435
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
1436
- ) -> "WindowItem":
1437
- return WindowItem(
1438
- type=self.type,
1439
- content=self.content.with_select_context(
1440
- local_concepts, grain, environment
1441
- ),
1442
- over=[
1443
- x.with_select_context(local_concepts, grain, environment)
1444
- for x in self.over
1445
- ],
1446
- order_by=[
1447
- x.with_select_context(local_concepts, grain, environment)
1448
- for x in self.order_by
1449
- ],
1450
- index=self.index,
1451
- )
1452
-
1453
- @property
1454
- def concept_arguments(self) -> List[Concept]:
1455
- return self.arguments
1456
-
1457
- @property
1458
- def arguments(self) -> List[Concept]:
1459
- output = [self.content]
1460
- for order in self.order_by:
1461
- output += [order.output]
1462
- for item in self.over:
1463
- output += [item]
1464
- return output
1465
-
1466
- @property
1467
- def output(self) -> Concept:
1468
- if isinstance(self.content, ConceptTransform):
1469
- return self.content.output
1470
- return self.content
1471
-
1472
- @output.setter
1473
- def output(self, value):
1474
- if isinstance(self.content, ConceptTransform):
1475
- self.content.output = value
1476
- else:
1477
- self.content = value
1478
-
1479
- @property
1480
- def input(self) -> List[Concept]:
1481
- base = self.content.input
1482
- for v in self.order_by:
1483
- base += v.input
1484
- for c in self.over:
1485
- base += c.input
1486
- return base
1487
-
1488
- @property
1489
- def output_datatype(self):
1490
- return self.content.datatype
1491
-
1492
- @property
1493
- def output_purpose(self):
1494
- return Purpose.PROPERTY
1495
-
1496
-
1497
- class FilterItem(Namespaced, SelectContext, BaseModel):
1498
- content: Concept
1499
- where: "WhereClause"
1500
-
1501
- def __str__(self):
1502
- return f"<Filter: {str(self.content)} where {str(self.where)}>"
1503
-
1504
- def with_merge(
1505
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1506
- ) -> "FilterItem":
1507
- return FilterItem(
1508
- content=source.with_merge(source, target, modifiers),
1509
- where=self.where.with_merge(source, target, modifiers),
1510
- )
1511
-
1512
- def with_namespace(self, namespace: str) -> "FilterItem":
1513
- return FilterItem(
1514
- content=self.content.with_namespace(namespace),
1515
- where=self.where.with_namespace(namespace),
1516
- )
1517
-
1518
- def with_select_context(
1519
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
1520
- ) -> FilterItem:
1521
- return FilterItem(
1522
- content=self.content.with_select_context(
1523
- local_concepts, grain, environment
1524
- ),
1525
- where=self.where.with_select_context(local_concepts, grain, environment),
1526
- )
1527
-
1528
- @property
1529
- def arguments(self) -> List[Concept]:
1530
- output = [self.content]
1531
- output += self.where.input
1532
- return output
1533
-
1534
- @property
1535
- def output(self) -> Concept:
1536
- if isinstance(self.content, ConceptTransform):
1537
- return self.content.output
1538
- return self.content
1539
-
1540
- @output.setter
1541
- def output(self, value):
1542
- if isinstance(self.content, ConceptTransform):
1543
- self.content.output = value
1544
- else:
1545
- self.content = value
1546
-
1547
- @property
1548
- def input(self) -> List[Concept]:
1549
- base = self.content.input
1550
- base += self.where.input
1551
- return base
1552
-
1553
- @property
1554
- def output_datatype(self):
1555
- return self.content.datatype
1556
-
1557
- @property
1558
- def output_purpose(self):
1559
- return self.content.purpose
1560
-
1561
- @property
1562
- def concept_arguments(self):
1563
- return [self.content] + self.where.concept_arguments
1564
-
1565
-
1566
- class SelectItem(Mergeable, Namespaced, BaseModel):
1567
- content: Union[Concept, ConceptTransform]
1568
- modifiers: List[Modifier] = Field(default_factory=list)
1569
-
1570
- @property
1571
- def output(self) -> Concept:
1572
- if isinstance(self.content, ConceptTransform):
1573
- return self.content.output
1574
- elif isinstance(self.content, WindowItem):
1575
- return self.content.output
1576
- return self.content
1577
-
1578
- @property
1579
- def input(self) -> List[Concept]:
1580
- return self.content.input
1581
-
1582
- def with_merge(
1583
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1584
- ) -> "SelectItem":
1585
- return SelectItem(
1586
- content=self.content.with_merge(source, target, modifiers),
1587
- modifiers=modifiers,
1588
- )
1589
-
1590
- def with_namespace(self, namespace: str) -> "SelectItem":
1591
- return SelectItem(
1592
- content=self.content.with_namespace(namespace),
1593
- modifiers=self.modifiers,
1594
- )
1595
-
1596
-
1597
- class OrderItem(Mergeable, SelectContext, Namespaced, BaseModel):
1598
- expr: Concept
1599
- order: Ordering
1600
-
1601
- def with_namespace(self, namespace: str) -> "OrderItem":
1602
- return OrderItem(expr=self.expr.with_namespace(namespace), order=self.order)
1603
-
1604
- def with_select_context(
1605
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
1606
- ) -> "OrderItem":
1607
- return OrderItem(
1608
- expr=self.expr.with_select_context(
1609
- local_concepts, grain, environment=environment
1610
- ),
1611
- order=self.order,
1612
- )
1613
-
1614
- def with_merge(
1615
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1616
- ) -> "OrderItem":
1617
- return OrderItem(
1618
- expr=source.with_merge(source, target, modifiers), order=self.order
1619
- )
1620
-
1621
- @property
1622
- def input(self):
1623
- return self.expr.input
1624
-
1625
- @property
1626
- def output(self):
1627
- return self.expr.output
1628
-
1629
-
1630
- class OrderBy(SelectContext, Mergeable, Namespaced, BaseModel):
1631
- items: List[OrderItem]
1632
-
1633
- def with_namespace(self, namespace: str) -> "OrderBy":
1634
- return OrderBy(items=[x.with_namespace(namespace) for x in self.items])
1635
-
1636
- def with_merge(
1637
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1638
- ) -> "OrderBy":
1639
- return OrderBy(
1640
- items=[x.with_merge(source, target, modifiers) for x in self.items]
1641
- )
1642
-
1643
- def with_select_context(self, local_concepts, grain, environment):
1644
- return OrderBy(
1645
- items=[
1646
- x.with_select_context(local_concepts, grain, environment)
1647
- for x in self.items
1648
- ]
1649
- )
1650
-
1651
- @property
1652
- def concept_arguments(self):
1653
- return [x.expr for x in self.items]
1654
-
1655
-
1656
- class RawSQLStatement(BaseModel):
1657
- text: str
1658
- meta: Optional[Metadata] = Field(default_factory=lambda: Metadata())
1659
-
1660
-
1661
- class SelectStatement(HasUUID, Mergeable, Namespaced, SelectTypeMixin, BaseModel):
1662
- selection: List[SelectItem]
1663
- order_by: Optional[OrderBy] = None
1664
- limit: Optional[int] = None
1665
- meta: Metadata = Field(default_factory=lambda: Metadata())
1666
- local_concepts: Annotated[
1667
- EnvironmentConceptDict, PlainValidator(validate_concepts)
1668
- ] = Field(default_factory=EnvironmentConceptDict)
1669
- grain: Grain = Field(default_factory=Grain)
1670
-
1671
- @classmethod
1672
- def from_inputs(
1673
- cls,
1674
- environment: Environment,
1675
- selection: List[SelectItem],
1676
- order_by: OrderBy | None = None,
1677
- limit: int | None = None,
1678
- meta: Metadata | None = None,
1679
- where_clause: WhereClause | None = None,
1680
- having_clause: HavingClause | None = None,
1681
- ) -> "SelectStatement":
1682
-
1683
- output = SelectStatement(
1684
- selection=selection,
1685
- where_clause=where_clause,
1686
- having_clause=having_clause,
1687
- limit=limit,
1688
- order_by=order_by,
1689
- meta=meta or Metadata(),
1690
- )
1691
- for parse_pass in [
1692
- 1,
1693
- 2,
1694
- ]:
1695
- # the first pass will result in all concepts being defined
1696
- # the second will get grains appropriately
1697
- # eg if someone does sum(x)->a, b+c -> z - we don't know if Z is a key to group by or an aggregate
1698
- # until after the first pass, and so don't know the grain of a
1699
-
1700
- if parse_pass == 1:
1701
- grain = Grain.from_concepts(
1702
- [
1703
- x.content
1704
- for x in output.selection
1705
- if isinstance(x.content, Concept)
1706
- ],
1707
- where_clause=output.where_clause,
1708
- )
1709
- if parse_pass == 2:
1710
- grain = Grain.from_concepts(
1711
- output.output_components, where_clause=output.where_clause
1712
- )
1713
- output.grain = grain
1714
- pass_grain = Grain() if parse_pass == 1 else grain
1715
- for item in selection:
1716
- # we don't know the grain of an aggregate at assignment time
1717
- # so rebuild at this point in the tree
1718
- # TODO: simplify
1719
- if isinstance(item.content, ConceptTransform):
1720
- new_concept = item.content.output.with_select_context(
1721
- output.local_concepts,
1722
- # the first pass grain will be incorrect
1723
- pass_grain,
1724
- environment=environment,
1725
- )
1726
- output.local_concepts[new_concept.address] = new_concept
1727
- item.content.output = new_concept
1728
- if parse_pass == 2 and CONFIG.select_as_definition:
1729
- environment.add_concept(new_concept)
1730
- elif isinstance(item.content, UndefinedConcept):
1731
- environment.concepts.raise_undefined(
1732
- item.content.address,
1733
- line_no=item.content.metadata.line_number,
1734
- file=environment.env_file_path,
1735
- )
1736
- elif isinstance(item.content, Concept):
1737
- # Sometimes cached values here don't have the latest info
1738
- # but we can't just use environment, as it might not have the right grain.
1739
- item.content = item.content.with_select_context(
1740
- output.local_concepts,
1741
- pass_grain,
1742
- environment=environment,
1743
- )
1744
- output.local_concepts[item.content.address] = item.content
1745
-
1746
- if order_by:
1747
- output.order_by = order_by.with_select_context(
1748
- local_concepts=output.local_concepts,
1749
- grain=output.grain,
1750
- environment=environment,
1751
- )
1752
- if output.having_clause:
1753
- output.having_clause = output.having_clause.with_select_context(
1754
- local_concepts=output.local_concepts,
1755
- grain=output.grain,
1756
- environment=environment,
1757
- )
1758
- output.validate_syntax(environment)
1759
- return output
1760
-
1761
- def validate_syntax(self, environment: Environment):
1762
- if self.where_clause:
1763
- for x in self.where_clause.concept_arguments:
1764
- if isinstance(x, UndefinedConcept):
1765
- environment.concepts.raise_undefined(
1766
- x.address, x.metadata.line_number
1767
- )
1768
- all_in_output = [x.address for x in self.output_components]
1769
- if self.where_clause:
1770
- for concept in self.where_clause.concept_arguments:
1771
- if (
1772
- concept.lineage
1773
- and isinstance(concept.lineage, Function)
1774
- and concept.lineage.operator
1775
- in FunctionClass.AGGREGATE_FUNCTIONS.value
1776
- ):
1777
- if concept.address in self.locally_derived:
1778
- raise SyntaxError(
1779
- f"Cannot reference an aggregate derived in the select ({concept.address}) in the same statement where clause; move to the HAVING clause instead; Line: {self.meta.line_number}"
1780
- )
1781
-
1782
- if (
1783
- concept.lineage
1784
- and isinstance(concept.lineage, AggregateWrapper)
1785
- and concept.lineage.function.operator
1786
- in FunctionClass.AGGREGATE_FUNCTIONS.value
1787
- ):
1788
- if concept.address in self.locally_derived:
1789
- raise SyntaxError(
1790
- f"Cannot reference an aggregate derived in the select ({concept.address}) in the same statement where clause; move to the HAVING clause instead; Line: {self.meta.line_number}"
1791
- )
1792
- if self.having_clause:
1793
- self.having_clause.hydrate_missing(self.local_concepts)
1794
- for concept in self.having_clause.concept_arguments:
1795
- if concept.address not in [x.address for x in self.output_components]:
1796
- raise SyntaxError(
1797
- f"Cannot reference a column ({concept.address}) that is not in the select projection in the HAVING clause, move to WHERE; Line: {self.meta.line_number}"
1798
- )
1799
- if self.order_by:
1800
- for concept in self.order_by.concept_arguments:
1801
- if concept.address not in all_in_output:
1802
- raise SyntaxError(
1803
- f"Cannot order by a column that is not in the output projection; {self.meta.line_number}"
1804
- )
1805
-
1806
- def __str__(self):
1807
- from trilogy.parsing.render import render_query
1808
-
1809
- return render_query(self)
1810
-
1811
- @field_validator("selection", mode="before")
1812
- @classmethod
1813
- def selection_validation(cls, v):
1814
- new = []
1815
- for item in v:
1816
- if isinstance(item, (Concept, ConceptTransform)):
1817
- new.append(SelectItem(content=item))
1818
- else:
1819
- new.append(item)
1820
- return new
1821
-
1822
- def with_merge(
1823
- self, source: Concept, target: Concept, modifiers: List[Modifier]
1824
- ) -> "SelectStatement":
1825
- return SelectStatement(
1826
- selection=[x.with_merge(source, target, modifiers) for x in self.selection],
1827
- order_by=(
1828
- self.order_by.with_merge(source, target, modifiers)
1829
- if self.order_by
1830
- else None
1831
- ),
1832
- limit=self.limit,
1833
- )
1834
-
1835
- @property
1836
- def locally_derived(self) -> set[str]:
1837
- locally_derived: set[str] = set()
1838
- for item in self.selection:
1839
- if isinstance(item.content, ConceptTransform):
1840
- locally_derived.add(item.content.output.address)
1841
- return locally_derived
1842
-
1843
- @property
1844
- def input_components(self) -> List[Concept]:
1845
- output = set()
1846
- output_list = []
1847
- for item in self.selection:
1848
- for concept in item.input:
1849
- if concept.name in output:
1850
- continue
1851
- output.add(concept.name)
1852
- output_list.append(concept)
1853
- if self.where_clause:
1854
- for concept in self.where_clause.input:
1855
- if concept.name in output:
1856
- continue
1857
- output.add(concept.name)
1858
- output_list.append(concept)
1859
-
1860
- return output_list
1861
-
1862
- @property
1863
- def output_components(self) -> List[Concept]:
1864
- output = []
1865
- for item in self.selection:
1866
- if isinstance(item, Concept):
1867
- output.append(item)
1868
- else:
1869
- output.append(item.output)
1870
- return output
1871
-
1872
- @property
1873
- def hidden_components(self) -> set[str]:
1874
- output = set()
1875
- for item in self.selection:
1876
- if isinstance(item, SelectItem) and Modifier.HIDDEN in item.modifiers:
1877
- output.add(item.output.address)
1878
- return output
1879
-
1880
- @property
1881
- def all_components(self) -> List[Concept]:
1882
- return self.input_components + self.output_components
1883
-
1884
- def to_datasource(
1885
- self,
1886
- namespace: str,
1887
- name: str,
1888
- address: Address,
1889
- grain: Grain | None = None,
1890
- ) -> Datasource:
1891
- if self.where_clause or self.having_clause:
1892
- modifiers = [Modifier.PARTIAL]
1893
- else:
1894
- modifiers = []
1895
- columns = [
1896
- # TODO: replace hardcoded replacement here
1897
- # if the concept is a locally derived concept, it cannot ever be partial
1898
- # but if it's a concept pulled in from upstream and we have a where clause, it should be partial
1899
- ColumnAssignment(
1900
- alias=(
1901
- c.name.replace(".", "_")
1902
- if c.namespace == DEFAULT_NAMESPACE
1903
- else c.address.replace(".", "_")
1904
- ),
1905
- concept=c,
1906
- modifiers=modifiers if c.address not in self.locally_derived else [],
1907
- )
1908
- for c in self.output_components
1909
- ]
1910
-
1911
- condition = None
1912
- if self.where_clause:
1913
- condition = self.where_clause.conditional
1914
- if self.having_clause:
1915
- if condition:
1916
- condition = self.having_clause.conditional + condition
1917
- else:
1918
- condition = self.having_clause.conditional
1919
-
1920
- new_datasource = Datasource(
1921
- name=name,
1922
- address=address,
1923
- grain=grain or self.grain,
1924
- columns=columns,
1925
- namespace=namespace,
1926
- non_partial_for=WhereClause(conditional=condition) if condition else None,
1927
- )
1928
- for column in columns:
1929
- column.concept = column.concept.with_grain(new_datasource.grain)
1930
- return new_datasource
1931
-
1932
- def with_namespace(self, namespace: str) -> "SelectStatement":
1933
- return SelectStatement(
1934
- selection=[c.with_namespace(namespace) for c in self.selection],
1935
- where_clause=(
1936
- self.where_clause.with_namespace(namespace)
1937
- if self.where_clause
1938
- else None
1939
- ),
1940
- order_by=self.order_by.with_namespace(namespace) if self.order_by else None,
1941
- limit=self.limit,
1942
- )
1943
-
1944
-
1945
- class CopyStatement(BaseModel):
1946
- target: str
1947
- target_type: IOType
1948
- meta: Optional[Metadata] = Field(default_factory=lambda: Metadata())
1949
- select: SelectStatement
1950
-
1951
-
1952
- class AlignItem(Namespaced, BaseModel):
1953
- alias: str
1954
- concepts: List[Concept]
1955
- namespace: Optional[str] = Field(default=DEFAULT_NAMESPACE, validate_default=True)
1956
-
1957
- @computed_field # type: ignore
1958
- @cached_property
1959
- def concepts_lcl(self) -> LooseConceptList:
1960
- return LooseConceptList(concepts=self.concepts)
1961
-
1962
- def with_namespace(self, namespace: str) -> "AlignItem":
1963
- return AlignItem(
1964
- alias=self.alias,
1965
- concepts=[c.with_namespace(namespace) for c in self.concepts],
1966
- namespace=namespace,
1967
- )
1968
-
1969
- def gen_concept(self, parent: MultiSelectStatement):
1970
- datatypes = set([c.datatype for c in self.concepts])
1971
- purposes = set([c.purpose for c in self.concepts])
1972
- if len(datatypes) > 1:
1973
- raise InvalidSyntaxException(
1974
- f"Datatypes do not align for merged statements {self.alias}, have {datatypes}"
1975
- )
1976
- if len(purposes) > 1:
1977
- purpose = Purpose.KEY
1978
- else:
1979
- purpose = list(purposes)[0]
1980
- new = Concept(
1981
- name=self.alias,
1982
- datatype=datatypes.pop(),
1983
- purpose=purpose,
1984
- lineage=parent,
1985
- namespace=parent.namespace,
1986
- )
1987
- return new
1988
-
1989
-
1990
- class AlignClause(Namespaced, BaseModel):
1991
- items: List[AlignItem]
1992
-
1993
- def with_namespace(self, namespace: str) -> "AlignClause":
1994
- return AlignClause(items=[x.with_namespace(namespace) for x in self.items])
1995
-
1996
-
1997
- class MultiSelectStatement(HasUUID, SelectTypeMixin, Mergeable, Namespaced, BaseModel):
1998
- selects: List[SelectStatement]
1999
- align: AlignClause
2000
- namespace: str
2001
- order_by: Optional[OrderBy] = None
2002
- limit: Optional[int] = None
2003
- meta: Optional[Metadata] = Field(default_factory=lambda: Metadata())
2004
- local_concepts: Annotated[
2005
- EnvironmentConceptDict, PlainValidator(validate_concepts)
2006
- ] = Field(default_factory=EnvironmentConceptDict)
2007
-
2008
- def __repr__(self):
2009
- return "MultiSelect<" + " MERGE ".join([str(s) for s in self.selects]) + ">"
2010
-
2011
- @property
2012
- def arguments(self) -> List[Concept]:
2013
- output = []
2014
- for select in self.selects:
2015
- output += select.input_components
2016
- return unique(output, "address")
2017
-
2018
- @property
2019
- def concept_arguments(self) -> List[Concept]:
2020
- output = []
2021
- for select in self.selects:
2022
- output += select.input_components
2023
- if self.where_clause:
2024
- output += self.where_clause.concept_arguments
2025
- return unique(output, "address")
2026
-
2027
- def with_merge(
2028
- self, source: Concept, target: Concept, modifiers: List[Modifier]
2029
- ) -> "MultiSelectStatement":
2030
- new = MultiSelectStatement(
2031
- selects=[s.with_merge(source, target, modifiers) for s in self.selects],
2032
- align=self.align,
2033
- namespace=self.namespace,
2034
- order_by=(
2035
- self.order_by.with_merge(source, target, modifiers)
2036
- if self.order_by
2037
- else None
2038
- ),
2039
- limit=self.limit,
2040
- meta=self.meta,
2041
- where_clause=(
2042
- self.where_clause.with_merge(source, target, modifiers)
2043
- if self.where_clause
2044
- else None
2045
- ),
2046
- )
2047
- return new
2048
-
2049
- def get_merge_concept(self, check: Concept):
2050
- for item in self.align.items:
2051
- if check in item.concepts_lcl:
2052
- return item.gen_concept(self)
2053
- return None
2054
-
2055
- def with_namespace(self, namespace: str) -> "MultiSelectStatement":
2056
- return MultiSelectStatement(
2057
- selects=[c.with_namespace(namespace) for c in self.selects],
2058
- align=self.align.with_namespace(namespace),
2059
- namespace=namespace,
2060
- order_by=self.order_by.with_namespace(namespace) if self.order_by else None,
2061
- limit=self.limit,
2062
- meta=self.meta,
2063
- where_clause=(
2064
- self.where_clause.with_namespace(namespace)
2065
- if self.where_clause
2066
- else None
2067
- ),
2068
- local_concepts=EnvironmentConceptDict(
2069
- {k: v.with_namespace(namespace) for k, v in self.local_concepts.items()}
2070
- ),
2071
- )
2072
-
2073
- @property
2074
- def grain(self):
2075
- base = Grain()
2076
- for select in self.selects:
2077
- base += select.grain
2078
- return base
2079
-
2080
- @computed_field # type: ignore
2081
- @cached_property
2082
- def derived_concepts(self) -> List[Concept]:
2083
- output = []
2084
- for item in self.align.items:
2085
- output.append(item.gen_concept(self))
2086
- return output
2087
-
2088
- def find_source(self, concept: Concept, cte: CTE | UnionCTE) -> Concept:
2089
- for x in self.align.items:
2090
- if concept.name == x.alias:
2091
- for c in x.concepts:
2092
- if c.address in cte.output_lcl:
2093
- return c
2094
- raise SyntaxError(
2095
- f"Could not find upstream map for multiselect {str(concept)} on cte ({cte})"
2096
- )
2097
-
2098
- @property
2099
- def output_components(self) -> List[Concept]:
2100
- output = self.derived_concepts
2101
- for select in self.selects:
2102
- output += select.output_components
2103
- return unique(output, "address")
2104
-
2105
- @computed_field # type: ignore
2106
- @cached_property
2107
- def hidden_components(self) -> set[str]:
2108
- output: set[str] = set()
2109
- for select in self.selects:
2110
- output = output.union(select.hidden_components)
2111
- return output
2112
-
2113
-
2114
- class Address(BaseModel):
2115
- location: str
2116
- is_query: bool = False
2117
- quoted: bool = False
2118
-
2119
-
2120
- class Query(BaseModel):
2121
- text: str
2122
-
2123
-
2124
- def safe_concept(v: Union[Dict, Concept]) -> Concept:
2125
- if isinstance(v, dict):
2126
- return Concept.model_validate(v)
2127
- return v
2128
-
2129
-
2130
- class GrainWindow(BaseModel):
2131
- window: Window
2132
- sort_concepts: List[Concept]
2133
-
2134
- def __str__(self):
2135
- return (
2136
- "GrainWindow<"
2137
- + ",".join([c.address for c in self.sort_concepts])
2138
- + f":{str(self.window)}>"
2139
- )
2140
-
2141
-
2142
- def safe_grain(v) -> Grain:
2143
- if isinstance(v, dict):
2144
- return Grain.model_validate(v)
2145
- elif isinstance(v, Grain):
2146
- return v
2147
- elif not v:
2148
- return Grain(components=set())
2149
- else:
2150
- raise ValueError(f"Invalid input type to safe_grain {type(v)}")
2151
-
2152
-
2153
- class DatasourceMetadata(BaseModel):
2154
- freshness_concept: Concept | None
2155
- partition_fields: List[Concept] = Field(default_factory=list)
2156
- line_no: int | None = None
2157
-
2158
-
2159
- class MergeStatementV2(HasUUID, Namespaced, BaseModel):
2160
- sources: list[Concept]
2161
- targets: dict[str, Concept]
2162
- source_wildcard: str | None = None
2163
- target_wildcard: str | None = None
2164
- modifiers: List[Modifier] = Field(default_factory=list)
2165
-
2166
- def with_namespace(self, namespace: str) -> "MergeStatementV2":
2167
- new = MergeStatementV2(
2168
- sources=[x.with_namespace(namespace) for x in self.sources],
2169
- targets={k: v.with_namespace(namespace) for k, v in self.targets.items()},
2170
- modifiers=self.modifiers,
2171
- )
2172
- return new
2173
-
2174
-
2175
- class Datasource(HasUUID, Namespaced, BaseModel):
2176
- name: str
2177
- columns: List[ColumnAssignment]
2178
- address: Union[Address, str]
2179
- grain: Grain = Field(
2180
- default_factory=lambda: Grain(components=set()), validate_default=True
2181
- )
2182
- namespace: Optional[str] = Field(default=DEFAULT_NAMESPACE, validate_default=True)
2183
- metadata: DatasourceMetadata = Field(
2184
- default_factory=lambda: DatasourceMetadata(freshness_concept=None)
2185
- )
2186
- where: Optional[WhereClause] = None
2187
- non_partial_for: Optional[WhereClause] = None
2188
-
2189
- def duplicate(self) -> Datasource:
2190
- return self.model_copy(deep=True)
2191
-
2192
- @property
2193
- def hidden_concepts(self) -> List[Concept]:
2194
- return []
2195
-
2196
- def merge_concept(
2197
- self, source: Concept, target: Concept, modifiers: List[Modifier]
2198
- ):
2199
- original = [c for c in self.columns if c.concept.address == source.address]
2200
- early_exit_check = [
2201
- c for c in self.columns if c.concept.address == target.address
2202
- ]
2203
- if early_exit_check:
2204
- return None
2205
- if len(original) != 1:
2206
- raise ValueError(
2207
- f"Expected exactly one column to merge, got {len(original)} for {source.address}, {[x.alias for x in original]}"
2208
- )
2209
- # map to the alias with the modifier, and the original
2210
- self.columns = [
2211
- c.with_merge(source, target, modifiers)
2212
- for c in self.columns
2213
- if c.concept.address != source.address
2214
- ] + original
2215
- self.grain = self.grain.with_merge(source, target, modifiers)
2216
- self.where = (
2217
- self.where.with_merge(source, target, modifiers) if self.where else None
2218
- )
2219
-
2220
- self.add_column(target, original[0].alias, modifiers)
2221
-
2222
- @property
2223
- def identifier(self) -> str:
2224
- if not self.namespace or self.namespace == DEFAULT_NAMESPACE:
2225
- return self.name
2226
- return f"{self.namespace}.{self.name}"
2227
-
2228
- @property
2229
- def safe_identifier(self) -> str:
2230
- return self.identifier.replace(".", "_")
2231
-
2232
- @property
2233
- def condition(self):
2234
- return None
2235
-
2236
- @property
2237
- def output_lcl(self) -> LooseConceptList:
2238
- return LooseConceptList(concepts=self.output_concepts)
2239
-
2240
- @property
2241
- def can_be_inlined(self) -> bool:
2242
- if isinstance(self.address, Address) and self.address.is_query:
2243
- return False
2244
- # for x in self.columns:
2245
- # if not isinstance(x.alias, str):
2246
- # return False
2247
- return True
2248
-
2249
- @property
2250
- def non_partial_concept_addresses(self) -> set[str]:
2251
- return set([c.address for c in self.full_concepts])
2252
-
2253
- @field_validator("namespace", mode="plain")
2254
- @classmethod
2255
- def namespace_validation(cls, v):
2256
- return v or DEFAULT_NAMESPACE
2257
-
2258
- @field_validator("address")
2259
- @classmethod
2260
- def address_enforcement(cls, v):
2261
- if isinstance(v, str):
2262
- v = Address(location=v)
2263
- return v
2264
-
2265
- @field_validator("grain", mode="before")
2266
- @classmethod
2267
- def grain_enforcement(cls, v: Grain, info: ValidationInfo):
2268
- grain: Grain = safe_grain(v)
2269
- return grain
2270
-
2271
- def add_column(
2272
- self,
2273
- concept: Concept,
2274
- alias: str | RawColumnExpr | Function,
2275
- modifiers: List[Modifier] | None = None,
2276
- ):
2277
- self.columns.append(
2278
- ColumnAssignment(alias=alias, concept=concept, modifiers=modifiers or [])
2279
- )
2280
-
2281
- def __add__(self, other):
2282
- if not other == self:
2283
- raise ValueError(
2284
- "Attempted to add two datasources that are not identical, this is not a valid operation"
2285
- )
2286
- return self
2287
-
2288
- def __repr__(self):
2289
- return f"Datasource<{self.identifier}@<{self.grain}>"
2290
-
2291
- def __str__(self):
2292
- return self.__repr__()
2293
-
2294
- def __hash__(self):
2295
- return self.identifier.__hash__()
2296
-
2297
- def with_namespace(self, namespace: str):
2298
- new_namespace = (
2299
- namespace + "." + self.namespace
2300
- if self.namespace and self.namespace != DEFAULT_NAMESPACE
2301
- else namespace
2302
- )
2303
- return Datasource(
2304
- name=self.name,
2305
- namespace=new_namespace,
2306
- grain=self.grain.with_namespace(namespace),
2307
- address=self.address,
2308
- columns=[c.with_namespace(namespace) for c in self.columns],
2309
- where=self.where.with_namespace(namespace) if self.where else None,
2310
- )
2311
-
2312
- @property
2313
- def concepts(self) -> List[Concept]:
2314
- return [c.concept for c in self.columns]
2315
-
2316
- @property
2317
- def group_required(self):
2318
- return False
2319
-
2320
- @property
2321
- def full_concepts(self) -> List[Concept]:
2322
- return [c.concept for c in self.columns if Modifier.PARTIAL not in c.modifiers]
2323
-
2324
- @property
2325
- def nullable_concepts(self) -> List[Concept]:
2326
- return [c.concept for c in self.columns if Modifier.NULLABLE in c.modifiers]
2327
-
2328
- @property
2329
- def output_concepts(self) -> List[Concept]:
2330
- return self.concepts
2331
-
2332
- @property
2333
- def partial_concepts(self) -> List[Concept]:
2334
- return [c.concept for c in self.columns if Modifier.PARTIAL in c.modifiers]
2335
-
2336
- def get_alias(
2337
- self, concept: Concept, use_raw_name: bool = True, force_alias: bool = False
2338
- ) -> Optional[str | RawColumnExpr] | Function:
2339
- # 2022-01-22
2340
- # this logic needs to be refined.
2341
- # if concept.lineage:
2342
- # # return None
2343
- for x in self.columns:
2344
- if x.concept == concept or x.concept.with_grain(concept.grain) == concept:
2345
- if use_raw_name:
2346
- return x.alias
2347
- return concept.safe_address
2348
- existing = [str(c.concept.with_grain(self.grain)) for c in self.columns]
2349
- raise ValueError(
2350
- f"{LOGGER_PREFIX} Concept {concept} not found on {self.identifier}; have"
2351
- f" {existing}."
2352
- )
2353
-
2354
- @property
2355
- def safe_location(self) -> str:
2356
- if isinstance(self.address, Address):
2357
- return self.address.location
2358
- return self.address
2359
-
2360
-
2361
- class UnnestJoin(BaseModel):
2362
- concepts: list[Concept]
2363
- parent: Function
2364
- alias: str = "unnest"
2365
- rendering_required: bool = True
2366
-
2367
- def __hash__(self):
2368
- return self.safe_identifier.__hash__()
2369
-
2370
- @property
2371
- def safe_identifier(self) -> str:
2372
- return self.alias + "".join([str(s.address) for s in self.concepts])
2373
-
2374
-
2375
- class InstantiatedUnnestJoin(BaseModel):
2376
- concept_to_unnest: Concept
2377
- alias: str = "unnest"
2378
-
2379
-
2380
- class ConceptPair(BaseModel):
2381
- left: Concept
2382
- right: Concept
2383
- existing_datasource: Union[Datasource, "QueryDatasource"]
2384
- modifiers: List[Modifier] = Field(default_factory=list)
2385
-
2386
- @property
2387
- def is_partial(self):
2388
- return Modifier.PARTIAL in self.modifiers
2389
-
2390
- @property
2391
- def is_nullable(self):
2392
- return Modifier.NULLABLE in self.modifiers
2393
-
2394
-
2395
- class CTEConceptPair(ConceptPair):
2396
- cte: CTE
2397
-
2398
-
2399
- class BaseJoin(BaseModel):
2400
- right_datasource: Union[Datasource, "QueryDatasource"]
2401
- join_type: JoinType
2402
- concepts: Optional[List[Concept]] = None
2403
- left_datasource: Optional[Union[Datasource, "QueryDatasource"]] = None
2404
- concept_pairs: list[ConceptPair] | None = None
2405
-
2406
- def __init__(self, **data: Any):
2407
- super().__init__(**data)
2408
- if (
2409
- self.left_datasource
2410
- and self.left_datasource.identifier == self.right_datasource.identifier
2411
- ):
2412
- raise SyntaxError(
2413
- f"Cannot join a dataself to itself, joining {self.left_datasource} and"
2414
- f" {self.right_datasource}"
2415
- )
2416
- final_concepts = []
2417
-
2418
- # if we have a list of concept pairs
2419
- if self.concept_pairs:
2420
- return
2421
- if self.concepts == []:
2422
- return
2423
- assert self.left_datasource and self.right_datasource
2424
- for concept in self.concepts or []:
2425
- include = True
2426
- for ds in [self.left_datasource, self.right_datasource]:
2427
- synonyms = []
2428
- for c in ds.output_concepts:
2429
- synonyms += list(c.pseudonyms)
2430
- if (
2431
- concept.address not in [c.address for c in ds.output_concepts]
2432
- and concept.address not in synonyms
2433
- ):
2434
- raise SyntaxError(
2435
- f"Invalid join, missing {concept} on {ds.name}, have"
2436
- f" {[c.address for c in ds.output_concepts]}"
2437
- )
2438
- if include:
2439
- final_concepts.append(concept)
2440
- if not final_concepts and self.concepts:
2441
- # if one datasource only has constants
2442
- # we can join on 1=1
2443
- for ds in [self.left_datasource, self.right_datasource]:
2444
- # single rows
2445
- if all(
2446
- [
2447
- c.granularity == Granularity.SINGLE_ROW
2448
- for c in ds.output_concepts
2449
- ]
2450
- ):
2451
- self.concepts = []
2452
- return
2453
- # if everything is at abstract grain, we can skip joins
2454
- if all([c.grain.abstract for c in ds.output_concepts]):
2455
- self.concepts = []
2456
- return
2457
-
2458
- left_keys = [c.address for c in self.left_datasource.output_concepts]
2459
- right_keys = [c.address for c in self.right_datasource.output_concepts]
2460
- match_concepts = [c.address for c in self.concepts]
2461
- raise SyntaxError(
2462
- "No mutual join keys found between"
2463
- f" {self.left_datasource.identifier} and"
2464
- f" {self.right_datasource.identifier}, left_keys {left_keys},"
2465
- f" right_keys {right_keys},"
2466
- f" provided join concepts {match_concepts}"
2467
- )
2468
- self.concepts = final_concepts
2469
-
2470
- @property
2471
- def unique_id(self) -> str:
2472
- return str(self)
2473
-
2474
- @property
2475
- def input_concepts(self) -> List[Concept]:
2476
- base = []
2477
- if self.concept_pairs:
2478
- for pair in self.concept_pairs:
2479
- base += [pair.left, pair.right]
2480
- elif self.concepts:
2481
- base += self.concepts
2482
- return base
2483
-
2484
- def __str__(self):
2485
- if self.concept_pairs:
2486
- return (
2487
- f"{self.join_type.value} {self.right_datasource.name} on"
2488
- f" {','.join([str(k.existing_datasource.name) + '.'+ str(k.left)+'='+str(k.right) for k in self.concept_pairs])}"
2489
- )
2490
- return (
2491
- f"{self.join_type.value} {self.right_datasource.name} on"
2492
- f" {','.join([str(k) for k in self.concepts])}"
2493
- )
2494
-
2495
-
2496
- class QueryDatasource(BaseModel):
2497
- input_concepts: List[Concept]
2498
- output_concepts: List[Concept]
2499
- datasources: List[Union[Datasource, "QueryDatasource"]]
2500
- source_map: Dict[str, Set[Union[Datasource, "QueryDatasource", "UnnestJoin"]]]
2501
-
2502
- grain: Grain
2503
- joins: List[BaseJoin | UnnestJoin]
2504
- limit: Optional[int] = None
2505
- condition: Optional[Union["Conditional", "Comparison", "Parenthetical"]] = Field(
2506
- default=None
2507
- )
2508
- filter_concepts: List[Concept] = Field(default_factory=list)
2509
- source_type: SourceType = SourceType.SELECT
2510
- partial_concepts: List[Concept] = Field(default_factory=list)
2511
- hidden_concepts: set[str] = Field(default_factory=set)
2512
- nullable_concepts: List[Concept] = Field(default_factory=list)
2513
- join_derived_concepts: List[Concept] = Field(default_factory=list)
2514
- force_group: bool | None = None
2515
- existence_source_map: Dict[str, Set[Union[Datasource, "QueryDatasource"]]] = Field(
2516
- default_factory=dict
2517
- )
2518
-
2519
- def __repr__(self):
2520
- return f"{self.identifier}@<{self.grain}>"
2521
-
2522
- @property
2523
- def safe_identifier(self):
2524
- return self.identifier.replace(".", "_")
2525
-
2526
- @property
2527
- def non_partial_concept_addresses(self) -> List[str]:
2528
- return [
2529
- c.address
2530
- for c in self.output_concepts
2531
- if c.address not in [z.address for z in self.partial_concepts]
2532
- ]
2533
-
2534
- @field_validator("joins")
2535
- @classmethod
2536
- def validate_joins(cls, v):
2537
- unique_pairs = set()
2538
- for join in v:
2539
- if not isinstance(join, BaseJoin):
2540
- continue
2541
- pairing = str(join)
2542
- if pairing in unique_pairs:
2543
- raise SyntaxError(f"Duplicate join {str(join)}")
2544
- unique_pairs.add(pairing)
2545
- return v
2546
-
2547
- @field_validator("input_concepts")
2548
- @classmethod
2549
- def validate_inputs(cls, v):
2550
- return unique(v, "address")
2551
-
2552
- @field_validator("output_concepts")
2553
- @classmethod
2554
- def validate_outputs(cls, v):
2555
- return unique(v, "address")
2556
-
2557
- @field_validator("source_map")
2558
- @classmethod
2559
- def validate_source_map(cls, v: dict, info: ValidationInfo):
2560
- values = info.data
2561
- for key in ("input_concepts", "output_concepts"):
2562
- if not values.get(key):
2563
- continue
2564
- concept: Concept
2565
- for concept in values[key]:
2566
- if (
2567
- concept.address not in v
2568
- and not any(x in v for x in concept.pseudonyms)
2569
- and CONFIG.validate_missing
2570
- ):
2571
- raise SyntaxError(
2572
- f"On query datasource missing source map for {concept.address} on {key}, have {v}"
2573
- )
2574
- return v
2575
-
2576
- def __str__(self):
2577
- return self.__repr__()
2578
-
2579
- def __hash__(self):
2580
- return (self.identifier).__hash__()
2581
-
2582
- @property
2583
- def concepts(self):
2584
- return self.output_concepts
2585
-
2586
- @property
2587
- def name(self):
2588
- return self.identifier
2589
-
2590
- @property
2591
- def group_required(self) -> bool:
2592
- if self.force_group is True:
2593
- return True
2594
- if self.force_group is False:
2595
- return False
2596
- if self.source_type:
2597
- if self.source_type in [
2598
- SourceType.FILTER,
2599
- ]:
2600
- return False
2601
- elif self.source_type in [
2602
- SourceType.GROUP,
2603
- ]:
2604
- return True
2605
- return False
2606
-
2607
- def __add__(self, other) -> "QueryDatasource":
2608
- # these are syntax errors to avoid being caught by current
2609
- if not isinstance(other, QueryDatasource):
2610
- raise SyntaxError("Can only merge two query datasources")
2611
- if not other.grain == self.grain:
2612
- raise SyntaxError(
2613
- "Can only merge two query datasources with identical grain"
2614
- )
2615
- if not self.group_required == other.group_required:
2616
- raise SyntaxError(
2617
- "can only merge two datasources if the group required flag is the same"
2618
- )
2619
- if not self.join_derived_concepts == other.join_derived_concepts:
2620
- raise SyntaxError(
2621
- "can only merge two datasources if the join derived concepts are the same"
2622
- )
2623
- if not self.force_group == other.force_group:
2624
- raise SyntaxError(
2625
- "can only merge two datasources if the force_group flag is the same"
2626
- )
2627
- logger.debug(
2628
- f"{LOGGER_PREFIX} merging {self.name} with"
2629
- f" {[c.address for c in self.output_concepts]} concepts and"
2630
- f" {other.name} with {[c.address for c in other.output_concepts]} concepts"
2631
- )
2632
-
2633
- merged_datasources: dict[str, Union[Datasource, "QueryDatasource"]] = {}
2634
-
2635
- for ds in [*self.datasources, *other.datasources]:
2636
- if ds.safe_identifier in merged_datasources:
2637
- merged_datasources[ds.safe_identifier] = (
2638
- merged_datasources[ds.safe_identifier] + ds
2639
- )
2640
- else:
2641
- merged_datasources[ds.safe_identifier] = ds
2642
-
2643
- final_source_map: defaultdict[
2644
- str, Set[Union[Datasource, "QueryDatasource", "UnnestJoin"]]
2645
- ] = defaultdict(set)
2646
-
2647
- # add our sources
2648
- for key in self.source_map:
2649
- final_source_map[key] = self.source_map[key].union(
2650
- other.source_map.get(key, set())
2651
- )
2652
- # add their sources
2653
- for key in other.source_map:
2654
- if key not in final_source_map:
2655
- final_source_map[key] = other.source_map[key]
2656
-
2657
- # if a ds was merged (to combine columns), we need to update the source map
2658
- # to use the merged item
2659
- for k, v in final_source_map.items():
2660
- final_source_map[k] = set(
2661
- merged_datasources.get(x.safe_identifier, x) for x in list(v)
2662
- )
2663
- self_hidden: set[str] = self.hidden_concepts or set()
2664
- other_hidden: set[str] = other.hidden_concepts or set()
2665
- # hidden is the minimum overlapping set
2666
- hidden = self_hidden.intersection(other_hidden)
2667
- qds = QueryDatasource(
2668
- input_concepts=unique(
2669
- self.input_concepts + other.input_concepts, "address"
2670
- ),
2671
- output_concepts=unique(
2672
- self.output_concepts + other.output_concepts, "address"
2673
- ),
2674
- source_map=final_source_map,
2675
- datasources=list(merged_datasources.values()),
2676
- grain=self.grain,
2677
- joins=unique(self.joins + other.joins, "unique_id"),
2678
- condition=(
2679
- self.condition + other.condition
2680
- if self.condition and other.condition
2681
- else self.condition or other.condition
2682
- ),
2683
- source_type=self.source_type,
2684
- partial_concepts=unique(
2685
- self.partial_concepts + other.partial_concepts, "address"
2686
- ),
2687
- join_derived_concepts=self.join_derived_concepts,
2688
- force_group=self.force_group,
2689
- hidden_concepts=hidden,
2690
- )
2691
-
2692
- return qds
2693
-
2694
- @property
2695
- def identifier(self) -> str:
2696
- filters = abs(hash(str(self.condition))) if self.condition else ""
2697
- grain = "_".join([str(c).replace(".", "_") for c in self.grain.components])
2698
- return (
2699
- "_join_".join([d.identifier for d in self.datasources])
2700
- + (f"_at_{grain}" if grain else "_at_abstract")
2701
- + (f"_filtered_by_{filters}" if filters else "")
2702
- )
2703
-
2704
- def get_alias(
2705
- self,
2706
- concept: Concept,
2707
- use_raw_name: bool = False,
2708
- force_alias: bool = False,
2709
- source: str | None = None,
2710
- ):
2711
- for x in self.datasources:
2712
- # query datasources should be referenced by their alias, always
2713
- force_alias = isinstance(x, QueryDatasource)
2714
- #
2715
- use_raw_name = isinstance(x, Datasource) and not force_alias
2716
- if source and x.safe_identifier != source:
2717
- continue
2718
- try:
2719
- return x.get_alias(
2720
- concept.with_grain(self.grain),
2721
- use_raw_name,
2722
- force_alias=force_alias,
2723
- )
2724
- except ValueError as e:
2725
- from trilogy.constants import logger
2726
-
2727
- logger.debug(e)
2728
- continue
2729
- existing = [c.with_grain(self.grain) for c in self.output_concepts]
2730
- if concept in existing:
2731
- return concept.name
2732
-
2733
- existing_str = [str(c) for c in existing]
2734
- datasources = [ds.identifier for ds in self.datasources]
2735
- raise ValueError(
2736
- f"{LOGGER_PREFIX} Concept {str(concept)} not found on {self.identifier};"
2737
- f" have {existing_str} from {datasources}."
2738
- )
2739
-
2740
- @property
2741
- def safe_location(self):
2742
- return self.identifier
2743
-
2744
-
2745
- class Comment(BaseModel):
2746
- text: str
2747
-
2748
-
2749
- class CTE(BaseModel):
2750
- name: str
2751
- source: "QueryDatasource"
2752
- output_columns: List[Concept]
2753
- source_map: Dict[str, list[str]]
2754
- grain: Grain
2755
- base: bool = False
2756
- group_to_grain: bool = False
2757
- existence_source_map: Dict[str, list[str]] = Field(default_factory=dict)
2758
- parent_ctes: List[Union["CTE", "UnionCTE"]] = Field(default_factory=list)
2759
- joins: List[Union["Join", "InstantiatedUnnestJoin"]] = Field(default_factory=list)
2760
- condition: Optional[Union["Conditional", "Comparison", "Parenthetical"]] = None
2761
- partial_concepts: List[Concept] = Field(default_factory=list)
2762
- nullable_concepts: List[Concept] = Field(default_factory=list)
2763
- join_derived_concepts: List[Concept] = Field(default_factory=list)
2764
- hidden_concepts: set[str] = Field(default_factory=set)
2765
- order_by: Optional[OrderBy] = None
2766
- limit: Optional[int] = None
2767
- base_name_override: Optional[str] = None
2768
- base_alias_override: Optional[str] = None
2769
-
2770
- @property
2771
- def identifier(self):
2772
- return self.name
2773
-
2774
- @property
2775
- def safe_identifier(self):
2776
- return self.name
2777
-
2778
- @computed_field # type: ignore
2779
- @property
2780
- def output_lcl(self) -> LooseConceptList:
2781
- return LooseConceptList(concepts=self.output_columns)
2782
-
2783
- @field_validator("output_columns")
2784
- def validate_output_columns(cls, v):
2785
- return unique(v, "address")
2786
-
2787
- def inline_constant(self, concept: Concept):
2788
- if not concept.derivation == PurposeLineage.CONSTANT:
2789
- return False
2790
- if not isinstance(concept.lineage, Function):
2791
- return False
2792
- if not concept.lineage.operator == FunctionType.CONSTANT:
2793
- return False
2794
- # remove the constant
2795
- removed: set = set()
2796
- if concept.address in self.source_map:
2797
- removed = removed.union(self.source_map[concept.address])
2798
- del self.source_map[concept.address]
2799
-
2800
- if self.condition:
2801
- self.condition = self.condition.inline_constant(concept)
2802
-
2803
- # if we've entirely removed the need to join to someplace to get the concept
2804
- # drop the join as well.
2805
- for removed_cte in removed:
2806
- still_required = any(
2807
- [
2808
- removed_cte in x
2809
- for x in self.source_map.values()
2810
- or self.existence_source_map.values()
2811
- ]
2812
- )
2813
- if not still_required:
2814
- self.joins = [
2815
- join
2816
- for join in self.joins
2817
- if not isinstance(join, Join)
2818
- or (
2819
- isinstance(join, Join)
2820
- and (
2821
- join.right_cte.name != removed_cte
2822
- and any(
2823
- [
2824
- x.cte.name != removed_cte
2825
- for x in (join.joinkey_pairs or [])
2826
- ]
2827
- )
2828
- )
2829
- )
2830
- ]
2831
- for join in self.joins:
2832
- if isinstance(join, UnnestJoin) and concept in join.concepts:
2833
- join.rendering_required = False
2834
-
2835
- self.parent_ctes = [
2836
- x for x in self.parent_ctes if x.name != removed_cte
2837
- ]
2838
- if removed_cte == self.base_name_override:
2839
- candidates = [x.name for x in self.parent_ctes]
2840
- self.base_name_override = candidates[0] if candidates else None
2841
- self.base_alias_override = candidates[0] if candidates else None
2842
- return True
2843
-
2844
- @property
2845
- def comment(self) -> str:
2846
- base = f"Target: {str(self.grain)}. Group: {self.group_to_grain}"
2847
- base += f" Source: {self.source.source_type}."
2848
- if self.parent_ctes:
2849
- base += f" References: {', '.join([x.name for x in self.parent_ctes])}."
2850
- if self.joins:
2851
- base += f"\n-- Joins: {', '.join([str(x) for x in self.joins])}."
2852
- if self.partial_concepts:
2853
- base += (
2854
- f"\n-- Partials: {', '.join([str(x) for x in self.partial_concepts])}."
2855
- )
2856
- base += f"\n-- Source Map: {self.source_map}."
2857
- base += f"\n-- Output: {', '.join([str(x) for x in self.output_columns])}."
2858
- if self.source.input_concepts:
2859
- base += f"\n-- Inputs: {', '.join([str(x) for x in self.source.input_concepts])}."
2860
- if self.hidden_concepts:
2861
- base += f"\n-- Hidden: {', '.join([str(x) for x in self.hidden_concepts])}."
2862
- if self.nullable_concepts:
2863
- base += (
2864
- f"\n-- Nullable: {', '.join([str(x) for x in self.nullable_concepts])}."
2865
- )
2866
-
2867
- return base
2868
-
2869
- def inline_parent_datasource(self, parent: CTE, force_group: bool = False) -> bool:
2870
- qds_being_inlined = parent.source
2871
- ds_being_inlined = qds_being_inlined.datasources[0]
2872
- if not isinstance(ds_being_inlined, Datasource):
2873
- return False
2874
- if any(
2875
- [
2876
- x.safe_identifier == ds_being_inlined.safe_identifier
2877
- for x in self.source.datasources
2878
- ]
2879
- ):
2880
- return False
2881
-
2882
- self.source.datasources = [
2883
- ds_being_inlined,
2884
- *[
2885
- x
2886
- for x in self.source.datasources
2887
- if x.safe_identifier != qds_being_inlined.safe_identifier
2888
- ],
2889
- ]
2890
- # need to identify this before updating joins
2891
- if self.base_name == parent.name:
2892
- self.base_name_override = ds_being_inlined.safe_location
2893
- self.base_alias_override = ds_being_inlined.safe_identifier
2894
-
2895
- for join in self.joins:
2896
- if isinstance(join, InstantiatedUnnestJoin):
2897
- continue
2898
- if (
2899
- join.left_cte
2900
- and join.left_cte.safe_identifier == parent.safe_identifier
2901
- ):
2902
- join.inline_cte(parent)
2903
- if join.joinkey_pairs:
2904
- for pair in join.joinkey_pairs:
2905
- if pair.cte and pair.cte.safe_identifier == parent.safe_identifier:
2906
- join.inline_cte(parent)
2907
- if join.right_cte.safe_identifier == parent.safe_identifier:
2908
- join.inline_cte(parent)
2909
- for k, v in self.source_map.items():
2910
- if isinstance(v, list):
2911
- self.source_map[k] = [
2912
- (
2913
- ds_being_inlined.safe_identifier
2914
- if x == parent.safe_identifier
2915
- else x
2916
- )
2917
- for x in v
2918
- ]
2919
- elif v == parent.safe_identifier:
2920
- self.source_map[k] = [ds_being_inlined.safe_identifier]
2921
-
2922
- # zip in any required values for lookups
2923
- for k in ds_being_inlined.output_lcl.addresses:
2924
- if k in self.source_map and self.source_map[k]:
2925
- continue
2926
- self.source_map[k] = [ds_being_inlined.safe_identifier]
2927
- self.parent_ctes = [
2928
- x for x in self.parent_ctes if x.safe_identifier != parent.safe_identifier
2929
- ]
2930
- if force_group:
2931
- self.group_to_grain = True
2932
- return True
2933
-
2934
- def __add__(self, other: "CTE" | UnionCTE):
2935
- if isinstance(other, UnionCTE):
2936
- raise ValueError("cannot merge CTE and union CTE")
2937
- logger.info('Merging two copies of CTE "%s"', self.name)
2938
- if not self.grain == other.grain:
2939
- error = (
2940
- "Attempting to merge two ctes of different grains"
2941
- f" {self.name} {other.name} grains {self.grain} {other.grain}| {self.group_to_grain} {other.group_to_grain}| {self.output_lcl} {other.output_lcl}"
2942
- )
2943
- raise ValueError(error)
2944
- if not self.condition == other.condition:
2945
- error = (
2946
- "Attempting to merge two ctes with different conditions"
2947
- f" {self.name} {other.name} conditions {self.condition} {other.condition}"
2948
- )
2949
- raise ValueError(error)
2950
- mutually_hidden = set()
2951
- for concept in self.hidden_concepts:
2952
- if concept in other.hidden_concepts:
2953
- mutually_hidden.add(concept)
2954
- self.partial_concepts = unique(
2955
- self.partial_concepts + other.partial_concepts, "address"
2956
- )
2957
- self.parent_ctes = merge_ctes(self.parent_ctes + other.parent_ctes)
2958
-
2959
- self.source_map = {**self.source_map, **other.source_map}
2960
-
2961
- self.output_columns = unique(
2962
- self.output_columns + other.output_columns, "address"
2963
- )
2964
- self.joins = unique(self.joins + other.joins, "unique_id")
2965
- self.partial_concepts = unique(
2966
- self.partial_concepts + other.partial_concepts, "address"
2967
- )
2968
- self.join_derived_concepts = unique(
2969
- self.join_derived_concepts + other.join_derived_concepts, "address"
2970
- )
2971
-
2972
- self.source.source_map = {**self.source.source_map, **other.source.source_map}
2973
- self.source.output_concepts = unique(
2974
- self.source.output_concepts + other.source.output_concepts, "address"
2975
- )
2976
- self.nullable_concepts = unique(
2977
- self.nullable_concepts + other.nullable_concepts, "address"
2978
- )
2979
- self.hidden_concepts = mutually_hidden
2980
- self.existence_source_map = {
2981
- **self.existence_source_map,
2982
- **other.existence_source_map,
2983
- }
2984
- return self
2985
-
2986
- @property
2987
- def relevant_base_ctes(self):
2988
- return self.parent_ctes
2989
-
2990
- @property
2991
- def is_root_datasource(self) -> bool:
2992
- return (
2993
- len(self.source.datasources) == 1
2994
- and isinstance(self.source.datasources[0], Datasource)
2995
- and not self.source.datasources[0].name == CONSTANT_DATASET
2996
- )
2997
-
2998
- @property
2999
- def base_name(self) -> str:
3000
- if self.base_name_override:
3001
- return self.base_name_override
3002
- # if this cte selects from a single datasource, select right from it
3003
- if self.is_root_datasource:
3004
- return self.source.datasources[0].safe_location
3005
-
3006
- # if we have multiple joined CTEs, pick the base
3007
- # as the root
3008
- elif len(self.source.datasources) == 1 and len(self.parent_ctes) == 1:
3009
- return self.parent_ctes[0].name
3010
- elif self.relevant_base_ctes:
3011
- return self.relevant_base_ctes[0].name
3012
- return self.source.name
3013
-
3014
- @property
3015
- def quote_address(self) -> bool:
3016
- if self.is_root_datasource:
3017
- candidate = self.source.datasources[0]
3018
- if isinstance(candidate, Datasource) and isinstance(
3019
- candidate.address, Address
3020
- ):
3021
- return candidate.address.quoted
3022
- return False
3023
-
3024
- @property
3025
- def base_alias(self) -> str:
3026
- if self.base_alias_override:
3027
- return self.base_alias_override
3028
- if self.is_root_datasource:
3029
- return self.source.datasources[0].identifier
3030
- elif self.relevant_base_ctes:
3031
- return self.relevant_base_ctes[0].name
3032
- elif self.parent_ctes:
3033
- return self.parent_ctes[0].name
3034
- return self.name
3035
-
3036
- def get_concept(self, address: str) -> Concept | None:
3037
- for cte in self.parent_ctes:
3038
- if address in cte.output_columns:
3039
- match = [x for x in cte.output_columns if x.address == address].pop()
3040
- if match:
3041
- return match
3042
-
3043
- for array in [self.source.input_concepts, self.source.output_concepts]:
3044
- match_list = [x for x in array if x.address == address]
3045
- if match_list:
3046
- return match_list.pop()
3047
- match_list = [x for x in self.output_columns if x.address == address]
3048
- if match_list:
3049
- return match_list.pop()
3050
- return None
3051
-
3052
- def get_alias(self, concept: Concept, source: str | None = None) -> str:
3053
- for cte in self.parent_ctes:
3054
- if concept.address in cte.output_columns:
3055
- if source and source != cte.name:
3056
- continue
3057
- return concept.safe_address
3058
-
3059
- try:
3060
- source = self.source.get_alias(concept, source=source)
3061
-
3062
- if not source:
3063
- raise ValueError("No source found")
3064
- return source
3065
- except ValueError as e:
3066
- return f"INVALID_ALIAS: {str(e)}"
3067
-
3068
- @property
3069
- def group_concepts(self) -> List[Concept]:
3070
- def check_is_not_in_group(c: Concept):
3071
- if len(self.source_map.get(c.address, [])) > 0:
3072
- return False
3073
- if c.derivation == PurposeLineage.ROWSET:
3074
- assert isinstance(c.lineage, RowsetItem)
3075
- return check_is_not_in_group(c.lineage.content)
3076
- if c.derivation == PurposeLineage.CONSTANT:
3077
- return True
3078
- if c.purpose == Purpose.METRIC:
3079
- return True
3080
-
3081
- if c.derivation == PurposeLineage.BASIC and c.lineage:
3082
- if all([check_is_not_in_group(x) for x in c.lineage.concept_arguments]):
3083
- return True
3084
- if (
3085
- isinstance(c.lineage, Function)
3086
- and c.lineage.operator == FunctionType.GROUP
3087
- ):
3088
- return check_is_not_in_group(c.lineage.concept_arguments[0])
3089
- return False
3090
-
3091
- return (
3092
- unique(
3093
- [c for c in self.output_columns if not check_is_not_in_group(c)],
3094
- "address",
3095
- )
3096
- if self.group_to_grain
3097
- else []
3098
- )
3099
-
3100
- @property
3101
- def render_from_clause(self) -> bool:
3102
- if (
3103
- all([c.derivation == PurposeLineage.CONSTANT for c in self.output_columns])
3104
- and not self.parent_ctes
3105
- and not self.group_to_grain
3106
- ):
3107
- return False
3108
- # if we don't need to source any concepts from anywhere
3109
- # render without from
3110
- # most likely to happen from inlining constants
3111
- if not any([v for v in self.source_map.values()]):
3112
- return False
3113
- if (
3114
- len(self.source.datasources) == 1
3115
- and self.source.datasources[0].name == CONSTANT_DATASET
3116
- ):
3117
- return False
3118
- return True
3119
-
3120
- @property
3121
- def sourced_concepts(self) -> List[Concept]:
3122
- return [c for c in self.output_columns if c.address in self.source_map]
3123
-
3124
-
3125
- class UnionCTE(BaseModel):
3126
- name: str
3127
- source: QueryDatasource
3128
- parent_ctes: list[CTE | UnionCTE]
3129
- internal_ctes: list[CTE | UnionCTE]
3130
- output_columns: List[Concept]
3131
- grain: Grain
3132
- operator: str = "UNION ALL"
3133
- order_by: Optional[OrderBy] = None
3134
- limit: Optional[int] = None
3135
- hidden_concepts: set[str] = Field(default_factory=set)
3136
- partial_concepts: list[Concept] = Field(default_factory=list)
3137
- existence_source_map: Dict[str, list[str]] = Field(default_factory=dict)
3138
-
3139
- @computed_field # type: ignore
3140
- @property
3141
- def output_lcl(self) -> LooseConceptList:
3142
- return LooseConceptList(concepts=self.output_columns)
3143
-
3144
- def get_alias(self, concept: Concept, source: str | None = None) -> str:
3145
- for cte in self.parent_ctes:
3146
- if concept.address in cte.output_columns:
3147
- if source and source != cte.name:
3148
- continue
3149
- return concept.safe_address
3150
- return "INVALID_ALIAS"
3151
-
3152
- def get_concept(self, address: str) -> Concept | None:
3153
- for cte in self.internal_ctes:
3154
- if address in cte.output_columns:
3155
- match = [x for x in cte.output_columns if x.address == address].pop()
3156
- return match
3157
-
3158
- match_list = [x for x in self.output_columns if x.address == address]
3159
- if match_list:
3160
- return match_list.pop()
3161
- return None
3162
-
3163
- @property
3164
- def source_map(self):
3165
- return {x.address: [] for x in self.output_columns}
3166
-
3167
- @property
3168
- def condition(self):
3169
- return None
3170
-
3171
- @condition.setter
3172
- def condition(self, value):
3173
- raise NotImplementedError
3174
-
3175
- @property
3176
- def safe_identifier(self):
3177
- return self.name
3178
-
3179
- @property
3180
- def group_to_grain(self) -> bool:
3181
- return False
3182
-
3183
- def __add__(self, other):
3184
- if not isinstance(other, UnionCTE) or not other.name == self.name:
3185
- raise SyntaxError("Cannot merge union CTEs")
3186
- return self
3187
-
3188
-
3189
- def merge_ctes(ctes: List[CTE | UnionCTE]) -> List[CTE | UnionCTE]:
3190
- final_ctes_dict: Dict[str, CTE | UnionCTE] = {}
3191
- # merge CTEs
3192
- for cte in ctes:
3193
- if cte.name not in final_ctes_dict:
3194
- final_ctes_dict[cte.name] = cte
3195
- else:
3196
- final_ctes_dict[cte.name] = final_ctes_dict[cte.name] + cte
3197
-
3198
- final_ctes = list(final_ctes_dict.values())
3199
- return final_ctes
3200
-
3201
-
3202
- class CompiledCTE(BaseModel):
3203
- name: str
3204
- statement: str
3205
-
3206
-
3207
- class JoinKey(BaseModel):
3208
- concept: Concept
3209
-
3210
- def __str__(self):
3211
- return str(self.concept)
3212
-
3213
-
3214
- class Join(BaseModel):
3215
- right_cte: CTE
3216
- jointype: JoinType
3217
- left_cte: CTE | None = None
3218
- joinkey_pairs: List[CTEConceptPair] | None = None
3219
- inlined_ctes: set[str] = Field(default_factory=set)
3220
-
3221
- def inline_cte(self, cte: CTE):
3222
- self.inlined_ctes.add(cte.name)
3223
-
3224
- def get_name(self, cte: CTE):
3225
- if cte.identifier in self.inlined_ctes:
3226
- return cte.source.datasources[0].safe_identifier
3227
- return cte.safe_identifier
3228
-
3229
- @property
3230
- def right_name(self) -> str:
3231
- if self.right_cte.identifier in self.inlined_ctes:
3232
- return self.right_cte.source.datasources[0].safe_identifier
3233
- return self.right_cte.safe_identifier
3234
-
3235
- @property
3236
- def right_ref(self) -> str:
3237
- if self.right_cte.identifier in self.inlined_ctes:
3238
- return f"{self.right_cte.source.datasources[0].safe_location} as {self.right_cte.source.datasources[0].safe_identifier}"
3239
- return self.right_cte.safe_identifier
3240
-
3241
- @property
3242
- def unique_id(self) -> str:
3243
- return str(self)
3244
-
3245
- def __str__(self):
3246
- if self.joinkey_pairs:
3247
- return (
3248
- f"{self.jointype.value} join"
3249
- f" {self.right_name} on"
3250
- f" {','.join([k.cte.name + '.'+str(k.left.address)+'='+str(k.right.address) for k in self.joinkey_pairs])}"
3251
- )
3252
- elif self.left_cte:
3253
- return (
3254
- f"{self.jointype.value} JOIN {self.left_cte.name} and"
3255
- f" {self.right_name} on {','.join([str(k) for k in self.joinkey_pairs])}"
3256
- )
3257
- return f"{self.jointype.value} JOIN {self.right_name} on {','.join([str(k) for k in self.joinkey_pairs])}"
3258
-
3259
-
3260
- class UndefinedConcept(Concept, Mergeable, Namespaced):
3261
- model_config = ConfigDict(arbitrary_types_allowed=True)
3262
- name: str
3263
- line_no: int | None = None
3264
- datatype: DataType | ListType | StructType | MapType | NumericType = (
3265
- DataType.UNKNOWN
3266
- )
3267
- purpose: Purpose = Purpose.UNKNOWN
3268
-
3269
- def with_select_context(
3270
- self,
3271
- local_concepts: dict[str, Concept],
3272
- grain: Grain,
3273
- environment: Environment,
3274
- ) -> "Concept":
3275
- if self.address in local_concepts:
3276
- rval = local_concepts[self.address]
3277
- rval = rval.with_select_context(local_concepts, grain, environment)
3278
- return rval
3279
- if environment.concepts.fail_on_missing:
3280
- environment.concepts.raise_undefined(self.address, line_no=self.line_no)
3281
- return self
3282
-
3283
-
3284
- class EnvironmentDatasourceDict(dict):
3285
- def __init__(self, *args, **kwargs) -> None:
3286
- super().__init__(self, *args, **kwargs)
3287
-
3288
- def __getitem__(self, key: str) -> Datasource:
3289
- try:
3290
- return super(EnvironmentDatasourceDict, self).__getitem__(key)
3291
- except KeyError:
3292
- if DEFAULT_NAMESPACE + "." + key in self:
3293
- return self.__getitem__(DEFAULT_NAMESPACE + "." + key)
3294
- if "." in key and key.split(".", 1)[0] == DEFAULT_NAMESPACE:
3295
- return self.__getitem__(key.split(".", 1)[1])
3296
- raise
3297
-
3298
- def values(self) -> ValuesView[Datasource]: # type: ignore
3299
- return super().values()
3300
-
3301
- def items(self) -> ItemsView[str, Datasource]: # type: ignore
3302
- return super().items()
3303
-
3304
- def duplicate(self) -> "EnvironmentDatasourceDict":
3305
- new = EnvironmentDatasourceDict()
3306
- new.update({k: v.duplicate() for k, v in self.items()})
3307
- return new
3308
-
3309
-
3310
- class ImportStatement(HasUUID, BaseModel):
3311
- alias: str
3312
- path: Path
3313
- # environment: Union["Environment", None] = None
3314
- # TODO: this might result in a lot of duplication
3315
- # environment:"Environment"
3316
-
3317
-
3318
- class EnvironmentOptions(BaseModel):
3319
- allow_duplicate_declaration: bool = True
3320
-
3321
-
3322
- def validate_concepts(v) -> EnvironmentConceptDict:
3323
- if isinstance(v, EnvironmentConceptDict):
3324
- return v
3325
- elif isinstance(v, dict):
3326
- return EnvironmentConceptDict(
3327
- **{x: Concept.model_validate(y) for x, y in v.items()}
3328
- )
3329
- raise ValueError
3330
-
3331
-
3332
- def validate_datasources(v) -> EnvironmentDatasourceDict:
3333
- if isinstance(v, EnvironmentDatasourceDict):
3334
- return v
3335
- elif isinstance(v, dict):
3336
- return EnvironmentDatasourceDict(
3337
- **{x: Datasource.model_validate(y) for x, y in v.items()}
3338
- )
3339
- raise ValueError
3340
-
3341
-
3342
- class Environment(BaseModel):
3343
- model_config = ConfigDict(arbitrary_types_allowed=True, strict=False)
3344
-
3345
- concepts: Annotated[EnvironmentConceptDict, PlainValidator(validate_concepts)] = (
3346
- Field(default_factory=EnvironmentConceptDict)
3347
- )
3348
- datasources: Annotated[
3349
- EnvironmentDatasourceDict, PlainValidator(validate_datasources)
3350
- ] = Field(default_factory=EnvironmentDatasourceDict)
3351
- functions: Dict[str, Function] = Field(default_factory=dict)
3352
- data_types: Dict[str, DataType] = Field(default_factory=dict)
3353
- imports: Dict[str, list[ImportStatement]] = Field(
3354
- default_factory=lambda: defaultdict(list) # type: ignore
3355
- )
3356
- namespace: str = DEFAULT_NAMESPACE
3357
- working_path: str | Path = Field(default_factory=lambda: os.getcwd())
3358
- environment_config: EnvironmentOptions = Field(default_factory=EnvironmentOptions)
3359
- version: str = Field(default_factory=get_version)
3360
- cte_name_map: Dict[str, str] = Field(default_factory=dict)
3361
- materialized_concepts: set[str] = Field(default_factory=set)
3362
- alias_origin_lookup: Dict[str, Concept] = Field(default_factory=dict)
3363
- # TODO: support freezing environments to avoid mutation
3364
- frozen: bool = False
3365
- env_file_path: Path | None = None
3366
-
3367
- def freeze(self):
3368
- self.frozen = True
3369
-
3370
- def thaw(self):
3371
- self.frozen = False
3372
-
3373
- def duplicate(self):
3374
- return Environment.model_construct(
3375
- datasources=self.datasources.duplicate(),
3376
- concepts=self.concepts.duplicate(),
3377
- functions=dict(self.functions),
3378
- data_types=dict(self.data_types),
3379
- imports=dict(self.imports),
3380
- namespace=self.namespace,
3381
- working_path=self.working_path,
3382
- environment_config=self.environment_config,
3383
- version=self.version,
3384
- cte_name_map=dict(self.cte_name_map),
3385
- materialized_concepts=set(self.materialized_concepts),
3386
- alias_origin_lookup={
3387
- k: v.duplicate() for k, v in self.alias_origin_lookup.items()
3388
- },
3389
- )
3390
-
3391
- def __init__(self, **data):
3392
- super().__init__(**data)
3393
- concept = Concept(
3394
- name="_env_working_path",
3395
- namespace=self.namespace,
3396
- lineage=Function(
3397
- operator=FunctionType.CONSTANT,
3398
- arguments=[str(self.working_path)],
3399
- output_datatype=DataType.STRING,
3400
- output_purpose=Purpose.CONSTANT,
3401
- ),
3402
- datatype=DataType.STRING,
3403
- purpose=Purpose.CONSTANT,
3404
- )
3405
- self.add_concept(concept)
3406
-
3407
- # def freeze(self):
3408
- # self.frozen = True
3409
-
3410
- # def thaw(self):
3411
- # self.frozen = False
3412
-
3413
- @classmethod
3414
- def from_file(cls, path: str | Path) -> "Environment":
3415
- if isinstance(path, str):
3416
- path = Path(path)
3417
- with open(path, "r") as f:
3418
- read = f.read()
3419
- return Environment(working_path=path.parent, env_file_path=path).parse(read)[0]
3420
-
3421
- @classmethod
3422
- def from_string(cls, input: str) -> "Environment":
3423
- return Environment().parse(input)[0]
3424
-
3425
- @classmethod
3426
- def from_cache(cls, path) -> Optional["Environment"]:
3427
- with open(path, "r") as f:
3428
- read = f.read()
3429
- base = cls.model_validate_json(read)
3430
- version = get_version()
3431
- if base.version != version:
3432
- return None
3433
- return base
3434
-
3435
- def to_cache(self, path: Optional[str | Path] = None) -> Path:
3436
- if not path:
3437
- ppath = Path(self.working_path) / ENV_CACHE_NAME
3438
- else:
3439
- ppath = Path(path)
3440
- with open(ppath, "w") as f:
3441
- f.write(self.model_dump_json())
3442
- return ppath
3443
-
3444
- def gen_concept_list_caches(self) -> None:
3445
- concrete_addresses = set()
3446
- for datasource in self.datasources.values():
3447
- for concept in datasource.output_concepts:
3448
- concrete_addresses.add(concept.address)
3449
- self.materialized_concepts = set(
3450
- [
3451
- c.address
3452
- for c in self.concepts.values()
3453
- if c.address in concrete_addresses
3454
- ]
3455
- + [
3456
- c.address
3457
- for c in self.alias_origin_lookup.values()
3458
- if c.address in concrete_addresses
3459
- ],
3460
- )
3461
-
3462
- def validate_concept(self, new_concept: Concept, meta: Meta | None = None):
3463
- lookup = new_concept.address
3464
- existing: Concept = self.concepts.get(lookup) # type: ignore
3465
- if not existing:
3466
- return
3467
-
3468
- def handle_persist():
3469
- deriv_lookup = (
3470
- f"{existing.namespace}.{PERSISTED_CONCEPT_PREFIX}_{existing.name}"
3471
- )
3472
-
3473
- alt_source = self.alias_origin_lookup.get(deriv_lookup)
3474
- if not alt_source:
3475
- return None
3476
- # if the new concept binding has no lineage
3477
- # nothing to cause us to think a persist binding
3478
- # needs to be invalidated
3479
- if not new_concept.lineage:
3480
- return existing
3481
- if str(alt_source.lineage) == str(new_concept.lineage):
3482
- logger.info(
3483
- f"Persisted concept {existing.address} matched redeclaration, keeping current persistence binding."
3484
- )
3485
- return existing
3486
- logger.warning(
3487
- f"Persisted concept {existing.address} lineage {str(alt_source.lineage)} did not match redeclaration {str(new_concept.lineage)}, overwriting and invalidating persist binding."
3488
- )
3489
- for k, datasource in self.datasources.items():
3490
- if existing.address in datasource.output_concepts:
3491
- datasource.columns = [
3492
- x
3493
- for x in datasource.columns
3494
- if x.concept.address != existing.address
3495
- ]
3496
- return None
3497
-
3498
- if existing and self.environment_config.allow_duplicate_declaration:
3499
- if existing.metadata.concept_source == ConceptSource.PERSIST_STATEMENT:
3500
- return handle_persist()
3501
- return
3502
- elif existing.metadata:
3503
- if existing.metadata.concept_source == ConceptSource.PERSIST_STATEMENT:
3504
- return handle_persist()
3505
- # if the existing concept is auto derived, we can overwrite it
3506
- if existing.metadata.concept_source == ConceptSource.AUTO_DERIVED:
3507
- return None
3508
- elif meta and existing.metadata:
3509
- raise ValueError(
3510
- f"Assignment to concept '{lookup}' on line {meta.line} is a duplicate"
3511
- f" declaration; '{lookup}' was originally defined on line"
3512
- f" {existing.metadata.line_number}"
3513
- )
3514
- elif existing.metadata:
3515
- raise ValueError(
3516
- f"Assignment to concept '{lookup}' is a duplicate declaration;"
3517
- f" '{lookup}' was originally defined on line"
3518
- f" {existing.metadata.line_number}"
3519
- )
3520
- raise ValueError(
3521
- f"Assignment to concept '{lookup}' is a duplicate declaration;"
3522
- )
3523
-
3524
- def add_import(
3525
- self, alias: str, source: Environment, imp_stm: ImportStatement | None = None
3526
- ):
3527
- if self.frozen:
3528
- raise ValueError("Environment is frozen, cannot add imports")
3529
- exists = False
3530
- existing = self.imports[alias]
3531
- if imp_stm:
3532
- if any(
3533
- [x.path == imp_stm.path and x.alias == imp_stm.alias for x in existing]
3534
- ):
3535
- exists = True
3536
- else:
3537
- if any(
3538
- [x.path == source.working_path and x.alias == alias for x in existing]
3539
- ):
3540
- exists = True
3541
- imp_stm = ImportStatement(alias=alias, path=Path(source.working_path))
3542
- same_namespace = alias == self.namespace
3543
-
3544
- if not exists:
3545
- self.imports[alias].append(imp_stm)
3546
- # we can't exit early
3547
- # as there may be new concepts
3548
- for k, concept in source.concepts.items():
3549
- # skip internal namespace
3550
- if INTERNAL_NAMESPACE in concept.address:
3551
- continue
3552
- if same_namespace:
3553
- new = self.add_concept(concept, _ignore_cache=True)
3554
- else:
3555
- new = self.add_concept(
3556
- concept.with_namespace(alias), _ignore_cache=True
3557
- )
3558
-
3559
- k = address_with_namespace(k, alias)
3560
- # set this explicitly, to handle aliasing
3561
- self.concepts[k] = new
3562
-
3563
- for _, datasource in source.datasources.items():
3564
- if same_namespace:
3565
- self.add_datasource(datasource, _ignore_cache=True)
3566
- else:
3567
- self.add_datasource(
3568
- datasource.with_namespace(alias), _ignore_cache=True
3569
- )
3570
- for key, val in source.alias_origin_lookup.items():
3571
- if same_namespace:
3572
- self.alias_origin_lookup[key] = val
3573
- else:
3574
- self.alias_origin_lookup[address_with_namespace(key, alias)] = (
3575
- val.with_namespace(alias)
3576
- )
3577
-
3578
- self.gen_concept_list_caches()
3579
- return self
3580
-
3581
- def add_file_import(
3582
- self, path: str | Path, alias: str, env: Environment | None = None
3583
- ):
3584
- if self.frozen:
3585
- raise ValueError("Environment is frozen, cannot add imports")
3586
- from trilogy.parsing.parse_engine import (
3587
- PARSER,
3588
- ParseToObjects,
3589
- gen_cache_lookup,
3590
- )
3591
-
3592
- if isinstance(path, str):
3593
- if path.endswith(".preql"):
3594
- path = path.rsplit(".", 1)[0]
3595
- if "." not in path:
3596
- target = Path(self.working_path, path)
3597
- else:
3598
- target = Path(self.working_path, *path.split("."))
3599
- target = target.with_suffix(".preql")
3600
- else:
3601
- target = path
3602
- if not env:
3603
- parse_address = gen_cache_lookup(str(target), alias, str(self.working_path))
3604
- try:
3605
- with open(target, "r", encoding="utf-8") as f:
3606
- text = f.read()
3607
- nenv = Environment(
3608
- working_path=target.parent,
3609
- )
3610
- nenv.concepts.fail_on_missing = False
3611
- nparser = ParseToObjects(
3612
- environment=Environment(
3613
- working_path=target.parent,
3614
- ),
3615
- parse_address=parse_address,
3616
- token_address=target,
3617
- )
3618
- nparser.set_text(text)
3619
- nparser.transform(PARSER.parse(text))
3620
- nparser.hydrate_missing()
3621
-
3622
- except Exception as e:
3623
- raise ImportError(
3624
- f"Unable to import file {target.parent}, parsing error: {e}"
3625
- )
3626
- env = nparser.environment
3627
- imps = ImportStatement(alias=alias, path=target)
3628
- self.add_import(alias, source=env, imp_stm=imps)
3629
- return imps
3630
-
3631
- def parse(
3632
- self, input: str, namespace: str | None = None, persist: bool = False
3633
- ) -> Tuple[Environment, list]:
3634
- from trilogy import parse
3635
- from trilogy.core.query_processor import process_persist
3636
-
3637
- if namespace:
3638
- new = Environment()
3639
- _, queries = new.parse(input)
3640
- self.add_import(namespace, new)
3641
- return self, queries
3642
- _, queries = parse(input, self)
3643
- generatable = [
3644
- x
3645
- for x in queries
3646
- if isinstance(
3647
- x,
3648
- (
3649
- SelectStatement,
3650
- PersistStatement,
3651
- MultiSelectStatement,
3652
- ShowStatement,
3653
- ),
3654
- )
3655
- ]
3656
- while generatable:
3657
- t = generatable.pop(0)
3658
- if isinstance(t, PersistStatement) and persist:
3659
- processed = process_persist(self, t)
3660
- self.add_datasource(processed.datasource)
3661
- return self, queries
3662
-
3663
- def add_concept(
3664
- self,
3665
- concept: Concept,
3666
- meta: Meta | None = None,
3667
- force: bool = False,
3668
- add_derived: bool = True,
3669
- _ignore_cache: bool = False,
3670
- ):
3671
- if self.frozen:
3672
- raise ValueError("Environment is frozen, cannot add concepts")
3673
- if not force:
3674
- existing = self.validate_concept(concept, meta=meta)
3675
- if existing:
3676
- concept = existing
3677
- self.concepts[concept.address] = concept
3678
- from trilogy.core.environment_helpers import generate_related_concepts
3679
-
3680
- generate_related_concepts(concept, self, meta=meta, add_derived=add_derived)
3681
- if not _ignore_cache:
3682
- self.gen_concept_list_caches()
3683
- return concept
3684
-
3685
- def add_datasource(
3686
- self,
3687
- datasource: Datasource,
3688
- meta: Meta | None = None,
3689
- _ignore_cache: bool = False,
3690
- ):
3691
- if self.frozen:
3692
- raise ValueError("Environment is frozen, cannot add datasource")
3693
- self.datasources[datasource.identifier] = datasource
3694
-
3695
- eligible_to_promote_roots = datasource.non_partial_for is None
3696
- # mark this as canonical source
3697
- for current_concept in datasource.output_concepts:
3698
- if not eligible_to_promote_roots:
3699
- continue
3700
-
3701
- current_derivation = current_concept.derivation
3702
- # TODO: refine this section;
3703
- # too hacky for maintainability
3704
- if current_derivation not in (PurposeLineage.ROOT, PurposeLineage.CONSTANT):
3705
- persisted = f"{PERSISTED_CONCEPT_PREFIX}_" + current_concept.name
3706
- # override the current concept source to reflect that it's now coming from a datasource
3707
- if (
3708
- current_concept.metadata.concept_source
3709
- != ConceptSource.PERSIST_STATEMENT
3710
- ):
3711
- new_concept = current_concept.model_copy(deep=True)
3712
- new_concept.set_name(persisted)
3713
- self.add_concept(
3714
- new_concept, meta=meta, force=True, _ignore_cache=True
3715
- )
3716
- current_concept.metadata.concept_source = (
3717
- ConceptSource.PERSIST_STATEMENT
3718
- )
3719
- # remove the associated lineage
3720
- # to make this a root for discovery purposes
3721
- # as it now "exists" in a table
3722
- current_concept.lineage = None
3723
- current_concept = current_concept.with_default_grain()
3724
- self.add_concept(
3725
- current_concept, meta=meta, force=True, _ignore_cache=True
3726
- )
3727
- self.merge_concept(new_concept, current_concept, [])
3728
- else:
3729
- self.add_concept(current_concept, meta=meta, _ignore_cache=True)
3730
- if not _ignore_cache:
3731
- self.gen_concept_list_caches()
3732
- return datasource
3733
-
3734
- def delete_datasource(
3735
- self,
3736
- address: str,
3737
- meta: Meta | None = None,
3738
- ) -> bool:
3739
- if self.frozen:
3740
- raise ValueError("Environment is frozen, cannot delete datsources")
3741
- if address in self.datasources:
3742
- del self.datasources[address]
3743
- self.gen_concept_list_caches()
3744
- return True
3745
- return False
3746
-
3747
- def merge_concept(
3748
- self,
3749
- source: Concept,
3750
- target: Concept,
3751
- modifiers: List[Modifier],
3752
- force: bool = False,
3753
- ) -> bool:
3754
- if self.frozen:
3755
- raise ValueError("Environment is frozen, cannot merge concepts")
3756
- replacements = {}
3757
-
3758
- # exit early if we've run this
3759
- if source.address in self.alias_origin_lookup and not force:
3760
- if self.concepts[source.address] == target:
3761
- return False
3762
- self.alias_origin_lookup[source.address] = source
3763
- for k, v in self.concepts.items():
3764
- if v.address == target.address:
3765
- v.pseudonyms.add(source.address)
3766
-
3767
- if v.address == source.address:
3768
- replacements[k] = target
3769
- v.pseudonyms.add(target.address)
3770
- # we need to update keys and grains of all concepts
3771
- else:
3772
- replacements[k] = v.with_merge(source, target, modifiers)
3773
- self.concepts.update(replacements)
3774
-
3775
- for k, ds in self.datasources.items():
3776
- if source.address in ds.output_lcl:
3777
- ds.merge_concept(source, target, modifiers=modifiers)
3778
- return True
3779
-
3780
-
3781
- class LazyEnvironment(Environment):
3782
- """Variant of environment to defer parsing of a path
3783
- until relevant attributes accessed."""
3784
-
3785
- load_path: Path
3786
- loaded: bool = False
3787
-
3788
- def __getattribute__(self, name):
3789
- if name in (
3790
- "load_path",
3791
- "loaded",
3792
- "working_path",
3793
- "model_config",
3794
- "model_fields",
3795
- "model_post_init",
3796
- ) or name.startswith("_"):
3797
- return super().__getattribute__(name)
3798
- if not self.loaded:
3799
- logger.info(
3800
- f"lazily evaluating load path {self.load_path} to access {name}"
3801
- )
3802
- from trilogy import parse
3803
-
3804
- env = Environment(working_path=str(self.working_path))
3805
- with open(self.load_path, "r") as f:
3806
- parse(f.read(), env)
3807
- self.loaded = True
3808
- self.datasources = env.datasources
3809
- self.concepts = env.concepts
3810
- self.imports = env.imports
3811
- return super().__getattribute__(name)
3812
-
3813
-
3814
- class Comparison(
3815
- ConceptArgs, Mergeable, Namespaced, ConstantInlineable, SelectContext, BaseModel
3816
- ):
3817
- left: Union[
3818
- int,
3819
- str,
3820
- float,
3821
- list,
3822
- bool,
3823
- datetime,
3824
- date,
3825
- Function,
3826
- Concept,
3827
- "Conditional",
3828
- DataType,
3829
- "Comparison",
3830
- "Parenthetical",
3831
- MagicConstants,
3832
- WindowItem,
3833
- AggregateWrapper,
3834
- ]
3835
- right: Union[
3836
- int,
3837
- str,
3838
- float,
3839
- list,
3840
- bool,
3841
- date,
3842
- datetime,
3843
- Concept,
3844
- Function,
3845
- "Conditional",
3846
- DataType,
3847
- "Comparison",
3848
- "Parenthetical",
3849
- MagicConstants,
3850
- WindowItem,
3851
- AggregateWrapper,
3852
- TupleWrapper,
3853
- ]
3854
- operator: ComparisonOperator
3855
-
3856
- def hydrate_missing(self, concepts: EnvironmentConceptDict):
3857
- if isinstance(self.left, UndefinedConcept) and self.left.address in concepts:
3858
- self.left = concepts[self.left.address]
3859
- if isinstance(self.right, UndefinedConcept) and self.right.address in concepts:
3860
- self.right = concepts[self.right.address]
3861
- if isinstance(self.left, Mergeable):
3862
- self.left.hydrate_missing(concepts)
3863
- if isinstance(self.right, Mergeable):
3864
- self.right.hydrate_missing(concepts)
3865
- return self
3866
-
3867
- def __init__(self, *args, **kwargs) -> None:
3868
- super().__init__(*args, **kwargs)
3869
- if self.operator in (ComparisonOperator.IS, ComparisonOperator.IS_NOT):
3870
- if self.right != MagicConstants.NULL and DataType.BOOL != arg_to_datatype(
3871
- self.right
3872
- ):
3873
- raise SyntaxError(
3874
- f"Cannot use {self.operator.value} with non-null or boolean value {self.right}"
3875
- )
3876
- elif self.operator in (ComparisonOperator.IN, ComparisonOperator.NOT_IN):
3877
- right = arg_to_datatype(self.right)
3878
- if not isinstance(self.right, Concept) and not isinstance(right, ListType):
3879
- raise SyntaxError(
3880
- f"Cannot use {self.operator.value} with non-list type {right} in {str(self)}"
3881
- )
3882
-
3883
- elif isinstance(right, ListType) and not is_compatible_datatype(
3884
- arg_to_datatype(self.left), right.value_data_type
3885
- ):
3886
- raise SyntaxError(
3887
- f"Cannot compare {arg_to_datatype(self.left)} and {right} with operator {self.operator} in {str(self)}"
3888
- )
3889
- elif isinstance(self.right, Concept) and not is_compatible_datatype(
3890
- arg_to_datatype(self.left), arg_to_datatype(self.right)
3891
- ):
3892
- raise SyntaxError(
3893
- f"Cannot compare {arg_to_datatype(self.left)} and {arg_to_datatype(self.right)} with operator {self.operator} in {str(self)}"
3894
- )
3895
- else:
3896
- if not is_compatible_datatype(
3897
- arg_to_datatype(self.left), arg_to_datatype(self.right)
3898
- ):
3899
- raise SyntaxError(
3900
- f"Cannot compare {arg_to_datatype(self.left)} and {arg_to_datatype(self.right)} of different types with operator {self.operator} in {str(self)}"
3901
- )
3902
-
3903
- def __add__(self, other):
3904
- if other is None:
3905
- return self
3906
- if not isinstance(other, (Comparison, Conditional, Parenthetical)):
3907
- raise ValueError("Cannot add Comparison to non-Comparison")
3908
- if other == self:
3909
- return self
3910
- return Conditional(left=self, right=other, operator=BooleanOperator.AND)
3911
-
3912
- def __repr__(self):
3913
- return f"{str(self.left)} {self.operator.value} {str(self.right)}"
3914
-
3915
- def __str__(self):
3916
- return self.__repr__()
3917
-
3918
- def __eq__(self, other):
3919
- if not isinstance(other, Comparison):
3920
- return False
3921
- return (
3922
- self.left == other.left
3923
- and self.right == other.right
3924
- and self.operator == other.operator
3925
- )
3926
-
3927
- def inline_constant(self, constant: Concept):
3928
- assert isinstance(constant.lineage, Function)
3929
- new_val = constant.lineage.arguments[0]
3930
- if isinstance(self.left, ConstantInlineable):
3931
- new_left = self.left.inline_constant(constant)
3932
- elif self.left == constant:
3933
- new_left = new_val
3934
- else:
3935
- new_left = self.left
3936
-
3937
- if isinstance(self.right, ConstantInlineable):
3938
- new_right = self.right.inline_constant(constant)
3939
- elif self.right == constant:
3940
- new_right = new_val
3941
- else:
3942
- new_right = self.right
3943
-
3944
- return self.__class__(
3945
- left=new_left,
3946
- right=new_right,
3947
- operator=self.operator,
3948
- )
3949
-
3950
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
3951
- return self.__class__(
3952
- left=(
3953
- self.left.with_merge(source, target, modifiers)
3954
- if isinstance(self.left, Mergeable)
3955
- else self.left
3956
- ),
3957
- right=(
3958
- self.right.with_merge(source, target, modifiers)
3959
- if isinstance(self.right, Mergeable)
3960
- else self.right
3961
- ),
3962
- operator=self.operator,
3963
- )
3964
-
3965
- def with_namespace(self, namespace: str):
3966
- return self.__class__(
3967
- left=(
3968
- self.left.with_namespace(namespace)
3969
- if isinstance(self.left, Namespaced)
3970
- else self.left
3971
- ),
3972
- right=(
3973
- self.right.with_namespace(namespace)
3974
- if isinstance(self.right, Namespaced)
3975
- else self.right
3976
- ),
3977
- operator=self.operator,
3978
- )
3979
-
3980
- def with_select_context(
3981
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
3982
- ):
3983
- return self.__class__(
3984
- left=(
3985
- self.left.with_select_context(local_concepts, grain, environment)
3986
- if isinstance(self.left, SelectContext)
3987
- else self.left
3988
- ),
3989
- # the right side does NOT need to inherit select grain
3990
- right=(
3991
- self.right.with_select_context(local_concepts, grain, environment)
3992
- if isinstance(self.right, SelectContext)
3993
- else self.right
3994
- ),
3995
- operator=self.operator,
3996
- )
3997
-
3998
- @property
3999
- def input(self) -> List[Concept]:
4000
- output: List[Concept] = []
4001
- if isinstance(self.left, (Concept,)):
4002
- output += [self.left]
4003
- if isinstance(
4004
- self.left, (Comparison, SubselectComparison, Conditional, Parenthetical)
4005
- ):
4006
- output += self.left.input
4007
- if isinstance(self.left, FilterItem):
4008
- output += self.left.concept_arguments
4009
- if isinstance(self.left, Function):
4010
- output += self.left.concept_arguments
4011
-
4012
- if isinstance(self.right, (Concept,)):
4013
- output += [self.right]
4014
- if isinstance(
4015
- self.right, (Comparison, SubselectComparison, Conditional, Parenthetical)
4016
- ):
4017
- output += self.right.input
4018
- if isinstance(self.right, FilterItem):
4019
- output += self.right.concept_arguments
4020
- if isinstance(self.right, Function):
4021
- output += self.right.concept_arguments
4022
- return output
4023
-
4024
- @property
4025
- def concept_arguments(self) -> List[Concept]:
4026
- """Return concepts directly referenced in where clause"""
4027
- output = []
4028
- output += get_concept_arguments(self.left)
4029
- output += get_concept_arguments(self.right)
4030
- return output
4031
-
4032
- @property
4033
- def row_arguments(self) -> List[Concept]:
4034
- output = []
4035
- if isinstance(self.left, ConceptArgs):
4036
- output += self.left.row_arguments
4037
- else:
4038
- output += get_concept_arguments(self.left)
4039
- if isinstance(self.right, ConceptArgs):
4040
- output += self.right.row_arguments
4041
- else:
4042
- output += get_concept_arguments(self.right)
4043
- return output
4044
-
4045
- @property
4046
- def existence_arguments(self) -> List[Tuple[Concept, ...]]:
4047
- """Return concepts directly referenced in where clause"""
4048
- output: List[Tuple[Concept, ...]] = []
4049
- if isinstance(self.left, ConceptArgs):
4050
- output += self.left.existence_arguments
4051
- if isinstance(self.right, ConceptArgs):
4052
- output += self.right.existence_arguments
4053
- return output
4054
-
4055
-
4056
- class SubselectComparison(Comparison):
4057
- def __eq__(self, other):
4058
- if not isinstance(other, SubselectComparison):
4059
- return False
4060
-
4061
- comp = (
4062
- self.left == other.left
4063
- and self.right == other.right
4064
- and self.operator == other.operator
4065
- )
4066
- return comp
4067
-
4068
- @property
4069
- def row_arguments(self) -> List[Concept]:
4070
- return get_concept_arguments(self.left)
4071
-
4072
- @property
4073
- def existence_arguments(self) -> list[tuple["Concept", ...]]:
4074
- return [tuple(get_concept_arguments(self.right))]
4075
-
4076
- def with_select_context(
4077
- self,
4078
- local_concepts: dict[str, Concept],
4079
- grain: Grain,
4080
- environment: Environment,
4081
- ):
4082
- # there's no need to pass the select grain through to a subselect comparison on the right
4083
- return self.__class__(
4084
- left=(
4085
- self.left.with_select_context(local_concepts, grain, environment)
4086
- if isinstance(self.left, SelectContext)
4087
- else self.left
4088
- ),
4089
- right=self.right,
4090
- operator=self.operator,
4091
- )
4092
-
4093
-
4094
- class CaseWhen(Namespaced, SelectContext, BaseModel):
4095
- comparison: Conditional | SubselectComparison | Comparison
4096
- expr: "Expr"
4097
-
4098
- def __str__(self):
4099
- return self.__repr__()
4100
-
4101
- def __repr__(self):
4102
- return f"WHEN {str(self.comparison)} THEN {str(self.expr)}"
4103
-
4104
- @property
4105
- def concept_arguments(self):
4106
- return get_concept_arguments(self.comparison) + get_concept_arguments(self.expr)
4107
-
4108
- def with_namespace(self, namespace: str) -> CaseWhen:
4109
- return CaseWhen(
4110
- comparison=self.comparison.with_namespace(namespace),
4111
- expr=(
4112
- self.expr.with_namespace(namespace)
4113
- if isinstance(
4114
- self.expr,
4115
- Namespaced,
4116
- )
4117
- else self.expr
4118
- ),
4119
- )
4120
-
4121
- def with_select_context(
4122
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
4123
- ) -> CaseWhen:
4124
- return CaseWhen(
4125
- comparison=self.comparison.with_select_context(
4126
- local_concepts, grain, environment
4127
- ),
4128
- expr=(
4129
- (self.expr.with_select_context(local_concepts, grain, environment))
4130
- if isinstance(self.expr, SelectContext)
4131
- else self.expr
4132
- ),
4133
- )
4134
-
4135
-
4136
- class CaseElse(Namespaced, SelectContext, BaseModel):
4137
- expr: "Expr"
4138
- # this ensures that it's easily differentiable from CaseWhen
4139
- discriminant: ComparisonOperator = ComparisonOperator.ELSE
4140
-
4141
- @property
4142
- def concept_arguments(self):
4143
- return get_concept_arguments(self.expr)
4144
-
4145
- def with_select_context(
4146
- self,
4147
- local_concepts: dict[str, Concept],
4148
- grain: Grain,
4149
- environment: Environment,
4150
- ):
4151
- return CaseElse(
4152
- discriminant=self.discriminant,
4153
- expr=(
4154
- self.expr.with_select_context(local_concepts, grain, environment)
4155
- if isinstance(
4156
- self.expr,
4157
- SelectContext,
4158
- )
4159
- else self.expr
4160
- ),
4161
- )
4162
-
4163
- def with_namespace(self, namespace: str) -> CaseElse:
4164
- return CaseElse(
4165
- discriminant=self.discriminant,
4166
- expr=(
4167
- self.expr.with_namespace(namespace)
4168
- if isinstance(
4169
- self.expr,
4170
- Namespaced,
4171
- )
4172
- else self.expr
4173
- ),
4174
- )
4175
-
4176
-
4177
- class Conditional(
4178
- Mergeable, ConceptArgs, Namespaced, ConstantInlineable, SelectContext, BaseModel
4179
- ):
4180
- left: Union[
4181
- int,
4182
- str,
4183
- float,
4184
- list,
4185
- bool,
4186
- MagicConstants,
4187
- Concept,
4188
- Comparison,
4189
- "Conditional",
4190
- "Parenthetical",
4191
- Function,
4192
- FilterItem,
4193
- ]
4194
- right: Union[
4195
- int,
4196
- str,
4197
- float,
4198
- list,
4199
- bool,
4200
- MagicConstants,
4201
- Concept,
4202
- Comparison,
4203
- "Conditional",
4204
- "Parenthetical",
4205
- Function,
4206
- FilterItem,
4207
- ]
4208
- operator: BooleanOperator
4209
-
4210
- def __add__(self, other) -> "Conditional":
4211
- if other is None:
4212
- return self
4213
- elif str(other) == str(self):
4214
- return self
4215
- elif isinstance(other, (Comparison, Conditional, Parenthetical)):
4216
- return Conditional(left=self, right=other, operator=BooleanOperator.AND)
4217
- raise ValueError(f"Cannot add {self.__class__} and {type(other)}")
4218
-
4219
- def __str__(self):
4220
- return self.__repr__()
4221
-
4222
- def __repr__(self):
4223
- return f"{str(self.left)} {self.operator.value} {str(self.right)}"
4224
-
4225
- def __eq__(self, other):
4226
- if not isinstance(other, Conditional):
4227
- return False
4228
- return (
4229
- self.left == other.left
4230
- and self.right == other.right
4231
- and self.operator == other.operator
4232
- )
4233
-
4234
- def inline_constant(self, constant: Concept) -> "Conditional":
4235
- assert isinstance(constant.lineage, Function)
4236
- new_val = constant.lineage.arguments[0]
4237
- if isinstance(self.left, ConstantInlineable):
4238
- new_left = self.left.inline_constant(constant)
4239
- elif self.left == constant:
4240
- new_left = new_val
4241
- else:
4242
- new_left = self.left
4243
-
4244
- if isinstance(self.right, ConstantInlineable):
4245
- new_right = self.right.inline_constant(constant)
4246
- elif self.right == constant:
4247
- new_right = new_val
4248
- else:
4249
- new_right = self.right
4250
-
4251
- if self.right == constant:
4252
- new_right = new_val
4253
-
4254
- return Conditional(
4255
- left=new_left,
4256
- right=new_right,
4257
- operator=self.operator,
4258
- )
4259
-
4260
- def with_namespace(self, namespace: str) -> "Conditional":
4261
- return Conditional(
4262
- left=(
4263
- self.left.with_namespace(namespace)
4264
- if isinstance(self.left, Namespaced)
4265
- else self.left
4266
- ),
4267
- right=(
4268
- self.right.with_namespace(namespace)
4269
- if isinstance(self.right, Namespaced)
4270
- else self.right
4271
- ),
4272
- operator=self.operator,
4273
- )
4274
-
4275
- def with_merge(
4276
- self, source: Concept, target: Concept, modifiers: List[Modifier]
4277
- ) -> "Conditional":
4278
- return Conditional(
4279
- left=(
4280
- self.left.with_merge(source, target, modifiers)
4281
- if isinstance(self.left, Mergeable)
4282
- else self.left
4283
- ),
4284
- right=(
4285
- self.right.with_merge(source, target, modifiers)
4286
- if isinstance(self.right, Mergeable)
4287
- else self.right
4288
- ),
4289
- operator=self.operator,
4290
- )
4291
-
4292
- def with_select_context(
4293
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
4294
- ):
4295
- return Conditional(
4296
- left=(
4297
- self.left.with_select_context(local_concepts, grain, environment)
4298
- if isinstance(self.left, SelectContext)
4299
- else self.left
4300
- ),
4301
- right=(
4302
- self.right.with_select_context(local_concepts, grain, environment)
4303
- if isinstance(self.right, SelectContext)
4304
- else self.right
4305
- ),
4306
- operator=self.operator,
4307
- )
4308
-
4309
- @property
4310
- def input(self) -> List[Concept]:
4311
- """Return concepts directly referenced in where clause"""
4312
- output = []
4313
-
4314
- for x in (self.left, self.right):
4315
- if isinstance(x, Concept):
4316
- output += x.input
4317
- elif isinstance(x, (Comparison, Conditional)):
4318
- output += x.input
4319
- elif isinstance(x, (Function, Parenthetical, FilterItem)):
4320
- output += x.concept_arguments
4321
- return output
4322
-
4323
- @property
4324
- def concept_arguments(self) -> List[Concept]:
4325
- """Return concepts directly referenced in where clause"""
4326
- output = []
4327
- output += get_concept_arguments(self.left)
4328
- output += get_concept_arguments(self.right)
4329
- return output
4330
-
4331
- @property
4332
- def row_arguments(self) -> List[Concept]:
4333
- output = []
4334
- if isinstance(self.left, ConceptArgs):
4335
- output += self.left.row_arguments
4336
- else:
4337
- output += get_concept_arguments(self.left)
4338
- if isinstance(self.right, ConceptArgs):
4339
- output += self.right.row_arguments
4340
- else:
4341
- output += get_concept_arguments(self.right)
4342
- return output
4343
-
4344
- @property
4345
- def existence_arguments(self) -> list[tuple["Concept", ...]]:
4346
- output = []
4347
- if isinstance(self.left, ConceptArgs):
4348
- output += self.left.existence_arguments
4349
- if isinstance(self.right, ConceptArgs):
4350
- output += self.right.existence_arguments
4351
- return output
4352
-
4353
- def decompose(self):
4354
- chunks = []
4355
- if self.operator == BooleanOperator.AND:
4356
- for val in [self.left, self.right]:
4357
- if isinstance(val, Conditional):
4358
- chunks.extend(val.decompose())
4359
- else:
4360
- chunks.append(val)
4361
- else:
4362
- chunks.append(self)
4363
- return chunks
4364
-
4365
-
4366
- class AggregateWrapper(Mergeable, Namespaced, SelectContext, BaseModel):
4367
- function: Function
4368
- by: List[Concept] = Field(default_factory=list)
4369
-
4370
- def __str__(self):
4371
- grain_str = [str(c) for c in self.by] if self.by else "abstract"
4372
- return f"{str(self.function)}<{grain_str}>"
4373
-
4374
- @property
4375
- def datatype(self):
4376
- return self.function.datatype
4377
-
4378
- @property
4379
- def concept_arguments(self) -> List[Concept]:
4380
- return self.function.concept_arguments + self.by
4381
-
4382
- @property
4383
- def output_datatype(self):
4384
- return self.function.output_datatype
4385
-
4386
- @property
4387
- def output_purpose(self):
4388
- return self.function.output_purpose
4389
-
4390
- @property
4391
- def arguments(self):
4392
- return self.function.arguments
4393
-
4394
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
4395
- return AggregateWrapper(
4396
- function=self.function.with_merge(source, target, modifiers=modifiers),
4397
- by=(
4398
- [c.with_merge(source, target, modifiers) for c in self.by]
4399
- if self.by
4400
- else []
4401
- ),
4402
- )
4403
-
4404
- def with_namespace(self, namespace: str) -> "AggregateWrapper":
4405
- return AggregateWrapper(
4406
- function=self.function.with_namespace(namespace),
4407
- by=[c.with_namespace(namespace) for c in self.by] if self.by else [],
4408
- )
4409
-
4410
- def with_select_context(
4411
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
4412
- ) -> AggregateWrapper:
4413
- if not self.by:
4414
- by = [environment.concepts[c] for c in grain.components]
4415
- else:
4416
- by = [
4417
- x.with_select_context(local_concepts, grain, environment)
4418
- for x in self.by
4419
- ]
4420
- parent = self.function.with_select_context(local_concepts, grain, environment)
4421
- return AggregateWrapper(function=parent, by=by)
4422
-
4423
-
4424
- class WhereClause(Mergeable, ConceptArgs, Namespaced, SelectContext, BaseModel):
4425
- conditional: Union[SubselectComparison, Comparison, Conditional, "Parenthetical"]
4426
-
4427
- def __repr__(self):
4428
- return str(self.conditional)
4429
-
4430
- @property
4431
- def input(self) -> List[Concept]:
4432
- return self.conditional.input
4433
-
4434
- @property
4435
- def concept_arguments(self) -> List[Concept]:
4436
- return self.conditional.concept_arguments
4437
-
4438
- @property
4439
- def row_arguments(self) -> List[Concept]:
4440
- return self.conditional.row_arguments
4441
-
4442
- @property
4443
- def existence_arguments(self) -> list[tuple["Concept", ...]]:
4444
- return self.conditional.existence_arguments
4445
-
4446
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
4447
- return WhereClause(
4448
- conditional=self.conditional.with_merge(source, target, modifiers)
4449
- )
4450
-
4451
- def with_namespace(self, namespace: str) -> WhereClause:
4452
- return WhereClause(conditional=self.conditional.with_namespace(namespace))
4453
-
4454
- def with_select_context(
4455
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
4456
- ) -> WhereClause:
4457
- return self.__class__(
4458
- conditional=self.conditional.with_select_context(
4459
- local_concepts, grain, environment
4460
- )
4461
- )
4462
-
4463
- @property
4464
- def components(self):
4465
- from trilogy.core.processing.utility import decompose_condition
4466
-
4467
- return decompose_condition(self.conditional)
4468
-
4469
- @property
4470
- def is_scalar(self):
4471
- from trilogy.core.processing.utility import is_scalar_condition
4472
-
4473
- return is_scalar_condition(self.conditional)
4474
-
4475
-
4476
- class HavingClause(WhereClause):
4477
- pass
4478
-
4479
- def hydrate_missing(self, concepts: EnvironmentConceptDict):
4480
- self.conditional.hydrate_missing(concepts)
4481
-
4482
- def with_select_context(
4483
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
4484
- ) -> HavingClause:
4485
- return HavingClause(
4486
- conditional=self.conditional.with_select_context(
4487
- local_concepts, grain, environment
4488
- )
4489
- )
4490
-
4491
-
4492
- class MaterializedDataset(BaseModel):
4493
- address: Address
4494
-
4495
-
4496
- # TODO: combine with CTEs
4497
- # CTE contains procesed query?
4498
- # or CTE references CTE?
4499
-
4500
-
4501
- class ProcessedQuery(BaseModel):
4502
- output_columns: List[Concept]
4503
- ctes: List[CTE | UnionCTE]
4504
- base: CTE | UnionCTE
4505
- joins: List[Join]
4506
- grain: Grain
4507
- hidden_columns: set[str] = Field(default_factory=set)
4508
- limit: Optional[int] = None
4509
- where_clause: Optional[WhereClause] = None
4510
- having_clause: Optional[HavingClause] = None
4511
- order_by: Optional[OrderBy] = None
4512
- local_concepts: Annotated[
4513
- EnvironmentConceptDict, PlainValidator(validate_concepts)
4514
- ] = Field(default_factory=EnvironmentConceptDict)
4515
-
4516
-
4517
- class PersistQueryMixin(BaseModel):
4518
- output_to: MaterializedDataset
4519
- datasource: Datasource
4520
- # base:Dataset
4521
-
4522
-
4523
- class ProcessedQueryPersist(ProcessedQuery, PersistQueryMixin):
4524
- pass
4525
-
4526
-
4527
- class CopyQueryMixin(BaseModel):
4528
- target: str
4529
- target_type: IOType
4530
- # base:Dataset
4531
-
4532
-
4533
- class ProcessedCopyStatement(ProcessedQuery, CopyQueryMixin):
4534
- pass
4535
-
4536
-
4537
- class ProcessedShowStatement(BaseModel):
4538
- output_columns: List[Concept]
4539
- output_values: List[Union[Concept, Datasource, ProcessedQuery]]
4540
-
4541
-
4542
- class ProcessedRawSQLStatement(BaseModel):
4543
- text: str
4544
-
4545
-
4546
- class Limit(BaseModel):
4547
- count: int
4548
-
4549
-
4550
- class ConceptDeclarationStatement(HasUUID, BaseModel):
4551
- concept: Concept
4552
-
4553
-
4554
- class ConceptDerivation(BaseModel):
4555
- concept: Concept
4556
-
4557
-
4558
- class RowsetDerivationStatement(HasUUID, Namespaced, BaseModel):
4559
- name: str
4560
- select: SelectStatement | MultiSelectStatement
4561
- namespace: str
4562
-
4563
- def __repr__(self):
4564
- return f"RowsetDerivation<{str(self.select)}>"
4565
-
4566
- def __str__(self):
4567
- return self.__repr__()
4568
-
4569
- @property
4570
- def derived_concepts(self) -> List[Concept]:
4571
- output: list[Concept] = []
4572
- orig: dict[str, Concept] = {}
4573
- for orig_concept in self.select.output_components:
4574
- name = orig_concept.name
4575
- if isinstance(orig_concept.lineage, FilterItem):
4576
- if orig_concept.lineage.where == self.select.where_clause:
4577
- name = orig_concept.lineage.content.name
4578
-
4579
- new_concept = Concept(
4580
- name=name,
4581
- datatype=orig_concept.datatype,
4582
- purpose=orig_concept.purpose,
4583
- lineage=RowsetItem(
4584
- content=orig_concept, where=self.select.where_clause, rowset=self
4585
- ),
4586
- grain=orig_concept.grain,
4587
- # TODO: add proper metadata
4588
- metadata=Metadata(concept_source=ConceptSource.CTE),
4589
- namespace=(
4590
- f"{self.name}.{orig_concept.namespace}"
4591
- if orig_concept.namespace != self.namespace
4592
- else self.name
4593
- ),
4594
- keys=orig_concept.keys,
4595
- )
4596
- orig[orig_concept.address] = new_concept
4597
- output.append(new_concept)
4598
- default_grain = Grain.from_concepts([*output])
4599
- # remap everything to the properties of the rowset
4600
- for x in output:
4601
- if x.keys:
4602
- if all([k in orig for k in x.keys]):
4603
- x.keys = set([orig[k].address if k in orig else k for k in x.keys])
4604
- else:
4605
- # TODO: fix this up
4606
- x.keys = set()
4607
- for x in output:
4608
- if all([c in orig for c in x.grain.components]):
4609
- x.grain = Grain(
4610
- components={orig[c].address for c in x.grain.components}
4611
- )
4612
- else:
4613
- x.grain = default_grain
4614
- return output
4615
-
4616
- @property
4617
- def arguments(self) -> List[Concept]:
4618
- return self.select.output_components
4619
-
4620
- def with_namespace(self, namespace: str) -> "RowsetDerivationStatement":
4621
- return RowsetDerivationStatement(
4622
- name=self.name,
4623
- select=self.select.with_namespace(namespace),
4624
- namespace=namespace,
4625
- )
4626
-
4627
-
4628
- class RowsetItem(Mergeable, Namespaced, BaseModel):
4629
- content: Concept
4630
- rowset: RowsetDerivationStatement
4631
- where: Optional["WhereClause"] = None
4632
-
4633
- def __repr__(self):
4634
- return (
4635
- f"<Rowset<{self.rowset.name}>: {str(self.content)} where {str(self.where)}>"
4636
- )
4637
-
4638
- def __str__(self):
4639
- return self.__repr__()
4640
-
4641
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
4642
- return RowsetItem(
4643
- content=self.content.with_merge(source, target, modifiers),
4644
- rowset=self.rowset,
4645
- where=(
4646
- self.where.with_merge(source, target, modifiers) if self.where else None
4647
- ),
4648
- )
4649
-
4650
- def with_namespace(self, namespace: str) -> "RowsetItem":
4651
- return RowsetItem(
4652
- content=self.content.with_namespace(namespace),
4653
- where=self.where.with_namespace(namespace) if self.where else None,
4654
- rowset=self.rowset.with_namespace(namespace),
4655
- )
4656
-
4657
- @property
4658
- def arguments(self) -> List[Concept]:
4659
- output = [self.content]
4660
- if self.where:
4661
- output += self.where.input
4662
- return output
4663
-
4664
- @property
4665
- def output(self) -> Concept:
4666
- if isinstance(self.content, ConceptTransform):
4667
- return self.content.output
4668
- return self.content
4669
-
4670
- @output.setter
4671
- def output(self, value):
4672
- if isinstance(self.content, ConceptTransform):
4673
- self.content.output = value
4674
- else:
4675
- self.content = value
4676
-
4677
- @property
4678
- def input(self) -> List[Concept]:
4679
- base = self.content.input
4680
- if self.where:
4681
- base += self.where.input
4682
- return base
4683
-
4684
- @property
4685
- def output_datatype(self):
4686
- return self.content.datatype
4687
-
4688
- @property
4689
- def output_purpose(self):
4690
- return self.content.purpose
4691
-
4692
- @property
4693
- def concept_arguments(self):
4694
- if self.where:
4695
- return [self.content] + self.where.concept_arguments
4696
- return [self.content]
4697
-
4698
-
4699
- class Parenthetical(
4700
- ConceptArgs, Mergeable, Namespaced, ConstantInlineable, SelectContext, BaseModel
4701
- ):
4702
- content: "Expr"
4703
-
4704
- def __str__(self):
4705
- return self.__repr__()
4706
-
4707
- def __add__(self, other) -> Union["Parenthetical", "Conditional"]:
4708
- if other is None:
4709
- return self
4710
- elif isinstance(other, (Comparison, Conditional, Parenthetical)):
4711
- return Conditional(left=self, right=other, operator=BooleanOperator.AND)
4712
- raise ValueError(f"Cannot add {self.__class__} and {type(other)}")
4713
-
4714
- def __repr__(self):
4715
- return f"({str(self.content)})"
4716
-
4717
- def with_namespace(self, namespace: str):
4718
- return Parenthetical(
4719
- content=(
4720
- self.content.with_namespace(namespace)
4721
- if isinstance(self.content, Namespaced)
4722
- else self.content
4723
- )
4724
- )
4725
-
4726
- def with_merge(self, source: Concept, target: Concept, modifiers: List[Modifier]):
4727
- return Parenthetical(
4728
- content=(
4729
- self.content.with_merge(source, target, modifiers)
4730
- if isinstance(self.content, Mergeable)
4731
- else self.content
4732
- )
4733
- )
4734
-
4735
- def with_select_context(
4736
- self, local_concepts: dict[str, Concept], grain: Grain, environment: Environment
4737
- ):
4738
- return Parenthetical(
4739
- content=(
4740
- self.content.with_select_context(local_concepts, grain, environment)
4741
- if isinstance(self.content, SelectContext)
4742
- else self.content
4743
- )
4744
- )
4745
-
4746
- def inline_constant(self, concept: Concept):
4747
- return Parenthetical(
4748
- content=(
4749
- self.content.inline_constant(concept)
4750
- if isinstance(self.content, ConstantInlineable)
4751
- else self.content
4752
- )
4753
- )
4754
-
4755
- @property
4756
- def concept_arguments(self) -> List[Concept]:
4757
- base: List[Concept] = []
4758
- x = self.content
4759
- if hasattr(x, "concept_arguments"):
4760
- base += x.concept_arguments
4761
- elif isinstance(x, Concept):
4762
- base.append(x)
4763
- return base
4764
-
4765
- @property
4766
- def row_arguments(self) -> List[Concept]:
4767
- if isinstance(self.content, ConceptArgs):
4768
- return self.content.row_arguments
4769
- return self.concept_arguments
4770
-
4771
- @property
4772
- def existence_arguments(self) -> list[tuple["Concept", ...]]:
4773
- if isinstance(self.content, ConceptArgs):
4774
- return self.content.existence_arguments
4775
- return []
4776
-
4777
- @property
4778
- def input(self):
4779
- base = []
4780
- x = self.content
4781
- if hasattr(x, "input"):
4782
- base += x.input
4783
- return base
4784
-
4785
-
4786
- class TupleWrapper(Generic[VT], tuple):
4787
- """Used to distinguish parsed tuple objects from other tuples"""
4788
-
4789
- def __init__(self, val, type: DataType, **kwargs):
4790
- super().__init__()
4791
- self.type = type
4792
- self.val = val
4793
-
4794
- def __getnewargs__(self):
4795
- return (self.val, self.type)
4796
-
4797
- def __new__(cls, val, type: DataType, **kwargs):
4798
- return super().__new__(cls, tuple(val))
4799
- # self.type = type
4800
-
4801
- @classmethod
4802
- def __get_pydantic_core_schema__(
4803
- cls, source_type: Any, handler: Callable[[Any], core_schema.CoreSchema]
4804
- ) -> core_schema.CoreSchema:
4805
- args = get_args(source_type)
4806
- if args:
4807
- schema = handler(Tuple[args]) # type: ignore
4808
- else:
4809
- schema = handler(Tuple)
4810
- return core_schema.no_info_after_validator_function(cls.validate, schema)
4811
-
4812
- @classmethod
4813
- def validate(cls, v):
4814
- return cls(v, type=arg_to_datatype(v[0]))
4815
-
4816
-
4817
- class PersistStatement(HasUUID, BaseModel):
4818
- datasource: Datasource
4819
- select: SelectStatement
4820
- meta: Optional[Metadata] = Field(default_factory=lambda: Metadata())
4821
-
4822
- @property
4823
- def identifier(self):
4824
- return self.datasource.identifier
4825
-
4826
- @property
4827
- def address(self):
4828
- return self.datasource.address
4829
-
4830
-
4831
- class ShowStatement(BaseModel):
4832
- content: SelectStatement | PersistStatement | ShowCategory
4833
-
4834
-
4835
- Expr = (
4836
- bool
4837
- | MagicConstants
4838
- | int
4839
- | str
4840
- | float
4841
- | list
4842
- | WindowItem
4843
- | FilterItem
4844
- | Concept
4845
- | Comparison
4846
- | Conditional
4847
- | Parenthetical
4848
- | Function
4849
- | AggregateWrapper
4850
- )
4851
-
4852
-
4853
- Concept.model_rebuild()
4854
- Grain.model_rebuild()
4855
- WindowItem.model_rebuild()
4856
- WindowItemOrder.model_rebuild()
4857
- FilterItem.model_rebuild()
4858
- Comparison.model_rebuild()
4859
- Conditional.model_rebuild()
4860
- Parenthetical.model_rebuild()
4861
- WhereClause.model_rebuild()
4862
- ImportStatement.model_rebuild()
4863
- CaseWhen.model_rebuild()
4864
- CaseElse.model_rebuild()
4865
- SelectStatement.model_rebuild()
4866
- CTE.model_rebuild()
4867
- BaseJoin.model_rebuild()
4868
- QueryDatasource.model_rebuild()
4869
- ProcessedQuery.model_rebuild()
4870
- ProcessedQueryPersist.model_rebuild()
4871
- InstantiatedUnnestJoin.model_rebuild()
4872
- UndefinedConcept.model_rebuild()
4873
- Function.model_rebuild()
4874
- Grain.model_rebuild()
4875
-
4876
-
4877
- def list_to_wrapper(args):
4878
- types = [arg_to_datatype(arg) for arg in args]
4879
- assert len(set(types)) == 1
4880
- return ListWrapper(args, type=types[0])
4881
-
4882
-
4883
- def tuple_to_wrapper(args):
4884
- types = [arg_to_datatype(arg) for arg in args]
4885
- assert len(set(types)) == 1
4886
- return TupleWrapper(args, type=types[0])
4887
-
4888
-
4889
- def dict_to_map_wrapper(arg):
4890
- key_types = [arg_to_datatype(arg) for arg in arg.keys()]
4891
-
4892
- value_types = [arg_to_datatype(arg) for arg in arg.values()]
4893
- assert len(set(key_types)) == 1
4894
- assert len(set(key_types)) == 1
4895
- return MapWrapper(arg, key_type=key_types[0], value_type=value_types[0])
4896
-
4897
-
4898
- def merge_datatypes(
4899
- inputs: list[DataType | ListType | StructType | MapType | NumericType],
4900
- ) -> DataType | ListType | StructType | MapType | NumericType:
4901
- """This is a temporary hack for doing between
4902
- allowable datatype transformation matrix"""
4903
- if len(inputs) == 1:
4904
- return inputs[0]
4905
- if set(inputs) == {DataType.INTEGER, DataType.FLOAT}:
4906
- return DataType.FLOAT
4907
- if set(inputs) == {DataType.INTEGER, DataType.NUMERIC}:
4908
- return DataType.NUMERIC
4909
- if any(isinstance(x, NumericType) for x in inputs) and all(
4910
- isinstance(x, NumericType)
4911
- or x in (DataType.INTEGER, DataType.FLOAT, DataType.NUMERIC)
4912
- for x in inputs
4913
- ):
4914
- candidate = next(x for x in inputs if isinstance(x, NumericType))
4915
- return candidate
4916
- return inputs[0]
4917
-
4918
-
4919
- def arg_to_datatype(arg) -> DataType | ListType | StructType | MapType | NumericType:
4920
- if isinstance(arg, Function):
4921
- return arg.output_datatype
4922
- elif isinstance(arg, MagicConstants):
4923
- if arg == MagicConstants.NULL:
4924
- return DataType.NULL
4925
- raise ValueError(f"Cannot parse arg datatype for arg of type {arg}")
4926
- elif isinstance(arg, Concept):
4927
- return arg.datatype
4928
- elif isinstance(arg, bool):
4929
- return DataType.BOOL
4930
- elif isinstance(arg, int):
4931
- return DataType.INTEGER
4932
- elif isinstance(arg, str):
4933
- return DataType.STRING
4934
- elif isinstance(arg, float):
4935
- return DataType.FLOAT
4936
- elif isinstance(arg, NumericType):
4937
- return arg
4938
- elif isinstance(arg, ListWrapper):
4939
- return ListType(type=arg.type)
4940
- elif isinstance(arg, AggregateWrapper):
4941
- return arg.function.output_datatype
4942
- elif isinstance(arg, Parenthetical):
4943
- return arg_to_datatype(arg.content)
4944
- elif isinstance(arg, TupleWrapper):
4945
- return ListType(type=arg.type)
4946
- elif isinstance(arg, WindowItem):
4947
- if arg.type in (WindowType.RANK, WindowType.ROW_NUMBER):
4948
- return DataType.INTEGER
4949
- return arg_to_datatype(arg.content)
4950
- elif isinstance(arg, list):
4951
- wrapper = list_to_wrapper(arg)
4952
- return ListType(type=wrapper.type)
4953
- elif isinstance(arg, MapWrapper):
4954
- return MapType(key_type=arg.key_type, value_type=arg.value_type)
4955
- elif isinstance(arg, datetime):
4956
- return DataType.DATETIME
4957
- elif isinstance(arg, date):
4958
- return DataType.DATE
4959
- else:
4960
- raise ValueError(f"Cannot parse arg datatype for arg of raw type {type(arg)}")