
Problem statement

In an n-dimensional, given a 2D subspace and a target direction (unit vector) , how to rotate a random vector along such that the projection of on to aligned with .


Naive approach

A naive approach would be to Rotate any vector to a target angle in high dimensional space, which involves the following steps:

  1. Project onto the mapping of to get
  2. Find the angle between and
  3. Calculate the rotation matrix using
    Transclude of Rotate-from-one-vector-to-another-vector-in-high-dimensional-space#^e160cf
  4. Apply on

Each of these steps will involve 1 matrix multiplication, except for step 3 (if using eq. (2) with precomputed values), resulting in 3 matrix multiplications in total. This is not efficient and we can do better.

Rotation with only 1 matrix multiplication

Looking at eq. (1) from above, computes the to-be-rotated component of by

  • map down to
  • do the rotation by
  • then map back up

Since this transformation preserves the norm, we can instead precompute this for the unit vector, then scale the result by .

Thus, the transformation becomes:

with be the angle between and the unit vector .

As the result, only one matrix multiplication is needed to compute , other multiplications can be precomputed.


def rotate_to_target(x, target_degree, basis1, basis2):      
    assert len(basis1.shape) == 1
    assert len(basis2.shape) == 1
    assert basis1.shape == basis2.shape
    n = basis1.shape[-1]
    # ensure bases are orthonormal
    u = basis1 / np.linalg.norm(basis1)
    v = basis2 - (basis2 @ u) * u
    v /= np.linalg.norm(v)
    theta = np.deg2rad(target_degree)
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    P = np.outer(u, u) + np.outer(v, v)
    # rotate counter-clockwise
    R_theta = [
        [cos_theta, -sin_theta],
        [sin_theta, cos_theta]
    uv = np.column_stack([u, v])
    rotated_component = uv @ R_theta @ np.array([1, 0])
    Px = x @ P
    scale = np.linalg.norm(Px, axis=-1, keepdims=True)
    result = x - Px + scale * rotated_component
    return result