我正在尝试对我的注册处理程序和数据库调用实施单元测试。但是,它在我的注册处理程序中的数据库调用上引发紧急错误。它是一个简单的注册处理程序,用于接收包含用户名、密码和电子邮件的 json。然后,我将使用 select 语句来检查该用户名是否在 signup 处理程序本身内重复。
当我向该处理程序发送我的发布请求时,这一切都有效。然而,当我实际进行单元测试时,它不起作用并给我抛出了两条错误消息。我觉得这是因为数据库没有在测试环境中初始化,但我不知道如何在不使用第三方框架进行模拟数据库的情况下做到这一点。
错误消息
panic: runtime error: invalid memory address or nil pointer dereference [recovered]
panic: runtime error: invalid memory address or nil pointer dereference
signup.go
package handler
type signupjson struct {
username string `json:"username"`
password string `json:"password"`
email string `json:"email"`
}
func signup(w http.responsewriter, r *http.request) {
// set headers
w.header().set("content-type", "application/json")
var newuser auth_management.signupjson
// reading the request body and unmarshal the body to the loginjson struct
bs, _ := io.readall(req.body)
if err := json.unmarshal(bs, &newuser); err != nil {
utils.responsejson(w, http.statusinternalservererror, "internal server error")
log.println("internal server error in unmarshal json body in signup route:", err)
return
}
ctx := context.background()
ctx, cancel = context.withtimeout(ctx, time.minute * 2)
defer cancel()
// check if username already exists in database (duplicates not allowed)
isexistingusername := database.getusername(ctx, newuser.username) // throws panic error here when testing
if isexistingusername {
utils.responsejson(w, http.statusbadrequest, "username has already been taken. please try again.")
return
}
// other code logic...
}
sqlquery.go
package database
var sql_select_from_users = "select %s from users where %s = $1;"
func getusername(ctx context.context, username string) bool {
row := conn.queryrow(ctx, fmt.sprintf(sql_select_from_users, "username", "username"), username)
return row.scan() != pgx.errnorows
}
signup_test.go
package handler
func test_signup(t *testing.t) {
var tests = []struct {
name string
posteddata signupjson
expectedstatuscode int
}{
{
name: "valid login",
posteddata: signupjson{
username: "testusername",
password: "testpassword",
email: "[email protected]",
},
expectedstatuscode: 200,
},
}
for _, e := range tests {
jsonstr, err := json.marshal(e.posteddata)
if err != nil {
t.fatal(err)
}
// setting a request for testing
req, _ := http.newrequest(http.methodpost, "/signup", strings.newreader(string(jsonstr)))
req.header.set("content-type", "application/json")
// setting and recording the response
res := httptest.newrecorder()
handler := http.handlerfunc(signup)
handler.servehttp(res, req)
if res.code != e.expectedstatuscode {
t.errorf("%s: returned wrong status code; expected %d but got %d", e.name, e.expectedstatuscode, res.code)
}
}
}
setup_test.go
func TestMain(m *testing.M) {
os.Exit(m.Run())
}
我在这里看到了一个类似的问题,但不确定这是否是正确的方法,因为没有响应,而且答案很混乱:how to write an unit test for a handler that invokes a function that invokes a function that intersted with db in golang using pgx 驱动程序?
正确答案
让我尝试帮助您弄清楚如何实现这些目标。我对你的代码进行了一些重构,但总体思路和使用的工具仍然与你的相同。首先,我将分享分为两个文件的生产代码:handlers/handlers.go
和 repo/repo.go
。
handlers/handlers.go
文件
package handlers
import (
"context"
"database/sql"
"encoding/json"
"io"
"net/http"
"time"
"handlertest/repo"
)
type signupjson struct {
username string `json:"username"`
password string `json:"password"`
email string `json:"email"`
}
func signup(w http.responsewriter, r *http.request) {
w.header().set("content-type", "application/json")
var newuser signupjson
bs, _ := io.readall(r.body)
if err := json.unmarshal(bs, &newuser); err != nil {
w.writeheader(http.statusbadrequest)
w.write([]byte(err.error()))
return
}
ctx, cancel := context.withtimeout(r.context(), time.minute*2)
defer cancel()
db, _ := ctx.value("db").(*sql.db)
if isexistingusername := repo.getusername(ctx, db, newuser.username); isexistingusername {
w.writeheader(http.statusbadrequest)
w.write([]byte("username already present"))
return
}
w.writeheader(http.statusok)
}
这里有两个主要区别:
- 使用的
context
。您不必实例化另一个ctx
,只需使用与http.request
一起提供的那个即可。 - 使用的
sql
客户端。正确的方法是通过context.context
来传递。对于这种情况,您不必构建任何结构或使用任何接口等。只需编写一个需要*sql.db
作为参数的函数即可。请记住这一点,函数是一等公民。
当然,还有重构的空间。 "db"
应该是一个常量,我们必须检查上下文值中是否存在此条目,但为了简洁起见,我省略了这些检查。
repo/repo.go
文件
package repo
import (
"context"
"database/sql"
"github.com/jackc/pgx/v5"
)
func getusername(ctx context.context, db *sql.db, username string) bool {
row := db.queryrowcontext(ctx, "select username from users where username = $1", username)
return row.scan() != pgx.errnorows
}
这里的代码与您的非常相似,除了以下两个小问题:
- 当您希望考虑上下文时,有一个名为
queryrowcontext
的专用方法。 - 当您必须构建 sql 查询时,请使用准备好的语句功能。不要将内容与
fmt.sprintf
连接起来,原因有两个:安全性和可测试性。
现在,我们要看看测试代码。
handlers/handlers_test.go
文件
package handlers
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
)
func TestSignUp(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("err not expected while open a mock db, %v", err)
}
defer db.Close()
t.Run("NewUser", func(t *testing.T) {
mock.ExpectQuery("SELECT username FROM users WHERE username = $1").WithArgs("[email protected]").WillReturnError(pgx.ErrNoRows)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(`{"username": "[email protected]", "password": "1234", "email": "[email protected]"}`))
ctx := context.WithValue(r.Context(), "DB", db)
r = r.WithContext(ctx)
SignUp(w, r)
assert.Equal(t, http.StatusOK, w.Code)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("not all expectations were met: %v", err)
}
})
t.Run("AlreadyExistentUser", func(t *testing.T) {
rows := sqlmock.NewRows([]string{"username"}).AddRow("[email protected]")
mock.ExpectQuery("SELECT username FROM users WHERE username = $1").WithArgs("[email protected]").WillReturnRows(rows)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(`{"username": "[email protected]", "password": "1234", "email": "[email protected]"}`))
ctx := context.WithValue(r.Context(), "DB", db)
r = r.WithContext(ctx)
SignUp(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("not all expectations were met: %v", err)
}
})
}
这里,与您的版本相比有很多变化。让我快速回顾一下:
- 使用子测试功能为测试提供层次结构。
- 使用
httptest
包,它提供了用于构建和断言 http 请求和响应的内容。 - 使用
sqlmock
包。模拟数据库的事实上的标准。 - 使用
context
传递sql
客户端以及http.request
。 - 已使用
github.com/stretchr/testify/assert
包完成断言。
这同样适用于这里:有重构的空间(例如,您可以使用表驱动测试功能重新设计测试)。
片尾
这可以被认为是编写 go 代码的惯用方式。我知道这可能非常具有挑战性,尤其是在一开始。如果您需要有关某些部分的更多详细信息,请告诉我,我将很乐意为您提供帮助,谢谢!