From 3424d3115199554359ba01bca669343907ec0190 Mon Sep 17 00:00:00 2001 From: phoenix Date: Mon, 7 Apr 2025 01:22:57 +0000 Subject: [PATCH] Login endpoint (#20) Reviewed-on: https://git.kundeng.us/phoenix/icarus_auth/pulls/20 Co-authored-by: phoenix Co-committed-by: phoenix --- .env.sample | 1 + .gitea/workflows/workflow.yml | 1 + Cargo.toml | 5 +- src/callers/login.rs | 101 +++++++++++++++++++++ src/callers/mod.rs | 2 + src/hashing/mod.rs | 4 + src/lib.rs | 12 ++- src/main.rs | 164 +++++++++++++++++++++++++++------- src/repo/mod.rs | 60 +++++++++++++ src/token_stuff/mod.rs | 87 ++++++++++++++++++ 10 files changed, 402 insertions(+), 35 deletions(-) create mode 100644 src/callers/login.rs create mode 100644 src/token_stuff/mod.rs diff --git a/.env.sample b/.env.sample index 135f9aa..c7494ce 100644 --- a/.env.sample +++ b/.env.sample @@ -1 +1,2 @@ DATABASE_URL=postgres://username:password@localhost/database_name +SECRET_KEY=refero34o8rfhfjn983thf39fhc943rf923n3h \ No newline at end of file diff --git a/.gitea/workflows/workflow.yml b/.gitea/workflows/workflow.yml index 557a245..aa3fded 100644 --- a/.gitea/workflows/workflow.yml +++ b/.gitea/workflows/workflow.yml @@ -73,6 +73,7 @@ jobs: # Define DATABASE_URL for tests to use DATABASE_URL: postgresql://${{ secrets.DB_TEST_USER || 'testuser' }}:${{ secrets.DB_TEST_PASSWORD || 'testpassword' }}@postgres:5432/${{ secrets.DB_TEST_NAME || 'testdb' }} RUST_LOG: info # Optional: configure test log level + SECRET_KEY: ${{ secrets.TOKEN_SECRET_KEY }} # Make SSH agent available if tests fetch private dependencies SSH_AUTH_SOCK: ${{ env.SSH_AUTH_SOCK }} run: | diff --git a/Cargo.toml b/Cargo.toml index 6231883..701c694 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "icarus_auth" -version = "0.2.0" +version = "0.3.0" edition = "2024" rust-version = "1.86" @@ -18,7 +18,8 @@ uuid = { version = "1.16.0", features = ["v4", "serde"] } argon2 = { version = "0.5.3", features = ["std"] } # Use the latest 0.5.x version rand = { version = "0.9" } time = { version = "0.3.41", features = ["macros", "serde"] } -icarus_models = { git = "ssh://git@git.kundeng.us/phoenix/icarus_models.git", tag = "v0.4.0" } +josekit = { version = "0.10.1" } +icarus_models = { git = "ssh://git@git.kundeng.us/phoenix/icarus_models.git", tag = "v0.4.1" } [dev-dependencies] http-body-util = { version = "0.1.3" } diff --git a/src/callers/login.rs b/src/callers/login.rs new file mode 100644 index 0000000..66286ee --- /dev/null +++ b/src/callers/login.rs @@ -0,0 +1,101 @@ +pub mod request { + use serde::{Deserialize, Serialize}; + + #[derive(Default, Deserialize, Serialize)] + pub struct Request { + pub username: String, + pub password: String, + } +} + +pub mod response { + use serde::{Deserialize, Serialize}; + + #[derive(Default, Deserialize, Serialize)] + pub struct Response { + pub message: String, + pub data: Vec, + } +} + +pub mod endpoint { + use axum::{Json, http::StatusCode}; + + use crate::hashing; + use crate::repo; + use crate::token_stuff; + + use super::request; + use super::response; + + async fn not_found(message: &str) -> (StatusCode, Json) { + ( + StatusCode::NOT_FOUND, + Json(response::Response { + message: String::from(message), + data: Vec::new(), + }), + ) + } + + pub async fn login( + axum::Extension(pool): axum::Extension, + Json(payload): Json, + ) -> (StatusCode, Json) { + let usr = icarus_models::user::User { + username: payload.username, + password: payload.password, + ..Default::default() + }; + + // Check if user exists + match repo::user::exists(&pool, &usr.username).await { + Ok(exists) => { + if !exists { + return not_found("Not Found").await; + } + } + Err(err) => { + return not_found(&err.to_string()).await; + } + }; + + let user = repo::user::get(&pool, &usr.username).await.unwrap(); + let salt = repo::salt::get(&pool, &user.salt_id).await.unwrap(); + let salt_str = hashing::get_salt(&salt.salt).unwrap(); + + // Check if password is correct + match hashing::hash_password(&usr.password, &salt_str) { + Ok(hash_password) => { + if hashing::verify_password(&usr.password, hash_password.clone()).unwrap() { + // Create token + let key = token_stuff::get_key().unwrap(); + let (token_literal, duration) = token_stuff::create_token(&key).unwrap(); + + if token_stuff::verify_token(&key, &token_literal) { + ( + StatusCode::OK, + Json(response::Response { + message: String::from("Successful"), + data: vec![icarus_models::login_result::LoginResult { + id: user.id, + username: user.username, + token: token_literal, + token_type: String::from(token_stuff::TOKENTYPE), + expiration: duration, + }], + }), + ) + } else { + return not_found("Could not verify password").await; + } + } else { + return not_found("Error Hashing").await; + } + } + Err(err) => { + return not_found(&err.to_string()).await; + } + } + } +} diff --git a/src/callers/mod.rs b/src/callers/mod.rs index 33ddec1..ab9f31e 100644 --- a/src/callers/mod.rs +++ b/src/callers/mod.rs @@ -1,8 +1,10 @@ pub mod common; +pub mod login; pub mod register; pub mod endpoints { pub const ROOT: &str = "/"; pub const REGISTER: &str = "/api/v2/register"; pub const DBTEST: &str = "/api/v2/test/db"; + pub const LOGIN: &str = "/api/v2/login"; } diff --git a/src/hashing/mod.rs b/src/hashing/mod.rs index 1386d0c..a3fbbd4 100644 --- a/src/hashing/mod.rs +++ b/src/hashing/mod.rs @@ -15,6 +15,10 @@ pub fn generate_salt() -> Result { Ok(salt) } +pub fn get_salt(s: &str) -> Result { + SaltString::from_b64(s) +} + pub fn hash_password( password: &String, salt: &SaltString, diff --git a/src/lib.rs b/src/lib.rs index 1e67995..9ab8f4d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod callers; pub mod config; pub mod hashing; pub mod repo; +pub mod token_stuff; pub mod keys { pub const DBURL: &str = "DATABASE_URL"; @@ -15,7 +16,7 @@ mod connection_settings { pub const MAXCONN: u32 = 5; } -pub mod db_pool { +pub mod db { use sqlx::postgres::PgPoolOptions; use std::env; @@ -38,4 +39,13 @@ pub mod db_pool { env::var(keys::DBURL).expect(keys::error::ERROR) } + + pub async fn migrations(pool: &sqlx::PgPool) { + // Run migrations using the sqlx::migrate! macro + // Assumes your migrations are in a ./migrations folder relative to Cargo.toml + sqlx::migrate!("./migrations") + .run(pool) + .await + .expect("Failed to run migrations"); + } } diff --git a/src/main.rs b/src/main.rs index da52960..2dfe642 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,17 +14,6 @@ async fn main() { axum::serve(listener, app).await.unwrap(); } -mod db { - pub async fn migrations(pool: &sqlx::PgPool) { - // Run migrations using the sqlx::migrate! macro - // Assumes your migrations are in a ./migrations folder relative to Cargo.toml - sqlx::migrate!("./migrations") - .run(pool) - .await - .expect("Failed to run migrations on testcontainer DB"); - } -} - mod init { use axum::{ Router, @@ -32,7 +21,6 @@ mod init { }; use crate::callers; - use crate::db; pub async fn routes() -> Router { // build our application with a route @@ -43,14 +31,18 @@ mod init { callers::endpoints::REGISTER, post(callers::register::register_user), ) + .route( + callers::endpoints::LOGIN, + post(callers::login::endpoint::login), + ) } pub async fn app() -> Router { - let pool = icarus_auth::db_pool::create_pool() + let pool = icarus_auth::db::create_pool() .await .expect("Failed to create pool"); - db::migrations(&pool).await; + icarus_auth::db::migrations(&pool).await; routes().await.layer(axum::Extension(pool)) } @@ -141,6 +133,30 @@ mod tests { } } + fn get_test_register_request() -> icarus_auth::callers::register::request::Request { + icarus_auth::callers::register::request::Request { + username: String::from("somethingsss"), + password: String::from("Raindown!"), + email: String::from("dev@null.com"), + phone: String::from("1234567890"), + firstname: String::from("Bob"), + lastname: String::from("Smith"), + } + } + + fn get_test_register_payload( + usr: &icarus_auth::callers::register::request::Request, + ) -> serde_json::Value { + json!({ + "username": &usr.username, + "password": &usr.password, + "email": &usr.email, + "phone": &usr.phone, + "firstname": &usr.firstname, + "lastname": &usr.lastname, + }) + } + #[tokio::test] async fn test_hello_world() { let app = init::app().await; @@ -180,27 +196,12 @@ mod tests { let pool = db_mgr::connect_to_db(&db_name).await.unwrap(); - db::migrations(&pool).await; + icarus_auth::db::migrations(&pool).await; let app = init::routes().await.layer(axum::Extension(pool)); - let usr = icarus_auth::callers::register::request::Request { - username: String::from("somethingsss"), - password: String::from("Raindown!"), - email: String::from("dev@null.com"), - phone: String::from("1234567890"), - firstname: String::from("Bob"), - lastname: String::from("Smith"), - }; - - let payload = json!({ - "username": &usr.username, - "password": &usr.password, - "email": &usr.email, - "phone": &usr.phone, - "firstname": &usr.firstname, - "lastname": &usr.lastname, - }); + let usr = get_test_register_request(); + let payload = get_test_register_payload(&usr); let response = app .oneshot( @@ -244,4 +245,103 @@ mod tests { let _ = db_mgr::drop_database(&tm_pool, &db_name).await; } + + #[tokio::test] + async fn test_login_user() { + let tm_pool = db_mgr::get_pool().await.unwrap(); + + let db_name = db_mgr::generate_db_name().await; + + match db_mgr::create_database(&tm_pool, &db_name).await { + Ok(_) => { + println!("Success"); + } + Err(e) => { + assert!(false, "Error: {:?}", e.to_string()); + } + } + + let pool = db_mgr::connect_to_db(&db_name).await.unwrap(); + + icarus_auth::db::migrations(&pool).await; + + let app = init::routes().await.layer(axum::Extension(pool)); + + let usr = get_test_register_request(); + let payload = get_test_register_payload(&usr); + + let response = app + .clone() + .oneshot( + Request::builder() + .method(axum::http::Method::POST) + .uri(callers::endpoints::REGISTER) + .header(axum::http::header::CONTENT_TYPE, "application/json") + .body(Body::from(payload.to_string())) + .unwrap(), + ) + .await; + + match response { + Ok(resp) => { + assert_eq!( + resp.status(), + StatusCode::CREATED, + "Message: {:?} {:?}", + resp, + usr.username + ); + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let parsed_body: callers::register::response::Response = + serde_json::from_slice(&body).unwrap(); + let returned_usr = &parsed_body.data[0]; + + assert_eq!(false, returned_usr.id.is_nil(), "Id is not populated"); + + assert_eq!( + usr.username, returned_usr.username, + "Usernames do not match" + ); + assert!(returned_usr.date_created.is_some(), "Date Created is empty"); + + let login_payload = json!({ + "username": &usr.username, + "password": &usr.password, + }); + + match app + .oneshot( + Request::builder() + .method(axum::http::Method::POST) + .uri(callers::endpoints::LOGIN) + .header(axum::http::header::CONTENT_TYPE, "application/json") + .body(Body::from(login_payload.to_string())) + .unwrap(), + ) + .await + { + Ok(resp) => { + assert_eq!(StatusCode::OK, resp.status(), "Status is not right"); + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let parsed_body: callers::login::response::Response = + serde_json::from_slice(&body).unwrap(); + let login_result = &parsed_body.data[0]; + assert!(!login_result.id.is_nil(), "Id is nil"); + } + Err(err) => { + assert!(false, "Error: {:?}", err.to_string()); + } + } + } + Err(err) => { + assert!(false, "Error: {:?}", err.to_string()); + } + }; + + let _ = db_mgr::drop_database(&tm_pool, &db_name).await; + } } diff --git a/src/repo/mod.rs b/src/repo/mod.rs index 049a840..b8a8c8c 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -7,6 +7,41 @@ pub mod user { pub date_created: Option, } + pub async fn get( + pool: &sqlx::PgPool, + username: &String, + ) -> Result { + let result = sqlx::query( + r#" + SELECT * FROM "user" WHERE username = $1 + "#, + ) + .bind(username) + .fetch_optional(pool) + .await; + + match result { + Ok(r) => match r { + Some(r) => Ok(icarus_models::user::User { + id: r.try_get("id")?, + username: r.try_get("username")?, + password: r.try_get("password")?, + email: r.try_get("email")?, + email_verified: r.try_get("email_verified")?, + phone: r.try_get("phone")?, + salt_id: r.try_get("salt_id")?, + firstname: r.try_get("firstname")?, + lastname: r.try_get("lastname")?, + date_created: r.try_get("date_created")?, + last_login: r.try_get("last_login")?, + status: r.try_get("status")?, + }), + None => Err(sqlx::Error::RowNotFound), + }, + Err(e) => Err(e), + } + } + pub async fn exists(pool: &sqlx::PgPool, username: &String) -> Result { let result = sqlx::query( r#" @@ -72,6 +107,31 @@ pub mod salt { pub id: uuid::Uuid, } + pub async fn get( + pool: &sqlx::PgPool, + id: &uuid::Uuid, + ) -> Result { + let result = sqlx::query( + r#" + SELECT * FROM "salt" WHERE id = $1 + "#, + ) + .bind(id) + .fetch_optional(pool) + .await; + + match result { + Ok(r) => match r { + Some(r) => Ok(icarus_models::user::salt::Salt { + id: r.try_get("id")?, + salt: r.try_get("salt")?, + }), + None => Err(sqlx::Error::RowNotFound), + }, + Err(e) => Err(e), + } + } + pub async fn insert( pool: &sqlx::PgPool, salt: &icarus_models::user::salt::Salt, diff --git a/src/token_stuff/mod.rs b/src/token_stuff/mod.rs new file mode 100644 index 0000000..2771dec --- /dev/null +++ b/src/token_stuff/mod.rs @@ -0,0 +1,87 @@ +use josekit::{ + self, + jws::{JwsHeader, alg::hmac::HmacJwsAlgorithm::Hs256}, + jwt::{self, JwtPayload}, +}; + +use time; + +pub const TOKENTYPE: &str = "JWT"; +pub const KEY_ENV: &str = "SECRET_KEY"; +pub const MESSAGE: &str = "Something random"; +pub const ISSUER: &str = "icarus_auth"; +pub const AUDIENCE: &str = "icarus"; + +pub fn get_key() -> Result { + dotenvy::dotenv().ok(); + let key = std::env::var(KEY_ENV).expect("SECRET_KEY_NOT_FOUND"); + Ok(key) +} + +pub fn get_expiration() -> time::Result { + let now = time::OffsetDateTime::now_utc(); + let epoch = time::OffsetDateTime::UNIX_EPOCH; + let since_the_epoch = now - epoch; + Ok(since_the_epoch) +} + +pub fn create_token(provided_key: &String) -> Result<(String, i64), josekit::JoseError> { + let mut header = JwsHeader::new(); + header.set_token_type(TOKENTYPE); + + let mut payload = JwtPayload::new(); + payload.set_subject(MESSAGE); + payload.set_issuer(ISSUER); + payload.set_audience(vec![AUDIENCE]); + match get_expiration() { + Ok(duration) => { + let expire = duration.whole_seconds(); + let _ = payload.set_claim( + "expiration", + Some(serde_json::to_value(expire.to_string()).unwrap()), + ); + + let key: String = if provided_key.is_empty() { + get_key().unwrap() + } else { + provided_key.to_owned() + }; + + let signer = Hs256.signer_from_bytes(key.as_bytes()).unwrap(); + Ok(( + josekit::jwt::encode_with_signer(&payload, &header, &signer).unwrap(), + duration.whole_seconds(), + )) + } + Err(e) => Err(josekit::JoseError::InvalidClaim(e.into())), + } +} + +pub fn verify_token(key: &String, token: &String) -> bool { + let ver = Hs256.verifier_from_bytes(key.as_bytes()).unwrap(); + let (payload, _header) = jwt::decode_with_verifier(token, &ver).unwrap(); + match payload.subject() { + Some(_sub) => true, + None => false, + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_tokenize() { + let special_key = get_key().unwrap(); + match create_token(&special_key) { + Ok((token, _duration)) => { + let result = verify_token(&special_key, &token); + assert!(result, "Token not verified"); + } + Err(err) => { + assert!(false, "Error: {:?}", err.to_string()); + } + }; + } +}