3rdparty/ccc: Fix some bounds checks and other error handling logic

This commit is contained in:
chaoticgd 2024-11-25 06:40:41 +00:00 committed by lightningterror
parent ffa06fbb09
commit 132e3e507d
7 changed files with 83 additions and 54 deletions

View File

@ -10,34 +10,35 @@ Result<ElfFile> ElfFile::parse(std::vector<u8> image)
ElfFile elf; ElfFile elf;
elf.image = std::move(image); elf.image = std::move(image);
const ElfIdentHeader* ident = get_packed<ElfIdentHeader>(elf.image, 0); const ElfIdentHeader* ident = get_unaligned<ElfIdentHeader>(elf.image, 0);
CCC_CHECK(ident, "ELF ident header out of range."); CCC_CHECK(ident, "ELF ident header out of range.");
CCC_CHECK(ident->magic == CCC_FOURCC("\x7f\x45\x4c\x46"), "Not an ELF file."); CCC_CHECK(ident->magic == CCC_FOURCC("\x7f\x45\x4c\x46"), "Not an ELF file.");
CCC_CHECK(ident->e_class == ElfIdentClass::B32, "Wrong ELF class (not 32 bit)."); CCC_CHECK(ident->e_class == ElfIdentClass::B32, "Wrong ELF class (not 32 bit).");
const ElfFileHeader* header = get_packed<ElfFileHeader>(elf.image, sizeof(ElfIdentHeader)); const ElfFileHeader* header = get_unaligned<ElfFileHeader>(elf.image, sizeof(ElfIdentHeader));
CCC_CHECK(header, "ELF file header out of range."); CCC_CHECK(header, "ELF file header out of range.");
elf.file_header = *header; elf.file_header = *header;
const ElfSectionHeader* shstr_section_header = get_packed<ElfSectionHeader>(elf.image, header->shoff + header->shstrndx * sizeof(ElfSectionHeader)); const ElfSectionHeader* shstr_section_header =
get_unaligned<ElfSectionHeader>(elf.image, header->shoff + header->shstrndx * sizeof(ElfSectionHeader));
CCC_CHECK(shstr_section_header, "ELF section name header out of range."); CCC_CHECK(shstr_section_header, "ELF section name header out of range.");
for(u32 i = 0; i < header->shnum; i++) { for(u32 i = 0; i < header->shnum; i++) {
u64 header_offset = header->shoff + i * sizeof(ElfSectionHeader); u64 header_offset = header->shoff + i * sizeof(ElfSectionHeader);
const ElfSectionHeader* section_header = get_packed<ElfSectionHeader>(elf.image, header_offset); const ElfSectionHeader* section_header = get_unaligned<ElfSectionHeader>(elf.image, header_offset);
CCC_CHECK(section_header, "ELF section header out of range."); CCC_CHECK(section_header, "ELF section header out of range.");
const char* name = get_string(elf.image, shstr_section_header->offset + section_header->name); std::optional<std::string_view> name = get_string(elf.image, shstr_section_header->offset + section_header->name);
CCC_CHECK(section_header, "ELF section name out of range."); CCC_CHECK(name.has_value(), "ELF section name out of range.");
ElfSection& section = elf.sections.emplace_back(); ElfSection& section = elf.sections.emplace_back();
section.name = name; section.name = *name;
section.header = *section_header; section.header = *section_header;
} }
for(u32 i = 0; i < header->phnum; i++) { for(u32 i = 0; i < header->phnum; i++) {
u64 header_offset = header->phoff + i * sizeof(ElfProgramHeader); u64 header_offset = header->phoff + i * sizeof(ElfProgramHeader);
const ElfProgramHeader* program_header = get_packed<ElfProgramHeader>(elf.image, header_offset); const ElfProgramHeader* program_header = get_unaligned<ElfProgramHeader>(elf.image, header_offset);
CCC_CHECK(program_header, "ELF program header out of range."); CCC_CHECK(program_header, "ELF program header out of range.");
elf.segments.emplace_back(*program_header); elf.segments.emplace_back(*program_header);

View File

@ -60,7 +60,7 @@ Result<void> import_symbols(
DemanglerFunctions demangler) DemanglerFunctions demangler)
{ {
for(u32 i = 0; i < symtab.size() / sizeof(Symbol); i++) { for(u32 i = 0; i < symtab.size() / sizeof(Symbol); i++) {
const Symbol* symbol = get_packed<Symbol>(symtab, i * sizeof(Symbol)); const Symbol* symbol = get_unaligned<Symbol>(symtab, i * sizeof(Symbol));
CCC_ASSERT(symbol); CCC_ASSERT(symbol);
Address address; Address address;
@ -86,13 +86,14 @@ Result<void> import_symbols(
} }
} }
const char* string = get_string(strtab, symbol->name); std::optional<std::string_view> string_view = get_string(strtab, symbol->name);
CCC_CHECK(string, "Symbol string out of range."); CCC_CHECK(string_view.has_value(), "Symbol string out of range.");
std::string string(*string_view);
switch(symbol->type()) { switch(symbol->type()) {
case SymbolType::NOTYPE: { case SymbolType::NOTYPE: {
Result<Label*> label = database.labels.create_symbol( Result<Label*> label = database.labels.create_symbol(
string, group.source, group.module_symbol, address, importer_flags, demangler); std::move(string), group.source, group.module_symbol, address, importer_flags, demangler);
CCC_RETURN_IF_ERROR(label); CCC_RETURN_IF_ERROR(label);
// These symbols get emitted at the same addresses as functions // These symbols get emitted at the same addresses as functions
@ -108,7 +109,7 @@ Result<void> import_symbols(
case SymbolType::OBJECT: { case SymbolType::OBJECT: {
if(symbol->size != 0) { if(symbol->size != 0) {
Result<GlobalVariable*> global_variable = database.global_variables.create_symbol( Result<GlobalVariable*> global_variable = database.global_variables.create_symbol(
string, group.source, group.module_symbol, address, importer_flags, demangler); std::move(string), group.source, group.module_symbol, address, importer_flags, demangler);
CCC_RETURN_IF_ERROR(global_variable); CCC_RETURN_IF_ERROR(global_variable);
if(*global_variable) { if(*global_variable) {
@ -116,7 +117,7 @@ Result<void> import_symbols(
} }
} else { } else {
Result<Label*> label = database.labels.create_symbol( Result<Label*> label = database.labels.create_symbol(
string, group.source, group.module_symbol, address, importer_flags, demangler); std::move(string), group.source, group.module_symbol, address, importer_flags, demangler);
CCC_RETURN_IF_ERROR(label); CCC_RETURN_IF_ERROR(label);
} }
@ -124,7 +125,7 @@ Result<void> import_symbols(
} }
case SymbolType::FUNC: { case SymbolType::FUNC: {
Result<Function*> function = database.functions.create_symbol( Result<Function*> function = database.functions.create_symbol(
string, group.source, group.module_symbol, address, importer_flags, demangler); std::move(string), group.source, group.module_symbol, address, importer_flags, demangler);
CCC_RETURN_IF_ERROR(function); CCC_RETURN_IF_ERROR(function);
if(*function) { if(*function) {
@ -135,7 +136,7 @@ Result<void> import_symbols(
} }
case SymbolType::FILE: { case SymbolType::FILE: {
Result<SourceFile*> source_file = database.source_files.create_symbol( Result<SourceFile*> source_file = database.source_files.create_symbol(
string, group.source, group.module_symbol); std::move(string), group.source, group.module_symbol);
CCC_RETURN_IF_ERROR(source_file); CCC_RETURN_IF_ERROR(source_file);
break; break;
@ -153,18 +154,18 @@ Result<void> print_symbol_table(FILE* out, std::span<const u8> symtab, std::span
fprintf(out, " Num: Value Size Type Bind Vis Ndx Name\n"); fprintf(out, " Num: Value Size Type Bind Vis Ndx Name\n");
for(u32 i = 0; i < symtab.size() / sizeof(Symbol); i++) { for(u32 i = 0; i < symtab.size() / sizeof(Symbol); i++) {
const Symbol* symbol = get_packed<Symbol>(symtab, i * sizeof(Symbol)); const Symbol* symbol = get_unaligned<Symbol>(symtab, i * sizeof(Symbol));
CCC_ASSERT(symbol); CCC_ASSERT(symbol);
const char* type = symbol_type_to_string(symbol->type()); const char* type = symbol_type_to_string(symbol->type());
const char* bind = symbol_bind_to_string(symbol->bind()); const char* bind = symbol_bind_to_string(symbol->bind());
const char* visibility = symbol_visibility_to_string(symbol->visibility()); const char* visibility = symbol_visibility_to_string(symbol->visibility());
const char* string = get_string(strtab, symbol->name); std::optional<std::string_view> string = get_string(strtab, symbol->name);
CCC_CHECK(string, "Symbol string out of range."); CCC_CHECK(string.has_value(), "Symbol string out of range.");
fprintf(out, "%6u: %08x %5u %-7s %-7s %-7s %3u %s\n", fprintf(out, "%6u: %08x %5u %-7s %-7s %-7s %3u %s\n",
i, symbol->value, symbol->size, type, bind, visibility, symbol->shndx, string); i, symbol->value, symbol->size, type, bind, visibility, symbol->shndx, string->data());
} }

View File

@ -90,7 +90,7 @@ Result<void> SymbolTableReader::init(std::span<const u8> elf, s32 section_offset
m_elf = elf; m_elf = elf;
m_section_offset = section_offset; m_section_offset = section_offset;
m_hdrr = get_packed<SymbolicHeader>(m_elf, m_section_offset); m_hdrr = get_unaligned<SymbolicHeader>(m_elf, m_section_offset);
CCC_CHECK(m_hdrr != nullptr, "MIPS debug section header out of bounds."); CCC_CHECK(m_hdrr != nullptr, "MIPS debug section header out of bounds.");
CCC_CHECK(m_hdrr->magic == 0x7009, "Invalid symbolic header."); CCC_CHECK(m_hdrr->magic == 0x7009, "Invalid symbolic header.");
@ -116,7 +116,7 @@ Result<File> SymbolTableReader::parse_file(s32 index) const
File file; File file;
u64 fd_offset = m_hdrr->file_descriptors_offset + index * sizeof(FileDescriptor); u64 fd_offset = m_hdrr->file_descriptors_offset + index * sizeof(FileDescriptor);
const FileDescriptor* fd_header = get_packed<FileDescriptor>(m_elf, fd_offset + m_fudge_offset); const FileDescriptor* fd_header = get_unaligned<FileDescriptor>(m_elf, fd_offset + m_fudge_offset);
CCC_CHECK(fd_header != nullptr, "MIPS debug file descriptor out of bounds."); CCC_CHECK(fd_header != nullptr, "MIPS debug file descriptor out of bounds.");
CCC_CHECK(fd_header->f_big_endian == 0, "Not little endian or bad file descriptor table."); CCC_CHECK(fd_header->f_big_endian == 0, "Not little endian or bad file descriptor table.");
@ -124,16 +124,16 @@ Result<File> SymbolTableReader::parse_file(s32 index) const
s32 rel_raw_path_offset = fd_header->strings_offset + fd_header->file_path_string_offset; s32 rel_raw_path_offset = fd_header->strings_offset + fd_header->file_path_string_offset;
s32 raw_path_offset = m_hdrr->local_strings_offset + rel_raw_path_offset + m_fudge_offset; s32 raw_path_offset = m_hdrr->local_strings_offset + rel_raw_path_offset + m_fudge_offset;
const char* command_line_path = get_string(m_elf, raw_path_offset); std::optional<std::string_view> command_line_path = get_string(m_elf, raw_path_offset);
if(command_line_path) { if(command_line_path.has_value()) {
file.command_line_path = command_line_path; file.command_line_path = *command_line_path;
} }
// Parse local symbols. // Parse local symbols.
for(s64 j = 0; j < fd_header->symbol_count; j++) { for(s64 j = 0; j < fd_header->symbol_count; j++) {
u64 rel_symbol_offset = (fd_header->isym_base + j) * sizeof(SymbolHeader); u64 rel_symbol_offset = (fd_header->isym_base + j) * sizeof(SymbolHeader);
u64 symbol_offset = m_hdrr->local_symbols_offset + rel_symbol_offset + m_fudge_offset; u64 symbol_offset = m_hdrr->local_symbols_offset + rel_symbol_offset + m_fudge_offset;
const SymbolHeader* symbol_header = get_packed<SymbolHeader>(m_elf, symbol_offset); const SymbolHeader* symbol_header = get_unaligned<SymbolHeader>(m_elf, symbol_offset);
CCC_CHECK(symbol_header != nullptr, "Symbol header out of bounds."); CCC_CHECK(symbol_header != nullptr, "Symbol header out of bounds.");
s32 strings_offset = m_hdrr->local_strings_offset + fd_header->strings_offset + m_fudge_offset; s32 strings_offset = m_hdrr->local_strings_offset + fd_header->strings_offset + m_fudge_offset;
@ -155,7 +155,7 @@ Result<File> SymbolTableReader::parse_file(s32 index) const
for(s64 i = 0; i < fd_header->procedure_descriptor_count; i++) { for(s64 i = 0; i < fd_header->procedure_descriptor_count; i++) {
u64 rel_procedure_offset = (fd_header->ipd_first + i) * sizeof(ProcedureDescriptor); u64 rel_procedure_offset = (fd_header->ipd_first + i) * sizeof(ProcedureDescriptor);
u64 procedure_offset = m_hdrr->procedure_descriptors_offset + rel_procedure_offset + m_fudge_offset; u64 procedure_offset = m_hdrr->procedure_descriptors_offset + rel_procedure_offset + m_fudge_offset;
const ProcedureDescriptor* procedure_descriptor = get_packed<ProcedureDescriptor>(m_elf, procedure_offset); const ProcedureDescriptor* procedure_descriptor = get_unaligned<ProcedureDescriptor>(m_elf, procedure_offset);
CCC_CHECK(procedure_descriptor != nullptr, "Procedure descriptor out of bounds."); CCC_CHECK(procedure_descriptor != nullptr, "Procedure descriptor out of bounds.");
CCC_CHECK(procedure_descriptor->symbol_index < file.symbols.size(), "Symbol index out of bounds."); CCC_CHECK(procedure_descriptor->symbol_index < file.symbols.size(), "Symbol index out of bounds.");
@ -175,7 +175,7 @@ Result<std::vector<Symbol>> SymbolTableReader::parse_external_symbols() const
std::vector<Symbol> external_symbols; std::vector<Symbol> external_symbols;
for(s64 i = 0; i < m_hdrr->external_symbols_count; i++) { for(s64 i = 0; i < m_hdrr->external_symbols_count; i++) {
u64 sym_offset = m_hdrr->external_symbols_offset + i * sizeof(ExternalSymbolHeader); u64 sym_offset = m_hdrr->external_symbols_offset + i * sizeof(ExternalSymbolHeader);
const ExternalSymbolHeader* external_header = get_packed<ExternalSymbolHeader>(m_elf, sym_offset + m_fudge_offset); const ExternalSymbolHeader* external_header = get_unaligned<ExternalSymbolHeader>(m_elf, sym_offset + m_fudge_offset);
CCC_CHECK(external_header != nullptr, "External header out of bounds."); CCC_CHECK(external_header != nullptr, "External header out of bounds.");
Result<Symbol> sym = get_symbol(external_header->symbol, m_elf, m_hdrr->external_strings_offset + m_fudge_offset); Result<Symbol> sym = get_symbol(external_header->symbol, m_elf, m_hdrr->external_strings_offset + m_fudge_offset);
@ -351,9 +351,9 @@ static Result<Symbol> get_symbol(const SymbolHeader& header, std::span<const u8>
{ {
Symbol symbol; Symbol symbol;
const char* string = get_string(elf, strings_offset + header.iss); std::optional<std::string_view> string = get_string(elf, strings_offset + header.iss);
CCC_CHECK(string, "Symbol has invalid string."); CCC_CHECK(string.has_value(), "Symbol has invalid string.");
symbol.string = string; symbol.string = string->data();
symbol.value = header.value; symbol.value = header.value;
symbol.symbol_type = (SymbolType) header.st; symbol.symbol_type = (SymbolType) header.st;

View File

@ -54,18 +54,19 @@ static const char* sndll_symbol_type_to_string(SNDLLSymbolType type);
Result<SNDLLFile> parse_sndll_file(std::span<const u8> image, Address address, SNDLLType type) Result<SNDLLFile> parse_sndll_file(std::span<const u8> image, Address address, SNDLLType type)
{ {
const u32* magic = get_packed<u32>(image, 0); std::optional<u32> magic = copy_unaligned<u32>(image, 0);
CCC_CHECK(magic.has_value(), "Failed to read SNDLL header.");
CCC_CHECK((*magic & 0xffffff) == CCC_FOURCC("SNR\00"), "Not a SNDLL %s.", address.valid() ? "section" : "file"); CCC_CHECK((*magic & 0xffffff) == CCC_FOURCC("SNR\00"), "Not a SNDLL %s.", address.valid() ? "section" : "file");
char version = *magic >> 24; char version = *magic >> 24;
switch(version) { switch(version) {
case '1': { case '1': {
const SNDLLHeaderV1* header = get_packed<SNDLLHeaderV1>(image, 0); const SNDLLHeaderV1* header = get_unaligned<SNDLLHeaderV1>(image, 0);
CCC_CHECK(header, "File too small to contain SNDLL V1 header."); CCC_CHECK(header, "File too small to contain SNDLL V1 header.");
return parse_sndll_common(image, address, type, header->common, SNDLL_V1); return parse_sndll_common(image, address, type, header->common, SNDLL_V1);
} }
case '2': { case '2': {
const SNDLLHeaderV2* header = get_packed<SNDLLHeaderV2>(image, 0); const SNDLLHeaderV2* header = get_unaligned<SNDLLHeaderV2>(image, 0);
CCC_CHECK(header, "File too small to contain SNDLL V2 header."); CCC_CHECK(header, "File too small to contain SNDLL V2 header.");
return parse_sndll_common(image, address, type, header->common, SNDLL_V2); return parse_sndll_common(image, address, type, header->common, SNDLL_V2);
} }
@ -84,10 +85,9 @@ static Result<SNDLLFile> parse_sndll_common(
sndll.version = version; sndll.version = version;
if(common.elf_path) { if(common.elf_path) {
const char* elf_path = get_string(image, common.elf_path); std::optional<std::string_view> elf_path = get_string(image, common.elf_path);
if(elf_path) { CCC_CHECK(elf_path.has_value(), "SNDLL header has invalid ELF path field.");
sndll.elf_path = elf_path; sndll.elf_path = *elf_path;
}
} }
CCC_CHECK(common.symbol_count < (32 * 1024 * 1024) / sizeof(SNDLLSymbol), "SNDLL symbol count is too high."); CCC_CHECK(common.symbol_count < (32 * 1024 * 1024) / sizeof(SNDLLSymbol), "SNDLL symbol count is too high.");
@ -95,10 +95,10 @@ static Result<SNDLLFile> parse_sndll_common(
for(u32 i = 0; i < common.symbol_count; i++) { for(u32 i = 0; i < common.symbol_count; i++) {
u32 symbol_offset = common.symbols - address.get_or_zero() + i * sizeof(SNDLLSymbolHeader); u32 symbol_offset = common.symbols - address.get_or_zero() + i * sizeof(SNDLLSymbolHeader);
const SNDLLSymbolHeader* symbol_header = get_packed<SNDLLSymbolHeader>(image, symbol_offset); const SNDLLSymbolHeader* symbol_header = get_unaligned<SNDLLSymbolHeader>(image, symbol_offset);
CCC_CHECK(symbol_header, "SNDLL symbol out of range."); CCC_CHECK(symbol_header, "SNDLL symbol out of range.");
const char* string = nullptr; std::optional<std::string_view> string;
if(symbol_header->string) { if(symbol_header->string) {
string = get_string(image, symbol_header->string - address.get_or_zero()); string = get_string(image, symbol_header->string - address.get_or_zero());
} }
@ -106,7 +106,9 @@ static Result<SNDLLFile> parse_sndll_common(
SNDLLSymbol& symbol = sndll.symbols.emplace_back(); SNDLLSymbol& symbol = sndll.symbols.emplace_back();
symbol.type = symbol_header->type; symbol.type = symbol_header->type;
symbol.value = symbol_header->value; symbol.value = symbol_header->value;
symbol.string = string; if(string.has_value()) {
symbol.string = *string;
}
} }
return sndll; return sndll;

View File

@ -7,8 +7,8 @@ namespace ccc {
Result<std::unique_ptr<SymbolFile>> parse_symbol_file(std::vector<u8> image, std::string file_name) Result<std::unique_ptr<SymbolFile>> parse_symbol_file(std::vector<u8> image, std::string file_name)
{ {
const u32* magic = get_packed<u32>(image, 0); const std::optional<u32> magic = copy_unaligned<u32>(image, 0);
CCC_CHECK(magic, "File too small."); CCC_CHECK(magic.has_value(), "File too small.");
std::unique_ptr<SymbolFile> symbol_file; std::unique_ptr<SymbolFile> symbol_file;

View File

@ -51,14 +51,17 @@ void set_custom_error_callback(CustomErrorCallback callback)
custom_error_callback = callback; custom_error_callback = callback;
} }
const char* get_string(std::span<const u8> bytes, u64 offset) std::optional<std::string_view> get_string(std::span<const u8> bytes, u64 offset)
{ {
for(const unsigned char* c = bytes.data() + offset; c < bytes.data() + bytes.size(); c++) { for(u64 i = offset; i < bytes.size(); i++) {
if(*c == '\0') { if(bytes[i] == '\0') {
return (const char*) &bytes[offset]; return std::string_view(
reinterpret_cast<const char*>(&bytes[offset]),
reinterpret_cast<const char*>(&bytes[i]));
} }
} }
return nullptr;
return std::nullopt;
} }
std::string merge_paths(const std::string& base, const std::string& path) std::string merge_paths(const std::string& base, const std::string& path)

View File

@ -72,7 +72,7 @@ void set_custom_error_callback(CustomErrorCallback callback);
} }
#define CCC_ABORT_IF_FALSE(condition, ...) \ #define CCC_ABORT_IF_FALSE(condition, ...) \
if (!(condition)) { \ if(!(condition)) { \
ccc::Error error = ccc::format_error(__FILE__, __LINE__, __VA_ARGS__); \ ccc::Error error = ccc::format_error(__FILE__, __LINE__, __VA_ARGS__); \
ccc::report_error(error); \ ccc::report_error(error); \
abort(); \ abort(); \
@ -208,16 +208,38 @@ void warn_impl(const char* source_file, int source_line, const char* format, Arg
#endif #endif
template <typename T> template <typename T>
const T* get_packed(std::span<const u8> bytes, u64 offset) const T* get_aligned(std::span<const u8> bytes, u64 offset)
{ {
if(offset + sizeof(T) <= bytes.size()) { if(offset > bytes.size() || bytes.size() - offset < sizeof(T) || offset % alignof(T) != 0) {
return reinterpret_cast<const T*>(&bytes[offset]);
} else {
return nullptr; return nullptr;
} }
return reinterpret_cast<const T*>(&bytes[offset]);
} }
const char* get_string(std::span<const u8> bytes, u64 offset); template <typename T>
const T* get_unaligned(std::span<const u8> bytes, u64 offset)
{
if(offset > bytes.size() || bytes.size() - offset < sizeof(T)) {
return nullptr;
}
return reinterpret_cast<const T*>(&bytes[offset]);
}
template <typename T>
std::optional<T> copy_unaligned(std::span<const u8> bytes, u64 offset)
{
if(offset > bytes.size() || bytes.size() - offset < sizeof(T)) {
return std::nullopt;
}
T value;
memcpy(&value, &bytes[offset], sizeof(T));
return value;
}
std::optional<std::string_view> get_string(std::span<const u8> bytes, u64 offset);
#define CCC_BEGIN_END(x) (x).begin(), (x).end() #define CCC_BEGIN_END(x) (x).begin(), (x).end()
#define CCC_ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) #define CCC_ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))