Today, I am going to walk through my process of moving my Keras trained model from my desktop onto my Android phone.
There are certain processes that I want to run against my model while I am away from my computer that I can’t do from the web. These range from some data collection to some larger data processing tasks. For example, I don’t want my web site to scrape all the spreads every time someone goes to it.
How I Got Here
After I created my deep learning model to predict NCAA basketball scores in Google Colab I decided I needed to deploy it a few more places. My first change was using TensorFlow.JS to deploy on the web. After that, I pulled the model down to my desktop and wrapped it in a larger .NET application. Today, I decided that I wanted to create an Android application that would allow me to have my model on my local device instead of calling out to the web. It will also allow me to run a more robust set of commands that I couldn’t do on the web.
Part of being a GDE allows me to get access to certain groups and people at Google. One of the groups that I am part of is the ML on Mobile group that works on TensorFlow Lite (link). We have had a few meetings and that was the final push I needed to carve out some time and do this project.
This library is fantastic. Since all of my stuff was done in TFv2 and Keras it would be a simple conversion and then learning the Java API calls.
First, I had to convert my Keras model to a TF Lite model. This was as easy as the following commands:
#Convert the model
converter = tf.lite.TFLiteConverter.from_keras_model(restored_model)
tflite_model = converter.convert()
#Save the TF Lite model.
with tf.io.gfile.GFile('model.tflite', 'wb') as f:
Second, I had to add that model to my ‘Assets’ folder in Android Studio.
Third, within Android Studio, I had to add code to the Gradle file to ensure the model doesn’t get compressed and to add the TensorFlow Lite libraries.
Now, I needed to actually create the code that will interact with my model.
Android Calling Code
The key class here is the Interpreter class. This is the class that takes in your model and runs all of the predictions. In my case, I fought this like crazy. It started with trying to figure out how to turn the stream I got from the Assets folder into a File object (I had to write it to local storage, fyi). Then, I had to figure out what the input and output objects were going to be.
I have a simple model in that it takes in 6 decimals and outputs a single number. So, my input was a simple float and my output was float. Here is the code:
File mdl = CreateFile(); //Creates the TF Lite model if it doesn't exist
Interpreter intp = new Interpreter(mdl);
//Create a 6 element float array. NOTE: I needed to do some normalization.
float inputs = BuildInputArray( 72.1,63.8, -5.1, 65.1, 70.8, -2.3);
//Create the output array that returns a Tensor
float out = new float;
//Run the prediction
//Close the model
//Get the results
float results = out;
Well, there are my simple steps to get your trained model onto a mobile device. In the future, I will be adding a bunch of features so that I can do all the work from my phone and not have to go to my desktop each day during the season to get my gambling picks.