Commit 6c5f1fd9 authored by Nik Vaessen's avatar Nik Vaessen
Browse files

initial commit

parents
Pipeline #56449 failed with stages
in 1 minute and 56 seconds
################################################################################
#
# This file defines a typed batch produced by the dataloader(s)
#
# Author(s): Nik Vaessen
################################################################################
import dataclasses
from typing import List
import torch
import torch as t
from torch.utils.data._utils.collate import default_collate
from skeleton.data.collating import collate_append_constant
################################################################################
# data sample and data batch classes
@dataclasses.dataclass
class SpeakerClassificationDataSample:
# a unique identifier for this ground_truth and input tensor pairing
key: str
# integer value for class of sample
ground_truth: int
# tensor of floats with shape depending on the particular task
# shape is most likely [NUM_FEATURES, NUM_FRAMES]
network_input: t.Tensor
@dataclasses.dataclass
class SpeakerClassificationDataBatch:
# the number of samples this batch contains
batch_size: int
# list of strings with length BATCH_SIZE where each index matches
# a unique identifier to the ground_truth or input tensor at the
# particular batch dimension
keys: List[str]
# tensor of floats with shape [BATCH_SIZE, NUM_FEATURES, NUM_FRAMES]
network_input: t.Tensor
# tensor of integers with shape [BATCH_SIZE]
ground_truth: t.Tensor
def __len__(self):
return self.batch_size
def to(self, device: torch.device) -> "SpeakerClassificationDataBatch":
return SpeakerClassificationDataBatch(
self.batch_size,
self.keys,
self.network_input.to(device),
self.ground_truth.to(device),
)
@staticmethod
def default_collate_fn(
lst: List[SpeakerClassificationDataSample],
) -> "SpeakerClassificationDataBatch":
batch_size = len(lst)
keys = default_collate([sample.key for sample in lst])
network_input = default_collate([sample.network_input for sample in lst])
ground_truth = t.squeeze(
default_collate([sample.ground_truth for sample in lst])
)
return SpeakerClassificationDataBatch(
batch_size=batch_size,
keys=keys,
network_input=network_input,
ground_truth=ground_truth,
)
@staticmethod
def pad_right_collate_fn(
lst: List[SpeakerClassificationDataSample],
) -> "SpeakerClassificationDataBatch":
batch_size = len(lst)
keys = default_collate([sample.key for sample in lst])
network_input = collate_append_constant(
[sample.network_input for sample in lst], frame_dim=0, feature_dim=1
)
ground_truth = t.squeeze(
default_collate([sample.ground_truth for sample in lst])
)
return SpeakerClassificationDataBatch(
batch_size=batch_size,
keys=keys,
network_input=network_input,
ground_truth=ground_truth,
)
################################################################################
#
# Utility functions for collating a batch with different number of frames in
# each sample.
#
# The functions assume an input dimensionality of
#
# [NUM_FEATURES, NUM_FRAMES] for MFCC-like input
# or
# [NUM_FRAMES] for wave-like input
#
# This will result in batches with respective dimensionality
# [NUM_SAMPLES, NUM_FEATURES, NUM_FRAMES] or [NUM_SAMPLES, NUM_FRAMES]
#
# Author(s): Nik Vaessen
################################################################################
from typing import List, Callable, Optional
import torch as t
from torch.nn import (
ConstantPad1d,
ConstantPad2d,
)
################################################################################
# private utility functions
def _determine_max_num_frames(
samples: List[t.Tensor], frame_dim: int = 0, feature_dim: Optional[int] = None
):
if len(samples) <= 0:
raise ValueError("expected non-empty list")
if frame_dim == feature_dim:
raise ValueError("frame_dim and feature_dim cannot be equal")
if not (0 <= frame_dim <= 1):
raise ValueError(f"frame_dim should be either 0 or 1, not {frame_dim}")
if feature_dim is not None and not (0 <= feature_dim <= 1):
raise ValueError(f"feature_dim should be either 0 or 1, not {feature_dim}")
# assume [NUM_FRAMES] or [NUM_FRAMES, NUM_FEATURES]
max_frames = -1
num_features = None
for idx, sample in enumerate(samples):
num_dim = len(sample.shape)
if not (num_dim == 1 or num_dim == 2):
raise ValueError(
"only 1 or 2-dimensional samples are supported."
f"Received sample with shape {sample.shape}"
)
elif num_dim == 2:
if feature_dim is None:
raise ValueError(
"padding a 2 dimensional tensor requires setting `feature_dim`"
)
if idx == 0:
num_features = sample.shape[feature_dim]
elif num_features != sample.shape[feature_dim]:
raise ValueError(
"list has inconsistent number of features. "
f"Received at least {num_features} and {sample.shape[0]}"
)
num_frames = sample.shape[frame_dim]
if num_frames > max_frames:
max_frames = num_frames
return max_frames
def _generic_append_padding(
samples: List[t.Tensor],
padding_init: Callable[[int, int, int], t.nn.Module],
frame_dim: int = 0,
feature_dim: Optional[int] = 1,
):
max_frames = _determine_max_num_frames(samples, frame_dim, feature_dim)
padded_samples = []
for sample in samples:
num_dim = len(sample.shape)
num_frames = sample.shape[frame_dim]
padded_sample = padding_init(num_dim, num_frames, max_frames)(sample)
padded_samples.append(padded_sample)
return t.stack(padded_samples)
################################################################################
# constant collating
def collate_append_constant(
samples: List[t.Tensor],
value: float = 0,
frame_dim: int = 0,
feature_dim: Optional[int] = None,
):
def padding_init(num_dim: int, num_frames: int, max_frames: int, v=value):
padding_right = max_frames - num_frames
if num_dim == 1:
return ConstantPad1d((0, padding_right), v)
else:
if frame_dim == 0:
return ConstantPad2d((0, 0, 0, padding_right), v)
elif frame_dim == 1:
return ConstantPad2d((0, padding_right, 0, 0), v)
else:
raise ValueError("frame_dim can only be 0 or 1")
return _generic_append_padding(samples, padding_init, frame_dim, feature_dim)
################################################################################
#
# This file defines the preprocessing pipeline for train, val and test data.
#
# Author(s): Nik Vaessen
################################################################################
import torch as t
import torchaudio
from skeleton.data.batch import SpeakerClassificationDataSample
################################################################################
# train data pipeline
class Preprocessor:
def __init__(
self,
audio_length_seconds: float,
n_mels: int,
normalize: bool,
normalize_channel_wise: bool,
) -> None:
self.mfcc_transform = torchaudio.transforms.MelSpectrogram(n_mels=n_mels)
self.audio_length = int(audio_length_seconds * 16_000) # assume 16 khz sampling rate
self.normalize = normalize
self.normalize_channel_wise = normalize_channel_wise
def train_data_pipeline(
self,
sample: SpeakerClassificationDataSample,
) -> SpeakerClassificationDataSample:
sample.network_input = sample.network_input[0 : self.audio_length]
sample.network_input = self._calculate_mfcc(sample.network_input)
if self.normalize:
sample.network_input, _, _ = normalize(
sample.network_input, channel_wise=self.normalize_channel_wise
)
return sample
def val_data_pipeline(
self,
sample: SpeakerClassificationDataSample,
) -> SpeakerClassificationDataSample:
sample.network_input = sample.network_input[0 : self.audio_length]
sample.network_input = self._calculate_mfcc(sample.network_input)
if self.normalize:
sample.network_input, _, _ = normalize(
sample.network_input, channel_wise=self.normalize_channel_wise
)
return sample
def test_data_pipeline(
self,
sample: SpeakerClassificationDataSample,
) -> SpeakerClassificationDataSample:
# we evaluate on the full audio file
sample.network_input = self._calculate_mfcc(sample.network_input)
if self.normalize:
sample.network_input, _, _ = normalize(
sample.network_input, channel_wise=self.normalize_channel_wise
)
return sample
def _calculate_mfcc(self, wav_tensor: t.Tensor) -> t.Tensor:
# spectogram of shape [n_mels, n_frames]
return self.mfcc_transform(wav_tensor)
################################################################################
# utility method for normalizing
def normalize(spectogram: t.Tensor, channel_wise: bool):
if len(spectogram.shape) != 2:
raise ValueError("expect to normalize over 2D input")
if channel_wise:
# calculate over last dimension
# (assuming shape [n_mels, n_frames])
std, mean = t.std_mean(spectogram, dim=1)
else:
std, mean = t.std_mean(spectogram)
normalized_spectogram = (spectogram.T - mean) / (std + 1e-9)
normalized_spectogram = normalized_spectogram.T
return normalized_spectogram, mean, std
################################################################################
#
# This file implements a datamodule for the tiny-voxceleb dataset.
# The data is loaded from pre-generated webdataset shards.
# (See `scripts/generate_shards.py` for how to generate shards.)
#
# Author(s): Nik Vaessen
################################################################################
import pathlib
import json
from typing import Iterator, List, Optional, Tuple
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import webdataset as wds
from skeleton.data.batch import (
SpeakerClassificationDataBatch,
SpeakerClassificationDataSample,
)
from skeleton.data.preprocess import Preprocessor
from skeleton.evaluation.evaluator import EvaluationPair
################################################################################
# data module implementation
class TinyVoxcelebDataModule(LightningDataModule):
def __init__(
self,
shard_folder: pathlib.Path,
batch_size: int,
preprocessor: Preprocessor,
val_trials_path: pathlib.Path,
dev_trials_path: pathlib.Path,
):
super().__init__()
self.batch_size = batch_size
self.shard_folder = shard_folder
self.preprocessor = preprocessor
self.val_trials_path = val_trials_path
self.dev_trials_path = dev_trials_path
self._num_speakers = None
def setup(self, stage: Optional[str] = None) -> None:
# train dataloader
self.train_ds = init_dataset(self.shard_folder / "train")
self.train_ds = self.train_ds.map(
self.preprocessor.train_data_pipeline
).batched(self.batch_size, collation_fn=_collate_samples)
# val dataloader
self.val_ds = init_dataset(self.shard_folder / "val")
self.val_ds = self.val_ds.map(self.preprocessor.val_data_pipeline).batched(
self.batch_size, collation_fn=_collate_samples
)
# dev dataloader
# we explicitly evaluate with a batch size of 1
self.dev_ds = init_dataset(self.shard_folder / "dev")
self.dev_ds = self.dev_ds.map(self.preprocessor.test_data_pipeline).batched(
1, collation_fn=_collate_samples
)
def train_dataloader(self) -> TRAIN_DATALOADERS:
return self.train_ds
def val_dataloader(self) -> EVAL_DATALOADERS:
return self.val_ds
def test_dataloader(self) -> EVAL_DATALOADERS:
return self.dev_ds
@property
def num_speakers(self):
if self._num_speakers is None:
with (self.shard_folder / "train" / "meta.json").open("r") as f:
train_meta = json.load(f)
self._num_speakers = train_meta["num_speakers"]
return self._num_speakers
@property
def val_trials(self) -> List[EvaluationPair]:
return load_evaluation_pairs(self.val_trials_path)
@property
def dev_trials(self) -> List[EvaluationPair]:
return load_evaluation_pairs(self.dev_trials_path)
################################################################################
# helper methods
def init_dataset(root_folder: pathlib.Path):
return (
wds.WebDataset(_find_urls(root_folder))
.decode(wds.torch_audio)
.map(_map_decoded_dict_to_batch)
)
def _find_urls(root_shard_folder: pathlib.Path):
return [str(s) for s in root_shard_folder.glob("*.tar*")]
def _map_decoded_dict_to_batch(d: dict):
return SpeakerClassificationDataSample(
key=d["__key__"],
ground_truth=d["json"]["speaker_id_idx"],
network_input=d["wav"][0].squeeze(),
)
def _collate_samples(
sample_iter: Iterator[SpeakerClassificationDataSample],
) -> SpeakerClassificationDataBatch:
return SpeakerClassificationDataBatch.pad_right_collate_fn([s for s in sample_iter])
################################################################################
# method for reading the text files containing evaluation trials
def read_test_pairs_file(pairs_file_path: pathlib.Path) -> Tuple[bool, str, str]:
with pairs_file_path.open("r") as f:
for line in f.readlines():
line = line.strip()
if line.count(" ") < 2:
continue
gt, path1, path2 = line.strip().split(" ")
yield bool(int(gt)) if int(gt) >= 0 else None, path1, path2
def load_evaluation_pairs(file_path: pathlib.Path):
pairs = []
for gt, path1, path2 in read_test_pairs_file(file_path):
utt1id = path1.split(".wav")[0]
utt2id = path2.split(".wav")[0]
if path1.count("/") == 2:
spk1id = path1.split("/")[0]
spk2id = path2.split("/")[0]
if gt is not None and (spk1id == spk2id) != gt:
raise ValueError(f"read gt={gt} for line `{path1} {path2}`")
pairs.append(EvaluationPair(gt, utt1id, utt2id))
return pairs
################################################################################
#
# This file implements two quantitative measures for speaker identification:
#
# * equal error rate
# * minimum detection cost
#
# Author(s): Nik Vaessen
################################################################################
from operator import itemgetter
from typing import List, Tuple
import numpy as np
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
################################################################################
# helper methods for both measures
def _verify_correct_scores(
groundtruth_scores: List[int], predicted_scores: List[float]
):
if len(groundtruth_scores) != len(predicted_scores):
raise ValueError(
f"length of input lists should match, while"
f" groundtruth_scores={len(groundtruth_scores)} and"
f" predicted_scores={len(predicted_scores)}"
)
if not all(np.isin(groundtruth_scores, [0, 1])):
raise ValueError(
f"groundtruth values should be either 0 and 1, while "
f"they are actually one of {np.unique(groundtruth_scores)}"
)
################################################################################
# EER (equal-error-rate)
def calculate_eer(
groundtruth_scores: List[int], predicted_scores: List[float], pos_label: int = 1
):
"""
Calculate the equal error rate between a list of groundtruth pos/neg scores
and a list of predicted pos/neg scores.
Adapted from: https://github.com/a-nagrani/VoxSRC2020/blob/master/compute_EER.py
:param groundtruth_scores: a list of groundtruth integer values (either 0 or 1)
:param predicted_scores: a list of prediction float values (in range [0, 1])
:param pos_label: which value (either 0 or 1) represents positive. Defaults to 1
:return: a tuple containing the equal error rate and the corresponding threshold
"""
_verify_correct_scores(groundtruth_scores, predicted_scores)
if not all(np.isin([pos_label], [0, 1])):
raise ValueError(f"The positive label should be either 0 or 1, not {pos_label}")
fpr, tpr, thresholds = roc_curve(
groundtruth_scores, predicted_scores, pos_label=pos_label
)
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
thresh = interp1d(fpr, thresholds)(eer).item()
return eer, thresh
################################################################################
# minimum detection cost - taken from
# https://github.com/a-nagrani/VoxSRC2020/blob/master/compute_min_dcf.py
# Copyright 2018 David Snyder
# This script is modified from the Kaldi toolkit -
# https://github.com/kaldi-asr/kaldi/blob/8ce3a95761e0eb97d95d3db2fcb6b2bfb7ffec5b/egs/sre08/v1/sid/compute_min_dcf.py
def _compute_error_rates(
groundtruth_scores: List[int],
predicted_scores: List[float],
) -> Tuple[List[float], List[float], List[float]]:
"""
Creates a list of false-negative rates, a list of false-positive rates
and a list of decision thresholds that give those error-rates.
:param groundtruth_scores: a list of groundtruth integer values (either 0 or 1)
:param predicted_scores: a list of prediction float values (in range [0, 1])
:return: a triple with a list of false negative rates, false positive rates
and a list of decision threshold
for those rates.
"""
# Sort the scores from smallest to largest, and also get the corresponding
# indexes of the sorted scores. We will treat the sorted scores as the
# thresholds at which the the error-rates are evaluated.
sorted_indexes, thresholds = zip(
*sorted(
[(index, threshold) for index, threshold in enumerate(predicted_scores)],
key=itemgetter(1),
)
)
groundtruth_scores = [groundtruth_scores[i] for i in sorted_indexes]
fnrs = []
fprs = []
# At the end of this loop, fnrs[i] is the number of errors made by
# incorrectly rejecting scores less than thresholds[i]. And, fprs[i]
# is the total number of times that we have correctly accepted scores
# greater than thresholds[i].
for i in range(0, len(groundtruth_scores)):
if i == 0:
fnrs.append(groundtruth_scores[i])
fprs.append(1 - groundtruth_scores[i])
else:
fnrs.append(fnrs[i - 1] + groundtruth_scores[i])
fprs.append(fprs[i - 1] + 1 - groundtruth_scores[i])
fnrs_norm = sum(groundtruth_scores)
fprs_norm = len(groundtruth_scores) - fnrs_norm
# Now divide by the total number of false negative errors to
# obtain the false positive rates across all thresholds
fnrs = [x / float(fnrs_norm) for x in fnrs]
# Divide by the total number of correct positives to get the
# true positive rate. Subtract these quantities from 1 to
# get the false positive rates.
fprs = [1 - x / float(fprs_norm) for x in fprs]
return fnrs, fprs, thresholds