lbug 0.12.3-dev.14 → 0.12.3-dev.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. package/lbug-source/CMakeLists.txt +1 -1
  2. package/lbug-source/src/include/optimizer/count_rel_table_optimizer.h +49 -0
  3. package/lbug-source/src/include/optimizer/logical_operator_visitor.h +6 -0
  4. package/lbug-source/src/include/planner/operator/logical_operator.h +1 -0
  5. package/lbug-source/src/include/planner/operator/scan/logical_count_rel_table.h +84 -0
  6. package/lbug-source/src/include/processor/operator/physical_operator.h +1 -0
  7. package/lbug-source/src/include/processor/operator/scan/count_rel_table.h +62 -0
  8. package/lbug-source/src/include/processor/plan_mapper.h +2 -0
  9. package/lbug-source/src/optimizer/CMakeLists.txt +1 -0
  10. package/lbug-source/src/optimizer/count_rel_table_optimizer.cpp +217 -0
  11. package/lbug-source/src/optimizer/logical_operator_visitor.cpp +6 -0
  12. package/lbug-source/src/optimizer/optimizer.cpp +6 -0
  13. package/lbug-source/src/planner/operator/logical_operator.cpp +2 -0
  14. package/lbug-source/src/planner/operator/scan/CMakeLists.txt +1 -0
  15. package/lbug-source/src/planner/operator/scan/logical_count_rel_table.cpp +24 -0
  16. package/lbug-source/src/processor/map/CMakeLists.txt +1 -0
  17. package/lbug-source/src/processor/map/map_count_rel_table.cpp +55 -0
  18. package/lbug-source/src/processor/map/plan_mapper.cpp +3 -0
  19. package/lbug-source/src/processor/operator/physical_operator.cpp +2 -0
  20. package/lbug-source/src/processor/operator/scan/CMakeLists.txt +1 -0
  21. package/lbug-source/src/processor/operator/scan/count_rel_table.cpp +137 -0
  22. package/lbug-source/test/optimizer/optimizer_test.cpp +46 -0
  23. package/lbug-source/tools/benchmark/count_rel_table.benchmark +5 -0
  24. package/lbug-source/tools/shell/embedded_shell.cpp +11 -0
  25. package/lbug-source/tools/shell/linenoise.cpp +3 -3
  26. package/lbug-source/tools/shell/test/test_shell_basics.py +12 -0
  27. package/package.json +1 -1
  28. package/prebuilt/lbugjs-darwin-arm64.node +0 -0
  29. package/prebuilt/lbugjs-linux-arm64.node +0 -0
  30. package/prebuilt/lbugjs-linux-x64.node +0 -0
  31. package/prebuilt/lbugjs-win32-x64.node +0 -0
@@ -1,6 +1,6 @@
1
1
  cmake_minimum_required(VERSION 3.15)
2
2
 
3
- project(Lbug VERSION 0.12.3.14 LANGUAGES CXX C)
3
+ project(Lbug VERSION 0.12.3.15 LANGUAGES CXX C)
4
4
 
5
5
  option(SINGLE_THREADED "Single-threaded mode" FALSE)
6
6
  if(SINGLE_THREADED)
@@ -0,0 +1,49 @@
1
+ #pragma once
2
+
3
+ #include "logical_operator_visitor.h"
4
+ #include "planner/operator/logical_plan.h"
5
+
6
+ namespace lbug {
7
+ namespace main {
8
+ class ClientContext;
9
+ }
10
+
11
+ namespace optimizer {
12
+
13
+ /**
14
+ * This optimizer detects patterns where we're counting all rows from a single rel table
15
+ * without any filters, and replaces the scan + aggregate with a direct count from table metadata.
16
+ *
17
+ * Pattern detected:
18
+ * AGGREGATE (COUNT_STAR only, no keys) →
19
+ * PROJECTION (empty or pass-through) →
20
+ * EXTEND (single rel table) →
21
+ * SCAN_NODE_TABLE
22
+ *
23
+ * This pattern is replaced with:
24
+ * COUNT_REL_TABLE (new operator that directly reads the count from table metadata)
25
+ */
26
+ class CountRelTableOptimizer : public LogicalOperatorVisitor {
27
+ public:
28
+ explicit CountRelTableOptimizer(main::ClientContext* context) : context{context} {}
29
+
30
+ void rewrite(planner::LogicalPlan* plan);
31
+
32
+ private:
33
+ std::shared_ptr<planner::LogicalOperator> visitOperator(
34
+ const std::shared_ptr<planner::LogicalOperator>& op);
35
+
36
+ std::shared_ptr<planner::LogicalOperator> visitAggregateReplace(
37
+ std::shared_ptr<planner::LogicalOperator> op) override;
38
+
39
+ // Check if the aggregate is a simple COUNT(*) with no keys
40
+ bool isSimpleCountStar(planner::LogicalOperator* op) const;
41
+
42
+ // Check if the plan below aggregate matches the pattern for optimization
43
+ bool canOptimize(planner::LogicalOperator* aggregate) const;
44
+
45
+ main::ClientContext* context;
46
+ };
47
+
48
+ } // namespace optimizer
49
+ } // namespace lbug
@@ -39,6 +39,12 @@ protected:
39
39
  return op;
40
40
  }
41
41
 
42
+ virtual void visitCountRelTable(planner::LogicalOperator* /*op*/) {}
43
+ virtual std::shared_ptr<planner::LogicalOperator> visitCountRelTableReplace(
44
+ std::shared_ptr<planner::LogicalOperator> op) {
45
+ return op;
46
+ }
47
+
42
48
  virtual void visitDelete(planner::LogicalOperator* /*op*/) {}
