mte/kasan: Implementing KASAN memory protection for ARM64 hardware MTE

1. Add hw_tags.c, which will call arm64_mte to implement tagging of memory blocks by operating registers
2. It has been able to run normally on the default NX memory allocator, excluding mempool and tlsf
3. On more complex configurations, memory tests such as memstress can run normally without system crashes

Signed-off-by: wangmingrong1 <wangmingrong1@xiaomi.com>
This commit is contained in:
wangmingrong1 2025-03-10 12:50:07 +08:00 committed by Xiang Xiao
parent 448ace4761
commit 8f541d2ef2
15 changed files with 326 additions and 22 deletions

View file

@ -546,6 +546,10 @@ config ARCH_HAVE_DEBUG
bool "Architecture have debug support"
default n
config ARCH_HAVE_MEMTAG
bool "Architecture have memory tagging support"
default n
config ARCH_HAVE_PERF_EVENTS
bool
default n

View file

@ -185,6 +185,7 @@ menu "ARMv8.5 architectural features"
config ARM64_MTE
bool "Memory Tagging Extension support"
select ARCH_HAVE_MEMTAG
select ARM64_TBI
default y

View file

@ -506,11 +506,15 @@ uint64_t arm64_get_mpid(int cpu);
int arm64_get_cpuid(uint64_t mpid);
#endif
#ifdef CONFIG_ARM64_MTE
void arm64_enable_mte(void);
#else
#define arm64_enable_mte()
#endif
/****************************************************************************
* Name: arm64_mte_init
*
* Description:
* Initialize MTE settings and enable memory tagging
*
****************************************************************************/
void arm64_mte_init(void);
#endif /* __ASSEMBLY__ */

View file

@ -170,8 +170,10 @@
* in the address range [59:55] = 0b00000 are unchecked accesses.
*/
#define TCR_TCMA0 (1ULL << 57)
#define TCR_TCMA1 (1ULL << 58)
#define TCR_TCMA0 BIT(57)
#define TCR_TCMA1 BIT(58)
#define TCR_MTX0_SHIFT BIT(60)
#define TCR_MTX1_SHIFT BIT(61)
#define TCR_PS_BITS_4GB 0x0ULL
#define TCR_PS_BITS_64GB 0x1ULL

View file

