class ParakeetExtractor:
def __init__(self, config: PretrainedConfig) -> None:
self.config = ExtractorConfig.from_hf_config(config)
"""`config` is named *exactly* for `._get_subsampling_output_length` below"""
self._clip_target_samples = int(
round(self.config.clip_duration_s * self.config.sampling_rate)
)
self._tail_min_samples = int(
round(self.config.clip_min_duration_s * self.config.sampling_rate)
)
@staticmethod
@cache
def _get_window(win_length: int, device: str) -> torch.Tensor:
return torch.hann_window(win_length, periodic=False, device=device)
@staticmethod
@cache
def _get_mel_filters(
feature_size: int, sampling_rate: int, n_fft: int, device: str
) -> torch.Tensor:
filter_bank = mel_filter_bank(
num_frequency_bins=n_fft // 2 + 1,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=sampling_rate / 2,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
return torch.from_numpy(filter_bank.T).to(device=device, dtype=torch.float32)
def _torch_extract_fbank_features(self, waveform: torch.Tensor, device: str):
# spectrogram
device = str(torch.device(device))
cfg = self.config
window = self._get_window(cfg.win_length, device)
stft = torch.stft(
waveform,
self.config.n_fft,
hop_length=cfg.hop_length,
win_length=cfg.win_length,
window=window,
return_complex=True,
pad_mode="constant",
)
mel_filters = self._get_mel_filters(
cfg.feature_size, cfg.sampling_rate, cfg.n_fft, device
)
return self._apply_mel_filters(stft, mel_filters)
@torch.compile(dynamic=True)
def _apply_mel_filters(
self, stft_output: torch.Tensor, mel_filters: torch.Tensor
) -> torch.Tensor:
magnitudes = stft_output.real.square() + stft_output.imag.square()
mel_spec = mel_filters @ magnitudes
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
return mel_spec.permute(0, 2, 1)
@torch.compile(dynamic=True)
def _apply_preemphasis(
self, input_features: torch.Tensor, audio_lengths: torch.Tensor
) -> torch.Tensor:
timemask = torch.arange(
input_features.shape[1], device=input_features.device
).unsqueeze(0) < audio_lengths.unsqueeze(1)
input_features = torch.cat(
[
input_features[:, :1],
input_features[:, 1:]
- self.config.preemphasis * input_features[:, :-1],
],
dim=1,
)
input_features = input_features.masked_fill(~timemask, 0.0)
return input_features
@torch.compile(dynamic=True)
def _normalize_mel_features(
self, mel_features: torch.Tensor, audio_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
features_lengths = torch.floor_divide(
audio_lengths + self.config.n_fft // 2 * 2 - self.config.n_fft,
self.config.hop_length,
)
attention_mask = (
torch.arange(mel_features.shape[1], device=mel_features.device)[None, :]
< features_lengths[:, None]
)
mask = attention_mask.unsqueeze(-1)
lengths = attention_mask.sum(dim=1)
mel_features_masked = mel_features * mask
mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1)
variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (
lengths - 1
).unsqueeze(-1)
std = torch.sqrt(variance).unsqueeze(1)
return (mel_features - mean) / (std + EPSILON) * mask, attention_mask
def _pad_raw_speech(
self, raw_speech: list[torch.Tensor], max_len: int, device: str
) -> torch.Tensor:
output = torch.full(
(len(raw_speech), max_len),
self.config.padding_value,
device=device,
dtype=torch.float32,
)
dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))]
srcs = [s.squeeze(-1) for s in raw_speech]
# single kernel horizontal fusion
torch._foreach_copy_(dsts, srcs)
return output
def _clip_sizes(self, audio_len: int) -> list[int]:
audio_len = max(audio_len, self._tail_min_samples)
num_full_clips, remainder = divmod(audio_len, self._clip_target_samples)
clip_sizes = [self._clip_target_samples] * num_full_clips
if remainder > 0:
clip_sizes.append(max(remainder, self._tail_min_samples))
return clip_sizes
def audio_token_count(self, audio_len: int) -> int:
total_tokens = 0
for clip_size in self._clip_sizes(audio_len):
num_frames = clip_size // self.config.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
total_tokens += int(n_tokens.item())
return max(1, total_tokens)
def split_audio_into_clips(self, audio: torch.Tensor) -> list[torch.Tensor]:
assert audio.ndim == 1
audio_len = int(audio.shape[0])
clip_sizes = self._clip_sizes(audio_len)
target_len = sum(clip_sizes)
if audio_len < target_len:
audio = torch.nn.functional.pad(audio, (0, target_len - audio_len))
clips = list[torch.Tensor]()
offset = 0
for clip_size in clip_sizes:
clips.append(audio[offset : offset + clip_size])
offset += clip_size
return clips
def __call__(
self,
raw_speech: list[np.ndarray],
*,
device: str = "cpu",
) -> dict[str, Any]:
raw_speech = [
torch.as_tensor(speech, device=device, dtype=torch.float32)
for speech in raw_speech
]
for i, speech in enumerate(raw_speech):
if len(speech.shape) > 1:
logger.warning(
"Only mono-channel audio is supported for input to %s. "
"We will take the mean of the channels to convert to mono.",
self.__class__.__name__,
)
raw_speech[i] = speech.mean(-1)
audio_clips = list[torch.Tensor]()
audio_num_clips = list[int]()
for audio in raw_speech:
clips = self.split_audio_into_clips(audio)
audio_clips.extend(clips)
audio_num_clips.append(len(clips))
raw_speech = audio_clips
audio_lengths = torch.tensor(
[len(speech) for speech in raw_speech], dtype=torch.long, device=device
)
max_length = max(len(speech) for speech in raw_speech)
input_features = self._pad_raw_speech(raw_speech, max_length, device)
input_features = self._apply_preemphasis(input_features, audio_lengths)
input_features = self._torch_extract_fbank_features(input_features, device)
input_features, attention_mask = self._normalize_mel_features(
input_features, audio_lengths
)
return {
"input_audio_features": input_features,
"feature_attention_mask": attention_mask,
"audio_num_clips": audio_num_clips,
}
@staticmethod
def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int:
config = ExtractorConfig.from_hf_config(raw_config)
return int(audio_tokens * config.subsampling_factor * config.hop_length)