Pytorch – Skalartyp Float erwartet, aber Double gefunden

TLDR: Der Standard-Datentyp eines Numpy-Arrays ist double/float64. Wenn ein Tensor aus diesem Array mit torch.as_tensor() erstellt wird, nimmt er diesen Datentyp an. Der Standarddatentyp eines neuronalen Netzwerkmodells ist allerdings float32. Die Verwendung des float64 Tensors als Eingabe für das NN-Modell ist somit nicht kompatibel und führt zu der entsprechenden Fehlermeldung “Skalartyp Float erwartet, aber Double gefunden”.

Beobachtung

import torch
from torch.nn import Linear
import numpy as np
model = Linear(3,1)
input = np.array([[3.14,3,3],[1,2,3]])
t_input = torch.as_tensor(input)
model(t_input)

Nachdem ein Neuronales Netzmodell und ein entsprechender Tensor erstellt wurde, erscheint folgender Fehler, wenn der Tensor als Eingabewert verwendet wurde:

RuntimeError: expected scalar type Float but found Double

Auflösung

Einer der ersten Google-Einträge schlägt vor, das Modell in den angegebenen Datentyp zu konvertieren, und siehe da es funktioniert tatsächlich:

model.double()
model(t_input)

Aber was ist überhaupt das eigentliche Problem?

Bei näherer Untersuchung kann festgestellt werden, dass der aus dem Numpy-Array erstellte Tensor den dtype torch.float64 besitzt. Das liegt daran, dass die Funktion torch.as_tensor() den dtype aus der Datenquelle ableitet (pytorch Doku). Und der Standard dtype eines Numpy-Arrays ist nunmal float64 (Numpy Doku).

t_input.dtype
# torch.float64

input.dtype
# dtype('float64')

Weitere Anmerkungen

Treten wir doch einen Schritt zurück und betrachten das Problem von einer anderen Perspektive. Im torch-Modul ist der dtype standardmäßig float32 (pytorch Doku) und kann mit torch.set_default_dtype() angepasst werden. Dies beeinflusst die verwendeten dtypes beim Erstellen eines neuronalen Netzes oder Tensors. Unter Vorraussetzung, dass keine anderen Einstellungen getätigt wurden, ist ein Tensor, welcher aus einer Liste erstellt wurde, somit als Eingabewert problemlos mit einem neuronalem Netz kompatibel. Wenn jedoch ein Tensor aus einer vorhandenen Datenquelle wie einem Numpy-Array mit der Funktion torch.as_Tensor() erstellt wurde, wird der dtype aus dem Datentyp der Datenquelle abgeleitet.

Prinzipiell könnte daher der Datentyp des numpy Arrays auch auf float32 konvertiert werden, würde aber an Präzision verlieren.

Zu beachten ist auch, dass der Aufruf von torch.Tensor(input) zwar zu keinen Fehlermeldungen führt, allerdings diese Funktion die Daten implizit in float32 / den Standard-dtype umwandelt, wodurch wiederum potentiell Präzision verloren geht.

Weitere Referenzen

Nach weiterer ergiebiger Suche, habe ich auch diesen SO Beitrag gefunden, der meine Beobachtungen bestätigt.

Kommentar verfassen

Your email address will not be published. Required fields are marked *

hungsblog | Nguyen Hung Manh | Dresden
Nach oben scrollen