Class ModelTrainingServices
- java.lang.Object
-
- com.playground.playground.interface_adapter.modelling.ModelTrainingServices
-
public class ModelTrainingServices extends java.lang.ObjectThis is the class that facilitates training the model and logging.
-
-
Constructor Summary
Constructors Constructor Description ModelTrainingServices(org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator data, org.deeplearning4j.nn.multilayer.MultiLayerNetwork model, java.lang.String statsFileName, org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator testData)Constructor for the ModelTrainingServices class which initializers the datasets and model.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description org.deeplearning4j.datasets.iterator.INDArrayDataSetIteratorgetData()org.deeplearning4j.nn.multilayer.MultiLayerNetworkgetModel()java.lang.StringgetStatsFileName()org.deeplearning4j.datasets.iterator.INDArrayDataSetIteratorgetTestData()voidsetModel(org.deeplearning4j.nn.multilayer.MultiLayerNetwork model)voidsetStatsFileName(java.lang.String statsFileName)java.lang.Object[]trainModel(boolean verbose)Train the model set through the constructor and created using the NeuralNet class.
-
-
-
Constructor Detail
-
ModelTrainingServices
public ModelTrainingServices(org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator data, org.deeplearning4j.nn.multilayer.MultiLayerNetwork model, java.lang.String statsFileName, org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator testData)Constructor for the ModelTrainingServices class which initializers the datasets and model.- Parameters:
data- The training dataset.model- The model DAG.statsFileName- The name of the logging file which will be saved to the disk.testData- The testing dataset.
-
-
Method Detail
-
trainModel
public java.lang.Object[] trainModel(boolean verbose)
Train the model set through the constructor and created using the NeuralNet class.- Parameters:
verbose- Should you print to the logger?- Returns:
- An ArrayList with trainingScore, testingScore and an ArrayList of Predictions.
-
getData
public org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getData()
- Returns:
- The training dataset.
-
getTestData
public org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getTestData()
- Returns:
- The testing dataset.
-
getModel
public org.deeplearning4j.nn.multilayer.MultiLayerNetwork getModel()
- Returns:
- Get the model.
-
setModel
public void setModel(org.deeplearning4j.nn.multilayer.MultiLayerNetwork model)
- Parameters:
model- The model to set.
-
getStatsFileName
public java.lang.String getStatsFileName()
- Returns:
- The name of the logging stats file.
-
setStatsFileName
public void setStatsFileName(java.lang.String statsFileName)
- Parameters:
statsFileName- The stats file name to set.
-
-