Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions be/src/vec/exprs/lambda_function/lambda_function_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LambdaFunctionFactory;

void register_function_array_map(LambdaFunctionFactory& factory);
void register_function_array_filter(LambdaFunctionFactory& factory);
void register_function_array_sort(LambdaFunctionFactory& factory);

class LambdaFunctionFactory {
using Creator = std::function<LambdaFunctionPtr()>;
Expand Down Expand Up @@ -62,6 +63,7 @@ class LambdaFunctionFactory {
std::call_once(oc, []() {
register_function_array_map(instance);
register_function_array_filter(instance);
register_function_array_sort(instance);
});
return instance;
}
Expand Down
274 changes: 274 additions & 0 deletions be/src/vec/exprs/lambda_function/varray_sort_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <glog/logging.h>

#include "common/status.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_varbinary.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/block.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/data_types/data_type.h"
#include "vec/exprs/lambda_function/lambda_function.h"
#include "vec/exprs/lambda_function/lambda_function_factory.h"
#include "vec/exprs/vexpr.h"
#include "vec/utils/util.hpp"

namespace doris::vectorized {
#include "common/compile_check_begin.h"

class VExprContext;

using ConstColumnVariant =
std::variant<const ColumnUInt8*, const ColumnInt8*, const ColumnInt16*, const ColumnInt32*,
const ColumnInt64*, const ColumnInt128*, const ColumnFloat32*,
const ColumnFloat64*, const ColumnString*, const ColumnVarbinary*,
const ColumnArray*, const ColumnIPv4*, const ColumnIPv6*,
const ColumnDecimal32*, const ColumnDecimal64*, const ColumnDecimal128V2*,
const ColumnDecimal128V3*, const ColumnDecimal256*, const ColumnDate*,
const ColumnDateTime*, const ColumnDateV2*, const ColumnDateTimeV2*,
const ColumnTime*, const ColumnTimeV2*>;

template <typename T>
struct is_column_vector : std::false_type {};

template <PrimitiveType T>
struct is_column_vector<ColumnVector<T>> : std::true_type {};

template <typename T>
inline constexpr bool is_column_vector_v = is_column_vector<T>::value;

class ArraySortFunction : public LambdaFunction {
ENABLE_FACTORY_CREATOR(ArraySortFunction);

public:
~ArraySortFunction() override = default;
static constexpr auto name = "array_sort";

static LambdaFunctionPtr create() { return std::make_shared<ArraySortFunction>(); }

std::string get_name() const override { return name; }

Status execute(VExprContext* context, const vectorized::Block* block, ColumnPtr& result_column,
const DataTypePtr& result_type, const VExprSPtrs& children) const override {
///* array_sort(lambda, arg) *///

DCHECK_EQ(children.size(), 2);

// 1. get data, we need to obtain this actual data and type.
ColumnPtr column_ptr;
RETURN_IF_ERROR(children[1]->execute_column(context, block, column_ptr));
DataTypePtr type_ptr = children[1]->execute_type(block);

auto column = column_ptr->convert_to_full_column_if_const();

auto input_rows = column->size();

auto arg_type = type_ptr;
auto arg_column = column;

ColumnPtr outside_null_map = nullptr;

if (arg_column->is_nullable()) {
arg_column = assert_cast<const ColumnNullable*>(column.get())->get_nested_column_ptr();
outside_null_map = assert_cast<const ColumnNullable*>(column.get())
->get_null_map_column_ptr()
->assume_mutable();
arg_type = assert_cast<const DataTypeNullable*>(type_ptr.get())->get_nested_type();
}

const auto& col_array = assert_cast<const ColumnArray&>(*arg_column);
const auto& off_data =
assert_cast<const ColumnArray::ColumnOffsets&>(col_array.get_offsets_column())
.get_data();

const auto& nested_nullable_column =
assert_cast<const ColumnNullable&>(*col_array.get_data_ptr());

auto pType = assert_cast<const DataTypeArray*>(arg_type.get())
->get_nested_type()
->get_primitive_type();

// Get the actual type data based on PrimitiveType.
ConstColumnVariant src_data;
RETURN_IF_ERROR(
get_data_from_type(pType, nested_nullable_column.get_nested_column(), src_data));

const auto& src_nullmap = nested_nullable_column.get_null_map_column();

const auto& col_type = assert_cast<const DataTypeArray&>(*arg_type);

// 2. prepare a lambda_block for lambda execution
auto element_size = nested_nullable_column.size();
IColumn::Permutation permutation(element_size);
for (size_t i = 0; i < element_size; ++i) {
permutation[i] = i;
}

/**
* suppose the data_type is nullable(int). The first two rows are the parameter columns, and the
* last row is the result column(type: tinyint). every column's size is 1. the lambda_block is
* row data nullmap type
* 0 10 0 nullable(int)
* 1 20 1 nullable(int)
* 2 1/-1/0 ... tinyint
* The size of a column is always 1; we only need to use it to store the specific values ​​in the array for comparison.
*/
Block lambda_block;
for (int i = 0; i <= 2; i++) {
lambda_block.insert(vectorized::ColumnWithTypeAndName(
nested_nullable_column.clone_empty(), col_type.get_nested_type(), "temp"));
}

MutableColumnPtr temp_data[2];
NullMap* temp_nullmap_data[2] = {nullptr, nullptr};
for (int i = 0; i < 2; i++) {
auto* temp_column = assert_cast<ColumnNullable*>(
lambda_block.get_by_position(i).column->assume_mutable().get());
temp_data[i] = temp_column->get_nested_column_ptr();
auto& null_map_col = assert_cast<ColumnUInt8&>(temp_column->get_null_map_column());
temp_nullmap_data[i] = &null_map_col.get_data();
temp_data[i]->resize(1);
temp_nullmap_data[i]->resize(1);
};

int lambda_res_id = 2;

// 3. sort array by executing lambda function
// During the sorting process, the parameter columns of lambda_block are first populated using prepare_lambda_input,
// and then the lambda function is executed to obtain the result.
std::visit(
[&](auto* data) {
using ColumnType = std::decay_t<decltype(*data)>;
ColumnType* data_vec[2] = {assert_cast<ColumnType*>(temp_data[0].get()),
assert_cast<ColumnType*>(temp_data[1].get())};

// If columnType is ColumnVector<T>, use `get_data()[0]`;
// otherwise, need to clear it first, and then use `insert_from`.
auto prepare_lambda_input = [&](size_t i, size_t cid) {
if (src_nullmap.get_data()[i]) {
(*temp_nullmap_data[cid])[0] = 1;
} else {
(*temp_nullmap_data[cid])[0] = 0;
if constexpr (is_column_vector_v<ColumnType>) {
data_vec[cid]->get_data()[0] = data->get_data()[i];
} else {
data_vec[cid]->clear();
data_vec[cid]->insert_from(*data, i);
}
}
};

for (int row = 0; row < input_rows; ++row) {
auto start = off_data[row - 1];
auto end = off_data[row];
std::sort(&permutation[start], &permutation[end], [&](size_t i, size_t j) {
prepare_lambda_input(i, 0);
prepare_lambda_input(j, 1);
auto status =
children[0]->execute(context, &lambda_block, &lambda_res_id);
if (!status.ok()) [[unlikely]] {
throw Status::InternalError(
"when execute array_sort lambda function: {}",
status.to_string());
}

// raw_res_col maybe columnVector or ColumnConst
ColumnPtr raw_res_col =
lambda_block.get_by_position(lambda_res_id).column;
ColumnPtr full_res_col = raw_res_col->convert_to_full_column_if_const();

// only -1, 0, 1
long cmp = assert_cast<const ColumnInt8*>(full_res_col.get())
->get_data()[0];

return cmp < 0;
});
}
},
src_data);

// 4. set the result to result_column
ColumnWithTypeAndName result_arr;
if (result_type->is_nullable()) {
result_column = ColumnNullable::create(
ColumnArray::create(nested_nullable_column.permute(permutation, 0),
col_array.get_offsets_ptr()),
outside_null_map);

} else {
DCHECK(!column->is_nullable());
result_column = ColumnArray::create(nested_nullable_column.permute(permutation, 0),
col_array.get_offsets_ptr());
}

return Status::OK();
}

#define DISPATCH_PRIMITIVE_TYPE(TYPE, COLUMN_CLASS) \
case TYPE: \
column_variant = &assert_cast<const COLUMN_CLASS&>(column); \
break;

Status get_data_from_type(PrimitiveType pType, const IColumn& column,
ConstColumnVariant& column_variant) const {
switch (pType) {
DISPATCH_PRIMITIVE_TYPE(TYPE_BOOLEAN, ColumnUInt8)
DISPATCH_PRIMITIVE_TYPE(TYPE_TINYINT, ColumnInt8)
DISPATCH_PRIMITIVE_TYPE(TYPE_SMALLINT, ColumnInt16)
DISPATCH_PRIMITIVE_TYPE(TYPE_INT, ColumnInt32)
DISPATCH_PRIMITIVE_TYPE(TYPE_BIGINT, ColumnInt64)
DISPATCH_PRIMITIVE_TYPE(TYPE_LARGEINT, ColumnInt128)
DISPATCH_PRIMITIVE_TYPE(TYPE_FLOAT, ColumnFloat32)
DISPATCH_PRIMITIVE_TYPE(TYPE_DOUBLE, ColumnFloat64)
DISPATCH_PRIMITIVE_TYPE(TYPE_CHAR, ColumnString)
DISPATCH_PRIMITIVE_TYPE(TYPE_STRING, ColumnString)
DISPATCH_PRIMITIVE_TYPE(TYPE_VARCHAR, ColumnString)
DISPATCH_PRIMITIVE_TYPE(TYPE_VARBINARY, ColumnVarbinary)
DISPATCH_PRIMITIVE_TYPE(TYPE_ARRAY, ColumnArray)
DISPATCH_PRIMITIVE_TYPE(TYPE_IPV4, ColumnIPv4)
DISPATCH_PRIMITIVE_TYPE(TYPE_IPV6, ColumnIPv6)
DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL32, ColumnDecimal32)
DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL64, ColumnDecimal64)
DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL128I, ColumnDecimal128V3)
DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMALV2, ColumnDecimal128V2)
DISPATCH_PRIMITIVE_TYPE(TYPE_DECIMAL256, ColumnDecimal256)
DISPATCH_PRIMITIVE_TYPE(TYPE_DATE, ColumnDate)
DISPATCH_PRIMITIVE_TYPE(TYPE_DATETIME, ColumnDateTime)
DISPATCH_PRIMITIVE_TYPE(TYPE_DATEV2, ColumnDateV2)
DISPATCH_PRIMITIVE_TYPE(TYPE_DATETIMEV2, ColumnDateTimeV2)
DISPATCH_PRIMITIVE_TYPE(TYPE_TIME, ColumnTime)
DISPATCH_PRIMITIVE_TYPE(TYPE_TIMEV2, ColumnTimeV2)
default:
return Status::InternalError("Unsupported type in array_sort");
}
return Status::OK();
}

