Skip to main content

Module 0x2::authenticator_state

use 0x1::option;
use 0x1::string;
use 0x1::u64;
use 0x2::dynamic_field;
use 0x2::object;
use 0x2::transfer;
use 0x2::tx_context;

Resource AuthenticatorState

Singleton shared object which stores the global authenticator state. The actual state is stored in a dynamic field of type AuthenticatorStateInner to support future versions of the authenticator state.

struct AuthenticatorState has key
Fields
id: object::UID
version: u64

Struct AuthenticatorStateInner

struct AuthenticatorStateInner has store
Fields
version: u64
active_jwks: vector<authenticator_state::ActiveJwk>
List of currently active JWKs.

Struct JWK

Must match the JWK struct in fastcrypto-zkp

struct JWK has copy, drop, store
Fields

Struct JwkId

Must match the JwkId struct in fastcrypto-zkp

struct JwkId has copy, drop, store
Fields

Struct ActiveJwk

struct ActiveJwk has copy, drop, store
Fields

Constants

Sender is not @0x0 the system address.

const ENotSystemAddress: u64 = 0;

const CurrentVersion: u64 = 1;

const EJwksNotSorted: u64 = 2;

const EWrongInnerVersion: u64 = 1;

Function active_jwk_equal

fun active_jwk_equal(a: &authenticator_state::ActiveJwk, b: &authenticator_state::ActiveJwk): bool
Implementation
fun active_jwk_equal(a: &ActiveJwk, b: &ActiveJwk): bool {
    // note: epoch is ignored
    jwk_equal(&a.jwk, &b.jwk) && jwk_id_equal(&a.jwk_id, &b.jwk_id)
}

Function jwk_equal

fun jwk_equal(a: &authenticator_state::JWK, b: &authenticator_state::JWK): bool
Implementation
fun jwk_equal(a: &JWK, b: &JWK): bool {
    (&a.kty == &b.kty) &&
       (&a.e == &b.e) &&
       (&a.n == &b.n) &&
       (&a.alg == &b.alg)
}

Function jwk_id_equal

fun jwk_id_equal(a: &authenticator_state::JwkId, b: &authenticator_state::JwkId): bool
Implementation
fun jwk_id_equal(a: &JwkId, b: &JwkId): bool {
    (&a.iss == &b.iss) && (&a.kid == &b.kid)
}

Function string_bytes_lt

fun string_bytes_lt(a: &string::String, b: &string::String): bool
Implementation
fun string_bytes_lt(a: &String, b: &String): bool {
    let a_bytes = a.as_bytes();
    let b_bytes = b.as_bytes();

    if (a_bytes.length() < b_bytes.length()) {
        true
    } else if (a_bytes.length() > b_bytes.length()) {
        false
    } else {
        let mut i = 0;
        while (i < a_bytes.length()) {
            let a_byte = a_bytes[i];
            let b_byte = b_bytes[i];
            if (a_byte < b_byte) {
                return true
            } else if (a_byte > b_byte) {
                return false
            };
            i = i + 1;
        };
        // all bytes are equal
        false
    }
}

Function jwk_lt

fun jwk_lt(a: &authenticator_state::ActiveJwk, b: &authenticator_state::ActiveJwk): bool
Implementation
fun jwk_lt(a: &ActiveJwk, b: &ActiveJwk): bool {
    // note: epoch is ignored
    if (&a.jwk_id.iss != &b.jwk_id.iss) {
        return string_bytes_lt(&a.jwk_id.iss, &b.jwk_id.iss)
    };
    if (&a.jwk_id.kid != &b.jwk_id.kid) {
        return string_bytes_lt(&a.jwk_id.kid, &b.jwk_id.kid)
    };
    if (&a.jwk.kty != &b.jwk.kty) {
        return string_bytes_lt(&a.jwk.kty, &b.jwk.kty)
    };
    if (&a.jwk.e != &b.jwk.e) {
        return string_bytes_lt(&a.jwk.e, &b.jwk.e)
    };
    if (&a.jwk.n != &b.jwk.n) {
        return string_bytes_lt(&a.jwk.n, &b.jwk.n)
    };
    string_bytes_lt(&a.jwk.alg, &b.jwk.alg)
}

Function create

Create and share the AuthenticatorState object. This function is call exactly once, when the authenticator state object is first created. Can only be called by genesis or change_epoch transactions.

