torch.nn.CTCLoss 与warpctc

it2022-05-05  164

1.torch.nn.CTCLoss

import torch from torch.nn import CTCLoss torch.backends.cudnn.benchmark = True T = 50 # Input sequence length C = 20 # Number of classes (including blank) N = 16 # Batch size S = 30 # Target sequence length of longest target in batch S_min = 10 torch.manual_seed(1234) input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() target = torch.randint(low=1, high=C, size=(N,S), dtype=torch.long) input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) critenzer = CTCLoss() loss = critenzer(input,target,input_lengths,target_lengths) loss.backward() loss

output:

tensor(10.5236, grad_fn=< MeanBackward0>)

2.warpctc_pytorch CTCLoss

源码见github: warpctc-pytorch

from warpctc_pytorch import CTCLoss as ctc probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous() labels = torch.IntTensor([1, 2]) label_sizes = torch.IntTensor([2]) probs_sizes = torch.IntTensor([2]) probs.requires_grad_(True) # tells autograd to compute gradients for probs

output:

tensor([2.4629], grad_fn=<_CTCBackward>)

probs: Tensor of (seqLength x batch x outputDim) containing output from network labels: 1 dimensional Tensor containing all the targets of the batch in one sequence probs_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example

3. 总结

warp中labels的size应是N*S或 sum(target_lens)torch.nn中的labels的size应为(N,S)或sum(target_lens)区别似乎主要在log_softmax()尽量还是先用warpctc

最新回复(0)