GraphQLのN+1問題をDataloaderで解決

こんにちは。株式会社PRVENT開発部、バックエンドチームの横川です。
今回は、GraphQLで発生するN+1問題と弊社の解決方法を紹介しようと思います。

N+1問題とは?

ループ処理の中で都度SQLが発行され、SQLが大量増殖してしまう問題です。

  • テーブルAからN件のデータ取得を1回
  • テーブルBから、テーブルAの各行に紐づくデータを(1件ずつ)取得を計N回

1回+N回のSQLが発行されるので、N+1問題です。(実行順で言うと1+N)
Nの数が多いとその分処理に時間がかかりますし、DBにも負荷がかかります。

GraphQLで起きるN+1問題

前回の記事のコードをリファクタしつつ、まずはN+1対策せずにお薬データを全件取得するQueryを実装しました。

# schema.graphqls

type Query {
  Medicine(ID: ID!): MedicineDetail!
  Medicines: [MedicineDetail!]! # ←全てのお薬取得
}

# お薬
type MedicineDetail {
  ID: ID!
  name: String!
  medicneCategory: MedicineCategoryDetail!
}

# お薬カテゴリ
type MedicineCategoryDetail {
  ID: ID!
  name: String!
}
// schema.resolvers.go(一部抜粋)

package graph

import (
    // 省略
)

// お薬カテゴリ取得 (お薬N件に対してN回実行)
func (r *medicineDetailResolver) MedicneCategory(ctx context.Context, obj *model.MedicineDetail) (*model.MedicineCategoryDetail, error) {
    // DB接続
    db := infra.DBConnection()
    defer db.Close()
    // sqlログ出力設定
    db = infra.LogConf(db)
    var m model.MedicineCategory
    // お薬カテゴリ取得
    err := db.QueryRow("SELECT * FROM medicine_categories where id = $1", obj.MedicneCategory.ID).Scan(&m.ID, &m.Name)
    if err != nil {
        return nil, err
    }
    return &model.MedicineCategoryDetail{
        ID:   m.ID,
        Name: m.Name,
    }, nil
}

// お薬全件取得
func (r *queryResolver) Medicines(ctx context.Context) ([]*model.MedicineDetail, error) {
    db := infra.DBConnection()
    defer db.Close()
    db = infra.LogConf(db)
    ms := []*model.MedicineDetail{}
    // お薬全件取得
    rows, err := db.Query("SELECT * FROM medicines")
    if err != nil {
        return nil, err
    }
    defer rows.Close()
    for rows.Next() {
        var m model.Medicine
        if err := rows.Scan(&m.ID, &m.Name, &m.MedicineCategoryID); err != nil {
            return nil, err
        }
        ms = append(ms, &model.MedicineDetail{
            ID:   m.ID,
            Name: m.Name,
            MedicneCategory: &model.MedicineCategoryDetail{
                ID: m.MedicineCategoryID,
            },
        })
    }
    return ms, nil
}

↓こんなデータを用意して

medicine_categories(お薬カテゴリテーブル)

ID name
1 頭痛薬
2 風邪薬
3 痒み止め

medicines(お薬テーブル)

ID name medicine_category_id
1 ロキソニン 1
2 バファリン 2
3 ムヒ 3

以下のリクエストを投げると

スクリーンショット 2022-09-29 15.22.16(2).png (104.4 kB)

QueryContext query="SELECT * FROM medicines"
QueryContext args=["2"] query="SELECT * FROM medicine_categories where id = $1"
QueryContext args=["3"] query="SELECT * FROM medicine_categories where id = $1"
QueryContext args=["1"] query="SELECT * FROM medicine_categories where id = $1"

お薬3件取得に対して3回お薬カテゴリ取得のSQLが発行されます。

REST APIなどで発生するN+1問題だとORMのメソッドで関連データを先読みしたり、テーブル結合で解決するかと思います。
ですが、GraphQLだと取得するデータをクライアント側で選択できることが強みなので関連データの先読みはしたくないです。
このような問題を弊社ではDataloaderライブラリを使うことで解決しています。

