@@ -28,7 +28,7 @@ import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaB
2828import org .apache .wayang .api .util .{DataQuantaBuilderCache , TypeTrap }
2929import org .apache .wayang .basic .data .{Record , Tuple2 => RT2 }
3030import org .apache .wayang .basic .model .{DLModel , Model , LogisticRegressionModel ,DecisionTreeRegressionModel }
31- import org .apache .wayang .basic .operators .{DLTrainingOperator , GlobalReduceOperator , LocalCallbackSink , MapOperator , SampleOperator , LogisticRegressionOperator ,DecisionTreeRegressionOperator , LinearSVCOperator }
31+ import org .apache .wayang .basic .operators .{DLTrainingOperator , GlobalReduceOperator , JoinOperator , LocalCallbackSink , MapOperator , ReduceByOperator , SampleOperator , SortOperator , LogisticRegressionOperator ,DecisionTreeRegressionOperator , LinearSVCOperator }
3232import org .apache .wayang .commons .util .profiledb .model .Experiment
3333import org .apache .wayang .core .api .spatial .{SpatialGeometry , SpatialPredicate }
3434import org .apache .wayang .core .function .FunctionDescriptor .{SerializableBiFunction , SerializableBinaryOperator , SerializableFunction , SerializableIntUnaryOperator , SerializablePredicate }
@@ -1020,6 +1020,10 @@ class SortDataQuantaBuilder[T, Key](inputDataQuanta: DataQuantaBuilder[_, T],
10201020 /** [[LoadEstimator ]] to estimate the RAM load of the [[keyUdf ]]. */
10211021 private var keyUdfRamEstimator : LoadEstimator = _
10221022
1023+ /** SQL column and direction implementing the sort key. */
1024+ private var sqlColumnName : String = _
1025+ private var sqlDirection : String = _
1026+
10231027
10241028 // Try to infer the type classes from the UDFs.
10251029 locally {
@@ -1060,8 +1064,27 @@ class SortDataQuantaBuilder[T, Key](inputDataQuanta: DataQuantaBuilder[_, T],
10601064 this
10611065 }
10621066
1063- override protected def build =
1064- applyTargetPlatforms(inputDataQuanta.dataQuanta().sortJava(keyUdf)(this .keyTag), this .getTargetPlatforms())
1067+ /**
1068+ * Add a SQL implementation of the sort key.
1069+ *
1070+ * @param columnName SQL column to sort by
1071+ * @param direction SQL sort direction, e.g. `ASC` or `DESC`
1072+ * @return this instance
1073+ */
1074+ def withSqlUdf (columnName : String , direction : String ) = {
1075+ this .sqlColumnName = columnName
1076+ this .sqlDirection = direction
1077+ this
1078+ }
1079+
1080+ override protected def build = {
1081+ val result = inputDataQuanta.dataQuanta().sortJava(keyUdf)(this .keyTag)
1082+ if (this .sqlColumnName != null ) {
1083+ result.operator.asInstanceOf [SortOperator [T , Key ]]
1084+ .getKeyDescriptor.withSqlImplementation(this .sqlColumnName, this .sqlDirection)
1085+ }
1086+ applyTargetPlatforms(result, this .getTargetPlatforms())
1087+ }
10651088
10661089}
10671090
@@ -1283,6 +1306,10 @@ class ReduceByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T]
12831306 /** [[LoadProfileEstimator ]] to estimate the [[LoadProfile ]] of the [[udf ]]. */
12841307 private var udfLoadProfileEstimator : LoadProfileEstimator = _
12851308
1309+ /** SQL implementations of the grouping key and reduction. */
1310+ private var keySqlUdf : String = _
1311+ private var reduceSqlUdf : String = _
1312+
12861313 // TODO: Add these estimators.
12871314 // /** [[LoadEstimator]] to estimate the CPU load of the [[keyUdf]]. */
12881315 // private var keyUdfCpuEstimator: LoadEstimator = _
@@ -1322,7 +1349,29 @@ class ReduceByDataQuantaBuilder[Key, T](inputDataQuanta: DataQuantaBuilder[_, T]
13221349 this
13231350 }
13241351
1325- override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceByKeyJava(keyUdf, udf, this .udfLoadProfileEstimator), this .getTargetPlatforms())
1352+ /**
1353+ * Add SQL implementations of the grouping key and reduction.
1354+ *
1355+ * @param keySqlUdf SQL grouping column
1356+ * @param reduceSqlUdf SQL aggregate expression
1357+ * @return this instance
1358+ */
1359+ def withSqlUdfs (keySqlUdf : String , reduceSqlUdf : String ) = {
1360+ this .keySqlUdf = keySqlUdf
1361+ this .reduceSqlUdf = reduceSqlUdf
1362+ this
1363+ }
1364+
1365+ override protected def build = {
1366+ val result = inputDataQuanta.dataQuanta()
1367+ .reduceByKeyJava(keyUdf, udf, this .udfLoadProfileEstimator)
1368+ if (this .keySqlUdf != null ) {
1369+ val operator = result.operator.asInstanceOf [ReduceByOperator [T , Key ]]
1370+ operator.getKeyDescriptor.withSqlImplementation(this .keySqlUdf, this .keySqlUdf)
1371+ operator.getReduceDescriptor.withSqlImplementation(this .reduceSqlUdf)
1372+ }
1373+ applyTargetPlatforms(result, this .getTargetPlatforms())
1374+ }
13261375}
13271376
13281377/**
@@ -1402,6 +1451,9 @@ class GlobalReduceDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T],
14021451 /** [[LoadProfileEstimator ]] to estimate the [[LoadProfile ]] of the [[udf ]]. */
14031452 private var udfLoadProfileEstimator : LoadProfileEstimator = _
14041453
1454+ /** SQL implementation of the reduction. */
1455+ private var sqlUdf : String = _
1456+
14051457 // Try to infer the type classes from the udf.
14061458 locally {
14071459 val parameters = ReflectionUtils .getTypeParameters(udf.getClass, classOf [SerializableBinaryOperator [_]])
@@ -1422,7 +1474,25 @@ class GlobalReduceDataQuantaBuilder[T](inputDataQuanta: DataQuantaBuilder[_, T],
14221474 this
14231475 }
14241476
1425- override protected def build = applyTargetPlatforms(inputDataQuanta.dataQuanta().reduceJava(udf, this .udfLoadProfileEstimator), this .getTargetPlatforms())
1477+ /**
1478+ * Add a SQL implementation of the reduction.
1479+ *
1480+ * @param sqlUdf SQL aggregate expression
1481+ * @return this instance
1482+ */
1483+ def withSqlUdf (sqlUdf : String ) = {
1484+ this .sqlUdf = sqlUdf
1485+ this
1486+ }
1487+
1488+ override protected def build = {
1489+ val result = inputDataQuanta.dataQuanta().reduceJava(udf, this .udfLoadProfileEstimator)
1490+ if (this .sqlUdf != null ) {
1491+ result.operator.asInstanceOf [GlobalReduceOperator [T ]]
1492+ .getReduceDescriptor.withSqlImplementation(this .sqlUdf)
1493+ }
1494+ applyTargetPlatforms(result, this .getTargetPlatforms())
1495+ }
14261496
14271497}
14281498
@@ -1490,6 +1560,12 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
14901560 /** [[LoadEstimator ]] to estimate the RAM load of the [[keyUdf1 ]]. */
14911561 private var keyUdf1RamEstimator : LoadEstimator = _
14921562
1563+ /** SQL implementations of both join keys. */
1564+ private var keyUdf0TableName : String = _
1565+ private var keyUdf0SqlUdf : String = _
1566+ private var keyUdf1TableName : String = _
1567+ private var keyUdf1SqlUdf : String = _
1568+
14931569 // Try to infer the type classes from the UDFs.
14941570 locally {
14951571 val parameters = ReflectionUtils .getTypeParameters(keyUdf0.getClass, classOf [SerializableFunction [_, _]])
@@ -1568,6 +1644,22 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
15681644 this
15691645 }
15701646
1647+ /**
1648+ * Add SQL implementations of both join keys.
1649+ *
1650+ * @return this instance
1651+ */
1652+ def withSqlUdfs (thisTableName : String ,
1653+ thisKeySqlUdf : String ,
1654+ thatTableName : String ,
1655+ thatKeySqlUdf : String ) = {
1656+ this .keyUdf0TableName = thisTableName
1657+ this .keyUdf0SqlUdf = thisKeySqlUdf
1658+ this .keyUdf1TableName = thatTableName
1659+ this .keyUdf1SqlUdf = thatKeySqlUdf
1660+ this
1661+ }
1662+
15711663 /**
15721664 * Assemble the joined elements to new elements.
15731665 *
@@ -1579,8 +1671,16 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
15791671 override def apply (joinTuple : RT2 [In0 , In1 ]): NewOut = udf.apply(joinTuple.field0, joinTuple.field1)
15801672 })
15811673
1582- override protected def build =
1583- applyTargetPlatforms(inputDataQuanta0.dataQuanta().joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this .keyTag), this .getTargetPlatforms())
1674+ override protected def build = {
1675+ val result = inputDataQuanta0.dataQuanta()
1676+ .joinJava(keyUdf0, inputDataQuanta1.dataQuanta(), keyUdf1)(inputDataQuanta1.classTag, this .keyTag)
1677+ if (this .keyUdf0SqlUdf != null ) {
1678+ val operator = result.operator.asInstanceOf [JoinOperator [In0 , In1 , Key ]]
1679+ operator.getKeyDescriptor0.withSqlImplementation(this .keyUdf0TableName, this .keyUdf0SqlUdf)
1680+ operator.getKeyDescriptor1.withSqlImplementation(this .keyUdf1TableName, this .keyUdf1SqlUdf)
1681+ }
1682+ applyTargetPlatforms(result, this .getTargetPlatforms())
1683+ }
15841684
15851685}
15861686
0 commit comments