Source code for polytopax.core.utils

"""Core utility functions for PolytopAX."""

import warnings
from typing import Literal, TypeAlias

import jax
import jax.numpy as jnp
from jax import Array

# Type aliases
PointCloud: TypeAlias = Array  # shape: (..., n_points, dimension)
HullVertices: TypeAlias = Array  # shape: (n_vertices, dimension)
DirectionVectors: TypeAlias = Array  # shape: (n_directions, dimension)
SamplingMethod = Literal["uniform", "icosphere", "adaptive"]


[docs] def validate_point_cloud(points: Array) -> Array: """Validate point cloud shape and numerical validity. Args: points: Input point cloud with shape (..., n_points, dim) Returns: Validated point cloud Raises: ValueError: Invalid shape or numerical values """ if not isinstance(points, Array): raise TypeError(f"Expected JAX Array, got {type(points)}") if points.ndim < 2: raise ValueError(f"Point cloud must have at least 2 dimensions, got {points.ndim}") if points.shape[-1] < 1: raise ValueError(f"Point dimension must be at least 1, got {points.shape[-1]}") if points.shape[-2] < 1: raise ValueError(f"Must have at least 1 point, got {points.shape[-2]}") # Numerical validation - check for NaN and infinite values # Skip validation during JAX transformations (jit, vmap, etc.) # since traced arrays cannot be evaluated for concrete boolean values try: # This will work for concrete arrays but fail for traced arrays if jnp.any(jnp.isnan(points)): raise ValueError("Point cloud contains NaN values") if jnp.any(jnp.isinf(points)): raise ValueError("Point cloud contains infinite values") except jax.errors.TracerBoolConversionError: # During JAX transformations, skip numerical validation # Validation will happen when the function is actually called with concrete values pass return points
[docs] def generate_direction_vectors( dimension: int, n_directions: int, method: SamplingMethod = "uniform", random_key: Array | None = None ) -> DirectionVectors: """Generate direction vectors for sampling. Args: dimension: Spatial dimension n_directions: Number of directions to generate method: Sampling strategy - "uniform": Uniform distribution on sphere - "icosphere": Icosahedral subdivision (3D only) - "adaptive": Locally adaptive density sampling random_key: JAX random key (required for "uniform" and "adaptive") Returns: Normalized direction vector set with shape (n_directions, dimension) Raises: ValueError: Invalid parameters or unsupported combinations """ if dimension < 1: raise ValueError(f"Dimension must be at least 1, got {dimension}") if n_directions < 1: raise ValueError(f"Number of directions must be at least 1, got {n_directions}") if method == "uniform": if random_key is None: random_key = jax.random.PRNGKey(0) # Generate random vectors from standard normal distribution directions = jax.random.normal(random_key, (n_directions, dimension)) # Normalize to unit sphere norms = jnp.linalg.norm(directions, axis=1, keepdims=True) # Avoid division by zero norms = jnp.where(norms < 1e-12, 1.0, norms) directions = directions / norms return directions # type: ignore[no-any-return] elif method == "icosphere": if dimension != 3: raise ValueError("Icosphere method is only supported for 3D (dimension=3)") return _generate_icosphere_directions(n_directions) elif method == "adaptive": if random_key is None: random_key = jax.random.PRNGKey(0) # For now, use uniform sampling as placeholder # TODO: Implement proper adaptive sampling in future versions warnings.warn( "Adaptive sampling not yet implemented, falling back to uniform sampling", UserWarning, stacklevel=2 ) return generate_direction_vectors(dimension, n_directions, "uniform", random_key) else: raise ValueError(f"Unknown sampling method: {method}")
def _generate_icosphere_directions(n_directions: int) -> DirectionVectors: """Generate direction vectors using icosahedral subdivision. Args: n_directions: Target number of directions Returns: Direction vectors approximating uniform distribution on sphere """ # Base icosahedron vertices (12 vertices) phi = (1.0 + jnp.sqrt(5.0)) / 2.0 # Golden ratio # Icosahedron vertices vertices = jnp.array( [ [-1, phi, 0], [1, phi, 0], [-1, -phi, 0], [1, -phi, 0], [0, -1, phi], [0, 1, phi], [0, -1, -phi], [0, 1, -phi], [phi, 0, -1], [phi, 0, 1], [-phi, 0, -1], [-phi, 0, 1], ], dtype=jnp.float32, ) # Normalize vertices vertices = vertices / jnp.linalg.norm(vertices, axis=1, keepdims=True) # If we need more directions, we can add face centers and edge midpoints directions = vertices if n_directions > 12: # Add face centers (20 faces for icosahedron) # For simplicity, we'll just repeat and perturb vertices # TODO: Implement proper subdivision in future versions n_extra = n_directions - 12 if n_extra > 0: # Generate additional directions by slight perturbations key = jax.random.PRNGKey(42) perturbations = jax.random.normal(key, (n_extra, 3)) * 0.1 # Repeat vertices cyclically to match the number of extra directions needed extra_base = jnp.tile(vertices, (n_extra // 12 + 1, 1))[:n_extra] extra_vertices = extra_base + perturbations extra_vertices = extra_vertices / jnp.linalg.norm(extra_vertices, axis=1, keepdims=True) directions = jnp.concatenate([directions, extra_vertices], axis=0) # Truncate to exact number if we have too many directions = directions[:n_directions] return directions # type: ignore[no-any-return]
[docs] def robust_orientation_test(points: Array, tolerance: float = 1e-12) -> Array: """Robust geometric orientation test. Implements numerically stable orientation tests for geometric predicates. Based on Shewchuk (1997) adaptive precision predicates concepts. Args: points: Points to test with shape (..., n_points, dim) tolerance: Numerical tolerance for degeneracy detection Returns: Orientation indicators Note: This is a simplified implementation. Full robust predicates would require adaptive precision arithmetic. """ validate_point_cloud(points) # For now, implement basic numerical stability checks # TODO: Implement full Shewchuk adaptive precision predicates # Check for near-degenerate configurations if points.shape[-2] < points.shape[-1] + 1: # Not enough points for full-dimensional simplex return jnp.zeros(points.shape[:-2], dtype=bool) # Simple determinant-based orientation test for 2D/3D if points.shape[-1] == 2 and points.shape[-2] >= 3: # 2D orientation test p0, p1, p2 = points[..., 0, :], points[..., 1, :], points[..., 2, :] det = (p1[..., 0] - p0[..., 0]) * (p2[..., 1] - p0[..., 1]) - (p1[..., 1] - p0[..., 1]) * ( p2[..., 0] - p0[..., 0] ) return jnp.abs(det) > tolerance elif points.shape[-1] == 3 and points.shape[-2] >= 4: # 3D orientation test p0 = points[..., 0, :] v1 = points[..., 1, :] - p0 v2 = points[..., 2, :] - p0 v3 = points[..., 3, :] - p0 # Compute determinant of 3x3 matrix det = jnp.linalg.det(jnp.stack([v1, v2, v3], axis=-2)) return jnp.abs(det) > tolerance # For higher dimensions or insufficient points, use general approach return jnp.ones(points.shape[:-2], dtype=bool)
[docs] def compute_simplex_volume(vertices: Array) -> Array: """Compute volume of simplex defined by vertices. Args: vertices: Simplex vertices with shape (..., n_vertices, dim) where n_vertices = dim + 1 for full-dimensional simplex Returns: Volume of the simplex """ validate_point_cloud(vertices) n_vertices = vertices.shape[-2] dim = vertices.shape[-1] if n_vertices != dim + 1: raise ValueError(f"Expected {dim + 1} vertices for {dim}D simplex, got {n_vertices}") if dim == 0: return jnp.array(1.0) # Use the formula: |det(v1-v0, v2-v0, ..., vd-v0)| / d! v0 = vertices[..., 0, :] edge_vectors = vertices[..., 1:, :] - v0[..., None, :] # Compute determinant det = jnp.linalg.det(edge_vectors) # Volume is |det| / d! factorial = jnp.array([1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880][dim]) volume = jnp.abs(det) / factorial return volume
[docs] def project_to_simplex(point: Array, vertices: Array) -> tuple[Array, Array]: """Project point onto simplex defined by vertices. Args: point: Point to project with shape (..., dim) vertices: Simplex vertices with shape (..., n_vertices, dim) Returns: Tuple of (projected_point, barycentric_coordinates) """ validate_point_cloud(vertices) # This is a simplified implementation # TODO: Implement proper simplex projection algorithm # For now, just return closest vertex distances = jnp.linalg.norm(vertices - point[..., None, :], axis=-1) closest_idx = jnp.argmin(distances, axis=-1) vertices.shape[-2] barycentric = jnp.zeros(vertices.shape[:-1]) barycentric = barycentric.at[..., closest_idx].set(1.0) projected_point = jnp.sum(barycentric[..., :, None] * vertices, axis=-2) return projected_point, barycentric
[docs] def remove_duplicate_points(points: Array, tolerance: float = 1e-10) -> tuple[Array, Array]: """Remove duplicate points within tolerance. Args: points: Point cloud with shape (..., n_points, dim) tolerance: Distance tolerance for considering points duplicate Returns: Tuple of (unique_points, unique_indices) """ validate_point_cloud(points) # This is a simplified O(n²) implementation # TODO: Implement more efficient algorithm for large point sets n_points = points.shape[-2] points.shape[-1] # For JAX JIT compatibility, use a simpler approach that avoids boolean indexing # Just return all points for now (placeholder implementation) # In a production implementation, we would use JAX-compatible algorithms # such as sorting-based deduplication or other concrete indexing methods # Create identity indices (no duplicates removed in this simplified version) unique_indices = jnp.arange(n_points) unique_points = points return unique_points, unique_indices
[docs] def scale_to_unit_ball(points: Array) -> tuple[Array, tuple[Array, float]]: """Scale point cloud to fit in unit ball. Args: points: Point cloud with shape (..., n_points, dim) Returns: Tuple of (scaled_points, (center, scale_factor)) """ validate_point_cloud(points) # Compute center and scale center = jnp.mean(points, axis=-2, keepdims=True) centered_points = points - center max_distance = jnp.max(jnp.linalg.norm(centered_points, axis=-1)) scale_factor = jnp.where(max_distance > 1e-12, max_distance, 1.0) scaled_points = centered_points / scale_factor return scaled_points, (center.squeeze(-2), float(scale_factor))
[docs] def unscale_from_unit_ball(points: Array, transform_params: tuple[Array, float]) -> Array: """Reverse scaling from unit ball. Args: points: Scaled point cloud transform_params: (center, scale_factor) from scale_to_unit_ball Returns: Original scale point cloud """ center, scale_factor = transform_params return points * scale_factor + center