ITEEK株式会社のロゴ

ITEEK

ブログ

Blogs

実践Pytest入門(PyCon JP 2024)

2025-04-04

実践Pytest入門(PyCon JP 2024)

概要

レガシーコードは、テストがなく修正や拡張が難しいコードを指します。

このようなコードは技術的負債となり、長期的には開発速度の低下や保守コストの増大につながります。


この問題を解決するために、Pytestを活用した単体テストの実践が効果的です。

Pytestは、Pythonのテストフレームワークで、テストコードの作成と実行を容易にします。


実際にどのようにテストを書くことができるか見ていきましょう。


シンプルな関数のテスト

身長と体重からBMIを計算する例


========================

def calc_bmi(*, height: float, weight: float) -> float:

if (height <= 0) or (weight <= 0):

raise ValueError("Height and weight must br greater then 0")

return weight / (height ** 2)

========================


多数の分岐条件を持つケースのテスト

所得額から給与所得控除額を計算する例


========================

import pytest


def calc_exemption_amount(*, income: int) -> int:

if income < 0:

raise ValueError("Income must be positive.")

if income <= 1_625_000:

return 550_000

if 1_625_000 < income <= 1_800_000:

return int(income * 0.4) - 100_000

if 1_800_000 < income <= 3_600_000:

return int(income * 0.3) - 800_000

if 3_600_000 < income <= 6_600_000:

return int(income * 0.2) - 440_000

if 6_600_000 < income <= 8_500_000:

return int(income * 0.1) - 1_100_000

return 1_950_000

class TestCalcExemptionAmount:

@pytest.mark.parametrize(

("income", "expected"),

[

(1_625_000, 550_000),

(1_625_003, 550_001),

(1_800_000, 620_000),

(1_800_001, 1_340_000),

(3_600_000, 1_880_000),

(3_600_001, 1_160_000),

(6_600_000, 1_760_000),

(6_600_001, 1_760_001),

(8_500_000, 1_950_000),

(8_500_001, 1_950_001),

]

)

def test_income_and_exemption(self, income, expected):

assert calc_exemption_amount(income) == expected

========================


OS環境変数に依存するケースのテスト

OS環境変数に設定されているAPIのURLを取得する例


========================

"""tests/env/conftest.py"""

import pytest


@pytest.fixture

def mock_env_api_url(monkeypatch):

monkeypatch.setenv("API_URL", "http://localhost:8080")

"""tests/env/test_env.py"""

import pytest


os.environ["API_URL"] = "https://production.example.com"


def get_api_url() -> str | None:

return os.getenv("API_URL")

class TestGetAPIURL:

def test_get_api_url(self, mock_env_api_url):

assert get_api_url() == "http://localhost:8080"

========================


システム日時に依存するケースのテスト

システム時刻が営業時間中かどうかを取得する例


========================

import pytest

from freezegun import freeze_time

from datetime import datetime, time


def is_in_business() -> bool:

now = datetime.now()

if now.weekday() in (5, 6):

return False

if time(9, 0, 0) <= now.time() <= time(17, 0, 0):

return True

return False


class TestIsInBusiness:

@pytest.mark.parametrize(

("now", "expected"),

[

("2024-09-27 08:59:59", False),

("2024-09-27 09:00:00", True),

("2024-09-27 17:00:00", True),

("2024-09-27 17:00:01", False),

("2024-09-28 12:00:00", False),

("2024-09-29 12:00:00", False),

]

)

def test_is_in_business(self, now, expected):

with freeze_time(now):

assert is_in_business() == expected

========================


ファイル入出力のテスト

ファイルのテキストを読み込み、文字列に含まれる猫を犬に置換してファイル出力する例


========================

import pytest

from pathlib import Path

import re


def cat_to_dog(*, input_path: Path, output_path: Path) -> None:

input_text = input_path.read_text()

output_text = re.sub("猫", "犬", input_text)

output_path.write_text(output_text)

class TestCatToDog:

def test_normal(self, tmp_path):

intput_path = tmp_path / "input.txt"

output_path = tmp_path / "output.txt"

input_path.write_text("吾輩は猫である。名前はまだない。")

cat_to_dog(input_path=input_path, output_path=output_path)

assert output_path.read_text() == "吾輩は犬である。名前はまだない。"

========================


外部APIに依存したケースのテスト

郵便番号から住所を取得する外部APIを使用する例


========================

"""tests/api/conftest.py"""

import re

import pytest

import requests

from dataclasses import dataclass


ResultsType = list[dict[str, str]] | None


@dataclass

class MockResponse:

message: str | None = None

results: ResultsType = None

def raise_for_status(self) -> None:

return None

def json(self) -> dict[str, str | ResultsType]:

return {"message": self.message, "results": self.results}

@pytest.fixture

def mock_response(monkeypatch) -> None:

def mock_get(*args, **kwargs) -> MockResponse:

zipcode = kwargs["params"]["zipcode"]

if zipcode == "0000000":

return MockResponse()

elif re.match("^[0-9]{7}$", zipcode):

return MockResponse(results=[{"address1": "都道府県", "address2": "市区町村", "address3": "番地"}])

else:

return MockResponse(message="郵便番号の桁数や値が不正です")

monkeypatch.setattr(requests, "get", mock_get)

"""tests/api/test_api.py"""

import pytest

import requests


