#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![allow(dead_code)]
#![deny(missing_docs)]
use std::borrow::Cow;
use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[cfg(not(feature = "structured"))]
mod unescape;
#[cfg(feature = "integration-fs")]
pub mod integrations;
use base64::Engine;
use futures::FutureExt;
#[cfg(any(feature = "ecdsa", feature = "hmac"))]
use rand::Rng;
#[cfg(any(feature = "ecdsa", feature = "rsa", feature = "hmac"))]
use sha2::Digest;
#[cfg(feature = "structured")]
use serde::{de::DeserializeOwned, Serialize};
#[cfg(feature = "chacha20")]
use chacha20::cipher::{KeyIvInit, StreamCipher};
#[cfg(feature = "hmac")]
use hmac::{Hmac, Mac};
#[cfg(feature = "ecdsa")]
use p256::ecdsa::signature::{Signer, Verifier};
#[cfg(feature = "rsa")]
use rsa::RsaPublicKey;
#[cfg(feature = "chacha20")]
pub use chacha20;
#[cfg(feature = "hmac")]
pub use hmac;
#[cfg(feature = "ecdsa")]
pub use p256;
#[cfg(feature = "rsa")]
pub use rsa;
#[cfg(not(any(feature = "ecdsa", feature = "rsa", feature = "hmac")))]
compile_error!("At least one algorithm has to be enabled.");
#[cfg(not(feature = "structured"))]
pub trait Serialize {}
#[cfg(not(feature = "structured"))]
impl<T> Serialize for T {}
#[cfg(not(feature = "structured"))]
pub trait DeserializeOwned {}
#[cfg(not(feature = "structured"))]
impl<T> DeserializeOwned for T {}
const BASE64_ENGINE: base64::engine::general_purpose::GeneralPurpose =
base64::engine::general_purpose::GeneralPurpose::new(
&base64::alphabet::URL_SAFE,
base64::engine::general_purpose::GeneralPurposeConfig::new()
.with_encode_padding(false)
.with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
);
fn seconds_since_epoch() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn get_cookie<'a, T>(req: &'a kvarn::prelude::Request<T>, name: &str) -> Option<(&'a str, usize)> {
get_cookie_with_header_pos(req, name).map(|(c, p, _)| (c, p))
}
fn get_cookie_with_header_pos<'a, T>(
req: &'a kvarn::prelude::Request<T>,
name: &str,
) -> Option<(&'a str, usize, usize)> {
let mut cookie = None;
let filter = format!("{name}=");
for (header_pos, header) in req
.headers()
.get_all("cookie")
.into_iter()
.enumerate()
.filter_map(|(p, h)| h.to_str().ok().map(|h| (p, h)))
{
if let Some(pos) = header.find(&filter) {
cookie = Some((header, pos + filter.len(), header_pos));
break;
}
}
cookie
}
fn extract_cookie_value(d: (&str, usize)) -> &str {
let s = &d.0[d.1..];
s.split(';').next().unwrap_or(s)
}
fn remove_cookie(req: &mut kvarn::FatRequest, cookie_name: &str) -> bool {
use kvarn::prelude::*;
if let Some((cookie, pos, header_pos)) = get_cookie_with_header_pos(req, cookie_name) {
let value_start = pos - cookie_name.len() - 1;
let value_end = cookie[value_start..]
.find("; ")
.map(|v| v + 2)
.unwrap_or_else(|| cookie.len() - value_start)
+ value_start;
let mut new_cookie_header = cookie.to_owned();
new_cookie_header.drain(value_start..value_end);
let header_to_change = req.headers_mut().entry("cookie");
if let header::Entry::Occupied(mut entry) = header_to_change {
let header_to_change = entry.iter_mut().nth(header_pos).unwrap();
*header_to_change = HeaderValue::from_str(&new_cookie_header)
.expect("unreachable, as we just removed bytes");
} else {
unreachable!(
"The header must be present, since we got the data from it in the previous call"
);
}
true
} else {
false
}
}
fn remove_set_cookie(
response: &mut kvarn::prelude::Response<kvarn::prelude::Bytes>,
cookie_name: &str,
cookie_path: &str,
) {
let remove_cookie = format!(
"{cookie_name}=\"\"; \
Path={cookie_path}; \
Max-Age=0"
);
response.headers_mut().append(
"set-cookie",
kvarn::prelude::HeaderValue::from_str(&remove_cookie)
.expect("a user-supplied cookie_name or the cookie_path contains illegal bytes for use in a header"),
);
}
#[derive(Debug)]
pub enum AuthData<T: Serialize + DeserializeOwned = ()> {
None,
Text(String),
Number(f64),
TextNumber(String, f64),
Structured(T),
}
#[cfg(feature = "hmac")]
fn hmac_sha256(secret: &[u8], bytes: &[u8]) -> impl AsRef<[u8]> {
type HmacSha256 = Hmac<sha2::Sha256>;
let mut hmac = HmacSha256::new_from_slice(secret).unwrap();
hmac.update(bytes);
hmac.finalize().into_bytes()
}
fn ip_to_bytes(ip: IpAddr, buf: &mut Vec<u8>) {
match ip {
IpAddr::V4(v4) => buf.extend(v4.octets()),
IpAddr::V6(v6) => buf.extend(v6.octets()),
}
}
impl<T: Serialize + DeserializeOwned> AuthData<T> {
#[cfg(feature = "structured")]
fn into_jwt(
self,
signing_algo: &ComputedAlgo,
header: &[u8],
seconds_before_expiry: u64,
ip: Option<IpAddr>,
) -> String {
let mut s = BASE64_ENGINE.encode(header);
let mut map = match self {
Self::None => {
let mut map = serde_json::Map::new();
map.insert("__variant".to_owned(), "e".into());
map
}
Self::Text(t) => {
let mut map = serde_json::Map::new();
map.insert("text".to_owned(), serde_json::Value::String(t));
map.insert("__variant".to_owned(), "t".into());
map
}
Self::Number(n) => {
let mut map = serde_json::Map::new();
map.insert(
"num".to_owned(),
serde_json::Value::Number(
serde_json::Number::from_f64(n)
.expect("JWTs cannot contain NaN or infinities"),
),
);
map.insert("__variant".to_owned(), "n".into());
map
}
Self::TextNumber(t, n) => {
let mut map = serde_json::Map::new();
map.insert("text".to_owned(), serde_json::Value::String(t));
map.insert(
"num".to_owned(),
serde_json::Value::Number(
serde_json::Number::from_f64(n)
.expect("JWTs cannot contain NaN or infinities"),
),
);
map.insert("__variant".to_owned(), "tn".into());
map
}
Self::Structured(t) => {
let mut v =
serde_json::to_value(t).expect("failed to serialize structured auth data");
if let Some(map) = v.as_object_mut() {
let mut map = core::mem::take(map);
if map.contains_key("__variant") {
log::warn!("`__variant` key in JWT payload will be overridden");
}
if map.contains_key("__deserialize_v") {
log::warn!("`__deserialize_v` key in JWT payload will be overridden");
map.insert("__deserialize_v".to_owned(), serde_json::Value::Bool(false));
}
map.insert("__variant".to_owned(), "s".into());
map
} else {
let mut map = serde_json::Map::new();
map.insert("v".to_owned(), v);
map.insert("__deserialize_v".to_owned(), serde_json::Value::Bool(true));
map.insert("__variant".to_owned(), "s".into());
map
}
}
};
if map.contains_key("iat") {
log::warn!("`iat` key in JWT payload will be overridden");
}
if map.contains_key("exp") {
log::warn!("`exp` key in JWT payload will be overridden");
}
let now = seconds_since_epoch();
map.insert("iat".to_owned(), serde_json::Value::Number(now.into()));
map.insert(
"exp".to_owned(),
serde_json::Value::Number((now + seconds_before_expiry).into()),
);
let value = serde_json::Value::Object(map);
let payload = value.to_string();
s.push('.');
BASE64_ENGINE.encode_string(payload.as_bytes(), &mut s);
match signing_algo {
#[cfg(feature = "hmac")]
ComputedAlgo::HmacSha256 { secret, .. } => {
let mut hmac = Hmac::<sha2::Sha256>::new_from_slice(secret).unwrap();
hmac.update(s.as_bytes());
if let Some(ip) = ip {
hmac.update(IpBytes::from(ip).as_ref());
}
let sig = hmac.finalize().into_bytes();
s.push('.');
BASE64_ENGINE.encode_string(sig, &mut s);
}
#[cfg(feature = "rsa")]
ComputedAlgo::RSASha256 {
private_key,
public_key: _,
} => {
let mut hasher = sha2::Sha256::new();
hasher.update(s.as_bytes());
if let Some(ip) = ip {
hasher.update(IpBytes::from(ip).as_ref());
}
let hash = hasher.finalize();
let signature = private_key
.sign(rsa::Pkcs1v15Sign::new::<sha2::Sha256>(), &hash)
.expect("failed to sign JWT with RSA key");
s.push('.');
BASE64_ENGINE.encode_string(signature, &mut s);
}
#[cfg(feature = "ecdsa")]
ComputedAlgo::EcdsaP256 { private_key, .. } => {
let signature: p256::ecdsa::DerSignature = if let Some(ip) = ip {
let mut v = s.as_bytes().to_vec();
v.extend_from_slice(IpBytes::from(ip).as_ref());
private_key.sign(&v)
} else {
private_key.sign(s.as_bytes())
};
s.push('.');
BASE64_ENGINE.encode_string(signature, &mut s);
}
}
s
}
#[cfg(not(feature = "structured"))]
fn into_jwt(
self,
signing_algo: &ComputedAlgo,
header: &[u8],
seconds_before_expiry: u64,
ip: Option<IpAddr>,
) -> String {
let mut s = BASE64_ENGINE.encode(header);
let mut json = String::new();
json.push_str(r#"{"__variant":"#);
match self {
Self::None => {
json.push_str(r#""e","#);
}
Self::Text(t) => {
json.push_str(r#""t","text":""#);
json.push_str(&t.escape_default().to_string());
json.push_str("\",");
}
Self::Number(n) => {
json.push_str(r#""n","num":"#);
json.push_str(&n.to_string());
json.push(',');
}
Self::TextNumber(t, n) => {
json.push_str(r#""tn","text":""#);
json.push_str(&t.escape_default().to_string());
json.push_str("\",");
json.push_str(r#""num":"#);
json.push_str(&n.to_string());
json.push(',');
}
Self::Structured(_t) => {
panic!("Using AuthData::Structured without the serde feature enabled")
}
};
let now = seconds_since_epoch();
json.push_str(r#""iat":"#);
json.push_str(&now.to_string());
json.push(',');
json.push_str(r#""exp":"#);
json.push_str(&(now + seconds_before_expiry).to_string());
json.push('}');
let payload = json;
s.push('.');
BASE64_ENGINE.encode_string(payload.as_bytes(), &mut s);
match signing_algo {
#[cfg(feature = "hmac")]
ComputedAlgo::HmacSha256 { secret, .. } => {
let mut hmac = Hmac::<sha2::Sha256>::new_from_slice(secret).unwrap();
hmac.update(s.as_bytes());
if let Some(ip) = ip {
hmac.update(IpBytes::from(ip).as_ref());
}
let sig = hmac.finalize().into_bytes();
s.push('.');
BASE64_ENGINE.encode_string(sig, &mut s);
}
#[cfg(feature = "rsa")]
ComputedAlgo::RSASha256 {
private_key,
public_key: _,
} => {
let mut hasher = sha2::Sha256::new();
hasher.update(s.as_bytes());
if let Some(ip) = ip {
hasher.update(IpBytes::from(ip).as_ref());
}
let hash = hasher.finalize();
let signature = private_key
.sign(rsa::Pkcs1v15Sign::new::<sha2::Sha256>(), &hash)
.expect("failed to sign JWT with RSA key");
s.push('.');
BASE64_ENGINE.encode_string(signature, &mut s);
}
#[cfg(feature = "ecdsa")]
ComputedAlgo::EcdsaP256 { private_key, .. } => {
let signature: p256::ecdsa::Signature = if let Some(ip) = ip {
let mut v = s.as_bytes().to_vec();
v.extend_from_slice(IpBytes::from(ip).as_ref());
private_key.sign(&v)
} else {
private_key.sign(s.as_bytes())
};
s.push('.');
BASE64_ENGINE.encode_string(signature.to_der(), &mut s);
}
}
s
}
fn into_jwt_with_default_header(
self,
signing_algo: &ComputedAlgo,
seconds_before_expiry: u64,
ip: Option<IpAddr>,
) -> String {
static HS_HEADER: &[u8] = r#"{"alg":"HS256"}"#.as_bytes();
static RS_HEADER: &[u8] = r#"{"alg":"RS256"}"#.as_bytes();
static EP_HEADER: &[u8] = r#"{"alg":"ES256"}"#.as_bytes();
let header = match signing_algo {
#[cfg(feature = "hmac")]
ComputedAlgo::HmacSha256 { .. } => HS_HEADER,
#[cfg(feature = "rsa")]
ComputedAlgo::RSASha256 { .. } => RS_HEADER,
#[cfg(feature = "ecdsa")]
ComputedAlgo::EcdsaP256 { .. } => EP_HEADER,
};
self.into_jwt(signing_algo, header, seconds_before_expiry, ip)
}
}
#[derive(Debug)]
pub enum Validation<T: Serialize + DeserializeOwned> {
Unauthorized,
Authorized(AuthData<T>),
}
enum IpBytes {
V4([u8; 4]),
V6([u8; 16]),
}
impl From<IpAddr> for IpBytes {
fn from(ip: IpAddr) -> Self {
match ip {
IpAddr::V4(ip) => Self::V4(ip.octets()),
IpAddr::V6(ip) => Self::V6(ip.octets()),
}
}
}
impl AsRef<[u8]> for IpBytes {
fn as_ref(&self) -> &[u8] {
match self {
Self::V4(addr) => addr,
Self::V6(addr) => addr,
}
}
}
trait Validate {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()>;
}
#[cfg(any(feature = "rsa", feature = "ecdsa"))]
impl Validate for ValidationAlgo {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
(&self).validate(data, signature, ip)
}
}
#[cfg(any(feature = "rsa", feature = "ecdsa"))]
impl<'a> Validate for &'a ValidationAlgo {
#[allow(unused_variables)] fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
match *self {
#[cfg(feature = "rsa")]
ValidationAlgo::RSASha256 { public_key } => {
let mut hasher = sha2::Sha256::new();
hasher.update(data);
if let Some(ip) = ip {
hasher.update(IpBytes::from(ip).as_ref());
}
let hash = hasher.finalize();
public_key
.verify(rsa::Pkcs1v15Sign::new::<sha2::Sha256>(), &hash, signature)
.map_err(|_| ())
}
#[cfg(feature = "ecdsa")]
ValidationAlgo::EcdsaP256 { public_key } => {
let sig = p256::ecdsa::Signature::from_der(signature).map_err(|_| ())?;
public_key.verify(data, &sig).map_err(|_| ())
}
}
}
}
impl Validate for ComputedAlgo {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
(&self).validate(data, signature, ip)
}
}
impl<'a> Validate for &'a ComputedAlgo {
#[allow(unused_variables)] fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
match *self {
#[cfg(feature = "rsa")]
ComputedAlgo::RSASha256 { public_key, .. } => {
let mut hasher = sha2::Sha256::new();
hasher.update(data);
if let Some(ip) = ip {
hasher.update(IpBytes::from(ip).as_ref());
}
let hash = hasher.finalize();
public_key
.verify(rsa::Pkcs1v15Sign::new::<sha2::Sha256>(), &hash, signature)
.map_err(|_| ())
}
#[cfg(feature = "hmac")]
ComputedAlgo::HmacSha256 { secret, .. } => {
let mut hmac = Hmac::<sha2::Sha256>::new_from_slice(secret).unwrap();
hmac.update(data);
if let Some(ip) = ip {
hmac.update(IpBytes::from(ip).as_ref());
}
let hash = hmac.finalize().into_bytes();
if &*hash == signature {
Ok(())
} else {
Err(())
}
}
#[cfg(feature = "ecdsa")]
ComputedAlgo::EcdsaP256 { public_key, .. } => {
let sig = p256::ecdsa::Signature::from_der(signature).map_err(|_| ())?;
if let Some(ip) = ip {
let mut buf = Vec::with_capacity(data.len() + 16);
buf.extend_from_slice(data);
buf.extend_from_slice(IpBytes::from(ip).as_ref());
public_key.verify(&buf, &sig).map_err(|_| ())
} else {
public_key.verify(data, &sig).map_err(|_| ())
}
}
}
}
}
impl Validate for Mode {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
(&self).validate(data, signature, ip)
}
}
impl<'a> Validate for &'a Mode {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
match *self {
Mode::Sign(s) => s.validate(data, signature, ip),
#[cfg(any(feature = "rsa", feature = "ecdsa"))]
Mode::Validate(v) => v.validate(data, signature, ip),
}
}
}
#[cfg(all(test, feature = "ecdsa"))]
impl<'a> Validate for &'a [u8] {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
let signing = ecdsa_sk(self);
let public_key = signing.verifying_key();
let sig = p256::ecdsa::Signature::from_der(signature).map_err(|_| ())?;
if let Some(ip) = ip {
let mut buf = Vec::with_capacity(data.len() + 16);
buf.extend_from_slice(data);
buf.extend_from_slice(IpBytes::from(ip).as_ref());
public_key.verify(&buf, &sig).map_err(|_| ())
} else {
public_key.verify(data, &sig).map_err(|_| ())
}
}
}
#[cfg(all(test, feature = "ecdsa"))]
impl<'a, const LEN: usize> Validate for &'a [u8; LEN] {
fn validate(&self, data: &[u8], signature: &[u8], ip: Option<IpAddr>) -> Result<(), ()> {
(&self[..]).validate(data, signature, ip)
}
}
macro_rules! or_unauthorized {
($v: expr) => {
if let Some(v) = $v {
v
} else {
return Self::Unauthorized;
}
};
}
#[cfg(feature = "structured")]
fn validate(s: &str, validate: impl Validate, ip: Option<IpAddr>) -> Option<serde_json::Value> {
let parts = s.splitn(3, '.').collect::<Vec<_>>();
if parts.len() != 3 {
return None;
}
let signature_input = &s[..parts[0].len() + 1 + parts[1].len()];
let remote_signature = BASE64_ENGINE.decode(parts[2]).ok()?;
if validate
.validate(signature_input.as_bytes(), &remote_signature, ip)
.is_err()
{
return None;
}
let payload = BASE64_ENGINE
.decode(parts[1])
.ok()
.and_then(|p| String::from_utf8(p).ok())?;
let mut payload_value: serde_json::Value = payload.parse().ok()?;
let payload = payload_value.as_object_mut()?;
let exp = payload.get("exp").and_then(|v| v.as_u64())?;
let iat = payload.get("iat").and_then(|v| v.as_u64())?;
let now = seconds_since_epoch();
if exp < now || iat > now {
return None;
}
Some(payload_value)
}
#[cfg(not(feature = "structured"))]
fn validate(s: &str, validate: impl Validate, ip: Option<IpAddr>) -> Option<JwtData> {
let parts = s.splitn(3, '.').collect::<Vec<_>>();
if parts.len() != 3 {
return None;
}
let signature_input = &s[..parts[0].len() + 1 + parts[1].len()];
let remote_signature = BASE64_ENGINE.decode(parts[2]).ok()?;
if validate
.validate(signature_input.as_bytes(), &remote_signature, ip)
.is_err()
{
return None;
}
let payload = BASE64_ENGINE
.decode(parts[1])
.ok()
.and_then(|p| String::from_utf8(p).ok())?;
let mut entries = payload.strip_prefix('{')?.strip_suffix('}')?.trim();
let mut data = JwtData::default();
let mut last_missed_comma = false;
loop {
entries = if let Some(s) = entries.strip_prefix(',') {