Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.wayang.api.sql.calcite.converter;

import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;

Expand Down Expand Up @@ -50,9 +49,18 @@ Operator visit(final WayangFilter wayangRelNode) {
}

/** for quick sanity check **/
public static final EnumSet<SqlKind> SUPPORTED_OPS = EnumSet.of(SqlKind.AND, SqlKind.OR, SqlKind.NOT,
SqlKind.EQUALS, SqlKind.NOT_EQUALS,
SqlKind.LESS_THAN, SqlKind.GREATER_THAN,
SqlKind.GREATER_THAN_OR_EQUAL, SqlKind.LESS_THAN_OR_EQUAL, SqlKind.LIKE, SqlKind.IS_NOT_NULL, SqlKind.IS_NULL);
protected static final EnumSet<SqlKind> SUPPORTED_OPS = EnumSet.of(
SqlKind.AND,
SqlKind.OR,
SqlKind.NOT,
SqlKind.EQUALS,
SqlKind.NOT_EQUALS,
SqlKind.LESS_THAN,
SqlKind.GREATER_THAN,
SqlKind.GREATER_THAN_OR_EQUAL,
SqlKind.LESS_THAN_OR_EQUAL,
SqlKind.LIKE,
SqlKind.IS_NOT_NULL,
SqlKind.IS_NULL);

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.wayang.api.sql.calcite.converter;

import java.util.List;
import java.util.stream.Collectors;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
Expand Down Expand Up @@ -55,7 +54,7 @@ Operator visit(final WayangJoin wayangRelNode) {
final List<Integer> keys = call.getOperands().stream()
.map(RexInputRef.class::cast)
.map(RexInputRef::getIndex)
.collect(Collectors.toList());
.toList();

assert (keys.size() == 2) : "Amount of keys found in join was not 2, got: " + keys.size();

Expand All @@ -78,7 +77,7 @@ Operator visit(final WayangJoin wayangRelNode) {
childOpRight.connectTo(0, join, 1);

// Join returns Tuple2 - map to a Record
final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<>(
new JoinFlattenResult(),
ReflectionUtils.specify(Tuple2.class),
Record.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
*/
package org.apache.wayang.api.sql.calcite.converter.functions;

import java.util.Arrays;
import java.math.BigDecimal;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.runtime.SqlFunctions;
Expand All @@ -31,12 +30,12 @@

public class AggregateFunction
implements FunctionDescriptor.SerializableBinaryOperator<Record> {
final List<SqlKind> aggregateKinds;
private final List<SqlKind> aggregateKinds;

public AggregateFunction(final List<AggregateCall> aggregateCalls) {
this.aggregateKinds = aggregateCalls.stream()
.map(call -> call.getAggregation().getKind())
.collect(Collectors.toList());
.map(call -> call.getAggregation().getKind())
.toList();
}

@Override
Expand All @@ -56,15 +55,15 @@ public Record apply(final Record record1, final Record record2) {

switch (kind) {
case SUM:
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum);
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum, BigDecimal::add);
break;
case MIN:
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::least, SqlFunctions::least,
SqlFunctions::least, SqlFunctions::least);
SqlFunctions::least, SqlFunctions::least, SqlFunctions::least);
break;
case MAX:
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::greatest, SqlFunctions::greatest,
SqlFunctions::greatest, SqlFunctions::greatest);
SqlFunctions::greatest, SqlFunctions::greatest, SqlFunctions::greatest);
break;
case COUNT:
// since aggregates inject an extra column for counting before,
Expand All @@ -76,9 +75,7 @@ public Record apply(final Record record1, final Record record2) {
resValues[counter] = count;
break;
case AVG:
assert (field1 instanceof Integer && field2 instanceof Integer)
: "Expected to find integers for count but found: " + field1 + " and " + field2;
final Object avg = Integer.class.cast(field1) + Integer.class.cast(field2);
final Object avg = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum, BigDecimal::add);

resValues[counter] = avg;

Expand All @@ -95,6 +92,7 @@ public Record apply(final Record record1, final Record record2) {
return new Record(resValues);
}


