Guide to Development and Usage with Apache Wayang (incubating)
This section provides a set of examples to illustrate how to use Apache Wayang for different tasks.
Example 1: Machine Learning for query optimization in Apache Wayang
Apache Wayang can be customized with concrete
implementations of the EstimatableCost
interface in order to optimize
for a desired metric. The implementation can be enabled by providing it
to a Configuration
.
public class CustomEstimatableCost implements EstimatableCost {
/* Provide concrete implementations to match desired cost function(s)
* by implementing the interface in this class.
*/
}
public class WordCount {
public static void main(String[] args) {
/* Create a Wayang context and specify the platforms Wayang will consider */
Configuration config = new Configuration();
/* Provision of a EstimatableCost that implements the interface.*/
config.setCostModel(new CustomEstimatableCost());
WayangContext wayangContext = new WayangContext(config)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin());
/*... omitted */
}
}
In combination with an encoding scheme and a third party package to load ML models, the following example shows how to predict runtimes of query execution plans runtimes in Apache Wayang (incubating):
public class MLCost implements EstimatableCost {
public EstimatableCostFactory getFactory() {
return new Factory();
}
public static class Factory implements EstimatableCostFactory {
@Override public EstimatableCost makeCost() {
return new MLCost();
}
}
@Override public ProbabilisticDoubleInterval getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
@Override public ProbabilisticDoubleInterval getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
/** Returns a squashed cost estimate. */
@Override public double getSquashedEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}
@Override public double getSquashedParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}
@Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator operator) {
List<ProbabilisticDoubleInterval> intervalList = new ArrayList<ProbabilisticDoubleInterval>();
List<Double> doubleList = new ArrayList<Double>();
intervalList.add(this.getEstimate(plan, true));
doubleList.add(this.getSquashedEstimate(plan, true));
return new Tuple<>(intervalList, doubleList);
}
public PlanImplementation pickBestExecutionPlan(
Collection<PlanImplementation> executionPlans,
ExecutionPlan existingPlan,
Set<Channel> openChannels,
Set<ExecutionStage> executedStages) {
final PlanImplementation bestPlanImplementation = executionPlans.stream()
.reduce((p1, p2) -> {
final double t1 = p1.getSquashedCostEstimate();
final double t2 = p2.getSquashedCostEstimate();
return t1 < t2 ? p1 : p2;
})
.orElseThrow(() -> new WayangException("Could not find an execution plan."));
return bestPlanImplementation;
}
}
Third-party packages such as OnnxRuntime
can be used to load
pre-trained .onnx
files that contain desired ML models.
public class OrtMLModel {
private static OrtMLModel INSTANCE;
private OrtSession session;
private OrtEnvironment env;
private final Map<String, OnnxTensor> inputMap = new HashMap<>();
private final Set<String> requestedOutputs = new HashSet<>();
public static OrtMLModel getInstance(Configuration configuration) throws OrtException {
if (INSTANCE == null) {
INSTANCE = new OrtMLModel(configuration);
}
return INSTANCE;
}
private OrtMLModel(Configuration configuration) throws OrtException {
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
}
public void loadModel(String filePath) throws OrtException {
if (this.env == null) {
this.env = OrtEnvironment.getEnvironment();
}
if (this.session == null) {
this.session = env.createSession(filePath, new OrtSession.SessionOptions());
}
}
public void closeSession() throws OrtException {
this.session.close();
this.env.close();
}
/**
* @param encodedVector
* @return NaN on error, and a predicted cost on any other value.
* @throws OrtException
*/
public double runModel(Vector<Long> encodedVector) throws OrtException {
double costPrediction;
OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
this.inputMap.put("input", tensor);
this.requestedOutputs.add("output");
BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
try {
return ((double[]) r.get(s).get().getValue())[0];
} catch (OrtException e) {
return Double.NaN;
}
};
try (Result r = session.run(inputMap, requestedOutputs)) {
costPrediction = unwrapFunc.apply(r, "output");
}
return costPrediction;
}
}