User Guide
This user guide provides comprehensive documentation for using PolytopAX effectively in your projects.
Overview
PolytopAX is designed to make differentiable convex hull computation accessible and efficient for machine learning and scientific computing applications. This guide covers:
Basic concepts and terminology
Common usage patterns and best practices
Advanced features and customization options
Performance optimization techniques
Integration with existing workflows
Table of Contents
Key Concepts
Differentiable Convex Hulls
Traditional convex hull algorithms use discrete operations that break differentiability. PolytopAX uses:
Direction vector sampling on the unit sphere
Soft selection with temperature-controlled operations
Continuous approximations instead of discrete choices
This enables gradient-based optimization while maintaining computational efficiency.
JAX Integration
PolytopAX is built on JAX, providing:
JIT compilation for performance
Automatic differentiation for gradients
Vectorization for batch operations
Device acceleration on GPU/TPU
All PolytopAX operations are compatible with JAX transformations.
Approximation vs Exactness
PolytopAX computes approximate convex hulls that:
Maintain differentiability throughout computation
Provide controllable accuracy via algorithm parameters
Scale efficiently to large point sets
Support batch operations natively
The approximation quality can be tuned based on application requirements.
Common Workflows
1. Basic Convex Hull Computation
import polytopax as ptx
import jax.numpy as jnp
# Compute hull
points = jnp.array([[0, 0], [1, 0], [0, 1]])
hull = ptx.ConvexHull.from_points(points)
# Access properties
area = hull.volume()
perimeter = hull.surface_area()
2. Gradient-Based Optimization
import jax
def objective(points):
hull = ptx.ConvexHull.from_points(points)
return hull.volume() # Maximize area
# Compute gradients
grad_fn = jax.grad(objective)
gradients = grad_fn(points)
3. Batch Processing
# Process multiple point sets
batch_points = jnp.stack([points1, points2, points3])
batch_hulls = jax.vmap(ptx.ConvexHull.from_points)(batch_points)
4. Performance Optimization
# JIT compile for speed
@jax.jit
def fast_hull_computation(points):
return ptx.ConvexHull.from_points(points, n_directions=20)
Best Practices
Choose appropriate parameters for your accuracy/speed tradeoff
Use JIT compilation for repeated computations
Batch operations when processing multiple point sets
Profile your code to identify bottlenecks
Validate results for critical applications
Getting Help
Documentation: Browse this documentation site
Examples: Check the examples directory for practical code
API Reference: See detailed function documentation
Community: Join discussions on GitHub
Issues: Report bugs or request features
Let’s get started with Basic Usage!