Skip to content

Commit 93fbadb

Browse files
committed
Ensure plan enumeration & channel conversion are deterministic. Use ordered sets for scope/slot tracking, add stable cost tiebreaker, and cover with determinism tests.
1 parent 7bcf2ff commit 93fbadb

6 files changed

Lines changed: 225 additions & 32 deletions

File tree

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,16 @@
4747
import java.util.Collections;
4848
import java.util.Comparator;
4949
import java.util.HashMap;
50-
import java.util.HashSet;
5150
import java.util.Iterator;
5251
import java.util.LinkedList;
52+
import java.util.LinkedHashSet;
5353
import java.util.List;
5454
import java.util.Map;
5555
import java.util.Random;
5656
import java.util.Set;
5757
import java.util.function.ToDoubleFunction;
5858
import java.util.stream.Collectors;
59+
import java.util.stream.StreamSupport;
5960

6061
/**
6162
* This graph contains a set of {@link ChannelConversion}s.
@@ -168,7 +169,7 @@ private Tree mergeTrees(Collection<Tree> trees) {
168169
final Tree firstTree = iterator.next();
169170
Bitmask combinationSettledIndices = new Bitmask(firstTree.settledDestinationIndices);
170171
int maxSettledIndices = combinationSettledIndices.cardinality();
171-
final HashSet<ChannelDescriptor> employedChannelDescriptors = new HashSet<>(firstTree.employedChannelDescriptors);
172+
final LinkedHashSet<ChannelDescriptor> employedChannelDescriptors = new LinkedHashSet<>(firstTree.employedChannelDescriptors);
172173
int maxVisitedChannelDescriptors = employedChannelDescriptors.size();
173174
double costs = firstTree.costs;
174175
TreeVertex newRoot = new TreeVertex(firstTree.root.channelDescriptor, firstTree.root.settledIndices);
@@ -222,7 +223,11 @@ public static class CostbasedTreeSelectionStrategy implements TreeSelectionStrat
222223

223224
@Override
224225
public Tree select(Tree t1, Tree t2) {
225-
return t1.costs <= t2.costs ? t1 : t2;
226+
int cmp = Double.compare(t1.costs, t2.costs);
227+
if (cmp == 0) {
228+
cmp = t1.getDeterministicSignature().compareTo(t2.getDeterministicSignature());
229+
}
230+
return cmp <= 0 ? t1 : t2;
226231
}
227232

228233
}
@@ -381,7 +386,7 @@ private ShortestTreeSearcher(OutputSlot<?> sourceOutput,
381386
this.existingDestinationChannelIndices = new Bitmask();
382387

383388
this.collectExistingChannels(sourceChannel);
384-
this.openChannelDescriptors = new HashSet<>(openChannels.size());
389+
this.openChannelDescriptors = new LinkedHashSet<>(openChannels.size());
385390
for (Channel openChannel : openChannels) {
386391
this.openChannelDescriptors.add(openChannel.getDescriptor());
387392
}
@@ -477,7 +482,9 @@ private Set<ChannelDescriptor> resolveSupportedChannels(final InputSlot<?> input
477482
final List<ChannelDescriptor> supportedInputChannels = owner.getSupportedInputChannels(input.getIndex());
478483
if (input.isLoopInvariant()) {
479484
// Loop input is needed in several iterations and must therefore be reusable.
480-
return supportedInputChannels.stream().filter(ChannelDescriptor::isReusable).collect(Collectors.toSet());
485+
return supportedInputChannels.stream()
486+
.filter(ChannelDescriptor::isReusable)
487+
.collect(Collectors.toCollection(LinkedHashSet::new));
481488
} else {
482489
return WayangCollections.asSet(supportedInputChannels);
483490
}
@@ -546,7 +553,7 @@ private void kernelizeChannelRequests() {
546553
}
547554
if (channelDescriptors.size() - numReusableChannels == 1) {
548555
iterator.remove();
549-
channelDescriptors = new HashSet<>(channelDescriptors);
556+
channelDescriptors = new LinkedHashSet<>(channelDescriptors);
550557
channelDescriptors.removeIf(channelDescriptor -> !channelDescriptor.isReusable());
551558
kernelDestChannelDescriptorSetsToIndicesUpdates.add(new Tuple<>(channelDescriptors, indices));
552559
}
@@ -575,7 +582,7 @@ private void kernelizeChannelRequests() {
575582
*/
576583
private Tree searchTree() {
577584
// Prepare the recursive traversal.
578-
final HashSet<ChannelDescriptor> visitedChannelDescriptors = new HashSet<>(16);
585+
final LinkedHashSet<ChannelDescriptor> visitedChannelDescriptors = new LinkedHashSet<>(16);
579586
visitedChannelDescriptors.add(this.sourceChannelDescriptor);
580587

581588
// Perform the traversal.
@@ -777,7 +784,7 @@ private Set<ChannelDescriptor> getSuccessorChannelDescriptors(ChannelDescriptor
777784
final Channel channel = this.existingChannels.get(descriptor);
778785
if (channel == null || this.openChannelDescriptors.contains(descriptor)) return null;
779786

780-
Set<ChannelDescriptor> result = new HashSet<>();
787+
Set<ChannelDescriptor> result = new LinkedHashSet<>();
781788
for (ExecutionTask consumer : channel.getConsumers()) {
782789
if (!consumer.getOperator().isAuxiliary()) continue;
783790
for (Channel successorChannel : consumer.getOutputChannels()) {
@@ -988,7 +995,12 @@ private static class Tree {
988995
*
989996
* @see TreeVertex#channelDescriptor
990997
*/
991-
private final Set<ChannelDescriptor> employedChannelDescriptors = new HashSet<>();
998+
private final Set<ChannelDescriptor> employedChannelDescriptors = new LinkedHashSet<>();
999+
1000+
/**
1001+
* Cached deterministic signature for tie-breaking.
1002+
*/
1003+
private String deterministicSignature;
9921004

9931005
/**
9941006
* The sum of the costs of all {@link TreeEdge}s of this instance.
@@ -1010,6 +1022,7 @@ static Tree singleton(ChannelDescriptor channelDescriptor, Bitmask settledIndice
10101022
this.root = root;
10111023
this.settledDestinationIndices = settledDestinationIndices;
10121024
this.employedChannelDescriptors.add(root.channelDescriptor);
1025+
this.deterministicSignature = null;
10131026
}
10141027

10151028
/**
@@ -1033,6 +1046,21 @@ void reroot(ChannelDescriptor newRootChannelDescriptor,
10331046
this.employedChannelDescriptors.add(newRootChannelDescriptor);
10341047
this.settledDestinationIndices.orInPlace(newRootSettledIndices);
10351048
this.costs += edge.costEstimate;
1049+
this.deterministicSignature = null;
1050+
}
1051+
1052+
private String getDeterministicSignature() {
1053+
if (this.deterministicSignature == null) {
1054+
final String descriptorSignature = this.employedChannelDescriptors.stream()
1055+
.map(Object::toString)
1056+
.sorted()
1057+
.collect(Collectors.joining("|"));
1058+
final String indexSignature = StreamSupport.stream(this.settledDestinationIndices.spliterator(), false)
1059+
.map(String::valueOf)
1060+
.collect(Collectors.joining(","));
1061+
this.deterministicSignature = descriptorSignature + "#" + indexSignature;
1062+
}
1063+
return this.deterministicSignature;
10361064
}
10371065

10381066
@Override
@@ -1090,7 +1118,7 @@ private void copyEdgesFrom(TreeVertex that) {
10901118
* @return a {@link Set} of said {@link ChannelConversion}s
10911119
*/
10921120
private Set<ChannelConversion> getChildChannelConversions() {
1093-
Set<ChannelConversion> channelConversions = new HashSet<>();
1121+
Set<ChannelConversion> channelConversions = new LinkedHashSet<>();
10941122
for (TreeEdge edge : this.outEdges) {
10951123
channelConversions.add(edge.channelConversion);
10961124
channelConversions.addAll(edge.destination.getChildChannelConversions());

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import java.util.Collection;
4343
import java.util.Collections;
4444
import java.util.HashMap;
45-
import java.util.HashSet;
45+
import java.util.LinkedHashSet;
4646
import java.util.LinkedList;
4747
import java.util.List;
4848
import java.util.Map;
@@ -91,7 +91,7 @@ public class PlanEnumeration {
9191
* Creates a new instance.
9292
*/
9393
public PlanEnumeration() {
94-
this(new HashSet<>(), new HashSet<>(), new HashSet<>());
94+
this(new LinkedHashSet<>(), new LinkedHashSet<>(), new LinkedHashSet<>());
9595
}
9696

9797
/**

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import java.util.HashMap;
5454
import java.util.HashSet;
5555
import java.util.Iterator;
56+
import java.util.LinkedHashSet;
5657
import java.util.LinkedList;
5758
import java.util.List;
5859
import java.util.Map;
@@ -255,7 +256,7 @@ Collection<InputSlot<?>> findExecutionOperatorInputs(final InputSlot<?> someInpu
255256

256257
// Discern LoopHeadOperator InputSlots and loop body InputSlots.
257258
final List<LoopImplementation.IterationImplementation> iterationImpls = loopImplementation.getIterationImplementations();
258-
final Collection<InputSlot<?>> collector = new HashSet<>(innerInputs.size());
259+
final Collection<InputSlot<?>> collector = new LinkedHashSet<>(innerInputs.size());
259260
for (InputSlot<?> innerInput : innerInputs) {
260261
if (innerInput.getOwner() == loopSubplan.getLoopHead()) {
261262
final LoopImplementation.IterationImplementation initialIterationImpl = iterationImpls.get(0);
@@ -329,7 +330,7 @@ Collection<Tuple<OutputSlot<?>, PlanImplementation>> findExecutionOperatorOutput
329330
// For all the iterations, return the potential OutputSlots.
330331
final List<LoopImplementation.IterationImplementation> iterationImpls =
331332
loopImplementation.getIterationImplementations();
332-
final Set<Tuple<OutputSlot<?>, PlanImplementation>> collector = new HashSet<>(iterationImpls.size());
333+
final Set<Tuple<OutputSlot<?>, PlanImplementation>> collector = new LinkedHashSet<>(iterationImpls.size());
333334
for (LoopImplementation.IterationImplementation iterationImpl : iterationImpls) {
334335
final Collection<Tuple<OutputSlot<?>, PlanImplementation>> outputsWithContext =
335336
iterationImpl.getBodyImplementation().findExecutionOperatorOutputWithContext(innerOutput);
@@ -695,8 +696,8 @@ public double getSquashedCostEstimate() {
695696

696697
private Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(Operator operator) {
697698

698-
Set<Operator> inputOperators = new HashSet<>();
699-
Set<Junction> inputJunction = new HashSet<>();
699+
Set<Operator> inputOperators = new LinkedHashSet<>();
700+
Set<Junction> inputJunction = new LinkedHashSet<>();
700701

701702
List<ProbabilisticDoubleInterval> probalisticCost = new ArrayList<>();
702703
List<Double> squashedCost = new ArrayList<>();

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import java.util.Collection;
2525
import java.util.Collections;
2626
import java.util.HashMap;
27-
import java.util.HashSet;
27+
import java.util.LinkedHashSet;
2828
import java.util.Iterator;
2929
import java.util.LinkedList;
3030
import java.util.List;
@@ -59,7 +59,7 @@ public static <T> Set<T> asSet(Collection<T> collection) {
5959
if (collection instanceof Set<?>) {
6060
return (Set<T>) collection;
6161
}
62-
return new HashSet<>(collection);
62+
return new LinkedHashSet<>(collection);
6363
}
6464

6565
/**
@@ -69,7 +69,7 @@ public static <T> Set<T> asSet(Iterable<T> iterable) {
6969
if (iterable instanceof Set<?>) {
7070
return (Set<T>) iterable;
7171
}
72-
Set<T> set = new HashSet<>();
72+
Set<T> set = new LinkedHashSet<>();
7373
for (T t : iterable) {
7474
set.add(t);
7575
}
@@ -80,7 +80,7 @@ public static <T> Set<T> asSet(Iterable<T> iterable) {
8080
* Provides the given {@code values} as {@link Set}.
8181
*/
8282
public static <T> Set<T> asSet(T... values) {
83-
Set<T> set = new HashSet<>(values.length);
83+
Set<T> set = new LinkedHashSet<>(values.length);
8484
for (T value : values) {
8585
set.add(value);
8686
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.core.optimizer.channels;
20+
21+
import org.apache.wayang.core.api.Configuration;
22+
import org.apache.wayang.core.api.Job;
23+
import org.apache.wayang.core.optimizer.DefaultOptimizationContext;
24+
import org.apache.wayang.core.optimizer.OptimizationContext;
25+
import org.apache.wayang.core.optimizer.OptimizationUtils;
26+
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimate;
27+
import org.apache.wayang.core.plan.executionplan.Channel;
28+
import org.apache.wayang.core.plan.executionplan.ExecutionTask;
29+
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
30+
import org.apache.wayang.core.plan.wayangplan.InputSlot;
31+
import org.apache.wayang.core.plan.wayangplan.OutputSlot;
32+
import org.apache.wayang.core.platform.ChannelDescriptor;
33+
import org.apache.wayang.core.platform.Junction;
34+
import org.apache.wayang.core.test.DummyExecutionOperator;
35+
import org.apache.wayang.core.test.DummyExternalReusableChannel;
36+
import org.apache.wayang.core.test.DummyNonReusableChannel;
37+
import org.apache.wayang.core.test.DummyReusableChannel;
38+
import org.apache.wayang.core.test.MockFactory;
39+
import org.junit.jupiter.api.Test;
40+
41+
import java.util.ArrayList;
42+
import java.util.Arrays;
43+
import java.util.Collections;
44+
import java.util.List;
45+
import java.util.function.Supplier;
46+
import java.util.stream.Collectors;
47+
48+
import static org.junit.jupiter.api.Assertions.assertEquals;
49+
50+
class ChannelConversionGraphDeterminismTest {
51+
52+
private static Supplier<ExecutionOperator> createDummyExecutionOperatorFactory(ChannelDescriptor channelDescriptor) {
53+
return () -> {
54+
ExecutionOperator execOp = new DummyExecutionOperator(1, 1, false);
55+
execOp.getSupportedOutputChannels(0).add(channelDescriptor);
56+
return execOp;
57+
};
58+
}
59+
60+
private static DefaultChannelConversion conversion(ChannelDescriptor source, ChannelDescriptor target) {
61+
return new DefaultChannelConversion(source, target, createDummyExecutionOperatorFactory(target));
62+
}
63+
64+
@Test
65+
void channelConversionSelectionIsStable() {
66+
List<String> first = computeJunctionFingerprint();
67+
List<String> second = computeJunctionFingerprint();
68+
assertEquals(first, second, "Channel conversion choices must be deterministic.");
69+
}
70+
71+
private static List<String> computeJunctionFingerprint() {
72+
Configuration configuration = new Configuration();
73+
ChannelConversionGraph graph = new ChannelConversionGraph(configuration);
74+
graph.add(conversion(DummyReusableChannel.DESCRIPTOR, DummyNonReusableChannel.DESCRIPTOR));
75+
graph.add(conversion(DummyReusableChannel.DESCRIPTOR, DummyExternalReusableChannel.DESCRIPTOR));
76+
graph.add(conversion(DummyExternalReusableChannel.DESCRIPTOR, DummyNonReusableChannel.DESCRIPTOR));
77+
graph.add(conversion(DummyNonReusableChannel.DESCRIPTOR, DummyReusableChannel.DESCRIPTOR));
78+
79+
Job job = MockFactory.createJob(configuration);
80+
OptimizationContext optimizationContext = new DefaultOptimizationContext(job);
81+
82+
DummyExecutionOperator sourceOperator = new DummyExecutionOperator(0, 1, false);
83+
sourceOperator.getSupportedOutputChannels(0).add(DummyReusableChannel.DESCRIPTOR);
84+
optimizationContext.addOneTimeOperator(sourceOperator)
85+
.setOutputCardinality(0, new CardinalityEstimate(1000, 1000, 1d));
86+
87+
DummyExecutionOperator destOperator0 = new DummyExecutionOperator(1, 1, false);
88+
destOperator0.getSupportedInputChannels(0).add(DummyNonReusableChannel.DESCRIPTOR);
89+
90+
DummyExecutionOperator destOperator1 = new DummyExecutionOperator(1, 1, false);
91+
destOperator1.getSupportedInputChannels(0).add(DummyExternalReusableChannel.DESCRIPTOR);
92+
93+
Junction junction = graph.findMinimumCostJunction(
94+
sourceOperator.getOutput(0),
95+
Arrays.asList(destOperator0.getInput(0), destOperator1.getInput(0)),
96+
optimizationContext,
97+
false
98+
);
99+
100+
return describeJunction(junction);
101+
}
102+
103+
private static List<String> describeJunction(Junction junction) {
104+
List<String> descriptorList = new ArrayList<>();
105+
descriptorList.add(describeChannel(junction.getSourceChannel(), true));
106+
for (int i = 0; i < junction.getNumTargets(); i++) {
107+
descriptorList.add(describeChannel(junction.getTargetChannel(i), false));
108+
}
109+
return descriptorList;
110+
}
111+
112+
private static String describeChannel(Channel channel, boolean isSourceChannel) {
113+
if (channel == null) {
114+
return "null";
115+
}
116+
List<String> descriptors = new ArrayList<>();
117+
Channel cursor = channel;
118+
while (cursor != null) {
119+
descriptors.add(cursor.getDescriptor().toString() + (cursor.isCopy() ? ":copy" : ":orig"));
120+
ExecutionTask producer = cursor.getProducer();
121+
if (producer == null || producer.getNumInputChannels() == 0) {
122+
break;
123+
}
124+
// If we are describing the top-level source channel (junction entry), stop once we reach the producer that
125+
// has no inputs. For target channels, follow until the conversion tree ends.
126+
if (isSourceChannel) {
127+
cursor = producer.getNumInputChannels() == 0 ? null : producer.getInputChannel(0);
128+
} else if (producer.getNumInputChannels() == 0) {
129+
cursor = null;
130+
} else {
131+
cursor = producer.getInputChannel(0);
132+
}
133+
}
134+
Collections.reverse(descriptors);
135+
return descriptors.stream().collect(Collectors.joining("->"));
136+
}
137+
}

0 commit comments

Comments
 (0)