Class TensorflowModel

  • All Implemented Interfaces:
    java.io.Serializable, java.lang.AutoCloseable, Model

    public class TensorflowModel
    extends DLModel
    implements java.lang.AutoCloseable
    See Also:
    Serialized Form
    • Field Summary

      • Fields inherited from class org.apache.wayang.basic.model.DLModel

        out
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      void close()  
      Op getAccuracyCalculation()  
      Op getCriterion()  
      Optimizer getOptimizer()  
      <XT extends org.tensorflow.ndarray.NdArray<?>,​PT extends org.tensorflow.ndarray.NdArray<?> & org.tensorflow.types.family.TType>
      PT
      predict​(XT x)  
      <XT extends org.tensorflow.ndarray.NdArray<?>,​YT extends org.tensorflow.ndarray.NdArray<?>>
      void
      train​(XT x, YT y, int epoch, int batchSize)  
      • Methods inherited from class org.apache.wayang.basic.model.DLModel

        getOut
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Constructor Detail

      • TensorflowModel

        public TensorflowModel​(DLModel model,
                               Op criterion,
                               Optimizer optimizer,
                               Op accuracyCalculation)
    • Method Detail

      • train

        public <XT extends org.tensorflow.ndarray.NdArray<?>,​YT extends org.tensorflow.ndarray.NdArray<?>> void train​(XT x,
                                                                                                                            YT y,
                                                                                                                            int epoch,
                                                                                                                            int batchSize)
      • predict

        public <XT extends org.tensorflow.ndarray.NdArray<?>,​PT extends org.tensorflow.ndarray.NdArray<?> & org.tensorflow.types.family.TType> PT predict​(XT x)
      • getCriterion

        public Op getCriterion()
      • getOptimizer

        public Optimizer getOptimizer()
      • getAccuracyCalculation

        public Op getAccuracyCalculation()
      • close

        public void close()
        Specified by:
        close in interface java.lang.AutoCloseable