Skip to content

Commit af38a3c

Browse files
committed
Respect Parquet Logical type when reading statistics
1 parent e1c9c7d commit af38a3c

File tree

3 files changed

+151
-11
lines changed

3 files changed

+151
-11
lines changed

lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import io.trino.spi.predicate.SortedRangeSet;
2929
import io.trino.spi.predicate.TupleDomain;
3030
import io.trino.spi.predicate.ValueSet;
31+
import io.trino.spi.type.DecimalConversions;
3132
import io.trino.spi.type.DecimalType;
33+
import io.trino.spi.type.Decimals;
3234
import io.trino.spi.type.Int128;
3335
import io.trino.spi.type.TimestampType;
3436
import io.trino.spi.type.Type;
@@ -48,6 +50,7 @@
4850
import org.apache.parquet.io.ParquetDecodingException;
4951
import org.apache.parquet.io.api.Binary;
5052
import org.apache.parquet.schema.LogicalTypeAnnotation;
53+
import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
5154
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation;
5255
import org.apache.parquet.schema.PrimitiveType;
5356
import org.joda.time.DateTimeZone;
@@ -68,10 +71,12 @@
6871
import static io.trino.parquet.ParquetTimestampUtils.decodeInt96Timestamp;
6972
import static io.trino.parquet.ParquetTypeUtils.getShortDecimalValue;
7073
import static io.trino.parquet.predicate.PredicateUtils.isStatisticsOverflow;
74+
import static io.trino.parquet.reader.ColumnReaderFactory.isDecimalRescaled;
7175
import static io.trino.plugin.base.type.TrinoTimestampEncoderFactory.createTimestampEncoder;
7276
import static io.trino.spi.type.BigintType.BIGINT;
7377
import static io.trino.spi.type.BooleanType.BOOLEAN;
7478
import static io.trino.spi.type.DateType.DATE;
79+
import static io.trino.spi.type.Decimals.longTenToNth;
7580
import static io.trino.spi.type.DoubleType.DOUBLE;
7681
import static io.trino.spi.type.IntegerType.INTEGER;
7782
import static io.trino.spi.type.RealType.REAL;
@@ -377,11 +382,8 @@ private static Domain getDomain(
377382
SortedRangeSet.Builder rangesBuilder = SortedRangeSet.builder(type, minimums.size());
378383
if (decimalType.isShort()) {
379384
for (int i = 0; i < minimums.size(); i++) {
380-
Object min = minimums.get(i);
381-
Object max = maximums.get(i);
382-
383-
long minValue = min instanceof Slice minSlice ? getShortDecimalValue(minSlice.getBytes()) : asLong(min);
384-
long maxValue = max instanceof Slice maxSlice ? getShortDecimalValue(maxSlice.getBytes()) : asLong(max);
385+
long minValue = getShortDecimal(minimums.get(i), decimalType, column);
386+
long maxValue = getShortDecimal(maximums.get(i), decimalType, column);
385387

386388
if (isStatisticsOverflow(type, minValue, maxValue)) {
387389
return Domain.create(ValueSet.all(type), hasNullValue);
@@ -392,11 +394,8 @@ private static Domain getDomain(
392394
}
393395
else {
394396
for (int i = 0; i < minimums.size(); i++) {
395-
Object min = minimums.get(i);
396-
Object max = maximums.get(i);
397-
398-
Int128 minValue = min instanceof Slice minSlice ? Int128.fromBigEndian(minSlice.getBytes()) : Int128.valueOf(asLong(min));
399-
Int128 maxValue = max instanceof Slice maxSlice ? Int128.fromBigEndian(maxSlice.getBytes()) : Int128.valueOf(asLong(max));
397+
Int128 minValue = getLongDecimal(minimums.get(i), decimalType, column);
398+
Int128 maxValue = getLongDecimal(maximums.get(i), decimalType, column);
400399

401400
rangesBuilder.addRangeInclusive(minValue, maxValue);
402401
}
@@ -494,6 +493,61 @@ private static Domain getDomain(
494493
return Domain.create(ValueSet.all(type), hasNullValue);
495494
}
496495

496+
private static long getShortDecimal(Object value, DecimalType columnType, ColumnDescriptor column)
497+
{
498+
LogicalTypeAnnotation annotation = column.getPrimitiveType().getLogicalTypeAnnotation();
499+
500+
if (annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation) {
501+
if (isDecimalRescaled(decimalAnnotation, columnType)) {
502+
if (decimalAnnotation.getPrecision() <= Decimals.MAX_SHORT_PRECISION) {
503+
long rescale = longTenToNth(Math.abs(columnType.getScale() - decimalAnnotation.getScale()));
504+
return DecimalConversions.shortToShortCast(
505+
value instanceof Slice slice ? getShortDecimalValue(slice.getBytes()) : asLong(value),
506+
decimalAnnotation.getPrecision(),
507+
decimalAnnotation.getScale(),
508+
columnType.getPrecision(),
509+
columnType.getScale(),
510+
rescale,
511+
rescale / 2);
512+
}
513+
Int128 int128Representation = value instanceof Slice minSlice ? Int128.fromBigEndian(minSlice.getBytes()) : Int128.valueOf(asLong(value));
514+
return DecimalConversions.longToShortCast(
515+
int128Representation,
516+
decimalAnnotation.getPrecision(),
517+
decimalAnnotation.getScale(),
518+
columnType.getPrecision(),
519+
columnType.getScale());
520+
}
521+
}
522+
return value instanceof Slice slice ? getShortDecimalValue(slice.getBytes()) : asLong(value);
523+
}
524+
525+
private static Int128 getLongDecimal(Object value, DecimalType columnType, ColumnDescriptor column)
526+
{
527+
LogicalTypeAnnotation annotation = column.getPrimitiveType().getLogicalTypeAnnotation();
528+
529+
if (annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation) {
530+
if (isDecimalRescaled(decimalAnnotation, columnType)) {
531+
if (decimalAnnotation.getPrecision() <= Decimals.MAX_SHORT_PRECISION) {
532+
return DecimalConversions.shortToLongCast(
533+
value instanceof Slice slice ? getShortDecimalValue(slice.getBytes()) : asLong(value),
534+
decimalAnnotation.getPrecision(),
535+
decimalAnnotation.getScale(),
536+
columnType.getPrecision(),
537+
columnType.getScale());
538+
}
539+
Int128 int128Representation = value instanceof Slice slice ? Int128.fromBigEndian(slice.getBytes()) : Int128.valueOf(asLong(value));
540+
return DecimalConversions.longToLongCast(
541+
int128Representation,
542+
decimalAnnotation.getPrecision(),
543+
decimalAnnotation.getScale(),
544+
columnType.getPrecision(),
545+
columnType.getScale());
546+
}
547+
}
548+
return value instanceof Slice slice ? Int128.fromBigEndian(slice.getBytes()) : Int128.valueOf(asLong(value));
549+
}
550+
497551
@VisibleForTesting
498552
public static Domain getDomain(
499553
Type type,

lib/trino-parquet/src/main/java/io/trino/parquet/reader/ColumnReaderFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ public Optional<Boolean> visit(UUIDLogicalTypeAnnotation uuidLogicalType)
330330
.orElse(FALSE);
331331
}
332332

333-
private static boolean isDecimalRescaled(DecimalLogicalTypeAnnotation decimalAnnotation, DecimalType trinoType)
333+
public static boolean isDecimalRescaled(DecimalLogicalTypeAnnotation decimalAnnotation, DecimalType trinoType)
334334
{
335335
return decimalAnnotation.getPrecision() != trinoType.getPrecision()
336336
|| decimalAnnotation.getScale() != trinoType.getScale();

lib/trino-parquet/src/test/java/io/trino/parquet/TestTupleDomainParquetPredicate.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
import static java.util.Collections.singletonList;
102102
import static java.util.Collections.singletonMap;
103103
import static java.util.concurrent.TimeUnit.MILLISECONDS;
104+
import static org.apache.parquet.schema.LogicalTypeAnnotation.decimalType;
104105
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
105106
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
106107
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY;
@@ -276,6 +277,43 @@ public void testShortDecimalWithNoScale()
276277
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int32 ShortDecimalColumnWithNoScale\": [min: 100, max: 10, num_nulls: 0] [testFile]");
277278
}
278279

280+
@Test
281+
public void testShortDecimalWithLongDecimalAnnotation()
282+
throws Exception
283+
{
284+
ColumnDescriptor columnDescriptor = createColumnDescriptor(FIXED_LEN_BYTE_ARRAY, decimalType(2, 38), "ShortDecimalColumnWithDecimalAnnotation");
285+
BigInteger maximum = new BigInteger("12345");
286+
287+
Type type = createDecimalType(5, 2);
288+
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
289+
assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(maximum, maximum), ID, UTC)).isEqualTo(singleValue(type, 12345L));
290+
291+
assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
292+
assertThat(getDomain(columnDescriptor, type, 10, intColumnStats(0, 100), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
293+
294+
type = createDecimalType(15, 2);
295+
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
296+
assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(maximum, maximum), ID, UTC)).isEqualTo(singleValue(type, 12345L));
297+
298+
assertThat(getDomain(columnDescriptor, type, 10, binaryColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
299+
assertThat(getDomain(columnDescriptor, type, 10, intColumnStats(0, 100), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, 0L, true, 100L, true)), false));
300+
301+
Type typeWithDifferentScale = createDecimalType(5, 1);
302+
assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 0, null, ID, UTC)).isEqualTo(all(typeWithDifferentScale));
303+
304+
assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 10, longColumnStats(10012L, 10012L), ID, UTC)).isEqualTo(singleValue(typeWithDifferentScale, 1001L));
305+
306+
// Test that statistics overflowing the size of the type are not used
307+
assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 10, longColumnStats(100012L, 100012L), ID, UTC)).isEqualTo(singleValue(typeWithDifferentScale, 10001L));
308+
309+
assertThat(getDomain(columnDescriptor, typeWithDifferentScale, 10, longColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(typeWithDifferentScale, 0L, true, 10L, true)), false));
310+
311+
// fail on higher precision values
312+
assertThatExceptionOfType(ParquetCorruptionException.class)
313+
.isThrownBy(() -> getDomain(columnDescriptor, createDecimalType(4, 2), 10, binaryColumnStats(maximum, maximum), ID, UTC))
314+
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) ShortDecimalColumnWithDecimalAnnotation (DECIMAL(38,2))\": [min: 0x00000000000000000000000000003039, max: 0x00000000000000000000000000003039, num_nulls: 0] [testFile]");
315+
}
316+
279317
@Test
280318
public void testLongDecimal()
281319
throws Exception
@@ -319,6 +357,49 @@ public void testLongDecimalWithNoScale()
319357
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required fixed_len_byte_array(0) LongDecimalColumnWithNoScale\": [min: 0x00000000000000000000000000000064, max: 0x0000000000000000000000000000000A, num_nulls: 0] [testFile]");
320358
}
321359

