[How-To] Using deep learning models within the Java ecosystem

After weeks of training and optimizing a neural net at some point it might be ready for production. Most deep learning projects never reach this point and for the rest it’s time to think about frameworks and technology stack. The are several frameworks, which promise easy productionalizing of deep neural nets. Some might be a good fit, others are not. This post will not about these paid solutions, this post is about how to bring your deep learning model into production without third-party tools. This post shows an example with flask as a RESTful API and and a direct integration in java for batch predictions. This post focuses on deep neural networks, for traditional ML, take a look at this series.

Out-of-the-box API solutions

First let’s look at some tools, which might be feasible for the most projects. The straightforward solution for tensorflow models is tensorflow serve. I wrote about how to use the SavedModel approach and tensorflow serve HERE.

Even simpler is serving predictions from a trained model is using Jupyter Kernel Gateway, which allows headless access to Jupyter notebooks. But it is obviously not really meant for production usage.

Serving with flask as RESTful API

The straight forward solution would be using any python web server and wrap the model into a REST API. Flask does not need much explanations, it a simple HTTP service for on-demand predictions. Here is a very simple example webapp, which can then be deployed on servers and cloud services like AWS. Here is a good tutorial.

import os
from flask import Flask
from tensorflow.python.keras.models import load_model

# use only CPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
init = None
model = None
app = Flask(__name__)


@app.route("/predict", methods=["POST"])
def predict():
    init_if_neccecary()
    if flask.request.method == "POST":
        preds = model.predict(image)
        data["predictions"] = preds
        data["label"] = ['Class A', 'Class B', ...]

    return flask.jsonify(data)


def init_if_neccecary():
    global init, model
    if not init:
        print('init')
        model = load_model('mymodel.model')
        init = True

Serving models directly in Java

All these options involve third party frameworks, HTTP or any other network communication. In applications where you predict millions of examples it’s better to do predictions directly on the local machine.

DL4J

There is deeplearning4j, which integrated directly into the java world. Unfortunately, I cannot really recommend it, because it moves the ML code to java and honestly ML in Java is not fun for data wrangling and to verbose, plus you must rewrite the code. Furthermore, it does not support important layers like GRU, TimeDistributed. For simple ANNs it’s ok, but researchers do not use it, so there are no state-of-the-art models available.

Tensorflow Java API

There is also a Java API for tensorflow, which can be used to load SavedModels. Before looking at the java API let’s think about deep learning frameworks. What is TensorFlow actually doing? It is basically a library for parallel computing, and it can utilize GPUs through CUDA but also SSE, AVX, etc. on CPUs. Python is the API to access the C++ core, but in the end it’s using highly optimized binaries. The java API needs to ship all these binaries. It introduces a huge dependency with 145 MB called tensorflow-jni, JNI is the native Interface from Java to call native (C/C++) libraries. We don’t want a 145 MB binary package in our application or a 350 MB package with GPU support! Besides that, the java API is very limited and python is often already installed on servers, adding tensorflow with pip install tensorflow is easy.

It is not difficult to call python from Java directly. In this case we can write everything related to our machine learning model in python and just call the script by java, like an API. We can stay with all machine learning related code in python and use the java ecosystem for everything else.

Call python from java

Example: A java backend is running batch jobs and we want to predict a lot of examples at once. We want to apply deep learning model to classify text. Invocation of any HTTP services is not feasible. The only thing we should do before running python scripts from java is to install python and all the required libraries on the machine, where the java application is running.

Calling a python script in java is straight forward. In a batch processing scenario, we export preprocessed data to csv, json or avro. This data will be used by python for predictions and we call the python script directly or via a bash script.

public void run( Path predictionsFilePath ) {
   log.info("Starting prediction");
   // target file data.json
   dataFile = Paths.get("data.json");
   predictionsFilePath.toFile().deleteOnExit();
// here java does some text preprocessing and exports an json file with the data preProcessor.exportJson(_dataFile); log.info("Executing Python based prediction");
// Now we run the bash script, which is then calling the python script, passing the data file, the file where we read the predictions and the model, which is used ProcessBuilder pr = new ProcessBuilder("run.sh", String.format("predict --modelpath %s --datafile %s --targetfile %s", "latest_model", dataFile.toAbsolutePath(), predictionsFilePath.toAbsolutePath())); pr.directory(new File("/mypath")); pr.redirectOutput(ProcessBuilder.Redirect.INHERIT); pr.redirectError(ProcessBuilder.Redirect.INHERIT); try { Process start = pr.start();
// here we set the timeout for the process, we assume something went wrong, if the it takes longer then 15 minutes to make the predictions start.waitFor(15, TimeUnit.MINUTES); processPredictions(predictionsFilePath); } catch ( IOException | InterruptedException e ) { throw new RuntimeException("Execution failed", e); } }
public void processPredictions(Path predictionsFilePath){ CsvToBean< MyType> csvToBean = new CsvToBeanBuilder(Files.newBufferedReader(predictionsFilePath)) // .withType(MyType.class) // .withIgnoreLeadingWhiteSpace(true)// .build(); for ( MyType bean: csvToBean ) { // Do Something with predictions }
log.info("Succesfull processed %s predictions".format(csvToBean.size()));
}

The java code calls this shell script, where we can setup additional things and finally call the python script, passing the params:

#!/bin/bash
python "main.py" predict $@

In our python script we can parameterize prediction and training scenario, so we can share the same codebase. Here is an short example of the idea. We load a NLP model, where we have a tensorflow model plus a dictionary (latest_model.model, latest_model.dict).

if self.opts.operation == 'predict':
    print('----------- Running Prediction ---------')

   print('Using given model: {}'.format(self.opts.modelname))
   model = load_model(self.opts.modelname + '.model',
                      custom_objects=[])
   with open(self.opts.modelname + '.dict', 'rb') as fp:
       word_dict = pickle.load(fp)

    predictor = Predictor(self.opts.targetfile,
                          model, word_dict, self.opts.datafile)
    predictor.predict()
elif self.opts.operation == 'train': print('----------- Running Training ---------') trainer = Trainer() trainer.train(self.get_training_data())

 

Feel free to add a comment!

Related Posts

2 comments

[…] I wrote a post about how to deploy deep learning models into production without the use of additional frameworks. […]

Thanks for the useful info. I also note that you can mix Java and Python, and call native Java deep learning packages using DataMelt https://jwork.org/dmelt/

Leave a reply