Rendezvous-Tokyo

無料枠のCerebrium(Saas)でtacotron2の推論を実行してもらう

前回、tacotronの学習から推論まで行なった。

せっかくなので推論をAPIで呼び出せるようにしたいなと思い調べたら
cerebrium (opens new window)なるものを見つけた。

比較的簡単に使えたが、あまり記事がなかったので参考までに残しておく。

軽く試して諦めたもの

これらも無料枠で試してみたが、基本的にCPUしか使えないため動かなかった。

cerebriumへのデプロイ

cerebrium (opens new window)のアカウントを作るとチュートリアルあるので済ませておく。
まずは cerebrium init first-project で基本ファイルを生成する。

  • main.py
  • requirements.txt
  • pkglist.txt
  • conda_pkglist.txt
  • config.yaml

1. ルートディレクトリにtacotronのソースをコピーする

cerebrium init first-project で作ったディレクトリに前回のソースを配置。

2. main.pyの編集

  • Itemのtextはリクエストパラメータで使う。
  • predict()内で推論させてAPIの戻り値としてbase64の音声ファイルを返却する。
import base64
from pydantic import BaseModel
from inference import gen


class Item(BaseModel):
    text: str


def predict(item, run_id, logger):
    item = Item(**item)

    filename = gen(item.text)
    with open(filename, "rb") as file:
        mp3_data = file.read()
    mp3_base64 = base64.b64encode(mp3_data).decode("utf-8")
    return mp3_base64

3. requirements.txt

typing-extensions==4.6.1
falcon==1.2.0
inflect==0.2.5
librosa==0.9.1
Unidecode==0.4.20
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
tensorflow==2.14.0
tensorflow-estimator==2.14.0
typing_extensions==4.6.1
IPython==8.17.2
matplotlib==3.7.2
numpy==1.24.3
pyopenjtalk-prebuilt==0.3.0
decorator==5.1.1

4. config.yaml

トレーニングデータなんかは exclude に登録。

%YAML 1.2
---
hardware: TURING_5000
cpu: 2
min_replicas: 0
log_level: INFO
include: '[./*, main.py, requirements.txt, pkglist.txt, conda_pkglist.txt]'
exclude: '[./.*, ./__*, ./out*, ./save*]'
cooldown: 60
disable_animation: false

5. inference.py

  • mian.pyから呼び出せるように関数化。
  • ローカルのMacではCPUが使われて、Cerebrium上ではcudaが使われるように分岐を追加
  • 保存ファイルはCerebrium上では完全パス指定。
import sys

import IPython.display as ipd
import matplotlib
import matplotlib.pylab as plt

sys.path.append("waveglow/")
import datetime

import numpy as np
import pyopenjtalk
import torch
import torchaudio
from torch import mps

from audio_processing import griffin_lim
from hparams import create_hparams
from layers import STFT, TacotronSTFT
from model import Tacotron2
from text import text_to_sequence
from train import load_model


def gen(text: str):
    torch.nn.Module.dump_patches = False
    hparams = create_hparams()

    checkpoint_path = "checkpoint_12000"
    model = load_model(hparams)
    model.load_state_dict(torch.load(checkpoint_path)["state_dict"])

    cuda_available = torch.cuda.is_available()

    if cuda_available:
        model = model.cuda().eval()
    else:
        model = model.to("cpu").eval()

    waveglow_path = "waveglow_256channels_universal_v5.pt"
    waveglow = torch.load(waveglow_path)["model"]

    if cuda_available:
        waveglow = waveglow.cuda().eval()
    else:
        waveglow = waveglow.to("cpu").eval()

    for k in waveglow.convinv:
        k.float()

    phoneme_text = (
        pyopenjtalk.g2p(text, kana=False).replace("pau", ",").replace(" ", "")
    )
    sequence = np.array(text_to_sequence(phoneme_text, ["basic_cleaners"]))[None, :]

    if cuda_available:
        model = model.float().to("cuda")
        sequence = torch.from_numpy(sequence).to("cuda").long()
    else:
        model = model.to("cpu")
        sequence = torch.from_numpy(sequence).to("cpu").long()

    mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)

    with torch.no_grad():
        audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)

    if cuda_available:
        audio_tensor = torch.from_numpy(audio[0].data.cpu().numpy())
    else:
        audio_tensor = torch.from_numpy(audio[0].data.cpu().numpy())

    audio_tensor = audio_tensor.unsqueeze(0)

    filename = None  # ファイル名を初期化

    if cuda_available:
        filename = "/worker/app/result_cuda.wav"
    else:
        filename = "./result_cpu.wav"

    torchaudio.save(
        filepath=filename, src=audio_tensor, sample_rate=hparams.sampling_rate
    )
    return filename


if __name__ == "__main__":
    gen("デフォルトテキスト")

デプロイ

次のコマンドでデプロイする

cerebrium deploy test-model --config-file ./config.yaml --disable-syntax-check --disable-predict

成功するとログの最後にリクエストのサンプルコマンドがもらえるから叩いて確認できる

curl -X POST https://run.cerebrium.ai/v3/p-12345hoge/test-model/predict \
     -H 'Content-Type: application/json' \
     -H 'Authorization: hoge \
     --data '{"text": "はじめまして"}'

確認してみる

ChatGPTに依頼してCerebriumにリクエストするHTML生成してもらった。

感想

まず無料で使える時点で神。
サポートのチャットの返信も速くて好印象でした。

以上。