Skip to content

Middleware Guide

Use middleware to extend logic before and after module execution.

1. Overview

Middleware are hooks that run before and after module execution, used for:

  • Logging: Record invocation logs
  • Performance Monitoring: Collect execution time and metrics
  • Error Handling: Unified exception handling
  • Data Transformation: Modify inputs/outputs
  • Caching: Cache execution results
  • Rate Limiting: Control invocation frequency

2. Middleware Interface

from typing import Any
from apcore import Context


class Middleware:
    """Middleware base class"""

    def before(
        self,
        module_id: str,
        inputs: dict[str, Any],
        context: Context
    ) -> dict[str, Any] | None:
        """
        Called before module execution

        Args:
            module_id: Module ID
            inputs: Input parameters
            context: Invocation context

        Returns:
            Modified inputs, or None to keep unchanged
        """
        return None

    def after(
        self,
        module_id: str,
        inputs: dict[str, Any],
        output: dict[str, Any],
        context: Context
    ) -> dict[str, Any] | None:
        """
        Called after module execution

        Args:
            module_id: Module ID
            inputs: Input parameters
            output: Output result
            context: Invocation context

        Returns:
            Modified output, or None to keep unchanged
        """
        return None

    def on_error(
        self,
        module_id: str,
        inputs: dict[str, Any],
        error: Exception,
        context: Context
    ) -> dict[str, Any] | None:
        """
        Called when module execution fails

        Args:
            module_id: Module ID
            inputs: Input parameters
            error: Exception object
            context: Invocation context

        Returns:
            Alternative output result (for error recovery), or None to continue raising exception
        """
        return None

3. Quick Start

3.1 Create Simple Middleware

from apcore import Middleware, Context


class LoggingMiddleware(Middleware):
    """Logging middleware"""

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        print(f"[{context.trace_id}] Calling {module_id}")
        # ⚠️ Use context.redacted_inputs to avoid leaking sensitive data
        print(f"  Inputs: {context.redacted_inputs}")

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        print(f"[{context.trace_id}] {module_id} completed")
        print(f"  Output: {output}")

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> None:
        print(f"[{context.trace_id}] {module_id} failed: {error}")

3.2 Using Middleware

from apcore import Registry, Executor

registry = Registry(extensions_dir="./extensions")
registry.discover()

executor = Executor(
    registry=registry,
    middlewares=[LoggingMiddleware()]
)

# Middleware automatically executes when calling modules
result = executor.call(
    module_id="executor.email.send_email",
    inputs={"to": "[email protected]", "subject": "Hi", "body": "Hello"}
)

Output:

[abc-123] Calling executor.email.send_email
  Inputs: {'to': '[email protected]', 'subject': 'Hi', 'body': 'Hello'}
[abc-123] executor.email.send_email completed
  Output: {'success': True, 'message_id': 'msg_456'}

4. Execution Model

4.1 Onion Model

Middleware executes in an onion model:

Request → MW1.before → MW2.before → MW3.before → Module execution
Response ← MW1.after  ← MW2.after  ← MW3.after  ←  Output result

4.2 Execution Order

executor = Executor(
    registry=registry,
    middlewares=[
        MiddlewareA(),  # First
        MiddlewareB(),  # Second
        MiddlewareC()   # Third
    ]
)

Execution order:

1. MiddlewareA.before()
2. MiddlewareB.before()
3. MiddlewareC.before()
4. module.execute()
5. MiddlewareC.after()
6. MiddlewareB.after()
7. MiddlewareA.after()

4.3 Error Handling

Request → MW1.before → MW2.before → Module execution (Exception!)
      MW1.on_error ← MW2.on_error ← Exception propagation

If any on_error returns non-None, the exception is "swallowed" and that value is returned as the result.

