Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
package org.apache.gravitino.catalog.doris.utils;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -38,14 +40,15 @@
import org.apache.gravitino.rel.expressions.transforms.Transforms;
import org.apache.gravitino.rel.partitions.ListPartition;
import org.apache.gravitino.rel.partitions.Partition;
import org.apache.gravitino.rel.partitions.Partitions;
import org.apache.gravitino.rel.partitions.RangePartition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class DorisUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(DorisUtils.class);
private static final Pattern PARTITION_INFO_PATTERN =
Pattern.compile("PARTITION BY \\b(LIST|RANGE)\\b\\((.+)\\)");
Pattern.compile("PARTITION BY\\s+\\b(LIST|RANGE)\\b\\s*\\(([^)]+)\\)");

private static final Pattern DISTRIBUTION_INFO_PATTERN =
Pattern.compile(
Expand Down Expand Up @@ -98,23 +101,30 @@ public static Map<String, String> extractPropertiesFromSql(String createTableSql

public static Optional<Transform> extractPartitionInfoFromSql(String createTableSql) {
try {
String[] lines = createTableSql.split("\n");
for (String line : lines) {
Matcher matcher = PARTITION_INFO_PATTERN.matcher(line.trim());
if (matcher.matches()) {
String partitionType = matcher.group(1);
String partitionInfoString = matcher.group(2);
String[] columns =
Arrays.stream(partitionInfoString.split(", "))
.map(s -> s.substring(1, s.length() - 1))
.toArray(String[]::new);
if (LIST_PARTITION.equals(partitionType)) {
String[][] filedNames =
Arrays.stream(columns).map(s -> new String[] {s}).toArray(String[][]::new);
return Optional.of(Transforms.list(filedNames));
} else if (RANGE_PARTITION.equals(partitionType)) {
return Optional.of(Transforms.range(new String[] {columns[0]}));
// Merge all lines to handle multi-line partition definitions
String mergedSql = createTableSql.replace('\n', ' ');
Matcher matcher = PARTITION_INFO_PATTERN.matcher(mergedSql);
if (matcher.find()) {
String partitionType = matcher.group(1);
String partitionInfoString = matcher.group(2);
String[] columns =
Arrays.stream(partitionInfoString.split(","))
.map(String::trim)
.map(
s ->
(s.startsWith("`") && s.endsWith("`")) ? s.substring(1, s.length() - 1) : s)
.toArray(String[]::new);
if (LIST_PARTITION.equals(partitionType)) {
String[][] fieldNames =
Arrays.stream(columns).map(s -> new String[] {s}).toArray(String[][]::new);
// Try to extract partition assignments
ListPartition[] assignments = extractListPartitionAssignments(mergedSql);
if (assignments.length > 0) {
return Optional.of(Transforms.list(fieldNames, assignments));
}
return Optional.of(Transforms.list(fieldNames));
} else if (RANGE_PARTITION.equals(partitionType)) {
return Optional.of(Transforms.range(new String[] {columns[0]}));
}
}
return Optional.empty();
Expand All @@ -124,6 +134,107 @@ public static Optional<Transform> extractPartitionInfoFromSql(String createTable
}
}

private static ListPartition[] extractListPartitionAssignments(String mergedSql) {
try {
// Locate "PARTITION <name> VALUES IN (" and extract the outer paren content manually
// to correctly handle multi-column partitions: VALUES IN (("a", 1), ("b", 2))
Pattern headerPattern =
Pattern.compile("PARTITION\\s+(?:`([^`]+)`|(\\w+))\\s+VALUES\\s+IN\\s*\\(");
Matcher matcher = headerPattern.matcher(mergedSql);
List<ListPartition> partitions = new ArrayList<>();
while (matcher.find()) {
String partitionName = matcher.group(1) != null ? matcher.group(1) : matcher.group(2);
// Walk from the opening '(' to its balanced closing ')' to capture the full content
int openPos = matcher.end() - 1;
int depth = 0;
int closePos = -1;
for (int i = openPos; i < mergedSql.length(); i++) {
char c = mergedSql.charAt(i);
if (c == '(') depth++;
else if (c == ')') {
depth--;
if (depth == 0) {
closePos = i;
break;
}
}
}
if (closePos < 0) {
continue;
}
String outerContent = mergedSql.substring(openPos + 1, closePos).trim();
Literal<?>[][] values = parseListValues(outerContent);
partitions.add(Partitions.list(partitionName, values, null));
}
return partitions.isEmpty() ? new ListPartition[0] : partitions.toArray(new ListPartition[0]);
} catch (Exception e) {
LOGGER.warn("Failed to extract list partition assignments", e);
return new ListPartition[0];
}
}

/**
* Parses VALUES IN content into a 2-D literal array.
*
* <ul>
* <li>Single-column {@code "a", "b"} → {@code [["a"], ["b"]]}
* <li>Multi-column {@code ("a", 1), ("b", 2)} → {@code [["a", "1"], ["b", "2"]]}
* </ul>
*/
private static Literal<?>[][] parseListValues(String outerContent) {
if (!outerContent.startsWith("(")) {
// Single-column: each comma-separated token is one partition value row
return Arrays.stream(outerContent.split(","))
.map(String::trim)
.map(DorisUtils::stripQuotes)
.map(v -> new Literal<?>[] {Literals.stringLiteral(v)})
.toArray(Literal<?>[][]::new);
}
// Multi-column: each (...) group is one tuple
List<Literal<?>[]> tuples = new ArrayList<>();
int i = 0;
while (i < outerContent.length()) {
if (outerContent.charAt(i) == '(') {
int depth = 0;
int end = -1;
for (int j = i; j < outerContent.length(); j++) {
char c = outerContent.charAt(j);
if (c == '(') depth++;
else if (c == ')') {
depth--;
if (depth == 0) {
end = j;
break;
}
}
}
if (end > i) {
String tupleContent = outerContent.substring(i + 1, end);
Literal<?>[] row =
Arrays.stream(tupleContent.split(","))
.map(String::trim)
.map(DorisUtils::stripQuotes)
.map(Literals::stringLiteral)
.toArray(Literal<?>[]::new);
tuples.add(row);
i = end + 1;
} else {
i++;
}
} else {
i++;
}
}
return tuples.toArray(new Literal<?>[0][]);
}

private static String stripQuotes(String s) {
if (s.startsWith("\"") && s.endsWith("\"") && s.length() >= 2) {
return s.substring(1, s.length() - 1);
}
return s;
}

/**
* Generate sql fragment that create partition in Apache Doris.
*
Expand Down
Loading