Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,42 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
udfLoad: LoadProfileEstimator = null): DataQuanta[NewOut] =
mapPartitionsJava(toSerializablePartitionFunction(udf), selectivity, udfLoad)


/**
* Feed this instance into a [[LogisticRegressionOperator]].
* Trains a logistic regression model using the provided feature and label data.
*
* @param labels DataQuanta containing the label values (0.0 or 1.0)
* @param fitIntercept whether to fit an intercept term
* @return a new [[DataQuanta]] instance containing the trained [[LogisticRegressionModel]]
*/
def trainLogisticRegression(labels: DataQuanta[java.lang.Double], fitIntercept: Boolean): DataQuanta[LogisticRegressionModel] = {
val operator = new LogisticRegressionOperator(fitIntercept)
this.connectTo(operator, 0)
labels.connectTo(operator, 1)
operator
}

/**
* Feed this instance into a [[TimeSeriesDecisionTreeRegressionOperator]].
* Trains a Decision Tree Regression model over time series data using lag-based features.
*
* @param labels DataQuanta containing the target values to predict
* @param lag the number of previous time steps to use as input features
* @param maxDepth the maximum depth of the decision tree
* @param minInstances the minimum number of instances per node
* @return a new [[DataQuanta]] instance containing the predicted values
*/
def trainTimeSeriesDecisionTree(
labels: DataQuanta[java.lang.Double],
lag: Int,
maxDepth: Int,
minInstances: Int
): DataQuanta[java.lang.Double] = {
val operator = new TimeSeriesDecisionTreeRegressionOperator(lag, maxDepth, minInstances)
this.connectTo(operator, 0)
labels.connectTo(operator, 1)
operator
}



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaB
import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap}
import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
import org.apache.wayang.basic.model.{DLModel, Model, LogisticRegressionModel}
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator, LogisticRegressionOperator}
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator, LogisticRegressionOperator,TimeSeriesDecisionTreeRegressionOperator}
import org.apache.wayang.commons.util.profiledb.model.Experiment
import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate}
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval
Expand Down Expand Up @@ -291,10 +291,48 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
option: DLTrainingOperator.Option) =
new DLTrainingDataQuantaBuilder(this, that, model, option)


/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.LogisticRegressionOperator]].
* This operator trains a logistic regression model using the provided features and labels.
*
* @param that the [[DataQuantaBuilder]] containing the label values (0.0 or 1.0)
* @param fitIntercept whether to include an intercept term in the model
* @return a [[LogisticRegressionDataQuantaBuilder]] for the trained [[LogisticRegressionModel]]
*/
def trainLogisticRegression(that: DataQuantaBuilder[_, java.lang.Double], fitIntercept: Boolean = true): LogisticRegressionDataQuantaBuilder =
new LogisticRegressionDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Array[Double]]], that, fitIntercept)


/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.TimeSeriesDecisionTreeRegressionOperator]].
* This operator generates lagged features internally and trains a Spark DecisionTreeRegressor
* for time series forecasting.
*
* @param that the [[DataQuantaBuilder]] containing the label values
* @param lag the number of previous time steps to use as input features
* @param maxDepth the maximum depth of the decision tree
* @param minInstances the minimum number of instances per node in the tree
* @return a [[DataQuantaBuilder]] containing the predicted output values
*/
def trainTimeSeriesDecisionTree(
that: DataQuantaBuilder[_, java.lang.Double],
lag: Int,
maxDepth: Int,
minInstances: Int
): DataQuantaBuilder[_, java.lang.Double] =
new CustomOperatorDataQuantaBuilder[java.lang.Double](
new TimeSeriesDecisionTreeRegressionOperator(lag, maxDepth, minInstances),
0,
new DataQuantaBuilderCache,
this,
that
)





