Source code for maru_handler.handler

# SPDX-License-Identifier: Apache-2.0
# Copyright 2026 XCENA Inc.
"""MaruHandler - Main interface for Maru shared memory KV cache client.

This module provides the primary entry point for clients to interact with
the Maru shared memory KV cache system.

Example:
    from maru import MaruConfig, MaruHandler

    config = MaruConfig(server_url="tcp://localhost:5555")
    with MaruHandler(config) as handler:
        # Zero-copy store: alloc → write to buf → store
        handle = handler.alloc(size=len(data))
        handle.buf[:len(data)] = data
        handler.store(key="12345", handle=handle)

        result = handler.retrieve(key="12345")  # returns MemoryInfo
"""

import logging
import threading
from collections.abc import Callable

from maru_common import MaruConfig
from maru_shm import MaruHandle

from .memory import (
    AllocHandle,
    DaxMapper,
    MemoryInfo,
    OwnedRegionManager,
    PagedMemoryAllocator,
)
from .rpc_client import RpcClient

logger = logging.getLogger(__name__)


[docs] class MaruHandler: """Main interface for Maru shared memory KV cache operations. This class handles: - Connection management to MaruServer - Memory mapping via DaxMapper - KV store/retrieve operations Thread-safety: - Read operations (exists, retrieve, batch_exists, batch_retrieve) are lock-free — they rely on RpcAsyncClient (already thread-safe) and DaxMapper's internal lock for lazy mapping. - Write operations (store, batch_store, delete) are serialized by ``_write_lock`` to guarantee atomicity of allocate-write-register. - ``close()`` sets ``_closing`` event to reject new operations, then acquires ``_write_lock`` to wait for in-flight writes before teardown. Architecture:: MaruHandler ├── RpcClient (server communication, sole RPC owner) ├── DaxMapper (memory mapping via MaruShmClient, owns all mmap/munmap) ├── OwnedRegionManager (owned regions + allocation, no RPC) │ ├── OwnedRegion 1 (PagedMemoryAllocator) │ ├── OwnedRegion 2 (PagedMemoryAllocator) │ └── ... └── _key_to_location (key -> (region_id, page_index)) """ def __init__(self, config: MaruConfig | None = None): """Initialize MaruHandler. Args: config: Configuration object. If None, uses defaults. """ self._config = config or MaruConfig() if self._config.use_async_rpc: from .rpc_async_client import RpcAsyncClient self._rpc = RpcAsyncClient( self._config.server_url, timeout_ms=self._config.timeout_ms, max_inflight=self._config.max_inflight, ) else: self._rpc = RpcClient( self._config.server_url, timeout_ms=self._config.timeout_ms, ) self._mapper: DaxMapper | None = None # Managers (initialized on connect) self._owned: OwnedRegionManager | None = None # Thread-safety self._write_lock = threading.Lock() self._closing = threading.Event() # Connection state self._key_to_location: dict[str, tuple[int, int]] = {} self._connected = False # Region-added callback (set by CxlMemoryAdapter) self._on_region_added: Callable[[int, int], None] | None = None # Expansion policy self._auto_expand = self._config.auto_expand self._expand_size = self._config.expand_size or self._config.pool_size logger.debug("Created MaruHandler with config: %s", self._config) # ========================================================================= # Public Accessors # ========================================================================= @property def mapper(self) -> DaxMapper: """Deprecated: Use get_buffer_view() instead.""" return self._mapper
[docs] def get_buffer_view( self, region_id: int, offset: int, size: int ) -> memoryview | None: """Get a memoryview slice from a mapped region. Args: region_id: The region ID (owned or shared). offset: Byte offset within the region. size: Number of bytes to view. Returns: Writable memoryview, or None if region not mapped. """ return self._mapper.get_buffer_view(region_id, offset, size)
[docs] def get_region_page_count(self, region_id: int) -> int | None: """Get page count for a region (owned or shared). Args: region_id: The region ID. Returns: Number of pages, or None if region not found. """ if self._owned is not None: region = self._owned.get_owned_region(region_id) if region is not None: return region.allocator.page_count mapped = self._mapper.get_region(region_id) if mapped is None: return None return mapped.size // self._config.chunk_size_bytes
[docs] def get_owned_region_ids(self) -> list[int]: """Get list of currently owned region IDs. Returns: List of region IDs. Empty if not connected. """ if self._owned is None: return [] return self._owned.get_region_ids()
[docs] def get_chunk_size(self) -> int: """Get the configured chunk size in bytes. Returns: Chunk size in bytes. """ return self._config.chunk_size_bytes
[docs] def set_on_region_added(self, callback: Callable[[int, int], None] | None) -> None: """Register callback invoked with (region_id, page_count) after region added. On registration, replays callback for all existing owned regions so the caller doesn't need separate init-time logic. Args: callback: Called with (region_id, page_count), or None to unregister. """ self._on_region_added = callback if callback is not None and self._owned is not None: for rid in self._owned.get_region_ids(): region = self._owned.get_owned_region(rid) if region is not None: logger.debug( "on_region_added replay: region=%d pages=%d", rid, region.allocator.page_count, ) callback(rid, region.allocator.page_count)
# ========================================================================= # Connection Management # =========================================================================
[docs] def connect(self) -> bool: """Connect to the server and request a memory allocation. Returns: True if successful """ if self._connected: return True try: # 1. Connect RPC client self._rpc.connect() # 1b. Handshake to get server config (rm_address) try: handshake_resp = self._rpc.handshake() rm_address = handshake_resp.get("rm_address") or self._config.rm_address except Exception: logger.debug("Handshake failed, using config rm_address", exc_info=True) rm_address = self._config.rm_address self._mapper = DaxMapper(rm_address=rm_address) # 2. Initialize managers self._owned = OwnedRegionManager( mapper=self._mapper, chunk_size=self._config.chunk_size_bytes, ) # 3. Request initial owned region via RPC try: response = self._rpc.request_alloc( instance_id=self._config.instance_id, size=self._config.pool_size, ) except Exception: logger.error( "RPC request_alloc failed during connect", exc_info=True, ) return False if not response.success or response.handle is None: logger.error( "Initial allocation failed: %s", getattr(response, "error", "unknown"), ) if self._owned is not None: self._owned.close() self._owned = None self._rpc.close() return False # 4. Add region to OwnedRegionManager (mmap + allocator) handle = response.handle try: self._owned.add_region(handle) except Exception: logger.error("Failed to init region", exc_info=True) try: self._rpc.return_alloc(self._config.instance_id, handle.region_id) except Exception: logger.debug( "Failed to return allocation during cleanup", exc_info=True, ) if self._owned is not None: self._owned.close() self._owned = None self._rpc.close() return False self._connected = True # 5. Pre-map shared regions (eager mapping) if self._config.eager_map: self._premap_shared_regions() logger.info( "Connected: chunk_size=%d", self._config.chunk_size_bytes, ) return True except Exception: logger.error("Failed to connect", exc_info=True) return False
[docs] def close(self) -> None: """Close the connection and return all allocations. Sets ``_closing`` event to reject new operations, then acquires ``_write_lock`` to wait for in-flight writes before teardown. """ if not self._connected: return self._closing.set() # reject new operations immediately try: with self._write_lock: # 1. Close owned regions (allocator cleanup only) → get region_ids owned_region_ids: list[int] = [] if self._owned is not None: owned_region_ids = self._owned.close() # 2. Return allocations to server via RPC for rid in owned_region_ids: try: self._rpc.return_alloc(self._config.instance_id, rid) except Exception: logger.error("Failed to return region %d", rid, exc_info=True) # 3. Unmap all regions (owned + shared) via DaxMapper if self._mapper is not None: self._mapper.close() # 4. Close RPC connection self._rpc.close() except Exception: logger.error("Error during close", exc_info=True) finally: self._connected = False self._owned = None self._key_to_location.clear()
# ========================================================================= # KV Operations # =========================================================================
[docs] def alloc(self, size: int) -> AllocHandle: """Allocate a page and return a handle with a writable memoryview. The caller writes directly to ``handle.buf``, then passes the handle to ``store(key, handle)`` to register without copying. Args: size: Required bytes (must be <= chunk_size) Returns: AllocHandle with writable memoryview and allocation metadata Raises: RuntimeError: If not connected or closing ValueError: If size exceeds chunk_size or allocation fails """ self._ensure_connected() with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") chunk_size = self._owned.get_chunk_size() if size > chunk_size: raise ValueError( f"Requested size {size} exceeds chunk_size {chunk_size}" ) result = self._owned.allocate() if result is None: if not self._expand_region(): if not self._auto_expand: raise ValueError( "Cannot allocate page: pool exhausted " "and auto_expand is disabled" ) raise ValueError( "Cannot allocate page: pool exhausted after expansion attempt" ) result = self._owned.allocate() if result is None: raise ValueError("Cannot allocate page after expansion") region_id, page_index = result buf = self._mapper.get_buffer_view( region_id, page_index * chunk_size, size, ) if buf is None: self._owned.free(region_id, page_index) raise ValueError(f"Failed to get buffer view for region {region_id}") handle = AllocHandle( buf=buf, _region_id=region_id, _page_index=page_index, _size=size, ) logger.debug( "alloc: size=%d, region=%d, page=%d", size, region_id, page_index, ) return handle
[docs] def free(self, handle: AllocHandle) -> None: """Free a page previously obtained via alloc(). Can be called before store() (discard) or after (eviction). Args: handle: AllocHandle from alloc() Raises: ValueError: If handle is not tracked (already freed or invalid) """ self._ensure_connected() with self._write_lock: region_id = handle._region_id page_index = handle._page_index # Find and remove the key mapping if stored key_to_remove = None for key, loc in self._key_to_location.items(): if loc == (region_id, page_index): key_to_remove = key break if key_to_remove is not None: del self._key_to_location[key_to_remove] self._owned.free(region_id, page_index) logger.debug( "free: region=%d, page=%d, key=%s", region_id, page_index, key_to_remove, )
[docs] def store( self, key: str, handle: AllocHandle, ) -> bool: """Register a pre-written page in the KV cache (zero-copy). Data must already be written to the page via ``handle.buf``. This method only performs duplicate check + metadata registration. Args: key: The chunk key string handle: AllocHandle from alloc() Returns: True if successful """ self._ensure_connected() with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") # Duplicate skip if key in self._key_to_location: self._owned.free(handle._region_id, handle._page_index) logger.debug("store: key=%s already in local map, skipping", key) return True elif self._rpc.exists_kv(key): self._owned.free(handle._region_id, handle._page_index) logger.debug("store: key=%s already exists on server, skipping", key) return True region_id = handle._region_id page_index = handle._page_index offset = page_index * self._owned.get_chunk_size() total_size = handle._size try: is_new = self._rpc.register_kv( key=key, region_id=region_id, kv_offset=offset, kv_length=total_size, ) except Exception: self._owned.free(region_id, page_index) logger.error( "store: register_kv RPC failed for key=%s, freed page (region=%d, page=%d)", key, region_id, page_index, exc_info=True, ) return False if not is_new: self._owned.free(region_id, page_index) logger.debug( "store: key=%s lost register race, freed page (region=%d, page=%d)", key, region_id, page_index, ) return True self._key_to_location[key] = (region_id, page_index) logger.debug( "store: key=%s, region=%d, page=%d, offset=%d, size=%d", key, region_id, page_index, offset, total_size, ) return True
[docs] def retrieve(self, key: str) -> MemoryInfo | None: """Retrieve a zero-copy MemoryInfo from the KV cache. Returns a MemoryInfo with a memoryview slice of the mmap region. Works for both owned (RW) and shared (RO) regions. WARNING: The returned memoryview is only valid while the region remains mapped. Do not use after calling close(). Args: key: The chunk key string Returns: MemoryInfo with memoryview, or None if not found """ self._ensure_connected() result = self._rpc.lookup_kv(key) if not result.found or result.handle is None: logger.debug("Key %s not found", key) return None handle = result.handle region_id = handle.region_id # Shared region: on-demand mapping if not self._owned.is_owned(region_id): if self._mapper.get_region(region_id) is None: try: self._mapper.map_region(handle) except Exception: logger.error( "Failed to map shared region %d", region_id, exc_info=True ) return None buf = self._mapper.get_buffer_view( region_id, result.kv_offset, result.kv_length ) if buf is None: logger.error("Region %d: get_buffer_view returned None", region_id) return None logger.debug( "retrieve: key=%s, region=%d, page=%d, offset=%d, size=%d, " "readonly=%s, owned=%s", key, region_id, result.kv_offset // self._owned.get_chunk_size(), result.kv_offset, result.kv_length, buf.readonly, self._owned.is_owned(region_id), ) chunk_size = self._owned.get_chunk_size() page_index = result.kv_offset // chunk_size return MemoryInfo(view=buf, region_id=region_id, page_index=page_index)
[docs] def exists(self, key: str) -> bool: """Check if a key exists. Args: key: The chunk key string Returns: True if exists """ self._ensure_connected() return self._rpc.exists_kv(key)
[docs] def pin(self, key: str) -> bool: """Check if a key exists and pin it atomically. If the key exists, increments pin_count to protect from eviction. Args: key: The chunk key string Returns: True if exists (and was pinned) """ self._ensure_connected() return self._rpc.pin_kv(key)
[docs] def unpin(self, key: str) -> bool: """Unpin a KV entry, making it eligible for eviction. Args: key: The chunk key string Returns: True if unpinned successfully """ self._ensure_connected() return self._rpc.unpin(key)
[docs] def delete(self, key: str) -> bool: """Delete a key and free the corresponding page. Args: key: The chunk key string Returns: True if deleted """ self._ensure_connected() with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") # RPC first, then local free — prevents inconsistency on RPC failure result = self._rpc.delete_kv(key) if result: location = self._key_to_location.pop(key, None) if location is not None: region_id, page_index = location self._owned.free(region_id, page_index) logger.debug("Deleted key=%s", key) else: logger.debug("Delete key=%s: not found on server", key) return result
[docs] def healthcheck(self) -> bool: """Check if the handler and MaruServer are healthy. Verifies local connection state and sends a heartbeat RPC to confirm the MaruServer is responsive. Returns: True if connected and server responded to heartbeat """ if not self._connected or self._closing.is_set(): return False try: return self._rpc.heartbeat() except Exception as e: logger.warning("Healthcheck failed: %s", e) return False
[docs] def get_stats(self) -> dict: """Get server statistics.""" self._ensure_connected() stats = self._rpc.get_stats() result = { "kv_manager": { "total_entries": stats.kv_manager.total_entries, "total_size": stats.kv_manager.total_size, }, "allocation_manager": { "num_allocations": stats.allocation_manager.num_allocations, "total_allocated": stats.allocation_manager.total_allocated, "active_clients": stats.allocation_manager.active_clients, }, } if self._owned is not None: store_stats = self._owned.get_stats() result["store_regions"] = store_stats # Backward compat: first region stats as "allocator" regions_list = store_stats.get("regions", []) if regions_list: result["allocator"] = regions_list[0] return result
# ========================================================================= # Batch Operations # =========================================================================
[docs] def batch_retrieve(self, keys: list[str]) -> list[MemoryInfo | None]: """Retrieve multiple values as MemoryInfo in batch. Uses a single batch RPC call for lookup, returns zero-copy memoryview slices for both owned (RW) and shared (RO) regions. WARNING: Returned memoryviews are only valid while regions remain mapped. On RPC failure, allocated pages are freed but data already written to those pages is not zeroed. This is safe because the pages are never registered with the server and will be overwritten on reuse. Args: keys: List of chunk key strings Returns: List of MemoryInfo (None for keys not found) """ self._ensure_connected() try: batch_resp = self._rpc.batch_lookup_kv(keys) except Exception: logger.error("batch_retrieve RPC failed", exc_info=True) return [None] * len(keys) results: list[MemoryInfo | None] = [] for i, entry in enumerate(batch_resp.entries): if not entry.found or entry.handle is None: results.append(None) continue handle = entry.handle region_id = handle.region_id # Ensure region is mapped if not self._owned.is_owned(region_id): if self._mapper.get_region(region_id) is None: try: self._mapper.map_region(handle) except Exception: logger.error( "Failed to map shared region %d", region_id, exc_info=True, ) results.append(None) continue buf = self._mapper.get_buffer_view( region_id, entry.kv_offset, entry.kv_length ) if buf is None: logger.error("Region %d: get_buffer_view returned None", region_id) results.append(None) continue logger.debug( "batch_retrieve: key=%s, region=%d, page=%d, " "offset=%d, size=%d, readonly=%s", keys[i], region_id, entry.kv_offset // self._owned.get_chunk_size(), entry.kv_offset, entry.kv_length, buf.readonly, ) chunk_size = self._owned.get_chunk_size() page_index = entry.kv_offset // chunk_size results.append( MemoryInfo(view=buf, region_id=region_id, page_index=page_index) ) hits = sum(1 for r in results if r is not None) ro_count = sum(1 for r in results if r is not None and r.view.readonly) logger.debug( "batch_retrieve: %d/%d hits, %d readonly (shared), %d writable (owned)", hits, len(keys), ro_count, hits - ro_count, ) return results
[docs] def batch_store( self, keys: list[str], handles: list[AllocHandle], ) -> list[bool]: """Register multiple pre-written pages in batch (zero-copy). Data must already be written to each page via ``handle.buf``. Uses a single batch RPC call for metadata registration. Args: keys: List of chunk key strings handles: List of AllocHandle from alloc() Returns: List of booleans indicating success for each key """ self._ensure_connected() if len(keys) != len(handles): raise ValueError("keys and handles must have the same length") with self._write_lock: if self._closing.is_set(): raise RuntimeError("Handler is closing") chunk_size = self._owned.get_chunk_size() results = [True] * len(keys) register_entries = [] allocations: dict[int, tuple[int, int]] = {} # Phase 1: Batch check which keys already exist try: exists_resp = self._rpc.batch_exists_kv(keys) exists_results = exists_resp.results except Exception: logger.error( "batch_exists RPC failed, proceeding without check", exc_info=True ) exists_results = [False] * len(keys) skipped = sum(exists_results) if skipped > 0: logger.debug( "batch_store: %d/%d keys already exist, skipping", skipped, len(keys), ) # Phase 2: Build register entries, free duplicates for i, (key, handle) in enumerate(zip(keys, handles, strict=True)): if key in self._key_to_location: self._owned.free(handle._region_id, handle._page_index) logger.debug( "batch_store: key=%s already in local map, skipping", key ) continue if exists_results[i]: self._owned.free(handle._region_id, handle._page_index) logger.debug( "batch_store: key=%s already exists on server, skipping", key, ) continue region_id = handle._region_id page_index = handle._page_index allocations[i] = (region_id, page_index) offset = page_index * chunk_size register_entries.append((key, region_id, offset, handle._size)) # Phase 3: Batch register if register_entries: try: batch_resp = self._rpc.batch_register_kv(register_entries) except Exception: logger.error("Batch register RPC failed", exc_info=True) for _idx, (rid, pidx) in allocations.items(): self._owned.free(rid, pidx) return [False] * len(keys) if not batch_resp.success: logger.error("Batch register RPC failed") for _idx, (rid, pidx) in allocations.items(): self._owned.free(rid, pidx) return [False] * len(keys) batch_idx = 0 for i in range(len(keys)): if results[i] and i in allocations: if batch_idx < len(batch_resp.results): results[i] = batch_resp.results[batch_idx] batch_idx += 1 # Track for i, key in enumerate(keys): if results[i] and i in allocations: self._key_to_location[key] = allocations[i] total_bytes = sum(handles[i]._size for i in range(len(keys)) if results[i]) logger.debug( "batch_store: %d/%d succeeded, total_data=%d bytes", sum(results), len(keys), total_bytes, ) return results
[docs] def batch_exists(self, keys: list[str]) -> list[bool]: """Check if multiple keys exist. Uses a single batch RPC call instead of N individual calls. Args: keys: List of chunk key strings Returns: List of booleans indicating existence for each key """ self._ensure_connected() try: batch_resp = self._rpc.batch_exists_kv(keys) except Exception: logger.error("batch_exists RPC failed", exc_info=True) return [False] * len(keys) return batch_resp.results
[docs] def batch_pin(self, keys: list[str]) -> list[bool]: """Check existence and pin multiple keys in a single RPC call. Args: keys: List of chunk key strings Returns: List of booleans — True if key exists (and was pinned). """ self._ensure_connected() return self._rpc.batch_pin_kv(keys).results
[docs] def batch_unpin(self, keys: list[str]) -> list[bool]: """Unpin multiple keys in a single RPC call. Args: keys: List of chunk key strings Returns: List of booleans — True if successfully unpinned. """ self._ensure_connected() return self._rpc.batch_unpin(keys).results
# ========================================================================= # Properties # ========================================================================= @property def pool_handle(self) -> MaruHandle | None: """Get initial pool handle (backward compat).""" if self._owned is None: return None first_rid = self._owned.get_first_region_id() if first_rid is None: return None mapped = self._mapper.get_region(first_rid) return mapped.handle if mapped else None @property def allocator(self) -> PagedMemoryAllocator | None: """Get the first region's allocator (backward compat).""" if self._owned is None: return None return self._owned.get_first_allocator() @property def owned_region_manager(self) -> OwnedRegionManager | None: """Deprecated: Use get_owned_region_ids(), get_region_page_count() instead.""" return self._owned @property def instance_id(self) -> str: """Get instance ID.""" return self._config.instance_id @property def connected(self) -> bool: """Check if connected.""" return self._connected # ========================================================================= # Helpers # ========================================================================= def _expand_region(self) -> bool: """Request a new store region from the server and add it. Gated by ``auto_expand`` config. Returns: True if expansion succeeded. """ if not self._auto_expand: logger.warning( "Pool exhausted but auto_expand is disabled. " "Set auto_expand=True in MaruConfig to enable." ) return False try: response = self._rpc.request_alloc( instance_id=self._config.instance_id, size=self._expand_size, ) except Exception: logger.error("RPC request_alloc failed during expand", exc_info=True) return False if not response.success or response.handle is None: logger.warning( "Region expansion refused: %s", getattr(response, "error", "unknown"), ) return False handle = response.handle try: region = self._owned.add_region(handle) logger.info("Expanded: new store region %d", handle.region_id) if self._on_region_added is not None: logger.debug( "on_region_added fire: region=%d pages=%d", handle.region_id, region.allocator.page_count, ) self._on_region_added(handle.region_id, region.allocator.page_count) return True except Exception: logger.error("Failed to init expanded region", exc_info=True) try: self._rpc.return_alloc(self._config.instance_id, handle.region_id) except Exception: logger.debug( "Failed to return allocation during expansion cleanup", exc_info=True, ) return False def _premap_shared_regions(self) -> None: """Pre-map all existing shared regions from other instances. Called during connect() to eliminate mmap from the retrieve hot path. Failures are logged but do not block connection — lazy fallback remains as safety net in retrieve(). """ try: response = self._rpc.list_allocations( exclude_instance_id=self._config.instance_id ) except Exception as e: logger.warning("Failed to list allocations for pre-map: %s", e) return if not response.success: logger.warning( "list_allocations failed: %s", response.error or "unknown", ) return # NOTE: Race window exists between list_allocations() and map_region(). # A region owner may disconnect between these calls, making the handle # stale. This is safe — map_region() failure is caught below and lazy # fallback in retrieve() handles it. mapped_count = 0 for handle in response.allocations: if self._mapper.get_region(handle.region_id) is not None: continue # already mapped (own region) try: self._mapper.map_region(handle, prefault=False) mapped_count += 1 except Exception as e: logger.warning( "Failed to pre-map shared region %d: %s", handle.region_id, e, ) logger.info( "Pre-mapped %d shared regions (%d total from server)", mapped_count, len(response.allocations), ) def _ensure_connected(self) -> None: """Ensure connected, raise if not or if closing.""" if self._closing.is_set(): raise RuntimeError("Handler is closing") if not self._connected or self._owned is None: raise RuntimeError("Not connected. Call connect() first.") # ========================================================================= # Context Manager # ========================================================================= def __enter__(self) -> "MaruHandler": """Context manager entry.""" if self._config.auto_connect: self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close()