4.4 Middleware Execution State Machine

  ┌──────┐    ┌────────────────┐    ┌──────────┐    ┌───────────────┐    ┌──────┐
  │ init │───▶│ before (order) │───▶│ execute  │───▶│ after (reverse)│───▶│ done │
  └──────┘    └───────┬────────┘    └────┬─────┘    └───────┬───────┘    └──────┘
                      │                  │                  │
                      │ exception        │ exception        │ exception
                      ▼                  ▼                  ▼
                  ┌──────────┐      ┌──────────┐       ┌──────────┐
                  │ skip rest│      │ on_error │       │ on_error │
                  │ go error │      │ (reverse)│       │ (reverse)│
                 └──────────┘      └──────────┘       └──────────┘

Precise exception handling semantics:

Stage Exception Behavior Description
before() stage Skip remaining before and execute, enter on_error chain e.g., rate limit rejection
execute() stage Enter on_error chain (reverse order) Module execution failure
after() stage Enter on_error chain (reverse order); remaining after hooks are skipped After-stage errors should be handled uniformly
on_error() stage Log error but should continue executing remaining on_error Error handling should not fail again

Thread safety requirements:

  • Middleware instances can be shared by multiple Executors
  • Middleware's before()/after()/on_error() must be thread-safe
  • When using instance variables to store state, must use thread-safe data structures (e.g., threading.Lock)
  • Using context.data to pass middleware state is the recommended thread-safe approach

5. Common Middleware

5.1 Logging Middleware

import logging
from datetime import datetime
from apcore import Middleware, Context


class LoggingMiddleware(Middleware):
    """Detailed logging middleware"""

    def __init__(self, logger: logging.Logger | None = None):
        self.logger = logger or logging.getLogger("apcore")
        self._start_times: dict[str, datetime] = {}

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        self._start_times[context.trace_id] = datetime.now()
        # ⚠️ Security: Use context.redacted_inputs instead of raw inputs to avoid leaking sensitive data
        self.logger.info(
            f"[{context.trace_id}] START {module_id}",
            extra={
                "trace_id": context.trace_id,
                "module_id": module_id,
                "caller_id": context.caller_id,
                "inputs": context.redacted_inputs
            }
        )

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        start = self._start_times.pop(context.trace_id, datetime.now())
        duration = (datetime.now() - start).total_seconds() * 1000

        self.logger.info(
            f"[{context.trace_id}] END {module_id} ({duration:.2f}ms)",
            extra={
                "trace_id": context.trace_id,
                "module_id": module_id,
                "duration_ms": duration,
                "output": output
            }
        )

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> None:
        # ⚠️ Security: Use context.redacted_inputs instead of raw inputs
        self.logger.error(
            f"[{context.trace_id}] ERROR {module_id}: {error}",
            extra={
                "trace_id": context.trace_id,
                "module_id": module_id,
                "error": str(error),
                "inputs": context.redacted_inputs
            },
            exc_info=True
        )

5.2 Performance Monitoring Middleware

import time
from apcore import Middleware, Context


class MetricsMiddleware(Middleware):
    """Performance metrics collection middleware"""

    def __init__(self):
        self.call_counts: dict[str, int] = {}
        self.durations: dict[str, list[float]] = {}
        self.error_counts: dict[str, int] = {}
        self._start_times: dict[str, float] = {}

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        self._start_times[context.trace_id] = time.perf_counter()

        # Increment call count
        self.call_counts[module_id] = self.call_counts.get(module_id, 0) + 1

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        start = self._start_times.pop(context.trace_id, time.perf_counter())
        duration = time.perf_counter() - start

        # Record execution time
        if module_id not in self.durations:
            self.durations[module_id] = []
        self.durations[module_id].append(duration)

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> None:
        self.error_counts[module_id] = self.error_counts.get(module_id, 0) + 1

    def get_stats(self, module_id: str) -> dict:
        """Get module statistics"""
        durations = self.durations.get(module_id, [])
        return {
            "call_count": self.call_counts.get(module_id, 0),
            "error_count": self.error_counts.get(module_id, 0),
            "avg_duration": sum(durations) / len(durations) if durations else 0,
            "min_duration": min(durations) if durations else 0,
            "max_duration": max(durations) if durations else 0,
        }

5.3 Cache Middleware

import hashlib
import json
from apcore import Middleware, Context