ENDPOINT = "http://zipcloud.ibsnet.co.jp/api/search"


def get_address(*, zipcode: str) -> str | None:

response = requests.get(ENDPOINT, params={"zipcode": zipcode}, timeout=5)

response.raise_for_status()

data = response.json()

if (message := data["message"]) is not None:

raise ValueError(message)

if (results := data["results"]) is None:

return None

return f"{results[0]['address1']} {results[0]['address2']} {results[0]['address3']}"

class TestGetAddress:

@pytest.mark.parametrize(

("zipcode", "expected"),

[

("0000000", None),

("1111111", "都道府県 市区町村 番地")

]

)

def test_get_address(self, mock_response, zipcode, expected):

assert get_address(zipcode=zipcode) == expected

@pytest.mark.parametrize("zipcode", ["1", "12345678", "dummy"])

def test_invalid_zipcode(self, mock_response, zipcode):

with pytest.raises(ValueError) as e:

get_address(zipcode=zipcode)

assert str(e.value) == "郵便番号の桁数や値が不正です"

========================


DB接続を伴うケースのテスト

DBのユーザを追加取得するテスト


========================

"""src/db.py"""

import os

from datetime import date


from sqlalchemy import String, create_engine, insert, select

from sqlalchemy.engine import URL

from sqlalchemy.orm import declarative_base, Mapped, mapped_column, scoped_session, sessionmaker

from sqlalchemy.orm.scoping import scoped_session as scoping_scoped_session


Base = declarative_base()


class User(Base):

__tablename__ = "user"

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)

name: Mapped[str] = mapped_column(String(12), nullable=False)

birthday: Mapped[date] = mapped_column(nullable=False)

def __eq__(self, other) -> bool:

return (self.id == other.id) and (self.name == other.name) and (self.birthday == other.birthday)

DATABASE_CONFIG = {

"drivername": "mysql+pymysql",

"username": os.environ["MYSQL_USER"],

"password": os.environ["MYSQL_PASSWORD"],

"host": os.environ["MYSQL_HOST"],

"port": os.environ["MYSQL_POST"],

"database": os.environ["MYSQL_DATABASE"],

"query": {"charset": "utf8"}}

}

engine = create_engine(URL.create(**DATABASE_CONFIG), echo=False)

Base.metadata.create_all(engine)

Session = scoped_session(sessionmaker(engine))


def get_user(db_session: scoping_scoped_session, user_id: int) -> User | None:

stmt = select(User).where(User.id == user_id)

return db_session.scalar(stmt)

def add_user(db_session: scoping_scoped_session, user: User) -> int:

with db_session() as session:

session.add(user)

session.commit()

return user.id

"""tests/api/conftest.py"""

import os

from operator import itemgetter

import pytest

from pytest_mysql import factories

from sqlalchemy import create_engine

from sqlalchemy.engine import URL

from sqlalchemy.orm import scoped_session, sessionmaker

from sqlalchemy.pool import NullPool

from db import Base


host, port, user, passwd = itemgetter("MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD")(os.environ)

mysql_noproc = factories.mysql_noproc(host=host, port=port, user=user)

mysql_fixture = factories.mysql("mysql_noproc", password=password, dbname="test")


@pytest.fixture

def test_session(mysql_fixture):

url = URL.create(drivername="mysql+pymysql", username=user, password=passwd, host=host, port=port, database="test", query={"charset": "utf8"})

engine = create_engine(url, echo=False, poolclass=NullPool)

Base.metadata.create_all(engine)

session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))

try:

yield session

except Exception:

session.rollback()

else:

session.commit()

finally:

session.close()

Base.metadata.drop_all(engine)

"""tests/api/user.py"""

from datetime import date

from db import User, get_user, add_user


class TestUser:

def test_get_user(self, test_session):

user = User(name="sato", birthday=date(1999, 12, 31))

with test_session() as session:

session.add(user)

session.commit()

user_id = user.id

user_ = get_user(db_session=test_session, user_id=user_id)

assert user == user_

def test_no_user(self, test_session):

assert get_user(db_session=test_session, user_id=1) is None

def test_add_user(self, test_session):

user = User(name="sato", birthday=date(1999, 12, 31))

user_id = add_user(db_session=test_session, user=user)

assert user == test_session.scalar(select(User).where(User.id == user_id))

def test_duplicate_pk(self, test_session):

try:

for _ in range(2):

add_user(db_session=test_session, user=User(id=1, name="sato", birthday=date(1999, 12, 31)))

except sqlalchemy.exc.IntegrityError as e:

assert isinstance(e.orig, pymysql.err.IntegrityError)

assert e.orig.args[0] == 1062

========================


まとめ

これらのパターンを理解し、適切にテストを実装することで、コードの品質を維持しやすくなります。

さらに、テスト駆動開発(TDD)を採用することで、テストとリファクタリングを同時に行うことができます。

TDDのサイクルは以下のようにすると良いです。

  • 失敗するテストを書く(レッド)
  • テストを通すコードを書く(グリーン)
  • コードをリファクタリングする


このアプローチにより、テストが容易なコードが自然と生まれ、結果的に開発スピードの向上につながります。

Pytestを活用し、適切なテストを書くことで、レガシーコード化を防ぎ、長期的に保守性の高いアプリケーションを開発していきましょう。

お問い合わせ

Contact

案件のご相談やお見積りなど、お気軽にご連絡ください。