@ -29,18 +29,24 @@
#include <stdio.h>
#include "arm64_arch.h"
#include "arm64_mmu.h"
/****************************************************************************
* Pre-processor Definitions
****************************************************************************/
#define GCR_EL1_VAL 0x10001
#define MTE_TAG_SHIFT 56
/* The alignment length of the MTE must be a multiple of sixteen */
#define MTE_MM_AILGN 16
/****************************************************************************
* Private Functions
****************************************************************************/
static int arm64_mte_is_support(void)
static int mte_is_support(void)
{
int supported;
__asm__ volatile (
@ -57,11 +63,67 @@ static int arm64_mte_is_support(void)
* Public Functions
****************************************************************************/
void arm64_enable_mte(void)
uint8_t up_memtag_get_tag(const void *addr)
{
return 0xf0 | (uint8_t)(((uint64_t)addr) >> MTE_TAG_SHIFT);
}
uint8_t up_memtag_get_random_tag(const void *addr)
{
asm("irg %0, %0" : "=r" (addr));
return up_memtag_get_tag(addr);
}
void *up_memtag_set_tag(const void *addr, uint8_t tag)
{
return (FAR void *)
((((uint64_t)addr) & ~((uint64_t)0xff << MTE_TAG_SHIFT)) |
((uint64_t)tag << MTE_TAG_SHIFT));
}
/* Set MTE state */
bool up_memtag_bypass(bool bypass)
{
uint64_t val = read_sysreg(sctlr_el1);
bool state = !(val & SCTLR_TCF1_BIT);
if (bypass)
{
val &= ~SCTLR_TCF1_BIT;
}
else
{
val |= SCTLR_TCF1_BIT;
}
write_sysreg(val, sctlr_el1);
return state;
}
/* Set memory tags for a given memory range */
void up_memtag_tag_mem(const void *addr, size_t size)
{
size_t i;
DEBUGASSERT((uintptr_t)addr % MTE_MM_AILGN == 0);
DEBUGASSERT(size % MTE_MM_AILGN == 0);
for (i = 0; i < size; i += MTE_MM_AILGN)
{
asm("stg %0, [%0]" : : "r"(addr + i));
}
}
/* Initialize MTE settings and enable memory tagging */
void arm64_mte_init(void)
{
uint64_t val;
if (!arm64_mte_is_support())
if (!mte_is_support())
{
return;
}
@ -78,6 +140,14 @@ void arm64_enable_mte(void)
assert(!(read_sysreg(ttbr0_el1) & TTBR_CNP_BIT));
assert(!(read_sysreg(ttbr1_el1) & TTBR_CNP_BIT));
/* Controls the default value for skipping high bytes */
val = read_sysreg(tcr_el1);
val |= TCR_TCMA1;
write_sysreg(val, tcr_el1);
/* Enable the MTE function */
val = read_sysreg(sctlr_el1);
val |= SCTLR_ATA_BIT | SCTLR_TCF1_BIT;
write_sysreg(val, sctlr_el1);

View file

@ -161,7 +161,9 @@ void arm64_chip_boot(void)
arm64_mmu_init(true);
arm64_enable_mte();
#ifdef CONFIG_ARM64_MTE
arm64_mte_init();
#endif
#ifdef CONFIG_DEVICE_TREE
fdt_register((const char *)0x40000000);

View file

@ -3067,6 +3067,65 @@ int up_get_legacy_irq(uint32_t devfn, uint8_t line, uint8_t pin);
#endif
#ifdef CONFIG_ARCH_HAVE_SYSCALL
/****************************************************************************
* Name: up_assert
****************************************************************************/
void up_assert(FAR const char *filename, int linenum, FAR const char *msg);
#endif
#ifdef CONFIG_ARCH_HAVE_MEMTAG
/****************************************************************************
* Name: up_memtag_bypass
*
* Description:
* Set MTE state bypass or not
*
****************************************************************************/
bool up_memtag_bypass(bool bypass);
/****************************************************************************
* Name: up_memtag_get_tag
****************************************************************************/
uint8_t up_memtag_get_tag(const void *addr);
/****************************************************************************
* Name: up_memtag_get_random_tag
*
* Description:
* Get a random label based on the address through the mte register
*
****************************************************************************/
uint8_t up_memtag_get_random_tag(const void *addr);
/****************************************************************************
* Name: up_memtag_set_tag
*
* Description:
* Get the address with label
*
****************************************************************************/
void *up_memtag_set_tag(const void *addr, uint8_t tag);
/****************************************************************************
* Name: up_memtag_tag_mem
*
* Description:
* Set memory tags for a given memory range
*
****************************************************************************/
void up_memtag_tag_mem(const void *addr, size_t size);
#endif /* CONFIG_ARCH_HAVE_MEMTAG */
#undef EXTERN
#if defined(__cplusplus)
}

View file

@ -48,6 +48,7 @@
# define kasan_stop()
# define kasan_debugpoint(t,a,s) 0
# define kasan_init_early()
# define kasan_bypass(state) (state)
#else
# define kasan_init_early() kasan_stop()
@ -180,7 +181,11 @@ uint8_t kasan_get_tag(FAR const void *addr);
*
****************************************************************************/
#ifdef CONFIG_MM_KASAN_INSTRUMENT
void kasan_start(void);
#else
# define kasan_start()
#endif
/****************************************************************************
* Name: kasan_stop
@ -198,7 +203,11 @@ void kasan_start(void);
*
****************************************************************************/
#ifdef CONFIG_MM_KASAN_INSTRUMENT
void kasan_stop(void);
#else
# define kasan_stop()
#endif
/****************************************************************************
* Name: kasan_debugpoint
@ -222,6 +231,12 @@ void kasan_stop(void);
int kasan_debugpoint(int type, FAR void *addr, size_t size);
/****************************************************************************
* Name: kasan_bypass
****************************************************************************/
bool kasan_bypass(bool state);
#undef EXTERN
#ifdef __cplusplus
}

View file

@ -32,14 +32,19 @@ config MM_KASAN_GENERIC
KASan generic mode that does not require hardware support at all
config MM_KASAN_SW_TAGS
bool "KAsan SW tags"
select ARM64_TBI
bool "KAsan softtags tags"
select ARM64_TBI if ARCH_ARM64
select MM_KASAN_INSTRUMENT
depends on ARCH_ARM64
---help---
KAsan based on software tags
endchoice # KAsan Mode
config MM_KASAN_HW_TAGS
bool "KAsan hardware tags"
select ARM64_MTE if ARCH_ARM64
---help---
KAsan based on hardware tags
endchoice
config MM_KASAN_INSTRUMENT_ALL
bool "Enable KASan for the entire image"
@ -52,6 +57,8 @@ config MM_KASAN_INSTRUMENT_ALL
to check. Enabling this option will get image size increased
and performance decreased significantly.
if MM_KASAN_INSTRUMENT
config MM_KASAN_REGIONS
int "Kasan region count"
default 8
@ -126,4 +133,5 @@ config MM_KASAN_GLOBAL_ALIGN
It is recommended to use 1, 2, 4, 8, 16, 32.
The maximum value is 32.
endif # MM_KASAN_INSTRUMENT
endif # MM_KASAN

View file

@ -230,6 +230,11 @@ FAR void *kasan_unpoison(FAR const void *addr, size_t size)
return (FAR void *)addr;
}
bool kasan_bypass(bool state)
{
return false;
}
void kasan_register(FAR void *addr, FAR size_t *size)
{
FAR struct kasan_region_s *region;

View file

@ -42,6 +42,8 @@
# include "generic.c"
#elif defined(CONFIG_MM_KASAN_SW_TAGS)
# include "sw_tags.c"
#elif defined(CONFIG_MM_KASAN_HW_TAGS)
# include "hw_tags.c"
#else
# define kasan_is_poisoned(addr, size) false
#endif
@ -96,6 +98,8 @@
#define KASAN_INIT_VALUE 0xcafe
#ifdef CONFIG_MM_KASAN_INSTRUMENT
/****************************************************************************
* Private Types
****************************************************************************/
@ -405,3 +409,6 @@ DEFINE_ASAN_LOAD_STORE(2)
DEFINE_ASAN_LOAD_STORE(4)
DEFINE_ASAN_LOAD_STORE(8)
DEFINE_ASAN_LOAD_STORE(16)
#endif /* CONFIG_MM_KASAN_INSTRUMENT */

94
mm/kasan/hw_tags.c Normal file
View file

@ -0,0 +1,94 @@
/****************************************************************************
* mm/kasan/hw_tags.c
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. The
* ASF licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
****************************************************************************/
/****************************************************************************
* Included Files
****************************************************************************/
#include <nuttx/arch.h>
/****************************************************************************
* Private Function
****************************************************************************/
static FAR void *
kasan_poison_tag(FAR const void *addr, size_t size, uint8_t tag)
{
FAR void *tag_addr;
/* Label this address pointer */
tag_addr = up_memtag_set_tag(addr, tag);
/* Add MTE hardware label to memory block */
up_memtag_tag_mem(tag_addr, size);
return tag_addr;
}
/****************************************************************************
* Public Functions
****************************************************************************/
bool kasan_bypass(bool state)
{
return up_memtag_bypass(state);
}
FAR void *kasan_clear_tag(FAR const void *addr)
{
return up_memtag_set_tag(addr, 0);
}
void kasan_poison(FAR const void *addr, size_t size)
{
uint8_t tag = up_memtag_get_random_tag(addr);
kasan_poison_tag(addr, size, tag);
}
uint8_t kasan_get_tag(FAR const void *addr)
{
return up_memtag_get_tag(addr);
}
FAR void *kasan_set_tag(FAR const void *addr, uint8_t tag)
{
return up_memtag_set_tag(addr, tag);
}
FAR void *kasan_unpoison(FAR const void *addr, size_t size)
{
uint8_t tag = up_memtag_get_random_tag(addr);
return kasan_poison_tag(addr, size, tag);
}
void kasan_register(FAR void *addr, FAR size_t *size)
{
uint8_t tag = up_memtag_get_random_tag(addr);
kasan_poison_tag(addr, *size, tag);
}
void kasan_unregister(FAR void *addr)
{
}

View file

@ -32,6 +32,7 @@
#include <debug.h>
#include <nuttx/arch.h>
#include <nuttx/mm/kasan.h>
#include <nuttx/mm/mm.h>
#include "mm_heap/mm.h"
@ -69,7 +70,16 @@ int mm_lock(FAR struct mm_heap_s *heap)
* Or, touch the heap internal data directly.
*/
return nxmutex_is_locked(&heap->mm_lock) ? -EAGAIN : 0;
if (nxmutex_is_locked(&heap->mm_lock))
{
return -EAGAIN;
}
else
{
kasan_bypass(true);
return 0;
}
# else
/* Can't take mutex in SMP interrupt handler */
@ -94,7 +104,13 @@ int mm_lock(FAR struct mm_heap_s *heap)
}
else
{
return nxmutex_lock(&heap->mm_lock);
int ret = nxmutex_lock(&heap->mm_lock);
if (ret >= 0)
{
kasan_bypass(true);
}
return 0;
}
}
@ -111,10 +127,12 @@ void mm_unlock(FAR struct mm_heap_s *heap)
#if defined(CONFIG_BUILD_FLAT) || defined(__KERNEL__)
if (up_interrupt_context())
{
kasan_bypass(false);
return;
}
#endif
kasan_bypass(false);
DEBUGVERIFY(nxmutex_unlock(&heap->mm_lock));
}
@ -128,8 +146,12 @@ void mm_unlock(FAR struct mm_heap_s *heap)
irqstate_t mm_lock_irq(FAR struct mm_heap_s *heap)
{
irqstate_t flags = up_irq_save();
UNUSED(heap);
return up_irq_save();
kasan_bypass(true);
return flags;
}
/****************************************************************************
@ -143,5 +165,6 @@ irqstate_t mm_lock_irq(FAR struct mm_heap_s *heap)
void mm_unlock_irq(FAR struct mm_heap_s *heap, irqstate_t state)
{
UNUSED(heap);
kasan_bypass(false);
up_irq_restore(state);
}

View file

@ -29,6 +29,7 @@
#include <assert.h>
#include <debug.h>
#include <nuttx/mm/kasan.h>
#include <nuttx/mm/mm.h>
#include "mm_heap/mm.h"
@ -40,12 +41,17 @@
size_t mm_malloc_size(FAR struct mm_heap_s *heap, FAR void *mem)
{
FAR struct mm_freenode_s *node;
ssize_t size;
bool flag;
flag = kasan_bypass(true);
#ifdef CONFIG_MM_HEAP_MEMPOOL
if (heap->mm_mpool)
{
ssize_t size = mempool_multiple_alloc_size(heap->mm_mpool, mem);
size = mempool_multiple_alloc_size(heap->mm_mpool, mem);
if (size >= 0)
{
kasan_bypass(flag);
return size;
}
}
@ -66,5 +72,8 @@ size_t mm_malloc_size(FAR struct mm_heap_s *heap, FAR void *mem)
DEBUGASSERT(MM_NODE_IS_ALLOC(node));
return MM_SIZEOF_NODE(node) - MM_ALLOCNODE_OVERHEAD;
size = MM_SIZEOF_NODE(node) - MM_ALLOCNODE_OVERHEAD;
kasan_bypass(flag);
return size;
}

View file

@ -388,11 +388,12 @@ FAR void *mm_realloc(FAR struct mm_heap_s *heap, FAR void *oldmem,
heap->mm_curused - newsize);
sched_note_heap(NOTE_HEAP_ALLOC, heap, newmem, newsize,
heap->mm_curused);
size = MM_SIZEOF_NODE(oldnode);
mm_unlock(heap);
MM_ADD_BACKTRACE(heap, (FAR char *)newmem - MM_SIZEOF_ALLOCNODE);
newmem = kasan_unpoison(newmem, MM_SIZEOF_NODE(oldnode) -
MM_ALLOCNODE_OVERHEAD);
newmem = kasan_unpoison(newmem, size - MM_ALLOCNODE_OVERHEAD);
oldmem = kasan_set_tag(oldmem, kasan_get_tag(newmem));
if (newmem != oldmem)