Full tensors in Torchscript

Full tensors in Torchscript

I was productizing some PyTorch models recently and ran in to an issue with the way torch.full_like behaves in Torchscript.

import torch

The original code use torch.full_like to create a tensor with a given shape. It also forced the type of the new tensor to torch.float16. A minimal example that illustrates the issue is below.

def full(val: float, base: torch.Tensor):
    return torch.full_like(base, val, dtype=torch.float16)
full(42.0, torch.tensor([1, 2]))
tensor([42., 42.], dtype=torch.float16)

Converting directly to Torchscript fails, apparently because Torchscript can't select the appropriate overload.

@torch.jit.script
def full1(val: float, base: torch.Tensor):
    return torch.full_like(base, val, dytpe=torch.float16)

full1(42.0, torch.tensor([1, 2]))

...

RuntimeError: 
Arguments for call are not valid.
The following variants are available:

  aten::full_like(Tensor self, Scalar fill_value, *, int? memory_format=None) -> (Tensor):
  Keyword argument dytpe unknown.

  aten::full_like.dtype(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory=False, int? memory_format=None) -> (Tensor):
  Argument dtype not provided.

The original call is:
  File "<ipython-input-8-d96dada73816>", line 3
@torch.jit.script
def full1(val: float, base: torch.Tensor):
    return torch.full_like(base, val, dytpe=torch.float16)
           ~~~~~~~~~~~~~~~ <--- HERE

Workarounds

Provide all the arguments

The obvious solution is to provide all the arguments. That works, but is rather more verbose.

@torch.jit.script
def full2(val: float, base: torch.Tensor):
    return torch.full_like(
        base,
        val,
        dtype=torch.float16,
        layout=base.layout,
        device=base.device,
    )

full2(42.0, torch.tensor([1, 2]))
tensor([42., 42.], dtype=torch.float16)

Switch to torch.full

In this case, only the shape is needed. The device and layout are already the defaults, so torch.full is appropriate.

@torch.jit.script
def full3(val: float, base: torch.Tensor):
    return torch.full(base.shape, val, dtype=torch.float16)

full3(42.0, torch.tensor([1, 2]))
tensor([42., 42.], dtype=torch.float16)

links

social