Class ModelTrainingServices
- java.lang.Object
-
- com.playground.playground.interface_adapter.modelling.ModelTrainingServices
-
public class ModelTrainingServices extends java.lang.Object
This is the class that facilitates training the model and logging.
-
-
Constructor Summary
Constructors Constructor Description ModelTrainingServices(org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator data, org.deeplearning4j.nn.multilayer.MultiLayerNetwork model, java.lang.String statsFileName, org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator testData)
Constructor for the ModelTrainingServices class which initializers the datasets and model.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator
getData()
org.deeplearning4j.nn.multilayer.MultiLayerNetwork
getModel()
java.lang.String
getStatsFileName()
org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator
getTestData()
void
setModel(org.deeplearning4j.nn.multilayer.MultiLayerNetwork model)
void
setStatsFileName(java.lang.String statsFileName)
java.lang.Object[]
trainModel(boolean verbose)
Train the model set through the constructor and created using the NeuralNet class.
-
-
-
Constructor Detail
-
ModelTrainingServices
public ModelTrainingServices(org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator data, org.deeplearning4j.nn.multilayer.MultiLayerNetwork model, java.lang.String statsFileName, org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator testData)
Constructor for the ModelTrainingServices class which initializers the datasets and model.- Parameters:
data
- The training dataset.model
- The model DAG.statsFileName
- The name of the logging file which will be saved to the disk.testData
- The testing dataset.
-
-
Method Detail
-
trainModel
public java.lang.Object[] trainModel(boolean verbose)
Train the model set through the constructor and created using the NeuralNet class.- Parameters:
verbose
- Should you print to the logger?- Returns:
- An ArrayList with trainingScore, testingScore and an ArrayList of Predictions.
-
getData
public org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getData()
- Returns:
- The training dataset.
-
getTestData
public org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator getTestData()
- Returns:
- The testing dataset.
-
getModel
public org.deeplearning4j.nn.multilayer.MultiLayerNetwork getModel()
- Returns:
- Get the model.
-
setModel
public void setModel(org.deeplearning4j.nn.multilayer.MultiLayerNetwork model)
- Parameters:
model
- The model to set.
-
getStatsFileName
public java.lang.String getStatsFileName()
- Returns:
- The name of the logging stats file.
-
setStatsFileName
public void setStatsFileName(java.lang.String statsFileName)
- Parameters:
statsFileName
- The stats file name to set.
-
-