Enterprise TensorFlow - Eine TensorFlow Session in Java ausführen
Eine TensorFlow Session wird in Java genauso ausgeführt wie in Python. Dieser Post zeigt wie.
Im vorigen Post wurde gezeigt was nötig ist, um ein TensorFlow SavedModel
in Java zu laden. Jetzt ist es an der Zeit, Ergebnisse mit dem Modell zu erzeugen. Zum Glück ist das Vorgehen auf der JVM hier genau das gleiche wie bei Low-Level-TensorFlow in Python. Alles, was wir tun müssen, ist, die Knoten zu identifizieren, die die Ein- und Ausgabe unseres Computing-Graphen definieren, Daten in Tensoren zu verpacken und in einer Session die Berechnung durchzuführen.
Wie bei den vorherigen Beiträgen ist der komplette Code auf github zu finden.
Das Ausführen einer TensorFlow-Session in Java erfordert die folgenden Schritte:
- Eingabedaten in Tensorobjekte verpacken, mit statischen Helfern der
Tensors
-Klasse. Session
-Objekt holen- Erstellen eines
Runner
-Objekts für die Sitzung - Eingangs-Tensoren mit
Runner.feed
den richtigen Knoten im Graphen zuordnen - Definieren der Ausgabe, die mit
Runner.fetch
zurückgegeben werden soll. - Ausführen der Berechnung mit
Runner.run
- Entpacken der Ergebnis Tensoren mit einer der Utility-Methoden der
Tensors
-Klasse oder einemcopyTo
-Aufruf. - Nicht vergessen, alle Resourcen (Tensorobjekte) zu schließen.
Das sieht vielleicht ziemlich beängstigend aus, ist aber in der Praxis sehr einfach, dank einer gut dokumentierten API, vielen Hilfsfunktionen und einem netten fluent-Interface.
Daten in org.tensorflow.Tensor
-Objekte einpacken
Die Tensor
-Klasse ist die wichtigste Klasse beim Einsatz des TensorFlow Java Wrappers. Sie wird verwendet, um Daten zu verpacken und zu entpacken, um sie der TensorFlow-Engine zuzuführen und Ergebnisse zu erhalten. Der komplizierteste Teil der Ausführung unseres Modells in Java ist das korrekte Ein- und Auspacken unserer Daten. Zum Glück, sind im Fehlerfall die daraus resultierenden Fehlermeldungen sehr aussagekräftig und ausführlich, so dass dies normalerweise ein einfacher Job ist.
In 99% aller Anwendungsfälle reicht einfach eine der Hilfsmethoden in der Tensors
-Klasse, um einen Tensor
mit der richtigen Form, dem richtigen Datentyp und dem richtigen Inhalt zu erstellen. Hier ist ein Beispiel für das verpacken eines einzelnen Input-Fließkommawertes:
final Tensor<Float> t = Tensors.create(f);
Es gibt Hilfsmethoden für alle Datentypen und bis zu sechs Dimensionen, daher sollten dort alles zu finden sein, was man braucht. Wie wir später sehen werden, ist es dennoch sinnvoll, das erzeugen von Tensor
-Objekten in r-Erstellungen in eigene Funktionsaufrufe zu packen, um das Ressourcenhandling ein wenig zu vereinfachen - es muss sichergestellt werden .close()
auf allen erstellten Tensoren aufgerufen wird!
In sehr, sehr seltenen Fällen müssen Sie auf die Tensor
-Klasse selbst zurückgreifen - so können Sie einen Tensor
jeder beliebigen Form erstellen. Der Vollständigkeit halber sei hier ein Beispiel für das manuelle Verpacken eines float
in einen Tensor
genannt (Finger weg von dieser Methode, es sei denn, es ist unbedingt nötig):
final Tensor<Float> t = Tensor.create(
new long[] {1}, // the shape
FloatBuffer.wrap(new float[] {f}) // the data
);
Ausführen einer Sitzung und Output der Ergebnisse
Wie in der Python Low-Level-API wird ein Modell in einer Sitzung ausgeführt. Um ein Handle auf ein Session
-Objekt zu bekommen, rufen wir einfach die Methode SavedModelBundle.session()
auf. Das Session
-Objekt wird wiederum verwendet, um einen Runner
zu erhalten. Der Runner
bietet eine fluent API, die verwendet wird, um Tensor
en an Knoten im Graphen mit Runner.feed
zu binden und zu definieren, welche Tensoren nach der Berechnung mit Runner.fetch
zurückgegeben werden sollen. Die fluent API funktioniert wie ein Builder, jeder Aufruf gibt wieder den Runner
zurück, so dass wir Aufrufe verketten können. Wenn alles verkabelt ist, rufen wir Runner.run()
auf, um die Berechnung durchzuführen und das Ergebnis zurückzugeben. Das Ergebnis ist eine Liste von Tensor
en, die Anzahl der Elemente in der Liste hängt von der Anzahl der Runner.fetch
-Aufrufe ab, jeder Aufruf erzeugt einen zusätzlichen Listeneintrag. Dies ist ein vollständiges Beispiel, das alle Aufrufe zu einer langen Anweisung verkettet:
final Tensor<?> result =
// gets the session
bundle.session()
// creates a runner
.runner()
// binds tensors to input nodes in the graph, in our case
// `values` is an array of floats, toTensor creates a Tensor
// object, the first argument is a string with the name of
// the input node
.feed("wine_type" , toTensor(values[1], tensorsToClose))
.feed("fixed_acidity" , toTensor(values[2], tensorsToClose))
.feed("volatile_acidity" , toTensor(values[3], tensorsToClose))
.feed("citric_acid" , toTensor(values[4], tensorsToClose))
.feed("residual_sugar" , toTensor(values[5], tensorsToClose))
.feed("chlorides" , toTensor(values[6], tensorsToClose))
.feed("free_sulfur_dioxide" , toTensor(values[7], tensorsToClose))
.feed("total_sulfur_dioxide", toTensor(values[8], tensorsToClose))
.feed("density" , toTensor(values[9], tensorsToClose))
.feed("ph" , toTensor(values[10], tensorsToClose))
.feed("sulphates" , toTensor(values[11], tensorsToClose))
.feed("alcohol" , toTensor(values[12], tensorsToClose))
// define which output tensor to return
// (you can chain multiple `fetch` calls to
// return more then one tensor)
.fetch("dnn/head/logits:0")
// execute the runner - this returns a list
.run()
// We have only one fetch call, so we get a
// one-element-list. The `get(0)` call fetches
// the first element of the list
.get(0);
Auspacken der resultierenden Tensoren
Was jetzt noch bleibt, ist, das Ergebnis aus dem Tensor zu holen, der vom run()
-Aufruf zurückgegeben wird. Wenn der Ergebnis-Tensor nur ein Skalar ist, kann dafür einfach Tensor.floatValue()
, Tensor.booleanValue()
usw. aufgerufen werden. Wenn der resultierende Tensor kein Skalar ist, müssen die resultierenden Daten mit Tensor.copyTo(U destination)
abgerufen werden, wobei destination
ein multidimensionales Array ist. Eine vorgefertigte Regression mit neuronalen Netzen gibt beispielsweise immer einen zweidimensionalen Tensor zurück, auch wenn en nur ein einziges numerisches Ergebnis gibt. In diesem Fall kommt man wie folgt an das Ergebnis:
float[][] resultValues = (float[][]) result.copyTo(new float[1][1]);
float prediction = resultValues[0][0];
Der Type und die Anzahl der Dimensionen des Arrays hängen vom jeweiligen Modell ab.
Ressourcenverwaltung
Zwei Arten von Objekten müssen manuell geschlossen werden: Session
s und Tensor
en. Beachten Sie, dass alle Tensor
-Objekte - ob manuell erstellt oder aus einer laufenden Sitzung zurückgegeben - manuell geschlossen werden müssen. Ich ziehe es vor, dies zu tun, indem ich Tensor-Erzeuging immer in einer Hilfsfunktion ausführe, die alle erstellten Tensoren in einer Collection
sammelt um dann alles in einem finally
Block freizugeben, wenn ich fertig bin:
private static Tensor<Float> toTensor(final float f,
final Collection<Tensor<?>> tensorsToClose)
{
final Tensor<Float> t = Tensors.create(f);
if (tensorsToClose != null) {
tensorsToClose.add(t);
}
return t;
}
private static void closeTensors(final Collection<Tensor<?>> ts) {
for (final Tensor<?> t : ts) {
try {
t.close();
} catch (final Exception e) {
// TODO: decide on the error handling best fitting your use case here
// In most cases logging is the only useful thing left to do
System.err.println("Error closing Tensor.");
e.printStackTrace();
}
}
ts.clear();
}
private void runSession(final float foo, /* more params here */) {
final List<Tensor<?>> tensorsToClose = new ArrayList<Tensor<?>>();
try {
// run session
final List<Tensor<?>> result = bundle.session().runner()
.feed("foo", toTensor(foo, tensorsToClose))
// ... feed more tensors as necessary ...
.fetch("some_node")
// ... fetch more tensors as necessary ...
.run();
// mark result for cleanup
tensorsToClose.addAddAll(result);
// ... do something with the result ...
} finally {
closeTensors(tensorsToClose);
}
}
Hierbei muss man darauf achten, auf der Session
_nicht _close()
aufzurufen! Die Session
wird einmalig für das SavedModelBundle
erstellt, der session()
Aufruf gibt eine bestehende Referenz zurück, keine neue Session
. Die Session
ist thread-safe, so dass sie überall wiederverwendet werden kann. Sie muss erst geschlossen werden, wenn keine weiteren Aufrufe mehr folgen. Daher sollten die Session
nur am Ende des Programms oder beim Herunterfahren eines Servers beendet werden. Sie können dies einfach tun, indem Sie Ihr SavedModelBundle
schließen, das alle mit dem SavedModel
verbundenen Ressourcen freigibt. (Sie können sogar das Schließen der Sitzung auslassen, da das Ende Ihres JVM-Prozesses ohnehin alle damit verbundenen Ressourcen schließen sollte - ich hatte nie irgendwelche negativen Auswirkungen, aber tun Sie dies auf eigenes Risiko!)
Die richtigen Namen für Ein- und Ausgabeknoten bestimmen
Wenn Sie einen eigenen Estimator
geschrieben haben, wissen Sie wahrscheinlich wie Ihre Ein- und Ausgabeknoten heißen und welche Form sie haben. Manchmal haben Sie jedoch einen vorgefertigten Estimator
verwendet, bei dem Sie nicht wissen, wie die Ausgabeknoten aufgerufen werden, oder Sie haben das Modell nicht selbst geschrieben und müssen die gespeicherten Daten überprüfen, um zu wissen, was Sie aufrufen sollen. In diesem Fall müssen Sie Ihr SavedModel
auf der Kommandozeile überprüfen, um Ihr Tag, Ihre Ein- und Ausgabeknotennamen, deren Formen und Datentypen zu bestimmen. Dies kann durch aufeinanderfolgende Aufrufe des saved_model_cli
geschehen (wir verwenden hier das SavedModel
aus unserem Beispielprojekt, Ihre Ausgabe hängt natürlich vom verwendeten Modell ab):
chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/
The given SavedModel contains the following tag- sets:
serve
chris$ ~/Library/Python/3.6/bin/saved_model_cli show \
--dir saved_models/1512127459/ \
--tag_set serve
The given SavedModel MetaGraphDef contains
SignatureDefs with the following keys:
SignatureDef key: "predict"
chris$ ~/Library/Python/3.6/bin/saved_model_cli show\
--dir saved_models/1512127459/ \
--tag_set serve \
--signature_def predict
The given SavedModel SignatureDef contains the following input(s):
inputs['wine_type'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: wine_type:0
...
The given SavedModel SignatureDef contains the
following output(s):
outputs['predictions'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: dnn/head/logits:0
Method name is: tensorflow/serving/predict
Wie Sie sehen können, benötigen wir aufeinanderfolgende Aufrufe von saved_model_cli show
, um tiefer in unserem SavedModel
vorzudringen, um die Namen, Formen und Datentypen der Ein- und Ausgangstensoren zu bestimmen. Leider können diese Informationen nicht generisch mit der Java-API (AFAIK) abgerufen werden. Die Art und Form des Ergebnis-Tensors zumindest kann aber auch durch Aufrufe von Tensor.shape()
und Tensor.dataType()
überprüft werden.
Zusammenfassung
Eine TensorFlow-Sitzung in Java auszuführen ist ziemlich einfach. Hier noch einmal eine Checkliste:
SavedModelBundle
und die zugehörigeSession
sind thread-safe,Tensor
en nicht- Verwenden Sie das
saved_model_cli
, um die Namen und die Form Ihrer Ein- und Ausgabeknoten zu bestimmen. - Verpacken Sie Ihre Eingabedaten mit den Hilfsmethoden in der Klasse
Tensors
- Verwenden Sie die fluent-API des
SavedModelBundleBundle
, um eineSession
zu erhalten und auszuführen:bundle.session().runner().feed(...).fetch(...).run()
- Für skalare Ergebnisse: Verwenden Sie
Tensor.floatValue()
usw., um Daten aus den resultierendenTensor
en zu holen. - Für nicht skalare Ergebnisse: Verwenden Sie den richtigen Array-Typ und die richtige Form, um Daten aus Ihren resultierenden Tensoren mit
Tensor.copyTo()
zu holen. - Schließen Sie das
SavedModelBundle
erst dann, wenn Sie Ihre JVM beenden möchten, z.B. beim Herunterfahren des Servers.