// SPDX-License-Identifier: GPL-2.0-only /* * Copyright (c) 2025, Google LLC. * Pasha Tatashin */ #define _GNU_SOURCE #include #include #include #include #include #include #include #include #include #include #include #include #include #include "luo_test_utils.h" int luo_open_device(void) { return open(LUO_DEVICE, O_RDWR); } int luo_create_session(int luo_fd, const char *name) { struct liveupdate_ioctl_create_session arg = { .size = sizeof(arg) }; snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s", LIVEUPDATE_SESSION_NAME_LENGTH - 1, name); if (ioctl(luo_fd, LIVEUPDATE_IOCTL_CREATE_SESSION, &arg) < 0) return -errno; return arg.fd; } int luo_retrieve_session(int luo_fd, const char *name) { struct liveupdate_ioctl_retrieve_session arg = { .size = sizeof(arg) }; snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s", LIVEUPDATE_SESSION_NAME_LENGTH - 1, name); if (ioctl(luo_fd, LIVEUPDATE_IOCTL_RETRIEVE_SESSION, &arg) < 0) return -errno; return arg.fd; } int create_and_preserve_memfd(int session_fd, int token, const char *data) { struct liveupdate_session_preserve_fd arg = { .size = sizeof(arg) }; long page_size = sysconf(_SC_PAGE_SIZE); void *map = MAP_FAILED; int mfd = -1, ret = -1; mfd = memfd_create("test_mfd", 0); if (mfd < 0) return -errno; if (ftruncate(mfd, page_size) != 0) goto out; map = mmap(NULL, page_size, PROT_WRITE, MAP_SHARED, mfd, 0); if (map == MAP_FAILED) goto out; snprintf(map, page_size, "%s", data); munmap(map, page_size); arg.fd = mfd; arg.token = token; if (ioctl(session_fd, LIVEUPDATE_SESSION_PRESERVE_FD, &arg) < 0) goto out; ret = 0; out: if (ret != 0 && errno != 0) ret = -errno; if (mfd >= 0) close(mfd); return ret; } int restore_and_verify_memfd(int session_fd, int token, const char *expected_data) { struct liveupdate_session_retrieve_fd arg = { .size = sizeof(arg) }; long page_size = sysconf(_SC_PAGE_SIZE); void *map = MAP_FAILED; int mfd = -1, ret = -1; arg.token = token; if (ioctl(session_fd, LIVEUPDATE_SESSION_RETRIEVE_FD, &arg) < 0) return -errno; mfd = arg.fd; map = mmap(NULL, page_size, PROT_READ, MAP_SHARED, mfd, 0); if (map == MAP_FAILED) goto out; if (expected_data && strcmp(expected_data, map) != 0) { ksft_print_msg("Data mismatch! Expected '%s', Got '%s'\n", expected_data, (char *)map); ret = -EINVAL; goto out_munmap; } ret = mfd; out_munmap: munmap(map, page_size); out: if (ret < 0 && errno != 0) ret = -errno; if (ret < 0 && mfd >= 0) close(mfd); return ret; } int luo_session_finish(int session_fd) { struct liveupdate_session_finish arg = { .size = sizeof(arg) }; if (ioctl(session_fd, LIVEUPDATE_SESSION_FINISH, &arg) < 0) return -errno; return 0; } void create_state_file(int luo_fd, const char *session_name, int token, int next_stage) { char buf[32]; int state_session_fd; state_session_fd = luo_create_session(luo_fd, session_name); if (state_session_fd < 0) fail_exit("luo_create_session for state tracking"); snprintf(buf, sizeof(buf), "%d", next_stage); if (create_and_preserve_memfd(state_session_fd, token, buf) < 0) fail_exit("create_and_preserve_memfd for state tracking"); /* * DO NOT close session FD, otherwise it is going to be unpreserved */ } void restore_and_read_stage(int state_session_fd, int token, int *stage) { char buf[32] = {0}; int mfd; mfd = restore_and_verify_memfd(state_session_fd, token, NULL); if (mfd < 0) fail_exit("failed to restore state memfd"); if (read(mfd, buf, sizeof(buf) - 1) < 0) fail_exit("failed to read state mfd"); *stage = atoi(buf); close(mfd); } void daemonize_and_wait(void) { pid_t pid; ksft_print_msg("[STAGE 1] Forking persistent child to hold sessions...\n"); pid = fork(); if (pid < 0) fail_exit("fork failed"); if (pid > 0) { ksft_print_msg("[STAGE 1] Child PID: %d. Resources are pinned.\n", pid); ksft_print_msg("[STAGE 1] You may now perform kexec reboot.\n"); exit(EXIT_SUCCESS); } /* Detach from terminal so closing the window doesn't kill us */ if (setsid() < 0) fail_exit("setsid failed"); close(STDIN_FILENO); close(STDOUT_FILENO); close(STDERR_FILENO); /* Change dir to root to avoid locking filesystems */ if (chdir("/") < 0) exit(EXIT_FAILURE); while (1) sleep(60); } static int parse_stage_args(int argc, char *argv[]) { static struct option long_options[] = { {"stage", required_argument, 0, 's'}, {0, 0, 0, 0} }; int option_index = 0; int stage = 1; int opt; optind = 1; while ((opt = getopt_long(argc, argv, "s:", long_options, &option_index)) != -1) { switch (opt) { case 's': stage = atoi(optarg); if (stage != 1 && stage != 2) fail_exit("Invalid stage argument"); break; default: fail_exit("Unknown argument"); } } return stage; } int luo_test(int argc, char *argv[], const char *state_session_name, luo_test_stage1_fn stage1, luo_test_stage2_fn stage2) { int target_stage = parse_stage_args(argc, argv); int luo_fd = luo_open_device(); int state_session_fd; int detected_stage; if (luo_fd < 0) { ksft_exit_skip("Failed to open %s. Is the luo module loaded?\n", LUO_DEVICE); } state_session_fd = luo_retrieve_session(luo_fd, state_session_name); if (state_session_fd == -ENOENT) detected_stage = 1; else if (state_session_fd >= 0) detected_stage = 2; else fail_exit("Failed to check for state session"); if (target_stage != detected_stage) { ksft_exit_fail_msg("Stage mismatch Requested --stage %d, but system is in stage %d.\n" "(State session %s: %s)\n", target_stage, detected_stage, state_session_name, (detected_stage == 2) ? "EXISTS" : "MISSING"); } if (target_stage == 1) stage1(luo_fd); else stage2(luo_fd, state_session_fd); return 0; }