diff options
Diffstat (limited to 'db')
-rw-r--r-- | db/db.go | 11 | ||||
-rw-r--r-- | db/endpoint.go | 85 | ||||
-rw-r--r-- | db/session.go | 127 |
3 files changed, 220 insertions, 3 deletions
@@ -2,6 +2,7 @@ package db import ( "os" + "time" "io/fs" "errors" "database/sql" @@ -19,11 +20,13 @@ type User struct { type Session struct { Id string - UserId string + UserId int + Created time.Time + Modified time.Time } type Endpoint struct { - Id string + Id int Name string Path string Address string @@ -45,6 +48,7 @@ type Model interface { CreateSession(user User) (Session, error) DeleteSession(session Session) error CheckSession(session Session) (bool, error) + CleanSessions(maxIdle time.Duration) error AllSessions() ([]Session, error) CreateEndpoint(name, path, address string) (Endpoint, error) @@ -52,6 +56,7 @@ type Model interface { GetEndpointByName(name string) (Endpoint, error) GetEndpointByPath(path string) (Endpoint, error) GetEndpointByAddress(address string) (Endpoint, error) + AllEndpoints() ([]Endpoint, error) } @@ -100,7 +105,7 @@ func (p *Phlox) Create(filename string) error { insert into schema values(0); create table users (userid integer not null primary key, username string, passwordhash string, salt string); - create table sessions (sessionid string not null primary key, userid integer, foreign key(userid) references users(userid)); + create table sessions (sessionid string not null primary key, userid integer, created string, modified string, foreign key(userid) references users(userid)); create table endpoints (endpointid integer not null primary key, name string, path string, address string); `) if err != nil { diff --git a/db/endpoint.go b/db/endpoint.go new file mode 100644 index 0000000..f22d77b --- /dev/null +++ b/db/endpoint.go @@ -0,0 +1,85 @@ +package db + +func (p *Phlox) CreateEndpoint(name, path, address string) (Endpoint, error) { + var id int + row := p.db.QueryRow("select max(endpointid) from endpoints;") + err := row.Scan(&id) + if err != nil { + return Endpoint{}, err + } + id += 1 + + _, err = p.db.Exec( + "insert into endpoints values (?, ?, ?, ?)", + id, name, path, address, + ) + if err != nil { + return Endpoint{}, err + } + + endpoint := Endpoint{ + Id: id, + Name: name, + Path: path, + Address: address, + } + + return endpoint, nil +} + + +func (p *Phlox) DeleteEndpoint(endpoint Endpoint) error { + _, err := p.db.Exec("delete from endpoints where endpointid = ?;", endpoint.Id) + return err +} + + +func extractEndpoint(s Scanner) (Endpoint, error) { + var endpoint Endpoint + err := s.Scan( + &endpoint.Id, + &endpoint.Name, + &endpoint.Path, + &endpoint.Address, + ) + return endpoint, err +} + + +func queryEndpoint(p *Phlox, query, param string) (Endpoint, error) { + row := p.db.QueryRow(query, param) + endpoint, err := extractEndpoint(row) + return endpoint, err +} + +func (p *Phlox) GetEndpointByName(name string) (Endpoint, error) { + return queryEndpoint(p, "select * from endpoints where name = ?;", name) +} + +func (p *Phlox) GetEndpointByPath(path string) (Endpoint, error) { + return queryEndpoint(p, "select * from endpoints where path = ?;", path) +} + +func (p *Phlox) GetEndpointByAddress(address string) (Endpoint, error) { + return queryEndpoint(p, "select * from endpoints where address = ?;", address) +} + + +func (p *Phlox) AllEndpoints() ([]Endpoint, error) { + endpoints := make([]Endpoint, 0) + rows, err := p.db.Query("select * from endpoints;") + if err != nil { + return endpoints, err + } + defer rows.Close() + + for rows.Next() { + endpoint, err := extractEndpoint(rows) + if err != nil { + return endpoints, err + } + endpoints = append(endpoints, endpoint) + } + + return endpoints, nil +} diff --git a/db/session.go b/db/session.go new file mode 100644 index 0000000..bddedda --- /dev/null +++ b/db/session.go @@ -0,0 +1,127 @@ +package db + +import ( + "database/sql" + "crypto/rand" + "encoding/base64" + "errors" + "time" +) + + +func (p *Phlox) CreateSession(user User) (Session, error) { + bytes := make([]byte, 32) + _, err := rand.Read(bytes) + if err != nil { + return Session{}, err + } + + sessionid := base64.StdEncoding.EncodeToString(bytes) + userid := user.Id + now := time.Now().UTC() + nowStr := now.Format(time.RFC3339) + + _, err = p.db.Exec( + "insert into sessions (sessionid, userid, created, modified) values (?, ?, ?, ?);", + sessionid, + userid, + nowStr, nowStr, + ) + if err != nil { + return Session{}, err + } + + return Session{ + Id: sessionid, + UserId: userid, + Created: now, + Modified: now, + }, nil +} + + +func (p *Phlox) DeleteSession(session Session) error { + _, err := p.db.Exec("delete from sessions where sessionid = ?;", session.Id) + return err +} + + +func extractSession(s Scanner) (Session, error) { + var ( + session Session + createdStr string + modifiedStr string + ) + + // scan + err := s.Scan(&session.Id, &session.UserId, &createdStr, &modifiedStr) + if err != nil { + return Session{}, err + } + + // parse times + session.Created, err = time.Parse(time.RFC3339, createdStr) + if err != nil { + return Session{}, err + } + session.Modified, err = time.Parse(time.RFC3339, modifiedStr) + if err != nil { + return Session{}, err + } + + return session, nil +} + + +func (p *Phlox) CheckSession(session Session) (bool, error) { + row := p.db.QueryRow("select * from sessions where sessionid = ?", session.Id) + session, err := extractSession(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + // no row returned, so invalid session + return false, nil + } else { + // some other error + return false, err + } + } + + return true, nil +} + + +func (p *Phlox) TouchSession(session Session) error { + now := time.Now().UTC().Format(time.RFC3339) + _, err := p.db.Exec( + "update sessions set modified = ? where sessionid = ?;", + now, session.Id, + ) + return err +} + + +func (p *Phlox) CleanSessions(maxIdle time.Duration) error { + expire := time.Now().UTC().Add(-maxIdle).Format(time.RFC3339) + _, err := p.db.Exec("delete from sessions where modified < ?;", expire) + return err +} + + +func (p *Phlox) AllSessions() ([]Session, error) { + sessions := make([]Session, 0) + rows, err := p.db.Query("select * from sessions;") + if err != nil { + return sessions, err + } + defer rows.Close() + + for rows.Next() { + session, err := extractSession(rows) + if err != nil { + return sessions, err + } + sessions = append(sessions, session) + } + + return sessions, nil +} |