diff options
Diffstat (limited to 'db/user.go')
-rw-r--r-- | db/user.go | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..37c0744 --- /dev/null +++ b/db/user.go @@ -0,0 +1,149 @@ +package db + +import ( + "golang.org/x/crypto/bcrypt" + "crypto/rand" + "encoding/base64" + "database/sql" +) + + +func getNextUserId(db *sql.DB) (int, error) { + row := db.QueryRow("select max(userid) from users;") + if row.Err() != nil { + return -1, row.Err() + } + + var id int + err := row.Scan(&id) + if err != nil { + return -1, err + } + return id, nil +} + + +func saltPassword(password string, salt []byte) []byte { + salted := []byte(password) + salted = append(salted, salt...) + return salted +} + + +func (p *Phlox) CreateUser(username, password string) (User, error) { + user := User{} + + userId, err := getNextUserId(p.db) + if err != nil { + return user, err + } + + salt := make([]byte, 10) + _, err = rand.Read(salt) + if err != nil { + return user, err + } + + salted := saltPassword(password, salt) + + hash, err := bcrypt.GenerateFromPassword(salted, bcrypt.DefaultCost) + if err != nil { + return user, err + } + + hash64 := base64.StdEncoding.EncodeToString(hash) + salt64 := base64.StdEncoding.EncodeToString(salt) + + _, err = p.db.Exec("insert into users (userid, username, passwordhash, salt) values (?, ?, ?, ?)", userId, username, hash64, salt64) + if err != nil { + return user, err + } + + user.Id = userId + user.Name = username + user.PasswordHash = hash + user.Salt = salt + + return user, nil +} + + + +func extractUser(s Scanner) (User, error) { + var userid int + var username string + var hash64 string + var salt64 string + err := s.Scan(&userid, &username, &hash64, &salt64) + if err != nil { + return User{}, err + } + + hash, err := base64.StdEncoding.DecodeString(hash64) + if err != nil { + return User{}, err + } + salt, err := base64.StdEncoding.DecodeString(salt64) + if err != nil { + return User{}, err + } + + user := User{ + Id: userid, + Name: username, + PasswordHash: hash, + Salt: salt, + } + + return user, nil +} + + +func (p *Phlox) AuthenticateUser(username, password string) (bool, User, error) { + row := p.db.QueryRow("select * from users where username = ?;", username) + user, err := extractUser(row) + if err != nil { + return false, User{}, err + } + + salted := saltPassword(password, user.Salt) + err = bcrypt.CompareHashAndPassword(user.PasswordHash, salted) + if err != nil { + // bad password + return false, User{}, nil + } else { + // success! + return true, user, nil + } +} + + +func (p *Phlox) GetById(id int) (User, error) { + row := p.db.QueryRow("select * from users where userid = ?;", id) + user, err := extractUser(row) + if err != nil { + return User{}, err + } + + return user, nil +} + + +func (p *Phlox) AllUsers() ([]User, error) { + users := make([]User, 0) + rows, err := p.db.Query("select * from users;") + if err != nil { + return users, err + } + defer rows.Close() + + for rows.Next() { + user, err := extractUser(rows) + if err != nil { + return users, err + } + users = append(users, user) + } + + return users, nil +} |