class CacheMiddleware(Middleware):
    """Simple cache middleware"""

    def __init__(self, ttl_seconds: int = 300):
        self.cache: dict[str, tuple[dict, float]] = {}
        self.ttl = ttl_seconds

    def _cache_key(self, module_id: str, inputs: dict) -> str:
        """Generate cache key"""
        content = json.dumps({"module_id": module_id, "inputs": inputs}, sort_keys=True)
        return hashlib.md5(content.encode()).hexdigest()

    def before(self, module_id: str, inputs: dict, context: Context) -> dict | None:
        key = self._cache_key(module_id, inputs)

        if key in self.cache:
            cached_output, cached_time = self.cache[key]
            import time
            if time.time() - cached_time < self.ttl:
                # Cache hit, skip execution
                context.data["cache_hit"] = True
                return cached_output

        return None

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        if not context.data.get("cache_hit"):
            # Cache miss, store result
            key = self._cache_key(module_id, inputs)
            import time
            self.cache[key] = (output, time.time())

5.4 Retry Middleware

import time
from apcore import Middleware, Context


class RetryMiddleware(Middleware):
    """Automatic retry middleware"""

    def __init__(
        self,
        max_retries: int = 3,
        delay_seconds: float = 1.0,
        exponential_backoff: bool = True
    ):
        self.max_retries = max_retries
        self.delay = delay_seconds
        self.exponential = exponential_backoff
        self._retry_counts: dict[str, int] = {}

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> dict | None:
        key = context.trace_id
        retries = self._retry_counts.get(key, 0)

        if retries < self.max_retries:
            # Calculate delay
            delay = self.delay * (2 ** retries) if self.exponential else self.delay
            time.sleep(delay)

            # Increment retry count
            self._retry_counts[key] = retries + 1

            # Re-execute (return special marker)
            context.data["retry_requested"] = True
            return None  # Continue raising exception, let Executor handle retry

        # Clean up retry count
        self._retry_counts.pop(key, None)
        return None  # Exceeded retry limit, continue raising exception

5.5 Rate Limit Middleware

import time
from collections import deque
from apcore import Middleware, Context, ModuleError


class RateLimitMiddleware(Middleware):
    """Rate limiting middleware"""

    def __init__(self, max_calls: int = 100, window_seconds: int = 60):
        self.max_calls = max_calls
        self.window = window_seconds
        self._call_times: dict[str, deque] = {}

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        now = time.time()

        if module_id not in self._call_times:
            self._call_times[module_id] = deque()

        calls = self._call_times[module_id]

        # Remove calls outside the window
        while calls and calls[0] < now - self.window:
            calls.popleft()

        # Check if limit is exceeded
        if len(calls) >= self.max_calls:
            raise ModuleError(
                f"Rate limit exceeded for {module_id}",
                module_id=module_id,
                trace_id=context.trace_id
            )

        # Record current call
        calls.append(now)

6. Modifying Inputs/Outputs

6.1 Modifying Inputs

class InputTransformMiddleware(Middleware):
    """Input transformation middleware"""

    def before(self, module_id: str, inputs: dict, context: Context) -> dict:
        # Return modified inputs
        modified = inputs.copy()

        # Add default values
        if "timeout" not in modified:
            modified["timeout"] = 30

        # Add trace information
        modified["_trace_id"] = context.trace_id

        return modified

6.2 Modifying Outputs

class OutputEnrichMiddleware(Middleware):
    """Output enrichment middleware"""

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> dict:
        # Return modified output
        enriched = output.copy()

        # Add metadata
        enriched["_metadata"] = {
            "module_id": module_id,
            "trace_id": context.trace_id,
            "timestamp": datetime.now().isoformat()
        }

        return enriched

6.3 Error Recovery

class FallbackMiddleware(Middleware):
    """Fallback middleware"""

    def __init__(self, fallback_values: dict[str, dict]):
        """
        Args:
            fallback_values: {module_id: default_output}
        """
        self.fallback_values = fallback_values

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> dict | None:
        if module_id in self.fallback_values:
            # Return fallback value, "swallow" exception
            return self.fallback_values[module_id]

        return None  # Continue raising exception

7. Middleware Composition

7.1 Function-first Middleware (Function-first API)