fun create(ctx: &tx_context::TxContext)
Implementation
fun create(ctx: &TxContext) {
    assert!(ctx.sender() == @0x0, ENotSystemAddress);

    let version = CurrentVersion;

    let inner = AuthenticatorStateInner {
        version,
        active_jwks: vector[],
    };

    let mut self = AuthenticatorState {
        id: object::authenticator_state(),
        version,
    };

    dynamic_field::add(&mut self.id, version, inner);
    transfer::share_object(self);
}

Function load_inner_mut

fun load_inner_mut(self: &mut authenticator_state::AuthenticatorState): &mut authenticator_state::AuthenticatorStateInner
Implementation
fun load_inner_mut(
    self: &mut AuthenticatorState,
): &mut AuthenticatorStateInner {
    let version = self.version;

    // replace this with a lazy update function when we add a new version of the inner object.
    assert!(version == CurrentVersion, EWrongInnerVersion);

    let inner: &mut AuthenticatorStateInner = dynamic_field::borrow_mut(&mut self.id, self.version);

    assert!(inner.version == version, EWrongInnerVersion);
    inner
}

Function load_inner

fun load_inner(self: &authenticator_state::AuthenticatorState): &authenticator_state::AuthenticatorStateInner
Implementation
fun load_inner(
    self: &AuthenticatorState,
): &AuthenticatorStateInner {
    let version = self.version;

    // replace this with a lazy update function when we add a new version of the inner object.
    assert!(version == CurrentVersion, EWrongInnerVersion);

    let inner: &AuthenticatorStateInner = dynamic_field::borrow(&self.id, self.version);

    assert!(inner.version == version, EWrongInnerVersion);
    inner
}

Function check_sorted

fun check_sorted(new_active_jwks: &vector<authenticator_state::ActiveJwk>)
Implementation
fun check_sorted(new_active_jwks: &vector<ActiveJwk>) {
    let mut i = 0;
    while (i < new_active_jwks.length() - 1) {
        let a = &new_active_jwks[i];
        let b = &new_active_jwks[i + 1];
        assert!(jwk_lt(a, b), EJwksNotSorted);
        i = i + 1;
    };
}

Function update_authenticator_state

Record a new set of active_jwks. Called when executing the AuthenticatorStateUpdate system transaction. The new input vector must be sorted and must not contain duplicates. If a new JWK is already present, but with a previous epoch, then the epoch is updated to indicate that the JWK has been validated in the current epoch and should not be expired.

fun update_authenticator_state(self: &mut authenticator_state::AuthenticatorState, new_active_jwks: vector<authenticator_state::ActiveJwk>, ctx: &tx_context::TxContext)
Implementation
fun update_authenticator_state(
    self: &mut AuthenticatorState,
    new_active_jwks: vector<ActiveJwk>,
    ctx: &TxContext,
) {
    // Validator will make a special system call with sender set as 0x0.
    assert!(ctx.sender() == @0x0, ENotSystemAddress);

    check_sorted(&new_active_jwks);
    let new_active_jwks = deduplicate(new_active_jwks);

    let inner = self.load_inner_mut();

    let mut res = vector[];
    let mut i = 0;
    let mut j = 0;
    let active_jwks_len = inner.active_jwks.length();
    let new_active_jwks_len = new_active_jwks.length();

    while (i < active_jwks_len && j < new_active_jwks_len) {
        let old_jwk = &inner.active_jwks[i];
        let new_jwk = &new_active_jwks[j];

        // when they are equal, push only one, but use the max epoch of the two
        if (active_jwk_equal(old_jwk, new_jwk)) {
            let mut jwk = *old_jwk;
            jwk.epoch = old_jwk.epoch.max(new_jwk.epoch);
            res.push_back(jwk);
            i = i + 1;
            j = j + 1;
        } else if (jwk_id_equal(&old_jwk.jwk_id, &new_jwk.jwk_id)) {
            // if only jwk_id is equal, then the key has changed. Providers should not send
            // JWKs like this, but if they do, we must ignore the new JWK to avoid having a
            // liveness / forking issues
            res.push_back(*old_jwk);
            i = i + 1;
            j = j + 1;
        } else if (jwk_lt(old_jwk, new_jwk)) {
            res.push_back(*old_jwk);
            i = i + 1;
        } else {
            res.push_back(*new_jwk);
            j = j + 1;
        }
    };

    while (i < active_jwks_len) {
        res.push_back(inner.active_jwks[i]);
        i = i + 1;
    };
    while (j < new_active_jwks_len) {
        res.push_back(new_active_jwks[j]);
        j = j + 1;
    };

    inner.active_jwks = res;
}

