diff options
Diffstat (limited to 'db')
-rw-r--r-- | db/db.go | 143 | ||||
-rw-r--r-- | db/go.mod | 5 | ||||
-rw-r--r-- | db/go.sum | 4 | ||||
-rw-r--r-- | db/user.go | 149 |
4 files changed, 298 insertions, 3 deletions
@@ -1,10 +1,147 @@ package db import ( - "fmt" + "os" + "io/fs" + "errors" + "database/sql" + _ "github.com/mattn/go-sqlite3" ) -func HelloWorld() { - fmt.Println("hello, world!") +type User struct { + Id int + Name string + PasswordHash []byte + Salt []byte +} + + +type Session struct { + Id string + UserId string +} + +type Endpoint struct { + Id string + Name string + Path string + Address string +} + + +type Model interface { + Create(filename string) error + Open(filename string) error + Close() error + + GetSchemaVersion() (int, error) + + CreateUser(username, password string) (User, error) + AuthenticateUser(username, password string) (User, error) + GetById(id string) (User, error) + AllUsers() ([]User, error) + + CreateSession(user User) (Session, error) + DeleteSession(session Session) error + CheckSession(session Session) (bool, error) + AllSessions() ([]Session, error) + + CreateEndpoint(name, path, address string) (Endpoint, error) + DeleteEndpoint(endpoint Endpoint) error + GetEndpointByName(name string) (Endpoint, error) + GetEndpointByPath(path string) (Endpoint, error) + GetEndpointByAddress(address string) (Endpoint, error) +} + + +// interface to wrap sql.Row and sql.Rows +type Scanner interface { + Scan(dest ...any) error +} + + +type Phlox struct { + db *sql.DB +} + + +func fileExists(filename string) (bool, error) { + _, err := os.Stat(filename) + if err == nil { + return true, nil + } else if errors.Is(err, fs.ErrNotExist) { + return false, nil + } else { + // unknown error + return false, err + } +} + + +func (p *Phlox) Create(filename string) error { + exist, err := fileExists(filename) + if err != nil { + return err + } + + if exist { + // file already exists + return fs.ErrExist + } + + p.db, err = sql.Open("sqlite3", filename) + if err != nil { + return err + } + + _, err = p.db.Exec(` + create table schema(version integer); + insert into schema values(0); + + create table users (userid integer not null primary key, username string, passwordhash string, salt string); + create table sessions (sessionid string not null primary key, userid integer, foreign key(userid) references users(userid)); + create table endpoints (endpointid integer not null primary key, name string, path string, address string); + `) + if err != nil { + return err + } + + return nil +} + + +func (p *Phlox) Open(filename string) error { + exist, err := fileExists(filename) + if err != nil { + return err + } + + if !exist { + // no such file! + return fs.ErrNotExist + } + + p.db, err = sql.Open("sqlite3", filename) + if err != nil { + return err + } + + return nil +} + + +func (p *Phlox) GetSchemaVersion() (int, error) { + row := p.db.QueryRow("select max(version) from schema;") + if row.Err() != nil { + return -1, row.Err() + } + + var version int + err := row.Scan(&version) + if err != nil { + return -1, err + } + + return version, nil } @@ -1,3 +1,8 @@ module sanine.net/git/phlox/db go 1.19 + +require ( + github.com/mattn/go-sqlite3 v1.14.16 + golang.org/x/crypto v0.8.0 +) diff --git a/db/go.sum b/db/go.sum new file mode 100644 index 0000000..48de3bc --- /dev/null +++ b/db/go.sum @@ -0,0 +1,4 @@ +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..37c0744 --- /dev/null +++ b/db/user.go @@ -0,0 +1,149 @@ +package db + +import ( + "golang.org/x/crypto/bcrypt" + "crypto/rand" + "encoding/base64" + "database/sql" +) + + +func getNextUserId(db *sql.DB) (int, error) { + row := db.QueryRow("select max(userid) from users;") + if row.Err() != nil { + return -1, row.Err() + } + + var id int + err := row.Scan(&id) + if err != nil { + return -1, err + } + return id, nil +} + + +func saltPassword(password string, salt []byte) []byte { + salted := []byte(password) + salted = append(salted, salt...) + return salted +} + + +func (p *Phlox) CreateUser(username, password string) (User, error) { + user := User{} + + userId, err := getNextUserId(p.db) + if err != nil { + return user, err + } + + salt := make([]byte, 10) + _, err = rand.Read(salt) + if err != nil { + return user, err + } + + salted := saltPassword(password, salt) + + hash, err := bcrypt.GenerateFromPassword(salted, bcrypt.DefaultCost) + if err != nil { + return user, err + } + + hash64 := base64.StdEncoding.EncodeToString(hash) + salt64 := base64.StdEncoding.EncodeToString(salt) + + _, err = p.db.Exec("insert into users (userid, username, passwordhash, salt) values (?, ?, ?, ?)", userId, username, hash64, salt64) + if err != nil { + return user, err + } + + user.Id = userId + user.Name = username + user.PasswordHash = hash + user.Salt = salt + + return user, nil +} + + + +func extractUser(s Scanner) (User, error) { + var userid int + var username string + var hash64 string + var salt64 string + err := s.Scan(&userid, &username, &hash64, &salt64) + if err != nil { + return User{}, err + } + + hash, err := base64.StdEncoding.DecodeString(hash64) + if err != nil { + return User{}, err + } + salt, err := base64.StdEncoding.DecodeString(salt64) + if err != nil { + return User{}, err + } + + user := User{ + Id: userid, + Name: username, + PasswordHash: hash, + Salt: salt, + } + + return user, nil +} + + +func (p *Phlox) AuthenticateUser(username, password string) (bool, User, error) { + row := p.db.QueryRow("select * from users where username = ?;", username) + user, err := extractUser(row) + if err != nil { + return false, User{}, err + } + + salted := saltPassword(password, user.Salt) + err = bcrypt.CompareHashAndPassword(user.PasswordHash, salted) + if err != nil { + // bad password + return false, User{}, nil + } else { + // success! + return true, user, nil + } +} + + +func (p *Phlox) GetById(id int) (User, error) { + row := p.db.QueryRow("select * from users where userid = ?;", id) + user, err := extractUser(row) + if err != nil { + return User{}, err + } + + return user, nil +} + + +func (p *Phlox) AllUsers() ([]User, error) { + users := make([]User, 0) + rows, err := p.db.Query("select * from users;") + if err != nil { + return users, err + } + defer rows.Close() + + for rows.Next() { + user, err := extractUser(rows) + if err != nil { + return users, err + } + users = append(users, user) + } + + return users, nil +} |