In addition to class-based middleware, you can directly register functions using use_before() / use_after():

executor = Executor(registry=registry)

# Register before hook
executor.use_before(
    lambda module_id, inputs, ctx: print(f"Calling {module_id}")
)

# Register after hook
executor.use_after(
    lambda module_id, inputs, output, ctx: print(f"Done {module_id}")
)

# Can also use regular functions
def log_before(module_id: str, inputs: dict, context: Context) -> None:
    context.logger.info(f"START {module_id}")

def log_after(module_id: str, inputs: dict, output: dict, context: Context) -> None:
    context.logger.info(f"END {module_id}")

executor.use_before(log_before)
executor.use_after(log_after)

Use cases:

Approach Use Case
Class-based (Middleware) Need on_error, need shared state, complex logic
Function-first (use_before/use_after) Simple hooks, rapid prototyping, single concern

7.2 Chaining

executor = Executor(registry=registry)

# Chain middleware
executor \
    .use(LoggingMiddleware()) \
    .use(MetricsMiddleware()) \
    .use(CacheMiddleware(ttl_seconds=60)) \
    .use(RetryMiddleware(max_retries=3))

7.3 Conditional Middleware

class ConditionalMiddleware(Middleware):
    """Conditional middleware: only applies to specific modules"""

    def __init__(self, inner: Middleware, pattern: str):
        self.inner = inner
        self.pattern = pattern

    def _matches(self, module_id: str) -> bool:
        import fnmatch
        return fnmatch.fnmatch(module_id, self.pattern)

    def before(self, module_id: str, inputs: dict, context: Context):
        if self._matches(module_id):
            return self.inner.before(module_id, inputs, context)

    def after(self, module_id: str, inputs: dict, output: dict, context: Context):
        if self._matches(module_id):
            return self.inner.after(module_id, inputs, output, context)

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context):
        if self._matches(module_id):
            return self.inner.on_error(module_id, inputs, error, context)


# Usage: Enable cache only for executor.* modules
executor.use(
    ConditionalMiddleware(CacheMiddleware(), "executor.*")
)

7.4 Middleware Group

class MiddlewareGroup(Middleware):
    """Middleware group: use multiple middleware as one"""

    def __init__(self, middlewares: list[Middleware]):
        self.middlewares = middlewares

    def before(self, module_id: str, inputs: dict, context: Context):
        result = inputs
        for mw in self.middlewares:
            new_result = mw.before(module_id, result, context)
            if new_result is not None:
                result = new_result
        return result if result != inputs else None

    def after(self, module_id: str, inputs: dict, output: dict, context: Context):
        result = output
        for mw in reversed(self.middlewares):
            new_result = mw.after(module_id, inputs, result, context)
            if new_result is not None:
                result = new_result
        return result if result != output else None


# Usage
production_middlewares = MiddlewareGroup([
    LoggingMiddleware(),
    MetricsMiddleware(),
    RateLimitMiddleware()
])

executor.use(production_middlewares)

8. Async Middleware

from apcore import AsyncMiddleware, Context


class AsyncLoggingMiddleware(AsyncMiddleware):
    """Async logging middleware"""

    async def before(self, module_id: str, inputs: dict, context: Context) -> None:
        # Async log writing
        await self._async_log(f"Calling {module_id}")

    async def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        await self._async_log(f"Completed {module_id}")

    async def _async_log(self, message: str) -> None:
        # Async write to logging service
        import aiohttp
        async with aiohttp.ClientSession() as session:
            await session.post("https://logs.example.com/api/log", json={"message": message})

9. Middleware Security Specifications

9.1 Log Redaction

When middleware records logs, it is forbidden to directly log raw inputs parameters. Must use context.redacted_inputs, which is automatically generated by Executor based on x-sensitive: true markers in the Schema.

# ❌ Dangerous: May leak passwords, API keys, and other sensitive data
self.logger.info(f"inputs: {inputs}")

# ✅ Safe: Sensitive fields are replaced with ***REDACTED***
self.logger.info(f"inputs: {context.redacted_inputs}")

9.2 _secret_ Prefix Convention

