Skip to main content

Guide to Development and Usage with Apache Wayang (incubating)

This section provides a set of examples to illustrate how to use Apache Wayang for different tasks.

Example 1: Machine Learning for query optimization in Apache Wayang

Apache Wayang can be customized with concrete implementations of the EstimatableCost interface in order to optimize for a desired metric. The implementation can be enabled by providing it to a Configuration.

public class CustomEstimatableCost implements EstimatableCost {
/* Provide concrete implementations to match desired cost function(s)
* by implementing the interface in this class.
*/
}
public class WordCount {
public static void main(String[] args) {
/* Create a Wayang context and specify the platforms Wayang will consider */
Configuration config = new Configuration();
/* Provision of a EstimatableCost that implements the interface.*/
config.setCostModel(new CustomEstimatableCost());
WayangContext wayangContext = new WayangContext(config)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin());
/*... omitted */
}
}

In combination with an encoding scheme and a third party package to load ML models, the following example shows how to predict runtimes of query execution plans runtimes in Apache Wayang (incubating):

public class MLCost implements EstimatableCost {
public EstimatableCostFactory getFactory() {
return new Factory();
}

public static class Factory implements EstimatableCostFactory {
@Override public EstimatableCost makeCost() {
return new MLCost();
}
}

@Override public ProbabilisticDoubleInterval getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);

return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}

@Override public ProbabilisticDoubleInterval getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);

return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}

/** Returns a squashed cost estimate. */
@Override public double getSquashedEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);

return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}

@Override public double getSquashedParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);

return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}

@Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator operator) {
List<ProbabilisticDoubleInterval> intervalList = new ArrayList<ProbabilisticDoubleInterval>();
List<Double> doubleList = new ArrayList<Double>();
intervalList.add(this.getEstimate(plan, true));
doubleList.add(this.getSquashedEstimate(plan, true));

return new Tuple<>(intervalList, doubleList);
}

public PlanImplementation pickBestExecutionPlan(
Collection<PlanImplementation> executionPlans,
ExecutionPlan existingPlan,
Set<Channel> openChannels,
Set<ExecutionStage> executedStages) {
final PlanImplementation bestPlanImplementation = executionPlans.stream()
.reduce((p1, p2) -> {
final double t1 = p1.getSquashedCostEstimate();
final double t2 = p2.getSquashedCostEstimate();
return t1 < t2 ? p1 : p2;
})
.orElseThrow(() -> new WayangException("Could not find an execution plan."));
return bestPlanImplementation;
}
}

Third-party packages such as OnnxRuntime can be used to load pre-trained .onnx files that contain desired ML models.

public class OrtMLModel {

private static OrtMLModel INSTANCE;

private OrtSession session;
private OrtEnvironment env;

private final Map<String, OnnxTensor> inputMap = new HashMap<>();
private final Set<String> requestedOutputs = new HashSet<>();

public static OrtMLModel getInstance(Configuration configuration) throws OrtException {
if (INSTANCE == null) {
INSTANCE = new OrtMLModel(configuration);
}

return INSTANCE;
}

private OrtMLModel(Configuration configuration) throws OrtException {
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
}

public void loadModel(String filePath) throws OrtException {
if (this.env == null) {
this.env = OrtEnvironment.getEnvironment();
}

if (this.session == null) {
this.session = env.createSession(filePath, new OrtSession.SessionOptions());
}
}

public void closeSession() throws OrtException {
this.session.close();
this.env.close();
}

/**
* @param encodedVector
* @return NaN on error, and a predicted cost on any other value.
* @throws OrtException
*/
public double runModel(Vector<Long> encodedVector) throws OrtException {
double costPrediction;

OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
this.inputMap.put("input", tensor);
this.requestedOutputs.add("output");

BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
try {
return ((double[]) r.get(s).get().getValue())[0];
} catch (OrtException e) {
return Double.NaN;
}
};

try (Result r = session.run(inputMap, requestedOutputs)) {
costPrediction = unwrapFunc.apply(r, "output");
}

return costPrediction;
}
}