How to save CNTK model to file in C#


Final process of training is the model which should be used in the production as safe, reliable and accurate. Usually model training is frustrating and time consuming process, and it is not like we can see as demo to introduce with the library. Once the model is built with right combination of parameter values and network architecture, the process of modelling turn in to interesting and funny task, since it calculates the values just as we expect.

In most of the time, we have to save the current state of the model, and continue with training due to various reasons:

  • to change the parameters of the learner
  • to switch from one to another machine,
  • or to share the state of the model with your team mate,
  • to switch from CPU to GPU,
  • etc

In all above cases the current state of the model should be saved, and continue the training process from the last stage of the model. Because, the training from the beginning is not a solution, since we already invest time and knowledge to achieve progress of the model building.

The CNTK supports two kind of persisting the model.

  • production/ready or evaluation ready state, and
  • saving the checkpoint of the model for later training.

In the first case, the model is prepare for the evaluation and production but cannot be trained again, because it is freed from all other information but for the evaluation. During the saving process, only one files is generated.

In the second case, beside a model file, another file is generated with the name “modelname.ckp”. The file contains all information needed for the continuation of training.  Once the trainer  checkpoint is persisted we can continue with model training even if we changed the following:

  • the training data set with the same dimensions and data types
  • the parameters of the learner,
  • the learner

What we cannot change in order to continue with training is the the network model. In other words, the model must remain with the same number of layers, input and output dimensions.

Saving, loading and evaluating the model

Once the model is trained it can be persisted as separated file. As separate file, it can be loaded and evaluated with different dataset, but the number of the features and the label must remain the same as in case when was trained. Use this method when you want to share the model with someone else, or when you want to deploy the model in the production.

The model is saved simply by calling the CNTK method  Save:

public void SaveTrainedModel(Function model, string fileName)
{
    model.Save(fileName);
}

The model evaluation requires several steps:

  • load the model from the file,
  • extract the features and label from the model
  • call evaluate method from the model, by passing the batch created from the features, label and the evaluation dataset.

The model is loaded by calling Load method.

public Function LoadTrainedModel(string fileName, DeviceDescriptor device)
{
   return Function.Load(fileName, device, ModelFormat.CNTKv2);
}

Once the model is loaded, features and label are extracted from the model on the following way:

//load the mdoel from file
Function model = Function.Load(modelFile, device);
//extract features and label from the model
Variable feature = ffnn_model.Arguments[0];
Variable label = ffnn_model.Output;

The next step is creating the minibatch in order to pass the data to the evaluation.In this case we are going to create only one row for the Iris example of:

//Example: 5.0f, 3.5f, 1.3f, 0.3f, setosa
float[] xVal = new float[4] { 5.0f, 3.5f, 1.3f, 0.3f };
Value xValues = Value.CreateBatch<float>(new int[] {feature.Shape[0] }, xVal, device);
//Value yValues = - we don't need it, because we are going to calculate it

Once we created the variable and values we can map them and pass to the model evaluation, and calculate the result:

//map the variables and values
var inputDataMap = new Dictionary<Variable, Value>();
inputDataMap.Add(feature,xValues);
var outputDataMap = new Dictionary<Variable, Value>();
outputDataMap.Add(label, null);
//evaluate the model
ffnn_model.Evaluate(inputDataMap, outputDataMap, device);
//extract the result  as one hot vector
var outputData = outputDataMap[label].GetDenseData<float>(label);

The evaluation result should be transformed to proper format, and compared with expected result:

//transforms into class value
var actualLabels = outputData.Select(l => l.IndexOf(l.Max())).ToList();
var flower = actualLabels.FirstOrDefault();
var strFlower = flower == 0 ? "setosa" : flower == 1 ? "versicolor" : "versicolor";
Console.WriteLine($"Model Prediction: Input({xVal[0]},{xVal[1]},{xVal[2]},{xVal[3]}), Iris Flower={strFlower}");
Console.WriteLine($"Model Expectation: Input({xVal[0]},{xVal[1]},{xVal[2]},{xVal[3]}), Iris Flower= setosa");

Training previous saved model

Training previously saved model is very simple, since it requires no special coding. Right after the trainer is created with all necessary stuff (network, learning rate, momentum and other),
you just need to call

 trainer.RestoreFromCheckpoint(strIrisFilePath);

No additional code should be added.
The above method is called, after you successfully saved the model state by calling

trainer.SaveCheckpoint(strIrisFilePath);

The method is usually called at the end of the training process.
Complete source code from this blog post can be found here.

Advertisements

About Bahrudin Hrnjica

PhD in Mechanical Engineering, Microsoft MVP for Visual Studio and Development Technologies. Likes .NET, Math,Data Science, Evolutionary Algorithms, Machine Learning, Blogging.

Posted on 26/11/2017, in .NET, C#, CNTK, CodeProject and tagged , , , , , , . Bookmark the permalink. Leave a comment.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s