diff --git a/server/auth/basic/auth_basic.go b/server/auth/basic/auth_basic.go index 9ec3021f2..b679dc1fd 100644 --- a/server/auth/basic/auth_basic.go +++ b/server/auth/basic/auth_basic.go @@ -27,7 +27,7 @@ func parseSecret(bsecret []byte) (uname, password string, err error) { secret := string(bsecret) splitAt := strings.Index(secret, ":") - if splitAt < 1 { + if splitAt < 0 { err = types.ErrMalformed return } @@ -77,17 +77,17 @@ func (BasicAuth) UpdateRecord(rec *auth.Rec, secret []byte) error { return err } - storedUID, _, _, _, err := store.Users.GetAuthRecord("basic", uname) + login, _, _, _, err := store.Users.GetAuthRecord(rec.Uid, "basic") if err != nil { return err } - // Record not found - probably invalid login - if storedUID.IsZero() { + // User does not have a record. + if login == "" { return types.ErrNotFound } - if storedUID != rec.Uid { - // User is trying to change login to one that is already taken - return types.ErrDuplicate + if uname == "" { + // User is changing just the password. + uname = login } passhash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { @@ -95,7 +95,7 @@ func (BasicAuth) UpdateRecord(rec *auth.Rec, secret []byte) error { } var expires time.Time if rec.Lifetime > 0 { - expires = time.Now().Add(rec.Lifetime) + expires = types.TimeNow().Add(rec.Lifetime) } _, err = store.Users.UpdateAuthRecord(rec.Uid, auth.LevelAuth, "basic", uname, passhash, expires) if err != nil { @@ -111,7 +111,11 @@ func (BasicAuth) Authenticate(secret []byte) (*auth.Rec, error) { return nil, err } - uid, authLvl, passhash, expires, err := store.Users.GetAuthRecord("basic", uname) + if len(uname) < minLoginLength || len(uname) > maxLoginLength { + return nil, types.ErrFailed + } + + uid, authLvl, passhash, expires, err := store.Users.GetAuthUniqueRecord("basic", uname) if err != nil { return nil, err } else if uid.IsZero() { @@ -146,7 +150,11 @@ func (BasicAuth) IsUnique(secret []byte) (bool, error) { return false, err } - uid, _, _, _, err := store.Users.GetAuthRecord("basic", uname) + if len(uname) < minLoginLength || len(uname) > maxLoginLength { + return false, types.ErrPolicy + } + + uid, _, _, _, err := store.Users.GetAuthUniqueRecord("basic", uname) if err != nil { return false, err } diff --git a/server/db/mysql/adapter.go b/server/db/mysql/adapter.go index 2eb7b5ff8..088e7bbd3 100644 --- a/server/db/mysql/adapter.go +++ b/server/db/mysql/adapter.go @@ -186,7 +186,7 @@ func (a *adapter) CreateDb(reset bool) error { if _, err = tx.Exec( `CREATE TABLE kvmeta(` + - "`key` CHAR(32)," + + "`key` CHAR(32)," + "`value` TEXT," + "PRIMARY KEY(`key`)" + `)`); err != nil { @@ -198,16 +198,16 @@ func (a *adapter) CreateDb(reset bool) error { if _, err = tx.Exec( `CREATE TABLE users( - id BIGINT NOT NULL, - createdat DATETIME(3) NOT NULL, - updatedat DATETIME(3) NOT NULL, - deletedat DATETIME(3), - state INT DEFAULT 0, - access JSON, - lastseen DATETIME, - useragent VARCHAR(255) DEFAULT '', - public JSON, - tags JSON, + id BIGINT NOT NULL, + createdat DATETIME(3) NOT NULL, + updatedat DATETIME(3) NOT NULL, + deletedat DATETIME(3), + state INT DEFAULT 0, + access JSON, + lastseen DATETIME, + useragent VARCHAR(255) DEFAULT '', + public JSON, + tags JSON, PRIMARY KEY(id) )`); err != nil { return err @@ -216,9 +216,9 @@ func (a *adapter) CreateDb(reset bool) error { // Indexed user tags. if _, err = tx.Exec( `CREATE TABLE usertags( - id INT NOT NULL AUTO_INCREMENT, - userid BIGINT NOT NULL, - tag VARCHAR(96) NOT NULL, + id INT NOT NULL AUTO_INCREMENT, + userid BIGINT NOT NULL, + tag VARCHAR(96) NOT NULL, PRIMARY KEY(id), FOREIGN KEY(userid) REFERENCES users(id), INDEX usertags_tag (tag) @@ -229,13 +229,13 @@ func (a *adapter) CreateDb(reset bool) error { // Indexed devices. Normalized into a separate table. if _, err = tx.Exec( `CREATE TABLE devices( - id INT NOT NULL AUTO_INCREMENT, - userid BIGINT NOT NULL, - hash CHAR(16) NOT NULL, - deviceid TEXT NOT NULL, - platform VARCHAR(32), - lastseen DATETIME NOT NULL, - lang VARCHAR(8), + id INT NOT NULL AUTO_INCREMENT, + userid BIGINT NOT NULL, + hash CHAR(16) NOT NULL, + deviceid TEXT NOT NULL, + platform VARCHAR(32), + lastseen DATETIME NOT NULL, + lang VARCHAR(8), PRIMARY KEY(id), FOREIGN KEY(userid) REFERENCES users(id), UNIQUE INDEX devices_hash (hash) @@ -245,16 +245,18 @@ func (a *adapter) CreateDb(reset bool) error { // Authentication records for the basic authentication scheme. if _, err = tx.Exec( - `CREATE TABLE basicauth( - id INT NOT NULL AUTO_INCREMENT, - login VARCHAR(32) NOT NULL, - userid BIGINT NOT NULL, - authlvl INT NOT NULL, - secret VARCHAR(255) NOT NULL, - expires DATETIME, + `CREATE TABLE auth( + id INT NOT NULL AUTO_INCREMENT, + uname VARCHAR(32) NOT NULL, + userid BIGINT NOT NULL, + scheme VARCHAR(16) NOT NULL, + authlvl INT NOT NULL, + secret VARCHAR(255) NOT NULL, + expires DATETIME, PRIMARY KEY(id), FOREIGN KEY(userid) REFERENCES users(id), - UNIQUE INDEX basicauth_login(login) + UNIQUE INDEX auth_userid_scheme(userid, scheme), + UNIQUE INDEX auth_uname(uname) )`); err != nil { return err } @@ -262,18 +264,18 @@ func (a *adapter) CreateDb(reset bool) error { // Topics if _, err = tx.Exec( `CREATE TABLE topics( - id INT NOT NULL AUTO_INCREMENT, - createdat DATETIME(3) NOT NULL, - updatedat DATETIME(3) NOT NULL, - deletedat DATETIME(3), - touchedat DATETIME(3), - name CHAR(25) NOT NULL, - usebt INT DEFAULT 0, - access JSON, - seqid INT NOT NULL DEFAULT 0, - delid INT DEFAULT 0, - public JSON, - tags JSON, + id INT NOT NULL AUTO_INCREMENT, + createdat DATETIME(3) NOT NULL, + updatedat DATETIME(3) NOT NULL, + deletedat DATETIME(3), + touchedat DATETIME(3), + name CHAR(25) NOT NULL, + usebt INT DEFAULT 0, + access JSON, + seqid INT NOT NULL DEFAULT 0, + delid INT DEFAULT 0, + public JSON, + tags JSON, PRIMARY KEY(id), UNIQUE INDEX topics_name(name) )`); err != nil { @@ -283,9 +285,9 @@ func (a *adapter) CreateDb(reset bool) error { // Indexed topic tags. if _, err = tx.Exec( `CREATE TABLE topictags( - id INT NOT NULL AUTO_INCREMENT, - topic CHAR(25) NOT NULL, - tag VARCHAR(96) NOT NULL, + id INT NOT NULL AUTO_INCREMENT, + topic CHAR(25) NOT NULL, + tag VARCHAR(96) NOT NULL, PRIMARY KEY(id), FOREIGN KEY(topic) REFERENCES topics(name), INDEX topictags_tag(tag) @@ -319,16 +321,16 @@ func (a *adapter) CreateDb(reset bool) error { // Messages if _, err = tx.Exec( `CREATE TABLE messages( - id INT NOT NULL AUTO_INCREMENT, - createdat DATETIME(3) NOT NULL, - updatedat DATETIME(3) NOT NULL, - deletedat DATETIME(3), - delid INT DEFAULT 0, - seqid INT NOT NULL, - topic CHAR(25) NOT NULL,` + - "`from` BIGINT NOT NULL," + - `head JSON, - content JSON, + id INT NOT NULL AUTO_INCREMENT, + createdat DATETIME(3) NOT NULL, + updatedat DATETIME(3) NOT NULL, + deletedat DATETIME(3), + delid INT DEFAULT 0, + seqid INT NOT NULL, + topic CHAR(25) NOT NULL,` + + "`from` BIGINT NOT NULL," + + `head JSON, + content JSON, PRIMARY KEY(id),` + "FOREIGN KEY(`from`) REFERENCES users(id)," + `FOREIGN KEY(topic) REFERENCES topics(name), @@ -340,12 +342,12 @@ func (a *adapter) CreateDb(reset bool) error { // Deletion log if _, err = tx.Exec( `CREATE TABLE dellog( - id INT NOT NULL AUTO_INCREMENT, - topic VARCHAR(25) NOT NULL, - deletedfor BIGINT NOT NULL DEFAULT 0, - delid INT NOT NULL, - low INT NOT NULL, - hi INT NOT NULL, + id INT NOT NULL AUTO_INCREMENT, + topic VARCHAR(25) NOT NULL, + deletedfor BIGINT NOT NULL DEFAULT 0, + delid INT NOT NULL, + low INT NOT NULL, + hi INT NOT NULL, PRIMARY KEY(id), FOREIGN KEY(topic) REFERENCES topics(name), UNIQUE INDEX dellog_topic_delid_deletedfor(topic,delid,deletedfor), @@ -435,15 +437,15 @@ func (a *adapter) UserCreate(user *t.User) error { } // Add user's authentication record -func (a *adapter) AuthAddRecord(uid t.Uid, authLvl auth.Level, unique string, +func (a *adapter) AuthAddRecord(uid t.Uid, scheme, unique string, authLvl auth.Level, secret []byte, expires time.Time) (bool, error) { var exp *time.Time if !expires.IsZero() { exp = &expires } - _, err := a.db.Exec("INSERT INTO basicauth(login,userid,authLvl,secret,expires) VALUES(?,?,?,?,?)", - unique, store.DecodeUid(uid), authLvl, secret, exp) + _, err := a.db.Exec("INSERT INTO auth(uname,userid,scheme,authLvl,secret,expires) VALUES(?,?,?,?,?,?)", + unique, store.DecodeUid(uid), scheme, authLvl, secret, exp) if err != nil { if isDupe(err) { return true, t.ErrDuplicate @@ -455,13 +457,13 @@ func (a *adapter) AuthAddRecord(uid t.Uid, authLvl auth.Level, unique string, // Delete user's authentication record func (a *adapter) AuthDelRecord(user t.Uid, unique string) error { - _, err := a.db.Exec("DELETE FROM basicauth WHERE userid=? AND login=?", store.DecodeUid(user), unique) + _, err := a.db.Exec("DELETE FROM auth WHERE userid=? AND uname=?", store.DecodeUid(user), unique) return err } // Delete user's all authentication records func (a *adapter) AuthDelAllRecords(uid t.Uid) (int, error) { - res, err := a.db.Exec("DELETE FROM basicauth WHERE userid=?", store.DecodeUid(uid)) + res, err := a.db.Exec("DELETE FROM auth WHERE userid=?", store.DecodeUid(uid)) if err != nil { return 0, err } @@ -471,25 +473,52 @@ func (a *adapter) AuthDelAllRecords(uid t.Uid) (int, error) { } // Update user's authentication secret -func (a *adapter) AuthUpdRecord(unique string, authLvl auth.Level, secret []byte, expires time.Time) (int, error) { +func (a *adapter) AuthUpdRecord(uid t.Uid, scheme, unique string, authLvl auth.Level, + secret []byte, expires time.Time) (bool, error) { var exp *time.Time if !expires.IsZero() { exp = &expires } - res, err := a.db.Exec("UPDATE basicauth SET authLvl=?,secret=?,expires=? WHERE login=?", - authLvl, secret, exp, unique) + _, err := a.db.Exec("UPDATE auth SET uname=?,authLvl=?,secret=?,expires=? WHERE uname=?", + unique, authLvl, secret, exp, unique) + if isDupe(err) { + return true, t.ErrDuplicate + } + + return false, err +} + +// Retrieve user's authentication record +func (a *adapter) AuthGetRecord(uid t.Uid, scheme string) (string, auth.Level, []byte, time.Time, error) { + var expires time.Time + + var record struct { + Uname string + Authlvl auth.Level + Secret []byte + Expires *time.Time + } + err := a.db.Get(&record, "SELECT uname,secret,expires,authlvl FROM auth WHERE userid=? AND scheme=?", + store.DecodeUid(uid), scheme) if err != nil { - return 0, err + if err == sql.ErrNoRows { + // Nothing found - clear the error + err = nil + } + return "", 0, nil, expires, err } - count, _ := res.RowsAffected() - return int(count), nil + if record.Expires != nil { + expires = *record.Expires + } + + return record.Uname, record.Authlvl, record.Secret, expires, nil } // Retrieve user's authentication record -func (a *adapter) AuthGetRecord(unique string) (t.Uid, auth.Level, []byte, time.Time, error) { +func (a *adapter) AuthGetUniqueRecord(unique string) (t.Uid, auth.Level, []byte, time.Time, error) { var expires time.Time var record struct { @@ -499,7 +528,7 @@ func (a *adapter) AuthGetRecord(unique string) (t.Uid, auth.Level, []byte, time. Expires *time.Time } - err := a.db.Get(&record, "SELECT userid, secret, expires, authlvl FROM basicauth WHERE login=?", unique) + err := a.db.Get(&record, "SELECT userid,secret,expires,authlvl FROM auth WHERE uname=?", unique) if err != nil { if err == sql.ErrNoRows { // Nothing found - clear the error @@ -512,7 +541,6 @@ func (a *adapter) AuthGetRecord(unique string) (t.Uid, auth.Level, []byte, time. expires = *record.Expires } - // log.Println("loggin in user Id=", user.Uid(), user.Id) return store.EncodeUid(record.Userid), record.Authlvl, record.Secret, expires, nil } diff --git a/server/db/mysql/schema.sql b/server/db/mysql/schema.sql index 4e3c87f24..141514fc0 100644 --- a/server/db/mysql/schema.sql +++ b/server/db/mysql/schema.sql @@ -55,17 +55,19 @@ CREATE TABLE devices( ); # Authentication records for the basic authentication scheme. -CREATE TABLE basicauth( - id INT NOT NULL AUTO_INCREMENT, - login VARCHAR(32) NOT NULL, - userid BIGINT NOT NULL, - authlvl INT NOT NULL, - secret VARCHAR(255) NOT NULL, - expires DATETIME, +CREATE TABLE auth( + id INT NOT NULL AUTO_INCREMENT, + uname VARCHAR(32) NOT NULL, + userid BIGINT NOT NULL, + scheme VARCHAR(16) NOT NULL, + authlvl INT NOT NULL, + secret VARCHAR(255) NOT NULL, + expires DATETIME, PRIMARY KEY(id), FOREIGN KEY(userid) REFERENCES users(id), - UNIQUE INDEX basicauth_login (login) + UNIQUE INDEX auth_userid_scheme(userid, scheme), + UNIQUE INDEX auth_uname (uname) ); diff --git a/server/db/rethinkdb/adapter.go b/server/db/rethinkdb/adapter.go index aadbfab47..8c7ae4563 100644 --- a/server/db/rethinkdb/adapter.go +++ b/server/db/rethinkdb/adapter.go @@ -292,13 +292,14 @@ func (a *adapter) UserCreate(user *t.User) error { } // Add user's authentication record -func (a *adapter) AuthAddRecord(uid t.Uid, authLvl auth.Level, unique string, +func (a *adapter) AuthAddRecord(uid t.Uid, scheme, unique string, authLvl auth.Level, secret []byte, expires time.Time) (bool, error) { _, err := rdb.DB(a.dbName).Table("auth").Insert( map[string]interface{}{ "unique": unique, "userid": uid.String(), + "scheme": scheme, "authLvl": authLvl, "secret": secret, "expires": expires}).RunWrite(a.conn) @@ -326,18 +327,77 @@ func (a *adapter) AuthDelAllRecords(uid t.Uid) (int, error) { return res.Deleted, err } -// Update user's authentication secret -func (a *adapter) AuthUpdRecord(unique string, authLvl auth.Level, secret []byte, expires time.Time) (int, error) { - res, err := rdb.DB(a.dbName).Table("auth").Get(unique).Update( - map[string]interface{}{ - "authLvl": authLvl, - "secret": secret, - "expires": expires}).RunWrite(a.conn) - return res.Updated, err +// Update user's authentication secret. +func (a *adapter) AuthUpdRecord(uid t.Uid, scheme, unique string, authLvl auth.Level, + secret []byte, expires time.Time) (bool, error) { + // The 'unique' is used as a primary key (no other way to ensure uniqueness in RethinkDB). + // The primary key is immutable. If 'unique' has changed, we have to replace the old record with a new one: + // 1. Check if 'unique' has changed. + // 2. If not, execute update by 'unique' + // 3. If yes, first insert the new record (it may fail due to dublicate 'unique') then delete the old one. + var dupe bool + // Get the old 'unique' + res, err := rdb.DB(a.dbName).Table("auth").GetAllByIndex("userid", uid.String()). + Filter(map[string]interface{}{"scheme": scheme}). + Pluck("unique").Default(nil).Run(a.conn) + if err != nil { + return dupe, err + } + if res.IsNil() { + // If the record is not found, don't update it + return dupe, t.ErrNotFound + } + var record struct { + Unique string `gorethink:"unique"` + } + if err = res.One(&record); err != nil { + return dupe, err + } + if record.Unique == unique { + // Unique has not changed + _, err = rdb.DB(a.dbName).Table("auth").Get(unique).Update( + map[string]interface{}{ + "authLvl": authLvl, + "secret": secret, + "expires": expires}).RunWrite(a.conn) + } else { + // Unique has changed. Insert-Delete. + dupe, err = a.AuthAddRecord(uid, scheme, unique, authLvl, secret, expires) + if err == nil { + // We can't do much with the error here. No support for transactions :( + a.AuthDelRecord(uid, unique) + } + } + return dupe, err +} + +// Retrieve user's authentication record +func (a *adapter) AuthGetRecord(uid t.Uid, scheme string) (string, auth.Level, []byte, time.Time, error) { + // Default() is needed to prevent Pluck from returning an error + row, err := rdb.DB(a.dbName).Table("auth").GetAllByIndex("userid", uid.String()). + Filter(map[string]interface{}{"scheme": scheme}). + Pluck("unique", "secret", "expires", "authLvl").Default(nil).Run(a.conn) + if err != nil || row.IsNil() { + return "", 0, nil, time.Time{}, err + } + + var record struct { + Unique string `gorethink:"unique"` + AuthLvl auth.Level `gorethink:"authLvl"` + Secret []byte `gorethink:"secret"` + Expires time.Time `gorethink:"expires"` + } + + if err = row.One(&record); err != nil { + return "", 0, nil, time.Time{}, err + } + + // log.Println("loggin in user Id=", user.Uid(), user.Id) + return record.Unique, record.AuthLvl, record.Secret, record.Expires, nil } // Retrieve user's authentication record -func (a *adapter) AuthGetRecord(unique string) (t.Uid, auth.Level, []byte, time.Time, error) { +func (a *adapter) AuthGetUniqueRecord(unique string) (t.Uid, auth.Level, []byte, time.Time, error) { // Default() is needed to prevent Pluck from returning an error row, err := rdb.DB(a.dbName).Table("auth").Get(unique).Pluck( "userid", "secret", "expires", "authLvl").Default(nil).Run(a.conn) diff --git a/server/store/adapter/adapter.go b/server/store/adapter/adapter.go index 3c2990454..b6d35f08c 100644 --- a/server/store/adapter/adapter.go +++ b/server/store/adapter/adapter.go @@ -36,11 +36,12 @@ type Adapter interface { CredGet(uid t.Uid, method string) ([]*t.Credential, error) // Authentication management for the basic authentication scheme - AuthGetRecord(unique string) (t.Uid, auth.Level, []byte, time.Time, error) - AuthAddRecord(user t.Uid, authLvl auth.Level, unique string, secret []byte, expires time.Time) (bool, error) + AuthGetUniqueRecord(unique string) (t.Uid, auth.Level, []byte, time.Time, error) + AuthGetRecord(user t.Uid, scheme string) (string, auth.Level, []byte, time.Time, error) + AuthAddRecord(user t.Uid, scheme, unique string, authLvl auth.Level, secret []byte, expires time.Time) (bool, error) AuthDelRecord(user t.Uid, unique string) error AuthDelAllRecords(uid t.Uid) (int, error) - AuthUpdRecord(unique string, authLvl auth.Level, secret []byte, expires time.Time) (int, error) + AuthUpdRecord(user t.Uid, scheme, unique string, authLvl auth.Level, secret []byte, expires time.Time) (bool, error) // Topic/contact management diff --git a/server/store/store.go b/server/store/store.go index f4f61ceac..469f5142c 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "sort" + "strings" "time" "github.com/tinode/chat/server/auth" @@ -192,25 +193,36 @@ func (UsersObjMapper) Create(user *types.User, private interface{}) (*types.User // GetAuthRecord takes a unique identifier and a authentication scheme name, fetches user ID and // authentication secret. -func (UsersObjMapper) GetAuthRecord(scheme, unique string) (types.Uid, auth.Level, []byte, time.Time, error) { - return adp.AuthGetRecord(scheme + ":" + unique) +func (UsersObjMapper) GetAuthRecord(user types.Uid, scheme string) (string, auth.Level, []byte, time.Time, error) { + unique, authLvl, secret, expires, err := adp.AuthGetRecord(user, scheme) + if err == nil { + parts := strings.Split(unique, ":") + unique = parts[1] + } + return unique, authLvl, secret, expires, err +} + +// GetAuthRecord takes a unique identifier and a authentication scheme name, fetches user ID and +// authentication secret. +func (UsersObjMapper) GetAuthUniqueRecord(scheme, unique string) (types.Uid, auth.Level, []byte, time.Time, error) { + return adp.AuthGetUniqueRecord(scheme + ":" + unique) } // AddAuthRecord creates a new authentication record for the given user. func (UsersObjMapper) AddAuthRecord(uid types.Uid, authLvl auth.Level, scheme, unique string, secret []byte, expires time.Time) (bool, error) { - return adp.AuthAddRecord(uid, authLvl, scheme+":"+unique, secret, expires) + return adp.AuthAddRecord(uid, scheme, scheme+":"+unique, authLvl, secret, expires) } // UpdateAuthRecord updates authentication record with a new secret and expiration time. func (UsersObjMapper) UpdateAuthRecord(uid types.Uid, authLvl auth.Level, scheme, unique string, - secret []byte, expires time.Time) (int, error) { + secret []byte, expires time.Time) (bool, error) { - return adp.AuthUpdRecord(scheme+":"+unique, authLvl, secret, expires) + return adp.AuthUpdRecord(uid, scheme, scheme+":"+unique, authLvl, secret, expires) } -// DelAuthRecord deletes user's all auth records of the givel scheme. +// DelAuthRecord deletes user's all auth records of the given scheme. func (UsersObjMapper) DelAuthRecords(uid types.Uid, scheme string) error { return adp.AuthDelRecord(uid, scheme) }