GraphQLのN+1問題をDataloaderで解決
こんにちは。株式会社PRVENT開発部、バックエンドチームの横川です。
今回は、GraphQLで発生するN+1問題と弊社の解決方法を紹介しようと思います。
N+1問題とは?
ループ処理の中で都度SQLが発行され、SQLが大量増殖してしまう問題です。
- テーブルAからN件のデータ取得を1回
- テーブルBから、テーブルAの各行に紐づくデータを(1件ずつ)取得を計N回
1回+N回のSQLが発行されるので、N+1問題です。(実行順で言うと1+N)
Nの数が多いとその分処理に時間がかかりますし、DBにも負荷がかかります。
GraphQLで起きるN+1問題
前回の記事のコードをリファクタしつつ、まずはN+1対策せずにお薬データを全件取得するQueryを実装しました。
# schema.graphqls type Query { Medicine(ID: ID!): MedicineDetail! Medicines: [MedicineDetail!]! # ←全てのお薬取得 } # お薬 type MedicineDetail { ID: ID! name: String! medicneCategory: MedicineCategoryDetail! } # お薬カテゴリ type MedicineCategoryDetail { ID: ID! name: String! }
// schema.resolvers.go(一部抜粋) package graph import ( // 省略 ) // お薬カテゴリ取得 (お薬N件に対してN回実行) func (r *medicineDetailResolver) MedicneCategory(ctx context.Context, obj *model.MedicineDetail) (*model.MedicineCategoryDetail, error) { // DB接続 db := infra.DBConnection() defer db.Close() // sqlログ出力設定 db = infra.LogConf(db) var m model.MedicineCategory // お薬カテゴリ取得 err := db.QueryRow("SELECT * FROM medicine_categories where id = $1", obj.MedicneCategory.ID).Scan(&m.ID, &m.Name) if err != nil { return nil, err } return &model.MedicineCategoryDetail{ ID: m.ID, Name: m.Name, }, nil } // お薬全件取得 func (r *queryResolver) Medicines(ctx context.Context) ([]*model.MedicineDetail, error) { db := infra.DBConnection() defer db.Close() db = infra.LogConf(db) ms := []*model.MedicineDetail{} // お薬全件取得 rows, err := db.Query("SELECT * FROM medicines") if err != nil { return nil, err } defer rows.Close() for rows.Next() { var m model.Medicine if err := rows.Scan(&m.ID, &m.Name, &m.MedicineCategoryID); err != nil { return nil, err } ms = append(ms, &model.MedicineDetail{ ID: m.ID, Name: m.Name, MedicneCategory: &model.MedicineCategoryDetail{ ID: m.MedicineCategoryID, }, }) } return ms, nil }
↓こんなデータを用意して
medicine_categories(お薬カテゴリテーブル)
ID | name |
---|---|
1 | 頭痛薬 |
2 | 風邪薬 |
3 | 痒み止め |
medicines(お薬テーブル)
ID | name | medicine_category_id |
---|---|---|
1 | ロキソニン | 1 |
2 | バファリン | 2 |
3 | ムヒ | 3 |
以下のリクエストを投げると
QueryContext query="SELECT * FROM medicines" QueryContext args=["2"] query="SELECT * FROM medicine_categories where id = $1" QueryContext args=["3"] query="SELECT * FROM medicine_categories where id = $1" QueryContext args=["1"] query="SELECT * FROM medicine_categories where id = $1"
お薬3件取得に対して3回お薬カテゴリ取得のSQLが発行されます。
REST APIなどで発生するN+1問題だとORMのメソッドで関連データを先読みしたり、テーブル結合で解決するかと思います。
ですが、GraphQLだと取得するデータをクライアント側で選択できることが強みなので関連データの先読みはしたくないです。
このような問題を弊社ではDataloaderライブラリを使うことで解決しています。
Dataloaderとは
データ取得をバッチ化するためのライブラリです。
先ほどお薬カテゴリの取得が3回発生しましたが、Dataloaderはそれを順番に実行するのではなく一定時間待機してその間に発生したリクエスト(取得したいデータのキー)を蓄積します。
その後、溜まったキーをSQLのin句などでまとめて取得できる仕組みです。
オリジナルは Facebook社が開発する graphql/dataloader で、弊社ではGo用のDataloaderライブラリ vektah/dataloaden を使用しています。
Dataloader導入
githubを参考にしつつDataloaderを導入します。
インストール後dataloader
ディレクトリを作成し以下のようなdataloader.go
ファイルを作成しておきます。
// dataloader/dataloader.go package dataloader
その後、dataloaderディレクトで下記コマンドを実行しソースを自動生成します。
go run github.com/vektah/dataloaden ${Loader名} ${取得するデータのキーの型} ${返却する型へのパス}/${返却する型} go run github.com/vektah/dataloaden MedicineCategoryLoader string *gqlgen-medicines/graph/model.MedicineCategoryDetail
すると、dataloader
ディレクトリ配下にファイルが自動生成されるので、それを使うためのミドルウェアを作成します。
// medicinecategoryloader_gen.go(自動生成ファイル) / Code generated by github.com/vektah/dataloaden, DO NOT EDIT. package dataloader import ( "sync" "time" "gqlgen-medicines/graph/model" ) // MedicineCategoryLoaderConfig captures the config to create a new MedicineCategoryLoader type MedicineCategoryLoaderConfig struct { // Fetch is a method that provides the data for the loader Fetch func(keys []string) ([]*model.MedicineCategoryDetail, []error) // Wait is how long wait before sending a batch Wait time.Duration // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit MaxBatch int } // NewMedicineCategoryLoader creates a new MedicineCategoryLoader given a fetch, wait, and maxBatch func NewMedicineCategoryLoader(config MedicineCategoryLoaderConfig) *MedicineCategoryLoader { return &MedicineCategoryLoader{ fetch: config.Fetch, wait: config.Wait, maxBatch: config.MaxBatch, } } // MedicineCategoryLoader batches and caches requests type MedicineCategoryLoader struct { // this method provides the data for the loader fetch func(keys []string) ([]*model.MedicineCategoryDetail, []error) // how long to done before sending a batch wait time.Duration // this will limit the maximum number of keys to send in one batch, 0 = no limit maxBatch int // INTERNAL // lazily created cache cache map[string]*model.MedicineCategoryDetail // the current batch. keys will continue to be collected until timeout is hit, // then everything will be sent to the fetch method and out to the listeners batch *medicineCategoryLoaderBatch // mutex to prevent races mu sync.Mutex } type medicineCategoryLoaderBatch struct { keys []string data []*model.MedicineCategoryDetail error []error closing bool done chan struct{} } // Load a MedicineCategoryDetail by key, batching and caching will be applied automatically func (l *MedicineCategoryLoader) Load(key string) (*model.MedicineCategoryDetail, error) { return l.LoadThunk(key)() } // 長いので省略
// middleware/dataloader.go package middleware import ( // 省略 ) const LoaderKey = "Loader" func LoaderMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 自動生成されたDataloaderのメソッドを使ってLoader作成 loader := dataloader.NewMedicineCategoryLoader(dataloader.MedicineCategoryLoaderConfig{ MaxBatch: 100, // 最大100リクエストバッチ化する Wait: 10 * time.Millisecond, // 10ms待機する // 最大10ミリ秒待機した結果 or 100リクエスト分のお薬カテゴリIDのスライスが keys という名前で渡ってくる。 Fetch: func(keys []string) ([]*model.MedicineCategoryDetail, []error) { var errors = []error{} db := infra.DBConnection() defer db.Close() db = infra.LogConf(db) // お薬カテゴリをまとめて取得する rows, err := db.Query("SELECT * FROM medicine_categories WHERE id = ANY($1)", pq.Array(keys)) if err != nil { errors = append(errors, err) return nil, errors } defer rows.Close() ms := map[string]*model.MedicineCategoryDetail{} // idごとに取得したお薬カテゴリをマッピング for rows.Next() { var m model.MedicineCategory if err := rows.Scan(&m.ID, &m.Name); err != nil { errors = append(errors, err) return nil, errors } ms[m.ID] = &model.MedicineCategoryDetail{ ID: m.ID, Name: m.Name, } } // 渡ってきたkeysの順番にお薬カテゴリを並び替え result := make([]*model.MedicineCategoryDetail, len(keys)) for i, key := range keys { result[i] = ms[key] } return result, nil }, }) // contextにloaderをセット ctx := context.WithValue(r.Context(), LoaderKey, loader) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } // loaderを使うための関数(リゾルバで呼び出す) func MedicineCategory(ctx context.Context, id string) (*model.MedicineCategoryDetail, error) { v := ctx.Value(LoaderKey) loader, ok := v.(*dataloader.MedicineCategoryLoader) if !ok { return nil, errors.New("failed to get loader from current context") } return loader.Load(id) }
// server.go package main import ( // 省略 ) const defaultPort = "8080" func main() { port := os.Getenv("PORT") if port == "" { port = defaultPort } srv := handler.NewDefaultServer(generated.NewExecutableSchema(generated.Config{Resolvers: &graph.Resolver{}})) http.Handle("/", playground.Handler("GraphQL playground", "/query")) http.Handle("/query", middleware.LoaderMiddleware(srv)) // ミドルウェアでハンドラーをラップ log.Printf("connect to http://localhost:%s/ for GraphQL playground", port) log.Fatal(http.ListenAndServe(":"+port, nil)) }
リゾルバもDataloaderを使うように修正します。
// schema.resolvers.go(一部抜粋) package graph import ( // 省略 ) // お薬カテゴリ取得 (お薬N件に対してN回実行) func (r *medicineDetailResolver) MedicneCategory(ctx context.Context, obj *model.MedicineDetail) (*model.MedicineCategoryDetail, error) { // 処理はLoaderに全て寄せたのでリゾルバでは呼ぶだけ return middleware.MedicineCategory(ctx, obj.ID) }
以上で導入は完了です。
ちなみにお薬カテゴリの取得が発生した場合の挙動が以下です。
- 10ms待機してIDを蓄積
- 溜まったIDを
where = ANY($1)
に入れてお薬カテゴリをまとめて取得 - それぞれのお薬に紐付いて返却
リクエストを投げた結果がこちらです
QueryContext query="SELECT * FROM medicines" QueryContext args=["{\"3\",\"1\",\"2\"}"] query="SELECT * FROM medicine_categories WHERE id = ANY($1)"
お薬カテゴリ取得のSQLが1回になりました🙌
最後に
今回は、GraphQLで発生するN+1問題と、Dataloaderを使った解決方法を紹介しました。
自分がDataloaderを導入した1年ほど前は、gqlgen公式のdataloaderのページで今回紹介したvektah/dataloadenの使用方法が書かれていましたが、いつの間にか graph-gophers/dataloaderに変わってたので、こちらも触ってみようと思います。