Enterprise TensorFlow - Executing a TensorFlow Session in Java
A TensorFlow Session can be executed in Java in the same way as in Python. This post shows how.
We have managed to load a TensorFlow SavedModel
in Java. Now it is time to get results out of the model. Luckily, the idiom for this is just the same as in low-level TensorFlow. All we need to do is identify the nodes that define the input and output of our computing graph, wrap data in Tensors and run a session.
As with the previous posts, you can find the complete working code on github.
Running a TensorFlow session in Java requires the following steps:
- Wrap your input data in
Tensor
objects using thestatic
helpers in theTensors
class - Get a
Session
object - Create a
Runner
object for the session - Assign the input
Tensor
s to the proper nodes in your graph withRunner.feed
- Define the output you want returned with
Runner.fetch
- Execute the Computation with
Runner.run
- Unwrap the result
Tensor
s using one of theTensor
’s convenience methods or acopyTo
call - Make sure you close all
Tensor
objects
This might look quite daunting, but is very simple in practice, thanks to a well documented API, lots of helper functions and a nice fluent Interface.
Wrapping data in org.tensorflow.Tensor
objects
The Tensor
class is the most important class when using the TensorFlow Java Wrapper. It is used to wrap and unwrap data to feed it to the TensorFlow engine and get results back. The most complicated part of running our model in java is correctly wrapping and unwrapping our data. Luckily, if we do something wrong, the resulting error messages are very meaningful and verbose, so this is normally an easy job.
In 99% of all use cases, you can simply call one of the helper methods in the Tensors
class to create a Tensor
of the proper shape, data type and content. Here is an example of wrapping a single input float
value:
final Tensor<Float> t = Tensors.create(f);
There are helper methods for all data types and up to six dimensions, so you should find everything you need there. As we will sell later, you still may want to wrap all Tensor
creation in function calls of your own to make resource handling a bit easier - you must make sure to call .close()
on all created Tensor
s!
In very, very rare cases you may have to resort to create
calls on the Tensor
class itself - this will allow you to create a Tensor
of any shape. For completeness sake, here is an example of manually wrapping a float
in a Tensor
(do not do this unless you absolutely have to):
final Tensor<Float> t = Tensor.create(
new long[] {1}, // the shape
FloatBuffer.wrap(new float[] {f}) // the data
);
Running a session and retrieving results
As in the Python low-level API, a model is executed in a session. To get a handle to a Session
object, we just call the SavedModelBundle.session()
method. The Session
object is in turn used to get a Runner
. The Runner
provides a fluent API that is used to bind Tensor
s to nodes in the graph with Runner.feed
and to define which Tensor
s to return after the computation is complete with Runner.fetch
. The fluent API works like a Builder, each call again returns the Runner
so we can chain calls. When everything is wired, we call Runner.run()
to perform the computation and return the result. The result is a list of Tensor
s, the number of elements in the list depends on the number of Runner.fetch
calls, each call will create an additional List entry. This is a complete result chaining all calls into one long statement:
final Tensor<?> result =
// gets the session
bundle.session()
// creates a runner
.runner()
// binds tensors to input nodes in the graph, in our case
// `values` is an array of floats, toTensor creates a Tensor
// object, the first argument is a string with the name of
// the input node
.feed("wine_type" , toTensor(values[1], tensorsToClose))
.feed("fixed_acidity" , toTensor(values[2], tensorsToClose))
.feed("volatile_acidity" , toTensor(values[3], tensorsToClose))
.feed("citric_acid" , toTensor(values[4], tensorsToClose))
.feed("residual_sugar" , toTensor(values[5], tensorsToClose))
.feed("chlorides" , toTensor(values[6], tensorsToClose))
.feed("free_sulfur_dioxide" , toTensor(values[7], tensorsToClose))
.feed("total_sulfur_dioxide", toTensor(values[8], tensorsToClose))
.feed("density" , toTensor(values[9], tensorsToClose))
.feed("ph" , toTensor(values[10], tensorsToClose))
.feed("sulphates" , toTensor(values[11], tensorsToClose))
.feed("alcohol" , toTensor(values[12], tensorsToClose))
// define which output tensor to return
// (you can chain multiple `fetch` calls to
// return more then one tensor)
.fetch("dnn/head/logits:0")
// execute the runner - this returns a list
.run()
// We have only one fetch call, so we get a
// one-element-list. The `get(0)` call fetches
// the first element of the list
.get(0);
Unwrapping resulting Tensor
s
What is left now is to get the result out of the Tensor
returned by the run()
call. If the result Tensor is simply a scalar, you can just call Tensor.floatValue()
, Tensor.booleanValue()
etc. If the resulting tensor is not a scalar, the resulting data needs to be retrieved with Tensor.copyTo(U destination)
, where destination is a multidimensional array. Prepackaged neural network regression estimators for example always return a two dimensional tensor, even if you only have one single numerical result. In that case, you can retrieve the result like this:
float[][] resultValues = (float[][]) result.copyTo(new float[1][1]);
float prediction = resultValues[0][0];
The type and number of dimensions of the array depends on your model.
Resource management
Two types of objects will need manual closing for proper resource handling: Session
s and Tensor
s. Note that allTensor
objects - whether created manually or returned from running a session - must be closed manually. I prefer to do this by performing all Tensor
creation in helper functions that collect all created Tensor
s in a Collection
and then free everything in a finally block after I am done:
private static Tensor<Float> toTensor(final float f,
final Collection<Tensor<?>> tensorsToClose)
{
final Tensor<Float> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static void closeTensors(final Collection<Tensor<?>> ts) {
for (final Tensor<?> t : ts) {
try {
t.close();
} catch (final Exception e) {
// TODO: decide on the error handling best fitting your use case here
// In most cases logging is the only useful thing left to do
System.err.println("Error closing Tensor.");
e.printStackTrace();
}
}
ts.clear();
}
private void runSession(final float foo, /* more params here */) {
final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>();
try {
// run session
final List<Tensor<?>> result = bundle.session().runner()
.feed("foo", toTensor(foo, tensorsToClose))
// ... feed more tensors as necessary ...
.fetch("some_node")
// ... fetch more tensors as necessary ...
.run();
// mark result for cleanup
tensorsToClose.addAddAll(result);
// ... do something with the result ...
} finally {
closeTensors(tensorsToClose);
}
}
Note the absence of Session
closing: The session is created once for the SavedModelBundle
, the session()
call returns an existing reference, not a new session. The Session
is thread-safe, so it can be reused everywhere. It only needs closing when you are completely done. So you should close the Session
only at the end of your program or when you shut down your server. You can simply do this by closing your SavedModelBundle
, which frees all resources associated with the SavedModel
. (You may even omit closing the Session
, as the end of your JVM process should free all resources associated with it anyway - I never had any negative effects, but do this at your own risk!)
Determining the proper names for input and output nodes
If you have written your own Estimator
you probably know how the input and output nodes that you need to call, as well as their shapes. Sometimes however, you might have used a prepackaged Estimator
, where you do not know how the output nodes are called or you may not have written the model yourself and need to inspect the saved data to know what to call. In that case, like for the tag necessary to load the necessary MetaGraph
s in the previous post, you need to inspect your SavedModel
on the command line to determine your tag, your input and output node names, their shapes and data types. This can be done by successive calls to the saved_model_cli
like this (we use the SavedModel
from our example project here, your output will obviously depend on the model you use):
chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/
The given SavedModel contains the following tag- sets:
serve
chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/ \
--tag_set serve
The given SavedModel MetaGraphDef contains
SignatureDefs with the following keys:
SignatureDef key: "predict"
chris$ ~/Library/Python/3.6/bin/saved_model_cli show\
--dir saved_models/1512127459/ \
--tag_set serve \
--signature_def predict
The given SavedModel SignatureDef contains the following input(s):
inputs['wine_type'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: wine_type:0
...
The given SavedModel SignatureDef contains the
following output(s):
outputs['predictions'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: dnn/head/logits:0
Method name is: tensorflow/serving/predict
As you can see, we need successive calls to saved_model_cli show
to “dig deeper” into our SavedModel
to determine the names, shapes and datatypes of the input and output tensors. Regrettably, this information cannot be retrieved generically with the Java API (AFAIK). The type and shape of result Tensor
s however can also be examined by calls to Tensor.shape()
and Tensor.dataType()
.
Summary
Running a TensorFlow session in Java is pretty easy, just remember:
SavedModelBundle
and the correspondingSession
are thread-safe,Tensor
s are not- Use the
saved_model_cli
to determine the name and shape of your input and output nodes - Wrap your input data with the helper methods in the
Tensors
class - Use the fluent API on the
SavedModelBundle
to aquire and run a session:bundle.session().runner().feed(...).fetch(...).run()
- For scalar results: Use
Tensor.floatValue()
etc. to retrieve data from the resultingTensor
s. - For non-scalar results: Use the proper array type and shape to retrieve data from your resulting
Tensor
s usingTensor.copyTo
- Only call
close
on yourSavedModelBundle
when you are completely done and want to shut down your JVM, e.g. on server shutdown