FastAPI+SQLAlchemyをpytestでテスト

Posted by rhoboro on 2021-02-27

FastAPIでSQLAlchemyを利用するサンプルコードが公式ドキュメントにあります。 サンプルコードのmain.pyの重要な部分を抜粋するとこんな感じです。

app = FastAPI()

# Dependency
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
    users = crud.get_users(db, skip=skip, limit=limit)
    return users

SessionLocal()で作成したセッションオブジェクトをDIによりリクエストハンドラで利用できるようにしています。

この記事ではpytestを使って、この /users/ エンドポイントをテストすることを考えます。 このエンドポイントの処理はDBの状態によってレスポンスが変わるため、テストでは次のフローを行うことになります。

  1. DBの状態を意図した状態にする
  2. このエンドポイントにリクエストを投げる
  3. レスポンスが意図したものになっていることを確認する

なお、今回のコードはrhoboro/fastapi-sqlalchemyリポジトリに置いています。

FastAPI のテストクライアント

FastAPIでエンドポイントごとのテストを行う場合は、公式ドキュメントで紹介されているTestClientを利用します。 TestClientを利用するにはrequestsが必要なのでpytestと一緒にインストールし、ついでにテストコードを格納するパッケージも用意します。

$ pip install pytest requests
$ mkdir sql_app/tests
$ touch sql_app/tests/__init__.py

それでは、 sql_app/tests/test_main.py を用意し、このエンドポイントのテストを書いていきます。

from fastapi.testclient import TestClient

from sql_app.main import app


client = TestClient(app)


def test_read_users_empty():
    response = client.get("/users/")
    assert response.status_code == 200
    assert response.json() == []

これで最初のテストが完成しました。 TestClientのインスタンスが持つget()メソッドを使うと、このエンドポイントに実際にリクエストが投げられ、レスポンスが返ってきます。

pytestコマンドでこのテストケースを実行してみましょう。 まだユーザーを作成していなければDBは空の状態なので、このテストはパスするはずです。

$ pytest -v
=========================================================================================== test session starts ============================================================================================
platform darwin -- Python 3.9.1, pytest-6.2.2, py-1.10.0, pluggy-0.13.1 -- /Users/rhoboro/go/src/github.com/rhoboro/fastapi-sqlalchemy/venv/bin/python3
cachedir: .pytest_cache
rootdir: /Users/rhoboro/go/src/github.com/rhoboro/fastapi-sqlalchemy
collected 1 item

sql_app/tests/test_main.py::test_read_users_empty PASSED

これで先ほどあげたフローのうち、2.と3.の方法がわかりました。

  1. DBの状態を意図した状態にする
  2. このエンドポイントにリクエストを投げる
  3. レスポンスが意図したものになっていることを確認する

それでは、ここから1.の部分について深掘りしていきます。

pytestで前処理と後処理を記述する

テストコードを書く際、「1. DBの状態を意図した状態にする」は一般に「前処理」と呼ばれます。 pytestでこの「前処理」やテストを行った後に実行される「後処理」を行うには、fixtureと呼ばれる機能を使います。

それでは sql_app/tests/conftest.py を用意して、fixtureを書いていきます。 シンプルなfixtureの例は次のようになります。 fixtureの詳細は公式ドキュメントを確認してください。

import pytest

@pytest.fixture(scope="function")
def test_db():
    print("SetUp")

    # ここより上に前処理を書く
    yield 1  # ここでテストが実行される
    # ここより下に後処理を書く

    print("TearDown")

この fixture を利用するには、テストケースの引数にこのfixture名を指定するだけです。 ここで実引数として渡されてくる値は、fixtureでyieldに渡した値になります。

def test_read_users_empty(test_db):
    print(test_db)  # => 1
    ...

この状態でテストを実行すると、テストケースの実行前後で前処理、後処理が行われるので 実際に試してみてください。 pytest は標準出力を常にキャプチャして失敗時のみ出力するため、テストケースに assert False を入れるなどして失敗するようにしておくと確認しやすいです。

テスト用のDBを用意する

ここからは実際にテスト用のDBを用意し、テストしたい処理を実行する前にデータを投入していきます。 先にテストケースのコードを記載しておきます。

テストケースには test_db という名前のfixtureで作成したSQLAlchemyのセッションオブジェクトを渡しています。 そして、そのセッションオブジェクトを利用してエンドポイントへのリクエストを実行する前にDBにデータを追加し、追加したデータとレスポンスが一致するか確認しています。

from fastapi.testclient import TestClient

from sql_app.main import app
from sql_app.models import User

client = TestClient(app)

...

def test_read_users(test_db):
    # テスト用のデータを用意
    user1 = User(email="[email protected]", hashed_password="unsecurepass")
    user2 = User(email="[email protected]", hashed_password="unsecurepass")
    test_db.add_all([user1, user2])
    test_db.flush()
    test_db.commit()

    # テスト対象の処理を実行
    response = client.get("/users/")

    # 処理の結果の確認
    assert response.status_code == 200
    assert response.json() == [
        {"email": "[email protected]", "id": 1, "is_active": True, "items": []},
        {"email": "[email protected]", "id": 2, "is_active": True, "items": []},
    ]

セッションオブジェクトを作成している test_db の中身は次のようになっています。

まず、テストDBでの永続化は必要ありません。 そのためSessionクラスのサブクラスTestingSessionを用意し、commit()呼び出し時に永続化しないようにしておきます(参考)。 そして、リクエストハンドラ内でセッションインスタンスを取得するためのget_db()を上書きし、TestingSessionクラスが利用されるようにします。

