TLDR: The default datatype of a numpy array translates to double/float64. If a Tensor is created from that array using
torch.as_tensor it will adopt that datatype, which is not compatible with the default datatype of a neural network model which is float32. Using that tensor as input to the nn model will result in the error
expected scalar type Float but found Double.
import torch from torch.nn import Linear import numpy as np model = Linear(3,1) input = np.array([[3.14,3,3],[1,2,3]]) t_input = torch.as_tensor(input) model(t_input)
After having created a Neural network model and a appropriate tensor, the following error pops up, when using the tensor as input:
RuntimeError: expected scalar type Float but found Double
One of the first google entries suggest to convert the model to the specified datatype and it indeed works:
But what is the root problem?
Upon further inspection you will notice, that the tensor created from the numpy array has the dtype torch.float64. This is because the function
torch.as_tensor() infers the dtype from the data source (pytorch doc). And the default dtype of a numpy array is indeed float64 (numpy doc).
t_input.dtype # torch.float64 input.dtype # dtype('float64')
Let’s switch perspectives! In the torch module, by default the dtype is float32 (pytorch doc) and can be changed with
torch.set_default_dtype(). This does influence the utilized dtypes when creating a neural network or tensor. Thus creating a tensor from a list should prove to be a compatible input for a neural network created in the same manner. However, if you create a tensor from existing data such as a numpy array using the
torch.as_Tensor() function, it infers the dtype from the previous datatype.
Practically spoken, you could change the dtype of your numpay array to float32, however you would lose precision this way.
torch.Tensor(input) will result in no problems, however that function implicitely converts the data into float32 / the default dtype, thus losing precision.
Further googling actually also reveals exactly my findings in this SO post.