package db import ( "golang.org/x/crypto/bcrypt" "crypto/rand" "encoding/base64" "database/sql" ) func getNextUserId(db *sql.DB) (int, error) { row := db.QueryRow("select max(userid) from users;") if row.Err() != nil { return -1, row.Err() } var id int err := row.Scan(&id) if err != nil { return -1, err } return id, nil } func saltPassword(password string, salt []byte) []byte { salted := []byte(password) salted = append(salted, salt...) return salted } func hashPassword(password string, salt []byte) ([]byte, error) { salted := saltPassword(password, salt) hash, err := bcrypt.GenerateFromPassword(salted, bcrypt.DefaultCost) if err != nil { return []byte{}, err } return hash, nil } func (p *Phlox) CreateUser(username, password string) (User, error) { user := User{} userId, err := getNextUserId(p.db) if err != nil { return user, err } salt := make([]byte, 10) _, err = rand.Read(salt) if err != nil { return user, err } hash, err := hashPassword(password, salt) if err != nil { return user, err } hash64 := base64.StdEncoding.EncodeToString(hash) salt64 := base64.StdEncoding.EncodeToString(salt) _, err = p.db.Exec("insert into users (userid, username, passwordhash, salt) values (?, ?, ?, ?)", userId, username, hash64, salt64) if err != nil { return user, err } user.Id = userId user.Name = username user.PasswordHash = hash user.Salt = salt return user, nil } func (p *Phlox) DeleteUser(user User) error { _, err := p.db.Exec("delete from users where userid=?;", user.Id) return err } func (p *Phlox) SetPassword(user User, password string) error { hash, err := hashPassword(password, user.Salt) if err != nil { return err } hash64 := base64.StdEncoding.EncodeToString(hash) _, err = p.db.Exec("update users set passwordhash=? where userid=?;", hash64, user.Id) return err } func extractUser(s Scanner) (User, error) { var userid int var username string var hash64 string var salt64 string err := s.Scan(&userid, &username, &hash64, &salt64) if err != nil { return User{}, err } hash, err := base64.StdEncoding.DecodeString(hash64) if err != nil { return User{}, err } salt, err := base64.StdEncoding.DecodeString(salt64) if err != nil { return User{}, err } user := User{ Id: userid, Name: username, PasswordHash: hash, Salt: salt, } return user, nil } func (p *Phlox) AuthenticateUser(username, password string) (bool, User, error) { row := p.db.QueryRow("select * from users where username = ?;", username) user, err := extractUser(row) if err != nil { return false, User{}, err } salted := saltPassword(password, user.Salt) err = bcrypt.CompareHashAndPassword(user.PasswordHash, salted) if err != nil { // bad password return false, User{}, nil } else { // success! return true, user, nil } } func (p *Phlox) GetByUsername(username string) (User, error) { row := p.db.QueryRow("select * from users where username = ?;", username) user, err := extractUser(row) if err != nil { return User{}, err } return user, nil } func (p *Phlox) GetById(id int) (User, error) { row := p.db.QueryRow("select * from users where userid = ?;", id) user, err := extractUser(row) if err != nil { return User{}, err } return user, nil } func (p *Phlox) AllUsers() ([]User, error) { users := make([]User, 0) rows, err := p.db.Query("select * from users;") if err != nil { return users, err } defer rows.Close() for rows.Next() { user, err := extractUser(rows) if err != nil { return users, err } users = append(users, user) } return users, nil }