/**
* Handles casts for the record class for each interior type.
*
Expand All @@ -110,7 +108,8 @@ private Object castAndMap(final Object a, final Object b,
final BiFunction<String, String, Object> stringMap,
final BiFunction<Long, Long, Object> longMap,
final BiFunction<Integer, Integer, Object> integerMap,
final BiFunction<Double, Double, Object> doubleMap) {
final BiFunction<Double, Double, Object> doubleMap,
final BiFunction<BigDecimal, BigDecimal, Object> bigDecimalMap) {
// support operations between null and any
// class
if ((a == null || b == null) || (a.getClass() == b.getClass())) {
Expand All @@ -122,19 +121,16 @@ private Object castAndMap(final Object a, final Object b,
// force .getClass() to be safe so
// we can pass null objects to
// .apply methods.
switch (aWrapped.orElse(bWrapped.orElse("")).getClass().getSimpleName()) {
case "String":
return stringMap.apply((String) a, (String) b);
case "Long":
return longMap.apply((Long) a, (Long) b);
case "Integer":
return integerMap.apply((Integer) a, (Integer) b);
case "Double":
return doubleMap.apply((Double) a, (Double) b);
default:
throw new IllegalStateException("Unsupported operation between: " + aWrapped.getClass().toString()
+ " and: " + bWrapped.getClass().toString());
}
return switch (aWrapped.orElse(bWrapped.orElse("")).getClass().getSimpleName()) {
case "String" -> stringMap.apply((String) a, (String) b);
case "Long" -> longMap.apply((Long) a, (Long) b);
case "Integer" -> integerMap.apply((Integer) a, (Integer) b);
case "Double" -> doubleMap.apply((Double) a, (Double) b);
case "BigDecimal" -> bigDecimalMap.apply((BigDecimal) a, (BigDecimal) b);
default -> throw new IllegalStateException("Unsupported operation between: "
+ aWrapped.getClass().toString()
+ " and: " + bWrapped.getClass().toString());
};
}
throw new IllegalStateException("Unsupported operation between: " + a.getClass().getSimpleName() + " and: "
+ b.getClass().getSimpleName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,33 @@
package org.apache.wayang.api.sql.calcite.converter.functions;

import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Calendar;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;

import org.apache.calcite.util.Sarg;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;

import com.google.common.collect.ImmutableRangeSet;

/**
* AST of the {@link RexCall} arithmetic, composed into serializable nodes;
* {@link Call}, {@link InputRef}, {@link Literal}
*/
interface CallTreeFactory<Input, Output> extends Serializable {
public default Node<Output> fromRexNode(final RexNode node) {
if (node instanceof RexCall) {
return new Call<>((RexCall) node, this);
} else if (node instanceof RexInputRef) {
return new InputRef<>((RexInputRef) node);
} else if (node instanceof RexLiteral) {
return new Literal<>((RexLiteral) node);
interface CallTreeFactory extends Serializable {
public default Node fromRexNode(final RexNode node) {
if (node instanceof final RexCall call) {
return new Call(call, this);
} else if (node instanceof final RexInputRef inputRef) {
return new InputRef(inputRef);
} else if (node instanceof final RexLiteral literal) {
return new Literal(literal);
} else {
throw new UnsupportedOperationException("Unsupported RexNode in filter condition: " + node);
}
Expand All @@ -55,50 +58,66 @@ public default Node<Output> fromRexNode(final RexNode node) {
* @return a serializable function of +, -, * or /
* @throws UnsupportedOperationException on unrecognized {@link SqlKind}
*/
public SerializableFunction<List<Output>, Output> deriveOperation(SqlKind kind);
public SerializableFunction<List<Object>, Object> deriveOperation(SqlKind kind);
}

interface Node<Output> extends Serializable {
public Output evaluate(final Record record);
interface Node extends Serializable {
public Object evaluate(final Record rec);
}

class Call<Input, Output> implements Node<Output> {
final List<Node<Output>> operands;
final SerializableFunction<List<Output>, Output> operation;
class Call implements Node {
private final List<Node> operands;
final SerializableFunction<List<Object>, Object> operation;

protected Call(final RexCall call, final CallTreeFactory<Input, Output> tree) {
operands = call.getOperands().stream().map(tree::fromRexNode).collect(Collectors.toList());
protected Call(final RexCall call, final CallTreeFactory tree) {
operands = call.getOperands().stream().map(tree::fromRexNode).toList();
operation = tree.deriveOperation(call.getKind());
}

@Override
public Output evaluate(final Record record) {
return operation.apply(operands.stream().map(op -> op.evaluate(record)).collect(Collectors.toList()));
public Object evaluate(final Record rec) {
return operation.apply(
operands.stream()
.map(op -> op.evaluate(rec))
.toList());
}
}

class Literal<Output> implements Node<Output> {
final Output value;
class Literal implements Node {
final Serializable value;

Literal(final RexLiteral literal) {
value = (Output) literal.getValue2();
value = switch (literal.getTypeName()) {
case DATE -> literal.getValueAs(Calendar.class);
case INTEGER -> literal.getValueAs(Double.class);
case INTERVAL_DAY -> literal.getValueAs(BigDecimal.class).doubleValue();
case DECIMAL -> literal.getValueAs(BigDecimal.class).doubleValue();
case CHAR -> literal.getValueAs(String.class);
case SARG -> {
final Sarg<?> sarg = literal.getValueAs(Sarg.class);
assert sarg.rangeSet instanceof Serializable : "Sarg RangeSet was not serializable.";
yield (ImmutableRangeSet<?>) sarg.rangeSet;
}
default -> throw new UnsupportedOperationException(
"Literal conversion to Java not implemented, type: " + literal.getTypeName());
};
}

@Override
public Output evaluate(final Record record) {
public Object evaluate(final Record rec) {
return value;
}
}

class InputRef<Output> implements Node<Output> {
class InputRef implements Node {
private final int key;

InputRef(final RexInputRef inputRef) {
this.key = inputRef.getIndex();
}

@Override
public Output evaluate(final Record record) {
return (Output) record.getField(key);
public Object evaluate(final Record rec) {
return rec.getField(key);
}
}
Loading
Loading