#undef DISPATCH_PRIMITIVE_TYPE
};

void register_function_array_sort(doris::vectorized::LambdaFunctionFactory& factory) {
factory.register_function<ArraySortFunction>();
}

#include "common/compile_check_end.h"
} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DictGet;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DictGetMany;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
Expand Down Expand Up @@ -571,6 +572,61 @@ public Expr visitArrayMap(ArrayMap arrayMap, PlanTranslatorContext context) {
return functionCallExpr;
}

@Override
public Expr visitArraySort(ArraySort arraySort, PlanTranslatorContext context) {
if (!(arraySort.child(0) instanceof Lambda)) {
return visitScalarFunction(arraySort, context);
}
Lambda lambda = (Lambda) arraySort.child(0);
List<Expr> arguments = new ArrayList<>(arraySort.children().size());
arguments.add(null);

// Construct the first column
ArrayItemReference arrayItemReference = lambda.getLambdaArgument(0);
String argName = arrayItemReference.getName();
Expr expr = arrayItemReference.getArrayExpression().accept(this, context);
arguments.add(expr);
ColumnRefExpr column = new ColumnRefExpr();
column.setName(argName);
column.setColumnId(0);
column.setNullable(true);
column.setType(((ArrayType) expr.getType()).getItemType());
context.addExprIdColumnRefPair(arrayItemReference.getExprId(), column);

// the second column here will not be used; it's just a placeholder.
arrayItemReference = lambda.getLambdaArgument(1);
ColumnRefExpr column2 = new ColumnRefExpr(column);
column2.setColumnId(1);
context.addExprIdColumnRefPair(arrayItemReference.getExprId(), column2);

List<Type> argTypes = arraySort.getArguments().stream()
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.collect(Collectors.toList());
// two slots are same, we only need one
lambda.getLambdaArguments().stream().skip(1)
.map(ArrayItemReference::getArrayExpression)
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.forEach(argTypes::add);
NullableMode nullableMode = arraySort.nullable()
? NullableMode.ALWAYS_NULLABLE
: NullableMode.ALWAYS_NOT_NULLABLE;
Type itemType = ((ArrayType) arguments.get(1).getType()).getItemType();
org.apache.doris.catalog.Function catalogFunction = new Function(
new FunctionName(arraySort.getName()), argTypes,
ArrayType.create(itemType, true),
true, true, nullableMode);

// create catalog FunctionCallExpr without analyze again
Expr lambdaBody = visitLambda(lambda, context);
arguments.set(0, lambdaBody);
LambdaFunctionCallExpr functionCallExpr = new LambdaFunctionCallExpr(catalogFunction,
new FunctionParams(false, arguments));
functionCallExpr.setNullableFromNereids(arraySort.nullable());
return functionCallExpr;
}

@Override
public Expr visitDictGet(DictGet dictGet, PlanTranslatorContext context) {
List<Expr> arguments = dictGet.getArguments().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ private UnboundFunction processHighOrderFunction(UnboundFunction unboundFunction
// bindLambdaFunction
Lambda lambda = (Lambda) unboundFunction.children().get(0);
Expression lambdaFunction = lambda.getLambdaFunction();
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(subChildren);
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(unboundFunction.getName(), subChildren);

List<Slot> boundedSlots = arrayItemReferences.stream()
.map(ArrayItemReference::toSlot)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAny;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSplit;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySplit;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
Expand Down Expand Up @@ -247,6 +248,17 @@ public Void visitArrayMap(ArrayMap arrayMap, CollectorContext context) {
return visit(arrayMap, context);
}

@Override
public Void visitArraySort(ArraySort arraySort, CollectorContext context) {
// ARRAY_SORT(lambda, <arr>)

Expression argument = arraySort.getArgument(0);
if ((argument instanceof Lambda)) {
return collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arraySort, context);
}

@Override
public Void visitArrayCount(ArrayCount arrayCount, CollectorContext context) {
// ARRAY_COUNT(<lambda>, <arr>[, ... ])
Expand Down
Loading