When passing middleware state via context.data, if it contains sensitive information, use the _secret_ prefix:

# Set sensitive data
context.data["_secret_auth_token"] = "Bearer sk-..."

# Framework guarantee: keys starting with _secret_ are automatically redacted in log serialization

9.3 Five-Layer Log Security Model

Layer Mechanism Description
L1 Schema x-sensitive: true Mark sensitive fields in Schema
L2 Executor auto-generates redacted_inputs Based on A13 redact_sensitive() algorithm
L3 _secret_ prefix convention Sensitive keys in context.data
L4 Error message sanitization propagate_error() does not include raw inputs in error messages
L5 context.logger Auto-inject trace_id/module_id, unified log format

10. Best Practices

10.1 Middleware Order

# Recommended order
executor = Executor(
    registry=registry,
    middlewares=[
        # 1. Rate limiting (check first)
        RateLimitMiddleware(),

        # 2. Logging (record all requests)
        LoggingMiddleware(),

        # 3. Performance monitoring
        MetricsMiddleware(),

        # 4. Cache (avoid duplicate execution)
        CacheMiddleware(),

        # 5. Retry (retry on execution failure)
        RetryMiddleware(),

        # 6. Fallback (last resort)
        FallbackMiddleware(fallbacks={...})
    ]
)

10.2 Performance Considerations

class EfficientMiddleware(Middleware):
    """Efficient middleware"""

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        # ✓ Quick check
        if not self._should_process(module_id):
            return

        # ✗ Avoid time-consuming operations in before
        # self._slow_operation()

        # ✓ Use context.data to pass data to after
        context.data["start_time"] = time.perf_counter()

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        start = context.data.get("start_time")
        if start:
            duration = time.perf_counter() - start
            # Process...

10.3 Error Handling

class SafeMiddleware(Middleware):
    """Safe middleware: own errors don't affect main flow"""

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        try:
            # Middleware logic
            self._do_something()
        except Exception as e:
            # Log error but don't raise
            logging.warning(f"Middleware error: {e}")

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        try:
            self._do_something_else()
        except Exception as e:
            logging.warning(f"Middleware error: {e}")

11. Complete Example

from apcore import Registry, Executor, Middleware, Context
import logging
import time


# 1. Custom logging middleware
class AppLoggingMiddleware(Middleware):
    def __init__(self):
        self.logger = logging.getLogger("app")
        self._times: dict[str, float] = {}

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        self._times[context.trace_id] = time.perf_counter()
        self.logger.info(f"[{context.trace_id}] → {module_id}")

    def after(self, module_id: str, inputs: dict, output: dict, context: Context) -> None:
        duration = (time.perf_counter() - self._times.pop(context.trace_id, 0)) * 1000
        success = output.get("success", True)
        status = "✓" if success else "✗"
        self.logger.info(f"[{context.trace_id}] {status} {module_id} ({duration:.1f}ms)")

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> None:
        self.logger.error(f"[{context.trace_id}] ✗ {module_id}: {error}")


# 2. Simple metrics middleware
class SimpleMetricsMiddleware(Middleware):
    def __init__(self):
        self.calls = 0
        self.errors = 0

    def before(self, module_id: str, inputs: dict, context: Context) -> None:
        self.calls += 1

    def on_error(self, module_id: str, inputs: dict, error: Exception, context: Context) -> None:
        self.errors += 1


# 3. Setup
logging.basicConfig(level=logging.INFO, format="%(message)s")

registry = Registry(extensions_dir="./extensions")
registry.discover()

metrics = SimpleMetricsMiddleware()

executor = Executor(
    registry=registry,
    middlewares=[
        AppLoggingMiddleware(),
        metrics
    ]
)

# 4. Usage
result = executor.call(
    module_id="executor.email.send_email",
    inputs={"to": "[email protected]", "subject": "Hi", "body": "Hello"}
)

print(f"\nTotal calls: {metrics.calls}")
print(f"Total errors: {metrics.errors}")

Output:

[abc-123] → executor.email.send_email
[abc-123] ✓ executor.email.send_email (45.2ms)

Total calls: 1
Total errors: 0

Next Steps