43
49
  virtual std::shared_ptr<planner::LogicalOperator> visitDeleteReplace(
44
50
  std::shared_ptr<planner::LogicalOperator> op) {
@@ -17,6 +17,7 @@ enum class LogicalOperatorType : uint8_t {
17
17
  ATTACH_DATABASE,
18
18
  COPY_FROM,
19
19
  COPY_TO,
20
+ COUNT_REL_TABLE,
20
21
  CREATE_MACRO,
21
22
  CREATE_SEQUENCE,
22
23
  CREATE_TABLE,
@@ -0,0 +1,84 @@
1
+ #pragma once
2
+
3
+ #include "binder/expression/expression.h"
4
+ #include "binder/expression/node_expression.h"
5
+ #include "catalog/catalog_entry/rel_group_catalog_entry.h"
6
+ #include "common/enums/extend_direction.h"
7
+ #include "planner/operator/logical_operator.h"
8
+
9
+ namespace lbug {
10
+ namespace planner {
11
+
12
+ struct LogicalCountRelTablePrintInfo final : OPPrintInfo {
13
+ std::string relTableName;
14
+ std::shared_ptr<binder::Expression> countExpr;
15
+
16
+ LogicalCountRelTablePrintInfo(std::string relTableName,
17
+ std::shared_ptr<binder::Expression> countExpr)
18
+ : relTableName{std::move(relTableName)}, countExpr{std::move(countExpr)} {}
19
+
20
+ std::string toString() const override {
21
+ return "Table: " + relTableName + ", Count: " + countExpr->toString();
22
+ }
23
+
24
+ std::unique_ptr<OPPrintInfo> copy() const override {
25
+ return std::make_unique<LogicalCountRelTablePrintInfo>(relTableName, countExpr);
26
+ }
27
+ };
28
+
29
+ /**
30
+ * LogicalCountRelTable is an optimized operator that counts the number of rows
31
+ * in a rel table by scanning through bound nodes and counting edges.
32
+ *
33
+ * This operator is created by CountRelTableOptimizer when it detects:
34
+ * COUNT(*) over a single rel table with no filters
35
+ */
36
+ class LogicalCountRelTable final : public LogicalOperator {
37
+ static constexpr LogicalOperatorType type_ = LogicalOperatorType::COUNT_REL_TABLE;
38
+
39
+ public:
40
+ LogicalCountRelTable(catalog::RelGroupCatalogEntry* relGroupEntry,
41
+ std::vector<common::table_id_t> relTableIDs,
42
+ std::vector<common::table_id_t> boundNodeTableIDs,
43
+ std::shared_ptr<binder::NodeExpression> boundNode, common::ExtendDirection direction,
44
+ std::shared_ptr<binder::Expression> countExpr)
45
+ : LogicalOperator{type_}, relGroupEntry{relGroupEntry}, relTableIDs{std::move(relTableIDs)},
46
+ boundNodeTableIDs{std::move(boundNodeTableIDs)}, boundNode{std::move(boundNode)},
47
+ direction{direction}, countExpr{std::move(countExpr)} {
48
+ cardinality = 1; // Always returns exactly one row
49
+ }
50
+
51
+ void computeFactorizedSchema() override;
52
+ void computeFlatSchema() override;
53
+
54
+ std::string getExpressionsForPrinting() const override { return countExpr->toString(); }
55
+
56
+ catalog::RelGroupCatalogEntry* getRelGroupEntry() const { return relGroupEntry; }
57
+ const std::vector<common::table_id_t>& getRelTableIDs() const { return relTableIDs; }
58
+ const std::vector<common::table_id_t>& getBoundNodeTableIDs() const {
59
+ return boundNodeTableIDs;
60
+ }
61
+ std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
62
+ common::ExtendDirection getDirection() const { return direction; }
63
+ std::shared_ptr<binder::Expression> getCountExpr() const { return countExpr; }
64
+
65
+ std::unique_ptr<OPPrintInfo> getPrintInfo() const override {
66
+ return std::make_unique<LogicalCountRelTablePrintInfo>(relGroupEntry->getName(), countExpr);
67
+ }
68
+
69
+ std::unique_ptr<LogicalOperator> copy() override {
70
+ return std::make_unique<LogicalCountRelTable>(relGroupEntry, relTableIDs, boundNodeTableIDs,
71
+ boundNode, direction, countExpr);
72
+ }
73
+
74
+ private:
75
+ catalog::RelGroupCatalogEntry* relGroupEntry;
76
+ std::vector<common::table_id_t> relTableIDs;
77
+ std::vector<common::table_id_t> boundNodeTableIDs;
78
+ std::shared_ptr<binder::NodeExpression> boundNode;
79
+ common::ExtendDirection direction;
80
+ std::shared_ptr<binder::Expression> countExpr;
81
+ };
82
+
83
+ } // namespace planner
84
+ } // namespace lbug
@@ -22,6 +22,7 @@ enum class PhysicalOperatorType : uint8_t {
22
22
  ATTACH_DATABASE,
23
23
  BATCH_INSERT,
24
24
  COPY_TO,
25
+ COUNT_REL_TABLE,
25
26
  CREATE_MACRO,
26
27
  CREATE_SEQUENCE,
27
28
  CREATE_TABLE,
@@ -0,0 +1,62 @@
1
+ #pragma once
2
+
3
+ #include "common/enums/rel_direction.h"
4
+ #include "processor/operator/physical_operator.h"
5
+ #include "storage/table/node_table.h"
6
+ #include "storage/table/rel_table.h"
7
+
8
+ namespace lbug {
9
+ namespace processor {
10
+
11
+ struct CountRelTablePrintInfo final : OPPrintInfo {
12
+ std::string relTableName;
13
+
14
+ explicit CountRelTablePrintInfo(std::string relTableName)
15
+ : relTableName{std::move(relTableName)} {}
16
+
17
+ std::string toString() const override { return "Table: " + relTableName; }
18
+
19
+ std::unique_ptr<OPPrintInfo> copy() const override {
20
+ return std::make_unique<CountRelTablePrintInfo>(relTableName);
21
+ }
22
+ };
23
+
24
+ /**
25
+ * CountRelTable is a source operator that counts edges in a rel table
26
+ * by scanning through all bound nodes and counting their edges.
27
+ * It creates its own internal vectors for node scanning (not exposed in ResultSet).
28
+ */
29
+ class CountRelTable final : public PhysicalOperator {
30
+ static constexpr PhysicalOperatorType type_ = PhysicalOperatorType::COUNT_REL_TABLE;
31
+
32
+ public:
33
+ CountRelTable(std::vector<storage::NodeTable*> nodeTables,
34
+ std::vector<storage::RelTable*> relTables, common::RelDataDirection direction,
35
+ DataPos countOutputPos, physical_op_id id, std::unique_ptr<OPPrintInfo> printInfo)
36
+ : PhysicalOperator{type_, id, std::move(printInfo)}, nodeTables{std::move(nodeTables)},
37
+ relTables{std::move(relTables)}, direction{direction}, countOutputPos{countOutputPos} {}
38
+
39
+ bool isSource() const override { return true; }
40
+ bool isParallel() const override { return false; }
41
+
42
+ void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
43
+
44
+ bool getNextTuplesInternal(ExecutionContext* context) override;
45
+
46
+ std::unique_ptr<PhysicalOperator> copy() override {
47
+ return std::make_unique<CountRelTable>(nodeTables, relTables, direction, countOutputPos, id,
48
+ printInfo->copy());
49
+ }
50
+
51
+ private:
52
+ std::vector<storage::NodeTable*> nodeTables;
53
+ std::vector<storage::RelTable*> relTables;
54
+ common::RelDataDirection direction;
55
+ DataPos countOutputPos;
56
+ common::ValueVector* countVector;
57
+ bool hasExecuted;
58
+ common::row_idx_t totalCount;
59
+ };
60
+
61
+ } // namespace processor
62
+ } // namespace lbug
@@ -90,6 +90,8 @@ public:
90
90
  std::unique_ptr<PhysicalOperator> mapCopyRelFrom(
91
91
  const planner::LogicalOperator* logicalOperator);
92
92
  std::unique_ptr<PhysicalOperator> mapCopyTo(const planner::LogicalOperator* logicalOperator);
93
+ std::unique_ptr<PhysicalOperator> mapCountRelTable(
94
+ const planner::LogicalOperator* logicalOperator);
93
95
  std::unique_ptr<PhysicalOperator> mapCreateMacro(
94
96
  const planner::LogicalOperator* logicalOperator);
95
97
  std::unique_ptr<PhysicalOperator> mapCreateSequence(
@@ -4,6 +4,7 @@ add_library(lbug_optimizer
4
4
  agg_key_dependency_optimizer.cpp
5
5
  cardinality_updater.cpp
6
6
  correlated_subquery_unnest_solver.cpp
7
+ count_rel_table_optimizer.cpp
7
8
  factorization_rewriter.cpp
8
9
  filter_push_down_optimizer.cpp
9
10
  logical_operator_collector.cpp
@@ -0,0 +1,217 @@
1
+ #include "optimizer/count_rel_table_optimizer.h"
2
+
3
+ #include "binder/expression/aggregate_function_expression.h"
4
+ #include "binder/expression/node_expression.h"
5
+ #include "catalog/catalog_entry/node_table_id_pair.h"
6
+ #include "function/aggregate/count_star.h"
7
+ #include "main/client_context.h"
8
+ #include "planner/operator/extend/logical_extend.h"
9
+ #include "planner/operator/logical_aggregate.h"
10
+ #include "planner/operator/logical_projection.h"
11
+ #include "planner/operator/scan/logical_count_rel_table.h"
12
+ #include "planner/operator/scan/logical_scan_node_table.h"
13
+
14
+ using namespace lbug::common;
15
+ using namespace lbug::planner;
16
+ using namespace lbug::binder;
17
+ using namespace lbug::catalog;
18
+
19
+ namespace lbug {
20
+ namespace optimizer {
21
+
22
+ void CountRelTableOptimizer::rewrite(LogicalPlan* plan) {
23
+ visitOperator(plan->getLastOperator());
24
+ }
25
+
26
+ std::shared_ptr<LogicalOperator> CountRelTableOptimizer::visitOperator(
27
+ const std::shared_ptr<LogicalOperator>& op) {
28
+ // bottom-up traversal
29
+ for (auto i = 0u; i < op->getNumChildren(); ++i) {
30
+ op->setChild(i, visitOperator(op->getChild(i)));
31
+ }
32
+ auto result = visitOperatorReplaceSwitch(op);
33
+ result->computeFlatSchema();
34
+ return result;
35
+ }
36
+
37
+ bool CountRelTableOptimizer::isSimpleCountStar(LogicalOperator* op) const {
38
+ if (op->getOperatorType() != LogicalOperatorType::AGGREGATE) {
39
+ return false;
40
+ }
41
+ auto& aggregate = op->constCast<LogicalAggregate>();
42
+
43
+ // Must have no keys (i.e., a simple aggregate without GROUP BY)
44
+ if (aggregate.hasKeys()) {
45
+ return false;
46
+ }
47
+
48
+ // Must have exactly one aggregate expression
49
+ auto aggregates = aggregate.getAggregates();
50
+ if (aggregates.size() != 1) {
51
+ return false;
52
+ }
53
+
54
+ // Must be COUNT_STAR
55
+ auto& aggExpr = aggregates[0];
56
+ if (aggExpr->expressionType != ExpressionType::AGGREGATE_FUNCTION) {
57
+ return false;
58
+ }
59
+ auto& aggFuncExpr = aggExpr->constCast<AggregateFunctionExpression>();
60
+ if (aggFuncExpr.getFunction().name != function::CountStarFunction::name) {
61
+ return false;
62
+ }
63
+
64
+ // COUNT_STAR should not be DISTINCT (conceptually it doesn't make sense)
65
+ if (aggFuncExpr.isDistinct()) {
66
+ return false;
67
+ }
68
+
69
+ return true;
70
+ }
71
+
72
+ bool CountRelTableOptimizer::canOptimize(LogicalOperator* aggregate) const {
73
+ // Pattern we're looking for:
74
+ // AGGREGATE (COUNT_STAR, no keys)
75
+ // -> PROJECTION (empty expressions or pass-through)
76
+ // -> EXTEND (single rel table, no properties scanned)
77
+ // -> SCAN_NODE_TABLE (no properties scanned)
78
+ //
79
+ // Note: The projection between aggregate and extend might be empty or
80
+ // just projecting the count expression.
81
+
82
+ auto* current = aggregate->getChild(0).get();
83
+
84
+ // Skip any projections (they should be empty or just for count)
85
+ while (current->getOperatorType() == LogicalOperatorType::PROJECTION) {
86
+ auto& proj = current->constCast<LogicalProjection>();
87
+ // Empty projection is okay, it's just a passthrough
88
+ if (!proj.getExpressionsToProject().empty()) {
89
+ // If projection has expressions, they should all be aggregate expressions
90
+ // (which means they're just passing through the count)
91
+ for (auto& expr : proj.getExpressionsToProject()) {
92
+ if (expr->expressionType != ExpressionType::AGGREGATE_FUNCTION) {
93
+ return false;
94
+ }
95
+ }
96
+ }
97
+ current = current->getChild(0).get();
98
+ }
99
+
100
+ // Now we should have EXTEND
101
+ if (current->getOperatorType() != LogicalOperatorType::EXTEND) {
102
+ return false;
103
+ }
104
+ auto& extend = current->constCast<LogicalExtend>();
105
+
106
+ // Don't optimize for undirected edges (BOTH direction) - the query pattern
107
+ // (a)-[e]-(b) generates a plan that scans both directions, and optimizing
108
+ // this would require special handling to avoid double counting.
109
+ if (extend.getDirection() == ExtendDirection::BOTH) {
110
+ return false;
111
+ }
112
+
113
+ // The rel should be a single table (not multi-labeled)
114
+ auto rel = extend.getRel();
115
+ if (rel->isMultiLabeled()) {
116
+ return false;
117
+ }
118
+
119
+ // Check if we're scanning any properties (we can only optimize when no properties needed)
120
+ if (!extend.getProperties().empty()) {
121
+ return false;
122
+ }
123
+
124
+ // The child of extend should be SCAN_NODE_TABLE
125
+ auto* extendChild = current->getChild(0).get();
126
+ if (extendChild->getOperatorType() != LogicalOperatorType::SCAN_NODE_TABLE) {
127
+ return false;
128
+ }
129
+ auto& scanNode = extendChild->constCast<LogicalScanNodeTable>();
130
+
131
+ // Check if node scan has any properties (we can only optimize when no properties needed)
132
+ if (!scanNode.getProperties().empty()) {
133
+ return false;
134
+ }
135
+
136
+ return true;
137
+ }
138
+
139
+ std::shared_ptr<LogicalOperator> CountRelTableOptimizer::visitAggregateReplace(
140
+ std::shared_ptr<LogicalOperator> op) {
141
+ if (!isSimpleCountStar(op.get())) {
142
+ return op;
143
+ }
144
+
145
+ if (!canOptimize(op.get())) {
146
+ return op;
147
+ }
148
+
149
+ // Find the EXTEND operator
150
+ auto* current = op->getChild(0).get();
151
+ while (current->getOperatorType() == LogicalOperatorType::PROJECTION) {
152
+ current = current->getChild(0).get();
153
+ }
154
+
155
+ KU_ASSERT(current->getOperatorType() == LogicalOperatorType::EXTEND);
156
+ auto& extend = current->constCast<LogicalExtend>();
157
+ auto rel = extend.getRel();
158
+ auto boundNode = extend.getBoundNode();
159
+ auto nbrNode = extend.getNbrNode();
160
+
161
+ // Get the rel group entry
162
+ KU_ASSERT(rel->getNumEntries() == 1);
163
+ auto* relGroupEntry = rel->getEntry(0)->ptrCast<RelGroupCatalogEntry>();
164
+
165
+ // Determine the source and destination node table IDs based on extend direction.
166
+ // If extendFromSource is true, then boundNode is the source and nbrNode is the destination.
167
+ // If extendFromSource is false, then boundNode is the destination and nbrNode is the source.
168
+ auto boundNodeTableIDs = boundNode->getTableIDsSet();
169
+ auto nbrNodeTableIDs = nbrNode->getTableIDsSet();
170
+
171
+ // Get only the rel table IDs that match the specific node table ID pairs in the query.
172
+ // A rel table connects a specific (srcTableID, dstTableID) pair.
173
+ std::vector<table_id_t> relTableIDs;
174
+ for (auto& info : relGroupEntry->getRelEntryInfos()) {
175
+ table_id_t srcTableID = info.nodePair.srcTableID;
176
+ table_id_t dstTableID = info.nodePair.dstTableID;
177
+
178
+ bool matches = false;
179
+ if (extend.extendFromSourceNode()) {
180
+ // boundNode is src, nbrNode is dst
181
+ matches =
182
+ boundNodeTableIDs.contains(srcTableID) && nbrNodeTableIDs.contains(dstTableID);
183
+ } else {
184
+ // boundNode is dst, nbrNode is src
185
+ matches =
186
+ boundNodeTableIDs.contains(dstTableID) && nbrNodeTableIDs.contains(srcTableID);
187
+ }
188
+
189
+ if (matches) {
190
+ relTableIDs.push_back(info.oid);
191
+ }
192
+ }
193
+
194
+ // If no matching rel tables, don't optimize (shouldn't happen for valid queries)
195
+ if (relTableIDs.empty()) {
196
+ return op;
197
+ }
198
+
199
+ // Get the count expression from the original aggregate
200
+ auto& aggregate = op->constCast<LogicalAggregate>();
201
+ auto countExpr = aggregate.getAggregates()[0];
202
+
203
+ // Get the bound node table IDs as a vector
204
+ std::vector<table_id_t> boundNodeTableIDsVec(boundNodeTableIDs.begin(),
205
+ boundNodeTableIDs.end());
206
+
207
+ // Create the new COUNT_REL_TABLE operator with all necessary information for scanning
208
+ auto countRelTable =
209
+ std::make_shared<LogicalCountRelTable>(relGroupEntry, std::move(relTableIDs),
210
+ std::move(boundNodeTableIDsVec), boundNode, extend.getDirection(), countExpr);
211
+ countRelTable->computeFlatSchema();
212
+
213
+ return countRelTable;
214
+ }
215
+
216
+ } // namespace optimizer
217
+ } // namespace lbug
@@ -19,6 +19,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(LogicalOperator* op) {
19
19
  case LogicalOperatorType::COPY_TO: {
20
20
  visitCopyTo(op);
21
21
  } break;
22
+ case LogicalOperatorType::COUNT_REL_TABLE: {
23
+ visitCountRelTable(op);
24
+ } break;
22
25
  case LogicalOperatorType::DELETE: {
23
26
  visitDelete(op);
24
27
  } break;
@@ -108,6 +111,9 @@ std::shared_ptr<LogicalOperator> LogicalOperatorVisitor::visitOperatorReplaceSwi
108
111
  case LogicalOperatorType::COPY_TO: {
109
112
  return visitCopyToReplace(op);
110
113
  }
114
+ case LogicalOperatorType::COUNT_REL_TABLE: {
115
+ return visitCountRelTableReplace(op);
116
+ }
111
117
  case LogicalOperatorType::DELETE: {
112
118
  return visitDeleteReplace(op);
113
119
  }
@@ -5,6 +5,7 @@
5
5
  #include "optimizer/agg_key_dependency_optimizer.h"
6
6
  #include "optimizer/cardinality_updater.h"
7
7
  #include "optimizer/correlated_subquery_unnest_solver.h"
8
+ #include "optimizer/count_rel_table_optimizer.h"
8
9
  #include "optimizer/factorization_rewriter.h"
9
10
  #include "optimizer/filter_push_down_optimizer.h"
10
11
  #include "optimizer/limit_push_down_optimizer.h"
@@ -32,6 +33,11 @@ void Optimizer::optimize(planner::LogicalPlan* plan, main::ClientContext* contex
32
33
  auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
33
34
  removeUnnecessaryJoinOptimizer.rewrite(plan);
34
35
 
36
+ // CountRelTableOptimizer should be applied early before other optimizations
37
+ // that might change the plan structure.
38
+ auto countRelTableOptimizer = CountRelTableOptimizer(context);
39
+ countRelTableOptimizer.rewrite(plan);
40
+
35
41
  auto filterPushDownOptimizer = FilterPushDownOptimizer(context);
36
42
  filterPushDownOptimizer.rewrite(plan);
37
43
 
@@ -22,6 +22,8 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
22
22
  return "COPY_FROM";
23
23
  case LogicalOperatorType::COPY_TO:
24
24
  return "COPY_TO";
25
+ case LogicalOperatorType::COUNT_REL_TABLE:
26
+ return "COUNT_REL_TABLE";
25
27
  case LogicalOperatorType::CREATE_MACRO:
26
28
  return "CREATE_MACRO";
27
29
  case LogicalOperatorType::CREATE_SEQUENCE:
@@ -1,5 +1,6 @@
1
1
  add_library(lbug_planner_scan
2
2
  OBJECT
3
+ logical_count_rel_table.cpp
3
4
  logical_expressions_scan.cpp
4
5
  logical_index_look_up.cpp
5
6
  logical_scan_node_table.cpp)
@@ -0,0 +1,24 @@
1
+ #include "planner/operator/scan/logical_count_rel_table.h"
2
+
3
+ namespace lbug {
4
+ namespace planner {
5
+
6
+ void LogicalCountRelTable::computeFactorizedSchema() {
7
+ createEmptySchema();
8
+ // Only output the count expression in a single-state group.
9
+ // This operator is a source - it has no child in the logical plan.
10
+ // The bound node is used internally for scanning but not exposed.
11
+ auto groupPos = schema->createGroup();
12
+ schema->insertToGroupAndScope(countExpr, groupPos);
13
+ schema->setGroupAsSingleState(groupPos);
14
+ }
15
+
16
+ void LogicalCountRelTable::computeFlatSchema() {
17
+ createEmptySchema();
18
+ // For flat schema, create a single group with the count expression.
19
+ auto groupPos = schema->createGroup();
20
+ schema->insertToGroupAndScope(countExpr, groupPos);
21
+ }
22
+
23
+ } // namespace planner
24
+ } // namespace lbug
@@ -7,6 +7,7 @@ add_library(lbug_processor_mapper
7
7
  map_acc_hash_join.cpp
8
8
  map_accumulate.cpp
9
9
  map_aggregate.cpp
10
+ map_count_rel_table.cpp
10
11
  map_standalone_call.cpp
11
12
  map_table_function_call.cpp
12
13
  map_copy_to.cpp
@@ -0,0 +1,55 @@
1
+ #include "planner/operator/scan/logical_count_rel_table.h"
2
+ #include "processor/operator/scan/count_rel_table.h"
3
+ #include "processor/plan_mapper.h"
4
+ #include "storage/storage_manager.h"
5
+
6
+ using namespace lbug::common;
7
+ using namespace lbug::planner;
8
+ using namespace lbug::storage;
9
+
10
+ namespace lbug {
11
+ namespace processor {
12
+
13
+ std::unique_ptr<PhysicalOperator> PlanMapper::mapCountRelTable(
14
+ const LogicalOperator* logicalOperator) {
15
+ auto& logicalCountRelTable = logicalOperator->constCast<LogicalCountRelTable>();
16
+ auto outSchema = logicalCountRelTable.getSchema();
17
+
18
+ auto storageManager = StorageManager::Get(*clientContext);
19
+
20
+ // Get the node tables for scanning bound nodes
21
+ std::vector<NodeTable*> nodeTables;
22
+ for (auto tableID : logicalCountRelTable.getBoundNodeTableIDs()) {
23
+ nodeTables.push_back(storageManager->getTable(tableID)->ptrCast<NodeTable>());
24
+ }
25
+
26
+ // Get the rel tables
27
+ std::vector<RelTable*> relTables;
28
+ for (auto tableID : logicalCountRelTable.getRelTableIDs()) {
29
+ relTables.push_back(storageManager->getTable(tableID)->ptrCast<RelTable>());
30
+ }
31
+
32
+ // Determine rel data direction from extend direction
33
+ auto extendDirection = logicalCountRelTable.getDirection();
34
+ RelDataDirection relDirection;
35
+ if (extendDirection == ExtendDirection::FWD) {
36
+ relDirection = RelDataDirection::FWD;
37
+ } else if (extendDirection == ExtendDirection::BWD) {
38
+ relDirection = RelDataDirection::BWD;
39
+ } else {
40
+ // For BOTH, we'll scan FWD (shouldn't reach here as optimizer filters BOTH)
41
+ relDirection = RelDataDirection::FWD;
42
+ }
43
+
44
+ // Get the output position for the count expression
45
+ auto countOutputPos = getDataPos(*logicalCountRelTable.getCountExpr(), *outSchema);
46
+
47
+ auto printInfo = std::make_unique<CountRelTablePrintInfo>(
48
+ logicalCountRelTable.getRelGroupEntry()->getName());
49
+
50
+ return std::make_unique<CountRelTable>(std::move(nodeTables), std::move(relTables),
51
+ relDirection, countOutputPos, getOperatorID(), std::move(printInfo));
52
+ }
53
+
54
+ } // namespace processor
55
+ } // namespace lbug
@@ -62,6 +62,9 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapOperator(const LogicalOperator*
62
62
  case LogicalOperatorType::COPY_TO: {
63
63
  physicalOperator = mapCopyTo(logicalOperator);
64
64
  } break;
65
+ case LogicalOperatorType::COUNT_REL_TABLE: {
66
+ physicalOperator = mapCountRelTable(logicalOperator);
67
+ } break;
65
68
  case LogicalOperatorType::CREATE_MACRO: {
66
69
  physicalOperator = mapCreateMacro(logicalOperator);
67
70
  } break;
@@ -27,6 +27,8 @@ std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType ope
27
27
  return "BATCH_INSERT";
28
28
  case PhysicalOperatorType::COPY_TO:
29
29
  return "COPY_TO";
30
+ case PhysicalOperatorType::COUNT_REL_TABLE:
31
+ return "COUNT_REL_TABLE";
30
32
  case PhysicalOperatorType::CREATE_MACRO:
31
33
  return "CREATE_MACRO";
32
34
  case PhysicalOperatorType::CREATE_SEQUENCE:
@@ -1,5 +1,6 @@
1
1
  add_library(lbug_processor_operator_scan
2
2
  OBJECT
3
+ count_rel_table.cpp
3
4
  primary_key_scan_node_table.cpp
4
5
  scan_multi_rel_tables.cpp
5
6
  scan_node_table.cpp
@@ -0,0 +1,137 @@
1
+ #include "processor/operator/scan/count_rel_table.h"
2
+
3
+ #include "common/system_config.h"
4
+ #include "main/client_context.h"
5
+ #include "main/database.h"
6
+ #include "processor/execution_context.h"
7
+ #include "storage/buffer_manager/memory_manager.h"
8
+ #include "storage/local_storage/local_rel_table.h"
9
+ #include "storage/local_storage/local_storage.h"
10
+ #include "storage/table/column.h"
11
+ #include "storage/table/column_chunk_data.h"
12
+ #include "storage/table/csr_chunked_node_group.h"
13
+ #include "storage/table/csr_node_group.h"
14
+ #include "storage/table/rel_table_data.h"
15
+ #include "transaction/transaction.h"
16
+
17
+ using namespace lbug::common;
18
+ using namespace lbug::storage;
19
+ using namespace lbug::transaction;
20
+
21
+ namespace lbug {
22
+ namespace processor {
23
+
24
+ void CountRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*context*/) {
25
+ countVector = resultSet->getValueVector(countOutputPos).get();
26
+ hasExecuted = false;
27
+ totalCount = 0;
28
+ }
29
+
30
+ // Count rels by using CSR metadata, accounting for deletions and uncommitted data.
31
+ // This is more efficient than scanning through all edges.
32
+ bool CountRelTable::getNextTuplesInternal(ExecutionContext* context) {
33
+ if (hasExecuted) {
34
+ return false;
35
+ }
36
+
37
+ auto transaction = Transaction::Get(*context->clientContext);
38
+ auto* memoryManager = context->clientContext->getDatabase()->getMemoryManager();
39
+
40
+ for (auto* relTable : relTables) {
41
+ // Get the RelTableData for the specified direction
42
+ auto* relTableData = relTable->getDirectedTableData(direction);
43
+ auto numNodeGroups = relTableData->getNumNodeGroups();
44
+ auto* csrLengthColumn = relTableData->getCSRLengthColumn();
45
+
46
+ // For each node group in the rel table
47
+ for (node_group_idx_t nodeGroupIdx = 0; nodeGroupIdx < numNodeGroups; nodeGroupIdx++) {
48
+ auto* nodeGroup = relTableData->getNodeGroup(nodeGroupIdx);
49
+ if (!nodeGroup) {
50
+ continue;
51
+ }
52
+
53
+ auto& csrNodeGroup = nodeGroup->cast<CSRNodeGroup>();
54
+
55
+ // Count from persistent (checkpointed) data
56
+ if (auto* persistentGroup = csrNodeGroup.getPersistentChunkedGroup()) {
57
+ // Sum the actual relationship lengths from the CSR header instead of using
58
+ // getNumRows() which includes dummy rows added for CSR offset array gaps
59
+ auto& csrPersistentGroup = persistentGroup->cast<ChunkedCSRNodeGroup>();
60
+ auto& csrHeader = csrPersistentGroup.getCSRHeader();
61
+
62
+ // Get the number of nodes in this CSR header
63
+ auto numNodes = csrHeader.length->getNumValues();
64
+ if (numNodes == 0) {
65
+ continue;
66
+ }
67
+
68
+ // Create an in-memory chunk to scan the CSR length column into
69
+ auto lengthChunk =
70
+ ColumnChunkFactory::createColumnChunkData(*memoryManager, LogicalType::UINT64(),
71
+ false /*enableCompression*/, StorageConfig::NODE_GROUP_SIZE,
72
+ ResidencyState::IN_MEMORY, false /*initializeToZero*/);
73
+
74
+ // Initialize scan state and scan the length column from disk
75
+ ChunkState chunkState;
76
+ csrHeader.length->initializeScanState(chunkState, csrLengthColumn);
77
+ csrLengthColumn->scan(chunkState, lengthChunk.get(), 0 /*offsetInChunk*/, numNodes);
78
+
79
+ // Sum all the lengths
80
+ auto* lengthData = reinterpret_cast<const uint64_t*>(lengthChunk->getData());
81
+ row_idx_t groupRelCount = 0;
82
+ for (offset_t i = 0; i < numNodes; ++i) {
83
+ groupRelCount += lengthData[i];
84
+ }
85
+ totalCount += groupRelCount;
86
+
87
+ // Subtract deletions from persistent data
88
+ if (persistentGroup->hasVersionInfo()) {
89
+ auto numDeletions =
90
+ persistentGroup->getNumDeletions(transaction, 0, groupRelCount);
91
+ totalCount -= numDeletions;
92
+ }
93
+ }
94
+
95
+ // Count in-memory committed data (not yet checkpointed)
96
+ // This data is stored in chunkedGroups within the NodeGroup
97
+ auto numChunkedGroups = csrNodeGroup.getNumChunkedGroups();
98
+ for (node_group_idx_t i = 0; i < numChunkedGroups; i++) {
99
+ auto* chunkedGroup = csrNodeGroup.getChunkedNodeGroup(i);
100
+ if (chunkedGroup) {
101
+ auto numRows = chunkedGroup->getNumRows();
102
+ totalCount += numRows;
103
+ // Subtract deletions from in-memory committed data
104
+ if (chunkedGroup->hasVersionInfo()) {
105
+ auto numDeletions = chunkedGroup->getNumDeletions(transaction, 0, numRows);
106
+ totalCount -= numDeletions;
107
+ }
108
+ }
109
+ }
110
+ }
111
+
112
+ // Add uncommitted insertions from local storage
113
+ if (transaction->isWriteTransaction()) {
114
+ if (auto* localTable =
115
+ transaction->getLocalStorage()->getLocalTable(relTable->getTableID())) {
116
+ auto& localRelTable = localTable->cast<LocalRelTable>();
117
+ // Count entries in the CSR index for this direction.
118
+ // We can't use getNumTotalRows() because it includes deleted rows.
119
+ auto& csrIndex = localRelTable.getCSRIndex(direction);
120
+ for (const auto& [nodeOffset, rowIndices] : csrIndex) {
121
+ totalCount += rowIndices.size();
122
+ }
123
+ }
124
+ }
125
+ }
126
+
127
+ hasExecuted = true;
128
+
129
+ // Write the count to the output vector (single value)
130
+ countVector->state->getSelVectorUnsafe().setToUnfiltered(1);
131
+ countVector->setValue<int64_t>(0, static_cast<int64_t>(totalCount));
132
+
133
+ return true;
134
+ }
135
+
136
+ } // namespace processor
137
+ } // namespace lbug
@@ -1,5 +1,6 @@
1
1
  #include "graph_test/private_graph_test.h"
2
2
  #include "planner/operator/logical_plan_util.h"
3
+ #include "planner/operator/scan/logical_count_rel_table.h"
3
4
  #include "test_runner/test_runner.h"
4
5
 
5
6
  namespace lbug {
@@ -17,6 +18,19 @@ public:
17
18
  std::unique_ptr<planner::LogicalPlan> getRoot(const std::string& query) {
18
19
  return TestRunner::getLogicalPlan(query, *conn);
19
20
  }
21
+
22
+ // Helper to check if a specific operator type exists in the plan
23
+ static bool hasOperatorType(planner::LogicalOperator* op, planner::LogicalOperatorType type) {
24
+ if (op->getOperatorType() == type) {
25
+ return true;
26
+ }
27
+ for (auto i = 0u; i < op->getNumChildren(); ++i) {
28
+ if (hasOperatorType(op->getChild(i).get(), type)) {
29
+ return true;
30
+ }
31
+ }
32
+ return false;
33
+ }
20
34
  };
21
35
 
22
36
  TEST_F(OptimizerTest, JoinHint) {
@@ -211,5 +225,37 @@ TEST_F(OptimizerTest, SubqueryHint) {
211
225
  ASSERT_STREQ(getEncodedPlan(q6).c_str(), "Filter()HJ(a._ID){S(a)}{E(a)Filter()S(b)}");
212
226
  }
213
227
 
228
+ TEST_F(OptimizerTest, CountRelTableOptimizer) {
229
+ // Test that COUNT(*) over a single rel table is optimized to COUNT_REL_TABLE
230
+ auto q1 = "MATCH (a:person)-[e:knows]->(b:person) RETURN COUNT(*);";
231
+ auto plan1 = getRoot(q1);
232
+ ASSERT_TRUE(hasOperatorType(plan1->getLastOperator().get(),
233
+ planner::LogicalOperatorType::COUNT_REL_TABLE));
234
+ // Verify the query returns the correct result
235
+ auto result1 = conn->query(q1);
236
+ ASSERT_TRUE(result1->isSuccess());
237
+ ASSERT_EQ(result1->getNumTuples(), 1);
238
+ auto tuple1 = result1->getNext();
239
+ ASSERT_EQ(tuple1->getValue(0)->getValue<int64_t>(), 14);
240
+
241
+ // Test that COUNT(*) with GROUP BY is NOT optimized (has keys)
242
+ auto q2 = "MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, COUNT(*);";
243
+ auto plan2 = getRoot(q2);
244
+ ASSERT_FALSE(hasOperatorType(plan2->getLastOperator().get(),
245
+ planner::LogicalOperatorType::COUNT_REL_TABLE));
246
+
247
+ // Test that COUNT(*) with WHERE clause is NOT optimized (has filter)
248
+ auto q3 = "MATCH (a:person)-[e:knows]->(b:person) WHERE a.ID > 0 RETURN COUNT(*);";
249
+ auto plan3 = getRoot(q3);
250
+ ASSERT_FALSE(hasOperatorType(plan3->getLastOperator().get(),
251
+ planner::LogicalOperatorType::COUNT_REL_TABLE));
252
+
253
+ // Test that COUNT(DISTINCT ...) is NOT optimized
254
+ auto q4 = "MATCH (a:person)-[e:knows]->(b:person) RETURN COUNT(DISTINCT a);";
255
+ auto plan4 = getRoot(q4);
256
+ ASSERT_FALSE(hasOperatorType(plan4->getLastOperator().get(),
257
+ planner::LogicalOperatorType::COUNT_REL_TABLE));
258
+ }
259
+
214
260
  } // namespace testing
215
261
  } // namespace lbug
@@ -0,0 +1,5 @@
1
+ -NAME count_rel_table
2
+ -PRERUN CREATE NODE TABLE account(ID INT64 PRIMARY KEY); CREATE REL TABLE follows(FROM account TO account); COPY account FROM "dataset/snap/amazon0601/parquet/amazon-nodes.parquet"; COPY follows FROM "dataset/snap/amazon0601/parquet/amazon-edges.parquet";
3
+ -QUERY MATCH ()-[r:follows]->() RETURN COUNT(*)
4
+ ---- 1
5
+ 3387388
@@ -543,6 +543,17 @@ std::vector<std::unique_ptr<QueryResult>> EmbeddedShell::processInput(std::strin
543
543
  historyLine = input;
544
544
  return queryResults;
545
545
  }
546
+ // Normalize trailing semicolons
547
+ if (!unicodeInput.empty() && unicodeInput.back() == ';') {
548
+ // trim trailing ;
549
+ while (!unicodeInput.empty() && unicodeInput.back() == ';') {
550
+ unicodeInput.pop_back();
551
+ }
552
+ if (unicodeInput.empty()) {
553
+ return queryResults;
554
+ }
555
+ unicodeInput += ';';
556
+ }
546
557
  // process shell commands
547
558
  if (!continueLine && unicodeInput[0] == ':') {
548
559
  processShellCommands(unicodeInput);
@@ -3306,7 +3306,7 @@ bool cypherComplete(const char* z) {
3306
3306
  /* Token: */
3307
3307
  /* State: ** SEMI WS OTHER */
3308
3308
  /* 0 INVALID: */ {
3309
- 1,
3309
+ 2,
3310
3310
  0,
3311
3311
  2,
3312
3312
  },
@@ -3553,8 +3553,8 @@ static int linenoiseEdit(int stdin_fd, int stdout_fd, char* buf, size_t buflen,
3553
3553
  // check if this forms a complete Cypher statement or not or if enter is pressed in
3554
3554
  // the middle of a line
3555
3555
  l.buf[l.len] = '\0';
3556
- if (l.buf[0] != ':' &&
3557
- (l.pos != l.len || linenoiseAllWhitespace(l.buf) || !cypherComplete(l.buf))) {
3556
+ if (l.buf[0] != ':' && l.pos == l.len && !linenoiseAllWhitespace(l.buf) &&
3557
+ !cypherComplete(l.buf)) {
3558
3558
  if (linenoiseEditInsertMulti(&l, "\r\n")) {
3559
3559
  return -1;
3560
3560
  }
@@ -289,6 +289,18 @@ def test_shell_auto_completion(temp_db) -> None:
289
289
  assert test.shell_process.expect_exact(["(0 tuples)", pexpect.EOF]) == 0
290
290
 
291
291
 
292
+ def test_double_semicolon(temp_db) -> None:
293
+ test = (
294
+ ShellTest()
295
+ .add_argument(temp_db)
296
+ .statement("CREATE NODE TABLE User(name STRING, PRIMARY KEY(name));")
297
+ .statement("CREATE REL TABLE Follows(FROM User TO User, since INT64);")
298
+ .statement('MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, b.name, f.since;;')
299
+ )
300
+ result = test.run()
301
+ result.check_stdout("(0 tuples)")
302
+
303
+
292
304
  def test_shell_unicode_input(temp_db) -> None:
293
305
  test = (
294
306
  ShellTest()
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "lbug",
3
- "version": "0.12.3-dev.14",
3
+ "version": "0.12.3-dev.15",
4
4
  "description": "An in-process property graph database management system built for query speed and scalability.",
5
5
  "main": "index.js",
6
6
  "module": "./index.mjs",
Binary file
Binary file
Binary file
Binary file