Source code for supar.structs.dist

# -*- coding: utf-8 -*-

import torch
import torch.autograd as autograd
from supar.structs.semiring import (CrossEntropySemiring, EntropySemiring,
                                    KLDivergenceSemiring, KMaxSemiring,
                                    LogSemiring, MaxSemiring, SampledSemiring)
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property


[docs]class StructuredDistribution(Distribution): r""" Base class for structured distribution :math:`p(y)` :cite:`eisner-2016-inside,goodman-1999-semiring,li-eisner-2009-first`. Args: scores (torch.Tensor): Log potentials, also for high-order cases. """ def __init__(self, scores, **kwargs): self.scores = scores.requires_grad_() if isinstance(scores, torch.Tensor) else [s.requires_grad_() for s in scores] self.kwargs = kwargs def __repr__(self): return f"{self.__class__.__name__}()" @lazy_property def log_partition(self): r""" Computes the log partition function of the distribution :math:`p(y)`. """ return self.forward(LogSemiring) @lazy_property def marginals(self): r""" Computes marginal probabilities of the distribution :math:`p(y)`. """ return self.backward(self.log_partition.sum()) @lazy_property def max(self): r""" Computes the max score of the distribution :math:`p(y)`. """ return self.forward(MaxSemiring) @lazy_property def argmax(self): r""" Computes :math:`\arg\max_y p(y)` of the distribution :math:`p(y)`. """ raise NotImplementedError @lazy_property def mode(self): return self.argmax
[docs] def kmax(self, k): r""" Computes the k-max of the distribution :math:`p(y)`. """ return self.forward(KMaxSemiring(k))
[docs] def topk(self, k): r""" Computes the k-argmax of the distribution :math:`p(y)`. """ raise NotImplementedError
[docs] def sample(self): r""" Obtains a structured sample from the distribution :math:`y \sim p(y)`. TODO: multi-sampling. """ return self.backward(self.forward(SampledSemiring).sum()).detach()
@lazy_property def entropy(self): r""" Computes entropy :math:`H[p]` of the distribution :math:`p(y)`. """ return self.forward(EntropySemiring)
[docs] def cross_entropy(self, other): r""" Computes cross-entropy :math:`H[p,q]` of self and another distribution. Args: other (~supar.structs.dist.StructuredDistribution): Comparison distribution. """ return (self + other).forward(CrossEntropySemiring)
[docs] def kl(self, other): r""" Computes KL-divergence :math:`KL[p \parallel q]=H[p,q]-H[p]` of self and another distribution. Args: other (~supar.structs.dist.StructuredDistribution): Comparison distribution. """ return (self + other).forward(KLDivergenceSemiring)
[docs] def log_prob(self, value, **kwargs): """ Computes log probability over values :math:`p(y)`. """ return self.score(value, **kwargs) - self.log_partition
def score(self, value): raise NotImplementedError @torch.enable_grad() def forward(self, semiring): raise NotImplementedError def backward(self, log_partition): return autograd.grad(log_partition, self.scores if isinstance(self.scores, torch.Tensor) else self.scores[0], create_graph=True)[0]