Dataloaderとは

データ取得をバッチ化するためのライブラリです。
先ほどお薬カテゴリの取得が3回発生しましたが、Dataloaderはそれを順番に実行するのではなく一定時間待機してその間に発生したリクエスト(取得したいデータのキー)を蓄積します。
その後、溜まったキーをSQLのin句などでまとめて取得できる仕組みです。
オリジナルは Facebook社が開発する graphql/dataloader で、弊社ではGo用のDataloaderライブラリ vektah/dataloaden を使用しています。

Dataloader導入

githubを参考にしつつDataloaderを導入します。
インストール後dataloaderディレクトリを作成し以下のようなdataloader.goファイルを作成しておきます。

// dataloader/dataloader.go

package dataloader

その後、dataloaderディレクトで下記コマンドを実行しソースを自動生成します。

go run github.com/vektah/dataloaden ${Loader} ${取得するデータのキーの型} ${返却する型へのパス}/${返却する型}

go run github.com/vektah/dataloaden MedicineCategoryLoader string *gqlgen-medicines/graph/model.MedicineCategoryDetail

すると、dataloaderディレクトリ配下にファイルが自動生成されるので、それを使うためのミドルウェアを作成します。

// medicinecategoryloader_gen.go(自動生成ファイル)

/ Code generated by github.com/vektah/dataloaden, DO NOT EDIT.

package dataloader

import (
    "sync"
    "time"

    "gqlgen-medicines/graph/model"
)

// MedicineCategoryLoaderConfig captures the config to create a new MedicineCategoryLoader
type MedicineCategoryLoaderConfig struct {
    // Fetch is a method that provides the data for the loader
    Fetch func(keys []string) ([]*model.MedicineCategoryDetail, []error)

    // Wait is how long wait before sending a batch
    Wait time.Duration

    // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit
    MaxBatch int
}

// NewMedicineCategoryLoader creates a new MedicineCategoryLoader given a fetch, wait, and maxBatch
func NewMedicineCategoryLoader(config MedicineCategoryLoaderConfig) *MedicineCategoryLoader {
    return &MedicineCategoryLoader{
        fetch:    config.Fetch,
        wait:     config.Wait,
        maxBatch: config.MaxBatch,
    }
}

// MedicineCategoryLoader batches and caches requests
type MedicineCategoryLoader struct {
    // this method provides the data for the loader
    fetch func(keys []string) ([]*model.MedicineCategoryDetail, []error)

    // how long to done before sending a batch
    wait time.Duration

    // this will limit the maximum number of keys to send in one batch, 0 = no limit
    maxBatch int

    // INTERNAL

    // lazily created cache
    cache map[string]*model.MedicineCategoryDetail

    // the current batch. keys will continue to be collected until timeout is hit,
    // then everything will be sent to the fetch method and out to the listeners
    batch *medicineCategoryLoaderBatch

    // mutex to prevent races
    mu sync.Mutex
}

type medicineCategoryLoaderBatch struct {
    keys    []string
    data    []*model.MedicineCategoryDetail
    error   []error
    closing bool
    done    chan struct{}
}

// Load a MedicineCategoryDetail by key, batching and caching will be applied automatically
func (l *MedicineCategoryLoader) Load(key string) (*model.MedicineCategoryDetail, error) {
    return l.LoadThunk(key)()
}


// 長いので省略
// middleware/dataloader.go

package middleware

import (
    // 省略
)

const LoaderKey = "Loader"

func LoaderMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // 自動生成されたDataloaderのメソッドを使ってLoader作成
        loader := dataloader.NewMedicineCategoryLoader(dataloader.MedicineCategoryLoaderConfig{
            MaxBatch: 100, // 最大100リクエストバッチ化する
            Wait:     10 * time.Millisecond, // 10ms待機する
            // 最大10ミリ秒待機した結果 or 100リクエスト分のお薬カテゴリIDのスライスが keys という名前で渡ってくる。
            Fetch: func(keys []string) ([]*model.MedicineCategoryDetail, []error) {
                var errors = []error{}
                db := infra.DBConnection()
                defer db.Close()
                db = infra.LogConf(db)

                // お薬カテゴリをまとめて取得する
                rows, err := db.Query("SELECT * FROM medicine_categories WHERE id = ANY($1)", pq.Array(keys))
                if err != nil {
                    errors = append(errors, err)
                    return nil, errors
                }
                defer rows.Close()
                ms := map[string]*model.MedicineCategoryDetail{}
                // idごとに取得したお薬カテゴリをマッピング
                for rows.Next() {
                    var m model.MedicineCategory
                    if err := rows.Scan(&m.ID, &m.Name); err != nil {
                        errors = append(errors, err)
                        return nil, errors
                    }
                    ms[m.ID] = &model.MedicineCategoryDetail{
                        ID:   m.ID,
                        Name: m.Name,
                    }
                }
                // 渡ってきたkeysの順番にお薬カテゴリを並び替え
                result := make([]*model.MedicineCategoryDetail, len(keys))
                for i, key := range keys {
                    result[i] = ms[key]
                }
                return result, nil
            },
        })
        // contextにloaderをセット
        ctx := context.WithValue(r.Context(), LoaderKey, loader)
        r = r.WithContext(ctx)
        next.ServeHTTP(w, r)
    })
}

// loaderを使うための関数(リゾルバで呼び出す)
func MedicineCategory(ctx context.Context, id string) (*model.MedicineCategoryDetail, error) {
    v := ctx.Value(LoaderKey)
    loader, ok := v.(*dataloader.MedicineCategoryLoader)

    if !ok {
        return nil, errors.New("failed to get loader from current context")
    }

    return loader.Load(id)
}
// server.go

package main

import (
    // 省略
)

const defaultPort = "8080"

func main() {
    port := os.Getenv("PORT")
    if port == "" {
        port = defaultPort
    }

    srv := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: &graph.Resolver{}}))

    http.Handle("/", playground.Handler("GraphQL playground", "/query"))
    http.Handle("/query", middleware.LoaderMiddleware(srv)) // ミドルウェアでハンドラーをラップ
    log.Printf("connect to http://localhost:%s/ for GraphQL playground", port)
    log.Fatal(http.ListenAndServe(":"+port, nil))
}

ゾルバもDataloaderを使うように修正します。

// schema.resolvers.go(一部抜粋)

package graph

import (
    // 省略
)

// お薬カテゴリ取得 (お薬N件に対してN回実行)
func (r *medicineDetailResolver) MedicneCategory(ctx context.Context, obj *model.MedicineDetail) (*model.MedicineCategoryDetail, error) {
     // 処理はLoaderに全て寄せたのでリゾルバでは呼ぶだけ
    return middleware.MedicineCategory(ctx, obj.ID)
}

以上で導入は完了です。
ちなみにお薬カテゴリの取得が発生した場合の挙動が以下です。

  1. 10ms待機してIDを蓄積
  2. 溜まったIDをwhere = ANY($1)に入れてお薬カテゴリをまとめて取得
  3. それぞれのお薬に紐付いて返却

リクエストを投げた結果がこちらです
スクリーンショット 2022-09-29 15.22.16(2).png (104.4 kB)

QueryContext query="SELECT * FROM medicines"
QueryContext args=["{\"3\",\"1\",\"2\"}"] query="SELECT * FROM medicine_categories WHERE id = ANY($1)"

お薬カテゴリ取得のSQLが1回になりました🙌

最後に

今回は、GraphQLで発生するN+1問題と、Dataloaderを使った解決方法を紹介しました。
自分がDataloaderを導入した1年ほど前は、gqlgen公式のdataloaderのページで今回紹介したvektah/dataloadenの使用方法が書かれていましたが、いつの間にか graph-gophers/dataloaderに変わってたので、こちらも触ってみようと思います。