"""Convex hull computation functions - unified interface."""
import jax
from jax import Array
from ..operations.predicates import convex_hull_surface_area as hull_surface_area
from ..operations.predicates import convex_hull_volume as hull_volume
from ..operations.predicates import distance_to_convex_hull as distance_to_hull
from ..operations.predicates import point_in_convex_hull as point_in_hull
from .utils import HullVertices, PointCloud, validate_point_cloud
[docs]
def convex_hull(points: PointCloud, algorithm: str = "approximate", **kwargs) -> HullVertices:
"""Compute convex hull of a set of points.
This is the main unified interface for convex hull computation.
It supports multiple algorithms and provides a consistent API.
Args:
points: Input points array with shape (..., n_points, dimension)
algorithm: Algorithm to use:
- "approximate": Differentiable approximate hull (default)
- "quickhull": Exact Quickhull algorithm (Phase 2)
- "graham_scan": 2D Graham scan algorithm (Phase 2)
**kwargs: Algorithm-specific parameters passed to the underlying function
Returns:
Array of convex hull vertices
Example:
>>> import jax.numpy as jnp
>>> points = jnp.array([[0, 0], [1, 0], [0, 1], [1, 1]])
>>> hull_vertices = convex_hull(points, algorithm="approximate")
>>> print(hull_vertices.shape) # (n_hull_vertices, 2)
Algorithm-specific parameters:
For algorithm="approximate":
- n_directions (int): Number of sampling directions (default: 100)
- method (str): Sampling method ("uniform", "icosphere", "adaptive")
- temperature (float): Softmax temperature for differentiability
- random_key (Array): JAX random key
"""
points = validate_point_cloud(points)
if algorithm == "approximate":
from ..algorithms.approximation import approximate_convex_hull as _approximate_convex_hull
hull_vertices, _ = _approximate_convex_hull(points, **kwargs)
return hull_vertices
elif algorithm == "quickhull":
raise NotImplementedError(
"Quickhull algorithm will be implemented in Phase 2. Use algorithm='approximate' for now."
)
elif algorithm == "graham_scan":
raise NotImplementedError(
"Graham scan algorithm will be implemented in Phase 2. Use algorithm='approximate' for now."
)
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
[docs]
def approximate_convex_hull(
points: Array, n_directions: int = 100, method: str = "uniform", random_seed: int = 0
) -> tuple[Array, Array]:
"""Differentiable approximate convex hull computation.
This function maintains backward compatibility with the original API
while forwarding to the new implementation.
Args:
points: Point cloud with shape [..., n_points, dim]
n_directions: Number of sampling directions
method: Sampling strategy ('uniform', 'adaptive', 'icosphere')
random_seed: Random seed
Returns:
Tuple of (hull_points, hull_indices)
Note:
This function is maintained for backward compatibility.
For new code, consider using the unified convex_hull() function
or the algorithms.approximation module directly.
"""
# Convert to new API parameters
random_key = jax.random.PRNGKey(random_seed) if random_seed else None
from typing import cast
from ..algorithms.approximation import approximate_convex_hull as _approximate_convex_hull
from ..core.utils import SamplingMethod
return _approximate_convex_hull(
points, n_directions=n_directions, method=cast(SamplingMethod, method), random_key=random_key
)
# Re-export key functions for convenience
__all__ = [
"approximate_convex_hull",
"convex_hull",
"distance_to_hull",
"hull_surface_area",
"hull_volume",
"point_in_hull",
]
# JIT-compiled versions for performance
convex_hull_jit = jax.jit(convex_hull, static_argnames=["algorithm"])
approximate_convex_hull_jit = jax.jit(approximate_convex_hull, static_argnames=["method"])