/**
Expand Down Expand Up @@ -1779,8 +1817,8 @@ class FakeDataQuantaBuilder[T](_dataQuanta: DataQuanta[T])(implicit javaPlanBuil
/**
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.LogisticRegressionOperator]]s.
*
* @param inputDataQuanta0 [[DataQuantaBuilder]] για τα χαρακτηριστικά (features)
* @param inputDataQuanta1 [[DataQuantaBuilder]] για τις ετικέτες (labels)
* @param inputDataQuanta0 [[DataQuantaBuilder]] (features)
* @param inputDataQuanta1 [[DataQuantaBuilder]] (labels)
*/
class LogisticRegressionDataQuantaBuilder(inputDataQuanta0: DataQuantaBuilder[_, Array[Double]],
inputDataQuanta1: DataQuantaBuilder[_, java.lang.Double],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.wayang.basic.operators;

import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimator;
import org.apache.wayang.core.plan.wayangplan.BinaryToUnaryOperator;
import org.apache.wayang.core.types.DataSetType;

import java.util.Optional;

/**
* A time series regression operator using Spark's DecisionTreeRegressor.
* It expects a time series as input and generates lag features internally.
*/
public class TimeSeriesDecisionTreeRegressionOperator extends BinaryToUnaryOperator<double[], Double, Void> {

private final int lagWindowSize;
private final int maxDepth;
private final int minInstancesPerNode;

public TimeSeriesDecisionTreeRegressionOperator(int lagWindowSize, int maxDepth, int minInstancesPerNode) {
super(
DataSetType.createDefaultUnchecked(double[].class), // Time series input
DataSetType.createDefaultUnchecked(Double.class), // Predicted value output
DataSetType.none(), // No model output
false
);
this.lagWindowSize = lagWindowSize;
this.maxDepth = maxDepth;
this.minInstancesPerNode = minInstancesPerNode;
}

public TimeSeriesDecisionTreeRegressionOperator(TimeSeriesDecisionTreeRegressionOperator that) {
super(that);
this.lagWindowSize = that.lagWindowSize;
this.maxDepth = that.maxDepth;
this.minInstancesPerNode = that.minInstancesPerNode;
}

public int getLagWindowSize() {
return lagWindowSize;
}

public int getMaxDepth() {
return maxDepth;
}

public int getMinInstancesPerNode() {
return minInstancesPerNode;
}

@Override
public Optional<CardinalityEstimator> createCardinalityEstimator(int outputIndex, Configuration configuration) {
return super.createCardinalityEstimator(outputIndex, configuration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class Mappings {
new DecisionTreeClassificationMapping(),
new ModelTransformMapping(),
new LogisticRegressionMapping(),
new TimeSeriesDecisionTreeRegressionMapping(),
new PredictMapping()
);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/



package org.apache.wayang.spark.mapping.ml;

import org.apache.wayang.basic.operators.TimeSeriesDecisionTreeRegressionOperator;
import org.apache.wayang.core.mapping.*;
import org.apache.wayang.spark.operators.ml.SparkTimeSeriesDecisionTreeRegressionOperator;
import org.apache.wayang.spark.platform.SparkPlatform;

import java.util.Collection;
import java.util.Collections;

/**
* Mapping from {@link TimeSeriesDecisionTreeRegressionOperator} to {@link SparkTimeSeriesDecisionTreeRegressionOperator}.
*/
@SuppressWarnings("unchecked")
public class TimeSeriesDecisionTreeRegressionMapping implements Mapping {

@Override
public Collection<PlanTransformation> getTransformations() {
return Collections.singleton(new PlanTransformation(
this.createSubplanPattern(),
this.createReplacementSubplanFactory(),
SparkPlatform.getInstance()
));
}

private SubplanPattern createSubplanPattern() {
final OperatorPattern operatorPattern = new OperatorPattern(
"timeSeriesDecisionTreeRegression",
new TimeSeriesDecisionTreeRegressionOperator(3, 5, 2), // example default params
false
);
return SubplanPattern.createSingleton(operatorPattern);
}

private ReplacementSubplanFactory createReplacementSubplanFactory() {
return new ReplacementSubplanFactory.OfSingleOperators<TimeSeriesDecisionTreeRegressionOperator>(
(matchedOperator, epoch) -> new SparkTimeSeriesDecisionTreeRegressionOperator(matchedOperator).at(epoch)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.wayang.spark.operators.ml;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.*;
import org.apache.wayang.basic.operators.TimeSeriesDecisionTreeRegressionOperator;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
import org.apache.wayang.core.platform.ChannelDescriptor;
import org.apache.wayang.core.platform.ChannelInstance;
import org.apache.wayang.core.platform.lineage.ExecutionLineageNode;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.java.channels.CollectionChannel;
import org.apache.wayang.spark.channels.RddChannel;
import org.apache.wayang.spark.execution.SparkExecutor;
import org.apache.wayang.spark.operators.SparkExecutionOperator;

import java.util.*;

public class SparkTimeSeriesDecisionTreeRegressionOperator
extends TimeSeriesDecisionTreeRegressionOperator
implements SparkExecutionOperator {

private static final String FEATURES = "features";
private static final String LABEL = "label";

private static final StructType SCHEMA = new StructType(new StructField[]{
new StructField(FEATURES, new VectorUDT(), false, Metadata.empty()),
new StructField(LABEL, DataTypes.DoubleType, false, Metadata.empty())
});

public SparkTimeSeriesDecisionTreeRegressionOperator(int lagWindowSize, int maxDepth, int minInstancesPerNode) {
super(lagWindowSize, maxDepth, minInstancesPerNode);
}

public SparkTimeSeriesDecisionTreeRegressionOperator(TimeSeriesDecisionTreeRegressionOperator that) {
super(that);
}

private static Dataset<Row> createLaggedData(JavaRDD<double[]> seriesRdd, int lag) {
JavaRDD<Row> rows = seriesRdd.flatMap(series -> {
List<Row> result = new ArrayList<>();
for (int i = lag; i < series.length; i++) {
double[] input = Arrays.copyOfRange(series, i - lag, i);
double label = series[i];
result.add(RowFactory.create(Vectors.dense(input), label));
}
return result.iterator();
});

return SparkSession.builder().getOrCreate().createDataFrame(rows, SCHEMA);
}

@Override
public List<ChannelDescriptor> getSupportedInputChannels(int index) {
return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR);
}

@Override
public List<ChannelDescriptor> getSupportedOutputChannels(int index) {
return Collections.singletonList(CollectionChannel.DESCRIPTOR);
}

@Override
public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate(
ChannelInstance[] inputs,
ChannelInstance[] outputs,
SparkExecutor sparkExecutor,
OptimizationContext.OperatorContext operatorContext) {

RddChannel.Instance featuresInput = (RddChannel.Instance) inputs[0];
RddChannel.Instance labelsInput = (RddChannel.Instance) inputs[1];
CollectionChannel.Instance output = (CollectionChannel.Instance) outputs[0];

JavaRDD<double[]> timeSeriesRdd = featuresInput.provideRdd(); // 1D time series as array
int lag = this.getLagWindowSize();

Dataset<Row> trainingData = createLaggedData(timeSeriesRdd, lag);

DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setLabelCol(LABEL)
.setFeaturesCol(FEATURES)
.setMaxDepth(this.getMaxDepth())
.setMinInstancesPerNode(this.getMinInstancesPerNode());

DecisionTreeRegressionModel model = dt.fit(trainingData);

// Predict next values for each feature vector used in training
Dataset<Row> predictions = model.transform(trainingData);
JavaRDD<Double> predictedValues = predictions.toJavaRDD().map(row -> row.getDouble(2)); // prediction col is at index 2

output.accept(predictedValues.collect());

return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext);
}

@Override
public boolean containsAction() {
return false;
}
}
Loading
Loading