Class ModelTrainingServices


  • public class ModelTrainingServices
    extends java.lang.Object
    This is the class that facilitates training the model and logging.
    • Constructor Summary

      Constructors 
      Constructor Description
      ModelTrainingServices​(org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator data, org.deeplearning4j.nn.multilayer.MultiLayerNetwork model, java.lang.String statsFileName, org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator testData)
      Constructor for the ModelTrainingServices class which initializers the datasets and model.
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getData()  
      org.deeplearning4j.nn.multilayer.MultiLayerNetwork getModel()  
      java.lang.String getStatsFileName()  
      org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getTestData()  
      void setModel​(org.deeplearning4j.nn.multilayer.MultiLayerNetwork model)  
      void setStatsFileName​(java.lang.String statsFileName)  
      java.lang.Object[] trainModel​(boolean verbose)
      Train the model set through the constructor and created using the NeuralNet class.
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Constructor Detail

      • ModelTrainingServices

        public ModelTrainingServices​(org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator data,
                                     org.deeplearning4j.nn.multilayer.MultiLayerNetwork model,
                                     java.lang.String statsFileName,
                                     org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator testData)
        Constructor for the ModelTrainingServices class which initializers the datasets and model.
        Parameters:
        data - The training dataset.
        model - The model DAG.
        statsFileName - The name of the logging file which will be saved to the disk.
        testData - The testing dataset.
    • Method Detail

      • trainModel

        public java.lang.Object[] trainModel​(boolean verbose)
        Train the model set through the constructor and created using the NeuralNet class.
        Parameters:
        verbose - Should you print to the logger?
        Returns:
        An ArrayList with trainingScore, testingScore and an ArrayList of Predictions.
      • getData

        public org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getData()
        Returns:
        The training dataset.
      • getTestData

        public org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getTestData()
        Returns:
        The testing dataset.
      • getModel

        public org.deeplearning4j.nn.multilayer.MultiLayerNetwork getModel()
        Returns:
        Get the model.
      • setModel

        public void setModel​(org.deeplearning4j.nn.multilayer.MultiLayerNetwork model)
        Parameters:
        model - The model to set.
      • getStatsFileName

        public java.lang.String getStatsFileName()
        Returns:
        The name of the logging stats file.
      • setStatsFileName

        public void setStatsFileName​(java.lang.String statsFileName)
        Parameters:
        statsFileName - The stats file name to set.