diff --git a/include/paimon/bucket/bucket_function_type.h b/include/paimon/bucket/bucket_function_type.h new file mode 100644 index 0000000..7e764ea --- /dev/null +++ b/include/paimon/bucket/bucket_function_type.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "paimon/defs.h" +#include "paimon/visibility.h" + +namespace paimon { + +/// Specifies the bucket function type for paimon bucket. +/// This determines how rows are assigned to buckets during data writing. +enum class BucketFunctionType { + /// The default bucket function which will use arithmetic: + /// bucket_id = abs(hash_bucket_binary_row % numBuckets) to get bucket. + DEFAULT = 1, + /// The modulus bucket function which will use modulus arithmetic: + /// bucket_id = floorMod(bucket_key_value, numBuckets) to get bucket. + /// Note: the bucket key must be a single field of INT or BIGINT datatype. + MOD = 2, + /// The hive bucket function which will use hive-compatible hash arithmetic to get bucket. + HIVE = 3 +}; + +/// Describes a field's type information needed for Hive hashing. +struct PAIMON_EXPORT HiveFieldInfo { + FieldType type; + int32_t precision = 0; // Used for DECIMAL type + int32_t scale = 0; // Used for DECIMAL type + + explicit HiveFieldInfo(FieldType t) : type(t) {} + HiveFieldInfo(FieldType t, int32_t p, int32_t s) : type(t), precision(p), scale(s) {} +}; + +} // namespace paimon diff --git a/include/paimon/bucket/bucket_id_calculator.h b/include/paimon/bucket/bucket_id_calculator.h new file mode 100644 index 0000000..5f94573 --- /dev/null +++ b/include/paimon/bucket/bucket_id_calculator.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#include "paimon/bucket/bucket_function_type.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/result.h" +#include "paimon/status.h" +#include "paimon/visibility.h" + +struct ArrowSchema; +struct ArrowArray; + +namespace paimon { +class BucketFunction; +class MemoryPool; + +/// Calculator for determining bucket ids based on the given bucket keys. +/// +/// @note `BucketIdCalculator` is compatible with the Java implementation and uses +/// hash-based distribution to ensure even data distribution across buckets. +class PAIMON_EXPORT BucketIdCalculator { + public: + /// Create `BucketIdCalculator` with default bucket function. + /// @param is_pk_table Whether this is for a primary key table. + /// @param num_buckets Number of buckets. + /// @param pool Memory pool for memory allocation. + static Result> Create( + bool is_pk_table, int32_t num_buckets, const std::shared_ptr& pool); + + /// Create `BucketIdCalculator` with a custom bucket function. + /// @param is_pk_table Whether this is for a primary key table. + /// @param num_buckets Number of buckets. + /// @param bucket_function The bucket function to use for bucket assignment. + /// @param pool Memory pool for memory allocation. + static Result> Create( + bool is_pk_table, int32_t num_buckets, std::unique_ptr bucket_function, + const std::shared_ptr& pool); + + /// Create `BucketIdCalculator` with MOD bucket function. + /// @param is_pk_table Whether this is for a primary key table. + /// @param num_buckets Number of buckets. + /// @param bucket_key_type The type of the single bucket key field. Must be INT or BIGINT. + /// @param pool Memory pool for memory allocation. + static Result> CreateMod( + bool is_pk_table, int32_t num_buckets, FieldType bucket_key_type, + const std::shared_ptr& pool); + + /// Create `BucketIdCalculator` with HIVE bucket function. + /// @param is_pk_table Whether this is for a primary key table. + /// @param num_buckets Number of buckets. + /// @param field_infos The detailed type info of all fields in the bucket key row. + /// @param pool Memory pool for memory allocation. + static Result> CreateHive( + bool is_pk_table, int32_t num_buckets, const std::vector& field_infos, + const std::shared_ptr& pool); + + /// Calculate bucket ids for the given bucket keys. + /// @param bucket_keys Arrow struct array containing the bucket key values. + /// @param bucket_schema Arrow schema describing the structure of bucket_keys. + /// @param bucket_ids Output array to store calculated bucket ids. + /// @note 1. bucket_keys is a struct array, the order of fields needs to be consistent with + /// "bucket-key" options in table schema. 2. bucket_keys and bucket_schema match each other. 3. + /// bucket_ids is allocated enough space, at least >= bucket_keys->length + Status CalculateBucketIds(ArrowArray* bucket_keys, ArrowSchema* bucket_schema, + int32_t* bucket_ids) const; + + /// Destructor + ~BucketIdCalculator(); + + private: + BucketIdCalculator(int32_t num_buckets, std::unique_ptr bucket_function, + const std::shared_ptr& pool); + + private: + int32_t num_buckets_; + std::unique_ptr bucket_function_; + std::shared_ptr pool_; +}; +} // namespace paimon diff --git a/src/paimon/core/bucket/bucket_function.h b/src/paimon/core/bucket/bucket_function.h new file mode 100644 index 0000000..f02f469 --- /dev/null +++ b/src/paimon/core/bucket/bucket_function.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace paimon { + +class BinaryRow; + +/// Abstract interface for bucket functions. +/// A bucket function determines which bucket a row should be assigned to. +class BucketFunction { + public: + virtual ~BucketFunction() = default; + + /// Compute the bucket for the given row. + /// @param row The binary row to compute the bucket for. + /// @param num_buckets The total number of buckets. + /// @return The bucket index (0-based). + virtual int32_t Bucket(const BinaryRow& row, int32_t num_buckets) const = 0; +}; + +} // namespace paimon diff --git a/src/paimon/core/bucket/bucket_id_calculator.cpp b/src/paimon/core/bucket/bucket_id_calculator.cpp new file mode 100644 index 0000000..6bede5a --- /dev/null +++ b/src/paimon/core/bucket/bucket_id_calculator.cpp @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/bucket/bucket_id_calculator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/api.h" +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/array_decimal.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/c/abi.h" +#include "arrow/c/bridge.h" +#include "arrow/c/helpers.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" +#include "fmt/format.h" +#include "paimon/common/data/binary_row.h" +#include "paimon/common/data/binary_row_writer.h" +#include "paimon/common/utils/arrow/status_utils.h" +#include "paimon/common/utils/date_time_utils.h" +#include "paimon/common/utils/scope_guard.h" +#include "paimon/core/bucket/bucket_function.h" +#include "paimon/core/bucket/default_bucket_function.h" +#include "paimon/core/bucket/hive_bucket_function.h" +#include "paimon/core/bucket/mod_bucket_function.h" +#include "paimon/data/decimal.h" +#include "paimon/data/timestamp.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/result.h" + +namespace paimon { +namespace { +#define CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id) \ + if (typed_array->IsNull(row_id)) { \ + row_writer->SetNullAt(col_id); \ + return; \ + } + +using WriteFunction = std::function; +static Result WriteBucketRow(int32_t col_id, + const std::shared_ptr& field) { + arrow::Type::type type = field->type()->id(); + switch (type) { + case arrow::Type::type::BOOL: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteBoolean(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::INT8: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteByte(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::INT16: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteShort(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::INT32: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteInt(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::INT64: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteLong(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::FLOAT: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteFloat(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::DOUBLE: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteDouble(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::DATE32: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + row_writer->WriteInt(col_id, typed_array->Value(row_id)); + }; + return writer_func; + } + case arrow::Type::type::STRING: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + std::string_view value = typed_array->GetView(row_id); + row_writer->WriteStringView(col_id, value); + }; + return writer_func; + } + case arrow::Type::type::BINARY: { + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array](int32_t row_id, + BinaryRowWriter* row_writer) { + CHECK_AND_SET_NULL(typed_array, row_writer, row_id, col_id); + std::string_view value = typed_array->GetView(row_id); + row_writer->WriteStringView(col_id, value); + }; + return writer_func; + } + case arrow::Type::type::TIMESTAMP: { + auto timestamp_type = + arrow::internal::checked_pointer_cast(field->type()); + assert(timestamp_type); + int32_t precision = DateTimeUtils::GetPrecisionFromType(timestamp_type); + DateTimeUtils::TimeType time_type = + DateTimeUtils::GetTimeTypeFromArrowType(timestamp_type); + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [typed_array, col_id, precision, time_type]( + int32_t row_id, BinaryRowWriter* row_writer) { + if (typed_array->IsNull(row_id)) { + if (!Timestamp::IsCompact(precision)) { + row_writer->WriteTimestamp(col_id, std::nullopt, precision); + } else { + row_writer->SetNullAt(col_id); + } + return; + } + int64_t ts_value = typed_array->Value(row_id); + auto [milli, nano] = DateTimeUtils::TimestampConverter( + ts_value, time_type, DateTimeUtils::TimeType::MILLISECOND, + DateTimeUtils::TimeType::NANOSECOND); + row_writer->WriteTimestamp(col_id, Timestamp(milli, nano), precision); + }; + return writer_func; + } + case arrow::Type::type::DECIMAL: { + const auto* decimal_type = + arrow::internal::checked_cast(field->type().get()); + assert(decimal_type); + auto precision = decimal_type->precision(); + auto scale = decimal_type->scale(); + const auto* typed_array = + arrow::internal::checked_cast(field.get()); + assert(typed_array); + WriteFunction writer_func = [col_id, typed_array, precision, scale]( + int32_t row_id, BinaryRowWriter* row_writer) { + if (typed_array->IsNull(row_id)) { + if (!Decimal::IsCompact(precision)) { + row_writer->WriteDecimal(col_id, std::nullopt, precision); + } else { + row_writer->SetNullAt(col_id); + } + return; + } + arrow::Decimal128 decimal128(typed_array->GetValue(row_id)); + Decimal decimal(precision, scale, + static_cast( + static_cast( + static_cast(decimal128.high_bits())) + << 64 | + decimal128.low_bits())); + row_writer->WriteDecimal(col_id, decimal, precision); + }; + return writer_func; + } + default: + return Status::Invalid( + fmt::format("type {} not support in write bucket row", field->type()->ToString())); + } +} +} // namespace + +BucketIdCalculator::BucketIdCalculator(int32_t num_buckets, + std::unique_ptr bucket_function, + const std::shared_ptr& pool) + : num_buckets_(num_buckets), bucket_function_(std::move(bucket_function)), pool_(pool) {} + +BucketIdCalculator::~BucketIdCalculator() = default; + +Result> BucketIdCalculator::Create( + bool is_pk_table, int32_t num_buckets, const std::shared_ptr& pool) { + return Create(is_pk_table, num_buckets, std::make_unique(), pool); +} + +Result> BucketIdCalculator::Create( + bool is_pk_table, int32_t num_buckets, std::unique_ptr bucket_function, + const std::shared_ptr& pool) { + if (num_buckets == 0 || num_buckets < -2) { + return Status::Invalid("num buckets must be -1 or -2 or greater than 0"); + } + if (is_pk_table && num_buckets == -1) { + return Status::Invalid( + "DynamicBucketMode or CrossPartitionBucketMode cannot calculate bucket id in " + "primary key table"); + } + if (!is_pk_table && num_buckets == -2) { + return Status::Invalid("Append table not support PostponeBucketMode"); + } + return std::unique_ptr( + new BucketIdCalculator(num_buckets, std::move(bucket_function), pool)); +} + +Result> BucketIdCalculator::CreateMod( + bool is_pk_table, int32_t num_buckets, FieldType bucket_key_type, + const std::shared_ptr& pool) { + PAIMON_ASSIGN_OR_RAISE(auto mod_func, ModBucketFunction::Create(bucket_key_type)); + return Create(is_pk_table, num_buckets, std::move(mod_func), pool); +} + +Result> BucketIdCalculator::CreateHive( + bool is_pk_table, int32_t num_buckets, const std::vector& field_infos, + const std::shared_ptr& pool) { + PAIMON_ASSIGN_OR_RAISE(auto hive_func, HiveBucketFunction::Create(field_infos)); + return Create(is_pk_table, num_buckets, std::move(hive_func), pool); +} + +Status BucketIdCalculator::CalculateBucketIds(ArrowArray* bucket_keys, ArrowSchema* bucket_schema, + int32_t* bucket_ids) const { + ScopeGuard guard([bucket_keys, bucket_schema]() { + ArrowArrayRelease(bucket_keys); + ArrowSchemaRelease(bucket_schema); + }); + if (num_buckets_ == -1 || num_buckets_ == 1) { + memset(bucket_ids, 0, bucket_keys->length * sizeof(int32_t)); + return Status::OK(); + } + if (num_buckets_ == -2) { + for (int32_t i = 0; i < bucket_keys->length; i++) { + bucket_ids[i] = -2; + } + return Status::OK(); + } + + PAIMON_ASSIGN_OR_RAISE_FROM_ARROW(std::shared_ptr bucket_array, + arrow::ImportArray(bucket_keys, bucket_schema)); + const auto* struct_array = + arrow::internal::checked_cast(bucket_array.get()); + if (!struct_array) { + return Status::Invalid("bucket keys is not a struct array"); + } + std::vector write_functions; + int32_t num_fields = struct_array->num_fields(); + write_functions.reserve(num_fields); + for (int32_t col = 0; col < num_fields; col++) { + PAIMON_ASSIGN_OR_RAISE(WriteFunction write_func, + WriteBucketRow(col, struct_array->field(col))); + write_functions.push_back(std::move(write_func)); + } + + BinaryRow bucket_row(num_fields); + BinaryRowWriter row_writer(&bucket_row, /*initial_size=*/1024, pool_.get()); + for (int32_t row = 0; row < struct_array->length(); row++) { + row_writer.Reset(); + for (int32_t col = 0; col < num_fields; col++) { + write_functions[col](row, &row_writer); + } + row_writer.Complete(); + bucket_ids[row] = bucket_function_->Bucket(bucket_row, num_buckets_); + } + guard.Release(); + return Status::OK(); +} + +} // namespace paimon diff --git a/src/paimon/core/bucket/bucket_id_calculator_test.cpp b/src/paimon/core/bucket/bucket_id_calculator_test.cpp new file mode 100644 index 0000000..97284a6 --- /dev/null +++ b/src/paimon/core/bucket/bucket_id_calculator_test.cpp @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/bucket/bucket_id_calculator.h" + +#include +#include +#include + +#include "arrow/api.h" +#include "arrow/array/array_base.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/c/abi.h" +#include "arrow/c/bridge.h" +#include "arrow/ipc/json_simple.h" +#include "arrow/util/checked_cast.h" +#include "gtest/gtest.h" +#include "paimon/common/utils/arrow/status_utils.h" +#include "paimon/common/utils/date_time_utils.h" +#include "paimon/core/bucket/default_bucket_function.h" +#include "paimon/core/bucket/mod_bucket_function.h" +#include "paimon/fs/local/local_file_system.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { +class BucketIdCalculatorTest : public ::testing::Test { + public: + void SetUp() override {} + void TearDown() override {} + Result> CalculateBucketIds( + bool is_pk_table, int32_t num_buckets, const std::shared_ptr& bucket_schema, + const std::shared_ptr& bucket_array) const { + ::ArrowArray c_bucket_array; + EXPECT_TRUE(arrow::ExportArray(*bucket_array, &c_bucket_array).ok()); + ::ArrowSchema c_bucket_schema; + EXPECT_TRUE(arrow::ExportSchema(*bucket_schema, &c_bucket_schema).ok()); + std::vector bucket_ids(bucket_array->length()); + EXPECT_OK_AND_ASSIGN(auto bucket_id_cal, BucketIdCalculator::Create( + is_pk_table, num_buckets, GetDefaultPool())); + PAIMON_RETURN_NOT_OK(bucket_id_cal->CalculateBucketIds( + /*bucket_keys=*/&c_bucket_array, /*bucket_schema=*/&c_bucket_schema, + /*bucket_ids=*/bucket_ids.data())); + return bucket_ids; + } + + Result> CalculateBucketIds( + bool is_pk_table, int32_t num_buckets, const std::shared_ptr& bucket_schema, + const std::string& data_str) const { + PAIMON_ASSIGN_OR_RAISE_FROM_ARROW(auto bucket_array, + arrow::ipc::internal::json::ArrayFromJSON( + arrow::struct_(bucket_schema->fields()), data_str)); + return CalculateBucketIds(is_pk_table, num_buckets, bucket_schema, bucket_array); + } + + Result> CalculateBucketIds( + bool is_pk_table, int32_t num_buckets, std::unique_ptr bucket_function, + const std::shared_ptr& bucket_schema, + const std::shared_ptr& bucket_array) const { + ::ArrowArray c_bucket_array; + EXPECT_TRUE(arrow::ExportArray(*bucket_array, &c_bucket_array).ok()); + ::ArrowSchema c_bucket_schema; + EXPECT_TRUE(arrow::ExportSchema(*bucket_schema, &c_bucket_schema).ok()); + std::vector bucket_ids(bucket_array->length()); + EXPECT_OK_AND_ASSIGN(auto bucket_id_cal, BucketIdCalculator::Create( + is_pk_table, num_buckets, + std::move(bucket_function), GetDefaultPool())); + PAIMON_RETURN_NOT_OK(bucket_id_cal->CalculateBucketIds( + /*bucket_keys=*/&c_bucket_array, /*bucket_schema=*/&c_bucket_schema, + /*bucket_ids=*/bucket_ids.data())); + return bucket_ids; + } + + Result> CalculateBucketIds( + bool is_pk_table, int32_t num_buckets, std::unique_ptr bucket_function, + const std::shared_ptr& bucket_schema, const std::string& data_str) const { + PAIMON_ASSIGN_OR_RAISE_FROM_ARROW(auto bucket_array, + arrow::ipc::internal::json::ArrayFromJSON( + arrow::struct_(bucket_schema->fields()), data_str)); + return CalculateBucketIds(is_pk_table, num_buckets, std::move(bucket_function), + bucket_schema, bucket_array); + } +}; + +TEST_F(BucketIdCalculatorTest, TestCompatibleWithJava) { + // 5000 random records, first 12 column is record data, the last column is the bucket id + // calculated by Java FixedBucketRowKeyExtractor + std::string data_path = paimon::test::GetDataDir() + "/record_for_bucket_id.data"; + std::string content; + auto fs = std::make_unique(); + ASSERT_OK(fs->ReadFile(data_path, &content)); + content = content.substr(0, content.length() - 2); + content = "[" + content + "]"; + + arrow::FieldVector raw_bucket_fields = { + arrow::field("v0", arrow::boolean()), + arrow::field("v1", arrow::int8()), + arrow::field("v2", arrow::int16()), + arrow::field("v3", arrow::int32()), + arrow::field("v4", arrow::int64()), + arrow::field("v5", arrow::float32()), + arrow::field("v6", arrow::float64()), + arrow::field("v7", arrow::date32()), + arrow::field("v8", arrow::timestamp(arrow::TimeUnit::NANO)), + arrow::field("v9", arrow::decimal128(30, 20)), + arrow::field("v10", arrow::utf8()), + arrow::field("v11", arrow::binary())}; + auto bucket_schema = arrow::schema(raw_bucket_fields); + + arrow::FieldVector bucket_fields_with_id = bucket_schema->fields(); + bucket_fields_with_id.push_back(arrow::field("bucket_id", arrow::int32())); + auto bucket_array_with_id = std::dynamic_pointer_cast( + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_fields_with_id), content) + .ValueOrDie()); + + // exclude bucket id array + arrow::ArrayVector bucket_fields(bucket_array_with_id->fields().begin(), + bucket_array_with_id->fields().end() - 1); + auto bucket_array = + arrow::StructArray::Make(bucket_fields, bucket_schema->fields()).ValueOrDie(); + + ASSERT_OK_AND_ASSIGN(std::vector result, + CalculateBucketIds(/*is_pk_table=*/true, /*num_buckets=*/12345, + bucket_schema, bucket_array)); + + auto bucket_id_array = arrow::internal::checked_cast( + bucket_array_with_id->field(bucket_schema->num_fields()).get()); + ASSERT_TRUE(bucket_id_array); + // test compatible with java + for (int32_t i = 0; i < bucket_array->length(); i++) { + ASSERT_EQ(bucket_id_array->Value(i), result[i]); + } +} + +TEST_F(BucketIdCalculatorTest, TestCompatibleWithJavaWithNull) { + // 5000 random records, first 13 column is record data, the last column is the bucket id + // calculated by Java FixedBucketRowKeyExtractor. Besides, the first row is all null. + std::string data_path = paimon::test::GetDataDir() + "/record_with_null_for_bucket_id.data"; + std::string content; + auto fs = std::make_unique(); + ASSERT_OK(fs->ReadFile(data_path, &content)); + content = content.substr(0, content.length() - 2); + content = "[" + content + "]"; + + arrow::FieldVector raw_bucket_fields = { + arrow::field("v0", arrow::boolean()), + arrow::field("v1", arrow::int8()), + arrow::field("v2", arrow::int16()), + arrow::field("v3", arrow::int32()), + arrow::field("v4", arrow::int64()), + arrow::field("v5", arrow::float32()), + arrow::field("v6", arrow::float64()), + arrow::field("v7", arrow::date32()), + arrow::field("v8", arrow::timestamp(arrow::TimeUnit::NANO)), + arrow::field("v9", arrow::decimal128(5, 2)), + arrow::field("v10", arrow::decimal128(30, 2)), + arrow::field("v11", arrow::utf8()), + arrow::field("v12", arrow::binary())}; + auto bucket_schema = arrow::schema(raw_bucket_fields); + + arrow::FieldVector bucket_fields_with_id = bucket_schema->fields(); + bucket_fields_with_id.push_back(arrow::field("bucket_id", arrow::int32())); + auto bucket_array_with_id = std::dynamic_pointer_cast( + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_fields_with_id), content) + .ValueOrDie()); + + // exclude bucket id array + arrow::ArrayVector bucket_fields(bucket_array_with_id->fields().begin(), + bucket_array_with_id->fields().end() - 1); + auto bucket_array = + arrow::StructArray::Make(bucket_fields, bucket_schema->fields()).ValueOrDie(); + + ASSERT_OK_AND_ASSIGN(std::vector result, + CalculateBucketIds(/*is_pk_table=*/false, /*num_buckets=*/12345, + bucket_schema, bucket_array)); + + auto bucket_id_array = arrow::internal::checked_cast( + bucket_array_with_id->field(bucket_schema->num_fields()).get()); + ASSERT_TRUE(bucket_id_array); + // test compatible with java + for (int32_t i = 0; i < bucket_array->length(); i++) { + ASSERT_EQ(bucket_id_array->Value(i), result[i]); + } +} + +TEST_F(BucketIdCalculatorTest, TestCompatibleWithJavaWithTimestamp) { + // 5000 random records, first 8 column is record data, the last column is the bucket id + // calculated by Java FixedBucketRowKeyExtractor. Besides, the first row is all null. + std::string data_path = paimon::test::GetDataDir() + "record_with_timestamp_for_bucket_id.data"; + std::string content; + auto fs = std::make_unique(); + ASSERT_OK(fs->ReadFile(data_path, &content)); + content = content.substr(0, content.length() - 2); + content = "[" + content + "]"; + auto timezone = DateTimeUtils::GetLocalTimezoneName(); + arrow::FieldVector raw_bucket_fields = { + arrow::field("ts_sec", arrow::timestamp(arrow::TimeUnit::SECOND)), + arrow::field("ts_milli", arrow::timestamp(arrow::TimeUnit::MILLI)), + arrow::field("ts_micro", arrow::timestamp(arrow::TimeUnit::MICRO)), + arrow::field("ts_nano", arrow::timestamp(arrow::TimeUnit::NANO)), + arrow::field("ts_tz_sec", arrow::timestamp(arrow::TimeUnit::SECOND, timezone)), + arrow::field("ts_tz_milli", arrow::timestamp(arrow::TimeUnit::MILLI, timezone)), + arrow::field("ts_tz_micro", arrow::timestamp(arrow::TimeUnit::MICRO, timezone)), + arrow::field("ts_tz_nano", arrow::timestamp(arrow::TimeUnit::NANO, timezone)), + }; + auto bucket_schema = std::make_shared(raw_bucket_fields); + + arrow::FieldVector bucket_fields_with_id = bucket_schema->fields(); + bucket_fields_with_id.push_back(arrow::field("bucket_id", arrow::int32())); + auto bucket_array_with_id = std::dynamic_pointer_cast( + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_fields_with_id), content) + .ValueOrDie()); + + // exclude bucket id array + arrow::ArrayVector bucket_fields(bucket_array_with_id->fields().begin(), + bucket_array_with_id->fields().end() - 1); + auto bucket_array = + arrow::StructArray::Make(bucket_fields, bucket_schema->fields()).ValueOrDie(); + + ASSERT_OK_AND_ASSIGN(std::vector result, + CalculateBucketIds(/*is_pk_table=*/false, /*num_buckets=*/12345, + bucket_schema, bucket_array)); + + auto bucket_id_array = arrow::internal::checked_cast( + bucket_array_with_id->field(bucket_schema->num_fields()).get()); + ASSERT_TRUE(bucket_id_array); + // test compatible with java + for (int32_t i = 0; i < bucket_array->length(); i++) { + ASSERT_EQ(bucket_id_array->Value(i), result[i]); + } +} + +TEST_F(BucketIdCalculatorTest, TestInvalidCase) { + { + // test invalid bucket id + ASSERT_NOK_WITH_MSG( + BucketIdCalculator::Create(/*is_pk_table=*/true, /*num_buckets=*/0, GetDefaultPool()), + "num buckets must be -1 or -2 or greater than 0"); + } + { + // test invalid bucket mode with pk table + ASSERT_NOK_WITH_MSG( + BucketIdCalculator::Create(/*is_pk_table=*/true, /*num_buckets=*/-1, GetDefaultPool()), + "DynamicBucketMode or CrossPartitionBucketMode cannot calculate bucket id"); + } + { + // test invalid bucket mode with append table + ASSERT_NOK_WITH_MSG( + BucketIdCalculator::Create(/*is_pk_table=*/false, /*num_buckets=*/-2, GetDefaultPool()), + "Append table not support PostponeBucketMode"); + } + { + // test invalid bucket_keys + auto bucket_schema = + arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + auto bucket_array = + arrow::ipc::internal::json::ArrayFromJSON(arrow::int32(), "[10, 11, 12, 13]") + .ValueOrDie(); + ASSERT_NOK_WITH_MSG( + CalculateBucketIds(/*is_pk_table=*/false, 10, bucket_schema, bucket_array), + "ArrowArray struct has 0 children"); + } + { + // test invalid data type + auto bucket_schema = arrow::schema(arrow::FieldVector( + {arrow::field("b0", arrow::int32()), arrow::field("b1", arrow::list(arrow::int64()))})); + ASSERT_NOK_WITH_MSG( + CalculateBucketIds(/*is_pk_table=*/true, 10, bucket_schema, "[[10, [1, 1, 2]]]"), + "type list not support in write bucket row"); + } +} + +TEST_F(BucketIdCalculatorTest, TestUnawareBucket) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + ASSERT_OK_AND_ASSIGN( + std::vector result, + CalculateBucketIds(/*is_pk_table=*/false, -1, bucket_schema, "[[10], [-1], [50]]")); + std::vector expected = {0, 0, 0}; + ASSERT_EQ(expected, result); +} + +TEST_F(BucketIdCalculatorTest, TestPostponeBucket) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + ASSERT_OK_AND_ASSIGN( + std::vector result, + CalculateBucketIds(/*is_pk_table=*/true, -2, bucket_schema, "[[10], [-1], [50]]")); + std::vector expected = {-2, -2, -2}; + ASSERT_EQ(expected, result); +} + +TEST_F(BucketIdCalculatorTest, TestVariantType) { + arrow::FieldVector raw_bucket_fields = { + arrow::field("v0", arrow::boolean()), + arrow::field("v1", arrow::int8()), + arrow::field("v2", arrow::int16()), + arrow::field("v3", arrow::int32()), + arrow::field("v4", arrow::int64()), + arrow::field("v5", arrow::float32()), + arrow::field("v6", arrow::float64()), + arrow::field("v7", arrow::date32()), + arrow::field("v8", arrow::timestamp(arrow::TimeUnit::NANO)), + arrow::field("v9", arrow::decimal128(30, 20)), + arrow::field("v10", arrow::utf8()), + arrow::field("v11", arrow::binary())}; + auto bucket_schema = arrow::schema(raw_bucket_fields); + + auto bucket_array = std::dynamic_pointer_cast( + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_schema->fields()), R"([ + [true, 10, 200, 65536, 123456789, 0.0, 0.0, 2000, -86399999999500, "2134.48690000000000000009", "olá mundo,你好世界。Two roads diverged in a wood, and I took the one less traveled by, And that has made all the difference.", "Alice"], + [false, -128, -32768, -2147483648, -9223372036854775808, -3.4028235E38, -1.7976931348623157E308, -719528, -9223372036854775808, "-999999999999999999.99999999999999999999", "Alice", "olá mundo,你好世界。Two roads diverged in a wood, and I took the one less traveled by, And that has made all the difference."], + [true, 127, 32767, 2147483647, 9223372036854775807, 3.4028235E38, 1.7976931348623157E308, 2932896, 9223372036854775807, "999999999999999999.99999999999999999999", "Alice", "olá mundo,你好世界。Two roads diverged in a wood, and I took the one less traveled by, And that has made all the difference."], + [true, 0, 0, 0, 0, 1.4E-45, 4.9E-324, 0, 0, "0.00000000000000000000", "Alice", "olá mundo,你好世界。Two roads diverged in a wood, and I took the one less traveled by, And that has made all the difference."] +])") + .ValueOrDie()); + ASSERT_OK_AND_ASSIGN( + std::vector result, + CalculateBucketIds(/*is_pk_table=*/true, 12345, bucket_schema, bucket_array)); + std::vector expected = {11275, 12272, 6549, 11795}; + ASSERT_EQ(expected, result); + // test calculate multiple times + ASSERT_OK_AND_ASSIGN( + std::vector result2, + CalculateBucketIds(/*is_pk_table=*/true, 12345, bucket_schema, bucket_array)); + ASSERT_EQ(expected, result2); +} + +TEST_F(BucketIdCalculatorTest, TestWithModBucketFunction) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + ASSERT_OK_AND_ASSIGN(auto mod_func, ModBucketFunction::Create(FieldType::INT)); + ASSERT_OK_AND_ASSIGN( + std::vector result, + CalculateBucketIds(/*is_pk_table=*/true, /*num_buckets=*/10, std::move(mod_func), + bucket_schema, "[[10], [-1], [50], [-13], [0]]")); + // Java Math.floorMod semantics: + // floorMod(10, 10) = 0 + // floorMod(-1, 10) = 9 + // floorMod(50, 10) = 0 + // floorMod(-13, 10) = 7 + // floorMod(0, 10) = 0 + std::vector expected = {0, 9, 0, 7, 0}; + ASSERT_EQ(expected, result); +} + +TEST_F(BucketIdCalculatorTest, TestWithDefaultBucketFunctionExplicit) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + // Calculate with explicit DefaultBucketFunction + auto default_func = std::make_unique(); + ASSERT_OK_AND_ASSIGN( + std::vector result_explicit, + CalculateBucketIds(/*is_pk_table=*/true, /*num_buckets=*/10, std::move(default_func), + bucket_schema, "[[10], [-1], [50], [-13], [0]]")); + // Calculate with default (no BucketFunction passed) + ASSERT_OK_AND_ASSIGN(std::vector result_default, + CalculateBucketIds(/*is_pk_table=*/true, /*num_buckets=*/10, bucket_schema, + "[[10], [-1], [50], [-13], [0]]")); + ASSERT_EQ(result_default, result_explicit); +} + +TEST_F(BucketIdCalculatorTest, TestCreateWithDefaultBucketFunction) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + std::string data_str = "[[10], [-1], [50], [-13], [0]]"; + + // Calculate with explicit DefaultBucketFunction via Create + auto default_func = std::make_unique(); + ASSERT_OK_AND_ASSIGN(auto calc_explicit, + BucketIdCalculator::Create(/*is_pk_table=*/true, /*num_buckets=*/10, + std::move(default_func), GetDefaultPool())); + + // Calculate with the default Create (no BucketFunction) + ASSERT_OK_AND_ASSIGN( + auto calc_default, + BucketIdCalculator::Create(/*is_pk_table=*/true, /*num_buckets=*/10, GetDefaultPool())); + + auto bucket_array1 = + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_schema->fields()), data_str) + .ValueOrDie(); + ::ArrowArray c_array1; + EXPECT_TRUE(arrow::ExportArray(*bucket_array1, &c_array1).ok()); + ::ArrowSchema c_schema1; + EXPECT_TRUE(arrow::ExportSchema(*bucket_schema, &c_schema1).ok()); + std::vector result_explicit(bucket_array1->length()); + ASSERT_OK(calc_explicit->CalculateBucketIds(&c_array1, &c_schema1, result_explicit.data())); + + auto bucket_array2 = + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_schema->fields()), data_str) + .ValueOrDie(); + ::ArrowArray c_array2; + EXPECT_TRUE(arrow::ExportArray(*bucket_array2, &c_array2).ok()); + ::ArrowSchema c_schema2; + EXPECT_TRUE(arrow::ExportSchema(*bucket_schema, &c_schema2).ok()); + std::vector result_default(bucket_array2->length()); + ASSERT_OK(calc_default->CalculateBucketIds(&c_array2, &c_schema2, result_default.data())); + + ASSERT_EQ(result_default, result_explicit); +} + +TEST_F(BucketIdCalculatorTest, TestCreateWithModBucketFunction) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + std::string data_str = "[[10], [-1], [50], [-13], [0]]"; + + // Calculate with CreateMod + ASSERT_OK_AND_ASSIGN(auto calc_mod, + BucketIdCalculator::CreateMod(/*is_pk_table=*/true, /*num_buckets=*/10, + FieldType::INT, GetDefaultPool())); + + // Calculate with explicit ModBucketFunction + ASSERT_OK_AND_ASSIGN(auto mod_func, ModBucketFunction::Create(FieldType::INT)); + ASSERT_OK_AND_ASSIGN(std::vector result_explicit, + CalculateBucketIds(/*is_pk_table=*/true, /*num_buckets=*/10, + std::move(mod_func), bucket_schema, data_str)); + + auto bucket_array = + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_schema->fields()), data_str) + .ValueOrDie(); + ::ArrowArray c_array; + EXPECT_TRUE(arrow::ExportArray(*bucket_array, &c_array).ok()); + ::ArrowSchema c_schema; + EXPECT_TRUE(arrow::ExportSchema(*bucket_schema, &c_schema).ok()); + std::vector result_mod(bucket_array->length()); + ASSERT_OK(calc_mod->CalculateBucketIds(&c_array, &c_schema, result_mod.data())); + + ASSERT_EQ(result_explicit, result_mod); + // Verify expected values (Java Math.floorMod semantics) + std::vector expected = {0, 9, 0, 7, 0}; + ASSERT_EQ(expected, result_mod); +} + +TEST_F(BucketIdCalculatorTest, TestCreateWithHiveBucketFunction) { + auto bucket_schema = arrow::schema(arrow::FieldVector({arrow::field("b0", arrow::int32())})); + std::string data_str = "[[42], [0], [100]]"; + + std::vector field_infos = {HiveFieldInfo(FieldType::INT)}; + + // Calculate with CreateHive + ASSERT_OK_AND_ASSIGN(auto calc_hive, + BucketIdCalculator::CreateHive(/*is_pk_table=*/true, /*num_buckets=*/5, + field_infos, GetDefaultPool())); + + auto bucket_array = + arrow::ipc::internal::json::ArrayFromJSON(arrow::struct_(bucket_schema->fields()), data_str) + .ValueOrDie(); + ::ArrowArray c_array; + EXPECT_TRUE(arrow::ExportArray(*bucket_array, &c_array).ok()); + ::ArrowSchema c_schema; + EXPECT_TRUE(arrow::ExportSchema(*bucket_schema, &c_schema).ok()); + std::vector result(bucket_array->length()); + ASSERT_OK(calc_hive->CalculateBucketIds(&c_array, &c_schema, result.data())); + + // Verify all bucket ids are in valid range + for (auto bucket_id : result) { + ASSERT_GE(bucket_id, 0); + ASSERT_LT(bucket_id, 5); + } +} + +} // namespace paimon::test diff --git a/src/paimon/core/bucket/bucket_select_converter.cpp b/src/paimon/core/bucket/bucket_select_converter.cpp new file mode 100644 index 0000000..01bf860 --- /dev/null +++ b/src/paimon/core/bucket/bucket_select_converter.cpp @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/bucket_select_converter.h" + +#include +#include +#include +#include + +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" +#include "fmt/format.h" +#include "paimon/common/data/binary_row.h" +#include "paimon/common/data/binary_row_writer.h" +#include "paimon/common/utils/date_time_utils.h" +#include "paimon/common/utils/field_type_utils.h" +#include "paimon/core/bucket/bucket_function.h" +#include "paimon/core/bucket/default_bucket_function.h" +#include "paimon/core/bucket/hive_bucket_function.h" +#include "paimon/core/bucket/mod_bucket_function.h" +#include "paimon/data/timestamp.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/predicate/leaf_predicate.h" +#include "paimon/predicate/predicate.h" +#include "paimon/predicate/predicate_utils.h" + +namespace paimon { + +Result> BucketSelectConverter::Convert( + const std::shared_ptr& predicate, const std::vector& bucket_key_names, + const std::vector>& bucket_key_arrow_types, + BucketFunctionType bucket_function_type, int32_t num_buckets, MemoryPool* pool) { + assert(pool); + if (!predicate || bucket_key_names.empty() || num_buckets <= 0) { + return std::optional(std::nullopt); + } + + if (bucket_key_names.size() != bucket_key_arrow_types.size()) { + return Status::Invalid( + "bucket_key_names and bucket_key_arrow_types must have the same size"); + } + + // Derive FieldTypes from Arrow types + std::vector bucket_key_types; + bucket_key_types.reserve(bucket_key_arrow_types.size()); + for (const auto& arrow_type : bucket_key_arrow_types) { + PAIMON_ASSIGN_OR_RAISE(FieldType ft, FieldTypeUtils::ConvertToFieldType(arrow_type->id())); + bucket_key_types.push_back(ft); + } + + auto literals_opt = ExtractEqualLiterals(predicate, bucket_key_names); + if (!literals_opt.has_value()) { + return std::optional(std::nullopt); + } + + const auto& literals_map = literals_opt.value(); + auto num_fields = static_cast(bucket_key_names.size()); + + // Build a BinaryRow from the extracted literals + BinaryRow row(num_fields); + BinaryRowWriter writer(&row, /*initial_size=*/1024, pool); + writer.Reset(); + + for (int32_t i = 0; i < num_fields; i++) { + const auto& field_name = bucket_key_names[i]; + const auto& literal = literals_map.at(field_name); + PAIMON_RETURN_NOT_OK( + WriteLiteralToRow(i, literal, bucket_key_types[i], bucket_key_arrow_types[i], &writer)); + } + writer.Complete(); + + // Create the bucket function and compute the bucket + PAIMON_ASSIGN_OR_RAISE( + std::unique_ptr bucket_function, + CreateBucketFunction(bucket_function_type, bucket_key_types, bucket_key_arrow_types)); + int32_t bucket = bucket_function->Bucket(row, num_buckets); + return std::optional(bucket); +} + +std::optional> BucketSelectConverter::ExtractEqualLiterals( + const std::shared_ptr& predicate, const std::vector& bucket_key_names) { + std::set key_set(bucket_key_names.begin(), bucket_key_names.end()); + std::map result; + + auto splits = PredicateUtils::SplitAnd(predicate); + for (const auto& split : splits) { + auto* leaf = dynamic_cast(split.get()); + if (!leaf) { + continue; + } + // TODO(liangjie.liang): Support IN and OR predicates to enable multi-bucket pruning + if (leaf->GetFunction().GetType() != Function::Type::EQUAL) { + continue; + } + const auto& field_name = leaf->FieldName(); + if (key_set.find(field_name) == key_set.end()) { + continue; + } + const auto& literals = leaf->Literals(); + if (literals.size() != 1 || literals[0].IsNull()) { + continue; + } + // Only record the first EQUAL for each field + if (result.find(field_name) == result.end()) { + result.emplace(field_name, literals[0]); + } + } + + // Check all bucket key fields are covered + for (const auto& key_name : bucket_key_names) { + if (result.find(key_name) == result.end()) { + return std::nullopt; + } + } + return result; +} + +Status BucketSelectConverter::WriteLiteralToRow(int32_t pos, const Literal& literal, + FieldType field_type, + const std::shared_ptr& arrow_type, + BinaryRowWriter* writer) { + switch (field_type) { + case FieldType::BOOLEAN: + writer->WriteBoolean(pos, literal.GetValue()); + break; + case FieldType::TINYINT: + writer->WriteByte(pos, literal.GetValue()); + break; + case FieldType::SMALLINT: + writer->WriteShort(pos, literal.GetValue()); + break; + case FieldType::INT: + case FieldType::DATE: + writer->WriteInt(pos, literal.GetValue()); + break; + case FieldType::BIGINT: + writer->WriteLong(pos, literal.GetValue()); + break; + case FieldType::FLOAT: + writer->WriteFloat(pos, literal.GetValue()); + break; + case FieldType::DOUBLE: + writer->WriteDouble(pos, literal.GetValue()); + break; + case FieldType::STRING: + case FieldType::BINARY: { + auto value = literal.GetValue(); + writer->WriteStringView(pos, std::string_view{value}); + break; + } + case FieldType::TIMESTAMP: { + auto ts = literal.GetValue(); + auto timestamp_type = + arrow::internal::checked_pointer_cast(arrow_type); + int32_t precision = DateTimeUtils::GetPrecisionFromType(timestamp_type); + writer->WriteTimestamp(pos, ts, precision); + break; + } + case FieldType::DECIMAL: { + auto dec = literal.GetValue(); + const auto* decimal_type = + arrow::internal::checked_cast(arrow_type.get()); + int32_t precision = decimal_type->precision(); + writer->WriteDecimal(pos, dec, precision); + break; + } + default: + return Status::Invalid( + fmt::format("unsupported field type {} for bucket select conversion", + static_cast(field_type))); + } + return Status::OK(); +} + +Result> BucketSelectConverter::CreateBucketFunction( + BucketFunctionType type, const std::vector& bucket_key_types, + const std::vector>& bucket_key_arrow_types) { + switch (type) { + case BucketFunctionType::DEFAULT: + return std::unique_ptr(std::make_unique()); + case BucketFunctionType::MOD: { + if (bucket_key_types.size() != 1) { + return Status::Invalid("MOD bucket function requires exactly one bucket key field"); + } + return ModBucketFunction::Create(bucket_key_types[0]); + } + case BucketFunctionType::HIVE: { + std::vector field_infos; + field_infos.reserve(bucket_key_types.size()); + for (size_t i = 0; i < bucket_key_types.size(); i++) { + if (bucket_key_types[i] == FieldType::DECIMAL) { + const auto* decimal_type = + arrow::internal::checked_cast( + bucket_key_arrow_types[i].get()); + field_infos.emplace_back(bucket_key_types[i], decimal_type->precision(), + decimal_type->scale()); + } else { + field_infos.emplace_back(bucket_key_types[i]); + } + } + return HiveBucketFunction::Create(field_infos); + } + default: + return Status::Invalid( + fmt::format("unsupported bucket function type: {}", static_cast(type))); + } +} + +} // namespace paimon diff --git a/src/paimon/core/bucket/bucket_select_converter.h b/src/paimon/core/bucket/bucket_select_converter.h new file mode 100644 index 0000000..25eb31a --- /dev/null +++ b/src/paimon/core/bucket/bucket_select_converter.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/type_fwd.h" +#include "paimon/bucket/bucket_function_type.h" +#include "paimon/defs.h" +#include "paimon/predicate/literal.h" +#include "paimon/result.h" + +namespace paimon { + +class BinaryRowWriter; +class BucketFunction; +class MemoryPool; +class Predicate; + +/// Converts predicates on bucket key fields to a target bucket ID. +/// When all bucket key fields have EQUAL predicates, the converter computes +/// which bucket the data must reside in, enabling bucket pruning during scan. +class BucketSelectConverter { + public: + BucketSelectConverter() = delete; + ~BucketSelectConverter() = delete; + + /// Convert predicates to a target bucket ID. + /// @param predicate The predicate (possibly compound AND) to analyze. + /// @param bucket_key_names Ordered bucket key field names. + /// @param bucket_key_arrow_types Ordered Arrow data types for bucket key fields. + /// FieldType is derived from these automatically. + /// @param bucket_function_type The bucket function type (DEFAULT, MOD, HIVE). + /// @param num_buckets The total number of buckets. + /// @param pool Memory pool for BinaryRow construction. + /// @return The target bucket ID, or nullopt if predicates don't fully constrain all bucket + /// keys. + static Result> Convert( + const std::shared_ptr& predicate, + const std::vector& bucket_key_names, + const std::vector>& bucket_key_arrow_types, + BucketFunctionType bucket_function_type, int32_t num_buckets, MemoryPool* pool); + + private: + /// Extract single literal per bucket key field from EQUAL predicates. + /// Splits the predicate by AND and looks for EQUAL leaf predicates on bucket key fields. + /// @return A map from field name to literal, or nullopt if not all bucket keys are constrained. + static std::optional> ExtractEqualLiterals( + const std::shared_ptr& predicate, + const std::vector& bucket_key_names); + + /// Write a Literal value to a BinaryRowWriter at the given position. + static Status WriteLiteralToRow(int32_t pos, const Literal& literal, FieldType field_type, + const std::shared_ptr& arrow_type, + BinaryRowWriter* writer); + + /// Create the appropriate BucketFunction for the given type. + static Result> CreateBucketFunction( + BucketFunctionType type, const std::vector& bucket_key_types, + const std::vector>& bucket_key_arrow_types); +}; + +} // namespace paimon diff --git a/src/paimon/core/bucket/bucket_select_converter_test.cpp b/src/paimon/core/bucket/bucket_select_converter_test.cpp new file mode 100644 index 0000000..94c2f60 --- /dev/null +++ b/src/paimon/core/bucket/bucket_select_converter_test.cpp @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/bucket_select_converter.h" + +#include +#include + +#include "arrow/api.h" +#include "gtest/gtest.h" +#include "paimon/core/bucket/default_bucket_function.h" +#include "paimon/core/bucket/mod_bucket_function.h" +#include "paimon/data/decimal.h" +#include "paimon/data/timestamp.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/predicate/literal.h" +#include "paimon/predicate/predicate_builder.h" +#include "paimon/testing/utils/binary_row_generator.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +class BucketSelectConverterTest : public ::testing::Test { + protected: + std::shared_ptr pool_ = GetDefaultPool(); +}; + +TEST_F(BucketSelectConverterTest, SingleIntEqualDefault) { + int32_t num_buckets = 10; + Literal lit(static_cast(42)); + auto predicate = PredicateBuilder::Equal(0, "id", FieldType::INT, lit); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify by computing the expected bucket manually + auto row = BinaryRowGenerator::GenerateRow({static_cast(42)}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, SingleStringEqualDefault) { + int32_t num_buckets = 8; + std::string val = "hello_world"; + Literal lit(FieldType::STRING, val.c_str(), val.size()); + auto predicate = PredicateBuilder::Equal(0, "name", FieldType::STRING, lit); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"name"}, {arrow::utf8()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify + auto row = BinaryRowGenerator::GenerateRow({val}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, MultiKeyAndPredicate) { + int32_t num_buckets = 5; + Literal lit_id(static_cast(100)); + Literal lit_name(FieldType::STRING, "test", 4); + auto pred_id = PredicateBuilder::Equal(0, "id", FieldType::INT, lit_id); + auto pred_name = PredicateBuilder::Equal(1, "name", FieldType::STRING, lit_name); + ASSERT_OK_AND_ASSIGN(auto predicate, PredicateBuilder::And({pred_id, pred_name})); + + ASSERT_OK_AND_ASSIGN( + auto result, + BucketSelectConverter::Convert(predicate, {"id", "name"}, {arrow::int32(), arrow::utf8()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify + auto row = BinaryRowGenerator::GenerateRow({static_cast(100), std::string("test")}, + pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, MissingBucketKeyReturnsNullopt) { + int32_t num_buckets = 5; + Literal lit(static_cast(42)); + auto predicate = PredicateBuilder::Equal(0, "id", FieldType::INT, lit); + + ASSERT_OK_AND_ASSIGN( + auto result, + BucketSelectConverter::Convert(predicate, {"id", "name"}, {arrow::int32(), arrow::utf8()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(BucketSelectConverterTest, NonEqualPredicateReturnsNullopt) { + int32_t num_buckets = 5; + Literal lit(static_cast(42)); + auto predicate = PredicateBuilder::GreaterThan(0, "id", FieldType::INT, lit); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(BucketSelectConverterTest, OrPredicateReturnsNullopt) { + int32_t num_buckets = 5; + Literal lit1(static_cast(1)); + Literal lit2(static_cast(2)); + auto pred1 = PredicateBuilder::Equal(0, "id", FieldType::INT, lit1); + auto pred2 = PredicateBuilder::Equal(0, "id", FieldType::INT, lit2); + ASSERT_OK_AND_ASSIGN(auto predicate, PredicateBuilder::Or({pred1, pred2})); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(BucketSelectConverterTest, ModBucketFunction) { + int32_t num_buckets = 7; + Literal lit(static_cast(42)); + auto predicate = PredicateBuilder::Equal(0, "id", FieldType::INT, lit); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::MOD, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify: MOD function uses floorMod + auto row = BinaryRowGenerator::GenerateRow({static_cast(42)}, pool_.get()); + ASSERT_OK_AND_ASSIGN(auto mod_func, ModBucketFunction::Create(FieldType::INT)); + ASSERT_EQ(mod_func->Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, NullLiteralReturnsNullopt) { + int32_t num_buckets = 5; + Literal lit(FieldType::INT); // null literal + auto predicate = PredicateBuilder::Equal(0, "id", FieldType::INT, lit); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(BucketSelectConverterTest, DynamicBucketModeReturnsNullopt) { + Literal lit(static_cast(42)); + auto predicate = PredicateBuilder::Equal(0, "id", FieldType::INT, lit); + + ASSERT_OK_AND_ASSIGN( + auto result, BucketSelectConverter::Convert(predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, -1, pool_.get())); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(BucketSelectConverterTest, NullPredicateReturnsNullopt) { + ASSERT_OK_AND_ASSIGN( + auto result, BucketSelectConverter::Convert(nullptr, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, 5, pool_.get())); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(BucketSelectConverterTest, BigintKeyDefault) { + int32_t num_buckets = 16; + Literal lit(static_cast(123456789L)); + auto predicate = PredicateBuilder::Equal(0, "user_id", FieldType::BIGINT, lit); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"user_id"}, {arrow::int64()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify + auto row = BinaryRowGenerator::GenerateRow({static_cast(123456789L)}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, AndWithExtraPredicateStillWorks) { + // AND(EQUAL(id, 42), GREATER_THAN(value, 100)) + // Only id is bucket key, value is not — should still derive bucket from id + int32_t num_buckets = 5; + Literal lit_id(static_cast(42)); + Literal lit_val(static_cast(100)); + auto pred_id = PredicateBuilder::Equal(0, "id", FieldType::INT, lit_id); + auto pred_val = PredicateBuilder::GreaterThan(1, "value", FieldType::INT, lit_val); + ASSERT_OK_AND_ASSIGN(auto predicate, PredicateBuilder::And({pred_id, pred_val})); + + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"id"}, {arrow::int32()}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + auto row = BinaryRowGenerator::GenerateRow({static_cast(42)}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, TimestampMillisPrecision) { + // TIMESTAMP with millisecond precision (compact storage, precision=3) + int32_t num_buckets = 10; + Timestamp ts = Timestamp::FromEpochMillis(1700000000000L); + Literal lit(ts); + auto predicate = PredicateBuilder::Equal(0, "ts", FieldType::TIMESTAMP, lit); + + auto arrow_type = arrow::timestamp(arrow::TimeUnit::MILLI); + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"ts"}, {arrow_type}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify: precision=3 uses compact WriteTimestamp + auto row = BinaryRowGenerator::GenerateRow({TimestampType(ts, 3)}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, TimestampMicrosPrecision) { + // TIMESTAMP with microsecond precision (non-compact storage, precision=6) + int32_t num_buckets = 10; + Timestamp ts(1700000000000L, 123456); + Literal lit(ts); + auto predicate = PredicateBuilder::Equal(0, "ts", FieldType::TIMESTAMP, lit); + + auto arrow_type = arrow::timestamp(arrow::TimeUnit::MICRO); + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"ts"}, {arrow_type}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify: precision=6 uses non-compact WriteTimestamp (different layout than precision=3) + auto row = BinaryRowGenerator::GenerateRow({TimestampType(ts, 6)}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +TEST_F(BucketSelectConverterTest, DecimalKey) { + int32_t num_buckets = 10; + Decimal dec = Decimal::FromUnscaledLong(12345L, 10, 2); + Literal lit(dec); + auto predicate = PredicateBuilder::Equal(0, "amount", FieldType::DECIMAL, lit); + + auto arrow_type = arrow::decimal128(10, 2); + ASSERT_OK_AND_ASSIGN(auto result, BucketSelectConverter::Convert( + predicate, {"amount"}, {arrow_type}, + BucketFunctionType::DEFAULT, num_buckets, pool_.get())); + ASSERT_TRUE(result.has_value()); + + // Verify + auto row = BinaryRowGenerator::GenerateRow({dec}, pool_.get()); + DefaultBucketFunction func; + ASSERT_EQ(func.Bucket(row, num_buckets), result.value()); +} + +} // namespace paimon::test diff --git a/src/paimon/core/bucket/default_bucket_function.h b/src/paimon/core/bucket/default_bucket_function.h new file mode 100644 index 0000000..4b8f205 --- /dev/null +++ b/src/paimon/core/bucket/default_bucket_function.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "paimon/common/data/binary_row.h" +#include "paimon/core/bucket/bucket_function.h" + +namespace paimon { + +/// Default bucket function that uses the hash code of the row to determine the bucket. +class DefaultBucketFunction : public BucketFunction { + public: + int32_t Bucket(const BinaryRow& row, int32_t num_buckets) const override { + return std::abs(row.HashCode() % num_buckets); + } +}; + +} // namespace paimon diff --git a/src/paimon/core/bucket/default_bucket_function_test.cpp b/src/paimon/core/bucket/default_bucket_function_test.cpp new file mode 100644 index 0000000..16d18ca --- /dev/null +++ b/src/paimon/core/bucket/default_bucket_function_test.cpp @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/default_bucket_function.h" + +#include "gtest/gtest.h" +#include "paimon/common/data/binary_row.h" +#include "paimon/common/data/binary_row_writer.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +TEST(DefaultBucketFunctionTest, TestBasicHashMod) { + auto pool = GetDefaultPool(); + DefaultBucketFunction func; + + // Create a row with a single INT field + BinaryRow row(1); + BinaryRowWriter writer(&row, 0, pool.get()); + writer.WriteInt(0, 42); + writer.Complete(); + + int32_t num_buckets = 5; + int32_t bucket = func.Bucket(row, num_buckets); + ASSERT_GE(bucket, 0); + ASSERT_LT(bucket, num_buckets); + + // Verify it matches the expected formula: abs(hashCode % numBuckets) + int32_t expected = std::abs(row.HashCode() % num_buckets); + ASSERT_EQ(expected, bucket); +} + +TEST(DefaultBucketFunctionTest, TestDifferentNumBuckets) { + auto pool = GetDefaultPool(); + DefaultBucketFunction func; + + BinaryRow row(1); + BinaryRowWriter writer(&row, 0, pool.get()); + writer.WriteInt(0, 100); + writer.Complete(); + + for (int32_t num_buckets = 1; num_buckets <= 10; num_buckets++) { + int32_t bucket = func.Bucket(row, num_buckets); + ASSERT_GE(bucket, 0); + ASSERT_LT(bucket, num_buckets); + ASSERT_EQ(std::abs(row.HashCode() % num_buckets), bucket); + } +} + +TEST(DefaultBucketFunctionTest, TestMultiFieldRow) { + auto pool = GetDefaultPool(); + DefaultBucketFunction func; + + BinaryRow row(3); + BinaryRowWriter writer(&row, 0, pool.get()); + writer.WriteInt(0, 1); + writer.WriteLong(1, 2); + writer.WriteInt(2, 3); + writer.Complete(); + + int32_t num_buckets = 7; + int32_t bucket = func.Bucket(row, num_buckets); + ASSERT_GE(bucket, 0); + ASSERT_LT(bucket, num_buckets); + ASSERT_EQ(std::abs(row.HashCode() % num_buckets), bucket); +} + +} // namespace paimon::test diff --git a/src/paimon/core/bucket/hive_bucket_function.cpp b/src/paimon/core/bucket/hive_bucket_function.cpp new file mode 100644 index 0000000..72f0f00 --- /dev/null +++ b/src/paimon/core/bucket/hive_bucket_function.cpp @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/hive_bucket_function.h" + +#include +#include +#include +#include + +#include "fmt/format.h" +#include "paimon/common/data/binary_row.h" +#include "paimon/common/utils/field_type_utils.h" +#include "paimon/core/bucket/hive_hasher.h" +#include "paimon/status.h" + +namespace paimon { +HiveBucketFunction::HiveBucketFunction(const std::vector& field_infos) + : field_infos_(field_infos) {} + +Result> HiveBucketFunction::Create( + const std::vector& field_types) { + std::vector field_infos; + field_infos.reserve(field_types.size()); + for (const auto& type : field_types) { + field_infos.emplace_back(type); + } + return Create(field_infos); +} + +Result> HiveBucketFunction::Create( + const std::vector& field_infos) { + if (field_infos.empty()) { + return Status::Invalid("HiveBucketFunction requires at least one field"); + } + for (const auto& info : field_infos) { + switch (info.type) { + case FieldType::BOOLEAN: + case FieldType::TINYINT: + case FieldType::SMALLINT: + case FieldType::INT: + case FieldType::BIGINT: + case FieldType::FLOAT: + case FieldType::DOUBLE: + case FieldType::STRING: + case FieldType::BINARY: + case FieldType::DECIMAL: + case FieldType::DATE: + break; + default: + return Status::Invalid(fmt::format("Unsupported type as Hive bucket key type: {}", + FieldTypeUtils::FieldTypeToString(info.type))); + } + } + return std::unique_ptr(new HiveBucketFunction(field_infos)); +} + +int32_t HiveBucketFunction::Bucket(const BinaryRow& row, int32_t num_buckets) const { + static constexpr int32_t SEED = 0; + int32_t hash = SEED; + for (int32_t i = 0; i < row.GetFieldCount(); i++) { + hash = (31 * hash) + ComputeHash(row, i); + } + return Mod(hash & std::numeric_limits::max(), num_buckets); +} + +int32_t HiveBucketFunction::ComputeHash(const BinaryRow& row, int32_t field_index) const { + if (row.IsNullAt(field_index)) { + return 0; + } + + const auto& info = field_infos_[field_index]; + switch (info.type) { + case FieldType::BOOLEAN: + return HiveHasher::HashInt(row.GetBoolean(field_index) ? 1 : 0); + case FieldType::TINYINT: + return HiveHasher::HashInt(static_cast(row.GetByte(field_index))); + case FieldType::SMALLINT: + return HiveHasher::HashInt(static_cast(row.GetShort(field_index))); + case FieldType::INT: + case FieldType::DATE: + return HiveHasher::HashInt(row.GetInt(field_index)); + case FieldType::BIGINT: + return HiveHasher::HashLong(row.GetLong(field_index)); + case FieldType::FLOAT: { + float float_value = row.GetFloat(field_index); + int32_t bits; + if (float_value == -0.0f) { + bits = 0; + } else { + std::memcpy(&bits, &float_value, sizeof(bits)); + } + return HiveHasher::HashInt(bits); + } + case FieldType::DOUBLE: { + double double_value = row.GetDouble(field_index); + int64_t bits; + if (double_value == -0.0) { + bits = 0L; + } else { + std::memcpy(&bits, &double_value, sizeof(bits)); + } + return HiveHasher::HashLong(bits); + } + case FieldType::STRING: { + std::string_view sv = row.GetStringView(field_index); + return HiveHasher::HashBytes(sv.data(), static_cast(sv.size())); + } + case FieldType::BINARY: { + std::string_view sv = row.GetStringView(field_index); + return HiveHasher::HashBytes(sv.data(), static_cast(sv.size())); + } + case FieldType::DECIMAL: { + Decimal decimal = row.GetDecimal(field_index, info.precision, info.scale); + return HiveHasher::HashDecimal(decimal); + } + default: + // This should never happen since Create() validates the types. + assert(false); + return 0; + } +} + +int32_t HiveBucketFunction::Mod(int32_t value, int32_t divisor) { + int32_t remainder = value % divisor; + if (remainder < 0) { + return (remainder + divisor) % divisor; + } + return remainder; +} + +} // namespace paimon diff --git a/src/paimon/core/bucket/hive_bucket_function.h b/src/paimon/core/bucket/hive_bucket_function.h new file mode 100644 index 0000000..82f0c40 --- /dev/null +++ b/src/paimon/core/bucket/hive_bucket_function.h @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "paimon/bucket/bucket_function_type.h" +#include "paimon/core/bucket/bucket_function.h" +#include "paimon/result.h" + +namespace paimon { + +/// Hive-compatible bucket function. +/// This implements the same bucket assignment logic as Hive, using Hive's hash functions +/// to ensure compatibility between Paimon and Hive bucketed tables. +/// +/// The hash is computed by iterating over all fields in the row: +/// hash = (31 * hash) + computeHash(field_value) +/// Then the bucket is: (hash & INT32_MAX) % numBuckets +class HiveBucketFunction : public BucketFunction { + public: + /// Create a HiveBucketFunction with the given field types. + /// @param field_types The types of all fields in the bucket key row. + /// @return A Result containing the HiveBucketFunction or an error status. + static Result> Create( + const std::vector& field_types); + + /// Create a HiveBucketFunction with detailed field info (including decimal precision/scale). + /// @param field_infos The detailed type info of all fields in the bucket key row. + /// @return A Result containing the HiveBucketFunction or an error status. + static Result> Create( + const std::vector& field_infos); + + int32_t Bucket(const BinaryRow& row, int32_t num_buckets) const override; + + private: + explicit HiveBucketFunction(const std::vector& field_infos); + + /// Compute the Hive hash for a single field value. + int32_t ComputeHash(const BinaryRow& row, int32_t field_index) const; + + /// Mod operation that always returns non-negative result. + static int32_t Mod(int32_t value, int32_t divisor); + + std::vector field_infos_; +}; + +} // namespace paimon diff --git a/src/paimon/core/bucket/hive_bucket_function_test.cpp b/src/paimon/core/bucket/hive_bucket_function_test.cpp new file mode 100644 index 0000000..73a94c5 --- /dev/null +++ b/src/paimon/core/bucket/hive_bucket_function_test.cpp @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/hive_bucket_function.h" + +#include + +#include "gtest/gtest.h" +#include "paimon/common/data/binary_row.h" +#include "paimon/common/data/binary_row_writer.h" +#include "paimon/core/bucket/hive_hasher.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/testing/utils/binary_row_generator.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +class HiveBucketFunctionTest : public ::testing::Test { + protected: + /// Helper to create a BinaryRow with INT, STRING, BINARY, DECIMAL(10,4) fields. + /// Matches the Java test: toBinaryRow(rowType, 7, "hello", {1,2,3}, Decimal("12.3400", 10, 4)) + BinaryRow CreateMixedRow(int32_t int_val, const std::string& str_val, + const std::vector& binary_val, int64_t decimal_unscaled, + int32_t decimal_precision, int32_t decimal_scale) { + auto pool = GetDefaultPool(); + BinaryRow row(4); + BinaryRowWriter writer(&row, 0, pool.get()); + + // Field 0: INT + writer.WriteInt(0, int_val); + + // Field 1: STRING + writer.WriteStringView(1, std::string_view{str_val}); + + // Field 2: BINARY + writer.WriteStringView(2, std::string_view(binary_val.data(), binary_val.size())); + + // Field 3: DECIMAL (compact, precision <= 18) + writer.WriteDecimal( + 3, Decimal::FromUnscaledLong(decimal_unscaled, decimal_precision, decimal_scale), + decimal_precision); + + writer.Complete(); + return row; + } + + /// Helper to create a BinaryRow with all null fields. + BinaryRow CreateNullRow(int32_t num_fields) { + auto pool = GetDefaultPool(); + BinaryRow row(num_fields); + BinaryRowWriter writer(&row, 0, pool.get()); + for (int32_t i = 0; i < num_fields; i++) { + writer.SetNullAt(i); + } + writer.Complete(); + return row; + } + + BinaryRow CreateIntRow(int32_t value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); + } + + BinaryRow CreateBooleanRow(bool value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); + } + + BinaryRow CreateLongRow(int64_t value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); + } + + BinaryRow CreateFloatRow(float value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); + } + + BinaryRow CreateDoubleRow(double value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); + } + + BinaryRow CreateStringRow(const std::string& value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); + } +}; + +/// Test matching Java: testHiveBucketFunction +/// RowType: INT, STRING, BYTES, DECIMAL(10,4) +/// Values: 7, "hello", {1,2,3}, Decimal("12.3400", 10, 4) +TEST_F(HiveBucketFunctionTest, TestHiveBucketFunction) { + std::vector field_infos = { + HiveFieldInfo(FieldType::INT), + HiveFieldInfo(FieldType::STRING), + HiveFieldInfo(FieldType::BINARY), + HiveFieldInfo(FieldType::DECIMAL, 10, 4), + }; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_infos)); + + // Decimal("12.3400", 10, 4) => unscaled = 123400 + BinaryRow row = CreateMixedRow(7, "hello", {1, 2, 3}, 123400, 10, 4); + + // Verify individual hash components: + // HiveHasher.hashBytes("hello") = 99162322 + ASSERT_EQ(99162322, HiveHasher::HashBytes("hello", 5)); + // HiveHasher.hashBytes({1,2,3}) = 1026 + ASSERT_EQ(1026, HiveHasher::HashBytes("\x01\x02\x03", 3)); + // BigDecimal("12.34").hashCode() = 1234 * 31 + 2 = 38256 + // (After normalizing "12.3400" -> "12.34", unscaled=1234, scale=2) + ASSERT_EQ(38256, HiveHasher::HashDecimal(Decimal::FromUnscaledLong(123400, 10, 4))); + + // expectedHash = 31*(31*(31*7 + 99162322) + 1026) + 38256 = 805989529 (with int32 overflow) + // bucket = (805989529 & INT32_MAX) % 8 = 1 + ASSERT_EQ(1, func->Bucket(row, 8)); +} + +/// Test matching Java: testHiveBucketFunctionWithNulls +TEST_F(HiveBucketFunctionTest, TestHiveBucketFunctionWithNulls) { + std::vector field_types = {FieldType::INT, FieldType::STRING}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + BinaryRow row = CreateNullRow(2); + + // All nulls => hash = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(row, 4)); +} + +/// Test unsupported type returns error on Create +TEST_F(HiveBucketFunctionTest, TestUnsupportedType) { + // TIMESTAMP type should fail + std::vector field_types = {FieldType::TIMESTAMP}; + auto result = HiveBucketFunction::Create(field_types); + ASSERT_NOK_WITH_MSG(result.status(), "Unsupported type"); +} + +/// Test empty field types returns error +TEST_F(HiveBucketFunctionTest, TestEmptyFieldTypes) { + std::vector field_types = {}; + auto result = HiveBucketFunction::Create(field_types); + ASSERT_NOK_WITH_MSG(result.status(), "at least one field"); +} + +/// Test single INT field +TEST_F(HiveBucketFunctionTest, TestSingleIntField) { + std::vector field_types = {FieldType::INT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // hash = 31*0 + 42 = 42, bucket = (42 & INT32_MAX) % 5 = 2 + ASSERT_EQ(2, func->Bucket(CreateIntRow(42), 5)); + // hash = 31*0 + 0 = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(0), 5)); +} + +/// Test BOOLEAN field +TEST_F(HiveBucketFunctionTest, TestBooleanField) { + std::vector field_types = {FieldType::BOOLEAN}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // true => hashInt(1) = 1, bucket = 1 % 4 = 1 + ASSERT_EQ(1, func->Bucket(CreateBooleanRow(true), 4)); + // false => hashInt(0) = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateBooleanRow(false), 4)); +} + +/// Test BIGINT field +TEST_F(HiveBucketFunctionTest, TestBigintField) { + std::vector field_types = {FieldType::BIGINT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // Java Long.hashCode(100L) = (int)(100 ^ (100 >>> 32)) = 100 + // bucket = 100 % 7 = 2 + ASSERT_EQ(2, func->Bucket(CreateLongRow(100L), 7)); +} + +/// Test FLOAT field with -0.0f +TEST_F(HiveBucketFunctionTest, TestFloatNegativeZero) { + std::vector field_types = {FieldType::FLOAT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // -0.0f should be treated as 0 => hashInt(0) = 0 + ASSERT_EQ(func->Bucket(CreateFloatRow(0.0f), 5), func->Bucket(CreateFloatRow(-0.0f), 5)); +} + +/// Test DOUBLE field with -0.0 +TEST_F(HiveBucketFunctionTest, TestDoubleNegativeZero) { + std::vector field_types = {FieldType::DOUBLE}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // -0.0 should be treated as 0L => hashLong(0) = 0 + ASSERT_EQ(func->Bucket(CreateDoubleRow(0.0), 5), func->Bucket(CreateDoubleRow(-0.0), 5)); +} + +/// Test STRING field +TEST_F(HiveBucketFunctionTest, TestStringField) { + std::vector field_types = {FieldType::STRING}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // hashBytes("hello") = 99162322 + // bucket = (99162322 & INT32_MAX) % 10 = 99162322 % 10 = 2 + ASSERT_EQ(2, func->Bucket(CreateStringRow("hello"), 10)); +} + +/// Test different num_buckets produce valid results +TEST_F(HiveBucketFunctionTest, TestDifferentNumBuckets) { + std::vector field_types = {FieldType::INT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + for (int32_t num_buckets = 1; num_buckets <= 20; num_buckets++) { + int32_t bucket = func->Bucket(CreateIntRow(12345), num_buckets); + ASSERT_GE(bucket, 0); + ASSERT_LT(bucket, num_buckets); + } +} + +/// Test compatibility with Java HiveBucketFunction across multiple data types. +/// Expected values are computed from the Java implementation: +/// hash = 0 (seed) +/// for each field: hash = 31 * hash + computeHash(field) +/// bucket = (hash & INT32_MAX) % numBuckets +/// +/// Java computeHash per type: +/// BOOLEAN: hashInt(value ? 1 : 0) +/// INT/DATE: hashInt(value) [identity] +/// BIGINT: hashLong(value) = (int)(value ^ (value >>> 32)) +/// FLOAT: hashInt(Float.floatToIntBits(value)), -0.0f treated as 0 +/// DOUBLE: hashLong(Double.doubleToLongBits(value)), -0.0 treated as 0L +/// STRING/BINARY: hashBytes(bytes) +/// DECIMAL: BigDecimal.hashCode() after normalization +TEST_F(HiveBucketFunctionTest, TestCompatibleWithJava) { + auto pool = GetDefaultPool(); + const int32_t num_buckets = 128; + + // Case 1: Single INT field with various values + // Java: hash = 31*0 + hashInt(v) = v + // bucket = (v & INT32_MAX) % 128 + { + std::vector field_types = {FieldType::INT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // hashInt(0) = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(0), num_buckets)); + // hashInt(1) = 1, bucket = 1 + ASSERT_EQ(1, func->Bucket(CreateIntRow(1), num_buckets)); + // hashInt(127) = 127, bucket = 127 + ASSERT_EQ(127, func->Bucket(CreateIntRow(127), num_buckets)); + // hashInt(128) = 128, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(128), num_buckets)); + // hashInt(-1) = -1, (-1 & INT32_MAX) = 2147483647, 2147483647 % 128 = 127 + ASSERT_EQ(127, func->Bucket(CreateIntRow(-1), num_buckets)); + // hashInt(INT32_MIN) = -2147483648, (-2147483648 & INT32_MAX) = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(std::numeric_limits::min()), num_buckets)); + // hashInt(INT32_MAX) = 2147483647, (2147483647 & INT32_MAX) = 2147483647, % 128 = 127 + ASSERT_EQ(127, + func->Bucket(CreateIntRow(std::numeric_limits::max()), num_buckets)); + } + + // Case 2: Single BOOLEAN field + // Java: hashInt(true ? 1 : 0) + { + std::vector field_types = {FieldType::BOOLEAN}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // true => hashInt(1) = 1, bucket = 1 % 128 = 1 + ASSERT_EQ(1, func->Bucket(CreateBooleanRow(true), num_buckets)); + // false => hashInt(0) = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateBooleanRow(false), num_buckets)); + } + + // Case 3: Single BIGINT field + // Java: hashLong(v) = (int)(v ^ (v >>> 32)) + { + std::vector field_types = {FieldType::BIGINT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // hashLong(0) = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateLongRow(0L), num_buckets)); + // hashLong(100) = (int)(100 ^ 0) = 100, bucket = 100 % 128 = 100 + ASSERT_EQ(100, func->Bucket(CreateLongRow(100L), num_buckets)); + // hashLong(4294967296L) = (int)(4294967296 ^ 1) = 1, bucket = 1 + // 4294967296L = 0x100000000, >>> 32 = 1, xor = 0x100000001, (int) = 1 + ASSERT_EQ(1, func->Bucket(CreateLongRow(4294967296L), num_buckets)); + // hashLong(LONG_MAX) = (int)(0x7FFFFFFFFFFFFFFF ^ 0x7FFFFFFF) = (int)0x7FFFFF80000000 + // = (int)(0x7FFFFFFF80000000) => low 32 bits = 0x80000000 = -2147483648 + // Actually: 0x7FFFFFFFFFFFFFFF ^ (0x7FFFFFFFFFFFFFFF >>> 32) + // = 0x7FFFFFFFFFFFFFFF ^ 0x7FFFFFFF + // = 0x7FFFFFFF80000000 + // (int) = 0x80000000 = -2147483648 + // (-2147483648 & INT32_MAX) = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateLongRow(std::numeric_limits::max()), num_buckets)); + // hashLong(-1) = (int)(-1 ^ (0xFFFFFFFFFFFFFFFF >>> 32)) + // = (int)(-1 ^ 0xFFFFFFFF) = (int)(0) = 0 + ASSERT_EQ(0, func->Bucket(CreateLongRow(-1L), num_buckets)); + } + + // Case 4: Single FLOAT field + // Java: hashInt(Float.floatToIntBits(v)), -0.0f => 0 + { + std::vector field_types = {FieldType::FLOAT}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // 0.0f => bits = 0, hashInt(0) = 0 + ASSERT_EQ(0, func->Bucket(CreateFloatRow(0.0f), num_buckets)); + // -0.0f => treated as 0, hashInt(0) = 0 + ASSERT_EQ(0, func->Bucket(CreateFloatRow(-0.0f), num_buckets)); + // 1.0f => Float.floatToIntBits(1.0f) = 0x3F800000 = 1065353216 + // 1065353216 & INT32_MAX = 1065353216, % 128 = 0 + ASSERT_EQ(0, func->Bucket(CreateFloatRow(1.0f), num_buckets)); + // -1.0f => Float.floatToIntBits(-1.0f) = 0xBF800000 = -1082130432 + // (-1082130432 & INT32_MAX) = 1065353216, % 128 = 0 + ASSERT_EQ(0, func->Bucket(CreateFloatRow(-1.0f), num_buckets)); + } + + // Case 5: Single DOUBLE field + // Java: hashLong(Double.doubleToLongBits(v)), -0.0 => 0L + { + std::vector field_types = {FieldType::DOUBLE}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // 0.0 => bits = 0L, hashLong(0) = 0 + ASSERT_EQ(0, func->Bucket(CreateDoubleRow(0.0), num_buckets)); + // -0.0 => treated as 0L, hashLong(0) = 0 + ASSERT_EQ(0, func->Bucket(CreateDoubleRow(-0.0), num_buckets)); + // 1.0 => Double.doubleToLongBits(1.0) = 0x3FF0000000000000 = 4607182418800017408 + // hashLong = (int)(4607182418800017408 ^ (4607182418800017408 >>> 32)) + // = (int)(0x3FF0000000000000 ^ 0x3FF00000) + // = (int)(0x3FF000003FF00000) + // = (int)(0x3FF00000) = 1072693248 + // 1072693248 % 128 = 0 + ASSERT_EQ(0, func->Bucket(CreateDoubleRow(1.0), num_buckets)); + } + + // Case 6: Single STRING field + // Java: hashBytes(bytes) + { + std::vector field_types = {FieldType::STRING}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_types)); + + // hashBytes("hello") = 99162322 (verified in TestHiveBucketFunction) + // 99162322 & INT32_MAX = 99162322, % 128 = 82 + ASSERT_EQ(82, func->Bucket(CreateStringRow("hello"), num_buckets)); + // hashBytes("") = 0, bucket = 0 + ASSERT_EQ(0, func->Bucket(CreateStringRow(""), num_buckets)); + // hashBytes("a") = 97, bucket = 97 + ASSERT_EQ(97, func->Bucket(CreateStringRow("a"), num_buckets)); + } + + // Case 7: Single DATE field (same as INT) + // Java: hashInt(daysSinceEpoch) + { + std::vector field_infos = {HiveFieldInfo(FieldType::DATE)}; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_infos)); + + // DATE is stored as int32 days since epoch, hashed same as INT + // date = 2000 (days), hashInt(2000) = 2000, 2000 % 128 = 80 + ASSERT_EQ(80, func->Bucket(CreateIntRow(2000), num_buckets)); + } + + // Case 8: Multi-field row (INT, STRING, BINARY, DECIMAL) + // This is the same as TestHiveBucketFunction but with num_buckets=128 + // Java step-by-step (all arithmetic in int32 with overflow): + // hash = 0 + // hash = 31*0 + hashInt(7) = 7 + // hash = 31*7 + hashBytes("hello") = 217 + 99162322 = 99162539 + // hash = 31*99162539 + hashBytes({1,2,3}) = int32(-1220928587) + 1026 = -1220927561 + // hash = 31*(-1220927561) + hashDecimal(12.3400) = int32(805951273) + 38256 = 805989529 + // bucket = (805989529 & INT32_MAX) % 128 = 25 + { + std::vector field_infos = { + HiveFieldInfo(FieldType::INT), + HiveFieldInfo(FieldType::STRING), + HiveFieldInfo(FieldType::BINARY), + HiveFieldInfo(FieldType::DECIMAL, 10, 4), + }; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_infos)); + + BinaryRow row = CreateMixedRow(7, "hello", {1, 2, 3}, 123400, 10, 4); + // Already verified: func->Bucket(row, 8) == 1 + ASSERT_EQ(1, func->Bucket(row, 8)); + // With 128 buckets: 805989529 % 128 = 25 + ASSERT_EQ(25, func->Bucket(row, num_buckets)); + } + + // Case 9: All-null row + // Java: all nulls => hash = 0, bucket = 0 + { + std::vector field_infos = { + HiveFieldInfo(FieldType::INT), + HiveFieldInfo(FieldType::STRING), + HiveFieldInfo(FieldType::BIGINT), + }; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_infos)); + + BinaryRow row = CreateNullRow(3); + ASSERT_EQ(0, func->Bucket(row, num_buckets)); + } + + // Case 10: Multi-field row with BOOLEAN, INT, BIGINT, FLOAT, DOUBLE, STRING + // Java step-by-step: + // field 0: BOOLEAN true => hashInt(1) = 1 + // field 1: INT 42 => hashInt(42) = 42 + // field 2: BIGINT 100 => hashLong(100) = 100 + // field 3: FLOAT 0.0f => hashInt(0) = 0 + // field 4: DOUBLE 0.0 => hashLong(0) = 0 + // field 5: STRING "a" => hashBytes("a") = 97 + // + // hash = 0 + // hash = 31*0 + 1 = 1 + // hash = 31*1 + 42 = 73 + // hash = 31*73 + 100 = 2363 + // hash = 31*2363 + 0 = 73253 + // hash = 31*73253 + 0 = 2270843 + // hash = 31*2270843 + 97 = 70396230 + // bucket = (70396230 & INT32_MAX) % 128 = 70396230 % 128 = 70 + { + std::vector field_infos = { + HiveFieldInfo(FieldType::BOOLEAN), HiveFieldInfo(FieldType::INT), + HiveFieldInfo(FieldType::BIGINT), HiveFieldInfo(FieldType::FLOAT), + HiveFieldInfo(FieldType::DOUBLE), HiveFieldInfo(FieldType::STRING), + }; + ASSERT_OK_AND_ASSIGN(auto func, HiveBucketFunction::Create(field_infos)); + + BinaryRow row(6); + BinaryRowWriter writer(&row, 0, pool.get()); + writer.WriteBoolean(0, true); + writer.WriteInt(1, 42); + writer.WriteLong(2, 100L); + writer.WriteFloat(3, 0.0f); + writer.WriteDouble(4, 0.0); + writer.WriteStringView(5, std::string_view("a")); + writer.Complete(); + + ASSERT_EQ(70, func->Bucket(row, num_buckets)); + } +} + +} // namespace paimon::test diff --git a/src/paimon/core/bucket/hive_hasher.h b/src/paimon/core/bucket/hive_hasher.h new file mode 100644 index 0000000..d86be40 --- /dev/null +++ b/src/paimon/core/bucket/hive_hasher.h @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "paimon/data/decimal.h" + +namespace paimon { + +/// Hive-compatible hash utility functions. +/// This class provides hash functions that are compatible with Hive's ObjectInspectorUtils +/// hash implementation, ensuring consistent bucket assignment between Paimon C++ and Java. +class HiveHasher { + public: + /// Hash an int value (identity function, same as Hive). + static int32_t HashInt(int32_t input) { + return input; + } + + /// Hash a long value (same as Java's Long.hashCode). + static int32_t HashLong(int64_t input) { + return static_cast(input ^ (static_cast(input) >> 32)); + } + + /// Hash a byte array. + static int32_t HashBytes(const char* bytes, int32_t length) { + int32_t result = 0; + for (int32_t i = 0; i < length; i++) { + result = (result * 31) + static_cast(bytes[i]); + } + return result; + } + + /// Normalize a Decimal value for Hive-compatible hashing. + /// This implements the same logic as HiveHasher.normalizeDecimal in Java. + /// + /// The normalization process: + /// 1. Strip trailing zeros + /// 2. Check if integer digits exceed max precision (38) + /// 3. Limit scale to min(38, min(38 - intDigits, currentScale)) + /// 4. Round if necessary using HALF_UP + /// + /// @param decimal The decimal value to normalize. + /// @return The hash code of the normalized decimal, computed as Java BigDecimal.hashCode(). + static int32_t HashDecimal(const Decimal& decimal) { + // Java BigDecimal.hashCode() = unscaledValue.intValue() * 31 + scale + // For compact decimals (precision <= 18), we can use the long value directly. + // For non-compact decimals, we need to handle the 128-bit value. + + // First normalize: strip trailing zeros and limit scale + int32_t scale = decimal.Scale(); + auto value = decimal.Value(); + + // Strip trailing zeros + if (value == 0) { + // BigDecimal.ZERO.hashCode() = 0 * 31 + 0 = 0 + return 0; + } + + // Strip trailing zeros by dividing by 10 while remainder is 0 + while (scale > 0 && value != 0) { + auto quotient = value / 10; + auto remainder = value - quotient * 10; + if (remainder != 0) { + break; + } + value = quotient; + scale--; + } + + // After stripping, check if value is zero + if (value == 0) { + return 0; + } + + // Count integer digits + auto abs_value = value < 0 ? -value : value; + int32_t total_digits = 0; + auto temp = abs_value; + while (temp > 0) { + temp /= 10; + total_digits++; + } + int32_t int_digits = total_digits - scale; + + if (int_digits > HIVE_DECIMAL_MAX_PRECISION) { + // Overflow, return 0 (null equivalent) + return 0; + } + + int32_t max_scale = HIVE_DECIMAL_MAX_SCALE; + if (HIVE_DECIMAL_MAX_PRECISION - int_digits < max_scale) { + max_scale = HIVE_DECIMAL_MAX_PRECISION - int_digits; + } + if (scale < max_scale) { + max_scale = scale; + } + + if (scale > max_scale) { + // Need to round: scale down with HALF_UP rounding + int32_t scale_diff = scale - max_scale; + for (int32_t i = 0; i < scale_diff; i++) { + auto quotient = value / 10; + auto remainder = value - quotient * 10; + if (remainder < 0) remainder = -remainder; + if (remainder >= 5) { + value = quotient + (value < 0 ? -1 : 1); + } else { + value = quotient; + } + } + scale = max_scale; + + // Strip trailing zeros again after rounding + while (scale > 0 && value != 0) { + auto quotient = value / 10; + auto remainder = value - quotient * 10; + if (remainder != 0) { + break; + } + value = quotient; + scale--; + } + + if (value == 0) { + return 0; + } + } + + // Compute Java BigDecimal.hashCode(): + // hashCode = intValue(unscaledValue) * 31 + scale + // intValue() returns the low 32 bits of the value + auto int_value = static_cast(static_cast(value)); + return int_value * 31 + scale; + } + + private: + static constexpr int32_t HIVE_DECIMAL_MAX_PRECISION = 38; + static constexpr int32_t HIVE_DECIMAL_MAX_SCALE = 38; +}; + +} // namespace paimon diff --git a/src/paimon/core/bucket/mod_bucket_function.cpp b/src/paimon/core/bucket/mod_bucket_function.cpp new file mode 100644 index 0000000..375699e --- /dev/null +++ b/src/paimon/core/bucket/mod_bucket_function.cpp @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/mod_bucket_function.h" + +#include + +#include "fmt/format.h" +#include "paimon/common/data/binary_row.h" +#include "paimon/common/utils/field_type_utils.h" +#include "paimon/status.h" + +namespace paimon { + +namespace { + +/// Equivalent to Java's Math.floorMod semantics. +/// The result always has the same sign as the divisor (y), or is zero. +/// Works for both int32_t and int64_t as T. +template +inline int32_t FloorMod(T x, int32_t y) { + auto mod = static_cast(x) % static_cast(y); + // If the signs of mod and y differ and mod is not zero, adjust. + if ((mod ^ static_cast(y)) < 0 && mod != 0) { + mod += y; + } + return static_cast(mod); +} + +} // namespace + +ModBucketFunction::ModBucketFunction(FieldType bucket_key_type) + : bucket_key_type_(bucket_key_type) {} + +Result> ModBucketFunction::Create(FieldType bucket_key_type) { + if (bucket_key_type != FieldType::INT && bucket_key_type != FieldType::BIGINT) { + return Status::Invalid( + fmt::format("ModBucketFunction only supports INT or BIGINT bucket key type, but got {}", + FieldTypeUtils::FieldTypeToString(bucket_key_type))); + } + return std::unique_ptr(new ModBucketFunction(bucket_key_type)); +} + +int32_t ModBucketFunction::Bucket(const BinaryRow& row, int32_t num_buckets) const { + switch (bucket_key_type_) { + case FieldType::INT: + return FloorMod(row.GetInt(0), num_buckets); + case FieldType::BIGINT: + return FloorMod(row.GetLong(0), num_buckets); + default: + // This should never happen since Create() validates the type. + assert(false); + return 0; + } +} + +} // namespace paimon diff --git a/src/paimon/core/bucket/mod_bucket_function.h b/src/paimon/core/bucket/mod_bucket_function.h new file mode 100644 index 0000000..67747d1 --- /dev/null +++ b/src/paimon/core/bucket/mod_bucket_function.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "paimon/core/bucket/bucket_function.h" +#include "paimon/defs.h" +#include "paimon/result.h" + +namespace paimon { + +/// Mod bucket function that uses modulo operation on the bucket key value. +/// The bucket key must be a single field of INT or BIGINT type. +/// This implements Java's Math.floorMod semantics for negative numbers. +class ModBucketFunction : public BucketFunction { + public: + /// Create a ModBucketFunction with the given bucket key type. + /// @param bucket_key_type The type of the single bucket key field. Must be INT or BIGINT. + /// @return A Result containing the ModBucketFunction or an error status. + static Result> Create(FieldType bucket_key_type); + + int32_t Bucket(const BinaryRow& row, int32_t num_buckets) const override; + + private: + explicit ModBucketFunction(FieldType bucket_key_type); + + FieldType bucket_key_type_; +}; + +} // namespace paimon diff --git a/src/paimon/core/bucket/mod_bucket_function_test.cpp b/src/paimon/core/bucket/mod_bucket_function_test.cpp new file mode 100644 index 0000000..0eec2da --- /dev/null +++ b/src/paimon/core/bucket/mod_bucket_function_test.cpp @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "paimon/core/bucket/mod_bucket_function.h" + +#include "gtest/gtest.h" +#include "paimon/memory/memory_pool.h" +#include "paimon/testing/utils/binary_row_generator.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +namespace { + +BinaryRow CreateIntRow(int32_t value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); +} + +BinaryRow CreateLongRow(int64_t value) { + auto pool = GetDefaultPool(); + return BinaryRowGenerator::GenerateRow({value}, pool.get()); +} + +} // namespace + +TEST(ModBucketFunctionTest, TestIntType) { + ASSERT_OK_AND_ASSIGN(auto func, ModBucketFunction::Create(FieldType::INT)); + + // 1 % 5 = 1 + ASSERT_EQ(1, func->Bucket(CreateIntRow(1), 5)); + // 7 % 5 = 2 + ASSERT_EQ(2, func->Bucket(CreateIntRow(7), 5)); + // -2 floorMod 5 = 3 (Java Math.floorMod(-2, 5) = 3) + ASSERT_EQ(3, func->Bucket(CreateIntRow(-2), 5)); +} + +TEST(ModBucketFunctionTest, TestBigintType) { + ASSERT_OK_AND_ASSIGN(auto func, ModBucketFunction::Create(FieldType::BIGINT)); + + // 8 % 5 = 3 + ASSERT_EQ(3, func->Bucket(CreateLongRow(8), 5)); + // 0 % 5 = 0 + ASSERT_EQ(0, func->Bucket(CreateLongRow(0), 5)); + // -3 floorMod 5 = 2 (Java Math.floorMod(-3L, 5) = 2) + ASSERT_EQ(2, func->Bucket(CreateLongRow(-3), 5)); +} + +TEST(ModBucketFunctionTest, TestUnsupportedTypes) { + { + // STRING type should fail + auto result = ModBucketFunction::Create(FieldType::STRING); + ASSERT_NOK_WITH_MSG(result.status(), "only supports INT or BIGINT"); + } + { + // FLOAT type should fail + auto result = ModBucketFunction::Create(FieldType::FLOAT); + ASSERT_NOK_WITH_MSG(result.status(), "only supports INT or BIGINT"); + } + { + // DOUBLE type should fail + auto result = ModBucketFunction::Create(FieldType::DOUBLE); + ASSERT_NOK_WITH_MSG(result.status(), "only supports INT or BIGINT"); + } +} + +TEST(ModBucketFunctionTest, TestIntEdgeCases) { + ASSERT_OK_AND_ASSIGN(auto func, ModBucketFunction::Create(FieldType::INT)); + + // 0 % 5 = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(0), 5)); + // 5 % 5 = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(5), 5)); + // -5 floorMod 5 = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(-5), 5)); + // 1 % 1 = 0 + ASSERT_EQ(0, func->Bucket(CreateIntRow(1), 1)); +} + +TEST(ModBucketFunctionTest, TestBigintEdgeCases) { + ASSERT_OK_AND_ASSIGN(auto func, ModBucketFunction::Create(FieldType::BIGINT)); + + // Large value + ASSERT_EQ(3, func->Bucket(CreateLongRow(1000000003L), 5)); + // Negative large value: -1000000003 floorMod 5 = 2 + ASSERT_EQ(2, func->Bucket(CreateLongRow(-1000000003L), 5)); +} + +/// Large random compatibility test to ensure alignment with Java's Math.floorMod behavior. +/// The expected values are pre-computed using Java's Math.floorMod. +TEST(ModBucketFunctionTest, TestCompatibleWithJava) { + ASSERT_OK_AND_ASSIGN(auto int_func, ModBucketFunction::Create(FieldType::INT)); + ASSERT_OK_AND_ASSIGN(auto long_func, ModBucketFunction::Create(FieldType::BIGINT)); + + // Test INT type: pairs of (value, num_buckets) -> expected bucket (Java Math.floorMod) + // These values cover positive, negative, zero, edge cases, and large values. + struct IntTestCase { + int32_t value; + int32_t num_buckets; + int32_t expected; + }; + std::vector int_cases = { + {0, 10, 0}, + {1, 10, 1}, + {-1, 10, 9}, + {10, 10, 0}, + {-10, 10, 0}, + {11, 10, 1}, + {-11, 10, 9}, + {2147483647, 100, 47}, // INT32_MAX + {-2147483647, 100, 53}, // -(INT32_MAX) + {2147483647, 7, 1}, + {-2147483647, 7, 6}, + {123456789, 1000, 789}, + {-123456789, 1000, 211}, + {999, 1, 0}, + {-999, 1, 0}, + {42, 3, 0}, + {-42, 3, 0}, + {43, 3, 1}, + {-43, 3, 2}, + {100, 7, 2}, + {-100, 7, 5}, + }; + for (const auto& tc : int_cases) { + ASSERT_EQ(tc.expected, int_func->Bucket(CreateIntRow(tc.value), tc.num_buckets)) + << "INT floorMod(" << tc.value << ", " << tc.num_buckets << ")"; + } + + // Test BIGINT type: pairs of (value, num_buckets) -> expected bucket (Java Math.floorMod) + struct LongTestCase { + int64_t value; + int32_t num_buckets; + int32_t expected; + }; + std::vector long_cases = { + {0L, 10, 0}, + {1L, 10, 1}, + {-1L, 10, 9}, + {10L, 10, 0}, + {-10L, 10, 0}, + {9223372036854775807L, 100, 7}, // INT64_MAX + {-9223372036854775807L, 100, 93}, // -(INT64_MAX) + {9223372036854775807L, 7, 0}, + {-9223372036854775807L, 7, 0}, + {1234567890123456789L, 1000, 789}, + {-1234567890123456789L, 1000, 211}, + {100L, 7, 2}, + {-100L, 7, 5}, + {999999999999L, 13, 0}, + {-999999999999L, 13, 0}, + }; + for (const auto& tc : long_cases) { + ASSERT_EQ(tc.expected, long_func->Bucket(CreateLongRow(tc.value), tc.num_buckets)) + << "BIGINT floorMod(" << tc.value << ", " << tc.num_buckets << ")"; + } + + // Verify that all bucket results are in valid range [0, num_buckets) + for (int32_t num_buckets = 1; num_buckets <= 50; num_buckets++) { + for (int32_t v = -100; v <= 100; v++) { + int32_t bucket = int_func->Bucket(CreateIntRow(v), num_buckets); + ASSERT_GE(bucket, 0); + ASSERT_LT(bucket, num_buckets); + } + } +} + +} // namespace paimon::test