Source code for maru_handler.handler

# SPDX-License-Identifier: Apache-2.0
# Copyright 2026 XCENA Inc.
"""MaruHandler - Main interface for Maru shared memory KV cache client.

This module provides the primary entry point for clients to interact with
the Maru shared memory KV cache system.

Example:
    from maru import MaruConfig, MaruHandler

    config = MaruConfig(server_url="tcp://localhost:5555")
    with MaruHandler(config) as handler:
        # Zero-copy store: alloc → write to buf → store
        handle = handler.alloc(size=len(data))
        handle.buf[:] = data
        handler.store(key=12345, handle=handle)

        result = handler.retrieve(key=12345)  # returns MemoryInfo
"""

import ctypes
import logging
import threading

import numpy as np

from maru_common import MaruConfig
from maru_shm import MaruHandle

from .memory import (
    AllocHandle,
    DaxMapper,
    MemoryInfo,
    OwnedRegionManager,
    PagedMemoryAllocator,
)
from .rpc_client import RpcClient

logger = logging.getLogger(__name__)


def _gil_free_memcpy(dst: memoryview, src: memoryview | bytes, nbytes: int) -> None:
    """Copy *nbytes* from *src* into *dst*, releasing the GIL during copy.

    Uses ``ctypes.memmove`` which releases the GIL (all ctypes foreign-function
    calls do) for the actual memcpy, allowing other Python threads to run
    concurrently.
    """
    dst_c = (ctypes.c_char * nbytes).from_buffer(dst)
    if isinstance(src, memoryview) and not src.readonly:
        src_c = (ctypes.c_char * nbytes).from_buffer(src)
    elif isinstance(src, memoryview):
        # read-only memoryview — zero-copy view via numpy to get raw pointer
        arr = np.frombuffer(src[:nbytes], dtype=np.uint8)
        src_c = arr.ctypes.data
    else:
        # bytes — ctypes.memmove accepts bytes directly
        src_c = src
    ctypes.memmove(dst_c, src_c, nbytes)