360+
@Test
361+
public void testLongDecimalWithShortDecimalAnnotation()
362+
throws Exception
363+
{
364+
ColumnDescriptor columnDescriptor = createColumnDescriptor(INT32, decimalType(2, 5), "ShortDecimalColumn");
365+
DecimalType type = createDecimalType(20, 2);
366+
367+
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
368+
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(10012L, 10012L), ID, UTC)).isEqualTo(singleValue(type, Int128.valueOf(10012L)));
369+
370+
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0L, 10012L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, Int128.valueOf(0L), true, Int128.valueOf(10012L), true)), false));
371+
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, Int128.valueOf(0L), true, Int128.valueOf(100L), true)), false));
372+
373+
// fail on corrupted statistics
374+
assertThatExceptionOfType(ParquetCorruptionException.class)
375+
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, longColumnStats(100L, 10L), ID, UTC))
376+
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int32 ShortDecimalColumn (DECIMAL(5,2))\": [min: 100, max: 10, num_nulls: 0] [testFile]");
377+
}
378+
379+
@Test
380+
public void testLongDecimalWithInt64DecimalAnnotation()
381+
throws Exception
382+
{
383+
ColumnDescriptor columnDescriptor = createColumnDescriptor(INT64, decimalType(2, 5), "ShortDecimalColumn");
384+
DecimalType type = createDecimalType(20, 2);
385+
BigInteger maximum = new BigInteger("12345");
386+
387+
Int128 zero = Int128.ZERO;
388+
Int128 hundred = Int128.valueOf(100L);
389+
Int128 max = Int128.valueOf(maximum);
390+
391+
assertThat(getDomain(columnDescriptor, type, 0, null, ID, UTC)).isEqualTo(all(type));
392+
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(maximum.longValue(), maximum.longValue()), ID, UTC)).isEqualTo(singleValue(type, max));
393+
394+
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0L, 100L), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, zero, true, hundred, true)), false));
395+
assertThat(getDomain(columnDescriptor, type, 10, longColumnStats(0, 100), ID, UTC)).isEqualTo(create(ValueSet.ofRanges(range(type, zero, true, hundred, true)), false));
396+
397+
// fail on corrupted statistics
398+
assertThatExceptionOfType(ParquetCorruptionException.class)
399+
.isThrownBy(() -> getDomain(columnDescriptor, type, 10, longColumnStats(100L, 10L), ID, UTC))
400+
.withMessage("Malformed Parquet file. Corrupted statistics for column \"[] required int64 ShortDecimalColumn (DECIMAL(5,2))\": [min: 100, max: 10, num_nulls: 0] [testFile]");
401+
}
402+
322403
@Test
323404
public void testDouble()
324405
throws Exception
@@ -744,6 +825,11 @@ private ColumnDescriptor createColumnDescriptor(PrimitiveTypeName typeName, Stri
744825
return new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, typeName, columnName), 0, 0);
745826
}
746827

828+
private ColumnDescriptor createColumnDescriptor(PrimitiveTypeName typeName, LogicalTypeAnnotation typeAnnotation, String columnName)
829+
{
830+
return new ColumnDescriptor(new String[] {}, new PrimitiveType(REQUIRED, typeName, columnName).withLogicalTypeAnnotation(typeAnnotation), 0, 0);
831+
}
832+
747833
private TupleDomain<ColumnDescriptor> getEffectivePredicate(ColumnDescriptor column, VarcharType type, Slice value)
748834
{
749835
ColumnDescriptor predicateColumn = new ColumnDescriptor(column.getPath(), column.getPrimitiveType(), 0, 0);

0 commit comments

Comments
 (0)