Source code for supar.utils.fn

# -*- coding: utf-8 -*-

import os
import sys
import unicodedata
import urllib
import zipfile

import torch


def ispunct(token):
    return all(unicodedata.category(char).startswith('P') for char in token)


def isfullwidth(token):
    return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] for char in token)


def islatin(token):
    return all('LATIN' in unicodedata.name(char) for char in token)


def isdigit(token):
    return all('DIGIT' in unicodedata.name(char) for char in token)


def tohalfwidth(token):
    return unicodedata.normalize('NFKC', token)


[docs]def kmeans(x, k, max_it=32): r""" KMeans algorithm for clustering the sentences by length. Args: x (list[int]): The list of sentence lengths. k (int): The number of clusters. This is an approximate value. The final number of clusters can be less or equal to `k`. max_it (int): Maximum number of iterations. If centroids does not converge after several iterations, the algorithm will be early stopped. Returns: list[float], list[list[int]]: The first list contains average lengths of sentences in each cluster. The second is the list of clusters holding indices of data points. Examples: >>> x = torch.randint(10,20,(10,)).tolist() >>> x [15, 10, 17, 11, 18, 13, 17, 19, 18, 14] >>> centroids, clusters = kmeans(x, 3) >>> centroids [10.5, 14.0, 17.799999237060547] >>> clusters [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]] """ # the number of clusters must not be greater than the number of datapoints x, k = torch.tensor(x, dtype=torch.float), min(len(x), k) # collect unique datapoints d = x.unique() # initialize k centroids randomly c = d[torch.randperm(len(d))[:k]] # assign each datapoint to the cluster with the closest centroid dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) for _ in range(max_it): # if an empty cluster is encountered, # choose the farthest datapoint from the biggest cluster and move that the empty one mask = torch.arange(k).unsqueeze(-1).eq(y) none = torch.where(~mask.any(-1))[0].tolist() while len(none) > 0: for i in none: # the biggest cluster b = torch.where(mask[mask.sum(-1).argmax()])[0] # the datapoint farthest from the centroid of cluster b f = dists[b].argmax() # update the assigned cluster of f y[b[f]] = i # re-calculate the mask mask = torch.arange(k).unsqueeze(-1).eq(y) none = torch.where(~mask.any(-1))[0].tolist() # update the centroids c, old = (x * mask).sum(-1) / mask.sum(-1), c # re-assign all datapoints to clusters dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) # stop iteration early if the centroids converge if c.equal(old): break # assign all datapoints to the new-generated clusters # the empty ones are discarded assigned = y.unique().tolist() # get the centroids of the assigned clusters centroids = c[assigned].tolist() # map all values of datapoints to buckets clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned] return centroids, clusters
[docs]def stripe(x, n, w, offset=(0, 0), dim=1): r""" Returns a diagonal stripe of the tensor. Args: x (~torch.Tensor): the input tensor with 2 or more dims. n (int): the length of the stripe. w (int): the width of the stripe. offset (tuple): the offset of the first two dims. dim (int): 1 if returns a horizontal stripe; 0 otherwise. Returns: a diagonal stripe of the tensor. Examples: >>> x = torch.arange(25).view(5, 5) >>> x tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) >>> stripe(x, 2, 3) tensor([[0, 1, 2], [6, 7, 8]]) >>> stripe(x, 2, 3, (1, 1)) tensor([[ 6, 7, 8], [12, 13, 14]]) >>> stripe(x, 2, 3, (1, 1), 0) tensor([[ 6, 11, 16], [12, 17, 22]]) """ x, seq_len = x.contiguous(), x.size(1) stride, numel = list(x.stride()), x[0, 0].numel() stride[0] = (seq_len + 1) * numel stride[1] = (1 if dim == 1 else seq_len) * numel return x.as_strided(size=(n, w, *x.shape[2:]), stride=stride, storage_offset=(offset[0]*seq_len+offset[1])*numel)
def pad(tensors, padding_value=0, total_length=None, padding_side='right'): size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) for i in range(len(tensors[0].size()))] if total_length is not None: assert total_length >= size[1] size[1] = total_length out_tensor = tensors[0].data.new(*size).fill_(padding_value) for i, tensor in enumerate(tensors): out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor return out_tensor def download(url, reload=False): path = os.path.join(os.path.expanduser('~/.cache/supar'), os.path.basename(urllib.parse.urlparse(url).path)) os.makedirs(os.path.dirname(path), exist_ok=True) if reload: os.remove(path) if os.path.exists(path) else None if not os.path.exists(path): sys.stderr.write(f"Downloading: {url} to {path}\n") try: torch.hub.download_url_to_file(url, path, progress=True) except urllib.error.URLError: raise RuntimeError(f"File {url} unavailable. Please try other sources.") if zipfile.is_zipfile(path): with zipfile.ZipFile(path) as f: members = f.infolist() path = os.path.join(os.path.dirname(path), members[0].filename) if len(members) != 1: raise RuntimeError('Only one file (not dir) is allowed in the zipfile.') if reload or not os.path.exists(path): f.extractall(os.path.dirname(path)) return path def get_rng_state(): state = {'rng_state': torch.get_rng_state()} if torch.cuda.is_available(): state['cuda_rng_state'] = torch.cuda.get_rng_state() return state def set_rng_state(state): torch.set_rng_state(state['rng_state']) if torch.cuda.is_available(): torch.cuda.set_rng_state(state['cuda_rng_state'])