"""Geometric predicates for convex hull operations."""
import warnings
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array
from ..core.utils import HullVertices, compute_simplex_volume, validate_point_cloud
[docs]
def point_in_convex_hull(
point: Array, hull_vertices: HullVertices, tolerance: float = 1e-8, method: str = "halfspace"
) -> Array:
"""Test if point is inside convex hull.
Determines whether a point lies inside, on the boundary, or outside
of the convex hull defined by the given vertices.
Args:
point: Point to test with shape (..., dim)
hull_vertices: Hull vertices with shape (..., n_vertices, dim)
tolerance: Numerical tolerance for boundary detection
method: Algorithm to use ("halfspace", "linear_programming", "barycentric")
Returns:
Boolean array indicating inclusion (True = inside or on boundary)
Algorithm (linear_programming method):
A point p is inside the convex hull if it can be expressed as:
p = sum(λᵢ * vᵢ) where sum(λᵢ) = 1 and λᵢ >= 0
This is solved as a linear programming problem:
minimize 0
subject to: sum(λᵢ * vᵢ) = p
sum(λᵢ) = 1
λᵢ >= 0
"""
point = jnp.asarray(point)
hull_vertices = validate_point_cloud(hull_vertices)
# Validate dimensional consistency
if point.ndim == 0:
raise ValueError("Point must have at least 1 dimension")
point_dim = point.shape[-1] if point.ndim > 0 else 1
hull_dim = hull_vertices.shape[-1]
if point_dim != hull_dim:
raise ValueError(f"Point dimension ({point_dim}) must match hull dimension ({hull_dim})")
if method == "linear_programming":
return _point_in_hull_lp(point, hull_vertices, tolerance)
elif method == "barycentric":
return _point_in_hull_barycentric(point, hull_vertices, tolerance)
elif method == "halfspace":
return _point_in_hull_halfspace(point, hull_vertices, tolerance)
else:
raise ValueError(f"Unknown method: {method}")
def _point_in_hull_lp(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Linear programming based point-in-hull test."""
n_vertices = hull_vertices.shape[-2]
dim = hull_vertices.shape[-1]
# For small hulls, use direct barycentric coordinate computation
if n_vertices <= dim + 1:
return _point_in_hull_barycentric(point, hull_vertices, tolerance)
# For larger hulls, we need a more sophisticated LP solver
# For now, use a simplified approach: check if point is within
# the bounding box and use barycentric coordinates for a subset
# Compute bounding box
min_coords = jnp.min(hull_vertices, axis=-2)
max_coords = jnp.max(hull_vertices, axis=-2)
# Quick bounding box test
in_bbox = jnp.all((point >= min_coords - tolerance) & (point <= max_coords + tolerance), axis=-1)
# For points outside bounding box, return False
# For points inside bounding box, do more detailed test
def detailed_test(p, vertices):
# Use a simplified approach: find closest simplex and test inclusion
# This is a heuristic and not always accurate for complex hulls
center = jnp.mean(vertices, axis=-2)
distances = jnp.linalg.norm(vertices - center, axis=-1)
closest_indices = jnp.argsort(distances)[: dim + 1]
simplex = vertices[closest_indices]
return _point_in_simplex(p, simplex, tolerance)
# Apply detailed test only where bounding box test passed
detailed_result = jax.lax.cond(
jnp.any(in_bbox), lambda: detailed_test(point, hull_vertices), lambda: jnp.array(False, dtype=bool)
)
return jnp.logical_and(in_bbox, detailed_result)
def _point_in_hull_barycentric(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Barycentric coordinate based point-in-hull test."""
n_vertices = hull_vertices.shape[-2]
dim = hull_vertices.shape[-1]
if n_vertices == dim + 1:
# Perfect simplex case
return _point_in_simplex(point, hull_vertices, tolerance)
elif n_vertices < dim + 1:
# Degenerate case - not enough vertices for full-dimensional hull
return jnp.array(False, dtype=bool)
else:
# Over-determined case - decompose into simplices
# For simplicity, use the first (dim+1) vertices
simplex = hull_vertices[..., : dim + 1, :]
return _point_in_simplex(point, simplex, tolerance)
def _point_in_simplex(point: Array, simplex_vertices: Array, tolerance: float) -> Array:
"""Test if point is inside simplex using barycentric coordinates."""
n_vertices = simplex_vertices.shape[-2]
dim = simplex_vertices.shape[-1]
if n_vertices != dim + 1:
raise ValueError(f"Simplex must have {dim + 1} vertices, got {n_vertices}")
# Solve for barycentric coordinates
# point = sum(λᵢ * vᵢ) with sum(λᵢ) = 1
# Rearrange to: point - v₀ = sum(λᵢ * (vᵢ - v₀)) for i > 0
v0 = simplex_vertices[..., 0, :]
edge_vectors = simplex_vertices[..., 1:, :] - v0[..., None, :]
point_offset = point - v0
# Solve linear system: edge_vectors.T @ lambdas = point_offset
try:
# Use least squares for over-determined systems
lambdas_rest, residuals, rank, s = jnp.linalg.lstsq(edge_vectors.T, point_offset, rcond=None)
# Compute λ₀ = 1 - sum(λᵢ) for i > 0
lambda0 = 1.0 - jnp.sum(lambdas_rest)
# Full barycentric coordinates
lambdas = jnp.concatenate([lambda0[None], lambdas_rest])
# Check if all coordinates are non-negative (within tolerance)
return jnp.all(lambdas >= -tolerance)
except np.linalg.LinAlgError:
# Singular matrix - degenerate simplex
return jnp.array(False, dtype=bool)
[docs]
def convex_hull_volume(vertices: HullVertices, method: str = "simplex_decomposition") -> Array:
"""Compute volume of convex hull (differentiable).
Args:
vertices: Hull vertices with shape (..., n_vertices, dim)
method: Volume computation method
- "simplex_decomposition": Decompose into simplices
- "shoelace": Shoelace formula (2D only)
- "divergence_theorem": Use divergence theorem (3D only)
- "monte_carlo": Monte Carlo estimation
- "multi_method": Consensus across multiple methods
Returns:
Volume of the convex hull (d-dimensional measure)
Note:
For d-dimensional space, volume is the d-dimensional measure.
For 2D, this is area; for 3D, this is volume; etc.
"""
vertices = validate_point_cloud(vertices)
if method == "simplex_decomposition":
return _volume_simplex_decomposition(vertices)
elif method == "shoelace":
return _volume_shoelace_formula(vertices)
elif method == "divergence_theorem":
return _volume_divergence_theorem(vertices)
elif method == "monte_carlo":
return _volume_monte_carlo(vertices)
elif method == "multi_method":
return _volume_multi_method_consensus(vertices)
else:
raise ValueError(f"Unknown volume method: {method}")
def _volume_simplex_decomposition(vertices: HullVertices) -> Array:
"""Compute volume by decomposing hull into simplices."""
n_vertices = vertices.shape[-2]
dim = vertices.shape[-1]
if n_vertices < dim + 1:
# Not enough vertices for full-dimensional hull
return jnp.array(0.0)
if n_vertices == dim + 1:
# Perfect simplex
return compute_simplex_volume(vertices)
# For more vertices, decompose into simplices
# Use fan triangulation from first vertex
v0 = vertices[..., 0, :]
total_volume = jnp.array(0.0)
# Create simplices by connecting v0 with each (dim)-dimensional face
# This is a simplified approach - proper decomposition would use
# a more sophisticated algorithm like Delaunay triangulation
if dim == 2:
# 2D case: decompose into triangles
for i in range(1, n_vertices - 1):
triangle = jnp.stack([v0, vertices[..., i, :], vertices[..., i + 1, :]], axis=-2)
total_volume += compute_simplex_volume(triangle)
elif dim == 3:
# 3D case: decompose into tetrahedra
# Use convex hull's faces (simplified approximation)
for i in range(1, n_vertices - 2):
for j in range(i + 1, n_vertices - 1):
tetrahedron = jnp.stack(
[v0, vertices[..., i, :], vertices[..., j, :], vertices[..., j + 1, :]], axis=-2
)
total_volume += compute_simplex_volume(tetrahedron)
else:
# Higher dimensions: use approximate method
# This is not geometrically accurate but provides a reasonable estimate
warnings.warn(f"Simplex decomposition for dimension {dim} is approximate", UserWarning, stacklevel=2)
# Use average simplex volume scaled by number of simplices
if n_vertices >= dim + 1:
sample_simplex = vertices[..., : dim + 1, :]
sample_volume = compute_simplex_volume(sample_simplex)
# Rough scaling based on number of vertices
scaling_factor = n_vertices / (dim + 1)
total_volume = sample_volume * jnp.array(scaling_factor)
return jnp.abs(total_volume)
def _volume_divergence_theorem(vertices: HullVertices) -> Array:
"""Compute volume using divergence theorem (3D only)."""
dim = vertices.shape[-1]
if dim != 3:
warnings.warn(
"Divergence theorem method only works for 3D, falling back to simplex decomposition",
UserWarning,
stacklevel=2,
)
return _volume_simplex_decomposition(vertices)
# TODO: Implement proper divergence theorem volume calculation
# This requires computing the surface mesh and applying the theorem
# For now, fall back to simplex decomposition
warnings.warn("Divergence theorem not yet implemented, using simplex decomposition", UserWarning, stacklevel=2)
return _volume_simplex_decomposition(vertices)
def _volume_monte_carlo(vertices: HullVertices, n_samples: int = 10000, random_key: Array | None = None) -> Array:
"""Compute volume using Monte Carlo estimation."""
if random_key is None:
random_key = jax.random.PRNGKey(42)
# Compute bounding box
min_coords = jnp.min(vertices, axis=-2)
max_coords = jnp.max(vertices, axis=-2)
bbox_volume = jnp.prod(max_coords - min_coords)
# Generate random points in bounding box
dim = vertices.shape[-1]
random_points = jax.random.uniform(random_key, (n_samples, dim), minval=min_coords, maxval=max_coords)
# Test which points are inside the hull
inside_count = 0
for i in range(n_samples):
if point_in_convex_hull(random_points[i], vertices):
inside_count += 1
# Estimate volume
inside_ratio = inside_count / n_samples
estimated_volume = bbox_volume * inside_ratio
return estimated_volume
[docs]
def convex_hull_surface_area(vertices: HullVertices, faces: Array | None = None) -> Array:
"""Compute surface area of convex hull.
Args:
vertices: Hull vertices with shape (..., n_vertices, dim)
faces: Face vertex indices with shape (..., n_faces, vertices_per_face)
If None, faces will be computed automatically
Returns:
Surface area (sum of face areas)
"""
vertices = validate_point_cloud(vertices)
dim = vertices.shape[-1]
if faces is None:
faces = _compute_hull_faces(vertices)
if dim == 2:
# 2D case: perimeter calculation
return _compute_2d_perimeter(vertices)
elif dim == 3:
# 3D case: sum of triangle areas
return _compute_3d_surface_area(vertices, faces)
else:
# Higher dimensions: approximate using boundary measure
warnings.warn(f"Surface area computation for dimension {dim} is approximate", UserWarning, stacklevel=2)
return _compute_nd_boundary_measure(vertices)
def _compute_2d_perimeter(vertices: HullVertices) -> Array:
"""Compute perimeter of 2D convex hull."""
vertices.shape[-2]
# Compute edge lengths
edge_vectors = jnp.roll(vertices, -1, axis=-2) - vertices
edge_lengths = jnp.linalg.norm(edge_vectors, axis=-1)
return jnp.sum(edge_lengths, axis=-1)
def _compute_3d_surface_area(vertices: HullVertices, faces: Array) -> Array:
"""Compute surface area of 3D convex hull."""
total_area = jnp.array(0.0)
# For each triangular face, compute area
for face_indices in faces:
if len(face_indices) >= 3:
# Get vertices of the face
face_vertices = vertices[..., face_indices[:3], :]
# Compute triangle area using cross product
v1 = face_vertices[..., 1, :] - face_vertices[..., 0, :]
v2 = face_vertices[..., 2, :] - face_vertices[..., 0, :]
cross_product = jnp.cross(v1, v2)
area = 0.5 * jnp.linalg.norm(cross_product)
total_area += area
return total_area
def _compute_nd_boundary_measure(vertices: HullVertices) -> Array:
"""Approximate boundary measure for high-dimensional hulls."""
# This is a rough approximation
n_vertices = vertices.shape[-2]
dim = vertices.shape[-1]
# Use average distance between vertices as approximation
center = jnp.mean(vertices, axis=-2)
distances = jnp.linalg.norm(vertices - center[..., None, :], axis=-1)
avg_distance = jnp.mean(distances)
# Scale by number of vertices and dimension
boundary_measure = avg_distance * n_vertices * jnp.sqrt(dim)
return boundary_measure
def _compute_hull_faces(vertices: HullVertices) -> Array:
"""Compute faces of convex hull.
This is a simplified implementation that returns a reasonable
approximation of the faces. A full implementation would require
a proper convex hull algorithm.
"""
n_vertices = vertices.shape[-2]
dim = vertices.shape[-1]
if dim == 2:
# 2D: faces are edges (pairs of consecutive vertices)
faces = []
for i in range(n_vertices):
faces.append([i, (i + 1) % n_vertices])
return jnp.array(faces)
elif dim == 3:
# 3D: faces are triangles
# This is a simplified triangulation - not guaranteed to be correct
faces = []
for i in range(n_vertices - 2):
for j in range(i + 1, n_vertices - 1):
for k in range(j + 1, n_vertices):
faces.append([i, j, k])
return jnp.array(faces)
else:
# Higher dimensions: return empty array
return jnp.array([])
[docs]
def distance_to_convex_hull(point: Array, hull_vertices: HullVertices) -> Array:
"""Compute distance from point to convex hull.
Args:
point: Point with shape (..., dim)
hull_vertices: Hull vertices with shape (..., n_vertices, dim)
Returns:
Signed distance to hull:
- Positive: point is outside hull
- Zero: point is on boundary
- Negative: point is inside hull
"""
# Check if point is inside hull
is_inside = point_in_convex_hull(point, hull_vertices)
# Compute distance to closest vertex (approximation)
distances_to_vertices = jnp.linalg.norm(hull_vertices - point[..., None, :], axis=-1)
min_distance = jnp.min(distances_to_vertices, axis=-1)
# Return signed distance
return jnp.where(is_inside, -min_distance, min_distance)
[docs]
def hausdorff_distance(hull1_vertices: HullVertices, hull2_vertices: HullVertices) -> Array:
"""Compute Hausdorff distance between two convex hulls.
The Hausdorff distance is the maximum of:
1. Maximum distance from any point in hull1 to hull2
2. Maximum distance from any point in hull2 to hull1
Args:
hull1_vertices: First hull vertices
hull2_vertices: Second hull vertices
Returns:
Hausdorff distance between the hulls
"""
# Distance from hull1 vertices to hull2
distances_1_to_2 = jnp.array([jnp.abs(distance_to_convex_hull(v, hull2_vertices)) for v in hull1_vertices])
max_dist_1_to_2 = jnp.max(distances_1_to_2)
# Distance from hull2 vertices to hull1
distances_2_to_1 = jnp.array([jnp.abs(distance_to_convex_hull(v, hull1_vertices)) for v in hull2_vertices])
max_dist_2_to_1 = jnp.max(distances_2_to_1)
return jnp.maximum(max_dist_1_to_2, max_dist_2_to_1)
# =============================================================================
# PHASE 2: IMPROVED VOLUME COMPUTATION METHODS
# =============================================================================
def _volume_shoelace_formula(vertices: HullVertices) -> Array:
"""Compute 2D polygon area using shoelace formula.
The shoelace formula: Area = 0.5 * |Σ(x_i * y_{i+1} - x_{i+1} * y_i)|
"""
if vertices.shape[-1] != 2:
raise ValueError("Shoelace formula only works for 2D polygons")
n_vertices = vertices.shape[-2]
if n_vertices < 3:
return jnp.array(0.0)
# Sort vertices by angle to ensure proper ordering
centroid = jnp.mean(vertices, axis=-2)
centered_vertices = vertices - centroid
# Compute angles from centroid
angles = jnp.arctan2(centered_vertices[..., 1], centered_vertices[..., 0])
sorted_indices = jnp.argsort(angles)
sorted_vertices = vertices[sorted_indices]
# Apply shoelace formula
x = sorted_vertices[..., 0]
y = sorted_vertices[..., 1]
# Cyclic differences: x_i * y_{i+1} - x_{i+1} * y_i
x_next = jnp.roll(x, -1, axis=-1)
y_next = jnp.roll(y, -1, axis=-1)
cross_products = x * y_next - x_next * y
area = 0.5 * jnp.abs(jnp.sum(cross_products))
return area
def _volume_multi_method_consensus(vertices: HullVertices) -> Array:
"""Compute volume using multiple methods and return consensus.
Uses different methods based on dimensionality and returns a consensus
value to improve accuracy and reliability.
"""
dim = vertices.shape[-1]
if dim == 2:
# For 2D, use both simplex decomposition and shoelace
try:
volume_simplex = _volume_simplex_decomposition(vertices)
volume_shoelace = _volume_shoelace_formula(vertices)
# Check agreement between methods
relative_diff = jnp.abs(volume_simplex - volume_shoelace) / jnp.maximum(volume_simplex, 1e-10)
# If methods agree well, return average
if relative_diff < 0.1: # 10% agreement
return 0.5 * (volume_simplex + volume_shoelace)
else:
# If methods disagree, prefer shoelace for 2D (more accurate)
return volume_shoelace
except (ValueError, Exception):
# Fallback to simplex decomposition
return _volume_simplex_decomposition(vertices)
elif dim == 3:
# For 3D, use simplex decomposition (most reliable for 3D)
try:
volume_simplex = _volume_simplex_decomposition(vertices)
# Could add other 3D methods here in the future
return volume_simplex
except Exception:
# Fallback to Monte Carlo if simplex fails
return _volume_monte_carlo(vertices, n_samples=1000)
else:
# For higher dimensions, use simplex decomposition
return _volume_simplex_decomposition(vertices)
def _volume_determinant_method(vertices: HullVertices) -> Array:
"""Alternative volume computation using determinant method.
This method is particularly accurate for simplices and can serve
as a cross-check for other methods.
"""
n_vertices, dim = vertices.shape[-2], vertices.shape[-1]
if n_vertices == dim + 1:
# Perfect simplex - use determinant formula
v0 = vertices[0]
edge_vectors = vertices[1:] - v0
if dim == edge_vectors.shape[0]:
det = jnp.linalg.det(edge_vectors)
# Volume = |det| / d!
factorial = jnp.array([1, 1, 2, 6, 24, 120, 720, 5040][dim])
return jnp.abs(det) / factorial
# For non-simplex cases, fall back to simplex decomposition
return _volume_simplex_decomposition(vertices)
def compute_volume_accuracy_metrics(
vertices: HullVertices, exact_volume: float | None = None
) -> dict[str, float | dict[str, float | None] | bool]:
"""Compute accuracy metrics for volume computation methods.
Args:
vertices: Hull vertices
exact_volume: Known exact volume for comparison (if available)
Returns:
Dictionary containing accuracy metrics
"""
dim = vertices.shape[-1]
# Compute volume with different methods
volumes: dict[str, float | None] = {}
try:
volumes["simplex"] = float(_volume_simplex_decomposition(vertices))
except Exception:
volumes["simplex"] = None
if dim == 2:
try:
volumes["shoelace"] = float(_volume_shoelace_formula(vertices))
except Exception:
volumes["shoelace"] = None
try:
volumes["multi_method"] = float(_volume_multi_method_consensus(vertices))
except Exception:
volumes["multi_method"] = None
# Compute consistency metrics
valid_volumes = [v for v in volumes.values() if v is not None]
if len(valid_volumes) >= 2:
mean_volume = float(jnp.mean(jnp.array(valid_volumes)))
std_volume = float(jnp.std(jnp.array(valid_volumes)))
coefficient_of_variation = std_volume / mean_volume if mean_volume > 0 else float("inf")
else:
mean_volume = valid_volumes[0] if valid_volumes else 0.0
std_volume = 0.0
coefficient_of_variation = 0.0
metrics: dict[str, float | dict[str, float | None] | bool] = {
"volumes": volumes,
"mean_volume": float(mean_volume),
"std_volume": float(std_volume),
"coefficient_of_variation": float(coefficient_of_variation),
"method_consistency": float(1.0 - coefficient_of_variation) if coefficient_of_variation < 1 else 0.0,
}
# Add accuracy metrics if exact volume is provided
if exact_volume is not None:
metrics["exact_volume"] = exact_volume
for method, volume in volumes.items():
if volume is not None:
relative_error = abs(volume - exact_volume) / exact_volume
metrics[f"{method}_relative_error"] = float(relative_error)
metrics[f"{method}_accurate"] = relative_error < 0.05 # 5% threshold
return metrics
# =============================================================================
# PHASE 2: IMPROVED POINT CONTAINMENT METHODS
# =============================================================================
def _point_in_hull_halfspace(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Improved halfspace-based point-in-hull test.
This method computes the convex hull faces and checks if the point
is on the correct side of all halfspaces defined by the faces.
"""
n_vertices, dim = hull_vertices.shape[-2], hull_vertices.shape[-1]
if n_vertices < dim + 1:
# Not enough vertices for full-dimensional hull
# Fall back to distance-based test
distances = jnp.linalg.norm(hull_vertices - point, axis=-1)
min_distance = jnp.min(distances)
return min_distance <= tolerance
if dim == 2:
return _point_in_hull_2d_robust(point, hull_vertices, tolerance)
elif dim == 3:
return _point_in_hull_3d_robust(point, hull_vertices, tolerance)
else:
# For higher dimensions, use improved barycentric method
return _point_in_hull_barycentric_robust(point, hull_vertices, tolerance)
def _point_in_hull_2d_robust(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Robust 2D point-in-polygon test using winding number."""
n_vertices = hull_vertices.shape[-2]
if n_vertices < 3:
# Degenerate case
distances = jnp.linalg.norm(hull_vertices - point, axis=-1)
return jnp.min(distances) <= tolerance
# Sort vertices by angle to ensure proper order
centroid = jnp.mean(hull_vertices, axis=-2)
centered_vertices = hull_vertices - centroid
angles = jnp.arctan2(centered_vertices[:, 1], centered_vertices[:, 0])
sorted_indices = jnp.argsort(angles)
sorted_vertices = hull_vertices[sorted_indices]
# Use winding number algorithm
winding_number = jnp.array(0.0)
for i in range(n_vertices):
v1 = sorted_vertices[i] - point
v2 = sorted_vertices[(i + 1) % n_vertices] - point
# Check if point is on edge (within tolerance)
edge_vec = v2 - v1
if jnp.linalg.norm(edge_vec) > 1e-12:
# Project point onto edge
t = jnp.dot(-v1, edge_vec) / jnp.dot(edge_vec, edge_vec)
t = jnp.clip(t, 0.0, 1.0)
closest_point = v1 + t * edge_vec
distance_to_edge = jnp.linalg.norm(closest_point)
if distance_to_edge <= tolerance:
return jnp.array(True)
# Compute contribution to winding number
cross_product = v1[0] * v2[1] - v1[1] * v2[0]
dot_product = jnp.dot(v1, v2)
angle = jnp.arctan2(cross_product, dot_product)
winding_number += angle
# Point is inside if winding number is close to ±2π
abs_winding = jnp.abs(winding_number)
return abs_winding > jnp.pi # Threshold for "inside"
def _point_in_hull_3d_robust(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Robust 3D point-in-hull test."""
n_vertices = hull_vertices.shape[-2]
if n_vertices < 4:
# Degenerate 3D case
distances = jnp.linalg.norm(hull_vertices - point, axis=-1)
return jnp.min(distances) <= tolerance
# For 3D, use tetrahedralization approach
# Check if point is inside any tetrahedron formed by hull vertices
centroid = jnp.mean(hull_vertices, axis=-2)
# Test if point is inside the tetrahedron formed by centroid and any face
for i in range(n_vertices - 2):
for j in range(i + 1, n_vertices - 1):
for k in range(j + 1, n_vertices):
# Form tetrahedron with centroid and vertices i, j, k
tetrahedron = jnp.array([centroid, hull_vertices[i], hull_vertices[j], hull_vertices[k]])
# Check if point is in this tetrahedron using barycentric coordinates
if _point_in_tetrahedron(point, tetrahedron, tolerance):
return jnp.array(True)
return jnp.array(False)
def _point_in_tetrahedron(point: Array, tetrahedron_vertices: Array, tolerance: float) -> Array:
"""Test if point is inside tetrahedron using barycentric coordinates."""
# Solve for barycentric coordinates
# point = λ₀*v₀ + λ₁*v₁ + λ₂*v₂ + λ₃*v₃ where Σλᵢ = 1
v0 = tetrahedron_vertices[0]
edge_matrix = tetrahedron_vertices[1:] - v0 # 3x3 matrix
point_vec = point - v0
try:
# Solve the linear system
lambdas_123 = jnp.linalg.solve(edge_matrix.T, point_vec)
lambda_0 = 1.0 - jnp.sum(lambdas_123)
all_lambdas = jnp.concatenate([jnp.array([lambda_0]), lambdas_123])
# Point is inside if all barycentric coordinates are non-negative
return jnp.all(all_lambdas >= -tolerance)
except np.linalg.LinAlgError:
# Singular matrix - degenerate tetrahedron
return jnp.array(False)
def _point_in_hull_barycentric_robust(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Improved barycentric coordinate based point-in-hull test."""
n_vertices, dim = hull_vertices.shape[-2], hull_vertices.shape[-1]
if n_vertices == dim + 1:
# Perfect simplex - use exact barycentric coordinates
return _point_in_simplex_exact(point, hull_vertices, tolerance)
elif n_vertices < dim + 1:
# Under-determined - check distance to hull
distances = jnp.linalg.norm(hull_vertices - point, axis=-1)
return jnp.min(distances) <= tolerance
else:
# Over-determined - decompose into simplices
return _point_in_hull_simplex_decomposition(point, hull_vertices, tolerance)
def _point_in_simplex_exact(point: Array, simplex_vertices: Array, tolerance: float) -> Array:
"""Exact point-in-simplex test using barycentric coordinates."""
n_vertices, dim = simplex_vertices.shape[-2], simplex_vertices.shape[-1]
if n_vertices != dim + 1:
raise ValueError(f"Simplex in {dim}D should have {dim + 1} vertices, got {n_vertices}")
# Set up barycentric coordinate system
v0 = simplex_vertices[0]
edge_matrix = simplex_vertices[1:] - v0
point_vec = point - v0
try:
# Solve for barycentric coordinates
lambdas_rest = jnp.linalg.solve(edge_matrix.T, point_vec)
lambda_0 = 1.0 - jnp.sum(lambdas_rest)
all_lambdas = jnp.concatenate([jnp.array([lambda_0]), lambdas_rest])
# Point is inside if all coordinates are non-negative (within tolerance)
return jnp.all(all_lambdas >= -tolerance)
except np.linalg.LinAlgError:
# Degenerate simplex
distances = jnp.linalg.norm(simplex_vertices - point, axis=-1)
return jnp.min(distances) <= tolerance
def _point_in_hull_simplex_decomposition(point: Array, hull_vertices: HullVertices, tolerance: float) -> Array:
"""Test point containment by decomposing hull into simplices."""
n_vertices, dim = hull_vertices.shape[-2], hull_vertices.shape[-1]
# Simple approach: test against tetrahedra/triangles formed with centroid
centroid = jnp.mean(hull_vertices, axis=-2)
# For each subset of dim vertices, form a simplex with centroid
# and test if point is inside
if dim == 2:
# Test triangles
for i in range(n_vertices):
j = (i + 1) % n_vertices
triangle = jnp.array([centroid, hull_vertices[i], hull_vertices[j]])
if _point_in_simplex_exact(point, triangle, tolerance):
return jnp.array(True)
elif dim == 3:
# Test tetrahedra (simplified approach)
for i in range(n_vertices - 2):
for j in range(i + 1, n_vertices - 1):
for k in range(j + 1, n_vertices):
tetrahedron = jnp.array([centroid, hull_vertices[i], hull_vertices[j], hull_vertices[k]])
if _point_in_simplex_exact(point, tetrahedron, tolerance):
return jnp.array(True)
return jnp.array(False)
# JIT-compiled versions for performance
point_in_convex_hull_jit = jax.jit(point_in_convex_hull, static_argnames=["method"])
convex_hull_volume_jit = jax.jit(convex_hull_volume, static_argnames=["method"])
convex_hull_surface_area_jit = jax.jit(convex_hull_surface_area)
distance_to_convex_hull_jit = jax.jit(distance_to_convex_hull)
hausdorff_distance_jit = jax.jit(hausdorff_distance)