Source code for supar.structs.chain

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import Optional

import torch
from supar.structs.dist import StructuredDistribution
from supar.structs.semiring import LogSemiring, Semiring
from torch.distributions.utils import lazy_property


[docs]class LinearChainCRF(StructuredDistribution): r""" Linear-chain CRFs :cite:`lafferty-etal-2001-crf`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, n_tags]``. Log potentials. trans (~torch.Tensor): ``[n_tags+1, n_tags+1]``. Transition scores. ``trans[-1, :-1]``/``trans[:-1, -1]`` represent transitions for start/end positions respectively. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking. Default: ``None``. Examples: >>> from supar import LinearChainCRF >>> batch_size, seq_len, n_tags = 2, 5, 4 >>> lens = torch.tensor([3, 4]) >>> value = torch.randint(n_tags, (batch_size, seq_len)) >>> s1 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens) >>> s2 = LinearChainCRF(torch.randn(batch_size, seq_len, n_tags), torch.randn(n_tags+1, n_tags+1), lens) >>> s1.max tensor([4.4120, 8.9672], grad_fn=<MaxBackward0>) >>> s1.argmax tensor([[2, 0, 3, 0, 0], [3, 3, 3, 2, 0]]) >>> s1.log_partition tensor([ 6.3486, 10.9106], grad_fn=<LogsumexpBackward>) >>> s1.log_prob(value) tensor([ -8.1515, -10.5572], grad_fn=<SubBackward0>) >>> s1.entropy tensor([3.4150, 3.6549], grad_fn=<SelectBackward>) >>> s1.kl(s2) tensor([4.0333, 4.3807], grad_fn=<SelectBackward>) """ def __init__( self, scores: torch.Tensor, trans: Optional[torch.Tensor] = None, lens: Optional[torch.LongTensor] = None ) -> LinearChainCRF: super().__init__(scores, lens=lens) batch_size, seq_len, self.n_tags = scores.shape[:3] self.lens = scores.new_full((batch_size,), seq_len).long() if lens is None else lens self.mask = self.lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len))) self.trans = self.scores.new_full((self.n_tags+1, self.n_tags+1), LogSemiring.one) if trans is None else trans def __repr__(self): return f"{self.__class__.__name__}(n_tags={self.n_tags})" def __add__(self, other): return LinearChainCRF(torch.stack((self.scores, other.scores), -1), torch.stack((self.trans, other.trans), -1), self.lens) @lazy_property def argmax(self): return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
[docs] def topk(self, k: int) -> torch.LongTensor: preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1) return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
def score(self, value: torch.LongTensor) -> torch.Tensor: scores, mask, value = self.scores.transpose(0, 1), self.mask.t(), value.t() prev, succ = torch.cat((torch.full_like(value[:1], -1), value[:-1]), 0), value # [seq_len, batch_size] alpha = scores.gather(-1, value.unsqueeze(-1)).squeeze(-1) # [batch_size] alpha = LogSemiring.prod(LogSemiring.one_mask(LogSemiring.mul(alpha, self.trans[prev, succ]), ~mask), 0) alpha = alpha + self.trans[value.gather(0, self.lens.unsqueeze(0) - 1).squeeze(0), torch.full_like(value[0], -1)] return alpha def forward(self, semiring: Semiring) -> torch.Tensor: # [seq_len, batch_size, n_tags, ...] scores = semiring.convert(self.scores.transpose(0, 1)) trans = semiring.convert(self.trans) mask = self.mask.t() # [batch_size, n_tags] alpha = semiring.mul(trans[-1, :-1], scores[0]) for i in range(1, len(mask)): alpha[mask[i]] = semiring.mul(semiring.dot(alpha.unsqueeze(2), trans[:-1, :-1], 1), scores[i])[mask[i]] alpha = semiring.dot(alpha, trans[:-1, -1], 1) return semiring.unconvert(alpha)