Source code for supar.structs.tree

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from supar.structs.dist import StructuredDistribution
from supar.structs.fn import mst
from supar.structs.semiring import LogSemiring
from supar.utils.fn import stripe
from torch.distributions.utils import lazy_property


[docs]class MatrixTree(StructuredDistribution): r""" MatrixTree for calculating partitions and marginals of non-projective dependency trees in :math:`O(n^3)` by an adaptation of Kirchhoff's MatrixTree Theorem :cite:`koo-etal-2007-structured`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible dependent-head pairs. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking, regardless of root positions. Default: ``None``. multiroot (bool): If ``False``, requires the tree to contain only a single root. Default: ``True``. Examples: >>> from supar import MatrixTree >>> batch_size, seq_len = 2, 5 >>> lens = torch.tensor([3, 4]) >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]]) >>> s1 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens) >>> s2 = MatrixTree(torch.randn(batch_size, seq_len, seq_len), lens) >>> s1.max tensor([0.7174, 3.7910], grad_fn=<SumBackward1>) >>> s1.argmax tensor([[0, 0, 1, 1, 0], [0, 4, 1, 0, 3]]) >>> s1.log_partition tensor([2.0229, 6.0558], grad_fn=<CopyBackwards>) >>> s1.log_prob(arcs) tensor([-3.2209, -2.5756], grad_fn=<SubBackward0>) >>> s1.entropy tensor([1.9711, 3.4497], grad_fn=<SubBackward0>) >>> s1.kl(s2) tensor([1.3354, 2.6914], grad_fn=<AddBackward0>) """ def __init__(self, scores, lens=None, multiroot=False): super().__init__(scores) batch_size, seq_len = scores.shape[:2] self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0) self.multiroot = multiroot def __repr__(self): return f"{self.__class__.__name__}(multiroot={self.multiroot})" def __add__(self, other): return MatrixTree(torch.stack((self.scores, other.scores)), self.lens, self.multiroot) @lazy_property def max(self): arcs = self.argmax return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1) @lazy_property def argmax(self): with torch.no_grad(): return mst(self.scores, self.mask, self.multiroot)
[docs] def kmax(self, k): # TODO: Camerini algorithm raise NotImplementedError
[docs] def sample(self): raise NotImplementedError
@lazy_property def entropy(self): return self.log_partition - (self.marginals * self.scores).sum((-1, -2))
[docs] def cross_entropy(self, other): return other.log_partition - (self.marginals * other.scores).sum((-1, -2))
[docs] def kl(self, other): return other.log_partition - self.log_partition + (self.marginals * (self.scores - other.scores)).sum((-1, -2))
def score(self, value, partial=False): arcs = value if partial: mask, lens = self.mask, self.lens mask = mask.index_fill(1, self.lens.new_tensor(0), 1) mask = mask.unsqueeze(1) & mask.unsqueeze(2) arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0) scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask)) return self.__class__(scores, self.mask, **self.kwargs).log_partition return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1) @torch.enable_grad() def forward(self, semiring): s_arc = self.scores batch_size, *_ = s_arc.shape mask = self.mask.index_fill(1, self.lens.new_tensor(0), 1) s_arc = semiring.zero_mask(s_arc, ~(mask.unsqueeze(-1) & mask.unsqueeze(-2))) # A(i, j) = exp(s(i, j)) # double precision to prevent overflows A = torch.exp(s_arc).double() # Weighted degree matrix # D(i, j) = sum_j(A(i, j)), if h == m # 0, otherwise D = torch.zeros_like(A) D.diagonal(0, 1, 2).copy_(A.sum(-1)) # Laplacian matrix # L(i, j) = D(i, j) - A(i, j) L = D - A if not self.multiroot: L.diagonal(0, 1, 2).add_(-A[..., 0]) L[..., 1] = A[..., 0] L = nn.init.eye_(torch.empty_like(A[0])).repeat(batch_size, 1, 1).masked_scatter_(mask.unsqueeze(-1), L[mask]) # Z = L^(0, 0), the minor of L w.r.t row 0 and column 0 return L[:, 1:, 1:].slogdet()[1].float()
[docs]class DependencyCRF(StructuredDistribution): r""" First-order TreeCRF for projective dependency trees :cite:`eisner-2000-bilexical,zhang-etal-2020-efficient`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible dependent-head pairs. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking, regardless of root positions. Default: ``None``. multiroot (bool): If ``False``, requires the tree to contain only a single root. Default: ``True``. Examples: >>> from supar import DependencyCRF >>> batch_size, seq_len = 2, 5 >>> lens = torch.tensor([3, 4]) >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]]) >>> s1 = DependencyCRF(torch.randn(batch_size, seq_len, seq_len), lens) >>> s2 = DependencyCRF(torch.randn(batch_size, seq_len, seq_len), lens) >>> s1.max tensor([3.6346, 1.7194], grad_fn=<IndexBackward>) >>> s1.argmax tensor([[0, 2, 3, 0, 0], [0, 0, 3, 1, 1]]) >>> s1.log_partition tensor([4.1007, 3.3383], grad_fn=<IndexBackward>) >>> s1.log_prob(arcs) tensor([-1.3866, -5.5352], grad_fn=<SubBackward0>) >>> s1.entropy tensor([0.9979, 2.6056], grad_fn=<IndexBackward>) >>> s1.kl(s2) tensor([1.6631, 2.6558], grad_fn=<IndexBackward>) """ def __init__(self, scores, lens=None, multiroot=False): super().__init__(scores) batch_size, seq_len = scores.shape[:2] self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0) self.multiroot = multiroot def __repr__(self): return f"{self.__class__.__name__}(multiroot={self.multiroot})" def __add__(self, other): return DependencyCRF(torch.stack((self.scores, other.scores), -1), self.lens, self.multiroot) @lazy_property def argmax(self): return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
[docs] def topk(self, k): preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1) return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
def score(self, value, partial=False): arcs = value if partial: mask, lens = self.mask, self.lens mask = mask.index_fill(1, self.lens.new_tensor(0), 1) mask = mask.unsqueeze(1) & mask.unsqueeze(2) arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0) scores = LogSemiring.zero_mask(self.scores, ~(arcs & mask)) return self.__class__(scores, self.mask, **self.kwargs).log_partition return LogSemiring.prod(LogSemiring.one_mask(self.scores.gather(-1, arcs.unsqueeze(-1)).squeeze(-1), ~self.mask), -1) def forward(self, semiring): s_arc = self.scores batch_size, seq_len = s_arc.shape[:2] # [seq_len, seq_len, batch_size, ...], (h->m) s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0))) s_i = semiring.zeros_like(s_arc) s_c = semiring.zeros_like(s_arc) semiring.one_(s_c.diagonal().movedim(-1, 1)) for w in range(1, seq_len): n = seq_len - w # [n, batch_size, ...] il = ir = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1) # I(j->i) = <C(i->r), C(j->r+1)> * s(j->i), i <= r < j # fill the w-th diagonal of the lower triangular part of s_i with I(j->i) of n spans s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1)) # I(i->j) = <C(i->r), C(j->r+1)> * s(i->j), i <= r < j # fill the w-th diagonal of the upper triangular part of s_i with I(i->j) of n spans s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1)) # [n, batch_size, ...] # C(j->i) = <C(r->i), I(j->r)>, i <= r < j cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1) s_c.diagonal(-w).copy_(cl.movedim(0, -1)) # C(i->j) = <I(i->r), C(r->j)>, i < r <= j cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1) s_c.diagonal(w).copy_(cr.movedim(0, -1)) if not self.multiroot: s_c[0, w][self.lens.ne(w)] = semiring.zero return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
[docs]class Dependency2oCRF(StructuredDistribution): r""" Second-order TreeCRF for projective dependency trees :cite:`mcdonald-pereira-2006-online,zhang-etal-2020-efficient`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all possible dependent-head pairs. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking, regardless of root positions. Default: ``None``. multiroot (bool): If ``False``, requires the tree to contain only a single root. Default: ``True``. Examples: >>> from supar import Dependency2oCRF >>> batch_size, seq_len = 2, 5 >>> lens = torch.tensor([3, 4]) >>> arcs = torch.tensor([[0, 2, 0, 4, 2], [0, 3, 1, 0, 3]]) >>> sibs = torch.tensor([CoNLL.get_sibs(i) for i in arcs[:, 1:].tolist()]) >>> s1 = Dependency2oCRF((torch.randn(batch_size, seq_len, seq_len), torch.randn(batch_size, seq_len, seq_len, seq_len)), lens) >>> s2 = Dependency2oCRF((torch.randn(batch_size, seq_len, seq_len), torch.randn(batch_size, seq_len, seq_len, seq_len)), lens) >>> s1.max tensor([0.7574, 3.3634], grad_fn=<IndexBackward>) >>> s1.argmax tensor([[0, 3, 3, 0, 0], [0, 4, 4, 4, 0]]) >>> s1.log_partition tensor([1.9906, 4.3599], grad_fn=<IndexBackward>) >>> s1.log_prob((arcs, sibs)) tensor([-0.6975, -6.2845], grad_fn=<SubBackward0>) >>> s1.entropy tensor([1.6436, 2.1717], grad_fn=<IndexBackward>) >>> s1.kl(s2) tensor([0.4929, 2.0759], grad_fn=<IndexBackward>) """ def __init__(self, scores, lens=None, multiroot=False): super().__init__(scores) batch_size, seq_len = scores[0].shape[:2] self.lens = scores[0].new_full((batch_size,), seq_len-1).long() if lens is None else lens self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) self.mask = self.mask.index_fill(1, self.lens.new_tensor(0), 0) self.multiroot = multiroot def __repr__(self): return f"{self.__class__.__name__}(multiroot={self.multiroot})" def __add__(self, other): return Dependency2oCRF([torch.stack((i, j), -1) for i, j in zip(self.scores, other.scores)], self.lens, self.multiroot) @lazy_property def argmax(self): return self.lens.new_zeros(self.mask.shape).masked_scatter_(self.mask, torch.where(self.backward(self.max.sum()))[2])
[docs] def topk(self, k): preds = torch.stack([torch.where(self.backward(i))[2] for i in self.kmax(k).sum(0)], -1) return self.lens.new_zeros(*self.mask.shape, k).masked_scatter_(self.mask.unsqueeze(-1), preds)
def score(self, value, partial=False): arcs, sibs = value if partial: mask, lens = self.mask, self.lens mask = mask.index_fill(1, self.lens.new_tensor(0), 1) mask = mask.unsqueeze(1) & mask.unsqueeze(2) arcs = arcs.index_fill(1, lens.new_tensor(0), -1).unsqueeze(-1) arcs = arcs.eq(lens.new_tensor(range(mask.shape[1]))) | arcs.lt(0) s_arc, s_sib = LogSemiring.zero_mask(self.scores[0], ~(arcs & mask)), self.scores[1] return self.__class__((s_arc, s_sib), self.mask, **self.kwargs).log_partition s_arc = self.scores[0].gather(-1, arcs.unsqueeze(-1)).squeeze(-1) s_arc = LogSemiring.prod(LogSemiring.one_mask(s_arc, ~self.mask), -1) s_sib = self.scores[1].gather(-1, sibs.unsqueeze(-1)).squeeze(-1) s_sib = LogSemiring.prod(LogSemiring.one_mask(s_sib, ~sibs.gt(0)), (-1, -2)) return LogSemiring.mul(s_arc, s_sib) @torch.enable_grad() def forward(self, semiring): s_arc, s_sib = self.scores batch_size, seq_len = s_arc.shape[:2] # [seq_len, seq_len, batch_size, ...], (h->m) s_arc = semiring.convert(s_arc.movedim((1, 2), (1, 0))) # [seq_len, seq_len, seq_len, batch_size, ...], (h->m->s) s_sib = semiring.convert(s_sib.movedim((0, 2), (3, 0))) s_i = semiring.zeros_like(s_arc) s_s = semiring.zeros_like(s_arc) s_c = semiring.zeros_like(s_arc) semiring.one_(s_c.diagonal().movedim(-1, 1)) for w in range(1, seq_len): n = seq_len - w # I(j->i) = <I(j->r), S(j->r, i)> * s(j->i), i < r < j # <C(j->j), C(i->j-1)> * s(j->i), otherwise # [n, w, batch_size, ...] il = semiring.times(stripe(s_i, n, w, (w, 1)), stripe(s_s, n, w, (1, 0), 0), stripe(s_sib[range(w, n+w), range(n), :], n, w, (0, 1))) il[:, -1] = semiring.mul(stripe(s_c, n, 1, (w, w)), stripe(s_c, n, 1, (0, w - 1))).squeeze(1) il = semiring.sum(il, 1) s_i.diagonal(-w).copy_(semiring.mul(il, s_arc.diagonal(-w).movedim(-1, 0)).movedim(0, -1)) # I(i->j) = <I(i->r), S(i->r, j)> * s(i->j), i < r < j # <C(i->i), C(j->i+1)> * s(i->j), otherwise # [n, w, batch_size, ...] ir = semiring.times(stripe(s_i, n, w), stripe(s_s, n, w, (0, w), 0), stripe(s_sib[range(n), range(w, n+w), :], n, w)) if not self.multiroot: semiring.zero_(ir[0]) ir[:, 0] = semiring.mul(stripe(s_c, n, 1), stripe(s_c, n, 1, (w, 1))).squeeze(1) ir = semiring.sum(ir, 1) s_i.diagonal(w).copy_(semiring.mul(ir, s_arc.diagonal(w).movedim(-1, 0)).movedim(0, -1)) # [batch_size, ..., n] sl = sr = semiring.dot(stripe(s_c, n, w), stripe(s_c, n, w, (w, 1)), 1).movedim(0, -1) # S(j, i) = <C(i->r), C(j->r+1)>, i <= r < j s_s.diagonal(-w).copy_(sl) # S(i, j) = <C(i->r), C(j->r+1)>, i <= r < j s_s.diagonal(w).copy_(sr) # [n, batch_size, ...] # C(j->i) = <C(r->i), I(j->r)>, i <= r < j cl = semiring.dot(stripe(s_c, n, w, (0, 0), 0), stripe(s_i, n, w, (w, 0)), 1) s_c.diagonal(-w).copy_(cl.movedim(0, -1)) # C(i->j) = <I(i->r), C(r->j)>, i < r <= j cr = semiring.dot(stripe(s_i, n, w, (0, 1)), stripe(s_c, n, w, (1, w), 0), 1) s_c.diagonal(w).copy_(cr.movedim(0, -1)) return semiring.unconvert(s_c)[0][self.lens, range(batch_size)]
[docs]class ConstituencyCRF(StructuredDistribution): r""" Constituency TreeCRF :cite:`zhang-etal-2020-fast,stern-etal-2017-minimal`. Args: scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``. Scores of all constituents. lens (~torch.LongTensor): ``[batch_size]``. Sentence lengths for masking. Examples: >>> from supar import ConstituencyCRF >>> batch_size, seq_len = 2, 5 >>> lens = torch.tensor([3, 4]) >>> charts = torch.tensor([[[0, 1, 0, 1, 0], [0, 0, 1, 1, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 1, 1, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]]]).bool() >>> s1 = ConstituencyCRF(torch.randn(batch_size, seq_len, seq_len), lens) >>> s2 = ConstituencyCRF(torch.randn(batch_size, seq_len, seq_len), lens) >>> s1.max tensor([ 2.5068, -0.5628], grad_fn=<IndexBackward>) >>> s1.argmax [[[0, 3], [0, 1], [1, 3], [1, 2], [2, 3]], [[0, 4], [0, 2], [0, 1], [1, 2], [2, 4], [2, 3], [3, 4]]] >>> s1.log_partition tensor([2.9235, 0.0154], grad_fn=<IndexBackward>) >>> s1.log_prob(charts) tensor([-0.4167, -0.5781], grad_fn=<SubBackward0>) >>> s1.entropy tensor([0.6415, 1.2026], grad_fn=<IndexBackward>) >>> s1.kl(s2) tensor([0.0362, 2.9017], grad_fn=<IndexBackward>) """ def __init__(self, scores, lens=None): super().__init__(scores) batch_size, seq_len = scores.shape[:2] self.lens = scores.new_full((batch_size,), seq_len-1).long() if lens is None else lens self.mask = (self.lens.unsqueeze(-1) + 1).gt(self.lens.new_tensor(range(seq_len))) self.mask = self.mask.unsqueeze(1) & scores.new_ones(scores.shape[:3]).bool().triu_(1) def __repr__(self): return f"{self.__class__.__name__}()" def __add__(self, other): return ConstituencyCRF(torch.stack((self.scores, other.scores), -1), self.lens) @lazy_property def argmax(self): return [sorted(torch.nonzero(i).tolist(), key=lambda x:(x[0], -x[1])) for i in self.backward(self.max.sum())]
[docs] def topk(self, k): return list(zip(*[[sorted(torch.nonzero(i).tolist(), key=lambda x:(x[0], -x[1])) for i in self.backward(i)] for i in self.kmax(k).sum(0)]))
def score(self, value): return LogSemiring.prod(LogSemiring.prod(LogSemiring.one_mask(self.scores, ~(self.mask & value)), -1), -1) @torch.enable_grad() def forward(self, semiring): batch_size, seq_len = self.scores.shape[:2] # [seq_len, seq_len, batch_size, ...], (l->r) scores = semiring.convert(self.scores.movedim((1, 2), (0, 1))) s = semiring.zeros_like(scores) for w in range(1, seq_len): n = seq_len - w if w == 1: s.diagonal(w).copy_(scores.diagonal(w)) continue # [n, batch_size, ...] s_s = semiring.dot(stripe(s, n, w-1, (0, 1)), stripe(s, n, w-1, (1, w), 0), 1) s.diagonal(w).copy_(semiring.mul(s_s, scores.diagonal(w).movedim(-1, 0)).movedim(0, -1)) return semiring.unconvert(s)[0][self.lens, range(batch_size)]