Skip to content

Commit 0e3cfa5

Browse files
committed
data stream join operators for flink platform & code quality
1 parent 08b3ea0 commit 0e3cfa5

3 files changed

Lines changed: 207 additions & 18 deletions

File tree

wayang-platforms/wayang-flink/src/main/java/org/apache/wayang/flink/compiler/FunctionCompiler.java

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public class FunctionCompiler {
5858
* @param <O> output type of the transformation
5959
* @return a compiled function
6060
*/
61-
public <I, O> MapFunction<I, O> compile(TransformationDescriptor<I, O> descriptor) {
61+
public static <I, O> MapFunction<I, O> compile(TransformationDescriptor<I, O> descriptor) {
6262
// This is a dummy method but shows the intention of having something compilable in the descriptors.
6363
Function<I, O> function = descriptor.getJavaImplementation();
6464
return (MapFunction<I, O>) i -> function.apply(i);
@@ -72,7 +72,7 @@ public <I, O> MapFunction<I, O> compile(TransformationDescriptor<I, O> descripto
7272
* @param <O> output type of the transformation
7373
* @return a compiled function
7474
*/
75-
public <I, O> FlatMapFunction<I, O> compile(FunctionDescriptor.SerializableFunction<I, Iterable<O>> flatMapDescriptor) {
75+
public static <I, O> FlatMapFunction<I, O> compile(FunctionDescriptor.SerializableFunction<I, Iterable<O>> flatMapDescriptor) {
7676
return (t, collector) -> flatMapDescriptor.apply(t).forEach(collector::collect);
7777
}
7878

@@ -83,7 +83,7 @@ public <I, O> FlatMapFunction<I, O> compile(FunctionDescriptor.SerializableFunct
8383
* @param <T> input/output type of the transformation
8484
* @return a compiled function
8585
*/
86-
public <T> ReduceFunction<T> compile(ReduceDescriptor<T> descriptor) {
86+
public static <T> ReduceFunction<T> compile(ReduceDescriptor<T> descriptor) {
8787
// This is a dummy method but shows the intention of having something compilable in the descriptors.
8888
BiFunction<T, T, T> reduce_function = descriptor.getJavaImplementation();
8989
return new ReduceFunction<T>() {
@@ -94,26 +94,26 @@ public T reduce(T t, T t1) throws Exception {
9494
};
9595
}
9696

97-
public <T> FilterFunction<T> compile(PredicateDescriptor.SerializablePredicate<T> predicateDescriptor) {
98-
return t -> predicateDescriptor.test(t);
97+
public static <T> FilterFunction<T> compile(PredicateDescriptor.SerializablePredicate<T> predicateDescriptor) {
98+
return predicateDescriptor::test;
9999
}
100100

101101

102-
public <T> OutputFormat<T> compile(ConsumerDescriptor.SerializableConsumer<T> consumerDescriptor) {
102+
public static <T> OutputFormat<T> compile(ConsumerDescriptor.SerializableConsumer<T> consumerDescriptor) {
103103
return new OutputFormatConsumer<T>(consumerDescriptor);
104104
}
105105

106106

107-
public <T, K> KeySelector<T, K> compileKeySelector(TransformationDescriptor<T, K> descriptor){
107+
public static <T, K> KeySelector<T, K> compileKeySelector(TransformationDescriptor<T, K> descriptor){
108108
return new KeySelectorFunction<T, K>(descriptor);
109109
}
110110

111-
public <T0, T1, O> CoGroupFunction<T0, T1, O> compileCoGroup(){
111+
public static <T0, T1, O> CoGroupFunction<T0, T1, O> compileCoGroup(){
112112
return new FlinkCoGroupFunction<T0, T1, O>();
113113
}
114114

115115

116-
public <T> TextOutputFormat.TextFormatter<T> compileOutput(TransformationDescriptor<T, String> formattingDescriptor) {
116+
public static <T> TextOutputFormat.TextFormatter<T> compileOutput(TransformationDescriptor<T, String> formattingDescriptor) {
117117
Function<T, String> format = formattingDescriptor.getJavaImplementation();
118118
return new TextOutputFormat.TextFormatter<T>(){
119119

@@ -132,7 +132,7 @@ public String format(T value) {
132132
* @param <O> output type of the transformation
133133
* @return a compiled function
134134
*/
135-
public <I, O> MapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor){
135+
public static <I, O> MapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor){
136136
Function<Iterable<I>, Iterable<O>> function = descriptor.getJavaImplementation();
137137
return new MapPartitionFunction<I, O>() {
138138
@Override
@@ -146,13 +146,12 @@ public void mapPartition(Iterable<I> iterable, Collector<O> collector) throws Ex
146146
};
147147
}
148148

149-
public <T> WayangConvergenceCriterion compile(PredicateDescriptor<Collection<T>> descriptor){
150-
FunctionDescriptor.SerializablePredicate<Collection<T>> predicate = descriptor.getJavaImplementation();
151-
return new WayangConvergenceCriterion(predicate);
149+
public static <T> WayangConvergenceCriterion<T> compile(PredicateDescriptor<Collection<T>> descriptor){
150+
return new WayangConvergenceCriterion<T>(descriptor.getJavaImplementation());
152151
}
153152

154153

155-
public <I, O> RichFlatMapFunction<I, O> compile(FunctionDescriptor.ExtendedSerializableFunction<I, Iterable<O>> flatMapDescriptor, FlinkExecutionContext exe) {
154+
public static <I, O> RichFlatMapFunction<I, O> compile(FunctionDescriptor.ExtendedSerializableFunction<I, Iterable<O>> flatMapDescriptor, FlinkExecutionContext exe) {
156155

157156
return new RichFlatMapFunction<I, O>() {
158157
@Override
@@ -168,7 +167,7 @@ public void flatMap(I value, Collector<O> out) throws Exception {
168167
}
169168

170169

171-
public <I, O> RichMapFunction<I, O> compile(TransformationDescriptor<I, O> mapDescriptor, FlinkExecutionContext fex ) {
170+
public static <I, O> RichMapFunction<I, O> compile(TransformationDescriptor<I, O> mapDescriptor, FlinkExecutionContext fex ) {
172171

173172
FunctionDescriptor.ExtendedSerializableFunction<I, O> map = (FunctionDescriptor.ExtendedSerializableFunction) mapDescriptor.getJavaImplementation();
174173
return new RichMapFunction<I, O>() {
@@ -186,7 +185,7 @@ public void open(Configuration parameters) throws Exception {
186185

187186

188187

189-
public <I, O> RichMapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor, FlinkExecutionContext fex){
188+
public static <I, O> RichMapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor, FlinkExecutionContext fex){
190189
FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>> function =
191190
(FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>>)
192191
descriptor.getJavaImplementation();
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.flink.operators;
20+
21+
import java.time.Duration;
22+
import java.util.Arrays;
23+
import java.util.Collection;
24+
import java.util.List;
25+
26+
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
27+
import org.apache.flink.api.common.functions.JoinFunction;
28+
import org.apache.flink.streaming.api.datastream.DataStream;
29+
import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
30+
31+
import org.apache.wayang.basic.data.Tuple2;
32+
import org.apache.wayang.basic.function.ProjectionDescriptor;
33+
import org.apache.wayang.basic.operators.JoinOperator;
34+
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
35+
import org.apache.wayang.core.optimizer.OptimizationContext.OperatorContext;
36+
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
37+
import org.apache.wayang.core.platform.ChannelDescriptor;
38+
import org.apache.wayang.core.platform.ChannelInstance;
39+
import org.apache.wayang.core.platform.lineage.ExecutionLineageNode;
40+
import org.apache.wayang.core.util.Tuple;
41+
import org.apache.wayang.flink.channels.DataSetChannel;
42+
import org.apache.wayang.flink.channels.DataStreamChannel;
43+
import org.apache.wayang.flink.compiler.FunctionCompiler;
44+
import org.apache.wayang.flink.execution.FlinkExecutor;
45+
46+
public class FlinkDataStreamJoinOperator<I0, I1, K> extends JoinOperator<I0, I1, K> implements FlinkExecutionOperator {
47+
class Joiner implements JoinFunction<I0, I1, Tuple2<I0, I1>> {
48+
@Override
49+
public Tuple2<I0, I1> join(final I0 first, final I1 second) throws Exception {
50+
return new Tuple2<>(first, second);
51+
}
52+
}
53+
final WatermarkStrategy<I0> leftWatermarkStrategy;
54+
final WatermarkStrategy<I1> rightWatermarkStrategy;
55+
56+
final Duration duration;
57+
58+
public FlinkDataStreamJoinOperator(final ProjectionDescriptor<I0, K> descriptor0,
59+
final ProjectionDescriptor<I1, K> descriptor1) {
60+
this(descriptor0, descriptor1,
61+
WatermarkStrategy
62+
.<I0>forMonotonousTimestamps()
63+
.withTimestampAssigner((e, ts) -> 0L),
64+
WatermarkStrategy
65+
.<I1>forMonotonousTimestamps()
66+
.withTimestampAssigner((e, ts) -> 0L),
67+
Duration.ofDays(365));
68+
;
69+
}
70+
71+
public FlinkDataStreamJoinOperator(final ProjectionDescriptor<I0, K> descriptor0,
72+
final ProjectionDescriptor<I1, K> descriptor1, final WatermarkStrategy<I0> leftWatermarkStrategy,
73+
final WatermarkStrategy<I1> rightWatermarkStrategy, final Duration duration) {
74+
super(descriptor0, descriptor1);
75+
this.leftWatermarkStrategy = leftWatermarkStrategy;
76+
this.rightWatermarkStrategy = rightWatermarkStrategy;
77+
this.duration = duration;
78+
}
79+
80+
public FlinkDataStreamJoinOperator(final SerializableFunction<I0, K> keyExtractor0,
81+
final SerializableFunction<I1, K> keyExtractor1, final Class<I0> input0Class, final Class<I1> input1Class,
82+
final Class<K> keyClass, final WatermarkStrategy<I0> leftWatermarkStrategy,
83+
final WatermarkStrategy<I1> rightWatermarkStrategy, final Duration duration) {
84+
super(keyExtractor0, keyExtractor1, input0Class, input1Class, keyClass);
85+
86+
this.leftWatermarkStrategy = leftWatermarkStrategy;
87+
this.rightWatermarkStrategy = rightWatermarkStrategy;
88+
this.duration = duration;
89+
}
90+
91+
@Override
92+
public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate(final ChannelInstance[] inputs,
93+
final ChannelInstance[] kputs, final FlinkExecutor flinkExecutor, final OperatorContext operatorContext)
94+
throws Exception {
95+
assert inputs.length == this.getNumInputs();
96+
assert kputs.length == this.getNumOutputs();
97+
98+
final DataStreamChannel.Instance input0 = (DataStreamChannel.Instance) inputs[0];
99+
final DataStreamChannel.Instance input1 = (DataStreamChannel.Instance) inputs[1];
100+
final DataStreamChannel.Instance output = (DataStreamChannel.Instance) kputs[0];
101+
102+
final DataStream<I0> dataStream0 = input0.provideDataStream();
103+
final DataStream<I1> dataStream1 = input1.provideDataStream();
104+
105+
final DataStream<Tuple2<I0, I1>> outputStream = dataStream0
106+
.assignTimestampsAndWatermarks(leftWatermarkStrategy)
107+
.join(dataStream1.assignTimestampsAndWatermarks(rightWatermarkStrategy))
108+
.where(FunctionCompiler.compileKeySelector(keyDescriptor0))
109+
.equalTo(FunctionCompiler.compileKeySelector(keyDescriptor1))
110+
.window(TumblingEventTimeWindows.of(duration))
111+
.apply(new Joiner());
112+
113+
output.accept(outputStream);
114+
115+
return ExecutionOperator.modelLazyExecution(inputs, kputs, operatorContext);
116+
}
117+
118+
@Override
119+
public boolean containsAction() {
120+
return false;
121+
}
122+
123+
@Override
124+
public List<ChannelDescriptor> getSupportedInputChannels(final int index) {
125+
assert index <= this.getNumInputs() || (index == 0 && this.getNumInputs() == 0);
126+
return Arrays.asList(DataSetChannel.DESCRIPTOR, DataSetChannel.DESCRIPTOR_MANY);
127+
}
128+
129+
@Override
130+
public List<ChannelDescriptor> getSupportedOutputChannels(final int index) {
131+
assert index <= this.getNumOutputs() || (index == 0 && this.getNumOutputs() == 0);
132+
// return Collections.singletonList(DataSetChannel.DESCRIPTOR);
133+
return Arrays.asList(DataSetChannel.DESCRIPTOR, DataSetChannel.DESCRIPTOR_MANY);
134+
}
135+
}

wayang-platforms/wayang-flink/src/test/java/org/apache/wayang/flink/operators/FlinkDataStreamTests.java

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@
2323

2424
import org.apache.flink.streaming.api.datastream.DataStream;
2525
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
26-
26+
import org.apache.wayang.basic.data.Tuple2;
27+
import org.apache.wayang.basic.function.ProjectionDescriptor;
2728
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
2829
import org.apache.wayang.core.platform.ChannelInstance;
2930
import org.apache.wayang.core.types.DataSetType;
31+
import org.apache.wayang.core.types.DataUnitType;
3032
import org.apache.wayang.flink.channels.DataStreamChannel;
3133
import org.apache.wayang.java.channels.CollectionChannel;
32-
34+
import org.junit.jupiter.api.RepeatedTest;
3335
import org.junit.jupiter.api.Test;
36+
37+
import static org.junit.jupiter.api.Assertions.assertEquals;
3438
import static org.junit.jupiter.api.Assertions.assertTrue;
3539

3640
public class FlinkDataStreamTests extends FlinkOperatorTestBase {
@@ -113,4 +117,55 @@ public void javaConversion() throws Exception {
113117

114118
assertTrue(sinkOutput.provideCollection().size() > 0);
115119
}
120+
121+
@RepeatedTest(5)
122+
public void joinTest() throws Exception {
123+
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
124+
// Set up channels
125+
final DataStreamChannel.Instance input1 = this.createDataStreamChannelInstance();
126+
input1.accept(
127+
env.fromData(new Tuple2<>(1, "b"), new Tuple2<>(1, "c"), new Tuple2<>(2, "d"), new Tuple2<>(3, "e")));
128+
final DataStreamChannel.Instance input2 = this.createDataStreamChannelInstance();
129+
input2.accept(
130+
env.fromData(new Tuple2<>("x", 1), new Tuple2<>("y", 1), new Tuple2<>("z", 2), new Tuple2<>("w", 4)));
131+
132+
final DataStreamChannel.Instance output = this.createDataStreamChannelInstance();
133+
134+
// Set up the ChannelInstances.
135+
final ChannelInstance[] inputs = new ChannelInstance[] { input1, input2 };
136+
final ChannelInstance[] outputs = new ChannelInstance[] { output };
137+
138+
// Set up JoinOperator
139+
final ProjectionDescriptor<Tuple2<Integer, String>, Integer> left = new ProjectionDescriptor<>(
140+
DataUnitType.createBasicUnchecked(Tuple2.class),
141+
DataUnitType.createBasic(Integer.class),
142+
"field0");
143+
final ProjectionDescriptor<Tuple2<String, Integer>, Integer> right = new ProjectionDescriptor<>(
144+
DataUnitType.createBasicUnchecked(Tuple2.class),
145+
DataUnitType.createBasic(Integer.class),
146+
"field1");
147+
final FlinkDataStreamJoinOperator<Tuple2<Integer, String>, Tuple2<String, Integer>, Integer> join = new FlinkDataStreamJoinOperator<>(
148+
left, right);
149+
150+
// Execute.
151+
this.evaluate(join, inputs, outputs);
152+
153+
final DataStream<Tuple2<?, ?>> stream = output.<Tuple2<?, ?>>provideDataStream();
154+
final Iterator<Tuple2<?, ?>> ints = stream.executeAndCollect();
155+
156+
final ArrayList<Tuple2<?, ?>> collection = new ArrayList<>();
157+
ints.forEachRemaining(collection::add);
158+
159+
assertEquals(5, collection.size());
160+
assertTrue(collection.stream()
161+
.anyMatch(res -> res.equals(new Tuple2<>(new Tuple2<>(1, "b"), new Tuple2<>("x", 1)))));
162+
assertTrue(collection.stream()
163+
.anyMatch(res -> res.equals(new Tuple2<>(new Tuple2<>(1, "b"), new Tuple2<>("y", 1)))));
164+
assertTrue(collection.stream()
165+
.anyMatch(res -> res.equals(new Tuple2<>(new Tuple2<>(1, "c"), new Tuple2<>("x", 1)))));
166+
assertTrue(collection.stream()
167+
.anyMatch(res -> res.equals(new Tuple2<>(new Tuple2<>(1, "c"), new Tuple2<>("y", 1)))));
168+
assertTrue(collection.stream()
169+
.anyMatch(res -> res.equals(new Tuple2<>(new Tuple2<>(2, "d"), new Tuple2<>("z", 2)))));
170+
}
116171
}

0 commit comments

Comments
 (0)