[docs] class MaruHandler: """Main interface for Maru shared memory KV cache operations. This class handles: - Connection management to MaruServer - Memory mapping via DaxMapper - KV store/retrieve operations Thread-safety: - Read operations (exists, retrieve, batch_exists, batch_retrieve) are lock-free — they rely on RpcAsyncClient (already thread-safe) and DaxMapper's internal lock for lazy mapping. - Write operations (store, batch_store, delete) are serialized by ``_write_lock`` to guarantee atomicity of allocate-write-register. - ``close()`` sets ``_closing`` event to reject new operations, then acquires ``_write_lock`` to wait for in-flight writes before teardown. Architecture:: MaruHandler ├── RpcClient (server communication, sole RPC owner) ├── DaxMapper (memory mapping via MaruShmClient, owns all mmap/munmap) ├── OwnedRegionManager (owned regions + allocation, no RPC) │ ├── OwnedRegion 1 (PagedMemoryAllocator) │ ├── OwnedRegion 2 (PagedMemoryAllocator) │ └── ... └── _key_to_location (key -> (region_id, page_index)) """ def __init__(self, config: MaruConfig | None = None): """Initialize MaruHandler. Args: config: Configuration object. If None, uses defaults. """ self._config = config or MaruConfig() if self._config.use_async_rpc: from .rpc_async_client import RpcAsyncClient self._rpc = RpcAsyncClient( self._config.server_url, timeout_ms=self._config.timeout_ms, max_inflight=self._config.max_inflight, ) else: self._rpc = RpcClient( self._config.server_url, timeout_ms=self._config.timeout_ms, ) self._mapper = DaxMapper() # Managers (initialized on connect) self._owned: OwnedRegionManager | None = None # Thread-safety self._write_lock = threading.Lock() self._closing = threading.Event() # Connection state self._key_to_location: dict[str, tuple[int, int]] = {} self._connected = False logger.debug("Created MaruHandler with config: %s", self._config) # ========================================================================= # Connection Management # =========================================================================
[docs] def connect(self) -> bool: """Connect to the server and request a memory allocation. Returns: True if successful """ if self._connected: return True try: # 1. Connect RPC client self._rpc.connect() # 2. Initialize managers self._owned = OwnedRegionManager( mapper=self._mapper, chunk_size=self._config.chunk_size_bytes, ) # 3. Request initial owned region via RPC response = self._rpc.request_alloc( instance_id=self._config.instance_id, size=self._config.pool_size, ) if not response.success or response.handle is None: logger.error( "Failed to request initial allocation: %s", getattr(response, "error", "unknown"), ) self._owned = None self._rpc.close() return False # 4. Add region to OwnedRegionManager (mmap + allocator) try: self._owned.add_region(response.handle) except Exception: logger.error("Failed to init initial region", exc_info=True) try: self._rpc.return_alloc( self._config.instance_id, response.handle.region_id, ) except Exception: logger.debug( "Failed to return allocation during cleanup", exc_info=True ) self._owned = None self._rpc.close() return False self._connected = True # 5. Pre-map shared regions (eager mapping) if self._config.eager_map: self._premap_shared_regions() logger.info( "Connected: chunk_size=%d", self._config.chunk_size_bytes, ) return True except Exception: logger.error("Failed to connect", exc_info=True) return False
[docs] def close(self) -> None: """Close the connection and return all allocations. Sets ``_closing`` event to reject new operations, then acquires ``_write_lock`` to wait for in-flight writes before teardown. """ if not self._connected: return self._closing.set() # reject new operations immediately try: with self._write_lock: # 1. Close owned regions (allocator cleanup only) → get region_ids owned_region_ids: list[int] = [] if self._owned is not None: owned_region_ids = self._owned.close() # 2. Return allocations to server via RPC for rid in owned_region_ids: try: self._rpc.return_alloc(self._config.instance_id, rid) except Exception: logger.error("Failed to return region %d", rid, exc_info=True) # 3. Unmap all regions (owned + shared) via DaxMapper self._mapper.close() # 4. Close RPC connection self._rpc.close() except Exception: logger.error("Error during close", exc_info=True) finally: self._connected = False self._owned = None self._key_to_location.clear()
# ========================================================================= # KV Operations # =========================================================================
[docs] def alloc(self, size: int) -> AllocHandle: """Allocate a page and return a handle with a writable mmap memoryview. The caller writes directly to ``handle.buf``, then passes the handle to ``store(key, handle=handle)`` to register without copying. Args: size: Required bytes (must be <= chunk_size) Returns: AllocHandle with writable memoryview Raises: RuntimeError: If not connected or closing ValueError: If size exceeds chunk_size or allocation fails """ self._ensure_connected() with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") chunk_size = self._owned.get_chunk_size() if size > chunk_size: raise ValueError( f"Requested size {size} exceeds chunk_size {chunk_size}" ) result = self._owned.allocate() if result is None: if not self._expand_region(): raise ValueError("Cannot allocate page: pool exhausted") result = self._owned.allocate() if result is None: raise ValueError("Cannot allocate page after expansion") region_id, page_index = result buf = self._mapper.get_buffer_view( region_id, page_index * chunk_size, size, ) if buf is None: self._owned.free(region_id, page_index) raise ValueError(f"Failed to get buffer view for region {region_id}") handle = AllocHandle( buf=buf, _region_id=region_id, _page_index=page_index, _size=size, ) logger.debug( "alloc: size=%d, region=%d, page=%d", size, region_id, page_index, ) return handle
[docs] def free(self, handle: AllocHandle) -> None: """Free a page previously obtained via alloc(). Can be called before store() (discard) or after (eviction). Args: handle: AllocHandle from alloc() Raises: ValueError: If handle is not tracked (already freed or invalid) """ self._ensure_connected() with self._write_lock: region_id = handle._region_id page_index = handle._page_index # Find and remove the key mapping if stored key_to_remove = None for key, loc in self._key_to_location.items(): if loc == (region_id, page_index): key_to_remove = key break if key_to_remove is not None: del self._key_to_location[key_to_remove] self._owned.free(region_id, page_index) logger.debug( "free: region=%d, page=%d, key=%s", region_id, page_index, key_to_remove, )
[docs] def store( self, key: str, info: MemoryInfo | memoryview | None = None, prefix: bytes | None = None, *, data: memoryview | None = None, handle: AllocHandle | None = None, ) -> bool: """Store data to the KV cache. If ``handle`` is provided (zero-copy path), data is already written to the mmap region via alloc() and only register_kv is performed. Otherwise, allocate + memcpy + register are performed in one call. Args: key: The chunk key string info: MemoryInfo or memoryview with data prefix: Optional bytes to prepend (e.g., serialized metadata header) data: memoryview with data (preferred, keyword-only) handle: AllocHandle from alloc() for zero-copy store Returns: True if successful """ self._ensure_connected() with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") # Duplicate skip: check if key already exists (common to both paths) if key in self._key_to_location: if handle is not None: self._owned.free(handle._region_id, handle._page_index) logger.debug("store: key=%s already in local map, skipping", key) return True elif self._rpc.exists_kv(key): if handle is not None: self._owned.free(handle._region_id, handle._page_index) logger.debug("store: key=%s already exists on server, skipping", key) return True if handle is not None: # ── Zero-copy path ── if data is not None or info is not None: raise ValueError("Cannot specify both handle and data/info") region_id = handle._region_id page_index = handle._page_index offset = page_index * self._owned.get_chunk_size() total_size = handle._size is_new = self._rpc.register_kv( key=key, region_id=region_id, kv_offset=offset, kv_length=total_size, ) if not is_new: self._owned.free(region_id, page_index) logger.debug( "store: key=%s lost register race, freed page " "(region=%d, page=%d)", key, region_id, page_index, ) return True self._key_to_location[key] = (region_id, page_index) logger.debug( "Stored (zero-copy) key=%s: region=%d, page=%d, offset=%d, size=%d", key, region_id, page_index, offset, total_size, ) return True # ── Allocate + memcpy + register ── # Resolve source memoryview from either parameter if data is not None: src = data elif isinstance(info, memoryview): src = info elif isinstance(info, MemoryInfo): src = info.view else: raise TypeError( "Must provide data (memoryview) or info (MemoryInfo | memoryview)" ) # Normalize to 1D unsigned-byte view for mmap slice assignment if src.format != "B": src = src.cast("B") data_size = len(src) prefix_len = len(prefix) if prefix else 0 total_size = prefix_len + data_size logger.debug( "store: key=%s, data=%d bytes, prefix=%d bytes, " "total=%d bytes, readonly=%s", key, data_size, prefix_len, total_size, src.readonly, ) if total_size > self._owned.get_chunk_size(): logger.error( "Total size %d exceeds chunk_size %d", total_size, self._owned.get_chunk_size(), ) return False # Allocate page + CXL write + register (new or overwrite only) result = self._owned.allocate() if result is None: if not self._expand_region(): logger.error("Cannot allocate page for key %s", key) return False result = self._owned.allocate() if result is None: return False region_id, page_index = result # 2. Get writable memoryview slice for the page buf = self._mapper.get_buffer_view( region_id, page_index * self._owned.get_chunk_size(), total_size, ) if buf is None: self._owned.free(region_id, page_index) return False # 3. Write prefix + data via GIL-free memcpy offset = 0 if prefix: _gil_free_memcpy(buf[offset:], prefix, prefix_len) offset += prefix_len _gil_free_memcpy(buf[offset:], src, data_size) # 4. Register KV with server offset = page_index * self._owned.get_chunk_size() is_new = self._rpc.register_kv( key=key, region_id=region_id, kv_offset=offset, kv_length=total_size, ) if not is_new: # Race condition: another instance registered the same key # between our exists_kv check and register_kv call. # Free the page we just wrote — the data is identical anyway. self._owned.free(region_id, page_index) logger.debug( "store: key=%s lost register race, freed page (region=%d, page=%d)", key, region_id, page_index, ) return True # 5. Track self._key_to_location[key] = (region_id, page_index) logger.debug( "Stored key=%s: region=%d, page=%d, offset=%d, size=%d", key, region_id, page_index, offset, total_size, ) return True
[docs] def retrieve(self, key: str) -> MemoryInfo | None: """Retrieve a zero-copy MemoryInfo from the KV cache. Returns a MemoryInfo with a memoryview slice of the mmap region. Works for both owned (RW) and shared (RO) regions. WARNING: The returned memoryview is only valid while the region remains mapped. Do not use after calling close(). Args: key: The chunk key string Returns: MemoryInfo with memoryview, or None if not found """ self._ensure_connected() result = self._rpc.lookup_kv(key) if not result.found or result.handle is None: logger.debug("Key %s not found", key) return None handle = result.handle region_id = handle.region_id # Shared region: on-demand mapping if not self._owned.is_owned(region_id): if self._mapper.get_region(region_id) is None: try: self._mapper.map_region(handle) except Exception: logger.error( "Failed to map shared region %d", region_id, exc_info=True ) return None buf = self._mapper.get_buffer_view( region_id, result.kv_offset, result.kv_length ) if buf is None: logger.error("Region %d: get_buffer_view returned None", region_id) return None logger.debug( "retrieve: key=%s, region=%d, page=%d, offset=%d, size=%d, " "readonly=%s, owned=%s", key, region_id, result.kv_offset // self._owned.get_chunk_size(), result.kv_offset, result.kv_length, buf.readonly, self._owned.is_owned(region_id), ) return MemoryInfo(view=buf)
[docs] def exists(self, key: str) -> bool: """Check if a key exists. Args: key: The chunk key string Returns: True if exists """ self._ensure_connected() return self._rpc.exists_kv(key)
[docs] def delete(self, key: str) -> bool: """Delete a key and free the corresponding page. Args: key: The chunk key string Returns: True if deleted """ self._ensure_connected() with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") # RPC first, then local free — prevents inconsistency on RPC failure result = self._rpc.delete_kv(key) if result: location = self._key_to_location.pop(key, None) if location is not None: region_id, page_index = location self._owned.free(region_id, page_index) logger.debug("Deleted key=%s", key) else: logger.debug("Delete key=%s: not found on server", key) return result
[docs] def healthcheck(self) -> bool: """Check if the handler and MaruServer are healthy. Verifies local connection state and sends a heartbeat RPC to confirm the MaruServer is responsive. Returns: True if connected and server responded to heartbeat """ if not self._connected or self._closing.is_set(): return False try: return self._rpc.heartbeat() except Exception as e: logger.warning("Healthcheck failed: %s", e) return False
[docs] def get_stats(self) -> dict: """Get server statistics.""" self._ensure_connected() stats = self._rpc.get_stats() result = { "kv_manager": { "total_entries": stats.kv_manager.total_entries, "total_size": stats.kv_manager.total_size, }, "allocation_manager": { "num_allocations": stats.allocation_manager.num_allocations, "total_allocated": stats.allocation_manager.total_allocated, "active_clients": stats.allocation_manager.active_clients, }, } if self._owned is not None: store_stats = self._owned.get_stats() result["store_regions"] = store_stats # Backward compat: first region stats as "allocator" regions_list = store_stats.get("regions", []) if regions_list: result["allocator"] = regions_list[0] return result
# ========================================================================= # Batch Operations # =========================================================================
[docs] def batch_retrieve(self, keys: list[str]) -> list[MemoryInfo | None]: """Retrieve multiple values as MemoryInfo in batch. Uses a single batch RPC call for lookup, returns zero-copy memoryview slices for both owned (RW) and shared (RO) regions. WARNING: Returned memoryviews are only valid while regions remain mapped. On RPC failure, allocated pages are freed but data already written to those pages is not zeroed. This is safe because the pages are never registered with the server and will be overwritten on reuse. Args: keys: List of chunk key strings Returns: List of MemoryInfo (None for keys not found) """ self._ensure_connected() try: batch_resp = self._rpc.batch_lookup_kv(keys) except Exception: logger.error("batch_retrieve RPC failed", exc_info=True) return [None] * len(keys) results: list[MemoryInfo | None] = [] for i, entry in enumerate(batch_resp.entries): if not entry.found or entry.handle is None: results.append(None) continue handle = entry.handle region_id = handle.region_id # Ensure region is mapped if not self._owned.is_owned(region_id): if self._mapper.get_region(region_id) is None: try: self._mapper.map_region(handle) except Exception: logger.error( "Failed to map shared region %d", region_id, exc_info=True, ) results.append(None) continue buf = self._mapper.get_buffer_view( region_id, entry.kv_offset, entry.kv_length ) if buf is None: logger.error("Region %d: get_buffer_view returned None", region_id) results.append(None) continue logger.debug( "batch_retrieve: key=%s, region=%d, page=%d, " "offset=%d, size=%d, readonly=%s", keys[i], region_id, entry.kv_offset // self._owned.get_chunk_size(), entry.kv_offset, entry.kv_length, buf.readonly, ) results.append(MemoryInfo(view=buf)) hits = sum(1 for r in results if r is not None) ro_count = sum(1 for r in results if r is not None and r.view.readonly) logger.debug( "batch_retrieve: %d/%d hits, %d readonly (shared), %d writable (owned)", hits, len(keys), ro_count, hits - ro_count, ) return results
[docs] def batch_store( self, keys: list[str], infos: list[MemoryInfo | memoryview], prefixes: list[bytes | None] | None = None, ) -> list[bool]: """Store multiple key-value pairs in batch. Uses a single batch RPC call for registration. Args: keys: List of chunk key strings infos: List of MemoryInfo or memoryview with data prefixes: Optional list of prefix bytes per entry Returns: List of booleans indicating success for each key """ self._ensure_connected() if len(keys) != len(infos): raise ValueError("keys and infos must have the same length") if prefixes is not None and len(prefixes) != len(keys): raise ValueError("prefixes must have the same length as keys") with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") chunk_size = self._owned.get_chunk_size() results = [True] * len(keys) register_entries = [] allocations: dict[int, tuple[int, int]] = {} # Phase 1: Batch check which keys already exist (avoid CXL write waste) try: exists_resp = self._rpc.batch_exists_kv(keys) exists_results = exists_resp.results except Exception: logger.error( "batch_exists RPC failed, proceeding without check", exc_info=True ) exists_results = [False] * len(keys) skipped = sum(exists_results) if skipped > 0: logger.debug( "batch_store: %d/%d keys already exist, skipping CXL write", skipped, len(keys), ) # Phase 2: Only process new keys (skip duplicates) for i, (key, info) in enumerate(zip(keys, infos, strict=True)): is_local = key in self._key_to_location if is_local: # Same instance already stored — same key = same content, skip logger.debug( "batch_store: key=%s already in local map, skipping", key ) continue # results[i] stays True (idempotent) if exists_results[i]: # Another instance already registered — skip CXL write logger.debug( "batch_store: key=%s already exists on server, skipping", key ) continue # results[i] stays True (idempotent) prefix = prefixes[i] if prefixes else None prefix_len = len(prefix) if prefix else 0 # Normalize to 1D unsigned-byte view for mmap slice assignment src = info if isinstance(info, memoryview) else info.view if src.format != "B": src = src.cast("B") data_size = len(src) total_size = prefix_len + data_size if total_size > chunk_size: logger.error( "Total size %d exceeds chunk_size %d for key %s", total_size, chunk_size, key, ) results[i] = False continue # Allocate page (expand if needed) alloc_result = self._owned.allocate() if alloc_result is None: if not self._expand_region(): logger.error("Cannot allocate page for key %s", key) results[i] = False continue alloc_result = self._owned.allocate() if alloc_result is None: results[i] = False continue region_id, page_index = alloc_result allocations[i] = (region_id, page_index) # Write to page via GIL-free memcpy buf = self._mapper.get_buffer_view( region_id, page_index * chunk_size, total_size ) if buf is None: self._owned.free(region_id, page_index) results[i] = False continue mv_offset = 0 if prefix: _gil_free_memcpy(buf[mv_offset:], prefix, prefix_len) mv_offset += prefix_len _gil_free_memcpy(buf[mv_offset:], src, data_size) offset = page_index * chunk_size register_entries.append((key, region_id, offset, total_size)) # Batch register if register_entries: try: batch_resp = self._rpc.batch_register_kv(register_entries) except Exception: logger.error("Batch register RPC failed", exc_info=True) for _idx, (rid, pidx) in allocations.items(): self._owned.free(rid, pidx) return [False] * len(keys) if not batch_resp.success: logger.error("Batch register RPC failed") for _idx, (rid, pidx) in allocations.items(): self._owned.free(rid, pidx) return [False] * len(keys) batch_idx = 0 for i in range(len(keys)): if results[i] and i in allocations: if batch_idx < len(batch_resp.results): results[i] = batch_resp.results[batch_idx] batch_idx += 1 # Track for i, key in enumerate(keys): if results[i] and i in allocations: self._key_to_location[key] = allocations[i] total_bytes = sum( ( infos[i].nbytes if isinstance(infos[i], memoryview) else infos[i].view.nbytes ) for i in range(len(keys)) if results[i] ) logger.debug( "batch_store: %d/%d succeeded, total_data=%d bytes", sum(results), len(keys), total_bytes, ) return results
[docs] def batch_exists(self, keys: list[str]) -> list[bool]: """Check if multiple keys exist. Uses a single batch RPC call instead of N individual calls. Args: keys: List of chunk key strings Returns: List of booleans indicating existence for each key """ self._ensure_connected() try: batch_resp = self._rpc.batch_exists_kv(keys) except Exception: logger.error("batch_exists RPC failed", exc_info=True) return [False] * len(keys) return batch_resp.results
# ========================================================================= # Properties # ========================================================================= @property def pool_handle(self) -> MaruHandle | None: """Get initial pool handle (backward compat).""" if self._owned is None: return None first_rid = self._owned.get_first_region_id() if first_rid is None: return None mapped = self._mapper.get_region(first_rid) return mapped.handle if mapped else None @property def allocator(self) -> PagedMemoryAllocator | None: """Get the first region's allocator (backward compat).""" if self._owned is None: return None return self._owned.get_first_allocator() @property def owned_region_manager(self) -> OwnedRegionManager | None: """Get the owned region manager.""" return self._owned @property def instance_id(self) -> str: """Get instance ID.""" return self._config.instance_id @property def connected(self) -> bool: """Check if connected.""" return self._connected # ========================================================================= # Helpers # ========================================================================= def _expand_region(self) -> bool: """Request a new store region from the server and add it. Returns: True if expansion succeeded. """ try: response = self._rpc.request_alloc( instance_id=self._config.instance_id, size=self._config.pool_size, ) except Exception: logger.error("RPC request_alloc failed during expand", exc_info=True) return False if not response.success or response.handle is None: logger.error( "Server refused region expansion: %s", getattr(response, "error", "unknown"), ) return False handle = response.handle try: self._owned.add_region(handle) logger.info("Expanded: new store region %d", handle.region_id) return True except Exception: logger.error("Failed to init expanded region", exc_info=True) try: self._rpc.return_alloc(self._config.instance_id, handle.region_id) except Exception: logger.debug( "Failed to return allocation during expansion cleanup", exc_info=True, ) return False def _premap_shared_regions(self) -> None: """Pre-map all existing shared regions from other instances. Called during connect() to eliminate mmap from the retrieve hot path. Failures are logged but do not block connection — lazy fallback remains as safety net in retrieve(). """ try: response = self._rpc.list_allocations( exclude_instance_id=self._config.instance_id ) except Exception as e: logger.warning("Failed to list allocations for pre-map: %s", e) return if not response.success: logger.warning( "list_allocations failed: %s", response.error or "unknown", ) return # NOTE: Race window exists between list_allocations() and map_region(). # A region owner may disconnect between these calls, making the handle # stale. This is safe — map_region() failure is caught below and lazy # fallback in retrieve() handles it. mapped_count = 0 for handle in response.allocations: if self._mapper.get_region(handle.region_id) is not None: continue # already mapped (own region) try: self._mapper.map_region(handle, prefault=False) mapped_count += 1 except Exception as e: logger.warning( "Failed to pre-map shared region %d: %s", handle.region_id, e, ) logger.info( "Pre-mapped %d shared regions (%d total from server)", mapped_count, len(response.allocations), ) def _ensure_connected(self) -> None: """Ensure connected, raise if not or if closing.""" if self._closing.is_set(): raise RuntimeError("Handler is closing") if not self._connected or self._owned is None: raise RuntimeError("Not connected. Call connect() first.") # ========================================================================= # Context Manager # ========================================================================= def __enter__(self) -> "MaruHandler": """Context manager entry.""" if self._config.auto_connect: self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close()