import pytest
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm.session import close_all_sessions

from sql_app.main import app, get_db
from sql_app.models import Base


class TestingSession(Session):
    def commit(self):
        # テストなので永続化しない
        self.flush()
        self.expire_all()


@pytest.fixture(scope="function")
def test_db():
    engine = create_engine("sqlite:///./sql_app_test.db", connect_args={"check_same_thread": False})
    Base.metadata.create_all(bind=engine)

    TestSessionLocal = sessionmaker(
        class_=TestingSession, autocommit=False, autoflush=False, bind=engine
    )

    db = TestSessionLocal()

    # sql_app/main.py の get_db() を差し替える
    # https://fastapi.tiangolo.com/advanced/testing-dependencies/
    def get_db_for_testing():
        try:
            yield db
            db.commit()
        except SQLAlchemyError as e:
            assert e is not None
            db.rollback()

    app.dependency_overrides[get_db] = get_db_for_testing

    # テストケース実行
    yield db

    # 後処理
    db.rollback()
    close_all_sessions()
    engine.dispose()

これでテストを実行してみると、テストは成功するはずです。

$ pytest -vv                                                                                                       [main:fastapi-sqlalchemy]
=========================================================================================== test session starts ============================================================================================
platform darwin -- Python 3.9.1, pytest-6.2.2, py-1.10.0, pluggy-0.13.1 -- /Users/rhoboro/go/src/github.com/rhoboro/fastapi-sqlalchemy/venv/bin/python3
cachedir: .pytest_cache
rootdir: /Users/rhoboro/go/src/github.com/rhoboro/fastapi-sqlalchemy
collected 1 item

sql_app/tests/test_main.py::test_read_users_empty PASSED                                                                                                                                              [50%]
sql_app/tests/test_main.py::test_read_users PASSED                                                                                                                                                   [100%]

============================================================================================ 1 passed in 1.39s =============================================================================================

query_propertyを利用している場合

SQLAlchemy ORM を利用している場合は、次のようにscoped_sessionとquery_propertyを利用しているかもしれません。

from sqlalchemy.orm import scoped_session, sessionmaker


SessionLocal = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))

Base = declarative_base()
Base.query = SessionLocal.query_property()

実は、この場合さきほどのテストは追加したデータがうまく読み取れず空のデータが返ってくるためテストは失敗します。 これは、テストケースの実行中に使っているセッションオブジェクトと、リクエストハンドラ内で使っているセッションオブジェクトが違うことに起因しています。 TestingSessionではDBへの変更が永続化されないようにしているため、テストケースのセッションオブジェクトでのDB状態の変更を、リクエストハンドラが利用するセッションオブジェクトから知る方法がないためです。

これを解決するには、テストケースの実行中は常に同一のセッションオブジェクトが利用されるようにすればよいです。 具体的には次のようなコードにします。

@pytest.fixture(scope="function")
def test_db():
    ...
    # テストケースごとに異なるスコープを持たせる
    function_scope = uuid4().hex
    TestSessionLocal = scoped_session(
        sessionmaker(class_=TestingSession, autocommit=False, autoflush=False, bind=engine),
        scopefunc=lambda: function_scope,
    )
    Base.query = TestSessionLocal.query_property()

    ...
    # 後処理
    TestSessionLocal().rollback()
    TestSessionLocal.remove()
    close_all_sessions()
    engine.dispose()

scoped_sessionでは、同じスコープ内であれば同じセッションオブジェクトが返されます。 デフォルトではこのスコープがスレッドローカルになっていますが、テストケースの実行とリクエストハンドラの処理が別スレッドで行われているため、それぞれで違うセッションオブジェクトになっています。

scoped_sessionのスコープを変更するには scopefunc 引数が使えます。 この引数にハッシュ化可能なオブジェクトを返す関数を指定すると、そのハッシュ値で同一スコープかどうかの判定が行われます。 上記のコードでは、test_dbの実行ごと、つまりテストケースごとに同一のスコープとなるようにscopefuncを設定しています。 そのため、 query_property を利用していても利用可能なテストコードになっています。

おまけ

ここまではsqlite3を利用していましたが、業務ではPostgreSQLやMySQLなどを使うことも多いです。 これらを利用しているとテスト用のデータベースを用意するのが地味に面倒なため、わたしは次のようにテスト実行時にtestデータベースを作成し、終了時に削除しています。 なお、アプリケーション本体ではデフォルトのデータベースを使い、テスト時には test_db 内のcreate_engine()に渡す際に /test を追加し、テスト時のみtestデータベースを参照するようにしています。(参考

@pytest.fixture(scope="session", autouse=True)
def test_database() -> Generator:
    engine = create_engine(DB_URI + "/test")
    conn = engine.connect()
    # トランザクションを一度終了させる
    conn.execute("commit")
    try:
        conn.execute("drop database test")
    except SQLAlchemyError as e:
        pass
    finally:
        conn.close()

    conn = engine.connect()
    # トランザクションを一度終了させる
    conn.execute("commit")
    conn.execute("create database test")
    conn.close()

    yield

    conn = engine.connect()
    # トランザクションを一度終了させる
    conn.execute("commit")
    conn.execute("drop database test")
    conn.close()

後処理がうまくできていないと最後の drop database 実行時にエラーとなるので、エラーが起きた際はテスト後の後処理を見直してみてください。