Function deduplicate

fun deduplicate(jwks: vector<authenticator_state::ActiveJwk>): vector<authenticator_state::ActiveJwk>
Implementation
fun deduplicate(jwks: vector<ActiveJwk>): vector<ActiveJwk> {
    let mut res = vector[];
    let mut i = 0;
    let mut prev: Option<JwkId> = option::none();
    while (i < jwks.length()) {
        let jwk = &jwks[i];
        if (prev.is_none()) {
            prev.fill(jwk.jwk_id);
        } else if (jwk_id_equal(prev.borrow(), &jwk.jwk_id)) {
            // skip duplicate jwks in input
            i = i + 1;
            continue
        } else {
            *prev.borrow_mut() = jwk.jwk_id;
        };
        res.push_back(*jwk);
        i = i + 1;
    };
    res
}

Function expire_jwks

fun expire_jwks(self: &mut authenticator_state::AuthenticatorState, min_epoch: u64, ctx: &tx_context::TxContext)
Implementation
fun expire_jwks(
    self: &mut AuthenticatorState,
    // any jwk below this epoch is not retained
    min_epoch: u64,
    ctx: &TxContext) {
    // This will only be called by sui_system::advance_epoch
    assert!(ctx.sender() == @0x0, ENotSystemAddress);

    let inner = load_inner_mut(self);

    let len = inner.active_jwks.length();

    // first we count how many jwks from each issuer are above the min_epoch
    // and store the counts in a vector that parallels the (sorted) active_jwks vector
    let mut issuer_max_epochs = vector[];
    let mut i = 0;
    let mut prev_issuer: Option<String> = option::none();

    while (i < len) {
        let cur = &inner.active_jwks[i];
        let cur_iss = &cur.jwk_id.iss;
        if (prev_issuer.is_none()) {
            prev_issuer.fill(*cur_iss);
            issuer_max_epochs.push_back(cur.epoch);
        } else {
            if (cur_iss == prev_issuer.borrow()) {
                let back = issuer_max_epochs.length() - 1;
                let prev_max_epoch = &mut issuer_max_epochs[back];
                *prev_max_epoch = (*prev_max_epoch).max(cur.epoch);
            } else {
                *prev_issuer.borrow_mut() = *cur_iss;
                issuer_max_epochs.push_back(cur.epoch);
            }
        };
        i = i + 1;
    };

    // Now, filter out any JWKs that are below the min_epoch, unless that issuer has no
    // JWKs >= the min_epoch, in which case we keep all of them.
    let mut new_active_jwks: vector<ActiveJwk> = vector[];
    let mut prev_issuer: Option<String> = option::none();
    let mut i = 0;
    let mut j = 0;
    while (i < len) {
        let jwk = &inner.active_jwks[i];
        let cur_iss = &jwk.jwk_id.iss;

        if (prev_issuer.is_none()) {
            prev_issuer.fill(*cur_iss);
        } else if (cur_iss != prev_issuer.borrow()) {
            *prev_issuer.borrow_mut() = *cur_iss;
            j = j + 1;
        };

        let max_epoch_for_iss = &issuer_max_epochs[j];

        // TODO: if the iss for this jwk has *no* jwks that meet the minimum epoch,
        // then expire nothing.
        if (*max_epoch_for_iss < min_epoch || jwk.epoch >= min_epoch) {
            new_active_jwks.push_back(*jwk);
        };
        i = i + 1;
    };
    inner.active_jwks = new_active_jwks;
}

Function get_active_jwks

Get the current active_jwks. Called when the node starts up in order to load the current JWK state from the chain.

fun get_active_jwks(self: &authenticator_state::AuthenticatorState, ctx: &tx_context::TxContext): vector<authenticator_state::ActiveJwk>
Implementation
fun get_active_jwks(
    self: &AuthenticatorState,
    ctx: &TxContext,
): vector<ActiveJwk> {
    assert!(ctx.sender() == @0x0, ENotSystemAddress);
    self.load_inner().active_jwks
}