summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--db/db.go143
-rw-r--r--db/go.mod5
-rw-r--r--db/go.sum4
-rw-r--r--db/user.go149
4 files changed, 298 insertions, 3 deletions
diff --git a/db/db.go b/db/db.go
index 2c59694..4824a6d 100644
--- a/db/db.go
+++ b/db/db.go
@@ -1,10 +1,147 @@
package db
import (
- "fmt"
+ "os"
+ "io/fs"
+ "errors"
+ "database/sql"
+ _ "github.com/mattn/go-sqlite3"
)
-func HelloWorld() {
- fmt.Println("hello, world!")
+type User struct {
+ Id int
+ Name string
+ PasswordHash []byte
+ Salt []byte
+}
+
+
+type Session struct {
+ Id string
+ UserId string
+}
+
+type Endpoint struct {
+ Id string
+ Name string
+ Path string
+ Address string
+}
+
+
+type Model interface {
+ Create(filename string) error
+ Open(filename string) error
+ Close() error
+
+ GetSchemaVersion() (int, error)
+
+ CreateUser(username, password string) (User, error)
+ AuthenticateUser(username, password string) (User, error)
+ GetById(id string) (User, error)
+ AllUsers() ([]User, error)
+
+ CreateSession(user User) (Session, error)
+ DeleteSession(session Session) error
+ CheckSession(session Session) (bool, error)
+ AllSessions() ([]Session, error)
+
+ CreateEndpoint(name, path, address string) (Endpoint, error)
+ DeleteEndpoint(endpoint Endpoint) error
+ GetEndpointByName(name string) (Endpoint, error)
+ GetEndpointByPath(path string) (Endpoint, error)
+ GetEndpointByAddress(address string) (Endpoint, error)
+}
+
+
+// interface to wrap sql.Row and sql.Rows
+type Scanner interface {
+ Scan(dest ...any) error
+}
+
+
+type Phlox struct {
+ db *sql.DB
+}
+
+
+func fileExists(filename string) (bool, error) {
+ _, err := os.Stat(filename)
+ if err == nil {
+ return true, nil
+ } else if errors.Is(err, fs.ErrNotExist) {
+ return false, nil
+ } else {
+ // unknown error
+ return false, err
+ }
+}
+
+
+func (p *Phlox) Create(filename string) error {
+ exist, err := fileExists(filename)
+ if err != nil {
+ return err
+ }
+
+ if exist {
+ // file already exists
+ return fs.ErrExist
+ }
+
+ p.db, err = sql.Open("sqlite3", filename)
+ if err != nil {
+ return err
+ }
+
+ _, err = p.db.Exec(`
+ create table schema(version integer);
+ insert into schema values(0);
+
+ create table users (userid integer not null primary key, username string, passwordhash string, salt string);
+ create table sessions (sessionid string not null primary key, userid integer, foreign key(userid) references users(userid));
+ create table endpoints (endpointid integer not null primary key, name string, path string, address string);
+ `)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+
+func (p *Phlox) Open(filename string) error {
+ exist, err := fileExists(filename)
+ if err != nil {
+ return err
+ }
+
+ if !exist {
+ // no such file!
+ return fs.ErrNotExist
+ }
+
+ p.db, err = sql.Open("sqlite3", filename)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+
+func (p *Phlox) GetSchemaVersion() (int, error) {
+ row := p.db.QueryRow("select max(version) from schema;")
+ if row.Err() != nil {
+ return -1, row.Err()
+ }
+
+ var version int
+ err := row.Scan(&version)
+ if err != nil {
+ return -1, err
+ }
+
+ return version, nil
}
diff --git a/db/go.mod b/db/go.mod
index 51b2b67..7931102 100644
--- a/db/go.mod
+++ b/db/go.mod
@@ -1,3 +1,8 @@
module sanine.net/git/phlox/db
go 1.19
+
+require (
+ github.com/mattn/go-sqlite3 v1.14.16
+ golang.org/x/crypto v0.8.0
+)
diff --git a/db/go.sum b/db/go.sum
new file mode 100644
index 0000000..48de3bc
--- /dev/null
+++ b/db/go.sum
@@ -0,0 +1,4 @@
+github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
+github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
+golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
+golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
diff --git a/db/user.go b/db/user.go
new file mode 100644
index 0000000..37c0744
--- /dev/null
+++ b/db/user.go
@@ -0,0 +1,149 @@
+package db
+
+import (
+ "golang.org/x/crypto/bcrypt"
+ "crypto/rand"
+ "encoding/base64"
+ "database/sql"
+)
+
+
+func getNextUserId(db *sql.DB) (int, error) {
+ row := db.QueryRow("select max(userid) from users;")
+ if row.Err() != nil {
+ return -1, row.Err()
+ }
+
+ var id int
+ err := row.Scan(&id)
+ if err != nil {
+ return -1, err
+ }
+ return id, nil
+}
+
+
+func saltPassword(password string, salt []byte) []byte {
+ salted := []byte(password)
+ salted = append(salted, salt...)
+ return salted
+}
+
+
+func (p *Phlox) CreateUser(username, password string) (User, error) {
+ user := User{}
+
+ userId, err := getNextUserId(p.db)
+ if err != nil {
+ return user, err
+ }
+
+ salt := make([]byte, 10)
+ _, err = rand.Read(salt)
+ if err != nil {
+ return user, err
+ }
+
+ salted := saltPassword(password, salt)
+
+ hash, err := bcrypt.GenerateFromPassword(salted, bcrypt.DefaultCost)
+ if err != nil {
+ return user, err
+ }
+
+ hash64 := base64.StdEncoding.EncodeToString(hash)
+ salt64 := base64.StdEncoding.EncodeToString(salt)
+
+ _, err = p.db.Exec("insert into users (userid, username, passwordhash, salt) values (?, ?, ?, ?)", userId, username, hash64, salt64)
+ if err != nil {
+ return user, err
+ }
+
+ user.Id = userId
+ user.Name = username
+ user.PasswordHash = hash
+ user.Salt = salt
+
+ return user, nil
+}
+
+
+
+func extractUser(s Scanner) (User, error) {
+ var userid int
+ var username string
+ var hash64 string
+ var salt64 string
+ err := s.Scan(&userid, &username, &hash64, &salt64)
+ if err != nil {
+ return User{}, err
+ }
+
+ hash, err := base64.StdEncoding.DecodeString(hash64)
+ if err != nil {
+ return User{}, err
+ }
+ salt, err := base64.StdEncoding.DecodeString(salt64)
+ if err != nil {
+ return User{}, err
+ }
+
+ user := User{
+ Id: userid,
+ Name: username,
+ PasswordHash: hash,
+ Salt: salt,
+ }
+
+ return user, nil
+}
+
+
+func (p *Phlox) AuthenticateUser(username, password string) (bool, User, error) {
+ row := p.db.QueryRow("select * from users where username = ?;", username)
+ user, err := extractUser(row)
+ if err != nil {
+ return false, User{}, err
+ }
+
+ salted := saltPassword(password, user.Salt)
+ err = bcrypt.CompareHashAndPassword(user.PasswordHash, salted)
+ if err != nil {
+ // bad password
+ return false, User{}, nil
+ } else {
+ // success!
+ return true, user, nil
+ }
+}
+
+
+func (p *Phlox) GetById(id int) (User, error) {
+ row := p.db.QueryRow("select * from users where userid = ?;", id)
+ user, err := extractUser(row)
+ if err != nil {
+ return User{}, err
+ }
+
+ return user, nil
+}
+
+
+func (p *Phlox) AllUsers() ([]User, error) {
+ users := make([]User, 0)
+ rows, err := p.db.Query("select * from users;")
+ if err != nil {
+ return users, err
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ user, err := extractUser(rows)
+ if err != nil {
+ return users, err
+ }
+ users = append(users, user)
+ }
+
+ return users, nil
+}