diff --git a/AiDotNetBenchmarkTests/AiDotNetBenchmarkTests.csproj b/AiDotNetBenchmarkTests/AiDotNetBenchmarkTests.csproj
index 81661d704..43bf58f15 100644
--- a/AiDotNetBenchmarkTests/AiDotNetBenchmarkTests.csproj
+++ b/AiDotNetBenchmarkTests/AiDotNetBenchmarkTests.csproj
@@ -6,6 +6,7 @@
enable
enable
latest
+ true
$(NoWarn);CA1822
diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/AttentionBenchmark.cs b/AiDotNetBenchmarkTests/InferenceOptimization/AttentionBenchmark.cs
new file mode 100644
index 000000000..95ee71519
--- /dev/null
+++ b/AiDotNetBenchmarkTests/InferenceOptimization/AttentionBenchmark.cs
@@ -0,0 +1,135 @@
+using BenchmarkDotNet.Attributes;
+using BenchmarkDotNet.Jobs;
+using AiDotNet.InferenceOptimization;
+using AiDotNet.InferenceOptimization.Kernels;
+using AiDotNet.LinearAlgebra;
+using System;
+
+namespace AiDotNetBenchmarkTests.InferenceOptimization
+{
+ ///
+ /// Benchmarks for fused attention kernel
+ ///
+ [SimpleJob(RuntimeMoniker.Net80)]
+ [MemoryDiagnoser]
+ [CsvExporter]
+ [HtmlExporter]
+ public class AttentionBenchmark
+ {
+ private Tensor _q;
+ private Tensor _k;
+ private Tensor _v;
+ private AttentionKernel _attentionKernel;
+
+ [Params(64, 128, 256)]
+ public int SequenceLength { get; set; }
+
+ [Params(32, 64)]
+ public int FeatureDim { get; set; }
+
+ [GlobalSetup]
+ public void Setup()
+ {
+ OptimizationInitializer.Initialize(enableProfiling: false);
+
+ _attentionKernel = new AttentionKernel();
+
+ // Initialize Q, K, V tensors
+ var random = new Random(42);
+ _q = new Tensor(new[] { 1, SequenceLength, FeatureDim });
+ _k = new Tensor(new[] { 1, SequenceLength, FeatureDim });
+ _v = new Tensor(new[] { 1, SequenceLength, FeatureDim });
+
+ for (int i = 0; i < _q.Data.Length; i++)
+ {
+ _q.Data[i] = (float)random.NextDouble();
+ }
+
+ for (int i = 0; i < _k.Data.Length; i++)
+ {
+ _k.Data[i] = (float)random.NextDouble();
+ }
+
+ for (int i = 0; i < _v.Data.Length; i++)
+ {
+ _v.Data[i] = (float)random.NextDouble();
+ }
+ }
+
+ [Benchmark(Baseline = true)]
+ public Tensor NaiveAttention()
+ {
+ // Naive implementation: QK^T, softmax, multiply by V
+ float scale = 1.0f / MathF.Sqrt(FeatureDim);
+
+ // Compute attention scores
+ var scores = new float[SequenceLength * SequenceLength];
+
+ for (int i = 0; i < SequenceLength; i++)
+ {
+ for (int j = 0; j < SequenceLength; j++)
+ {
+ float score = 0.0f;
+ for (int k = 0; k < FeatureDim; k++)
+ {
+ score += _q.Data[i * FeatureDim + k] * _k.Data[j * FeatureDim + k];
+ }
+ scores[i * SequenceLength + j] = score * scale;
+ }
+ }
+
+ // Apply softmax
+ for (int i = 0; i < SequenceLength; i++)
+ {
+ float maxVal = float.NegativeInfinity;
+ for (int j = 0; j < SequenceLength; j++)
+ {
+ if (scores[i * SequenceLength + j] > maxVal)
+ maxVal = scores[i * SequenceLength + j];
+ }
+
+ float sum = 0.0f;
+ for (int j = 0; j < SequenceLength; j++)
+ {
+ scores[i * SequenceLength + j] = MathF.Exp(scores[i * SequenceLength + j] - maxVal);
+ sum += scores[i * SequenceLength + j];
+ }
+
+ for (int j = 0; j < SequenceLength; j++)
+ {
+ scores[i * SequenceLength + j] /= sum;
+ }
+ }
+
+ // Multiply by V
+ var result = new Tensor(new[] { 1, SequenceLength, FeatureDim });
+
+ for (int i = 0; i < SequenceLength; i++)
+ {
+ for (int j = 0; j < FeatureDim; j++)
+ {
+ float sum = 0.0f;
+ for (int k = 0; k < SequenceLength; k++)
+ {
+ sum += scores[i * SequenceLength + k] * _v.Data[k * FeatureDim + j];
+ }
+ result.Data[i * FeatureDim + j] = sum;
+ }
+ }
+
+ return result;
+ }
+
+ [Benchmark]
+ public Tensor OptimizedAttention()
+ {
+ return _attentionKernel.Execute(_q, _k, _v);
+ }
+
+ [Benchmark]
+ public Tensor MultiHeadAttention()
+ {
+ return _attentionKernel.MultiHeadAttention(_q, _k, _v, numHeads: 8);
+ }
+ }
+}
diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/GemmBenchmark.cs b/AiDotNetBenchmarkTests/InferenceOptimization/GemmBenchmark.cs
new file mode 100644
index 000000000..1da8a5815
--- /dev/null
+++ b/AiDotNetBenchmarkTests/InferenceOptimization/GemmBenchmark.cs
@@ -0,0 +1,84 @@
+using BenchmarkDotNet.Attributes;
+using BenchmarkDotNet.Jobs;
+using AiDotNet.InferenceOptimization;
+using AiDotNet.InferenceOptimization.Kernels;
+using AiDotNet.LinearAlgebra;
+using System;
+
+namespace AiDotNetBenchmarkTests.InferenceOptimization
+{
+ ///
+ /// Benchmarks for GEMM (General Matrix Multiplication) kernel
+ /// Tests optimized implementation against naive implementation
+ ///
+ [SimpleJob(RuntimeMoniker.Net80)]
+ [MemoryDiagnoser]
+ [CsvExporter]
+ [HtmlExporter]
+ public class GemmBenchmark
+ {
+ private Tensor _matrixA;
+ private Tensor _matrixB;
+ private GemmKernel _gemmKernel;
+
+ [Params(64, 128, 256, 512, 1024)]
+ public int MatrixSize { get; set; }
+
+ [GlobalSetup]
+ public void Setup()
+ {
+ OptimizationInitializer.Initialize(enableProfiling: false);
+
+ _gemmKernel = new GemmKernel();
+
+ // Initialize matrices with random data
+ var random = new Random(42);
+ _matrixA = new Tensor(new[] { MatrixSize, MatrixSize });
+ _matrixB = new Tensor(new[] { MatrixSize, MatrixSize });
+
+ for (int i = 0; i < _matrixA.Data.Length; i++)
+ {
+ _matrixA.Data[i] = (float)random.NextDouble();
+ }
+
+ for (int i = 0; i < _matrixB.Data.Length; i++)
+ {
+ _matrixB.Data[i] = (float)random.NextDouble();
+ }
+ }
+
+ [Benchmark(Baseline = true)]
+ public Tensor NaiveGemm()
+ {
+ // Naive triple-nested loop implementation
+ var result = new Tensor(new[] { MatrixSize, MatrixSize });
+
+ for (int i = 0; i < MatrixSize; i++)
+ {
+ for (int j = 0; j < MatrixSize; j++)
+ {
+ float sum = 0.0f;
+ for (int k = 0; k < MatrixSize; k++)
+ {
+ sum += _matrixA.Data[i * MatrixSize + k] * _matrixB.Data[k * MatrixSize + j];
+ }
+ result.Data[i * MatrixSize + j] = sum;
+ }
+ }
+
+ return result;
+ }
+
+ [Benchmark]
+ public Tensor OptimizedGemm()
+ {
+ return _gemmKernel.Execute(_matrixA, _matrixB);
+ }
+
+ [Benchmark]
+ public Tensor OptimizedGemmTranspose()
+ {
+ return _gemmKernel.GemmTransposeB(_matrixA, _matrixB);
+ }
+ }
+}
diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/README.md b/AiDotNetBenchmarkTests/InferenceOptimization/README.md
new file mode 100644
index 000000000..c4551f1cb
--- /dev/null
+++ b/AiDotNetBenchmarkTests/InferenceOptimization/README.md
@@ -0,0 +1,209 @@
+# Inference Optimization Benchmarks
+
+This directory contains benchmark tests for the inference optimization components.
+
+## Running Benchmarks
+
+### All Benchmarks
+
+```bash
+cd AiDotNetBenchmarkTests
+dotnet run -c Release --project InferenceOptimization
+```
+
+### Individual Benchmarks
+
+```bash
+# GEMM benchmark
+dotnet run -c Release --filter "*GemmBenchmark*"
+
+# SIMD benchmark
+dotnet run -c Release --filter "*SimdBenchmark*"
+
+# Attention benchmark
+dotnet run -c Release --filter "*AttentionBenchmark*"
+```
+
+## Benchmark Descriptions
+
+### GemmBenchmark
+Tests matrix multiplication performance:
+- **NaiveGemm**: Baseline triple-nested loop implementation
+- **OptimizedGemm**: Cache-blocked SIMD-optimized implementation
+- **OptimizedGemmTranspose**: Optimized implementation for transposed matrices
+
+**Matrix sizes tested**: 64x64, 128x128, 256x256, 512x512, 1024x1024
+
+**Expected results**:
+- 2-3x speedup on AVX2 systems
+- 2.5x speedup on ARM NEON systems
+- Better speedup for larger matrices
+
+### SimdBenchmark
+Tests SIMD-optimized vector operations:
+- **Vector Addition**: Element-wise addition
+- **Vector Multiplication**: Element-wise multiplication
+- **Dot Product**: Inner product with FMA optimization
+- **ReLU**: Activation function
+- **Sum**: Reduction operation
+
+**Array sizes tested**: 1K, 10K, 100K, 1M elements
+
+**Expected results**:
+- 4-8x speedup on AVX2 systems (processes 8 floats at once)
+- 2-4x speedup on SSE systems (processes 4 floats at once)
+- 2-4x speedup on NEON systems (processes 4 floats at once)
+
+### AttentionBenchmark
+Tests fused attention kernel performance:
+- **NaiveAttention**: Standard three-step implementation (QK^T, softmax, V)
+- **OptimizedAttention**: Fused implementation with SIMD
+- **MultiHeadAttention**: Multi-head variant (8 heads)
+
+**Parameters tested**:
+- Sequence lengths: 64, 128, 256
+- Feature dimensions: 32, 64
+
+**Expected results**:
+- 2-2.5x speedup from memory traffic reduction
+- Better performance for longer sequences
+
+## Interpreting Results
+
+BenchmarkDotNet produces detailed reports including:
+
+### Timing Metrics
+- **Mean**: Average execution time
+- **Error**: Half of 99.9% confidence interval
+- **StdDev**: Standard deviation
+- **Median**: 50th percentile
+
+### Memory Metrics
+- **Gen0/Gen1/Gen2**: Garbage collection frequency
+- **Allocated**: Total memory allocated
+
+### Speedup Calculation
+```text
+Speedup = Baseline Time / Optimized Time
+```
+
+Example output:
+```text
+| Method | MatrixSize | Mean | Error | StdDev | Ratio |
+|---------------------- |----------- |----------:|---------:|---------:|------:|
+| NaiveGemm | 256 | 27.45 ms | 0.421 ms | 0.394 ms | 1.00 |
+| OptimizedGemm | 256 | 9.12 ms | 0.142 ms | 0.133 ms | 0.33 |
+
+Speedup = 27.45 / 9.12 = 3.01x
+```
+
+## Performance Targets
+
+### GEMM
+- ✅ Target: 2-5x speedup
+- Platform dependent:
+ - AVX2: 2.5-3x
+ - AVX-512: 3-4x
+ - NEON: 2-2.5x
+
+### SIMD Operations
+- ✅ Target: 2-8x speedup
+- Depends on:
+ - Instruction set (AVX2 > SSE > scalar)
+ - Array size (larger = better amortization)
+ - Operation type (simple ops get higher speedup)
+
+### Attention
+- ✅ Target: 2-3x speedup
+- Benefits:
+ - Reduced memory traffic
+ - Fused operations
+ - Cache efficiency
+
+## Platform-Specific Results
+
+Your benchmark results will vary based on:
+
+1. **CPU Architecture**
+ - Intel/AMD x86_64: Best with AVX2/AVX-512
+ - ARM: Good with NEON
+ - Older CPUs: Falls back to SSE or scalar
+
+2. **Cache Hierarchy**
+ - Larger caches = Better performance for blocked algorithms
+ - L1/L2/L3 sizes affect optimal tile sizes
+
+3. **Memory Bandwidth**
+ - DDR4/DDR5 speed affects large matrix operations
+ - Memory channels matter for parallel operations
+
+4. **Thermal Throttling**
+ - Sustained benchmarks may hit thermal limits
+ - Use adequate cooling
+
+## Comparing to Reference Implementations
+
+To compare against Intel MKL or OpenBLAS:
+
+```csharp
+// Add reference implementation
+[Benchmark]
+public Tensor MKL_SGEMM()
+{
+ // Call Intel MKL cblas_sgemm
+ // ...
+}
+```
+
+## Contributing
+
+To add new benchmarks:
+
+1. Create a new class inheriting benchmark attributes
+2. Add `[Params]` for different sizes/configurations
+3. Implement baseline (naive) version
+4. Implement optimized version
+5. Add `[GlobalSetup]` for initialization
+6. Mark baseline with `[Benchmark(Baseline = true)]`
+7. Add memory diagnostics: `[MemoryDiagnoser]`
+
+## CI/CD Integration
+
+Add to your CI pipeline:
+
+```yaml
+- name: Run Benchmarks
+ run: |
+ cd AiDotNetBenchmarkTests
+ dotnet run -c Release --filter "*InferenceOptimization*"
+
+- name: Upload Results
+ uses: actions/upload-artifact@v3
+ with:
+ name: benchmark-results
+ path: BenchmarkDotNet.Artifacts/results/
+```
+
+## Troubleshooting
+
+### Benchmark takes too long
+- Reduce parameter ranges
+- Use `[SimpleJob]` instead of full job
+- Reduce warmup/iteration counts
+
+### Inconsistent results
+- Close other applications
+- Disable CPU frequency scaling
+- Run multiple iterations
+- Check for thermal throttling
+
+### Out of memory
+- Reduce test sizes
+- Add `[MemoryDiagnoser]` to track allocations
+- Consider streaming benchmarks for large data
+
+## References
+
+- [BenchmarkDotNet Documentation](https://benchmarkdotnet.org/)
+- [Intel Intrinsics Guide](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/)
+- [ARM NEON Documentation](https://developer.arm.com/architectures/instruction-sets/simd-isas/neon)
diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/SimdBenchmark.cs b/AiDotNetBenchmarkTests/InferenceOptimization/SimdBenchmark.cs
new file mode 100644
index 000000000..fb9dcf925
--- /dev/null
+++ b/AiDotNetBenchmarkTests/InferenceOptimization/SimdBenchmark.cs
@@ -0,0 +1,168 @@
+using BenchmarkDotNet.Attributes;
+using BenchmarkDotNet.Configs;
+using BenchmarkDotNet.Jobs;
+using AiDotNet.InferenceOptimization;
+using AiDotNet.Tensors.Engines.Simd;
+using System;
+
+namespace AiDotNetBenchmarkTests.InferenceOptimization
+{
+ ///
+ /// Benchmarks for SIMD-optimized operations
+ ///
+ [SimpleJob(RuntimeMoniker.Net80)]
+ [MemoryDiagnoser]
+ [CsvExporter]
+ [HtmlExporter]
+ [GroupBenchmarksBy(BenchmarkLogicalGroupRule.ByCategory)]
+ public class SimdBenchmark
+ {
+ private float[] _arrayA;
+ private float[] _arrayB;
+ private float[] _result;
+
+ [Params(1000, 10000, 100000, 1000000)]
+ public int ArraySize { get; set; }
+
+ [GlobalSetup]
+ public void Setup()
+ {
+ OptimizationInitializer.Initialize(enableProfiling: false);
+
+ var random = new Random(42);
+ _arrayA = new float[ArraySize];
+ _arrayB = new float[ArraySize];
+ _result = new float[ArraySize];
+
+ for (int i = 0; i < ArraySize; i++)
+ {
+ _arrayA[i] = (float)random.NextDouble();
+ _arrayB[i] = (float)random.NextDouble();
+ }
+ }
+
+ #region Vector Addition
+
+ [Benchmark(Baseline = true)]
+ [BenchmarkCategory("VectorAdd")]
+ public void VectorAdd_Scalar()
+ {
+ for (int i = 0; i < ArraySize; i++)
+ {
+ _result[i] = _arrayA[i] + _arrayB[i];
+ }
+ }
+
+ [Benchmark]
+ [BenchmarkCategory("VectorAdd")]
+ public unsafe void VectorAdd_SIMD()
+ {
+ fixed (float* pA = _arrayA, pB = _arrayB, pR = _result)
+ {
+ SimdKernels.VectorAdd(pA, pB, pR, ArraySize);
+ }
+ }
+
+ #endregion
+
+ #region Vector Multiplication
+
+ [Benchmark(Baseline = true)]
+ [BenchmarkCategory("VectorMultiply")]
+ public void VectorMultiply_Scalar()
+ {
+ for (int i = 0; i < ArraySize; i++)
+ {
+ _result[i] = _arrayA[i] * _arrayB[i];
+ }
+ }
+
+ [Benchmark]
+ [BenchmarkCategory("VectorMultiply")]
+ public unsafe void VectorMultiply_SIMD()
+ {
+ fixed (float* pA = _arrayA, pB = _arrayB, pR = _result)
+ {
+ SimdKernels.VectorMultiply(pA, pB, pR, ArraySize);
+ }
+ }
+
+ #endregion
+
+ #region Dot Product
+
+ [Benchmark(Baseline = true)]
+ [BenchmarkCategory("DotProduct")]
+ public float DotProduct_Scalar()
+ {
+ float sum = 0.0f;
+ for (int i = 0; i < ArraySize; i++)
+ {
+ sum += _arrayA[i] * _arrayB[i];
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ [BenchmarkCategory("DotProduct")]
+ public unsafe float DotProduct_SIMD()
+ {
+ fixed (float* pA = _arrayA, pB = _arrayB)
+ {
+ return SimdKernels.DotProduct(pA, pB, ArraySize);
+ }
+ }
+
+ #endregion
+
+ #region ReLU Activation
+
+ [Benchmark(Baseline = true)]
+ [BenchmarkCategory("ReLU")]
+ public void ReLU_Scalar()
+ {
+ for (int i = 0; i < ArraySize; i++)
+ {
+ _result[i] = Math.Max(0.0f, _arrayA[i]);
+ }
+ }
+
+ [Benchmark]
+ [BenchmarkCategory("ReLU")]
+ public unsafe void ReLU_SIMD()
+ {
+ fixed (float* pA = _arrayA, pR = _result)
+ {
+ SimdKernels.ReLU(pA, pR, ArraySize);
+ }
+ }
+
+ #endregion
+
+ #region Sum Reduction
+
+ [Benchmark(Baseline = true)]
+ [BenchmarkCategory("Sum")]
+ public float Sum_Scalar()
+ {
+ float sum = 0.0f;
+ for (int i = 0; i < ArraySize; i++)
+ {
+ sum += _arrayA[i];
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ [BenchmarkCategory("Sum")]
+ public unsafe float Sum_SIMD()
+ {
+ fixed (float* pA = _arrayA)
+ {
+ return SimdKernels.Sum(pA, ArraySize);
+ }
+ }
+
+ #endregion
+ }
+}
diff --git a/INTEGRATION_PLAN_PR433.md b/INTEGRATION_PLAN_PR433.md
new file mode 100644
index 000000000..1cb9457a1
--- /dev/null
+++ b/INTEGRATION_PLAN_PR433.md
@@ -0,0 +1,257 @@
+# PR #433 Integration Plan: Option A - Enhance CpuEngine with SIMD
+
+## Executive Summary
+
+This plan integrates the InferenceOptimization code from PR #433 into the existing IEngine architecture, eliminating duplication and ensuring all optimizations benefit the entire codebase.
+
+---
+
+## Part 1: Analysis Summary
+
+### What Already Exists in Master
+| Component | Location | Notes |
+|-----------|----------|-------|
+| IEngine interface | `AiDotNet.Tensors/Engines/IEngine.cs` | Unified engine abstraction |
+| CpuEngine | `AiDotNet.Tensors/Engines/CpuEngine.cs` | Generic O(n³) MatrixMultiply, NO SIMD |
+| GpuEngine | `AiDotNet.Tensors/Engines/GpuEngine.cs` | GPU operations |
+| TensorBase | `AiDotNet.Tensors/LinearAlgebra/TensorBase.cs` | Uses `Shape`, protected `_data` |
+
+### What PR #433 Adds (44 files in InferenceOptimization/)
+| Category | Files | Value |
+|----------|-------|-------|
+| SIMD Kernels | `SimdKernels.cs` | HIGH - AVX/AVX2/SSE/NEON explicit intrinsics |
+| Optimized GEMM | `GemmKernel.cs` | DUPLICATE - conflicts with CpuEngine |
+| Attention | `AttentionKernel.cs` | DUPLICATE - uses wrong Tensor API |
+| Convolution | `ConvolutionKernel.cs` | DUPLICATE - uses wrong Tensor API |
+| Platform Detection | `PlatformDetector.cs` | HIGH - CPU/SIMD capability detection |
+| CPU Optimization | `CacheOptimizer.cs`, `LoopOptimizer.cs` | HIGH - cache/loop optimization utilities |
+| Performance Profiler | `PerformanceProfiler.cs` | MEDIUM - profiling infrastructure |
+| Graph Optimization | `Core/*.cs`, `Passes/*.cs` | HIGH - 13 optimization passes |
+| IR System | `IR/*.cs` | HIGH - HLIR/LLIR intermediate representation |
+| Custom Operators | `CustomOperatorRegistry.cs`, `ICustomOperator.cs` | MEDIUM - extensibility |
+| GPU Infrastructure | `GpuKernelBase.cs` | LOW - base class only |
+
+### API Mismatch Issues
+The InferenceOptimization code uses a different Tensor API:
+- **PR #433 uses**: `tensor.Dimensions`, `tensor.Data`, `new Tensor(int[])`
+- **Actual API**: `tensor.Shape`, `tensor._data` (protected), `new TensorBase(int[])`
+
+This causes 130+ build errors.
+
+---
+
+## Part 2: Integration Steps
+
+### Step 1: Move SimdKernels to AiDotNet.Tensors (KEEP)
+
+**Action**: Move `SimdKernels.cs` to `AiDotNet.Tensors/Engines/Simd/` folder
+
+**Changes**:
+```
+src/InferenceOptimization/Kernels/SimdKernels.cs
+ → src/AiDotNet.Tensors/Engines/Simd/SimdKernels.cs
+```
+
+**Namespace change**: `AiDotNet.InferenceOptimization.Kernels` → `AiDotNet.Tensors.Engines.Simd`
+
+### Step 2: Integrate SIMD into CpuEngine
+
+**Action**: Add SIMD-accelerated paths to CpuEngine methods
+
+**Target methods to enhance**:
+1. `VectorAdd` - use `SimdKernels.VectorAdd` when T is float
+2. `VectorMultiply` - use `SimdKernels.VectorMultiply` when T is float
+3. `DotProduct` - use `SimdKernels.DotProduct` when T is float
+4. `MatrixMultiply` - use SIMD-optimized GEMM when T is float
+
+**Pattern**:
+```csharp
+public Vector VectorAdd(Vector a, Vector b)
+{
+ // Check if we can use SIMD optimization
+ if (typeof(T) == typeof(float) && SimdCapabilities.HasSimd)
+ {
+ return VectorAddSimd((Vector)(object)a, (Vector)(object)b);
+ }
+
+ // Generic fallback
+ return VectorAddGeneric(a, b);
+}
+```
+
+### Step 3: Move PlatformDetector to AiDotNet.Tensors (KEEP)
+
+**Action**: Move and integrate platform detection
+
+**Changes**:
+```
+src/InferenceOptimization/PlatformDetector.cs
+ → src/AiDotNet.Tensors/Engines/PlatformDetector.cs
+```
+
+**Integration**:
+- Initialize at `AiDotNetEngine` startup
+- Expose via `AiDotNetEngine.Capabilities`
+
+### Step 4: Move CPU Optimization Utilities (KEEP)
+
+**Action**: Move cache and loop optimizers
+
+**Changes**:
+```
+src/InferenceOptimization/CpuOptimization/CacheOptimizer.cs
+ → src/AiDotNet.Tensors/Engines/Optimization/CacheOptimizer.cs
+
+src/InferenceOptimization/CpuOptimization/LoopOptimizer.cs
+ → src/AiDotNet.Tensors/Engines/Optimization/LoopOptimizer.cs
+```
+
+**Use in**: CpuEngine MatrixMultiply for cache-blocked algorithms
+
+### Step 5: Move Performance Profiler (KEEP)
+
+**Action**: Move profiler to Helpers namespace
+
+**Changes**:
+```
+src/InferenceOptimization/Profiling/PerformanceProfiler.cs
+ → src/Helpers/PerformanceProfiler.cs
+```
+
+### Step 6: Fix Graph Optimization Passes (KEEP with fixes)
+
+**Action**: Keep IR and Passes but fix Tensor API usage
+
+**Files to fix** (use `Shape` instead of `Dimensions`):
+- `src/InferenceOptimization/Core/*.cs`
+- `src/InferenceOptimization/Passes/*.cs`
+- `src/InferenceOptimization/IR/*.cs`
+
+**API changes needed**:
+| Wrong | Correct |
+|-------|---------|
+| `tensor.Dimensions` | `tensor.Shape` |
+| `tensor.Data` | Create public accessor or use indexer |
+| `new Tensor(int[])` | Use TensorBase constructor |
+
+### Step 7: FIX Kernel Files API (KEEP ALL)
+
+**Action**: Fix Tensor API in all kernel files to use TensorBase properly
+
+**Files to FIX (change `Dimensions` → `Shape`, fix data access)**:
+- `src/InferenceOptimization/Kernels/GemmKernel.cs` - Industry-standard cache-blocked SIMD GEMM
+- `src/InferenceOptimization/Kernels/AttentionKernel.cs` - Fused transformer attention
+- `src/InferenceOptimization/Kernels/ConvolutionKernel.cs` - Optimized convolutions
+
+**Files to KEEP (extensibility infrastructure)**:
+- `src/InferenceOptimization/ICustomOperator.cs` - Extensibility interface
+- `src/InferenceOptimization/CustomOperatorRegistry.cs` - Operator registration system
+- `src/InferenceOptimization/OptimizationInitializer.cs` - Initialization entry point
+- `src/InferenceOptimization/GpuOptimization/GpuKernelBase.cs` - GPU kernel base class
+
+**Files to DELETE (examples only)**:
+- `src/InferenceOptimization/Examples/*.cs` - Example files that will be outdated
+
+### Step 8: Update Namespace References
+
+**Action**: Update all using statements
+
+**Old namespaces to replace**:
+- `AiDotNet.InferenceOptimization.Kernels` → `AiDotNet.Tensors.Engines.Simd`
+- `AiDotNet.InferenceOptimization.CpuOptimization` → `AiDotNet.Tensors.Engines.Optimization`
+- `AiDotNet.InferenceOptimization.Profiling` → `AiDotNet.Helpers`
+
+---
+
+## Part 3: File Disposition Matrix
+
+| File | Action | Notes |
+|------|--------|-------|
+| `Kernels/SimdKernels.cs` | MOVE | → `AiDotNet.Tensors/Engines/Simd/` |
+| `Kernels/GemmKernel.cs` | FIX | Fix Tensor API - industry-standard cache-blocked GEMM |
+| `Kernels/AttentionKernel.cs` | FIX | Fix Tensor API - fused transformer attention |
+| `Kernels/ConvolutionKernel.cs` | FIX | Fix Tensor API - optimized convolutions |
+| `PlatformDetector.cs` | MOVE | → `AiDotNet.Tensors/Engines/` |
+| `CpuOptimization/CacheOptimizer.cs` | MOVE | → `AiDotNet.Tensors/Engines/Optimization/` |
+| `CpuOptimization/LoopOptimizer.cs` | MOVE | → `AiDotNet.Tensors/Engines/Optimization/` |
+| `Profiling/PerformanceProfiler.cs` | MOVE | → `Helpers/` |
+| `Core/*.cs` | FIX | Fix Tensor API in place |
+| `Passes/*.cs` | FIX | Fix Tensor API in place |
+| `IR/*.cs` | FIX | Fix Tensor API in place |
+| `CustomOperatorRegistry.cs` | FIX | Fix nullability + keep for extensibility |
+| `ICustomOperator.cs` | KEEP | Extensibility interface |
+| `OptimizationInitializer.cs` | FIX | Fix nullability + keep as entry point |
+| `GpuOptimization/GpuKernelBase.cs` | FIX | Fix nullability + keep for future GPU work |
+| `Examples/*.cs` | DELETE | Will be outdated after API fixes |
+| `README.md` | UPDATE | Update for new architecture |
+
+---
+
+## Part 4: Expected Outcomes
+
+### After Integration:
+1. **Single Architecture**: All optimizations flow through IEngine
+2. **SIMD Everywhere**: CpuEngine automatically uses SIMD for float operations
+3. **No API Conflicts**: Graph passes use correct TensorBase API
+4. **Clean Codebase**: No duplicate kernel implementations
+
+### Issue #412 Completion After Integration:
+| Requirement | Status | Notes |
+|-------------|--------|-------|
+| SIMD Vectorization | 95% | Integrated into CpuEngine |
+| Optimized GEMM | 90% | Cache-blocked in CpuEngine |
+| Platform Detection | 100% | PlatformDetector integrated |
+| CPU Optimization | 90% | CacheOptimizer, LoopOptimizer |
+| Graph Optimization | 80% | IR system, 13 passes (needs API fix) |
+| Custom Operators | REMOVED | Not needed with IEngine |
+| GPU Optimization | 30% | Future work |
+| Benchmarks vs MKL | 0% | Future work |
+
+---
+
+## Part 5: Implementation Order
+
+1. **Phase 1 - Core SIMD** (Critical Path)
+ - [ ] Move SimdKernels.cs to AiDotNet.Tensors
+ - [ ] Move PlatformDetector.cs
+ - [ ] Integrate SIMD paths into CpuEngine
+
+2. **Phase 2 - Utilities**
+ - [ ] Move CacheOptimizer.cs
+ - [ ] Move LoopOptimizer.cs
+ - [ ] Move PerformanceProfiler.cs
+
+3. **Phase 3 - Graph Optimization**
+ - [ ] Fix Tensor API usage in Core/*.cs
+ - [ ] Fix Tensor API usage in Passes/*.cs
+ - [ ] Fix Tensor API usage in IR/*.cs
+
+4. **Phase 4 - Cleanup**
+ - [ ] Delete duplicate kernel files
+ - [ ] Delete unused infrastructure files
+ - [ ] Update README.md
+ - [ ] Build and test
+
+---
+
+## Appendix: Build Error Categories (130 errors)
+
+1. **CS1061 - Missing Members** (~100 errors)
+ - `Tensor` does not contain `Dimensions` → Use `Shape`
+ - `Tensor` does not contain `Data` → Use protected `_data` or accessor
+
+2. **CS8618 - Non-nullable Properties** (~15 errors)
+ - Properties need default values or nullable types
+
+3. **CS8603 - Possible Null Reference** (~10 errors)
+ - Need null checks or nullable return types
+
+4. **CS0103/CS0104 - Ambiguous/Missing References** (~5 errors)
+ - `Avx512VL` not found
+ - `Aes` ambiguous between X86 and ARM
+
+---
+
+*Plan created: 2025-12-14*
+*Target: PR #433, Issue #412*
+*Approach: Option A - Enhance CpuEngine with SIMD*
diff --git a/docs/INFERENCE_MVP_PHASES.md b/docs/INFERENCE_MVP_PHASES.md
new file mode 100644
index 000000000..de402b925
--- /dev/null
+++ b/docs/INFERENCE_MVP_PHASES.md
@@ -0,0 +1,180 @@
+# Inference MVP Phases (PR #433) — Implementation Plan
+
+This plan breaks down the remaining inference MVP work into phases that can be implemented and verified independently, while preserving the project’s facade philosophy (minimal public API surface) and default-first configuration approach.
+
+## Goals
+
+- Keep the public surface area limited to `PredictionModelBuilder` (build/train) and `PredictionModelResult` (inference), with a small number of carefully chosen inference entrypoints (e.g., session start).
+- Use industry-standard defaults when users do not specify options; allow opt-out via `InferenceOptimizationConfig`.
+- Ensure all PR #433 components are wired end-to-end: config → inference optimizer → optimized layers/kernels → serving integration.
+- Exceed industry standards where feasible (paged KV-cache, FP16 KV-cache, batching, speculative decoding, dynamic scheduling), without exposing internal IP.
+
+## Non-Goals (MVP)
+
+- Exposing layer-level kernels or optimizers publicly.
+- Full-blown public text-generation UX/API surface (tokenization, sampling, streaming) beyond what serving needs.
+- GPU-specific paging kernels (CPU-first correctness is priority for MVP).
+
+---
+
+## Phase 0 — Baseline Safety & Diagnostics
+
+**Outcome:** Optimization pipeline is observable and safe-by-default.
+
+1. Add internal inference diagnostics:
+ - Record decisions (enabled/disabled) per feature (KV-cache, masking mode resolution, flash attention, paging, speculative).
+ - Record exceptions with a reason and feature tag (do not throw from the facade in normal inference).
+2. Add stability guardrails:
+ - Avoid mutating user-supplied model instances unless explicitly requested (e.g., `cloneModel: false` internal path).
+ - When deep copy fails for a model, fall back to baseline inference and record diagnostics.
+3. Verification:
+ - Unit tests for optimizer decisions and non-throwing fallbacks.
+
+---
+
+## Phase 1 — Attention Rewrite Integration
+
+**Outcome:** All supported attention layers are rewritten consistently based on config, and the optimizer can run end-to-end.
+
+1. Wire attention rewrite decisions in `InferenceOptimizer`:
+ - `MultiHeadAttentionLayer` → `FlashAttentionLayer` when enabled.
+ - `MultiHeadAttentionLayer`/`FlashAttentionLayer` → cached attention when KV-cache enabled (causal default for sessions).
+ - `SelfAttentionLayer` conversion path to multi-head for downstream rewrites (when feasible).
+2. Add missing attention layer support:
+ - Ensure `SelfAttentionLayer`, `AttentionLayer`, and `GraphAttentionLayer` are handled for cloning/serialization and/or have clear non-supported fallbacks.
+3. Ensure cloning is truly deep:
+ - Serialization/deserialization round-trip must preserve parameters exactly.
+4. Verification:
+ - Unit tests that assert rewritten layers exist as expected per config.
+ - Clone tests verifying parameters match before mutation.
+
+---
+
+## Phase 2 — Paged Attention + Paged KV-Cache (Industry Standard Default)
+
+**Outcome:** Paged KV-cache is available and auto-enabled by default for KV-cache workloads, with opt-out.
+
+1. Default behavior:
+ - `EnablePagedKVCache = true` by default; user can disable.
+2. Wiring:
+ - `InferenceOptimizer` initializes paged cache when paged attention layers exist; otherwise falls back to contiguous KV-cache.
+3. Session behavior:
+ - Sessions should prefer causal masking when user sets `AttentionMasking=Auto`.
+4. Verification:
+ - Unit tests for paged attention kernel and cache mechanics (block tables, COW).
+ - Integration tests for session cache growth and reset.
+
+---
+
+## Phase 3 — KV-Cache Precision (FP16 default, opt-out)
+
+**Outcome:** KV-cache can store keys/values in FP16 to reduce memory and improve capacity/throughput, with opt-out to FP32.
+
+1. Configuration:
+ - Add `KVCachePrecision` with default `Auto` selecting FP16 when possible.
+2. Implementation:
+ - KV-cache uses FP16 backing storage when enabled and `T` supports conversion.
+3. Wiring:
+ - `InferenceOptimizer` resolves cache data type and records decision.
+4. Verification:
+ - Unit tests for FP16 round-trip and memory usage calculations.
+
+---
+
+## Phase 4 — Inference Sessions (Multi-Sequence, Facade-Friendly)
+
+**Outcome:** `PredictionModelResult.BeginInferenceSession()` supports multiple independent sequences for serving-style workloads without exposing internal implementation details.
+
+1. API:
+ - Keep public API minimal (session + sequence objects), do not expose layer internals.
+2. Behavior:
+ - Each sequence maintains independent KV-cache state and can `Reset()`.
+ - Session is safe to use for multiple sequences in parallel (thread-safety where feasible).
+3. Internal diagnostics:
+ - Provide internal-only stats hooks to validate caching behavior in integration tests without expanding public API.
+4. Verification:
+ - Integration tests for:
+ - Stateless `Predict()` behavior.
+ - Sequence independence.
+ - Reset restoring initial state.
+
+---
+
+## Phase 5 — Batching (Serving-First) + Resource Arbitration
+
+**Outcome:** `EnableBatching` is honored in serving and session contexts, with guardrails around incompatibilities with speculation.
+
+1. Serving integration:
+ - Use `AiDotNet.Serving` to host batching behavior and backpressure.
+2. Conflict policy:
+ - Document and implement a policy when batching and speculative decoding both enabled:
+ - Default: prioritize batching for throughput, optional override to prioritize speculation for latency.
+3. Verification:
+ - Serving-side tests around batch coalescing and max batch size enforcement.
+
+---
+
+## Phase 6 — Speculative Decoding MVP (Draft Model + Policy)
+
+**Outcome:** `EnableSpeculativeDecoding` and `SpeculationPolicy` are fully wired, with a draft model option and safe defaults.
+
+1. Configuration:
+ - Draft model selection via `DraftModelType` (NGram, small neural) and speculation depth.
+2. Session + serving:
+ - Enable speculation wherever sessions are used when flag is enabled.
+ - Serving integration for production usage (streaming/latency).
+3. Verification:
+ - Unit tests for policy decisions and safe fallback when draft model not available.
+
+---
+
+## Phase 7 - Dynamic Speculation & Alternative Speculators (Medusa/EAGLE)
+
+**Outcome:** Add next-gen speculative methods and dynamic scheduling options.
+
+1. Dynamic scheduling:
+ - Adaptive speculation depth based on acceptance rate, queue pressure, and compute budget.
+ - MVP implementation: serving-side `ContinuousBatcher` backs off speculative decoding under load and after observing low draft acceptance rates.
+2. Alternative methods:
+ - Add config hooks for Medusa/EAGLE-style multi-head draft proposals as a future opt-in.
+3. Verification:
+ - Bench-style tests (non-flaky) for acceptance-rate-driven behavior.
+
+---
+
+## Phase 8 - Inference Quantization (Gap-Closing)
+
+**Outcome:** Extend quantization support beyond training into inference areas where it is industry standard.
+
+1. KV-cache quantization:
+ - Optional per-layer KV-cache quantization (e.g., int8) with dequant on read.
+ - MVP implementation: `InferenceOptimizationConfig.KVCacheQuantization = Int8` routes KV-cache storage to int8 quantized backing storage with dequant-on-read.
+2. Weight-only quantization:
+ - Optional weight-only quant for inference (e.g., int8/int4) with fast matmul paths.
+3. Weight + activation quantization (advanced):
+ - Add as opt-in; ensure correctness-first.
+4. Verification:
+ - Unit tests validating numerics and shape correctness.
+
+---
+
+## Phase 9 - Multi-LoRA (Serving-First, Secure Defaults)
+
+**Outcome:** Multi-LoRA can be selected per request without leaking internal implementation details.
+
+1. Serving integration:
+ - Prefer selecting LoRA adapters from headers/metadata on serving side.
+ - MVP implementation: serving can route to a pre-loaded per-adapter model variant using `X-AiDotNet-Lora`/`X-AiDotNet-Adapter` headers (`{baseModelName}__{adapterId}`).
+2. Session integration:
+ - Optional adapter selection per sequence, but keep surface minimal.
+3. Verification:
+ - Serving tests for adapter routing and isolation.
+
+---
+
+## Release Checklist (Per Phase)
+
+- `dotnet build AiDotNet.sln -c Release`
+- Targeted `dotnet test` runs for touched areas
+- Update docs and XML comments to match project conventions (summary/remarks + “For Beginners” sections where appropriate)
+- Commit with conventional prefix (`feat:`, `fix:`, `test:`, `docs:`, `refactor:`) in small, regular increments
diff --git a/docs/PR433_FACADE_INFERENCE_PLAN.md b/docs/PR433_FACADE_INFERENCE_PLAN.md
new file mode 100644
index 000000000..f0e72603c
--- /dev/null
+++ b/docs/PR433_FACADE_INFERENCE_PLAN.md
@@ -0,0 +1,658 @@
+# PR #433 Facade Integration Plan (Inference Optimizations)
+
+## 0) Purpose
+
+Integrate PR #433 inference optimizations into the *existing facade pattern* so that:
+- Users configure everything via `PredictionModelBuilder.ConfigureInferenceOptimizations(...)`.
+- Users consume everything via `PredictionModelResult` (e.g., `Predict`, plus a session API).
+- Internal implementation details (optimizer, caches, kernels, batching, speculative decoding) remain hidden behind the facade.
+
+This plan is written to be actionable for a junior engineer without requiring deep prior context.
+
+---
+
+## 1) Constraints / Non‑Goals
+
+### 1.1 Facade constraints (must keep)
+- No new user-facing entry points besides:
+ - `PredictionModelBuilder` for configuration + training/build.
+ - `PredictionModelResult` for inference.
+- Any “advanced” types should be:
+ - `internal`, or
+ - nested under `PredictionModelResult` (if we must expose a session object).
+
+### 1.2 Correctness constraints (must keep)
+- "Industry standard defaults" when user omits config values.
+- Avoid semantic changes by default for non-autoregressive models, but prefer serving-grade defaults when inference optimizations are explicitly enabled:
+ - `AttentionMasking=Auto` should assume causal for inference sessions unless the model/layer explicitly indicates bidirectional attention.
+ - Paged KV-cache should be auto-enabled by default (opt-out via config).
+ - KV-cache must not silently change results for models that are not compatible with caching.
+
+### 1.3 Non-goals (for this integration pass)
+- Rewriting training-time behavior.
+- Full rewrite of the `src/InferenceOptimization/*` SIMD/IR plan (tracked separately in `INTEGRATION_PLAN_PR433.md`).
+
+---
+
+## 2) Current Issues / Gaps (What to Fix)
+
+### 2.1 “Disconnected” implementations
+New PR #433 classes exist but are not consistently invoked through the facade:
+- `FlashAttention.cs`, `CachedMultiHeadAttention.cs`, `InferenceOptimizer.cs`, KV cache variants, etc.
+
+### 2.2 Model cloning/serialization assumptions
+We need to treat `NeuralNetworkBase.Clone()` behavior carefully:
+- Confirm whether it is deep copy vs shallow copy in practice.
+- If cloning relies on serialization, ensure all attention layers can be deserialized with correct constructor arguments (constructor mismatch is a known risk).
+- If serialization cannot faithfully round-trip all layer types, implement a robust deep-clone path that still preserves the facade (users never touch these internals).
+
+### 2.3 Layer coverage
+The optimizer rewrite logic must recognize more attention layer types, not just:
+- `MultiHeadAttentionLayer` and `FlashAttentionLayer`.
+
+Add support/coverage for (at minimum):
+- `AttentionLayer` (mask-aware)
+- `SelfAttentionLayer`
+- `GraphAttentionLayer`
+- Any transformer encoder/decoder attention layers in `src/NeuralNetworks/Layers/*`
+
+### 2.4 Missing facade integration
+`EnableBatching` and `EnableSpeculativeDecoding` must be wired end-to-end:
+- Builder stores config.
+- Result uses config during inference.
+- Serving project should be able to leverage the same internal components.
+
+### 2.5 Paged attention integration missing
+`PagedAttention` / `PagedKVCache` exist but are not selected/used based on config.
+
+### 2.6 Tests
+We need:
+- Unit tests for low-level correctness (some exist).
+- Integration tests that validate the facade wiring end-to-end via `PredictionModelBuilder` + `PredictionModelResult`.
+
+---
+
+## 3) Target UX (What the user sees)
+
+### 3.1 Builder usage (no new user-facing components)
+```csharp
+var result = await new PredictionModelBuilder, Tensor>()
+ .ConfigureModel(model)
+ .ConfigureInferenceOptimizations(new InferenceOptimizationConfig
+ {
+ EnableFlashAttention = true,
+ EnableKVCache = true,
+ EnableBatching = true,
+ EnableSpeculativeDecoding = true
+ })
+ .BuildAsync(x, y);
+```
+
+### 3.2 Result usage
+Default usage stays the same:
+```csharp
+var y = result.Predict(x);
+```
+
+For sequence-based inference (KV-cache, generation, serving), add a *single* facade-friendly entry:
+- `result.BeginInferenceSession()` returns a session object that manages caches/batching internally.
+
+Example shape:
+```csharp
+using var session = result.BeginInferenceSession();
+var y1 = session.Predict(x1);
+var y2 = session.Predict(x2);
+```
+
+Notes:
+- The session type should be nested under `PredictionModelResult` (or otherwise kept hidden).
+- "Clear cache" should be internal to session lifecycle; avoid exposing raw cache types.
+- Avoid adding new public cache-control methods on `PredictionModelResult`; prefer session `Dispose()`/`Reset()`.
+
+---
+
+## 4) Implementation Plan (Phased)
+
+### Phase A - Baseline facade wiring + safety (no batching/spec yet)
+
+**Goal:** Ensure optimizations are *actually applied* and correct during inference for supported model types.
+
+1) **Confirm cloning semantics**
+ - Inspect `NeuralNetworkBase.Clone()` and `DeepCopy()` implementation.
+ - Decide policy:
+ - If deep clone is reliable for all relevant layers, prefer cloning before rewrites.
+ - If cloning is unreliable, either:
+ - Fix serialization/deserialization for attention layers, or
+ - Apply rewrites in-place but only inside an inference-session-scoped model instance owned by the result (never mutate the user's original model object).
+ - Required outcome for facade safety:
+ - Never mutate the user's original model object.
+ - Any optimized/mutated model instance is owned by the result/session internally.
+ - Acceptance criteria:
+ - No runtime errors when inference optimizations are enabled.
+ - No cross-request contamination (weights/caches).
+ - Fallback behavior + diagnostics (must be explicit):
+ - If an optimization is unsupported for the current model/layer/platform, it must be **auto-disabled** and inference proceeds with the baseline path (never throw by default).
+ - If an optimization throws at runtime, catch and **fall back to baseline** for that session/sequence where possible.
+ - Record decisions/exceptions via internal diagnostics (e.g., `InferenceDiagnostics.RecordDecision/RecordException`) so we can validate selection in tests and troubleshoot in serving logs without expanding the public API.
+
+2) **Make serialization/deserialization round-trip attention layers (if used by clone)**
+ - Inventory layer constructors that require metadata:
+ - `MultiHeadAttentionLayer` (e.g., head count)
+ - `SelfAttentionLayer`
+ - `AttentionLayer` (mask semantics)
+ - `GraphAttentionLayer` (graph-specific parameters)
+ - Update the model serialization format in a backward-compatible way:
+ - Preserve current header/structure so existing saved models still load.
+ - Add per-layer metadata payload (versioned) needed to reconstruct constructors.
+ - Extend deserialization to use metadata when present and fall back to safe defaults when absent.
+ - Acceptance criteria:
+ - `Clone()` becomes a true deep copy for supported model types.
+ - `OptimizeForInference()` can safely operate on a cloned instance.
+
+3) **Extend optimizer layer detection/rewrite coverage**
+ - Identify all attention-related layer types and their expected shapes/masking behavior:
+ - `MultiHeadAttentionLayer`
+ - `FlashAttentionLayer`
+ - `AttentionLayer` (mask input)
+ - `SelfAttentionLayer`
+ - `GraphAttentionLayer` (graph adjacency/attention specifics)
+ - For each layer type, decide:
+ - Can we rewrite to `FlashAttentionLayer` safely?
+ - Can we wrap/replace with KV-cache variants safely?
+ - If not, skip and record diagnostics (internal).
+ - Acceptance criteria:
+ - Transformer models built via default layer helper are optimized.
+ - Non-transformer models are unchanged.
+
+4) **Masking policy**
+ - Make masking decision consistent and safe:
+ - `AttentionMasking=Auto` defaults to causal masking for inference sessions unless a bidirectional model is clearly indicated.
+ - `AttentionMasking=Causal` forces causal mask.
+ - `AttentionMasking=Disabled` disables causal mask even for text generation.
+ - Heuristics for `Auto` (ordered, conservative):
+ - If the layer explicitly takes an attention mask input and one is provided, honor it (don’t infer).
+ - If KV-cache is enabled (or a session is created), assume causal unless explicitly overridden.
+ - If model metadata/task type exists and indicates non-causal (e.g., encoder/classification), prefer non-causal.
+ - If uncertain, default to causal only inside a session (so plain `Predict()` stays safest).
+ - Acceptance criteria:
+ - No causal mask applied for classification/encoder tasks unless forced.
+
+5) **Paged attention config plumbing (no swap yet)**
+ - Add config surface for paged attention selection with industry standard defaults:
+ - `EnablePagedKVCache = true` by default (opt-out).
+ - `PagedKVCacheBlockSize = 16` by default (configurable).
+ - Validate config values.
+ - Acceptance criteria:
+ - Configuration exists and is validated; actual usage comes in Phase C.
+
+---
+
+### Phase B — Inference Session API (facade-compliant)
+
+**Goal:** Provide a safe, explicit lifecycle for stateful inference features (KV-cache, batching queues).
+
+1) **Add `BeginInferenceSession()`**
+ - Implement as a method on `PredictionModelResult`:
+ - `public InferenceSession BeginInferenceSession(...)`
+ - Session responsibilities:
+ - Own (or reference) an optimized model instance.
+ - Own cache state (KV cache or paged KV cache).
+ - Provide session-scoped prediction methods.
+ - Reset/clear caches on `Dispose()`.
+
+2) **Decide session API surface**
+ - Minimum recommended:
+ - `Predict(TInput input)` for session-scoped inference.
+ - Optional: `Reset()` (but keep it on session, not on result).
+ - Avoid exposing raw cache objects.
+
+3) **Backwards compatibility**
+ - Keep `PredictionModelResult.Predict` working:
+ - It may internally use a “default session” with conservative behavior.
+ - But for KV-cache and generation patterns, prefer explicit session use.
+
+4) **Serving integration point**
+ - Provide internal hooks so `AiDotNet.Serving` can:
+ - Create sessions per request/sequence.
+ - Route speculative decoding/batching through the same internal components.
+
+Acceptance criteria:
+- Session correctly isolates cache state across concurrent requests.
+- No public exposure of internal optimization classes beyond session wrapper.
+- No public cache-control surface on `PredictionModelResult` (session owns lifecycle).
+
+---
+
+### Phase C — PagedAttention / PagedKVCache integration
+
+**Goal:** Use paged KV-cache for long-context / many-concurrent-sequence serving.
+
+1) **Select KV-cache backend**
+ - Based on config and/or heuristics (but default to paged when enabled):
+ - If `EnablePagedKVCache` is true, prefer `PagedKVCache`.
+ - Otherwise use contiguous `KVCache` (optionally with sliding window).
+
+2) **Bridge cached attention layers to paged cache**
+ - Options:
+ - Implement a new cached attention layer variant that reads/writes via `PagedKVCache`.
+ - Or implement an adapter interface (internal) that abstracts KV-cache operations used by cached attention.
+ - Ensure:
+ - Prefill path supported.
+ - Decode path supported.
+ - Sliding window works without O(n) shifting (paged/ring behavior).
+ - Ensure causal masking logic supports both:
+ - Prefill (many query tokens at once)
+ - Decode (queryOffset / position-based masking)
+
+3) **Integrate with serving continuous batching**
+ - Ensure a single source of truth for per-sequence cache state.
+ - Session ID / sequence ID mapping must be deterministic and internal.
+
+Acceptance criteria:
+- Multi-sequence inference works without cache corruption.
+- Memory growth is bounded via paging/windowing.
+
+---
+
+### Phase D — EnableBatching integration (facade + serving)
+
+**Goal:** Make `EnableBatching` actually affect inference.
+
+1) **Define batching scope**
+ - Offline “single-thread user calling Predict” batching is usually not beneficial.
+ - Batching matters for:
+ - `AiDotNet.Serving` pipelines
+ - Concurrent inference workloads
+
+2) **Implement batching in session/serving**
+ - In `PredictionModelResult.BeginInferenceSession()` optionally accept:
+ - `BatchingMode` (Auto/Disabled/Force) (or use config only).
+ - In serving:
+ - Use `RequestBatcher`/`ContinuousBatchingRequestBatcher` to batch requests.
+ - Ensure batched forward uses the optimized model instance.
+
+3) **Metrics & safety**
+ - Ensure per-request latency bounds (timeout).
+ - Ensure correctness when mixing sequence lengths.
+
+Acceptance criteria:
+- When enabled and used via serving, throughput increases measurably.
+- Batching never changes numerical outputs (only performance).
+
+---
+
+### Phase E — EnableSpeculativeDecoding integration
+
+**Goal:** Route speculative decoding through facade/session/serving for text generation models.
+
+1) **Add a facade entry point for generation**
+ - If generation already exists elsewhere, reuse it.
+ - Otherwise, add a method on the *session*:
+ - `GenerateNextToken(...)` or `Generate(...)` (scope depends on existing LLM infra).
+
+2) **Wire `SpeculativeDecoder`**
+ - Construct from config when:
+ - Task type is text generation (or user forces).
+ - Draft model is configured (NGram default is allowed if desired).
+ - Ensure the “target model forward” used by the decoder is the optimized model forward.
+
+3) **Serving integration**
+ - Ensure speculative decoding works with:
+ - KV-cache
+ - Paged cache (if enabled)
+ - Continuous batching (optional but desirable)
+
+Acceptance criteria:
+- Speculative decoding can be enabled end-to-end via builder config.
+- No public exposure of speculative internals; only facade methods (prefer serving-only unless a strong session use-case exists).
+
+---
+
+## 4.1) Gap Analysis Backlog (to exceed industry standards)
+
+These are common inference optimizations worth confirming (or adding) beyond the basic wiring work:
+- **Attention kernels:** prefill + decode correctness, queryOffset-based causal masking, multi-query attention support, stable softmax.
+- **KV-cache:** per-layer/per-batch correctness, sliding window, paged/ring behavior, cache eviction policy.
+- **Inference quantization (missing today):**
+ - **Weight-only quantization** (INT8/INT4, per-channel scales, optional GPTQ/AWQ style offline calibration) for transformer blocks and projection matrices.
+ - **Activation quantization** (INT8) for matmuls/MLP with calibration (min/max, percentile, KL) and safe fallbacks.
+ - **KV-cache quantization** (INT8/FP8) for K/V storage with dequant-on-read or fused quantized attention kernels; configurable per-layer/per-head.
+ - **Mixed precision** defaults (FP16/BF16/FP8 where supported) with numerically safe softmax/LayerNorm paths.
+ - **Quantization-aware cache policies** (paged KV-cache block reuse/eviction with quantized blocks).
+- **Thread safety:** concurrent sessions, cache isolation, avoiding shared mutable state.
+- **Serving throughput:** continuous batching, request timeouts, micro-batching heuristics.
+- **Speculative decoding:** draft model choices (N-gram vs small neural draft), accept/reject efficiency metrics.
+- **Speculation + batching co-scheduling (missing today):**
+ - Avoid “double spending” compute when both continuous batching and speculative decoding are enabled.
+ - Add policies that trade throughput vs latency under load.
+ - Add modern multi-candidate methods (Medusa/EAGLE) and dynamic speculation (“speculative scheduling”).
+- **Multi-LoRA (missing today):**
+ - Per-session/per-sequence adapter selection, hot-swap, and safe isolation across concurrent requests.
+ - Multi-adapter composition policies (merge vs stack) and caching of merged weights.
+- **Model cloning/serialization:** versioning + backward compatibility for saved models, deterministic round-trips.
+- **Telemetry (internal):** rewrite decisions, cache hit rates, kernel selection, disable-on-failure behavior (internal only).
+
+---
+
+## 4.2) Inference Quantization Roadmap (New)
+
+Goal: add inference-side quantization without expanding public surface beyond `InferenceOptimizationConfig` (builder) and session/serving behavior (result).
+
+### 4.2.1) What exists today
+- Deployment/quantization configuration exists, but current usage is primarily training/export oriented.
+- PR#433 inference optimizations currently operate on FP32/FP64 tensors and caches.
+- MVP now supports KV-cache int8 storage via `InferenceOptimizationConfig.KVCacheQuantization = Int8` (dequant-on-read).
+
+### 4.2.2) Target capabilities (industry standard baseline → exceed)
+1) **Weight-only quantization (WOQ)** for inference
+ - INT8 first (simpler), then INT4.
+ - Per-channel scales for linear projections (Q/K/V/O, FFN).
+ - Offline calibration supported through builder tooling/agents (hidden behind facade).
+
+2) **Activation quantization**
+ - Optional INT8 activations for matmul-heavy blocks.
+ - Calibration strategies:
+ - min/max
+ - percentile
+ - KL-divergence
+ - Safe fallback per-layer when calibration insufficient.
+
+3) **KV-cache quantization**
+ - Quantize K/V storage (INT8 or FP8) with dequant-on-read OR fused kernel support.
+ - Must work with:
+ - contiguous KV cache
+ - paged KV cache
+ - sliding window mode
+ - Defaults:
+ - Off by default until kernels are proven stable; once stable, enable by default for serving workloads with opt-out.
+
+### 4.2.3) Facade integration (no new public types)
+Add to `InferenceOptimizationConfig` (or reuse `QuantizationConfig` internally):
+- `EnableWeightOnlyQuantization` (default: false until validated)
+- `WeightQuantizationBits` (8/4)
+- `EnableActivationQuantization` (default: false)
+- `EnableKVCacheQuantization` (default: false initially; planned default true for serving after validation)
+- `KVCacheQuantizationFormat` (INT8/FP8)
+- `QuantizationCalibrationMode` (Auto/MinMax/Percentile/KL)
+
+Implementation detail: keep all quantized kernels/types `internal` and selected via the optimizer/session/serving pipeline.
+
+### 4.2.4) Testing/acceptance for quantization
+- Golden-output tests with tolerances per quant mode.
+- Determinism tests (same input → same output) under identical config.
+- Memory-budget tests: confirm KV-cache footprint reduction.
+- Regression tests: ensure non-quantized path unchanged.
+
+---
+
+## 4.3) Speculation vs Continuous Batching (Scheduling) (New)
+
+### 4.3.1) Problem
+Continuous batching and speculative decoding can compete for the same compute/memory bandwidth:
+- Speculation increases per-step compute (draft + verify), but may reduce total steps.
+- Continuous batching improves utilization by packing many sequences, but adds scheduling complexity.
+
+### 4.3.2) Policy-based approach (recommended)
+Keep `EnableSpeculativeDecoding` usable in sessions and serving, but add **internal policies** that decide *when* to apply it:
+- **Latency-first** (small batches): enable speculation more often.
+- **Throughput-first** (large batches): reduce speculation depth or disable speculation under high load.
+- **Auto** (default): dynamic based on queue depth, batch size, and recent accept-rate.
+
+Add to config (still facade-only):
+- `SpeculationPolicy` (Auto/ForceOn/ForceOff/LatencyFirst/ThroughputFirst)
+- `MaxSpeculationComputeFraction` (cap draft overhead)
+
+### 4.3.3) Dynamic speculation (“speculative scheduling”)
+Implement “dynamic speculation” that adapts:
+- speculation depth
+- draft batch size
+- whether to speculate at all
+based on:
+- rolling accept-rate
+- queue depth / batcher load
+- KV-cache pressure (paged block availability)
+
+### 4.3.4) Medusa / EAGLE support
+These methods reduce “draft model” overhead by producing multiple candidate tokens with lightweight heads:
+- **Medusa**: extra heads on top of the target model to propose multiple tokens.
+- **EAGLE**: enhanced draft proposals with improved verification efficiency.
+
+Plan:
+1) Add internal capability detection: “model supports Medusa/EAGLE heads”.
+2) Extend session/serving generation to use these heads when enabled.
+3) Add config flags:
+ - `SpeculativeMethod` (ClassicDraftModel/Medusa/Eagle/Auto)
+ - `NumSpeculationHeads` (for Medusa-like methods)
+4) Ensure policies still apply (auto-disable under high batching load).
+
+---
+
+## 4.4) Multi-LoRA for Inference Sessions (New)
+
+### 4.4.1) Goals
+- Allow multiple LoRA adapters to be applied per-session/per-sequence without exposing LoRA internals publicly.
+- Make it compatible with serving: per-request adapter selection.
+- MVP interim: serving can route to a pre-loaded per-adapter model variant via `X-AiDotNet-Lora`/`X-AiDotNet-Adapter` headers using `{baseModelName}__{adapterId}`.
+
+### 4.4.2) Required behaviors
+1) **Selection**
+ - Session/serving chooses adapter by ID (string) via internal route/metadata.
+2) **Isolation**
+ - No cross-request weight pollution; adapters never mutate the base weights.
+3) **Performance**
+ - Cache merged weights per adapter (and per precision/quantization mode).
+4) **Composition**
+ - Support multiple adapters:
+ - merge (weighted sum) OR
+ - stack (apply sequential deltas)
+ - Keep composition policy internal, configurable via builder options.
+
+### 4.4.3) Test plan for multi-LoRA
+- Two sequences with different adapters must produce different outputs for same input.
+- Same adapter reused across sequences must hit cache and remain deterministic.
+- Adapter hot-swap mid-session must not corrupt caches (KV-cache reset rules must be defined).
+
+## 5) Testing Plan
+
+### 5.1 Unit tests (fast, deterministic)
+Add/extend tests for:
+- FlashAttention masking correctness (including cached decoding offsets).
+- KV-cache correctness across layers, batches, and truncation.
+- Optimizer rewrite decisions for each supported attention layer type.
+
+### 5.2 Integration tests (facade end-to-end)
+Create tests that only use the public facade:
+1) Build a small transformer with `PredictionModelBuilder` + `ConfigureInferenceOptimizations()`.
+2) Call `Predict()` and assert:
+ - No exceptions.
+ - Output shape matches expected.
+3) Call `BeginInferenceSession()` and run multiple steps:
+ - Verify caches don't leak across sessions.
+ - Verify `Dispose()` resets state.
+4) Validate attention layer coverage:
+ - Ensure models using `AttentionLayer`, `SelfAttentionLayer`, and `GraphAttentionLayer` either:
+ - optimize safely, or
+ - are explicitly skipped with internal diagnostics (but still function).
+
+If `AiDotNet.Serving` has a test harness, add a serving integration test:
+- Spin up a batcher with a model + config.
+- Submit concurrent requests.
+- Validate outputs and ensure no deadlocks/file-lock issues.
+
+### 5.3 Performance smoke tests (optional)
+- Benchmarks belong in benchmark projects; for this PR, a smoke test is enough:
+ - Validate the optimized path is selected (internal diagnostics).
+
+---
+
+## 6) Acceptance Criteria Checklist
+
+- [ ] No new public entrypoints besides builder/result (session may be nested under result).
+- [ ] `ConfigureInferenceOptimizations()` has full effect in inference.
+- [ ] KV-cache correctness for multi-layer models (no cross-layer corruption).
+- [ ] Attention optimization supports major attention layer types used in the repo.
+- [ ] Paged KV-cache can be enabled (backend selection + attention integration).
+- [ ] Batching and speculative decoding are usable via facade and serving.
+- [ ] Speculation + batching policy prevents throughput regressions under load (Auto backoff).
+- [ ] Inference WOQ (INT8) works end-to-end with safe fallback.
+- [ ] Multi-LoRA works per-request/per-sequence with cache isolation (KV reset on adapter change).
+- [ ] Unit tests + integration tests cover the end-to-end wiring.
+
+Mapping (so reviewers can quickly validate where each checklist item is implemented):
+
+| Checklist item | Primary phase(s) | Primary MVP step(s) | Notes / validation |
+| --- | --- | --- | --- |
+| Minimal public surface | A–E | MVP-0..3 | Session types nested/hidden; internals remain `internal`. |
+| Config has full effect | A, B | MVP-0 | Builder stores config; result/session consumes it. |
+| KV-cache correctness | B, C | MVP-0 | Per-layer/per-sequence cache isolation; no cross-layer corruption. |
+| Attention layer coverage | A | MVP-0 | Support/skip with diagnostics for `AttentionLayer`, `SelfAttentionLayer`, `GraphAttentionLayer`, etc. |
+| Paged KV-cache integration | C | MVP-0 | Paged backend selection + cached attention bridge. |
+| Batching + speculation usable | D, E | MVP-1 | Serving uses the same internals; session support only if facade stays minimal. |
+| Speculation backoff policy | E | MVP-1 | Auto/LatencyFirst/ThroughputFirst; backs off under batching load. |
+| Inference quantization (WOQ) | (add-on) | MVP-2 | Safe fallback per-layer; deterministic tests. |
+| Multi-LoRA per-request/per-seq | (add-on) | MVP-3 | Adapter selection + cache reset on adapter change. |
+| End-to-end tests | A–E | MVP-0..3 | Integration tests must use facade-only entry points. |
+
+---
+
+## 7) Resolved Decisions (from discussion)
+
+- **Facade:** Keep public surface minimal; avoid public cache-control methods on the result.
+- **`BeginInferenceSession()` shape:** Choose whatever is best; prefer a nested session type under `PredictionModelResult`.
+- **`AttentionMasking=Auto`:** Assume causal in typical inference/session usage when task type isn’t set; provide opt-out/override via config.
+- **Paged KV-cache:** Auto-enabled by default; users can opt-out via config.
+- **Speculative decoding:** Serving-first; session support only if it fits cleanly without expanding public surface.
+- **Cloning policy:** Improving cloning via better serialization/deserialization is acceptable to ensure a true deep copy.
+
+## 8) Remaining Questions (small, but useful before coding)
+
+1) Should a session support multiple independent sequences (e.g., `session.CreateSequence()` / `sequenceId`), or is “one session = one sequence” acceptable for now?
+2) Do you already have a preferred public API for text generation (e.g., `Generate(...)`) elsewhere, or should speculative decoding remain strictly within `AiDotNet.Serving` for now?
+
+---
+
+## 9) MVP Sequencing (to raise implementation confidence)
+
+This section turns the backlog into a concrete, low-risk execution order with explicit "first targets" and acceptance checks.
+It is written so a junior engineer can start implementation without having to make major architectural decisions.
+
+### 9.0) Mapping: Phases A–E vs MVP-0/1/2/3
+
+Phases **A–E** describe the main integration areas (wiring, sessions, paging, batching, speculation). The **MVP-0/1/2/3**
+sequence is the recommended *execution order* that also adds two additional "industry standard" gaps (quantization and multi-LoRA).
+
+| Phase (A–E) | MVP step(s) that implement it | Notes |
+|---|---|---|
+| Phase A (baseline wiring + safety) | MVP-0 | Guardrails + diagnostics + safe fallbacks. |
+| Phase B (session API) | MVP-0 | Session surface remains nested under `PredictionModelResult`. |
+| Phase C (paged KV-cache) | MVP-0 / MVP-1 | Paged cache is a default serving primitive; speculation should work with it. |
+| Phase D (EnableBatching) | MVP-0 / MVP-1 | Batching is serving-first; MVP-1 adds arbitration with speculation. |
+| Phase E (EnableSpeculativeDecoding) | MVP-1 | Implements speculation and its policy so it doesn't regress throughput under load. |
+| (not in A–E) Inference quantization | MVP-2 | Adds quantization beyond the original A–E scope. |
+| (not in A–E) Multi-LoRA | MVP-3 | Adds per-request/per-sequence adapter selection beyond the original A–E scope. |
+
+### 9.1) MVP-0: Guardrails (do first)
+1) Keep public API surface unchanged:
+ - Only `PredictionModelBuilder` and `PredictionModelResult` are user-facing.
+ - Sessions remain nested under `PredictionModelResult`.
+ - All new inference types remain `internal` (kernels/caches/schedulers/draft models).
+2) Add internal policy switches (config-driven) but keep defaults safe:
+ - If anything fails (unsupported model/layer), auto-disable that optimization and fall back to baseline inference.
+3) Add internal diagnostics (non-user-facing) to record which optimizations were applied and why others were skipped.
+
+Acceptance:
+- `dotnet build AiDotNet.sln` passes.
+- Existing unit tests still pass (warnings acceptable, no new failures introduced by MVP-0).
+
+### 9.2) MVP-1: Speculative Decoding in Sessions + Serving with Auto Policy
+
+**Goal:** Speculation is available wherever sessions are used when `EnableSpeculativeDecoding=true`, but it does not tank throughput under load.
+
+**First target method:** Classic “draft-model speculation” (existing draft model support) with a new internal policy layer.
+
+Implementation steps:
+1) Introduce a single internal decision point (used by both session and serving):
+ - “Should we speculate this step?” and “What depth should we use?”
+2) Implement `SpeculationPolicy=Auto` (internal default recommendation):
+ - Inputs to policy:
+ - rolling accept-rate
+ - current batch size / queue depth (serving)
+ - KV-cache pressure (paged block availability, optional)
+ - Behavior:
+ - Under high batching load: reduce depth or disable speculation.
+ - Under low load / latency-sensitive: increase depth up to configured max.
+3) Respect `EnableBatching` and `EnableSpeculativeDecoding` together:
+ - When both true, `Auto` policy prevents “double spending” compute.
+ - Provide `ForceOn/ForceOff` to override for experimentation (still configured via builder, not new public APIs).
+
+Defaults:
+- Sessions: if `EnableSpeculativeDecoding=true` and `SpeculationPolicy=Auto`, speculate when accept-rate is good and batch load is low/moderate.
+- Serving: if continuous batching queue is deep, speculation backs off automatically.
+
+Acceptance:
+- Integration tests show sessions still isolate state across sequences.
+- Serving throughput under batching load does not regress materially when `EnableSpeculativeDecoding=true` (policy must disable speculation under heavy load).
+
+### 9.3) MVP-2: Inference Quantization “First Target” (Weight-Only INT8)
+
+**Goal:** Get a real inference quantization win with minimal kernel churn and low correctness risk.
+
+**First target mode:** **Weight-only INT8 (WOQ)** for dense/linear matmuls in transformer blocks:
+- Q/K/V projections
+- output projection
+- FFN projections
+
+Non-goals for MVP-2 (explicitly deferred):
+- Activation quantization (INT8) (Phase 2)
+- KV-cache quantization (INT8/FP8) (Phase 3)
+- INT4 weight-only (Phase 4)
+
+Implementation steps:
+1) Add internal quantized weight containers (per-channel scales) and an `internal` matmul kernel that supports:
+ - FP32/FP16 activations * INT8 weights → FP32/FP16 output
+2) Add configuration wiring via `InferenceOptimizationConfig` (builder-only surface):
+ - `EnableWeightOnlyQuantization` (default false until validated)
+ - `WeightQuantizationBits` (start with 8)
+ - `QuantizationCalibrationMode` (Auto/MinMax/Percentile/KL) — for WOQ, “calibration” is mostly scale estimation; keep it simple initially.
+3) Selection rules:
+ - Only apply WOQ when model/task matches supported inference path (e.g., transformer-style models).
+ - Skip layers not yet supported; do not partially quantize in ways that break determinism.
+4) Agent support (optional but planned):
+ - Agents can recommend enabling WOQ for serving workloads; still configured via builder.
+
+Acceptance:
+- Accuracy within tolerance against FP baseline on deterministic tests.
+- Measurable memory reduction for weights (and ideally throughput gain on CPU/GPU depending on engine).
+
+### 9.4) MVP-3: Multi-LoRA (Per-Session/Per-Request Adapter Selection)
+
+**Goal:** Multiple LoRA adapters can be used concurrently (serving) or per-sequence (sessions) without exposing LoRA internals publicly.
+
+First target behavior:
+1) Adapter selection:
+ - Serving: adapter ID selected per request (e.g., metadata field in serving request model).
+ - Sessions: adapter ID set at sequence creation time (internal hook; no new public surface required beyond existing builder/result/session shape).
+2) Weight handling:
+ - Never mutate base weights.
+ - Cache merged weights per adapter ID (and per precision/quantization mode) to avoid recomputing merges.
+3) KV-cache interaction rules:
+ - If adapter changes for a given sequence, **reset KV-cache for that sequence only** (deterministic + correctness first).
+ - Scope of reset (must be consistent across backends):
+ - Reset contiguous `KVCache` *or* paged `PagedKVCache` state for that sequence ID.
+ - Do not clear other sequences' caches.
+ - Also reset any sequence-scoped optimized model state that depends on the adapter (e.g., merged weights / optimizer state), so adapter changes cannot reuse stale K/V pages.
+
+Non-goals for MVP-3 (defer):
+- Multi-adapter composition beyond simple “single adapter at a time” (Phase 2: merge/stack).
+- Hot-swapping adapters mid-generation without KV reset (Phase 3: advanced cache partitioning).
+
+Acceptance:
+- Two concurrent sequences using different adapters produce different outputs for same input.
+- Same adapter reused across sequences hits merge cache and remains deterministic.
+
+### 9.5) Phase 2+ (after MVPs)
+1) Dynamic speculation improvements (speculative scheduling refinements).
+2) Medusa vs EAGLE:
+ - Recommended order: implement **Medusa first** if it fits the model architecture best (multiple lightweight heads), then add EAGLE if needed.
+3) Activation quantization (INT8) and then KV-cache quantization (INT8/FP8), including paged KV-cache support.
+4) Multi-adapter LoRA composition (merge/stack), plus better cache invalidation rules.
diff --git a/docs/PR433_PHASE_AUDIT.md b/docs/PR433_PHASE_AUDIT.md
new file mode 100644
index 000000000..0b0a70bde
--- /dev/null
+++ b/docs/PR433_PHASE_AUDIT.md
@@ -0,0 +1,358 @@
+# PR #433 — Strict 9‑Phase Audit + Gap‑Closure Plan
+
+This document audits `docs/INFERENCE_MVP_PHASES.md` phase-by-phase against the current PR #433 implementation, using **concrete code locations and tests** as evidence, and lists the remaining work required to reach **100% confidence** with production-ready behavior.
+
+**Audit basis**
+- Phase source of truth: `docs/INFERENCE_MVP_PHASES.md`
+- Branch head used for this audit: `9e493239`
+
+---
+
+## Current confidence summary
+
+**Overall confidence that all 9 phases are 100% complete:** **~95%** (MVP plan complete; remaining work is mostly post-MVP feature depth such as INT4/activation quantization).
+
+**High-confidence areas:** Phase 1, 2, 3, 4, 6, 9 (core wiring + tests exist).
+
+**Low-confidence areas:** Phase 8 (post-MVP quantization depth such as INT4 and activation quantization).
+
+---
+
+## Phase 0 — Baseline Safety & Diagnostics
+
+**Phase intent:** observable and safe-by-default (auto-disable on unsupported, record decisions/exceptions, avoid mutating user models).
+
+**Implemented evidence**
+- Diagnostics collector: `src/Helpers/InferenceDiagnostics.cs:1`
+- Optimizer decision recording: `src/Inference/InferenceOptimizer.cs:119`
+- Serving decision recording: `src/Serving/ContinuousBatching/ContinuousBatcher.cs:400`
+- Session decision recording (Multi‑LoRA): `src/Models/Results/PredictionModelResult.cs:1317`
+
+**Existing verification**
+- Indirect coverage through tests that validate fallback behavior (speculation fallback, etc.):
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:92`
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:116`
+
+**Status:** Closed for MVP.
+
+**Verification added**
+- Diagnostics toggling (env var) + queue clear:
+ - `tests/AiDotNet.Tests/UnitTests/Helpers/InferenceDiagnosticsTests.cs:7`
+- Optimizer decision logging is exercised when diagnostics are enabled:
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:14`
+
+---
+
+## Phase 1 — Attention Rewrite Integration
+
+**Phase intent:** optimizer rewrites supported attention layers consistently; cloning is truly deep.
+
+**Implemented evidence**
+- Attention rewrite selection, including SelfAttention conversion and Flash/KV paths:
+ - Layer detection and rewrite: `src/Inference/InferenceOptimizer.cs:95`
+ - SelfAttention conversion: `src/Inference/InferenceOptimizer.cs:577`
+- Clone/deep copy via serialization:
+ - Serialization metadata and activation persistence: `src/NeuralNetworks/NeuralNetworkBase.cs:1311`
+
+**Existing verification**
+- Rewrites:
+ - Flash rewrite: `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:14`
+ - KV cached rewrite for text generation: `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:36`
+ - SelfAttention -> cached rewrite: `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:67`
+- Clone correctness baseline:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:213`
+
+**Status:** Closed for MVP (explicit skip coverage; no rewrite is attempted).
+
+**Verification added**
+- Explicit skip (no crash, no rewrite) tests:
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:206`
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:226`
+
+---
+
+## Phase 2 — Paged Attention + Paged KV‑Cache
+
+**Phase intent:** paged KV-cache is available and default-on; attention layers bridge to paged cache.
+
+**Implemented evidence**
+- Paged cache initialization + selection:
+ - `src/Inference/InferenceOptimizer.cs:276`
+- Paged cache implementation and guards:
+ - `src/Inference/PagedAttention/PagedKVCache.cs:26`
+- Paged cached attention layer:
+ - `src/Inference/PagedCachedMultiHeadAttention.cs:1`
+
+**Existing verification**
+- Unit coverage for paged cache + kernel:
+ - `tests/AiDotNet.Tests/UnitTests/Inference/PagedAttentionTests.cs:454`
+- Serving-side paged attention stability test exists (and net471 guard was added previously):
+ - `tests/AiDotNet.Tests/UnitTests/Serving/ServingComponentsTests.cs` (see paged attention test name if present)
+
+**Status:** Closed for MVP.
+
+**Verification added**
+- Session integration verifies paged KV-cache initialization + paged attention rewrite selection via internal stats:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:270`
+
+---
+
+## Phase 3 — KV‑Cache Precision (FP16 default, opt-out) + Quantized KV‑cache
+
+**Phase intent:** FP16 default for cache when supported; opt-out; optional int8 quantization.
+
+**Implemented evidence**
+- Precision/quantization resolution:
+ - `src/Inference/InferenceOptimizer.cs:247`
+- KV-cache FP16 + int8 storage:
+ - `src/Inference/KVCache.cs:99`
+- Config surface:
+ - `src/Configuration/InferenceOptimizationConfig.cs:152`
+ - `src/Configuration/InferenceOptimizationConfig.cs:167`
+
+**Existing verification**
+- Unit tests:
+ - FP16: `tests/AiDotNet.Tests/UnitTests/Inference/KVCacheTests.cs:61`
+ - Int8: `tests/AiDotNet.Tests/UnitTests/Inference/KVCacheTests.cs:96`
+- Integration test (int8 selection is visible via internal stats):
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:157`
+
+**Status:** Closed for MVP.
+
+**Verification added**
+- Session integration verifies FP16 selection in Auto mode:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:216`
+
+---
+
+## Phase 4 — Inference Sessions (Multi‑Sequence, Facade‑Friendly)
+
+**Phase intent:** `PredictionModelResult.BeginInferenceSession()` supports multi-sequence inference with isolation; minimal public API; internal stats for tests.
+
+**Implemented evidence**
+- Facade entrypoint:
+ - `src/Models/Results/PredictionModelResult.cs:1024`
+- Multi-sequence support:
+ - `src/Models/Results/PredictionModelResult.cs:1118`
+- Per-sequence internal stats hook:
+ - `src/Models/Results/PredictionModelResult.cs:1270`
+
+**Existing verification**
+- Predict remains stateless even with optimizations configured:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:24`
+- Multi-sequence independence:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:77`
+- Reset restores baseline state:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:130`
+
+**Status:** Closed for MVP.
+
+**Verification added**
+- Concurrent multi-sequence Predict test:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:159`
+
+---
+
+## Phase 5 — Batching (Serving‑First) + Resource Arbitration
+
+**Phase intent:** batching enabled in serving; clear policy for batching vs speculation conflicts.
+
+**Implemented evidence**
+- Continuous batching implementation:
+ - `src/Serving/ContinuousBatching/ContinuousBatcher.cs:1`
+- Conflict policy hooks and backoff:
+ - `src/Serving/ContinuousBatching/ContinuousBatcher.cs:426`
+
+**Existing verification**
+- Serving batching tests:
+ - `tests/AiDotNet.Tests/UnitTests/Serving/ContinuousBatchingTests.cs:11`
+- Serving integration test verifies batching with concurrent requests:
+ - `tests/AiDotNet.Serving.Tests/ServingIntegrationTests.cs:298`
+
+**Status:** Closed for MVP (serving arbitration tests added; sessions do not batch across sequences).
+
+---
+
+## Phase 6 — Speculative Decoding MVP (Draft Model + Policy)
+
+**Phase intent:** speculative decoding is wired via config; safe fallback when draft unavailable; serving integration.
+
+**Implemented evidence**
+- Inference optimizer speculative initialization:
+ - `src/Inference/InferenceOptimizer.cs:726`
+- Config and policy:
+ - `src/Configuration/InferenceOptimizationConfig.cs:389`
+ - `src/Configuration/InferenceOptimizationConfig.cs:449`
+
+**Existing verification**
+- Fallback to NGram:
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:92`
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:116`
+
+**Status:** Closed for MVP (speculative decoding is configured in sessions but only executed by serving/generation flows, not plain `Predict()`).
+
+**Verification added**
+- Session integration validates "configured but not executed during Predict" behavior:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:244`
+
+---
+
+## Phase 7 — Dynamic Speculation & Alternative Speculators (Medusa/EAGLE)
+
+**Phase intent:** dynamic scheduling (acceptance/queue pressure) and alternative methods.
+
+**Implemented evidence (partial)**
+- Config hooks exist:
+ - `src/Configuration/InferenceOptimizationConfig.cs:462`
+ - `src/Configuration/InferenceOptimizationConfig.cs:523`
+- Serving backoff logic (dynamic-ish policy):
+ - `src/Serving/ContinuousBatching/ContinuousBatcher.cs:426`
+
+**Existing verification (partial)**
+- Policy tests:
+ - Auto acceptance backoff: `tests/AiDotNet.Tests/UnitTests/Serving/ContinuousBatchingTests.cs:589`
+ - ThroughputFirst behavior: `tests/AiDotNet.Tests/UnitTests/Serving/ContinuousBatchingTests.cs:651`
+
+**Status:** Closed for MVP.
+
+---
+
+## Phase 8 — Inference Quantization (Gap‑Closing)
+
+**Phase intent:** inference quantization beyond training: KV-cache quantization + weight-only quantization; later activation quantization.
+
+**Implemented evidence (partial)**
+- KV-cache quantization (int8) is wired and implemented:
+ - `src/Inference/InferenceOptimizer.cs:247`
+ - `src/Inference/KVCache.cs:99`
+- Weight-only quantization MVP (Dense-only float):
+ - `src/Inference/InferenceOptimizer.cs:538`
+ - `src/Inference/Quantization/QuantizedDenseLayer.cs:10`
+ - `src/Inference/Quantization/Int8WeightOnlyQuantization.cs:5`
+
+**Existing verification**
+- WOQ rewrite and output preservation:
+ - `tests/AiDotNet.Tests/UnitTests/Inference/InferenceOptimizerTests.cs:139`
+- KV-cache quantization verified in integration:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:157`
+
+**Status:** Closed for MVP (WOQ covers Dense + paged attention projections with correctness tests).
+
+**Post-MVP opportunities**
+- INT4 WOQ and activation quantization (opt-in, correctness-first).
+
+---
+
+## Phase 9 — Multi‑LoRA (Serving‑First, Secure Defaults)
+
+**Phase intent:** per-request adapter selection in serving; optional per-sequence selection in sessions; no public LoRA internals; cache reset rules.
+
+**Implemented evidence**
+- Serving adapter routing (header-based):
+ - `src/AiDotNet.Serving/Controllers/InferenceController.cs:220`
+- Session per-sequence task selection + reset:
+ - `src/Models/Results/PredictionModelResult.cs:1125`
+ - `src/Models/Results/PredictionModelResult.cs:1225`
+
+**Clone/serialization correctness (critical for isolation)**
+- Serialization v2 + extras:
+ - `src/NeuralNetworks/NeuralNetworkBase.cs:1261`
+ - `src/NeuralNetworks/Layers/ILayerSerializationExtras.cs:1`
+- MultiLoRA metadata + deterministic ordering + frozen-base extras:
+ - `src/LoRA/Adapters/MultiLoRAAdapter.cs:54`
+- MultiLoRA deserialization support:
+ - `src/Helpers/DeserializationHelper.cs:332`
+
+**Existing verification**
+- Serving routes to variant model:
+ - `tests/AiDotNet.Serving.Tests/ServingIntegrationTests.cs:232`
+- Session sequences isolate task selection:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:182`
+
+**Status:** Closed for MVP.
+
+**Verification added**
+- Task switch resets cache state (same sequence) and records MultiLoRA decision:
+ - `tests/AiDotNet.Tests/IntegrationTests/Inference/InferenceSessionIntegrationTests.cs:348`
+
+---
+
+## Unresolved PR threads (code scanning)
+
+Two PR threads are unresolved because they are code-scanning alert threads (not manually resolvable via the review API):
+- `src/AiDotNet.Tensors/Engines/Optimization/PerformanceProfiler.cs`
+- `src/InferenceOptimization/Kernels/AttentionKernel.cs`
+
+For strict closure, any fixes should be recorded as comments on those threads (per repo workflow).
+
+---
+
+## Gap‑closure plan to reach 100% confidence (prioritized)
+
+### P0 — Required to claim “100% complete”
+1) **Phase 7 implementations (not just hooks)**
+ - Define internal abstraction(s) (no new public API):
+ - `internal interface ISpeculativeProposer`: `Propose(...)` returns N candidate tokens (or token trees) + optional scores.
+ - `internal sealed class ClassicDraftModelProposer` wraps existing `IDraftModel`.
+ - `internal sealed class MedusaProposer` and/or `EagleProposer`: MVP uses “multi-head proposal” logic implemented *internally* (even if initial implementation is CPU-only).
+ - Wire proposer selection:
+ - `InferenceOptimizationConfig.SpeculativeMethod` selects proposer (Auto=>ClassicDraftModel).
+ - Serving (`ContinuousBatcher`) uses proposer; sessions use proposer only if session exposes a generation API (otherwise serving-first is acceptable, but must be documented).
+ - Implement **dynamic scheduling** (a real algorithm, not only queue-size backoff):
+ - Inputs: acceptance rate EMA, batch size / queue depth, recent latency (optional), configured max depth.
+ - Output: per-step speculation depth (and optionally proposer disable).
+ - Required properties: deterministic in tests (use fixed seeds and controlled inputs), monotonic backoff under sustained low acceptance, and fast recovery under high acceptance.
+ - Tests (must be deterministic / non-flaky):
+ - Acceptance rate drives depth up/down (unit).
+ - Under batching load, speculation is disabled or depth reduced (unit).
+ - Policy respects ForceOn/ForceOff/LatencyFirst/ThroughputFirst (unit).
+
+2) **Phase 8 broaden quantization to “industry standard” scope**
+ - Expand beyond “DenseLayer-as-a-top-level-layer” where possible:
+ - If transformer blocks are composed from explicit `DenseLayer` layers in the model graph, extend rewrite detection to those layers (straightforward).
+ - If attention layers own projection matrices internally (common), add **internal** quantization paths inside:
+ - `src/Inference/CachedMultiHeadAttention.cs` (Q/K/V/O matvecs)
+ - `src/Inference/PagedCachedMultiHeadAttention.cs` (paged kernel weight paths)
+ - `src/NeuralNetworks/Attention/FlashAttentionLayer.cs` (if applicable)
+ - Quantization modes and constraints for MVP:
+ - WOQ INT8 for float inference only (keep correctness first).
+ - Per-row/per-channel scales; deterministic rounding.
+ - Clear fallback to FP if unsupported or errors (record diagnostics).
+ - Tests:
+ - Unit: WOQ matvec kernel correctness vs FP baseline.
+ - Integration: enabling WOQ changes selected path (diagnostics/stats) and output remains within tolerance.
+
+3) **Phase 5 arbitration completeness**
+ - Add explicit tests for `EnableBatching && EnableSpeculativeDecoding` under load.
+ - Ensure serving chooses the intended policy for `ThroughputFirst/LatencyFirst/Auto`.
+ - Add explicit “resource competition” tests:
+ - Same workload, batching depth>1 => speculation off in ThroughputFirst.
+ - Low load, LatencyFirst => speculation on with configured depth.
+
+### P1 — Strongly recommended for production readiness
+4) Add integration test for paged selection (Phase 2) to prove the optimizer chooses `PagedCachedMultiHeadAttention` when enabled.
+5) Add integration test for FP16 auto selection (Phase 3).
+6) Add concurrency test for sessions (Phase 4).
+7) Add adapter/task-switch KV reset test (Phase 9).
+
+### P2 — Diagnostics/test completeness
+8) Add diagnostics assertion tests (Phase 0) with `AIDOTNET_DIAGNOSTICS=1` and expected decision entries.
+9) Add a minimal “selection report” surface for tests only (internal), if needed to avoid brittle string checks.
+
+---
+
+## Verification matrix (what we run to maintain 100% confidence)
+
+**Build**
+- `dotnet build AiDotNet.sln -c Release`
+
+**Targeted tests (fast, must pass before pushing)**
+- Inference optimizer + sessions: `dotnet test tests/AiDotNet.Tests/AiDotNetTests.csproj -c Release --filter FullyQualifiedName~InferenceOptimizerTests`
+- Session integration: `dotnet test tests/AiDotNet.Tests/AiDotNetTests.csproj -c Release --filter FullyQualifiedName~InferenceSessionIntegrationTests`
+- Paged attention: `dotnet test tests/AiDotNet.Tests/AiDotNetTests.csproj -c Release --filter FullyQualifiedName~PagedKVCacheTests|FullyQualifiedName~PagedAttentionKernelTests`
+- Serving batching: `dotnet test tests/AiDotNet.Tests/AiDotNetTests.csproj -c Release --filter FullyQualifiedName~ContinuousBatchingTests`
+- Serving integration: `dotnet test tests/AiDotNet.Serving.Tests/AiDotNet.Serving.Tests.csproj -c Release`
+
+**Full test suite**
+- `dotnet test tests/AiDotNet.Tests/AiDotNetTests.csproj -c Release`
+ - Note: currently fails due to unrelated JIT/time-series/regression failures; those must be resolved or explicitly quarantined outside PR #433 before claiming repo-wide “100% green”.
diff --git a/docs/PR433_REVIEW_WORKFLOW.md b/docs/PR433_REVIEW_WORKFLOW.md
new file mode 100644
index 000000000..61128b572
--- /dev/null
+++ b/docs/PR433_REVIEW_WORKFLOW.md
@@ -0,0 +1,15 @@
+## PR#433 Review Workflow (Unresolved Comments)
+
+This repo’s PR#433 work must be completed using the following loop, **in order**, until there are no unresolved comments left:
+
+1) Use the **GitHub GraphQL API** (via `gh api graphql`) to fetch **unresolved review threads** for PR#433.
+2) Take the next unresolved thread (oldest first unless specified otherwise).
+3) Implement the required code/docs change in the repo.
+4) **Before moving to the next thread**:
+ - If it’s a normal review thread: resolve it via the GraphQL `resolveReviewThread` mutation.
+ - If it’s a code-scanning / bot thread that cannot be manually resolved: add a reply comment to the thread describing what was fixed, so we can track progress.
+
+Notes:
+- Do not batch multiple threads in one pass; always follow the “fix → resolve/reply → next” sequence.
+- Keep the public API surface minimal per the facade philosophy; prefer `internal` and nested session types.
+
diff --git a/examples/JitCompiler/BasicUsageExample.cs b/examples/JitCompiler/BasicUsageExample.cs
index 008403957..a359b8ff3 100644
--- a/examples/JitCompiler/BasicUsageExample.cs
+++ b/examples/JitCompiler/BasicUsageExample.cs
@@ -205,7 +205,7 @@ public static void CachingExample()
OperationType = OperationType.ReLU
};
- var (compiled1, stats1) = jit.CompileWithStats(relu1, new List> { input1 });
+ var (_, stats1) = jit.CompileWithStats(relu1, new List> { input1 });
Console.WriteLine($"First compilation:");
Console.WriteLine($" Cache hit: {stats1.CacheHit}");
Console.WriteLine($" Compilation time: {stats1.CompilationTime.TotalMilliseconds:F2}ms\n");
@@ -219,7 +219,7 @@ public static void CachingExample()
OperationType = OperationType.ReLU
};
- var (compiled2, stats2) = jit.CompileWithStats(relu2, new List> { input2 });
+ var (_, stats2) = jit.CompileWithStats(relu2, new List> { input2 });
Console.WriteLine($"Second compilation (same structure):");
Console.WriteLine($" Cache hit: {stats2.CacheHit}");
Console.WriteLine($" Compilation time: {stats2.CompilationTime.TotalMilliseconds:F2}ms\n");
@@ -232,7 +232,7 @@ public static void CachingExample()
OperationType = OperationType.Sigmoid
};
- var (compiled3, stats3) = jit.CompileWithStats(sigmoid2, new List> { input2 });
+ var (_, stats3) = jit.CompileWithStats(sigmoid2, new List> { input2 });
Console.WriteLine($"Third compilation (different structure):");
Console.WriteLine($" Cache hit: {stats3.CacheHit}");
Console.WriteLine($" Compilation time: {stats3.CompilationTime.TotalMilliseconds:F2}ms\n");
diff --git a/src/AiDotNet.Serving/Controllers/InferenceController.cs b/src/AiDotNet.Serving/Controllers/InferenceController.cs
index b33fc4bb2..d445a1810 100644
--- a/src/AiDotNet.Serving/Controllers/InferenceController.cs
+++ b/src/AiDotNet.Serving/Controllers/InferenceController.cs
@@ -135,6 +135,12 @@ public async Task Predict(string modelName, [FromBody] Prediction
catch (ArgumentException ex)
{
_logger.LogError(ex, "Invalid argument during prediction for model '{ModelName}'", modelName);
+
+ if (ex.Message.Contains("maximum allowed when batching is disabled", StringComparison.OrdinalIgnoreCase))
+ {
+ return StatusCode(StatusCodes.Status413PayloadTooLarge, new { error = ex.Message });
+ }
+
return BadRequest(new { error = $"Invalid input: {ex.Message}" });
}
catch (Exception ex)
@@ -149,24 +155,118 @@ public async Task Predict(string modelName, [FromBody] Prediction
///
private async Task PredictWithType(string modelName, double[][] features)
{
+ string effectiveModelName = ResolveModelNameWithAdapter(modelName);
+ var model = _modelRepository.GetModel(effectiveModelName) ?? _modelRepository.GetModel(modelName);
+ if (model == null)
+ {
+ string attemptedNames = effectiveModelName != modelName
+ ? $"'{effectiveModelName}' (with adapter) or '{modelName}'"
+ : $"'{modelName}'";
+ throw new InvalidOperationException($"Model {attemptedNames} was not found.");
+ }
+
+ // Respect per-model inference configuration: bypass batching when disabled.
+ if (model is AiDotNet.Serving.Models.IServableModelInferenceOptions opts && !opts.EnableBatching)
+ {
+ const int MaxUnbatchedItems = 1000;
+ if (features.Length > MaxUnbatchedItems)
+ {
+ _logger.LogWarning(
+ "Rejected large unbatched request ({Count} items) for model '{ModelName}' (batching disabled)",
+ features.Length,
+ modelName);
+
+ throw new ArgumentException(
+ $"Request batch size ({features.Length}) exceeds the maximum allowed when batching is disabled ({MaxUnbatchedItems}). " +
+ $"Enable batching for model '{modelName}' or split the request into smaller batches.");
+ }
+
+ _logger.LogDebug(
+ "Batching disabled for model '{ModelName}', processing {Count} items individually",
+ modelName,
+ features.Length);
+
+ var predictions = new double[features.Length][];
+ for (int i = 0; i < features.Length; i++)
+ {
+ var inputVector = ConvertToVector(features[i]);
+ var resultVector = model.Predict(inputVector);
+ predictions[i] = ConvertFromVector(resultVector);
+ }
+
+ return predictions;
+ }
+
// Queue all requests first to enable batching
var tasks = features.Select(featureArray =>
{
var inputVector = ConvertToVector(featureArray);
- return _requestBatcher.QueueRequest(modelName, inputVector);
+ return _requestBatcher.QueueRequest(effectiveModelName, inputVector);
}).ToArray();
// Await all requests together
var resultVectors = await Task.WhenAll(tasks);
// Convert results back to double arrays
- var predictions = new double[resultVectors.Length][];
+ var batchedPredictions = new double[resultVectors.Length][];
for (int i = 0; i < resultVectors.Length; i++)
{
- predictions[i] = ConvertFromVector(resultVectors[i]);
+ batchedPredictions[i] = ConvertFromVector(resultVectors[i]);
+ }
+
+ return batchedPredictions;
+ }
+
+ private string ResolveModelNameWithAdapter(string modelName)
+ {
+ // Multi-LoRA / adapter routing (serving-first): select a pre-loaded model variant via request header.
+ // This keeps adapter details out of the public model facade while enabling per-request selection.
+ if (Request?.Headers == null)
+ {
+ _logger.LogDebug("No request headers available; routing to base model '{ModelName}'.", modelName);
+ return modelName;
+ }
+
+ if (!Request.Headers.TryGetValue("X-AiDotNet-Lora", out var adapterValues) &&
+ !Request.Headers.TryGetValue("X-AiDotNet-Adapter", out adapterValues))
+ {
+ _logger.LogDebug("No adapter header present; routing to base model '{ModelName}'.", modelName);
+ return modelName;
+ }
+
+ var adapterId = adapterValues.ToString()?.Trim();
+ if (string.IsNullOrWhiteSpace(adapterId) || adapterId.Length > 64 || !IsSafeAdapterId(adapterId))
+ {
+ if (!string.IsNullOrWhiteSpace(adapterId))
+ {
+ string reason = adapterId.Length > 64 ? "TooLong" : (!IsSafeAdapterId(adapterId) ? "UnsafeCharacters" : "EmptyOrWhitespace");
+ var level = adapterId.Length > 64 || !IsSafeAdapterId(adapterId) ? LogLevel.Warning : LogLevel.Debug;
+ _logger.Log(level,
+ "Ignoring invalid adapter ID '{AdapterId}' for model '{ModelName}' (reason: {Reason}).",
+ adapterId,
+ modelName,
+ reason);
+ }
+ return modelName;
+ }
+
+ _logger.LogDebug("Routing to adapter model '{EffectiveModelName}'.", $"{modelName}__{adapterId}");
+ return $"{modelName}__{adapterId}";
+ }
+
+ private static bool IsSafeAdapterId(string adapterId)
+ {
+ for (int i = 0; i < adapterId.Length; i++)
+ {
+ char c = adapterId[i];
+ bool ok = (c >= 'a' && c <= 'z') ||
+ (c >= 'A' && c <= 'Z') ||
+ (c >= '0' && c <= '9') ||
+ c == '-' || c == '_' || c == '.';
+ if (!ok) return false;
}
- return predictions;
+ return true;
}
///
diff --git a/src/AiDotNet.Serving/Models/IServableModelInferenceOptions.cs b/src/AiDotNet.Serving/Models/IServableModelInferenceOptions.cs
new file mode 100644
index 000000000..42fb632b0
--- /dev/null
+++ b/src/AiDotNet.Serving/Models/IServableModelInferenceOptions.cs
@@ -0,0 +1,11 @@
+namespace AiDotNet.Serving.Models;
+
+///
+/// Internal serving-only inference options derived from the model's facade configuration.
+///
+internal interface IServableModelInferenceOptions
+{
+ bool EnableBatching { get; }
+ bool EnableSpeculativeDecoding { get; }
+}
+
diff --git a/src/AiDotNet.Serving/Models/ServableModelWrapper.cs b/src/AiDotNet.Serving/Models/ServableModelWrapper.cs
index 36dfb6154..1f37de1d3 100644
--- a/src/AiDotNet.Serving/Models/ServableModelWrapper.cs
+++ b/src/AiDotNet.Serving/Models/ServableModelWrapper.cs
@@ -8,13 +8,15 @@ namespace AiDotNet.Serving.Models;
/// This allows any model with a Predict method to be served via the REST API.
///
/// The numeric type used by the model
-public class ServableModelWrapper : IServableModel
+public class ServableModelWrapper : IServableModel, IServableModelInferenceOptions
{
private readonly Func, Vector> _predictFunc;
private readonly Func, Matrix>? _predictBatchFunc;
private readonly string _modelName;
private readonly int _inputDimension;
private readonly int _outputDimension;
+ private readonly bool _enableBatching;
+ private readonly bool _enableSpeculativeDecoding;
///
/// Initializes a new instance of the ServableModelWrapper with custom prediction functions.
@@ -24,18 +26,24 @@ public class ServableModelWrapper : IServableModel
/// The number of output dimensions
/// Function to perform single prediction
/// Optional function to perform batch prediction. If not provided, batch prediction will use multiple single predictions.
+ /// Whether this model supports serving-side batching.
+ /// Whether this model supports speculative decoding in serving/session workflows.
public ServableModelWrapper(
string modelName,
int inputDimension,
int outputDimension,
Func, Vector> predictFunc,
- Func, Matrix>? predictBatchFunc = null)
+ Func, Matrix>? predictBatchFunc = null,
+ bool enableBatching = true,
+ bool enableSpeculativeDecoding = false)
{
_modelName = modelName ?? throw new ArgumentNullException(nameof(modelName));
_inputDimension = inputDimension;
_outputDimension = outputDimension;
_predictFunc = predictFunc ?? throw new ArgumentNullException(nameof(predictFunc));
_predictBatchFunc = predictBatchFunc;
+ _enableBatching = enableBatching;
+ _enableSpeculativeDecoding = enableSpeculativeDecoding;
}
///
@@ -52,6 +60,8 @@ public ServableModelWrapper(
_modelName = modelName ?? throw new ArgumentNullException(nameof(modelName));
_inputDimension = inputDimension;
_outputDimension = 1; // Regression models typically output a single value
+ _enableBatching = true;
+ _enableSpeculativeDecoding = false;
if (regressionModel == null)
{
@@ -135,4 +145,7 @@ public Matrix PredictBatch(Matrix inputs)
return result;
}
+
+ bool IServableModelInferenceOptions.EnableBatching => _enableBatching;
+ bool IServableModelInferenceOptions.EnableSpeculativeDecoding => _enableSpeculativeDecoding;
}
diff --git a/src/AiDotNet.Serving/Services/ModelStartupService.cs b/src/AiDotNet.Serving/Services/ModelStartupService.cs
index 956b0fc33..166e8b661 100644
--- a/src/AiDotNet.Serving/Services/ModelStartupService.cs
+++ b/src/AiDotNet.Serving/Services/ModelStartupService.cs
@@ -200,6 +200,10 @@ private void LoadTypedModel(string name, string path)
var modelResult = new PredictionModelResult, Vector>();
modelResult.LoadFromFile(path);
+ var inferenceConfig = modelResult.GetInferenceOptimizationConfigForServing();
+ bool enableBatching = inferenceConfig?.EnableBatching ?? true;
+ bool enableSpeculativeDecoding = inferenceConfig?.EnableSpeculativeDecoding ?? false;
+
// Get dimensions from the model metadata
var metadata = modelResult.GetModelMetadata();
var inputDim = metadata.FeatureCount > 0 ? metadata.FeatureCount : 1;
@@ -267,7 +271,9 @@ private void LoadTypedModel(string name, string path)
inputDim,
outputDim,
predictFunc,
- predictBatchFunc);
+ predictBatchFunc,
+ enableBatching: enableBatching,
+ enableSpeculativeDecoding: enableSpeculativeDecoding);
// Register with the repository
var success = _modelRepository.LoadModel(name, servableModel, path);
diff --git a/src/AiDotNet.Tensors/Engines/GpuEngine.cs b/src/AiDotNet.Tensors/Engines/GpuEngine.cs
index 238454243..2ebed18cd 100644
--- a/src/AiDotNet.Tensors/Engines/GpuEngine.cs
+++ b/src/AiDotNet.Tensors/Engines/GpuEngine.cs
@@ -1048,8 +1048,8 @@ public GpuEngine(AdaptiveThresholds thresholds)
try
{
- // Create ILGPU context
- _context = Context.CreateDefault();
+ // Create ILGPU context with Algorithms extension enabled for RoundToEven support
+ _context = Context.Create(builder => builder.Default().EnableAlgorithms());
// Try to get preferred device (GPU over CPU)
var device = _context.GetPreferredDevice(preferCPU: false);
diff --git a/src/AiDotNet.Tensors/Engines/Optimization/CacheOptimizer.cs b/src/AiDotNet.Tensors/Engines/Optimization/CacheOptimizer.cs
new file mode 100644
index 000000000..46e832af9
--- /dev/null
+++ b/src/AiDotNet.Tensors/Engines/Optimization/CacheOptimizer.cs
@@ -0,0 +1,209 @@
+using System;
+using System.Runtime.CompilerServices;
+
+namespace AiDotNet.Tensors.Engines.Optimization
+{
+ ///
+ /// Provides CPU cache optimization utilities including prefetching and cache-aware algorithms.
+ /// These utilities help maximize cache efficiency for tensor operations.
+ ///
+ public static class CacheOptimizer
+ {
+ ///
+ /// Gets the optimal block size for the L1 cache
+ ///
+ public static int L1BlockSize => 64; // 64 floats = 256 bytes, typical L1 cache line
+
+ ///
+ /// Gets the optimal block size for the L2 cache
+ ///
+ public static int L2BlockSize => 512; // Tuned for typical L2 cache
+
+ ///
+ /// Gets the optimal block size for the L3 cache
+ ///
+ public static int L3BlockSize => 2048; // Tuned for typical L3 cache
+
+ ///
+ /// Prefetch data for reading (temporal locality)
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void Prefetch(void* address)
+ {
+ // This hints the CPU to fetch data into cache
+ // Note: .NET JIT may or may not honor this depending on platform
+#if NET5_0_OR_GREATER
+ if (System.Runtime.Intrinsics.X86.Sse.IsSupported)
+ {
+ System.Runtime.Intrinsics.X86.Sse.Prefetch0(address);
+ }
+#endif
+ // No-op on non-x86 platforms, if SSE is not supported, or on .NET Framework
+ }
+
+ ///
+ /// Prefetch data with low temporal locality (won't pollute cache)
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void PrefetchNonTemporal(void* address)
+ {
+#if NET5_0_OR_GREATER
+ if (System.Runtime.Intrinsics.X86.Sse.IsSupported)
+ {
+ System.Runtime.Intrinsics.X86.Sse.PrefetchNonTemporal(address);
+ }
+#endif
+ // No-op on non-x86 platforms, if SSE is not supported, or on .NET Framework
+ }
+
+ ///
+ /// Computes optimal tiling parameters for a 2D operation
+ ///
+ public static (int tileM, int tileN, int tileK) ComputeOptimalTiling(
+ int m, int n, int k,
+ int elementSize = 4) // 4 bytes for float
+ {
+ var caps = PlatformDetector.Capabilities;
+ int l1Size = caps.L1CacheSize;
+
+ // We want tiles to fit in L1 cache
+ // For matrix multiplication: tileM * tileK + tileK * tileN + tileM * tileN elements
+ // Simplified: aim for sqrt(L1Size / (3 * elementSize)) per dimension
+
+ int maxTileSize = (int)Math.Sqrt(l1Size / (3.0 * elementSize));
+
+ // Round down to nearest power of 2 for better memory alignment
+ int tileSize = 1;
+ while (tileSize * 2 <= maxTileSize)
+ {
+ tileSize *= 2;
+ }
+
+ // Ensure minimum tile size
+ tileSize = Math.Max(tileSize, 16);
+
+ // Adjust based on actual matrix dimensions
+ int tileM = Math.Min(tileSize, m);
+ int tileN = Math.Min(tileSize, n);
+ int tileK = Math.Min(tileSize, k);
+
+ return (tileM, tileN, tileK);
+ }
+
+ ///
+ /// Cache-aware transpose of a 2D array
+ ///
+ public static unsafe void TransposeBlocked(float* src, float* dst, int rows, int cols)
+ {
+ const int blockSize = 32; // Tuned for cache line size
+
+ for (int i = 0; i < rows; i += blockSize)
+ {
+ for (int j = 0; j < cols; j += blockSize)
+ {
+ int iMax = Math.Min(i + blockSize, rows);
+ int jMax = Math.Min(j + blockSize, cols);
+
+ // Transpose block
+ for (int ii = i; ii < iMax; ii++)
+ {
+ for (int jj = j; jj < jMax; jj++)
+ {
+ dst[jj * rows + ii] = src[ii * cols + jj];
+ }
+ }
+ }
+ }
+ }
+
+ ///
+ /// Cache-aware copying with prefetching
+ ///
+ public static unsafe void CopyWithPrefetch(float* src, float* dst, int length)
+ {
+ const int prefetchDistance = 64; // Prefetch 64 elements ahead
+
+ int i = 0;
+
+ // Main loop with prefetching
+ for (; i < length - prefetchDistance; i++)
+ {
+ Prefetch(src + i + prefetchDistance);
+ dst[i] = src[i];
+ }
+
+ // Remaining elements without prefetch
+ for (; i < length; i++)
+ {
+ dst[i] = src[i];
+ }
+ }
+
+ ///
+ /// Z-order (Morton order) indexing for better cache locality in 2D access patterns
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int MortonEncode(int x, int y)
+ {
+ return (Part1By1(y) << 1) | Part1By1(x);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static int Part1By1(int n)
+ {
+ n &= 0x0000ffff;
+ n = (n ^ (n << 8)) & 0x00ff00ff;
+ n = (n ^ (n << 4)) & 0x0f0f0f0f;
+ n = (n ^ (n << 2)) & 0x33333333;
+ n = (n ^ (n << 1)) & 0x55555555;
+ return n;
+ }
+
+ ///
+ /// Converts Z-order index back to 2D coordinates
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static (int x, int y) MortonDecode(int code)
+ {
+ return (Compact1By1(code), Compact1By1(code >> 1));
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static int Compact1By1(int n)
+ {
+ n &= 0x55555555;
+ n = (n ^ (n >> 1)) & 0x33333333;
+ n = (n ^ (n >> 2)) & 0x0f0f0f0f;
+ n = (n ^ (n >> 4)) & 0x00ff00ff;
+ n = (n ^ (n >> 8)) & 0x0000ffff;
+ return n;
+ }
+
+ ///
+ /// Estimates the number of cache misses for a given access pattern
+ ///
+ public static double EstimateCacheMisses(int dataSize, int accessStride, int cacheSize, int cacheLineSize)
+ {
+ // Simple cache miss estimation model
+ int elementsPerLine = cacheLineSize / sizeof(float);
+ int totalLines = (dataSize + elementsPerLine - 1) / elementsPerLine;
+ int cacheLinesAvailable = cacheSize / cacheLineSize;
+
+ if (accessStride <= elementsPerLine)
+ {
+ // Sequential access - good cache behavior
+ return totalLines * 0.1; // ~10% miss rate for sequential
+ }
+ else if (totalLines <= cacheLinesAvailable)
+ {
+ // Data fits in cache
+ return totalLines * 0.05; // ~5% miss rate
+ }
+ else
+ {
+ // Poor cache behavior - strided access with cache thrashing
+ return totalLines * 0.8; // ~80% miss rate
+ }
+ }
+ }
+}
diff --git a/src/AiDotNet.Tensors/Engines/Optimization/LoopOptimizer.cs b/src/AiDotNet.Tensors/Engines/Optimization/LoopOptimizer.cs
new file mode 100644
index 000000000..32a83a9a9
--- /dev/null
+++ b/src/AiDotNet.Tensors/Engines/Optimization/LoopOptimizer.cs
@@ -0,0 +1,218 @@
+using System;
+using System.Runtime.CompilerServices;
+
+namespace AiDotNet.Tensors.Engines.Optimization
+{
+ ///
+ /// Provides loop optimization techniques including tiling and vectorization hints.
+ /// These utilities help maximize performance for tensor operations.
+ ///
+ public static class LoopOptimizer
+ {
+ ///
+ /// 2D loop tiling for matrix operations
+ ///
+ public static void Tile2D(
+ int rows, int cols,
+ int tileSize,
+ Action tileAction)
+ {
+ for (int i = 0; i < rows; i += tileSize)
+ {
+ int iEnd = Math.Min(i + tileSize, rows);
+
+ for (int j = 0; j < cols; j += tileSize)
+ {
+ int jEnd = Math.Min(j + tileSize, cols);
+
+ tileAction(i, iEnd, j, jEnd);
+ }
+ }
+ }
+
+ ///
+ /// 3D loop tiling for tensor operations
+ ///
+ public static void Tile3D(
+ int dim1, int dim2, int dim3,
+ int tileSize1, int tileSize2, int tileSize3,
+ Action tileAction)
+ {
+ for (int i = 0; i < dim1; i += tileSize1)
+ {
+ int iEnd = Math.Min(i + tileSize1, dim1);
+
+ for (int j = 0; j < dim2; j += tileSize2)
+ {
+ int jEnd = Math.Min(j + tileSize2, dim2);
+
+ for (int k = 0; k < dim3; k += tileSize3)
+ {
+ int kEnd = Math.Min(k + tileSize3, dim3);
+
+ tileAction(i, iEnd, j, jEnd, k, kEnd);
+ }
+ }
+ }
+ }
+
+ ///
+ /// Loop unrolling hint - processes elements in groups
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void UnrollBy4(int length, Action action)
+ {
+ int i = 0;
+ int unrolledLength = length & ~3; // Round down to multiple of 4
+
+ // Unrolled loop
+ for (; i < unrolledLength; i += 4)
+ {
+ action(i);
+ action(i + 1);
+ action(i + 2);
+ action(i + 3);
+ }
+
+ // Remainder
+ for (; i < length; i++)
+ {
+ action(i);
+ }
+ }
+
+ ///
+ /// Loop unrolling by 8 for better SIMD utilization
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void UnrollBy8(int length, Action action)
+ {
+ int i = 0;
+ int unrolledLength = length & ~7;
+
+ for (; i < unrolledLength; i += 8)
+ {
+ action(i);
+ action(i + 1);
+ action(i + 2);
+ action(i + 3);
+ action(i + 4);
+ action(i + 5);
+ action(i + 6);
+ action(i + 7);
+ }
+
+ for (; i < length; i++)
+ {
+ action(i);
+ }
+ }
+
+ ///
+ /// Strip mining - breaks loop into chunks for better cache utilization
+ ///
+ public static void StripMine(int totalSize, int stripSize, Action stripAction)
+ {
+ for (int start = 0; start < totalSize; start += stripSize)
+ {
+ int end = Math.Min(start + stripSize, totalSize);
+ stripAction(start, end);
+ }
+ }
+
+ ///
+ /// Loop fusion helper - executes multiple operations in a single pass
+ ///
+ public static void Fuse(int length, params Action[] actions)
+ {
+ for (int i = 0; i < length; i++)
+ {
+ foreach (var action in actions)
+ {
+ action(i);
+ }
+ }
+ }
+
+ ///
+ /// Loop interchange optimization for better cache locality
+ /// Automatically chooses better loop order based on access pattern
+ ///
+ public static void OptimalOrder2D(
+ int rows, int cols,
+ bool rowMajorAccess,
+ Action action)
+ {
+ if (rowMajorAccess)
+ {
+ // Standard order for row-major access
+ for (int i = 0; i < rows; i++)
+ {
+ for (int j = 0; j < cols; j++)
+ {
+ action(i, j);
+ }
+ }
+ }
+ else
+ {
+ // Interchanged order for column-major access
+ for (int j = 0; j < cols; j++)
+ {
+ for (int i = 0; i < rows; i++)
+ {
+ action(i, j);
+ }
+ }
+ }
+ }
+
+ ///
+ /// Parallel loop tiling with work stealing
+ ///
+ public static void ParallelTile2D(
+ int rows, int cols,
+ int tileSize,
+ Action tileAction)
+ {
+ int numTilesI = (rows + tileSize - 1) / tileSize;
+ int numTilesJ = (cols + tileSize - 1) / tileSize;
+ int totalTiles = numTilesI * numTilesJ;
+
+ System.Threading.Tasks.Parallel.For(0, totalTiles, tileIdx =>
+ {
+ int ti = tileIdx / numTilesJ;
+ int tj = tileIdx % numTilesJ;
+
+ int iStart = ti * tileSize;
+ int iEnd = Math.Min(iStart + tileSize, rows);
+
+ int jStart = tj * tileSize;
+ int jEnd = Math.Min(jStart + tileSize, cols);
+
+ tileAction(iStart, iEnd, jStart, jEnd);
+ });
+ }
+
+ ///
+ /// Automatically determines optimal tile size based on data dimensions and cache size
+ ///
+ public static int DetermineOptimalTileSize(int dimension, int elementSize = 4)
+ {
+ var caps = PlatformDetector.Capabilities;
+ int l1Size = caps.L1CacheSize;
+
+ // Aim to fit two tiles in L1 cache (one read, one write)
+ int maxElements = l1Size / (2 * elementSize);
+
+ // Find power of 2 that fits
+ int tileSize = 16; // Minimum tile size
+ while (tileSize * tileSize * 2 < maxElements && tileSize < dimension)
+ {
+ tileSize *= 2;
+ }
+
+ return Math.Min(tileSize, dimension);
+ }
+ }
+}
diff --git a/src/AiDotNet.Tensors/Engines/Optimization/PerformanceProfiler.cs b/src/AiDotNet.Tensors/Engines/Optimization/PerformanceProfiler.cs
new file mode 100644
index 000000000..23b6cca9e
--- /dev/null
+++ b/src/AiDotNet.Tensors/Engines/Optimization/PerformanceProfiler.cs
@@ -0,0 +1,203 @@
+using System;
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Linq;
+
+namespace AiDotNet.Tensors.Engines.Optimization
+{
+ ///
+ /// Thread-safe performance profiler for tracking operation timings and statistics.
+ /// Use this to measure and optimize tensor operations.
+ ///
+ public sealed class PerformanceProfiler
+ {
+ private static readonly Lazy _instance =
+ new Lazy(() => new PerformanceProfiler());
+
+ private readonly ConcurrentDictionary _stats;
+
+ ///
+ /// Gets the singleton instance of the profiler
+ ///
+ public static PerformanceProfiler Instance => _instance.Value;
+
+ ///
+ /// Enable or disable profiling (disabled by default for production)
+ ///
+ public bool Enabled { get; set; }
+
+ private PerformanceProfiler()
+ {
+ _stats = new ConcurrentDictionary();
+ Enabled = false;
+ }
+
+ ///
+ /// Starts profiling an operation
+ ///
+ public IDisposable Profile(string operationName)
+ {
+ if (!Enabled)
+ return DisposableHelper.Empty;
+
+ return new ProfileScope(this, operationName);
+ }
+
+ ///
+ /// Records a completed operation
+ ///
+ internal void RecordOperation(string operationName, long elapsedTicks, long memoryBytes = 0)
+ {
+ if (!Enabled)
+ return;
+
+ var updated = _stats.AddOrUpdate(
+ operationName,
+ _ => new OperationStats
+ {
+ OperationName = operationName,
+ CallCount = 1,
+ TotalTicks = elapsedTicks,
+ MinTicks = elapsedTicks,
+ MaxTicks = elapsedTicks,
+ TotalMemoryBytes = memoryBytes
+ },
+ (_, existing) =>
+ {
+ // Return new object to ensure thread-safety (avoid mutating existing object)
+ return new OperationStats
+ {
+ OperationName = existing.OperationName,
+ CallCount = existing.CallCount + 1,
+ TotalTicks = existing.TotalTicks + elapsedTicks,
+ MinTicks = Math.Min(existing.MinTicks, elapsedTicks),
+ MaxTicks = Math.Max(existing.MaxTicks, elapsedTicks),
+ TotalMemoryBytes = existing.TotalMemoryBytes + memoryBytes
+ };
+ });
+
+ _ = updated.CallCount;
+ }
+
+ ///
+ /// Gets statistics for a specific operation
+ ///
+ public OperationStats? GetStats(string operationName)
+ {
+ return _stats.TryGetValue(operationName, out var stats) ? stats : null;
+ }
+
+ ///
+ /// Gets all recorded statistics
+ ///
+ public OperationStats[] GetAllStats()
+ {
+ return _stats.Values.OrderByDescending(s => s.TotalMilliseconds).ToArray();
+ }
+
+ ///
+ /// Clears all statistics
+ ///
+ public void Clear()
+ {
+ _stats.Clear();
+ }
+
+ ///
+ /// Generates a performance report
+ ///
+ public string GenerateReport()
+ {
+ var stats = GetAllStats();
+ if (stats.Length == 0)
+ return "No profiling data available.";
+
+ var report = new System.Text.StringBuilder();
+ report.AppendLine("=== Performance Profile Report ===");
+ report.AppendLine();
+ report.AppendLine($"{"Operation",-40} {"Calls",10} {"Total (ms)",12} {"Avg (ms)",12} {"Min (ms)",12} {"Max (ms)",12} {"Memory (MB)",12}");
+ report.AppendLine(new string('-', 120));
+
+ foreach (var stat in stats)
+ {
+ report.AppendLine($"{stat.OperationName,-40} {stat.CallCount,10} {stat.TotalMilliseconds,12:F3} " +
+ $"{stat.AverageMilliseconds,12:F3} {stat.MinMilliseconds,12:F3} " +
+ $"{stat.MaxMilliseconds,12:F3} {stat.TotalMemoryMB,12:F2}");
+ }
+
+ report.AppendLine();
+ report.AppendLine($"Total operations: {stats.Length}");
+ report.AppendLine($"Total time: {stats.Sum(s => s.TotalMilliseconds):F3} ms");
+
+ return report.ToString();
+ }
+
+ private class ProfileScope : IDisposable
+ {
+ private readonly PerformanceProfiler _profiler;
+ private readonly string _operationName;
+ private readonly Stopwatch _stopwatch;
+ private readonly long _startMemory;
+
+ public ProfileScope(PerformanceProfiler profiler, string operationName)
+ {
+ _profiler = profiler;
+ _operationName = operationName;
+#if NET6_0_OR_GREATER
+ // Use per-thread allocation tracking for more accurate measurements
+ _startMemory = GC.GetAllocatedBytesForCurrentThread();
+#else
+ // Fallback for .NET Framework - less accurate but functional
+ _startMemory = GC.GetTotalMemory(false);
+#endif
+ _stopwatch = Stopwatch.StartNew();
+ }
+
+ public void Dispose()
+ {
+ _stopwatch.Stop();
+#if NET6_0_OR_GREATER
+ long endMemory = GC.GetAllocatedBytesForCurrentThread();
+#else
+ long endMemory = GC.GetTotalMemory(false);
+#endif
+ // Only report positive memory delta (allocation), ignore GC effects
+ long memoryDelta = Math.Max(0, endMemory - _startMemory);
+
+ _profiler.RecordOperation(_operationName, _stopwatch.ElapsedTicks, memoryDelta);
+ }
+ }
+
+ private static class DisposableHelper
+ {
+ public static readonly IDisposable Empty = new EmptyDisposable();
+
+ private class EmptyDisposable : IDisposable
+ {
+ public void Dispose() { }
+ }
+ }
+ }
+
+ ///
+ /// Statistics for a profiled operation
+ ///
+ public class OperationStats
+ {
+ public string OperationName { get; set; } = string.Empty;
+ public long CallCount { get; set; }
+ public long TotalTicks { get; set; }
+ public long MinTicks { get; set; }
+ public long MaxTicks { get; set; }
+ public long TotalMemoryBytes { get; set; }
+
+ public double TotalMilliseconds => TotalTicks * 1000.0 / Stopwatch.Frequency;
+ public double AverageMilliseconds => CallCount > 0 ? TotalMilliseconds / CallCount : 0;
+ public double MinMilliseconds => MinTicks * 1000.0 / Stopwatch.Frequency;
+ public double MaxMilliseconds => MaxTicks * 1000.0 / Stopwatch.Frequency;
+ public double TotalMemoryMB => TotalMemoryBytes / (1024.0 * 1024.0);
+ public double AverageMemoryMB => CallCount > 0 ? TotalMemoryMB / CallCount : 0;
+
+ public double ThroughputOpsPerSecond => TotalMilliseconds > 0 ? CallCount / (TotalMilliseconds / 1000.0) : 0;
+ }
+}
diff --git a/src/AiDotNet.Tensors/Engines/PlatformDetector.cs b/src/AiDotNet.Tensors/Engines/PlatformDetector.cs
new file mode 100644
index 000000000..07a607fc1
--- /dev/null
+++ b/src/AiDotNet.Tensors/Engines/PlatformDetector.cs
@@ -0,0 +1,284 @@
+using System;
+using System.Runtime.InteropServices;
+#if NET5_0_OR_GREATER
+using System.Runtime.Intrinsics.X86;
+using System.Runtime.Intrinsics.Arm;
+#endif
+
+namespace AiDotNet.Tensors.Engines
+{
+ ///
+ /// Provides platform and hardware capability detection for optimizing
+ /// tensor operations based on available SIMD instructions and cache sizes.
+ ///
+ public static class PlatformDetector
+ {
+ private static readonly Lazy _capabilities =
+ new Lazy(DetectCapabilities);
+
+ ///
+ /// Gets the detected platform capabilities
+ ///
+ public static PlatformCapabilities Capabilities => _capabilities.Value;
+
+ private static PlatformCapabilities DetectCapabilities()
+ {
+ var caps = new PlatformCapabilities
+ {
+ Architecture = RuntimeInformation.ProcessArchitecture,
+ OSDescription = RuntimeInformation.OSDescription,
+ FrameworkDescription = RuntimeInformation.FrameworkDescription,
+ ProcessorCount = Environment.ProcessorCount,
+ Is64BitProcess = Environment.Is64BitProcess,
+ Is64BitOperatingSystem = Environment.Is64BitOperatingSystem
+ };
+
+#if NET5_0_OR_GREATER
+ // Detect x86/x64 SIMD support
+ if (caps.Architecture == Architecture.X64 || caps.Architecture == Architecture.X86)
+ {
+ caps.HasSSE = Sse.IsSupported;
+ caps.HasSSE2 = Sse2.IsSupported;
+ caps.HasSSE3 = Sse3.IsSupported;
+ caps.HasSSSE3 = Ssse3.IsSupported;
+ caps.HasSSE41 = Sse41.IsSupported;
+ caps.HasSSE42 = Sse42.IsSupported;
+ caps.HasAVX = Avx.IsSupported;
+ caps.HasAVX2 = Avx2.IsSupported;
+ caps.HasFMA = Fma.IsSupported;
+ caps.HasAVX512F = Avx512F.IsSupported;
+ caps.HasAVX512BW = Avx512BW.IsSupported;
+ caps.HasAVX512DQ = Avx512DQ.IsSupported;
+ // AVX-512VL is implied when other AVX-512 extensions are supported
+ caps.HasAVX512VL = Avx512F.VL.IsSupported;
+ }
+
+ // Detect ARM SIMD support
+ if (caps.Architecture == Architecture.Arm64 || caps.Architecture == Architecture.Arm)
+ {
+ caps.HasNeon = AdvSimd.IsSupported;
+ caps.HasArmBase = ArmBase.IsSupported;
+ caps.HasArmAes = System.Runtime.Intrinsics.Arm.Aes.IsSupported;
+ caps.HasArmCrc32 = Crc32.IsSupported;
+ caps.HasArmDp = AdvSimd.Arm64.IsSupported;
+ }
+#endif
+
+ // Detect cache sizes (approximate based on typical values)
+ caps.L1CacheSize = EstimateL1CacheSize(caps.Architecture);
+ caps.L2CacheSize = EstimateL2CacheSize(caps.Architecture);
+ caps.L3CacheSize = EstimateL3CacheSize(caps.Architecture);
+
+ // Check for GPU support (requires additional libraries)
+ caps.HasCudaSupport = DetectCudaSupport();
+ caps.HasOpenCLSupport = DetectOpenCLSupport();
+
+ return caps;
+ }
+
+ private static int EstimateL1CacheSize(Architecture arch)
+ {
+ // Typical L1 cache size is 32KB per core
+ return 32 * 1024;
+ }
+
+ private static int EstimateL2CacheSize(Architecture arch)
+ {
+ // Typical L2 cache size is 256KB per core
+ return 256 * 1024;
+ }
+
+ private static int EstimateL3CacheSize(Architecture arch)
+ {
+ // Typical L3 cache size is 2-8MB shared
+ return 8 * 1024 * 1024;
+ }
+
+ ///
+ /// Checks whether CUDA driver support appears to be available on this machine.
+ ///
+ /// Notes:
+ /// - This attempts a lightweight runtime check for the CUDA driver library (not the toolkit).
+ /// - It is intentionally conservative: if we cannot verify CUDA driver presence, we return false.
+ /// - This does not guarantee that higher-level CUDA compute is usable (device selection, permissions, etc.).
+ ///
+ private static bool DetectCudaSupport()
+ {
+ if (!Environment.Is64BitProcess)
+ return false;
+
+#if NET5_0_OR_GREATER
+ // Prefer checking for the CUDA driver library:
+ // - Windows: nvcuda.dll
+ // - Linux: libcuda.so.1 (or libcuda.so)
+ try
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ return TryLoadNativeLibrary("nvcuda.dll");
+ }
+
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
+ {
+ return TryLoadNativeLibrary("libcuda.so.1") || TryLoadNativeLibrary("libcuda.so");
+ }
+
+ return false;
+ }
+ catch
+ {
+ return false;
+ }
+#else
+ // .NET Framework builds are conservative here; implement a native check if/when CUDA support is added for net471.
+ return false;
+#endif
+ }
+
+#if NET5_0_OR_GREATER
+ private static bool TryLoadNativeLibrary(string name)
+ {
+ if (string.IsNullOrWhiteSpace(name))
+ return false;
+
+ if (NativeLibrary.TryLoad(name, out var handle))
+ {
+ NativeLibrary.Free(handle);
+ return true;
+ }
+
+ return false;
+ }
+#endif
+
+ private static bool DetectOpenCLSupport()
+ {
+ // This would require OpenCL library calls
+ // For now, we'll return false (requires additional implementation)
+ return false;
+ }
+
+ ///
+ /// Gets a human-readable description of the platform capabilities
+ ///
+ public static string GetCapabilitiesDescription()
+ {
+ var caps = Capabilities;
+ var desc = new System.Text.StringBuilder();
+
+ desc.AppendLine($"Platform: {caps.OSDescription}");
+ desc.AppendLine($"Architecture: {caps.Architecture}");
+ desc.AppendLine($"Framework: {caps.FrameworkDescription}");
+ desc.AppendLine($"Processor Count: {caps.ProcessorCount}");
+ desc.AppendLine($"64-bit Process: {caps.Is64BitProcess}");
+ desc.AppendLine();
+
+ if (caps.Architecture == Architecture.X64 || caps.Architecture == Architecture.X86)
+ {
+ desc.AppendLine("x86/x64 SIMD Support:");
+ desc.AppendLine($" SSE: {caps.HasSSE}");
+ desc.AppendLine($" SSE2: {caps.HasSSE2}");
+ desc.AppendLine($" SSE3: {caps.HasSSE3}");
+ desc.AppendLine($" SSSE3: {caps.HasSSSE3}");
+ desc.AppendLine($" SSE4.1: {caps.HasSSE41}");
+ desc.AppendLine($" SSE4.2: {caps.HasSSE42}");
+ desc.AppendLine($" AVX: {caps.HasAVX}");
+ desc.AppendLine($" AVX2: {caps.HasAVX2}");
+ desc.AppendLine($" FMA: {caps.HasFMA}");
+ desc.AppendLine($" AVX-512F: {caps.HasAVX512F}");
+ desc.AppendLine($" AVX-512BW: {caps.HasAVX512BW}");
+ desc.AppendLine($" AVX-512DQ: {caps.HasAVX512DQ}");
+ desc.AppendLine($" AVX-512VL: {caps.HasAVX512VL}");
+ }
+
+ if (caps.Architecture == Architecture.Arm64 || caps.Architecture == Architecture.Arm)
+ {
+ desc.AppendLine("ARM SIMD Support:");
+ desc.AppendLine($" NEON: {caps.HasNeon}");
+ desc.AppendLine($" ARM Base: {caps.HasArmBase}");
+ desc.AppendLine($" AES: {caps.HasArmAes}");
+ desc.AppendLine($" CRC32: {caps.HasArmCrc32}");
+ desc.AppendLine($" Dot Product: {caps.HasArmDp}");
+ }
+
+ desc.AppendLine();
+ desc.AppendLine("GPU Support:");
+ desc.AppendLine($" CUDA: {caps.HasCudaSupport}");
+ desc.AppendLine($" OpenCL: {caps.HasOpenCLSupport}");
+
+ return desc.ToString();
+ }
+ }
+
+ ///
+ /// Represents detected platform capabilities including SIMD support,
+ /// cache sizes, and GPU availability.
+ ///
+ public class PlatformCapabilities
+ {
+ // Basic platform info
+ public Architecture Architecture { get; set; }
+ public string OSDescription { get; set; } = string.Empty;
+ public string FrameworkDescription { get; set; } = string.Empty;
+ public int ProcessorCount { get; set; }
+ public bool Is64BitProcess { get; set; }
+ public bool Is64BitOperatingSystem { get; set; }
+
+ // x86/x64 SIMD capabilities
+ public bool HasSSE { get; set; }
+ public bool HasSSE2 { get; set; }
+ public bool HasSSE3 { get; set; }
+ public bool HasSSSE3 { get; set; }
+ public bool HasSSE41 { get; set; }
+ public bool HasSSE42 { get; set; }
+ public bool HasAVX { get; set; }
+ public bool HasAVX2 { get; set; }
+ public bool HasFMA { get; set; }
+ public bool HasAVX512F { get; set; }
+ public bool HasAVX512BW { get; set; }
+ public bool HasAVX512DQ { get; set; }
+ public bool HasAVX512VL { get; set; }
+
+ // ARM SIMD capabilities
+ public bool HasNeon { get; set; }
+ public bool HasArmBase { get; set; }
+ public bool HasArmAes { get; set; }
+ public bool HasArmCrc32 { get; set; }
+ public bool HasArmDp { get; set; }
+
+ // Cache information
+ public int L1CacheSize { get; set; }
+ public int L2CacheSize { get; set; }
+ public int L3CacheSize { get; set; }
+
+ // GPU capabilities
+ public bool HasCudaSupport { get; set; }
+ public bool HasOpenCLSupport { get; set; }
+
+ ///
+ /// Returns the best available SIMD instruction set
+ ///
+ public string GetBestSimdSet()
+ {
+ if (Architecture == Architecture.X64 || Architecture == Architecture.X86)
+ {
+ if (HasAVX512F) return "AVX-512";
+ if (HasAVX2) return "AVX2";
+ if (HasAVX) return "AVX";
+ if (HasSSE42) return "SSE4.2";
+ if (HasSSE41) return "SSE4.1";
+ if (HasSSSE3) return "SSSE3";
+ if (HasSSE3) return "SSE3";
+ if (HasSSE2) return "SSE2";
+ if (HasSSE) return "SSE";
+ }
+ else if (Architecture == Architecture.Arm64 || Architecture == Architecture.Arm)
+ {
+ if (HasArmDp) return "NEON with Dot Product";
+ if (HasNeon) return "NEON";
+ }
+
+ return "None";
+ }
+ }
+}
diff --git a/src/AiDotNet.Tensors/Engines/Simd/SimdKernels.cs b/src/AiDotNet.Tensors/Engines/Simd/SimdKernels.cs
new file mode 100644
index 000000000..8b4e94788
--- /dev/null
+++ b/src/AiDotNet.Tensors/Engines/Simd/SimdKernels.cs
@@ -0,0 +1,424 @@
+using System;
+using System.Runtime.CompilerServices;
+#if NET5_0_OR_GREATER
+using System.Runtime.Intrinsics;
+using System.Runtime.Intrinsics.X86;
+using System.Runtime.Intrinsics.Arm;
+#endif
+
+namespace AiDotNet.Tensors.Engines.Simd
+{
+ ///
+ /// SIMD-optimized kernels for common operations.
+ /// Provides hardware-accelerated implementations using AVX2, SSE, and ARM NEON.
+ /// Falls back to scalar operations on .NET Framework.
+ ///
+ public static class SimdKernels
+ {
+ ///
+ /// SIMD-optimized vector addition for float arrays
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void VectorAdd(float* a, float* b, float* result, int length)
+ {
+ int i = 0;
+
+#if NET5_0_OR_GREATER
+ // AVX2 path (8 floats at a time)
+ if (Avx2.IsSupported && length >= 8)
+ {
+ int simdLength = length & ~7; // Round down to multiple of 8
+ for (; i < simdLength; i += 8)
+ {
+ var va = Avx.LoadVector256(a + i);
+ var vb = Avx.LoadVector256(b + i);
+ var vr = Avx.Add(va, vb);
+ Avx.Store(result + i, vr);
+ }
+ }
+ // SSE path (4 floats at a time)
+ else if (Sse.IsSupported && length >= 4)
+ {
+ int simdLength = length & ~3; // Round down to multiple of 4
+ for (; i < simdLength; i += 4)
+ {
+ var va = Sse.LoadVector128(a + i);
+ var vb = Sse.LoadVector128(b + i);
+ var vr = Sse.Add(va, vb);
+ Sse.Store(result + i, vr);
+ }
+ }
+ // NEON path (4 floats at a time)
+ else if (AdvSimd.IsSupported && length >= 4)
+ {
+ int simdLength = length & ~3;
+ for (; i < simdLength; i += 4)
+ {
+ var va = AdvSimd.LoadVector128(a + i);
+ var vb = AdvSimd.LoadVector128(b + i);
+ var vr = AdvSimd.Add(va, vb);
+ AdvSimd.Store(result + i, vr);
+ }
+ }
+#endif
+
+ // Scalar fallback for remaining elements
+ for (; i < length; i++)
+ {
+ result[i] = a[i] + b[i];
+ }
+ }
+
+ ///
+ /// SIMD-optimized vector multiplication
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void VectorMultiply(float* a, float* b, float* result, int length)
+ {
+ int i = 0;
+
+#if NET5_0_OR_GREATER
+ if (Avx2.IsSupported && length >= 8)
+ {
+ int simdLength = length & ~7;
+ for (; i < simdLength; i += 8)
+ {
+ var va = Avx.LoadVector256(a + i);
+ var vb = Avx.LoadVector256(b + i);
+ var vr = Avx.Multiply(va, vb);
+ Avx.Store(result + i, vr);
+ }
+ }
+ else if (Sse.IsSupported && length >= 4)
+ {
+ int simdLength = length & ~3;
+ for (; i < simdLength; i += 4)
+ {
+ var va = Sse.LoadVector128(a + i);
+ var vb = Sse.LoadVector128(b + i);
+ var vr = Sse.Multiply(va, vb);
+ Sse.Store(result + i, vr);
+ }
+ }
+ else if (AdvSimd.IsSupported && length >= 4)
+ {
+ int simdLength = length & ~3;
+ for (; i < simdLength; i += 4)
+ {
+ var va = AdvSimd.LoadVector128(a + i);
+ var vb = AdvSimd.LoadVector128(b + i);
+ var vr = AdvSimd.Multiply(va, vb);
+ AdvSimd.Store(result + i, vr);
+ }
+ }
+#endif
+
+ for (; i < length; i++)
+ {
+ result[i] = a[i] * b[i];
+ }
+ }
+
+ ///
+ /// SIMD-optimized dot product
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe float DotProduct(float* a, float* b, int length)
+ {
+ float sum = 0.0f;
+ int i = 0;
+
+#if NET5_0_OR_GREATER
+ if (Avx2.IsSupported && length >= 8)
+ {
+ var vsum = Vector256.Zero;
+ int simdLength = length & ~7;
+
+ for (; i < simdLength; i += 8)
+ {
+ var va = Avx.LoadVector256(a + i);
+ var vb = Avx.LoadVector256(b + i);
+ vsum = Fma.IsSupported
+ ? Fma.MultiplyAdd(va, vb, vsum)
+ : Avx.Add(vsum, Avx.Multiply(va, vb));
+ }
+
+ // Horizontal sum of vector
+ var high = Avx.ExtractVector128(vsum, 1);
+ var low = vsum.GetLower();
+ var sum128 = Sse.Add(high, low);
+
+ // Further reduce 4 floats to 1
+ var shuf = Sse.Shuffle(sum128, sum128, 0b_11_10_11_10);
+ sum128 = Sse.Add(sum128, shuf);
+ shuf = Sse.Shuffle(sum128, sum128, 0b_01_01_01_01);
+ sum128 = Sse.Add(sum128, shuf);
+ sum = sum128.ToScalar();
+ }
+ else if (Sse.IsSupported && length >= 4)
+ {
+ var vsum = Vector128.Zero;
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var va = Sse.LoadVector128(a + i);
+ var vb = Sse.LoadVector128(b + i);
+ vsum = Sse.Add(vsum, Sse.Multiply(va, vb));
+ }
+
+ // Horizontal sum
+ var shuf = Sse.Shuffle(vsum, vsum, 0b_11_10_11_10);
+ vsum = Sse.Add(vsum, shuf);
+ shuf = Sse.Shuffle(vsum, vsum, 0b_01_01_01_01);
+ vsum = Sse.Add(vsum, shuf);
+ sum = vsum.ToScalar();
+ }
+ else if (AdvSimd.IsSupported && length >= 4)
+ {
+ var vsum = Vector128.Zero;
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var va = AdvSimd.LoadVector128(a + i);
+ var vb = AdvSimd.LoadVector128(b + i);
+ vsum = AdvSimd.Add(vsum, AdvSimd.Multiply(va, vb));
+ }
+
+ // Horizontal sum for ARM - use AddPairwise on ARM64, manual fallback otherwise
+ if (AdvSimd.Arm64.IsSupported)
+ {
+ // AddPairwise reduces pairs: [a,b,c,d] -> [a+b, c+d, ?, ?] (lower 64 bits)
+ var pairSum = AdvSimd.Arm64.AddPairwise(vsum, vsum);
+ var finalSum = AdvSimd.Arm64.AddPairwiseScalar(pairSum.GetLower());
+ sum = finalSum.ToScalar();
+ }
+ else
+ {
+ sum = vsum.GetElement(0) + vsum.GetElement(1) + vsum.GetElement(2) + vsum.GetElement(3);
+ }
+ }
+#endif
+
+ // Scalar remainder
+ for (; i < length; i++)
+ {
+ sum += a[i] * b[i];
+ }
+
+ return sum;
+ }
+
+ ///
+ /// SIMD-optimized scalar multiply-add (result = a + b * scalar)
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void ScalarMultiplyAdd(float* a, float* b, float scalar, float* result, int length)
+ {
+ int i = 0;
+
+#if NET5_0_OR_GREATER
+ if (Avx2.IsSupported && length >= 8)
+ {
+ var vscalar = Vector256.Create(scalar);
+ int simdLength = length & ~7;
+
+ for (; i < simdLength; i += 8)
+ {
+ var va = Avx.LoadVector256(a + i);
+ var vb = Avx.LoadVector256(b + i);
+ var vr = Fma.IsSupported
+ ? Fma.MultiplyAdd(vb, vscalar, va)
+ : Avx.Add(va, Avx.Multiply(vb, vscalar));
+ Avx.Store(result + i, vr);
+ }
+ }
+ else if (Sse.IsSupported && length >= 4)
+ {
+ var vscalar = Vector128.Create(scalar);
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var va = Sse.LoadVector128(a + i);
+ var vb = Sse.LoadVector128(b + i);
+ var vr = Sse.Add(va, Sse.Multiply(vb, vscalar));
+ Sse.Store(result + i, vr);
+ }
+ }
+ else if (AdvSimd.IsSupported && length >= 4)
+ {
+ var vscalar = Vector128.Create(scalar);
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var va = AdvSimd.LoadVector128(a + i);
+ var vb = AdvSimd.LoadVector128(b + i);
+ var vr = AdvSimd.Add(va, AdvSimd.Multiply(vb, vscalar));
+ AdvSimd.Store(result + i, vr);
+ }
+ }
+#endif
+
+ for (; i < length; i++)
+ {
+ result[i] = a[i] + b[i] * scalar;
+ }
+ }
+
+ ///
+ /// SIMD-optimized ReLU activation
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void ReLU(float* input, float* output, int length)
+ {
+ int i = 0;
+
+#if NET5_0_OR_GREATER
+ if (Avx2.IsSupported && length >= 8)
+ {
+ var vzero = Vector256.Zero;
+ int simdLength = length & ~7;
+
+ for (; i < simdLength; i += 8)
+ {
+ var v = Avx.LoadVector256(input + i);
+ var vr = Avx.Max(v, vzero);
+ Avx.Store(output + i, vr);
+ }
+ }
+ else if (Sse.IsSupported && length >= 4)
+ {
+ var vzero = Vector128.Zero;
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var v = Sse.LoadVector128(input + i);
+ var vr = Sse.Max(v, vzero);
+ Sse.Store(output + i, vr);
+ }
+ }
+ else if (AdvSimd.IsSupported && length >= 4)
+ {
+ var vzero = Vector128.Zero;
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var v = AdvSimd.LoadVector128(input + i);
+ var vr = AdvSimd.Max(v, vzero);
+ AdvSimd.Store(output + i, vr);
+ }
+ }
+#endif
+
+ for (; i < length; i++)
+ {
+ output[i] = Math.Max(0.0f, input[i]);
+ }
+ }
+
+ ///
+ /// SIMD-optimized element-wise exponential
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe void Exp(float* input, float* output, int length)
+ {
+ // Note: True SIMD exp requires approximation algorithms
+ // This is a scalar fallback - can be optimized with SVML or custom approximations
+ for (int i = 0; i < length; i++)
+ {
+#if NET5_0_OR_GREATER
+ output[i] = MathF.Exp(input[i]);
+#else
+ output[i] = (float)Math.Exp(input[i]);
+#endif
+ }
+ }
+
+ ///
+ /// SIMD-optimized sum reduction
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static unsafe float Sum(float* data, int length)
+ {
+ float sum = 0.0f;
+ int i = 0;
+
+#if NET5_0_OR_GREATER
+ if (Avx2.IsSupported && length >= 8)
+ {
+ var vsum = Vector256.Zero;
+ int simdLength = length & ~7;
+
+ for (; i < simdLength; i += 8)
+ {
+ var v = Avx.LoadVector256(data + i);
+ vsum = Avx.Add(vsum, v);
+ }
+
+ var high = Avx.ExtractVector128(vsum, 1);
+ var low = vsum.GetLower();
+ var sum128 = Sse.Add(high, low);
+
+ var shuf = Sse.Shuffle(sum128, sum128, 0b_11_10_11_10);
+ sum128 = Sse.Add(sum128, shuf);
+ shuf = Sse.Shuffle(sum128, sum128, 0b_01_01_01_01);
+ sum128 = Sse.Add(sum128, shuf);
+ sum = sum128.ToScalar();
+ }
+ else if (Sse.IsSupported && length >= 4)
+ {
+ var vsum = Vector128.Zero;
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var v = Sse.LoadVector128(data + i);
+ vsum = Sse.Add(vsum, v);
+ }
+
+ var shuf = Sse.Shuffle(vsum, vsum, 0b_11_10_11_10);
+ vsum = Sse.Add(vsum, shuf);
+ shuf = Sse.Shuffle(vsum, vsum, 0b_01_01_01_01);
+ vsum = Sse.Add(vsum, shuf);
+ sum = vsum.ToScalar();
+ }
+ else if (AdvSimd.IsSupported && length >= 4)
+ {
+ var vsum = Vector128.Zero;
+ int simdLength = length & ~3;
+
+ for (; i < simdLength; i += 4)
+ {
+ var v = AdvSimd.LoadVector128(data + i);
+ vsum = AdvSimd.Add(vsum, v);
+ }
+
+ // Horizontal sum for ARM - use AddPairwise on ARM64, manual fallback otherwise
+ if (AdvSimd.Arm64.IsSupported)
+ {
+ // AddPairwise reduces pairs: [a,b,c,d] -> [a+b, c+d, ?, ?] (lower 64 bits)
+ var pairSum = AdvSimd.Arm64.AddPairwise(vsum, vsum);
+ var finalSum = AdvSimd.Arm64.AddPairwiseScalar(pairSum.GetLower());
+ sum = finalSum.ToScalar();
+ }
+ else
+ {
+ sum = vsum.GetElement(0) + vsum.GetElement(1) + vsum.GetElement(2) + vsum.GetElement(3);
+ }
+ }
+#endif
+
+ for (; i < length; i++)
+ {
+ sum += data[i];
+ }
+
+ return sum;
+ }
+ }
+}
diff --git a/src/AiDotNet.Tensors/LinearAlgebra/TensorBase.cs b/src/AiDotNet.Tensors/LinearAlgebra/TensorBase.cs
index 15f8438ec..cdf447e35 100644
--- a/src/AiDotNet.Tensors/LinearAlgebra/TensorBase.cs
+++ b/src/AiDotNet.Tensors/LinearAlgebra/TensorBase.cs
@@ -55,6 +55,16 @@ public abstract class TensorBase
///
public int Rank => Shape.Length;
+ ///
+ /// Gets direct access to the underlying data array for high-performance operations.
+ ///
+ ///
+ /// Warning: This property provides direct access to internal storage.
+ /// Modifications to this array will affect the tensor. Use with caution in
+ /// performance-critical code paths like SIMD operations.
+ ///
+ public T[] Data => _data.Data;
+
///
/// Initializes a new instance of the TensorBase class with the specified shape.
///
diff --git a/src/AiDotNet.Tensors/LinearAlgebra/VectorBase.cs b/src/AiDotNet.Tensors/LinearAlgebra/VectorBase.cs
index c777a263d..3096d76cd 100644
--- a/src/AiDotNet.Tensors/LinearAlgebra/VectorBase.cs
+++ b/src/AiDotNet.Tensors/LinearAlgebra/VectorBase.cs
@@ -74,6 +74,16 @@ protected VectorBase(IEnumerable values)
///
public int Length => _data.Length;
+ ///
+ /// Gets direct access to the underlying data array for high-performance operations.
+ ///
+ ///
+ /// Warning: This property provides direct access to internal storage.
+ /// Modifications to this array will affect the vector. Use with caution in
+ /// performance-critical code paths like SIMD operations.
+ ///
+ public T[] Data => _data;
+
///
/// Gets a value indicating whether the vector contains no elements.
///
diff --git a/src/AiDotNet.csproj b/src/AiDotNet.csproj
index b261eaf28..0c249075a 100644
--- a/src/AiDotNet.csproj
+++ b/src/AiDotNet.csproj
@@ -3,6 +3,7 @@
net8.0;net471
enable
enable
+ true
True
0.0.5-preview
Ai for .Net
diff --git a/src/Configuration/InferenceOptimizationConfig.cs b/src/Configuration/InferenceOptimizationConfig.cs
index 25dc47f1a..4e4221e2d 100644
--- a/src/Configuration/InferenceOptimizationConfig.cs
+++ b/src/Configuration/InferenceOptimizationConfig.cs
@@ -116,6 +116,97 @@ public class InferenceOptimizationConfig
/// Cache eviction policy (default: LRU).
public CacheEvictionPolicy KVCacheEvictionPolicy { get; set; } = CacheEvictionPolicy.LRU;
+ ///
+ /// Gets or sets whether to use a sliding window KV-cache for long contexts.
+ ///
+ ///
+ /// When enabled, only the most recent tokens are kept.
+ /// This is a common industry approach for long-context serving to cap memory usage.
+ ///
+ public bool UseSlidingWindowKVCache { get; set; } = false;
+
+ ///
+ /// Gets or sets the sliding window size in tokens when is enabled.
+ ///
+ /// Window size in tokens (default: 1024).
+ public int KVCacheWindowSize { get; set; } = 1024;
+
+ ///
+ /// Gets or sets the precision used for KV-cache storage.
+ ///
+ ///
+ ///
+ /// Industry-standard serving stores KV-cache in FP16 to halve memory usage and increase cache capacity.
+ /// The default selects FP16 when KV-cache is enabled and the numeric
+ /// type supports it.
+ ///
+ ///
+ /// For Beginners: This setting controls how much memory your model uses during autoregressive inference.
+ ///
+ /// - FP16: Uses about half the memory (recommended default)
+ /// - FP32: Uses more memory but can be slightly more numerically accurate
+ ///
+ /// Most production systems prefer FP16 KV-cache for capacity and throughput.
+ ///
+ ///
+ public KVCachePrecisionMode KVCachePrecision { get; set; } = KVCachePrecisionMode.Auto;
+
+ ///
+ /// Gets or sets the quantization mode used for KV-cache storage.
+ ///
+ ///
+ ///
+ /// KV-cache quantization can further reduce memory beyond FP16 by storing keys/values in int8 with scaling.
+ /// This is an opt-in advanced feature because it can introduce small numerical error.
+ ///
+ /// For Beginners:
+ /// - None (default): Store KV-cache in FP16/FP32 depending on .
+ /// - Int8: Store KV-cache in 8-bit integers to save memory (advanced).
+ ///
+ ///
+ public KVCacheQuantizationMode KVCacheQuantization { get; set; } = KVCacheQuantizationMode.None;
+
+ ///
+ /// Gets or sets whether to use a paged KV-cache backend (vLLM-style) for long-context / multi-sequence serving.
+ ///
+ ///
+ /// When enabled, the system may choose a paged cache implementation that allocates KV memory in fixed-size blocks.
+ /// This is the industry-standard approach for high-throughput serving where many sequences are active concurrently.
+ /// Users can disable this to force the traditional contiguous KV-cache.
+ ///
+ public bool EnablePagedKVCache { get; set; } = true;
+
+ ///
+ /// Gets or sets the block size (in tokens) for the paged KV-cache when enabled.
+ ///
+ ///
+ /// Common values are 16 or 32. Smaller blocks reduce internal fragmentation; larger blocks reduce table overhead.
+ ///
+ public int PagedKVCacheBlockSize { get; set; } = 16;
+
+ #endregion
+
+ #region Attention Settings
+
+ ///
+ /// Gets or sets whether Flash Attention is enabled (when applicable).
+ ///
+ ///
+ /// Flash Attention computes exact attention without materializing the full N×N attention matrix,
+ /// reducing memory bandwidth pressure and improving throughput for long sequences.
+ ///
+ public bool EnableFlashAttention { get; set; } = true;
+
+ ///
+ /// Gets or sets how attention masking should be applied for optimized attention implementations.
+ ///
+ ///
+ /// - Auto: Applies causal masking for known autoregressive models (e.g., text generation), otherwise no mask.
+ /// - Disabled: Never applies causal masking.
+ /// - Causal: Always applies causal masking (GPT-style).
+ ///
+ public AttentionMaskingMode AttentionMasking { get; set; } = AttentionMaskingMode.Auto;
+
#endregion
#region Batching Settings
@@ -251,6 +342,18 @@ public void Validate()
throw new InvalidOperationException(
$"SpeculationDepth must be non-negative. Got: {SpeculationDepth}");
}
+
+ if (UseSlidingWindowKVCache && KVCacheWindowSize <= 0)
+ {
+ throw new InvalidOperationException(
+ $"KVCacheWindowSize must be positive when UseSlidingWindowKVCache is enabled. Got: {KVCacheWindowSize}");
+ }
+
+ if (EnablePagedKVCache && PagedKVCacheBlockSize <= 0)
+ {
+ throw new InvalidOperationException(
+ $"PagedKVCacheBlockSize must be positive when EnablePagedKVCache is enabled. Got: {PagedKVCacheBlockSize}");
+ }
}
#endregion
@@ -294,9 +397,14 @@ public void Validate()
///
/// Options:
/// - NGram: Simple statistical model (fast, no GPU needed)
- /// - SmallNeural: Smaller version of the main model (more accurate drafts)
+ /// - SmallNeural: Smaller companion model (more accurate drafts)
///
/// NGram is usually sufficient and has near-zero overhead.
+ ///
+ ///
+ /// Note: Small neural draft models require an external companion model. In the MVP, the library
+ /// falls back to when a companion draft model is not available.
+ ///
///
///
public DraftModelType DraftModelType { get; set; } = DraftModelType.NGram;
@@ -331,9 +439,110 @@ public void Validate()
///
public bool UseTreeSpeculation { get; set; } = false;
+ ///
+ /// Gets or sets the policy for when speculative decoding should run.
+ ///
+ ///
+ /// Auto is recommended: it can back off speculative decoding under high load (e.g., large batches)
+ /// to avoid throughput regressions, while still enabling it for latency-sensitive scenarios.
+ ///
+ public SpeculationPolicy SpeculationPolicy { get; set; } = SpeculationPolicy.Auto;
+
+ ///
+ /// Gets or sets the speculative decoding method.
+ ///
+ ///
+ ///
+ /// The default currently selects .
+ ///
+ ///
+ /// For Beginners: This chooses the "style" of speculative decoding.
+ ///
+ ///
+ public SpeculativeMethod SpeculativeMethod { get; set; } = SpeculativeMethod.Auto;
+
+ #endregion
+
+ #region Inference Quantization (Advanced)
+
+ ///
+ /// Gets or sets whether weight-only INT8 quantization is enabled for inference.
+ ///
+ ///
+ ///
+ /// Weight-only quantization reduces memory bandwidth and improves cache locality by storing weights in int8
+ /// with per-output scaling. Activations remain in FP32/FP16, and accumulation is performed in float.
+ ///
+ ///
+ /// For Beginners: This makes your model weights smaller so the CPU/GPU can read them faster.
+ ///
+ ///
+ /// This is disabled by default until validated across more layer types and kernels. When enabled, the optimizer
+ /// will apply it opportunistically and fall back safely when unsupported.
+ ///
+ ///
+ public bool EnableWeightOnlyQuantization { get; set; } = false;
+
#endregion
}
+///
+/// Policies for enabling/disabling speculative decoding at runtime.
+///
+public enum SpeculationPolicy
+{
+ ///
+ /// Automatically decide based on runtime conditions (recommended).
+ ///
+ Auto,
+
+ ///
+ /// Always enable speculative decoding when configured.
+ ///
+ ForceOn,
+
+ ///
+ /// Always disable speculative decoding even if enabled in config.
+ ///
+ ForceOff,
+
+ ///
+ /// Prefer speculative decoding to reduce latency, even under moderate load.
+ ///
+ LatencyFirst,
+
+ ///
+ /// Prefer throughput and stability: use speculative decoding only when conditions are ideal.
+ ///
+ ThroughputFirst
+}
+
+///
+/// Selects the speculative decoding method.
+///
+public enum SpeculativeMethod
+{
+ ///
+ /// Automatically select the best available method (defaults to ClassicDraftModel today).
+ ///
+ Auto,
+
+ ///
+ /// Classic draft-model speculative decoding (standard).
+ ///
+ ClassicDraftModel,
+
+ ///
+ /// Medusa-style multi-head proposals (hook for future internal implementation).
+ ///
+ Medusa,
+
+ ///
+ /// EAGLE-style enhanced draft proposals (hook for future internal implementation).
+ ///
+ Eagle
+}
+
///
/// Cache eviction policies for KV cache management.
///
@@ -356,6 +565,65 @@ public enum DraftModelType
NGram,
/// Small neural network model (more accurate, uses GPU).
SmallNeural,
- /// Custom user-provided draft model.
+ /// Custom draft model (internal/serving integration).
Custom
}
+
+///
+/// Controls how attention masking is applied for optimized attention implementations.
+///
+public enum AttentionMaskingMode
+{
+ ///
+ /// Automatically select masking based on model/task heuristics.
+ ///
+ Auto,
+
+ ///
+ /// Do not apply causal masking.
+ ///
+ Disabled,
+
+ ///
+ /// Apply causal masking (autoregressive decoding).
+ ///
+ Causal
+}
+
+///
+/// Controls the numeric precision of KV-cache storage.
+///
+public enum KVCachePrecisionMode
+{
+ ///
+ /// Select an industry-standard default.
+ ///
+ ///
+ ///
+ /// Uses FP16 when KV-cache is enabled and the numeric type supports conversion; otherwise falls back to FP32.
+ ///
+ ///
+ Auto,
+
+ ///
+ /// Store KV-cache in FP16 (half precision) to reduce memory use.
+ ///
+ Float16,
+
+ ///
+ /// Store KV-cache in FP32 (single precision) for maximal numerical fidelity.
+ ///
+ Float32
+}
+
+///
+/// Controls optional KV-cache quantization for inference.
+///
+public enum KVCacheQuantizationMode
+{
+ /// No quantization (default).
+ None,
+
+ /// Signed int8 quantization with scaling (advanced, opt-in).
+ Int8
+}
diff --git a/src/Helpers/DeserializationHelper.cs b/src/Helpers/DeserializationHelper.cs
index 78777efa7..a8f633e6a 100644
--- a/src/Helpers/DeserializationHelper.cs
+++ b/src/Helpers/DeserializationHelper.cs
@@ -48,6 +48,13 @@ static DeserializationHelper()
///
public static ILayer CreateLayerFromType(string layerType, int[] inputShape, int[] outputShape, Dictionary? additionalParams = null)
{
+ // Allow layerType to contain serialized constructor metadata, e.g. "MultiHeadAttentionLayer;HeadCount=8".
+ if (TryParseLayerTypeIdentifier(layerType, out var parsedTypeName, out var parsedParams))
+ {
+ layerType = parsedTypeName;
+ additionalParams = MergeParams(additionalParams, parsedParams);
+ }
+
if (!LayerTypes.TryGetValue(layerType, out Type? openGenericType))
{
throw new NotSupportedException($"Layer type {layerType} is not supported for deserialization.");
@@ -88,15 +95,159 @@ public static ILayer CreateLayerFromType(string layerType, int[] inputShap
if (genericDef == typeof(DenseLayer<>))
{
- // DenseLayer(int inputSize, int outputSize, IActivationFunction? activationFunction = null)
- // Use specific constructor to avoid ambiguity with vector activation constructor
+ instance = CreateDenseLayer(type, inputShape, outputShape, additionalParams);
+ }
+ else if (genericDef == typeof(InputLayer<>))
+ {
+ // InputLayer(int inputSize)
+ var ctor = type.GetConstructor([typeof(int)]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find InputLayer constructor with (int).");
+ }
+
+ instance = ctor.Invoke([inputShape[0]]);
+ }
+ else if (genericDef == typeof(ReshapeLayer<>))
+ {
+ // ReshapeLayer(int[] inputShape, int[] outputShape)
+ var ctor = type.GetConstructor([typeof(int[]), typeof(int[])]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find ReshapeLayer constructor with (int[], int[]).");
+ }
+
+ instance = ctor.Invoke([inputShape, outputShape]);
+ }
+ else if (genericDef == typeof(EmbeddingLayer<>))
+ {
+ // EmbeddingLayer(int vocabularySize, int embeddingDimension)
+ int embeddingDim = outputShape[0];
+ int vocabSize = TryGetInt(additionalParams, "VocabularySize")
+ ?? TryGetInt(additionalParams, "VocabSize")
+ ?? throw new InvalidOperationException("EmbeddingLayer requires VocabularySize metadata for deserialization.");
+
+ var ctor = type.GetConstructor([typeof(int), typeof(int)]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find EmbeddingLayer constructor with (int, int).");
+ }
+ instance = ctor.Invoke([vocabSize, embeddingDim]);
+ }
+ else if (genericDef == typeof(PositionalEncodingLayer<>))
+ {
+ // PositionalEncodingLayer(int maxSequenceLength, int embeddingSize)
+ if (inputShape.Length < 2)
+ {
+ throw new InvalidOperationException("PositionalEncodingLayer requires input shape [maxSequenceLength, embeddingSize].");
+ }
+
+ int maxSeqLen = inputShape[0];
+ int embDim = inputShape[1];
+
+ var ctor = type.GetConstructor([typeof(int), typeof(int)]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find PositionalEncodingLayer constructor with (int, int).");
+ }
+ instance = ctor.Invoke([maxSeqLen, embDim]);
+ }
+ else if (genericDef == typeof(DropoutLayer<>))
+ {
+ // DropoutLayer(double dropoutRate = 0.5)
+ double rate = TryGetDouble(additionalParams, "DropoutRate") ?? 0.5;
+ var ctor = type.GetConstructor([typeof(double)]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find DropoutLayer constructor with (double).");
+ }
+ instance = ctor.Invoke([rate]);
+ }
+ else if (genericDef == typeof(LayerNormalizationLayer<>))
+ {
+ // LayerNormalizationLayer(int featureSize, double epsilon = ...)
+ int featureSize = inputShape[0];
+ double epsilon = TryGetDouble(additionalParams, "Epsilon") ?? NumericalStabilityHelper.LargeEpsilon;
+ var ctor = type.GetConstructor([typeof(int), typeof(double)]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find LayerNormalizationLayer constructor with (int, double).");
+ }
+ instance = ctor.Invoke([featureSize, epsilon]);
+ }
+ else if (genericDef == typeof(MultiHeadAttentionLayer<>))
+ {
+ instance = CreateMultiHeadAttentionLayer(type, inputShape, additionalParams);
+ }
+ else if (genericDef == typeof(SelfAttentionLayer<>))
+ {
+ // SelfAttentionLayer(int sequenceLength, int embeddingDimension, int headCount = 8, IActivationFunction? = null)
+ if (inputShape.Length < 2)
+ {
+ throw new InvalidOperationException("SelfAttentionLayer requires input shape [sequenceLength, embeddingDimension].");
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = TryGetInt(additionalParams, "HeadCount") ?? ResolveDefaultHeadCount(embDim);
+
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), typeof(int), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find SelfAttentionLayer constructor with (int, int, int, IActivationFunction).");
+ }
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ instance = ctor.Invoke([seqLen, embDim, headCount, activation]);
+ }
+ else if (genericDef == typeof(AttentionLayer<>))
+ {
+ // AttentionLayer(int inputSize, int attentionSize, IActivationFunction? = null)
+ int inputSize = inputShape[0];
+ int attentionSize = outputShape[0];
+
var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
var ctor = type.GetConstructor([typeof(int), typeof(int), activationFuncType]);
if (ctor is null)
{
- throw new InvalidOperationException($"Cannot find DenseLayer constructor with (int, int, IActivationFunction).");
+ throw new InvalidOperationException("Cannot find AttentionLayer constructor with (int, int, IActivationFunction).");
}
- instance = ctor.Invoke([inputShape[0], outputShape[0], null]);
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ instance = ctor.Invoke([inputSize, attentionSize, activation]);
+ }
+ else if (genericDef == typeof(GraphAttentionLayer<>))
+ {
+ // GraphAttentionLayer(int inputFeatures, int outputFeatures, int numHeads = 1, double alpha = 0.2, double dropoutRate = 0.0, IActivationFunction? = null)
+ int inputFeatures = inputShape[0];
+ int outputFeatures = outputShape[0];
+ int numHeads = TryGetInt(additionalParams, "NumHeads") ?? 1;
+ double alpha = TryGetDouble(additionalParams, "Alpha") ?? 0.2;
+ double dropout = TryGetDouble(additionalParams, "DropoutRate") ?? 0.0;
+
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), typeof(int), typeof(double), typeof(double), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find GraphAttentionLayer constructor with expected signature.");
+ }
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ instance = ctor.Invoke([inputFeatures, outputFeatures, numHeads, alpha, dropout, activation]);
+ }
+ else if (genericDef == typeof(AiDotNet.NeuralNetworks.Attention.FlashAttentionLayer<>))
+ {
+ instance = CreateFlashAttentionLayer(type, inputShape, additionalParams);
+ }
+ else if (genericDef == typeof(AiDotNet.Inference.CachedMultiHeadAttention<>))
+ {
+ instance = CreateCachedMultiHeadAttention(type, inputShape, additionalParams);
+ }
+ else if (genericDef == typeof(AiDotNet.Inference.PagedCachedMultiHeadAttention<>))
+ {
+ instance = CreatePagedCachedMultiHeadAttention(type, inputShape, additionalParams);
+ }
+ else if (genericDef == typeof(AiDotNet.LoRA.Adapters.MultiLoRAAdapter<>))
+ {
+ instance = CreateMultiLoRAAdapter(type, inputShape, outputShape, additionalParams);
}
else if (genericDef == typeof(ConvolutionalLayer<>))
{
@@ -139,42 +290,442 @@ public static ILayer CreateLayerFromType(string layerType, int[] inputShap
}
else if (genericDef == typeof(ActivationLayer<>))
{
- // ActivationLayer(int[] inputShape, IActivationFunction activationFunction)
+ instance = CreateActivationLayer(type, inputShape, additionalParams);
+ }
+ else
+ {
+ // Default: pass inputShape as first parameter
+ var ctor = type.GetConstructor([typeof(int[])]);
+ if (ctor is null)
+ {
+ throw new NotSupportedException(
+ $"Layer type {layerType} is not supported for deserialization (no known constructor found).");
+ }
+
+ instance = ctor.Invoke([inputShape]);
+ }
+ if (instance == null)
+ {
+ throw new InvalidOperationException($"Failed to create instance of layer type {layerType}.");
+ }
+
+ return (ILayer)instance;
+ }
+
+ private static object CreateDenseLayer(Type type, int[] inputShape, int[] outputShape, Dictionary? additionalParams)
+ {
+ // DenseLayer(int inputSize, int outputSize, IActivationFunction? activationFunction = null)
+ // Use specific constructor to avoid ambiguity with vector activation constructor.
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find DenseLayer constructor with (int, int, IActivationFunction).");
+ }
+
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ return ctor.Invoke([inputShape[0], outputShape[0], activation]);
+ }
+
+ private static object CreateMultiHeadAttentionLayer(Type type, int[] inputShape, Dictionary? additionalParams)
+ {
+ // MultiHeadAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, IActivationFunction? activationFunction = null)
+ if (inputShape.Length < 2)
+ {
+ throw new InvalidOperationException("MultiHeadAttentionLayer requires input shape [sequenceLength, embeddingDimension].");
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = TryGetInt(additionalParams, "HeadCount") ?? ResolveDefaultHeadCount(embDim);
+
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), typeof(int), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find MultiHeadAttentionLayer constructor with (int, int, int, IActivationFunction).");
+ }
+
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ return ctor.Invoke([seqLen, embDim, headCount, activation]);
+ }
+
+ private static object CreateFlashAttentionLayer(Type type, int[] inputShape, Dictionary? additionalParams)
+ {
+ // FlashAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, FlashAttentionConfig config, IActivationFunction? activationFunction = null)
+ if (inputShape.Length < 2)
+ {
+ throw new InvalidOperationException("FlashAttentionLayer requires input shape [sequenceLength, embeddingDimension].");
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = TryGetInt(additionalParams, "HeadCount") ?? ResolveDefaultHeadCount(embDim);
+ bool useCausal = TryGetBool(additionalParams, "UseCausalMask") ?? false;
+
+ var flashConfig = AiDotNet.NeuralNetworks.Attention.FlashAttentionConfig.Default;
+ flashConfig.UseCausalMask = useCausal;
+
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), typeof(int), typeof(AiDotNet.NeuralNetworks.Attention.FlashAttentionConfig), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find FlashAttentionLayer constructor with expected signature.");
+ }
+
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ return ctor.Invoke([seqLen, embDim, headCount, flashConfig, activation]);
+ }
+
+ private static object CreateCachedMultiHeadAttention(Type type, int[] inputShape, Dictionary? additionalParams)
+ {
+ // CachedMultiHeadAttention(int sequenceLength, int embeddingDimension, int headCount, bool useFlashAttention, int layerIndex, bool useCausalMask, IActivationFunction? activationFunction = null)
+ if (inputShape.Length < 2)
+ {
+ throw new InvalidOperationException("CachedMultiHeadAttention requires input shape [sequenceLength, embeddingDimension].");
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = TryGetInt(additionalParams, "HeadCount") ?? ResolveDefaultHeadCount(embDim);
+ bool useFlash = TryGetBool(additionalParams, "UseFlashAttention") ?? true;
+ bool useCausal = TryGetBool(additionalParams, "UseCausalMask") ?? true;
+
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), typeof(int), typeof(bool), typeof(int), typeof(bool), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find CachedMultiHeadAttention constructor with expected signature.");
+ }
+
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ return ctor.Invoke([seqLen, embDim, headCount, useFlash, 0, useCausal, activation]);
+ }
+
+ private static object CreatePagedCachedMultiHeadAttention(Type type, int[] inputShape, Dictionary? additionalParams)
+ {
+ // PagedCachedMultiHeadAttention(int sequenceLength, int embeddingDimension, int headCount, bool useCausalMask, IActivationFunction? activationFunction = null)
+ if (inputShape.Length < 2)
+ {
+ throw new InvalidOperationException("PagedCachedMultiHeadAttention requires input shape [sequenceLength, embeddingDimension].");
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = TryGetInt(additionalParams, "HeadCount") ?? ResolveDefaultHeadCount(embDim);
+ bool useCausal = TryGetBool(additionalParams, "UseCausalMask") ?? true;
+
+ var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([typeof(int), typeof(int), typeof(int), typeof(bool), activationFuncType]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find PagedCachedMultiHeadAttention constructor with expected signature.");
+ }
+
+ object? activation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", activationFuncType);
+ return ctor.Invoke([seqLen, embDim, headCount, useCausal, activation]);
+ }
+
+ private static object CreateMultiLoRAAdapter(Type type, int[] inputShape, int[] outputShape, Dictionary? additionalParams)
+ {
+ // MultiLoRAAdapter(ILayer baseLayer, string defaultTaskName, int defaultRank, double alpha, bool freezeBaseLayer)
+ bool freezeBaseLayer = TryGetBool(additionalParams, "FreezeBaseLayer") ?? true;
+
+ string? encodedBaseLayerId = additionalParams?.TryGetValue("BaseLayerTypeId", out var baseType) == true ? baseType as string : null;
+ string baseLayerIdentifier = !string.IsNullOrWhiteSpace(encodedBaseLayerId)
+ ? Uri.UnescapeDataString(encodedBaseLayerId)
+ : "DenseLayer`1";
+
+ var baseLayer = CreateLayerFromType(baseLayerIdentifier, inputShape, outputShape, null);
+
+ static string[] ParseList(string? raw)
+ {
+ if (string.IsNullOrWhiteSpace(raw)) return Array.Empty();
+ return raw!.Split(new[] { '|' }, StringSplitOptions.RemoveEmptyEntries);
+ }
+
+ static int[] ParseIntList(string? raw)
+ {
+ var parts = ParseList(raw);
+ var result = new int[parts.Length];
+ for (int i = 0; i < parts.Length; i++)
+ {
+ result[i] = int.TryParse(parts[i], System.Globalization.NumberStyles.Integer, System.Globalization.CultureInfo.InvariantCulture, out var v) ? v : 1;
+ }
+ return result;
+ }
+
+ static double[] ParseDoubleList(string? raw)
+ {
+ var parts = ParseList(raw);
+ var result = new double[parts.Length];
+ for (int i = 0; i < parts.Length; i++)
+ {
+ result[i] = double.TryParse(parts[i], System.Globalization.NumberStyles.Float, System.Globalization.CultureInfo.InvariantCulture, out var v) ? v : -1;
+ }
+ return result;
+ }
+
+ string? tasksRaw = additionalParams?.TryGetValue("Tasks", out var tasksObj) == true ? tasksObj as string : null;
+ var encodedTasks = ParseList(tasksRaw);
+ if (encodedTasks.Length == 0)
+ {
+ encodedTasks = ["default"];
+ }
+
+ var tasks = encodedTasks.Select(Uri.UnescapeDataString).ToArray();
+ var ranks = ParseIntList(additionalParams?.TryGetValue("TaskRanks", out var ranksObj) == true ? ranksObj as string : null);
+ var alphas = ParseDoubleList(additionalParams?.TryGetValue("TaskAlphas", out var alphasObj) == true ? alphasObj as string : null);
+
+ int defaultRank = ranks.Length > 0 ? ranks[0] : 1;
+ double defaultAlpha = alphas.Length > 0 ? alphas[0] : -1;
+
+ var iLayerType = typeof(ILayer<>).MakeGenericType(typeof(T));
+ var ctor = type.GetConstructor([iLayerType, typeof(string), typeof(int), typeof(double), typeof(bool)]);
+ if (ctor is null)
+ {
+ throw new InvalidOperationException("Cannot find MultiLoRAAdapter constructor with expected signature.");
+ }
+
+ var instance = ctor.Invoke([baseLayer, tasks[0], defaultRank, defaultAlpha, freezeBaseLayer]);
+ var multi = (AiDotNet.LoRA.Adapters.MultiLoRAAdapter)instance;
+
+ for (int taskIndex = 1; taskIndex < tasks.Length; taskIndex++)
+ {
+ int rank = taskIndex < ranks.Length ? ranks[taskIndex] : defaultRank;
+ double alpha = taskIndex < alphas.Length ? alphas[taskIndex] : -1;
+ multi.AddTask(tasks[taskIndex], rank, alpha);
+ }
+
+ if (additionalParams?.TryGetValue("CurrentTask", out var currentTaskObj) == true &&
+ currentTaskObj is string currentTaskEncoded)
+ {
+ string currentTask = Uri.UnescapeDataString(currentTaskEncoded);
+ if (!string.IsNullOrWhiteSpace(currentTask))
+ {
+ multi.SetCurrentTask(currentTask);
+ }
+ }
+
+ return instance;
+ }
+
+ private static object CreateActivationLayer(Type type, int[] inputShape, Dictionary? additionalParams)
+ {
+ // ActivationLayer(int[] inputShape, IActivationFunction activationFunction)
+ var scalarActivationType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
+ var vectorActivationType = typeof(IVectorActivationFunction<>).MakeGenericType(typeof(T));
+
+ object? vectorActivation = TryCreateActivationInstance(additionalParams, "VectorActivationType", vectorActivationType);
+ object? scalarActivation = TryCreateActivationInstance(additionalParams, "ScalarActivationType", scalarActivationType);
+
+ object? activationFunction = vectorActivation ?? scalarActivation;
+
+ if (activationFunction == null)
+ {
+ // Back-compat fallback: use enum if available, otherwise default ReLU.
ActivationFunction activationFunctionEnum = additionalParams?.TryGetValue("ActivationFunction", out var af) == true
? (ActivationFunction)af : ActivationFunction.ReLU;
- // Use ActivationFunctionFactory to create the IActivationFunction from enum
+
var factoryType = typeof(ActivationFunctionFactory<>).MakeGenericType(typeof(T));
var createMethod = factoryType.GetMethod("CreateActivationFunction", BindingFlags.Public | BindingFlags.Static);
if (createMethod is null)
{
throw new InvalidOperationException("Cannot find ActivationFunctionFactory.CreateActivationFunction method.");
}
- object? activationFunction = createMethod.Invoke(null, [activationFunctionEnum]);
- if (activationFunction is null)
+
+ activationFunction = createMethod.Invoke(null, [activationFunctionEnum]);
+ }
+
+ if (activationFunction == null)
+ {
+ throw new InvalidOperationException("Failed to create activation function for ActivationLayer.");
+ }
+
+ if (vectorActivationType.IsInstanceOfType(activationFunction))
+ {
+ var ctor = type.GetConstructor([typeof(int[]), vectorActivationType]);
+ if (ctor is null)
{
- throw new InvalidOperationException($"Failed to create activation function for {activationFunctionEnum}.");
+ throw new InvalidOperationException("Cannot find ActivationLayer constructor with (int[], IVectorActivationFunction).");
}
+ return ctor.Invoke([inputShape, activationFunction]);
+ }
- // Use specific constructor to avoid ambiguity with vector activation constructor
- var activationFuncType = typeof(IActivationFunction<>).MakeGenericType(typeof(T));
- var ctor = type.GetConstructor([typeof(int[]), activationFuncType]);
- if (ctor is null)
+ var scalarCtor = type.GetConstructor([typeof(int[]), scalarActivationType]);
+ if (scalarCtor is null)
+ {
+ throw new InvalidOperationException("Cannot find ActivationLayer constructor with (int[], IActivationFunction).");
+ }
+ return scalarCtor.Invoke([inputShape, activationFunction]);
+ }
+
+ private static bool TryParseLayerTypeIdentifier(
+ string identifier,
+ out string typeName,
+ out Dictionary parameters)
+ {
+ typeName = identifier;
+ parameters = new Dictionary(StringComparer.Ordinal);
+
+ int sep = identifier.IndexOf(';');
+ if (sep < 0)
+ {
+ return false;
+ }
+
+ typeName = identifier.Substring(0, sep);
+ var parts = identifier.Substring(sep + 1).Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
+ foreach (var part in parts)
+ {
+ int eq = part.IndexOf('=');
+ if (eq <= 0 || eq == part.Length - 1)
+ {
+ continue;
+ }
+
+ string key = part.Substring(0, eq);
+ string value = part.Substring(eq + 1);
+
+ if (int.TryParse(value, System.Globalization.NumberStyles.Integer, System.Globalization.CultureInfo.InvariantCulture, out int i))
+ {
+ parameters[key] = i;
+ }
+ else if (long.TryParse(value, System.Globalization.NumberStyles.Integer, System.Globalization.CultureInfo.InvariantCulture, out long l))
{
- throw new InvalidOperationException($"Cannot find ActivationLayer constructor with (int[], IActivationFunction).");
+ parameters[key] = l;
+ }
+ else if (double.TryParse(value, System.Globalization.NumberStyles.Float, System.Globalization.CultureInfo.InvariantCulture, out double d))
+ {
+ parameters[key] = d;
+ }
+ else if (bool.TryParse(value, out bool b))
+ {
+ parameters[key] = b;
+ }
+ else
+ {
+ parameters[key] = value;
}
- instance = ctor.Invoke([inputShape, activationFunction]);
}
- else
+
+ return true;
+ }
+
+ private static Dictionary MergeParams(
+ Dictionary? original,
+ Dictionary parsed)
+ {
+ if (original == null || original.Count == 0)
{
- // Default: pass inputShape as first parameter
- instance = Activator.CreateInstance(type, [inputShape]);
+ return parsed;
}
- if (instance == null)
+
+ foreach (var kvp in parsed)
{
- throw new InvalidOperationException($"Failed to create instance of layer type {layerType}.");
+ original[kvp.Key] = kvp.Value;
}
- return (ILayer)instance;
+ return original;
+ }
+
+ private static int? TryGetInt(Dictionary? parameters, string key)
+ {
+ if (parameters != null && parameters.TryGetValue(key, out var value) && value != null)
+ {
+ if (value is int i)
+ return i;
+ if (value is long l && l >= int.MinValue && l <= int.MaxValue)
+ return (int)l;
+ if (int.TryParse(value.ToString() ?? string.Empty, out int parsed))
+ return parsed;
+ }
+ return null;
+ }
+
+ private static double? TryGetDouble(Dictionary? parameters, string key)
+ {
+ if (parameters != null && parameters.TryGetValue(key, out var value) && value != null)
+ {
+ if (value is double d)
+ return d;
+ if (double.TryParse(value.ToString() ?? string.Empty, System.Globalization.NumberStyles.Float, System.Globalization.CultureInfo.InvariantCulture, out double parsed))
+ return parsed;
+ }
+ return null;
+ }
+
+ private static bool? TryGetBool(Dictionary? parameters, string key)
+ {
+ if (parameters != null && parameters.TryGetValue(key, out var value) && value != null)
+ {
+ if (value is bool b)
+ return b;
+ if (bool.TryParse(value.ToString() ?? string.Empty, out bool parsed))
+ return parsed;
+ }
+ return null;
+ }
+
+ private static object? TryCreateActivationInstance(
+ Dictionary? parameters,
+ string key,
+ Type expectedInterface)
+ {
+ if (parameters == null || !parameters.TryGetValue(key, out var value) || value == null)
+ {
+ return null;
+ }
+
+ string? typeName = value as string ?? value.ToString() ?? string.Empty;
+ if (string.IsNullOrWhiteSpace(typeName))
+ {
+ return null;
+ }
+
+ var type = Type.GetType(typeName, throwOnError: false);
+ if (type == null)
+ {
+ return null;
+ }
+
+ try
+ {
+ var instance = Activator.CreateInstance(type);
+ if (instance == null)
+ {
+ return null;
+ }
+
+ return expectedInterface.IsInstanceOfType(instance) ? instance : null;
+ }
+ catch (MissingMethodException)
+ {
+ return null;
+ }
+ catch (TargetInvocationException ex) when (ex.InnerException is MissingMethodException)
+ {
+ return null;
+ }
+ catch (Exception ex)
+ {
+ // Best-effort: deserialization should not throw if an optional activation cannot be created.
+ System.Diagnostics.Debug.WriteLine($"Unexpected error deserializing activation {typeName}: {ex.Message}");
+ return null;
+ }
+ }
+
+ private static int ResolveDefaultHeadCount(int embeddingDimension)
+ {
+ // Conservative but practical default: prefer common head counts if divisible, otherwise fall back to 1.
+ foreach (var candidate in new[] { 8, 4, 16, 12, 6, 2, 1 })
+ {
+ if (candidate > 0 && embeddingDimension % candidate == 0)
+ {
+ return candidate;
+ }
+ }
+ return 1;
}
///
@@ -203,7 +754,24 @@ public static ILayer CreateLayerFromType(string layerType, int[] inputShap
throw new InvalidOperationException($"Type {typeName} does not implement interface {typeof(TInterface).Name}");
}
- return (TInterface?)Activator.CreateInstance(type)
- ?? throw new InvalidOperationException($"Failed to create instance of type {typeName}");
+ try
+ {
+ return (TInterface?)Activator.CreateInstance(type);
+ }
+ catch (MissingMethodException)
+ {
+ // Some implementations (e.g., optimizers) require constructor arguments.
+ // Treat them as optional on deserialization and let callers provide sensible defaults.
+ return null;
+ }
+ catch (TargetInvocationException ex) when (ex.InnerException is MissingMethodException)
+ {
+ // Same as above: no parameterless ctor available.
+ return null;
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidOperationException($"Failed to instantiate type {typeName}", ex);
+ }
}
-}
\ No newline at end of file
+}
diff --git a/src/Helpers/InferenceDiagnostics.cs b/src/Helpers/InferenceDiagnostics.cs
new file mode 100644
index 000000000..11b501a2a
--- /dev/null
+++ b/src/Helpers/InferenceDiagnostics.cs
@@ -0,0 +1,89 @@
+using System;
+using System.Collections.Concurrent;
+
+namespace AiDotNet.Helpers;
+
+///
+/// Internal diagnostics for inference decisions (non-user-facing).
+/// Enable by setting env var AIDOTNET_DIAGNOSTICS=1.
+///
+internal static class InferenceDiagnostics
+{
+ private const int MaxEntries = 1024;
+
+ private static readonly ConcurrentQueue Entries = new();
+
+ private static bool IsEnabled()
+ {
+ var value = Environment.GetEnvironmentVariable("AIDOTNET_DIAGNOSTICS");
+ return string.Equals(value, "1", StringComparison.OrdinalIgnoreCase) ||
+ string.Equals(value, "true", StringComparison.OrdinalIgnoreCase);
+ }
+
+ internal static void RecordDecision(string area, string feature, bool enabled, string reason)
+ {
+ if (!IsEnabled())
+ return;
+
+ Entries.Enqueue(new InferenceDiagnosticEntry(
+ TimestampUtc: DateTime.UtcNow,
+ Area: area ?? string.Empty,
+ Feature: feature ?? string.Empty,
+ Enabled: enabled,
+ Reason: reason ?? string.Empty,
+ ExceptionType: null,
+ ExceptionMessage: null));
+
+ TrimIfNeeded();
+ }
+
+ internal static void RecordException(string area, string feature, Exception ex, string reason)
+ {
+ if (!IsEnabled())
+ return;
+
+ Entries.Enqueue(new InferenceDiagnosticEntry(
+ TimestampUtc: DateTime.UtcNow,
+ Area: area ?? string.Empty,
+ Feature: feature ?? string.Empty,
+ Enabled: false,
+ Reason: reason ?? string.Empty,
+ ExceptionType: ex.GetType().FullName ?? ex.GetType().Name,
+ ExceptionMessage: ex.Message));
+
+ TrimIfNeeded();
+ }
+
+ // Intentionally internal-only: serving can use InternalsVisibleTo to read these if needed later.
+ internal static InferenceDiagnosticEntry[] Snapshot()
+ {
+ if (!IsEnabled())
+ return Array.Empty();
+
+ return Entries.ToArray();
+ }
+
+ internal static void Clear()
+ {
+ while (Entries.TryDequeue(out _))
+ {
+ }
+ }
+
+ private static void TrimIfNeeded()
+ {
+ // Best-effort: bound memory use when diagnostics are enabled.
+ while (Entries.Count > MaxEntries && Entries.TryDequeue(out _))
+ {
+ }
+ }
+
+ internal readonly record struct InferenceDiagnosticEntry(
+ DateTime TimestampUtc,
+ string Area,
+ string Feature,
+ bool Enabled,
+ string Reason,
+ string? ExceptionType,
+ string? ExceptionMessage);
+}
diff --git a/src/Inference/CachedMultiHeadAttention.cs b/src/Inference/CachedMultiHeadAttention.cs
index 59042aab2..3526b6e06 100644
--- a/src/Inference/CachedMultiHeadAttention.cs
+++ b/src/Inference/CachedMultiHeadAttention.cs
@@ -34,12 +34,13 @@ namespace AiDotNet.Inference;
///
///
/// The numeric type for computations.
-public class CachedMultiHeadAttention : LayerBase
+internal class CachedMultiHeadAttention : LayerBase, ILayerSerializationMetadata
{
private readonly int _headCount;
private readonly int _headDimension;
private readonly int _embeddingDimension;
private readonly bool _useFlashAttention;
+ private readonly bool _useCausalMask;
// Projection weights
private Matrix _queryWeights;
@@ -92,6 +93,15 @@ public class CachedMultiHeadAttention : LayerBase
///
public bool UsesFlashAttention => _useFlashAttention;
+ ///
+ /// Gets whether causal masking is enabled for attention.
+ ///
+ ///
+ /// Causal masking is required for autoregressive decoding (GPT-style), where each token may only attend
+ /// to itself and previous tokens. Disable for bidirectional attention (BERT-style) and most encoders.
+ ///
+ public bool UsesCausalMask => _useCausalMask;
+
///
/// Gets or sets the KV-Cache. Must be set before inference.
///
@@ -118,15 +128,20 @@ public int LayerIndex
/// Number of attention heads.
/// Whether to use Flash Attention algorithm.
/// Index of this layer in the transformer (for cache access).
+ /// Whether to apply causal masking (required for autoregressive decoding).
+ /// Optional activation function (defaults to identity).
public CachedMultiHeadAttention(
int sequenceLength,
int embeddingDimension,
int headCount,
bool useFlashAttention = true,
- int layerIndex = 0)
+ int layerIndex = 0,
+ bool useCausalMask = true,
+ IActivationFunction? activationFunction = null)
: base(
[sequenceLength, embeddingDimension],
- [sequenceLength, embeddingDimension])
+ [sequenceLength, embeddingDimension],
+ activationFunction ?? new IdentityActivation())
{
if (embeddingDimension % headCount != 0)
{
@@ -139,6 +154,7 @@ public CachedMultiHeadAttention(
_embeddingDimension = embeddingDimension;
_useFlashAttention = useFlashAttention;
_layerIndex = layerIndex;
+ _useCausalMask = useCausalMask;
// Initialize projection weights
_queryWeights = new Matrix(embeddingDimension, embeddingDimension);
@@ -230,13 +246,18 @@ private Tensor ForwardWithCache(Tensor input)
Tensor attentionOutput;
if (_useFlashAttention)
{
- var config = new FlashAttentionConfig { UseCausalMask = true };
- var (flashOutput, _) = FlashAttention.Forward(queries, keys, values, config);
+ var config = FlashAttentionConfig.Default;
+ config.UseCausalMask = _useCausalMask;
+
+ int seqLenKV = keys.Shape[2];
+ int seqLenQ = queries.Shape[2];
+ int queryOffset = Math.Max(0, seqLenKV - seqLenQ);
+ var (flashOutput, _) = FlashAttention.Forward(queries, keys, values, config, queryOffset: queryOffset);
attentionOutput = flashOutput;
}
else
{
- attentionOutput = StandardAttention(queries, keys, values, useCausalMask: true);
+ attentionOutput = StandardAttention(queries, keys, values, useCausalMask: _useCausalMask);
}
// Reshape back to [batch, seq, embDim]
@@ -244,9 +265,9 @@ private Tensor ForwardWithCache(Tensor input)
// Output projection
var output = attentionOutput.Multiply(_outputWeights).Add(_outputBias);
- _lastOutput = output;
+ _lastOutput = ApplyActivation(output);
- return output;
+ return _lastOutput;
}
///
@@ -272,12 +293,13 @@ private Tensor ForwardStandard(Tensor input)
if (_useFlashAttention)
{
var config = FlashAttentionConfig.Default;
+ config.UseCausalMask = _useCausalMask;
var (flashOutput, _) = FlashAttention.Forward(queries, keys, values, config);
attentionOutput = flashOutput;
}
else
{
- attentionOutput = StandardAttention(queries, keys, values, useCausalMask: false);
+ attentionOutput = StandardAttention(queries, keys, values, useCausalMask: _useCausalMask);
}
// Reshape back
@@ -285,9 +307,9 @@ private Tensor ForwardStandard(Tensor input)
// Output projection
var output = attentionOutput.Multiply(_outputWeights).Add(_outputBias);
- _lastOutput = output;
+ _lastOutput = ApplyActivation(output);
- return output;
+ return _lastOutput;
}
///
@@ -387,6 +409,8 @@ public override Tensor Backward(Tensor outputGradient)
throw new InvalidOperationException("Forward pass must be called before backward pass.");
}
+ var activationGradient = ApplyActivationDerivative(_lastOutput, outputGradient);
+
// Standard backward pass (no cache during training)
// Implementation similar to MultiHeadAttentionLayer
var inputGradient = new Tensor(_lastInput.Shape);
@@ -397,7 +421,7 @@ public override Tensor Backward(Tensor outputGradient)
_keyWeightsGradient = new Matrix(_keyWeights.Rows, _keyWeights.Columns);
_valueWeightsGradient = new Matrix(_valueWeights.Rows, _valueWeights.Columns);
_outputWeightsGradient = new Matrix(_outputWeights.Rows, _outputWeights.Columns);
- _outputBiasGradient = outputGradient.Sum([0, 1]).ToVector();
+ _outputBiasGradient = activationGradient.Sum([0, 1]).ToVector();
return inputGradient;
}
@@ -512,6 +536,7 @@ public override Dictionary GetDiagnostics()
diagnostics["HeadDimension"] = _headDimension.ToString();
diagnostics["InferenceMode"] = InferenceMode.ToString();
diagnostics["UsesFlashAttention"] = _useFlashAttention.ToString();
+ diagnostics["UsesCausalMask"] = _useCausalMask.ToString();
diagnostics["LayerIndex"] = _layerIndex.ToString();
diagnostics["CacheAttached"] = (_cache != null).ToString();
@@ -580,4 +605,14 @@ private Tensor MatrixToTensor(Matrix matrix)
}
return tensor;
}
+
+ Dictionary AiDotNet.NeuralNetworks.Layers.ILayerSerializationMetadata.GetSerializationMetadata()
+ {
+ return new Dictionary
+ {
+ ["HeadCount"] = _headCount.ToString(),
+ ["UseFlashAttention"] = _useFlashAttention.ToString(),
+ ["UseCausalMask"] = _useCausalMask.ToString()
+ };
+ }
}
diff --git a/src/Inference/InferenceOptimizer.cs b/src/Inference/InferenceOptimizer.cs
index 0095356be..a14af2ec6 100644
--- a/src/Inference/InferenceOptimizer.cs
+++ b/src/Inference/InferenceOptimizer.cs
@@ -1,8 +1,14 @@
using AiDotNet.Configuration;
using AiDotNet.NeuralNetworks;
+using AiDotNet.NeuralNetworks.Attention;
using AiDotNet.NeuralNetworks.Layers;
using AiDotNet.Inference.SpeculativeDecoding;
+using AiDotNet.Inference.PagedAttention;
+using AiDotNet.Inference.Quantization;
+using AiDotNet.Helpers;
+using AiDotNet.Tensors.Helpers;
using AiDotNet.Tensors.LinearAlgebra;
+using System.Threading;
namespace AiDotNet.Inference;
@@ -31,10 +37,15 @@ namespace AiDotNet.Inference;
///
///
/// The numeric type for computations.
-public class InferenceOptimizer
+internal class InferenceOptimizer
{
private readonly InferenceOptimizationConfig _config;
private KVCache? _kvCache;
+ private PagedKVCache? _pagedKVCache;
+ private PagedAttentionKernel? _pagedKernel;
+ private long? _pagedSequenceId;
+ private List>? _pagedAttentionLayers;
+ private static long s_nextPagedSequenceId = DateTime.UtcNow.Ticks;
private IDraftModel? _draftModel;
private SpeculativeDecoder? _speculativeDecoder;
private bool _isInitialized;
@@ -71,6 +82,57 @@ public InferenceOptimizer()
{
}
+ ///
+ /// Creates an inference-optimized model instance based on the current configuration.
+ ///
+ /// The neural network to optimize.
+ /// Whether to clone the model before applying layer-level rewrites.
+ /// The optimized model and whether any optimizations were applied.
+ ///
+ /// This method can apply stateless layer rewrites (e.g., MultiHeadAttention -> FlashAttentionLayer)
+ /// and then initialize stateful inference features (e.g., KV-cache) on the resulting model.
+ ///
+ public (NeuralNetworkBase OptimizedModel, bool AnyOptimizationsApplied) OptimizeForInference(
+ NeuralNetworkBase model,
+ bool cloneModel = true)
+ {
+ if (model == null)
+ throw new ArgumentNullException(nameof(model));
+
+ _config.Validate();
+
+ // Clone only when we might rewrite layers; otherwise keep original reference.
+ bool mayRewriteAttention = _config.EnableFlashAttention || _config.EnableKVCache || _config.EnableWeightOnlyQuantization;
+ var workingModel = model;
+ if (cloneModel && mayRewriteAttention && HasOptimizableAttentionLayers(model))
+ {
+ try
+ {
+ // NeuralNetworkBase.Clone performs a deep copy via serialization.
+ workingModel = (NeuralNetworkBase)model.Clone();
+ }
+ catch (Exception ex)
+ {
+ // Some layer types may not yet support serialization-based cloning.
+ // Do not mutate the user's original model; just skip optimizations.
+ Console.WriteLine($"Warning: model cloning failed for inference optimizations: {ex.Message}. Skipping inference optimizations for this model instance.");
+ InferenceDiagnostics.RecordException(
+ area: "InferenceOptimizer",
+ feature: "CloneForRewrite",
+ ex: ex,
+ reason: "Clone failed; skipping all inference optimizations to avoid mutating user model.");
+ return (model, false);
+ }
+ }
+
+ bool anyApplied = ApplyAttentionOptimizations(workingModel);
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "AttentionRewrites", enabled: anyApplied, reason: anyApplied ? "Applied" : "NoApplicableLayersOrDisabled");
+ anyApplied |= ApplyWeightOnlyQuantization(workingModel);
+ anyApplied |= Initialize(workingModel);
+
+ return (workingModel, anyApplied);
+ }
+
///
/// Initializes inference optimizations for a neural network model.
///
@@ -91,12 +153,20 @@ public bool Initialize(NeuralNetworkBase model)
if (model == null)
throw new ArgumentNullException(nameof(model));
+ _config.Validate();
+
bool anyOptimizationsApplied = false;
// Find and configure attention layers for KV caching
if (_config.EnableKVCache)
{
- anyOptimizationsApplied |= InitializeKVCache(model);
+ anyOptimizationsApplied |= _config.EnablePagedKVCache
+ ? InitializePagedKVCache(model)
+ : InitializeKVCache(model);
+ }
+ else
+ {
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "KVCache", enabled: false, reason: "DisabledByConfig");
}
// Initialize speculative decoding if enabled
@@ -104,6 +174,10 @@ public bool Initialize(NeuralNetworkBase model)
{
anyOptimizationsApplied |= InitializeSpeculativeDecoding(model);
}
+ else
+ {
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "SpeculativeDecoding", enabled: false, reason: "DisabledByConfig");
+ }
_isInitialized = true;
return anyOptimizationsApplied;
@@ -138,7 +212,8 @@ private bool InitializeKVCache(NeuralNetworkBase model)
var firstLayer = attentionLayers[0];
int numHeads = firstLayer.HeadCount;
int headDim = firstLayer.HeadDimension;
- int maxSeqLen = EstimateMaxSequenceLength();
+ int numLayers = attentionLayers.Count;
+ int maxSeqLen = EstimateMaxSequenceLength(numLayers, numHeads, headDim);
// Create KV cache configuration
var cacheConfig = new KVCacheConfig
@@ -148,7 +223,12 @@ private bool InitializeKVCache(NeuralNetworkBase model)
HeadDimension = headDim,
MaxSequenceLength = maxSeqLen,
MaxBatchSize = _config.MaxBatchSize,
- PreAllocate = true
+ PreAllocate = true,
+ UseSlidingWindow = _config.UseSlidingWindowKVCache,
+ WindowSize = _config.UseSlidingWindowKVCache
+ ? Math.Min(_config.KVCacheWindowSize, maxSeqLen)
+ : 1024,
+ DataType = ResolveKVCacheDataType()
};
// Create and attach KV cache
@@ -164,21 +244,482 @@ private bool InitializeKVCache(NeuralNetworkBase model)
return true;
}
+ private CacheDataType ResolveKVCacheDataType()
+ {
+ bool fp16Capable = typeof(T) == typeof(float) || typeof(T) == typeof(double) || typeof(T) == typeof(Half);
+ bool int8Capable = fp16Capable;
+
+ CacheDataType resolved;
+ if (_config.KVCacheQuantization == KVCacheQuantizationMode.Int8 && int8Capable)
+ {
+ resolved = CacheDataType.Int8;
+ }
+ else
+ {
+ resolved = _config.KVCachePrecision switch
+ {
+ KVCachePrecisionMode.Float32 => CacheDataType.Float32,
+ KVCachePrecisionMode.Float16 => fp16Capable ? CacheDataType.Float16 : CacheDataType.Float32,
+ _ => fp16Capable ? CacheDataType.Float16 : CacheDataType.Float32
+ };
+ }
+
+ InferenceDiagnostics.RecordDecision(
+ area: "InferenceOptimizer",
+ feature: "KVCachePrecision",
+ enabled: resolved == CacheDataType.Float16 || resolved == CacheDataType.Int8,
+ reason: $"Precision={_config.KVCachePrecision};Quant={_config.KVCacheQuantization};Resolved={resolved};Type={typeof(T).Name}");
+
+ return resolved;
+ }
+
+ private bool InitializePagedKVCache(NeuralNetworkBase model)
+ {
+ var attentionLayers = new List>();
+ int layerIndex = 0;
+
+ foreach (var layer in model.Layers)
+ {
+ if (layer is PagedCachedMultiHeadAttention pagedAttention)
+ {
+ pagedAttention.LayerIndex = layerIndex;
+ attentionLayers.Add(pagedAttention);
+ layerIndex++;
+ }
+ }
+
+ if (attentionLayers.Count == 0)
+ {
+ // No paged attention layers present; fall back to contiguous cache if applicable.
+ return InitializeKVCache(model);
+ }
+
+ var firstLayer = attentionLayers[0];
+ int numHeads = firstLayer.HeadCount;
+ int headDim = firstLayer.HeadDimension;
+ int numLayers = attentionLayers.Count;
+
+ long availableBytes = (long)_config.KVCacheMaxSizeMB * 1024 * 1024;
+ int blockSize = _config.PagedKVCacheBlockSize;
+
+ _pagedKVCache = PagedKVCache.FromMemorySize(availableBytes, numLayers, numHeads, headDim, blockSize);
+ _pagedKernel = new PagedAttentionKernel(_pagedKVCache, new PagedAttentionConfig
+ {
+ NumHeads = numHeads,
+ HeadDimension = headDim,
+ BlockSize = blockSize,
+ MaxBatchSize = _config.MaxBatchSize
+ });
+
+ // Allocate a fresh sequence ID for this optimized model instance (one model == one sequence).
+ if (!TryAllocatePagedSequenceId(_pagedKVCache, initialTokens: 0, out long sequenceId))
+ {
+ InferenceDiagnostics.RecordDecision(
+ area: "InferenceOptimizer",
+ feature: "PagedKVCache",
+ enabled: false,
+ reason: "AllocateSequenceFailed(OutOfMemoryOrExhausted)");
+ _pagedKVCache = null;
+ _pagedKernel = null;
+ _pagedAttentionLayers = null;
+ _pagedSequenceId = null;
+ return false;
+ }
+
+ _pagedSequenceId = sequenceId;
+ _pagedAttentionLayers = attentionLayers;
+
+ foreach (var layer in attentionLayers)
+ {
+ layer.Kernel = _pagedKernel;
+ layer.SequenceId = sequenceId;
+ layer.InferenceMode = true;
+ }
+
+ return true;
+ }
+
+ private static bool TryAllocatePagedSequenceId(PagedKVCache cache, int initialTokens, out long sequenceId)
+ {
+ const int maxAttempts = 1024;
+ var spin = new SpinWait();
+
+ for (int attempt = 0; attempt < maxAttempts; attempt++)
+ {
+ sequenceId = Interlocked.Increment(ref s_nextPagedSequenceId);
+ if (cache.AllocateSequence(sequenceId, initialTokens))
+ {
+ return true;
+ }
+
+ spin.SpinOnce();
+ }
+
+ sequenceId = 0;
+ return false;
+ }
+
+ private static bool TryAllocatePagedSequenceId(PagedKVCache cache, long preferredId, int initialTokens, out long sequenceId)
+ {
+ if (cache.AllocateSequence(preferredId, initialTokens))
+ {
+ sequenceId = preferredId;
+ return true;
+ }
+
+ return TryAllocatePagedSequenceId(cache, initialTokens, out sequenceId);
+ }
+
+ private bool HasOptimizableAttentionLayers(NeuralNetworkBase model)
+ {
+ foreach (var layer in model.Layers)
+ {
+ if (layer is MultiHeadAttentionLayer || layer is FlashAttentionLayer || layer is SelfAttentionLayer)
+ return true;
+
+ if (_config.EnableWeightOnlyQuantization &&
+ typeof(T) == typeof(float) &&
+ layer is DenseLayer)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ private bool ApplyAttentionOptimizations(NeuralNetworkBase model)
+ {
+ bool useCausalMask = ResolveCausalMask(model);
+ InferenceDiagnostics.RecordDecision(
+ area: "InferenceOptimizer",
+ feature: "CausalMask",
+ enabled: useCausalMask,
+ reason: _config.AttentionMasking == AttentionMaskingMode.Auto ? "Auto" : _config.AttentionMasking.ToString());
+
+ // KV-cache is only beneficial for incremental decoding patterns; default to enabling it only when causal masking applies.
+ bool enableKVCache = _config.EnableKVCache && useCausalMask;
+ bool enablePagedKVCache = enableKVCache && _config.EnablePagedKVCache;
+ bool enableFlashAttention = _config.EnableFlashAttention;
+
+ bool anyRewritten = false;
+
+ for (int i = 0; i < model.Layers.Count; i++)
+ {
+ var layer = model.Layers[i];
+
+ if (layer is SelfAttentionLayer selfAttention && (enableKVCache || enableFlashAttention))
+ {
+ var converted = TryConvertSelfAttentionToMultiHead(selfAttention);
+ if (converted != null)
+ {
+ model.Layers[i] = converted;
+ anyRewritten = true;
+
+ // Re-process this index under MultiHeadAttention rules.
+ i--;
+ continue;
+ }
+
+ InferenceDiagnostics.RecordDecision(
+ area: "InferenceOptimizer",
+ feature: "SelfAttentionRewrite",
+ enabled: false,
+ reason: "UnsupportedSelfAttentionLayer(HeadCountOrShape)");
+ continue;
+ }
+
+ if (layer is MultiHeadAttentionLayer mha)
+ {
+ var inputShape = mha.GetInputShape();
+ if (inputShape.Length < 2)
+ {
+ continue;
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = mha.HeadCount;
+ var activation = mha.ScalarActivation;
+
+ if (enableKVCache)
+ {
+ if (enablePagedKVCache)
+ {
+ var paged = new PagedCachedMultiHeadAttention(
+ sequenceLength: seqLen,
+ embeddingDimension: embDim,
+ headCount: headCount,
+ useCausalMask: useCausalMask,
+ activationFunction: activation);
+ paged.EnableWeightOnlyQuantization = _config.EnableWeightOnlyQuantization;
+ paged.SetParameters(mha.GetParameters());
+ model.Layers[i] = paged;
+ }
+ else
+ {
+ var cached = new CachedMultiHeadAttention(
+ sequenceLength: seqLen,
+ embeddingDimension: embDim,
+ headCount: headCount,
+ useFlashAttention: enableFlashAttention,
+ layerIndex: 0,
+ useCausalMask: useCausalMask,
+ activationFunction: activation);
+ cached.SetParameters(mha.GetParameters());
+ model.Layers[i] = cached;
+ }
+ anyRewritten = true;
+ continue;
+ }
+
+ if (enableFlashAttention)
+ {
+ var flashConfig = FlashAttentionConfig.Default;
+ flashConfig.UseCausalMask = useCausalMask;
+
+ var flashLayer = new FlashAttentionLayer(
+ sequenceLength: seqLen,
+ embeddingDimension: embDim,
+ headCount: headCount,
+ config: flashConfig,
+ activationFunction: activation);
+ flashLayer.SetParameters(mha.GetParameters());
+ model.Layers[i] = flashLayer;
+ anyRewritten = true;
+ }
+
+ continue;
+ }
+
+ if (layer is FlashAttentionLayer flash && enableKVCache)
+ {
+ var inputShape = flash.GetInputShape();
+ if (inputShape.Length < 2)
+ {
+ continue;
+ }
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ int headCount = flash.HeadCount;
+ var activation = flash.ScalarActivation;
+
+ if (enablePagedKVCache)
+ {
+ var paged = new PagedCachedMultiHeadAttention(
+ sequenceLength: seqLen,
+ embeddingDimension: embDim,
+ headCount: headCount,
+ useCausalMask: useCausalMask,
+ activationFunction: activation);
+ paged.EnableWeightOnlyQuantization = _config.EnableWeightOnlyQuantization;
+ paged.SetParameters(flash.GetParameters());
+ model.Layers[i] = paged;
+ }
+ else
+ {
+ var cached = new CachedMultiHeadAttention(
+ sequenceLength: seqLen,
+ embeddingDimension: embDim,
+ headCount: headCount,
+ useFlashAttention: enableFlashAttention,
+ layerIndex: 0,
+ useCausalMask: useCausalMask,
+ activationFunction: activation);
+ cached.SetParameters(flash.GetParameters());
+ model.Layers[i] = cached;
+ }
+ anyRewritten = true;
+ }
+ }
+
+ return anyRewritten;
+ }
+
+ private bool ApplyWeightOnlyQuantization(NeuralNetworkBase model)
+ {
+ if (!_config.EnableWeightOnlyQuantization)
+ {
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "WeightOnlyQuantization", enabled: false, reason: "DisabledByConfig");
+ return false;
+ }
+
+ if (typeof(T) != typeof(float))
+ {
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "WeightOnlyQuantization", enabled: false, reason: $"UnsupportedType({typeof(T).Name})");
+ return false;
+ }
+
+ bool any = false;
+ for (int i = 0; i < model.Layers.Count; i++)
+ {
+ if (model.Layers[i] is DenseLayer dense)
+ {
+ try
+ {
+ var replacement = dense.VectorActivation != null
+ ? new QuantizedDenseLayer(dense, dense.VectorActivation)
+ : new QuantizedDenseLayer(dense);
+
+ model.Layers[i] = (AiDotNet.Interfaces.ILayer)(object)replacement;
+ any = true;
+ }
+ catch (Exception ex)
+ {
+ InferenceDiagnostics.RecordException("InferenceOptimizer", "WeightOnlyQuantization", ex, "DenseLayerQuantizationFailed;FallbackToFP");
+ }
+ }
+ }
+
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "WeightOnlyQuantization", enabled: any, reason: any ? "Applied(DenseLayer)" : "NoApplicableLayers");
+ return any;
+ }
+
+ private MultiHeadAttentionLayer? TryConvertSelfAttentionToMultiHead(SelfAttentionLayer layer)
+ {
+ var inputShape = layer.GetInputShape();
+ if (inputShape.Length < 2)
+ return null;
+
+ int seqLen = inputShape[0];
+ int embDim = inputShape[1];
+ if (seqLen <= 0 || embDim <= 0)
+ return null;
+
+ int headCount = TryGetHeadCountFromMetadata(layer) ?? 0;
+ if (headCount <= 0)
+ return null;
+
+ if (embDim % headCount != 0)
+ return null;
+
+ // SelfAttentionLayer has Q/K/V projections plus bias, but no output projection.
+ // We convert it into a MultiHeadAttentionLayer with an identity output projection so that
+ // downstream inference rewrites (FlashAttention / KV-cache) can be applied consistently.
+ var activation = layer.ScalarActivation;
+ var mha = new MultiHeadAttentionLayer(seqLen, embDim, headCount, activationFunction: activation);
+
+ var selfParams = layer.GetParameters();
+ int projSize = embDim * embDim;
+ int expectedSelf = (3 * projSize) + embDim;
+ if (selfParams.Length != expectedSelf)
+ return null;
+
+ var numOps = MathHelper.GetNumericOperations();
+
+ // MultiHead params: Q, K, V, O, bias
+ var combined = new Vector((4 * projSize) + embDim);
+ int idx = 0;
+
+ // Copy Q/K/V (3 * projSize)
+ for (int i = 0; i < 3 * projSize; i++)
+ combined[idx++] = selfParams[i];
+
+ // Output weights: identity matrix (embDim x embDim) flattened row-major
+ for (int r = 0; r < embDim; r++)
+ {
+ for (int c = 0; c < embDim; c++)
+ {
+ combined[idx++] = r == c ? numOps.One : numOps.Zero;
+ }
+ }
+
+ // Output bias (embDim)
+ for (int i = 0; i < embDim; i++)
+ combined[idx++] = selfParams[(3 * projSize) + i];
+
+ mha.SetParameters(combined);
+ return mha;
+ }
+
+ private static int? TryGetHeadCountFromMetadata(ILayer layer)
+ {
+ if (layer is not ILayerSerializationMetadata meta)
+ return null;
+
+ if (!meta.GetSerializationMetadata().TryGetValue("HeadCount", out var raw) || string.IsNullOrWhiteSpace(raw))
+ return null;
+
+ return int.TryParse(raw, out var parsed) ? parsed : null;
+ }
+
+ private bool ResolveCausalMask(NeuralNetworkBase model)
+ {
+ return _config.AttentionMasking switch
+ {
+ AttentionMaskingMode.Causal => true,
+ AttentionMaskingMode.Disabled => false,
+ _ => InferCausalFromModel(model)
+ };
+ }
+
+ private bool InferCausalFromModel(NeuralNetworkBase model)
+ {
+ // Default to causal when the user enables generation-oriented inference features.
+ // This matches industry-standard expectations for autoregressive decoding and avoids
+ // relying on users to set TaskType explicitly.
+ if (_config.EnableKVCache || _config.EnableSpeculativeDecoding)
+ return true;
+
+ // Otherwise, keep heuristics conservative to avoid changing semantics for non-generative models.
+ return model.Architecture.TaskType == NeuralNetworkTaskType.TextGeneration;
+ }
+
///
/// Estimates the maximum sequence length based on config and memory constraints.
///
- private int EstimateMaxSequenceLength()
+ /// Number of attention layers in the model.
+ /// Number of attention heads per layer.
+ /// Dimension of each attention head.
+ /// Maximum sequence length that fits within the configured memory budget.
+ private int EstimateMaxSequenceLength(int numLayers, int numHeads, int headDim)
{
- // Calculate based on available memory
- // Formula: maxSeqLen = (maxMemoryMB * 1024 * 1024) / (numLayers * numHeads * headDim * 2 * bytesPerElement)
- // Using a simplified estimate
+ // KV cache memory per token = numLayers * numHeads * headDim * 2 (K and V) * bytesPerElement
+ // For batch size, multiply by maxBatchSize
+ // Total: maxSeqLen * numLayers * numHeads * headDim * 2 * bytesPerElement * batchSize <= maxMemoryBytes
+
long maxMemoryBytes = (long)_config.KVCacheMaxSizeMB * 1024 * 1024;
- // Default reasonable sequence length
- const int defaultMaxSeqLen = 2048;
+ // Estimate bytes per element based on type T
+ int bytesPerElement = EstimateBytesPerElement();
+
+ // Memory per token per batch item = numLayers * numHeads * headDim * 2 * bytesPerElement
+ long memoryPerToken = (long)numLayers * numHeads * headDim * 2 * bytesPerElement;
+
+ // Account for batch size
+ long memoryPerTokenWithBatch = memoryPerToken * _config.MaxBatchSize;
+
+ // Prevent division by zero
+ if (memoryPerTokenWithBatch <= 0)
+ {
+ return 2048; // Reasonable default
+ }
+
+ // Calculate maximum sequence length
+ long calculatedMaxSeqLen = maxMemoryBytes / memoryPerTokenWithBatch;
- // Cap at reasonable maximum
- return Math.Min(defaultMaxSeqLen, 8192);
+ // Apply reasonable bounds using MathHelper.Clamp for net471 compatibility
+ const long minSeqLen = 128;
+ const long maxSeqLen = 32768; // Reasonable upper bound
+
+ return (int)MathHelper.Clamp(calculatedMaxSeqLen, minSeqLen, maxSeqLen);
+ }
+
+ ///
+ /// Estimates bytes per element based on the generic type T.
+ ///
+ private static int EstimateBytesPerElement()
+ {
+ // Common numeric types used in neural networks
+ var type = typeof(T);
+ if (type == typeof(float)) return 4;
+ if (type == typeof(double)) return 8;
+ if (type == typeof(Half)) return 2;
+ if (type == typeof(decimal)) return 16;
+
+ // Default to float size if unknown
+ return 4;
}
///
@@ -186,23 +727,56 @@ private int EstimateMaxSequenceLength()
///
private bool InitializeSpeculativeDecoding(NeuralNetworkBase model)
{
- // Create draft model based on configuration
- IDraftModel? draftModel = _config.DraftModelType switch
+ // Facade-friendly behavior: speculative decoding configuration must never crash inference.
+ // If a requested draft model is unavailable, fall back to an N-gram draft model and record diagnostics.
+ try
{
- DraftModelType.NGram => CreateNGramDraftModel(),
- DraftModelType.SmallNeural => CreateNeuralDraftModel(model),
- _ => null
- };
+ // For Custom draft models, an internal caller can provide one via SetCustomDraftModel().
+ if (_config.DraftModelType == DraftModelType.Custom)
+ {
+ if (_draftModel != null)
+ {
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "SpeculativeDraftModel", enabled: true, reason: "CustomProvided");
+ return true;
+ }
+
+ _draftModel = CreateNGramDraftModel();
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "SpeculativeDraftModel", enabled: _draftModel != null, reason: "CustomNotProvided_FallbackToNGram");
+ return _draftModel != null;
+ }
- if (draftModel == null)
+ IDraftModel? draftModel = _config.DraftModelType switch
+ {
+ DraftModelType.NGram => CreateNGramDraftModel(),
+ DraftModelType.SmallNeural => CreateNeuralDraftModel(model),
+ _ => CreateNGramDraftModel()
+ };
+
+ _draftModel = draftModel ?? CreateNGramDraftModel();
+ InferenceDiagnostics.RecordDecision(
+ "InferenceOptimizer",
+ "SpeculativeDraftModel",
+ enabled: _draftModel != null,
+ reason: draftModel != null ? _config.DraftModelType.ToString() : $"Unavailable({_config.DraftModelType})_FallbackToNGram");
+
+ return _draftModel != null;
+ }
+ catch (Exception ex)
{
- return false;
+ InferenceDiagnostics.RecordException("InferenceOptimizer", "SpeculativeDecoding", ex, "Draft model init failed; falling back to NGram.");
+ try
+ {
+ _draftModel = CreateNGramDraftModel();
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "SpeculativeDraftModel", enabled: _draftModel != null, reason: "ExceptionFallbackToNGram");
+ return _draftModel != null;
+ }
+ catch
+ {
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "SpeculativeDraftModel", enabled: false, reason: "FallbackFailed");
+ _draftModel = null;
+ return false;
+ }
}
-
- // Note: SpeculativeDecoder requires a target forward function
- // This will be set when actually doing inference via CreateSpeculativeDecoder
- _draftModel = draftModel;
- return true;
}
///
@@ -217,10 +791,19 @@ private bool InitializeSpeculativeDecoding(NeuralNetworkBase model)
///
/// Creates a small neural network draft model.
///
+ ///
+ /// SmallNeural draft models require a pre-trained companion model that is smaller
+ /// and faster than the target model but trained on similar data. This cannot be
+ /// automatically generated from the target model.
+ ///
+ ///
+ /// Always thrown because SmallNeural draft models require external pre-trained models.
+ ///
private IDraftModel? CreateNeuralDraftModel(NeuralNetworkBase model)
{
- // For neural draft models, we would need a pre-trained smaller model
- // This is a placeholder - in production, this would load a companion model
+ // SmallNeural draft models require a separate pre-trained smaller model. We do not expose
+ // draft model wiring via the public facade in the MVP, so treat this as unavailable.
+ InferenceDiagnostics.RecordDecision("InferenceOptimizer", "SpeculativeDraftModel", enabled: false, reason: "SmallNeuralUnavailable_FallbackToNGram");
return null;
}
@@ -258,6 +841,10 @@ public void DisableInferenceMode(NeuralNetworkBase model)
{
cachedAttention.InferenceMode = false;
}
+ else if (layer is PagedCachedMultiHeadAttention pagedAttention)
+ {
+ pagedAttention.InferenceMode = false;
+ }
}
}
@@ -267,6 +854,54 @@ public void DisableInferenceMode(NeuralNetworkBase model)
public void ClearCache()
{
_kvCache?.Clear();
+ if (_pagedKVCache != null && _pagedSequenceId.HasValue)
+ {
+ try
+ {
+ _pagedKVCache.FreeSequence(_pagedSequenceId.Value);
+ }
+ catch
+ {
+ // Best-effort cleanup.
+ }
+
+ // Re-allocate with the same ID if possible; otherwise allocate a new one.
+ if (!TryAllocatePagedSequenceId(_pagedKVCache, _pagedSequenceId.Value, initialTokens: 0, out long allocated))
+ {
+ InferenceDiagnostics.RecordDecision(
+ area: "InferenceOptimizer",
+ feature: "PagedKVCache",
+ enabled: false,
+ reason: "ClearCacheAllocateSequenceFailed(OutOfMemoryOrExhausted)");
+
+ // Safe fallback: disable paged inference mode on layers and keep session alive.
+ if (_pagedAttentionLayers != null)
+ {
+ foreach (var layer in _pagedAttentionLayers)
+ {
+ layer.InferenceMode = false;
+ layer.Kernel = null;
+ layer.ResetState();
+ }
+ }
+
+ _pagedSequenceId = null;
+ return;
+ }
+
+ _pagedSequenceId = allocated;
+
+ if (_pagedAttentionLayers != null && _pagedSequenceId.HasValue)
+ {
+ foreach (var layer in _pagedAttentionLayers)
+ {
+ layer.SequenceId = _pagedSequenceId.Value;
+ layer.ResetState();
+ layer.InferenceMode = true;
+ layer.Kernel ??= _pagedKernel;
+ }
+ }
+ }
}
///
@@ -280,7 +915,10 @@ public Dictionary GetStatistics()
["IsInitialized"] = _isInitialized,
["KVCacheEnabled"] = _config.EnableKVCache,
["SpeculativeDecodingEnabled"] = _config.EnableSpeculativeDecoding,
- ["BatchingEnabled"] = _config.EnableBatching
+ ["BatchingEnabled"] = _config.EnableBatching,
+ ["PagedKVCacheInitialized"] = _pagedKVCache != null,
+ ["PagedAttentionLayerCount"] = _pagedAttentionLayers?.Count ?? 0,
+ ["PagedAttentionWeightOnlyQuantizationEnabled"] = _pagedAttentionLayers?.Any(l => l.EnableWeightOnlyQuantization) ?? false
};
if (_kvCache != null)
@@ -310,6 +948,34 @@ public Dictionary GetStatistics()
///
public IDraftModel? DraftModel => _draftModel;
+ ///
+ /// Sets a custom draft model for speculative decoding.
+ ///
+ /// The custom draft model implementation.
+ ///
+ /// For Beginners: Use this method when you have your own draft model implementation.
+ ///
+ /// This is required when using DraftModelType.Custom or when you want to replace the
+ /// default NGram draft model with a more sophisticated model.
+ ///
+ /// Your custom draft model must implement IDraftModel<T> and provide:
+ /// - Draft token generation
+ /// - Probability estimation for speculative decoding verification
+ ///
+ /// Example:
+ ///
+ /// var optimizer = new InferenceOptimizer<float>(config);
+ /// optimizer.SetCustomDraftModel(myCustomDraftModel);
+ /// optimizer.Initialize(mainModel);
+ ///
+ ///
+ ///
+ /// Thrown when draftModel is null.
+ public void SetCustomDraftModel(IDraftModel draftModel)
+ {
+ _draftModel = draftModel ?? throw new ArgumentNullException(nameof(draftModel));
+ }
+
///
/// Creates a speculative decoder with the given target forward function.
///
@@ -339,7 +1005,13 @@ public Dictionary GetStatistics()
var speculativeConfig = new SpeculativeDecodingConfig
{
NumDraftTokens = _config.SpeculationDepth,
- UseTreeSpeculation = _config.UseTreeSpeculation
+ UseTreeSpeculation = _config.UseTreeSpeculation ||
+ _config.SpeculativeMethod == SpeculativeMethod.Medusa ||
+ _config.SpeculativeMethod == SpeculativeMethod.Eagle,
+ AdaptiveDraftLength = _config.SpeculationPolicy == SpeculationPolicy.Auto,
+ TreeBranchFactor = _config.SpeculativeMethod == SpeculativeMethod.Medusa ? 4 : 2,
+ MaxTreeDepth = Math.Max(1, _config.SpeculationDepth),
+ MinAcceptanceRate = MathHelper.GetNumericOperations().FromDouble(0.5)
};
_speculativeDecoder = new SpeculativeDecoder