/*
 ==============================================================================
 This file is part of the IEM plug-in suite.
 Authors: Felix Holzmüller, based on the SceneRotator code by Daniel Rudrich
 Copyright (c) 2024 - Institute of Electronic Music and Acoustics (IEM)
 https://iem.at

 The IEM plug-in suite is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation, either version 3 of the License, or
 (at your option) any later version.

 The IEM plug-in suite is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.

 You should have received a copy of the GNU General Public License
 along with this software.  If not, see <https://www.gnu.org/licenses/>.
 ==============================================================================
 */
#include <JuceHeader.h>

#include "Conversions.h"
#include "Quaternion.h"
#include "ambisonicTools.h"

template <int maxOrder = 7>
class AmbisonicRotator
{
public:
    AmbisonicRotator()
    {
        orderMatrices.add (new juce::dsp::Matrix<float> (0, 0)); // 0th
        orderMatricesCopy.add (new juce::dsp::Matrix<float> (0, 0)); // 0th

        for (int l = 1; l <= maxOrder; ++l)
        {
            const int nCh = (2 * l + 1);
            auto elem = orderMatrices.add (new juce::dsp::Matrix<float> (nCh, nCh));
            elem->clear();
            auto elemCopy = orderMatricesCopy.add (new juce::dsp::Matrix<float> (nCh, nCh));
            elemCopy->clear();
        }

        copyBuffer.setSize (maxChannels, maxChannels);
    };

    ~AmbisonicRotator() {};

    void process (juce::AudioBuffer<float>* bufferToRotate)
    {
        // Get samples per block and actual number of channels
        const int samples = bufferToRotate->getNumSamples();
        const int bufferChannels = bufferToRotate->getNumChannels();

        const int workingOrder =
            juce::jmin (isqrt (bufferToRotate->getNumChannels()) - 1, orderSetting, maxOrder);

        const int actualChannels = squares[workingOrder + 1];

        // Resize copyBuffer if necessary
        if ((copyBuffer.getNumChannels() != actualChannels)
            || (copyBuffer.getNumSamples() != samples))
            copyBuffer.setSize (actualChannels, samples);

        // Calculate new rotation matrix if necessary
        if (workingOrder != orderSetting)
        {
            orderSetting = workingOrder;
            rotationParamsHaveChanged = true;
        }

        bool newRotationMatrix = false;
        if (rotationParamsHaveChanged.get())
        {
            newRotationMatrix = true;
            calcRotationMatrix (workingOrder);
        }

        // make copy of input
        for (int ch = 0; ch < actualChannels; ++ch)
            copyBuffer.copyFrom (ch, 0, bufferToRotate->getReadPointer (ch), samples);

        // clear all channels except first
        for (int ch = 1; ch < bufferToRotate->getNumChannels(); ++ch)
            bufferToRotate->clear (ch, 0, samples);

        // rotate buffer
        for (int l = 1; l <= workingOrder; ++l)
        {
            const int offset = l * l;
            const int nCh = 2 * l + 1;
            auto R = orderMatrices[l];
            auto Rcopy = orderMatricesCopy[l];
            for (int o = 0; o < nCh; ++o)
            {
                const int chOut = offset + o;
                for (int p = 0; p < nCh; ++p)
                {
                    bufferToRotate->addFromWithRamp (chOut,
                                                     0,
                                                     copyBuffer.getReadPointer (offset + p),
                                                     samples,
                                                     Rcopy->operator() (o, p),
                                                     R->operator() (o, p));
                }
            }
        }

        // make copies for fading between old and new matrices
        if (newRotationMatrix)
            for (int l = 1; l <= workingOrder; ++l)
                *orderMatricesCopy[l] = *orderMatrices[l];
    }

    void updateParams (float yaw, float pitch, float roll, float order)
    {
        yawRadians = Conversions<float>::degreesToRadians (yaw);
        pitchRadians = Conversions<float>::degreesToRadians (pitch);
        rollRadians = Conversions<float>::degreesToRadians (roll);
        orderSetting = order;

        rotationParamsHaveChanged = true;
    }

    const int getOrder() { return orderSetting; }

private:
    double
        P (int i, int l, int a, int b, juce::dsp::Matrix<float>& R1, juce::dsp::Matrix<float>& Rlm1)
    {
        double ri1 = R1 (i + 1, 2);
        double rim1 = R1 (i + 1, 0);
        double ri0 = R1 (i + 1, 1);

        if (b == -l)
            return ri1 * Rlm1 (a + l - 1, 0) + rim1 * Rlm1 (a + l - 1, 2 * l - 2);
        else if (b == l)
            return ri1 * Rlm1 (a + l - 1, 2 * l - 2) - rim1 * Rlm1 (a + l - 1, 0);
        else
            return ri0 * Rlm1 (a + l - 1, b + l - 1);
    };

    double U (int l, int m, int n, juce::dsp::Matrix<float>& Rone, juce::dsp::Matrix<float>& Rlm1)
    {
        return P (0, l, m, n, Rone, Rlm1);
    }

    double V (int l, int m, int n, juce::dsp::Matrix<float>& Rone, juce::dsp::Matrix<float>& Rlm1)
    {
        if (m == 0)
        {
            auto p0 = P (1, l, 1, n, Rone, Rlm1);
            auto p1 = P (-1, l, -1, n, Rone, Rlm1);
            return p0 + p1;
        }
        else if (m > 0)
        {
            auto p0 = P (1, l, m - 1, n, Rone, Rlm1);
            if (m == 1) // d = 1;
                return p0 * sqrt (2);
            else // d = 0;
                return p0 - P (-1, l, 1 - m, n, Rone, Rlm1);
        }
        else
        {
            auto p1 = P (-1, l, -m - 1, n, Rone, Rlm1);
            if (m == -1) // d = 1;
                return p1 * sqrt (2);
            else // d = 0;
                return p1 + P (1, l, m + 1, n, Rone, Rlm1);
        }
    }

    double W (int l, int m, int n, juce::dsp::Matrix<float>& Rone, juce::dsp::Matrix<float>& Rlm1)
    {
        if (m > 0)
        {
            auto p0 = P (1, l, m + 1, n, Rone, Rlm1);
            auto p1 = P (-1, l, -m - 1, n, Rone, Rlm1);
            return p0 + p1;
        }
        else if (m < 0)
        {
            auto p0 = P (1, l, m - 1, n, Rone, Rlm1);
            auto p1 = P (-1, l, 1 - m, n, Rone, Rlm1);
            return p0 - p1;
        }

        return 0.0;
    }

    void calcRotationMatrix (const int order)
    {
        auto ca = std::cos (yawRadians);
        auto cb = std::cos (pitchRadians);
        auto cy = std::cos (rollRadians);

        auto sa = std::sin (yawRadians);
        auto sb = std::sin (pitchRadians);
        auto sy = std::sin (rollRadians);

        juce::dsp::Matrix<float> rotMat (3, 3);

        // if (*rotationSequence >= 0.5f) // roll -> pitch -> yaw (extrinsic rotations)
        // {
        //     rotMat (0, 0) = ca * cb;
        //     rotMat (1, 0) = sa * cb;
        //     rotMat (2, 0) = -sb;

        //     rotMat (0, 1) = ca * sb * sy - sa * cy;
        //     rotMat (1, 1) = sa * sb * sy + ca * cy;
        //     rotMat (2, 1) = cb * sy;

        //     rotMat (0, 2) = ca * sb * cy + sa * sy;
        //     rotMat (1, 2) = sa * sb * cy - ca * sy;
        //     rotMat (2, 2) = cb * cy;
        // }
        // else // yaw -> pitch -> roll (extrinsic rotations)
        // {
        rotMat (0, 0) = ca * cb;
        rotMat (1, 0) = sa * cy + ca * sb * sy;
        rotMat (2, 0) = sa * sy - ca * sb * cy;

        rotMat (0, 1) = -sa * cb;
        rotMat (1, 1) = ca * cy - sa * sb * sy;
        rotMat (2, 1) = ca * sy + sa * sb * cy;

        rotMat (0, 2) = sb;
        rotMat (1, 2) = -cb * sy;
        rotMat (2, 2) = cb * cy;
        // }

        auto Rl = orderMatrices[1];

        Rl->operator() (0, 0) = rotMat (1, 1);
        Rl->operator() (0, 1) = rotMat (1, 2);
        Rl->operator() (0, 2) = rotMat (1, 0);
        Rl->operator() (1, 0) = rotMat (2, 1);
        Rl->operator() (1, 1) = rotMat (2, 2);
        Rl->operator() (1, 2) = rotMat (2, 0);
        Rl->operator() (2, 0) = rotMat (0, 1);
        Rl->operator() (2, 1) = rotMat (0, 2);
        Rl->operator() (2, 2) = rotMat (0, 0);

        for (int l = 2; l <= order; ++l)
        {
            auto Rone = orderMatrices[1];
            auto Rlm1 = orderMatrices[l - 1];
            auto Rl = orderMatrices[l];
            for (int m = -l; m <= l; ++m)
            {
                for (int n = -l; n <= l; ++n)
                {
                    const int d = (m == 0) ? 1 : 0;
                    double denom;
                    if (abs (n) == l)
                        denom = (2 * l) * (2 * l - 1);
                    else
                        denom = l * l - n * n;

                    double u = sqrt ((l * l - m * m) / denom);
                    double v = sqrt ((1.0 + d) * (l + abs (m) - 1.0) * (l + abs (m)) / denom)
                               * (1.0 - 2.0 * d) * 0.5;
                    double w =
                        sqrt ((l - abs (m) - 1.0) * (l - abs (m)) / denom) * (1.0 - d) * (-0.5);

                    if (u != 0.0)
                        u *= U (l, m, n, *Rone, *Rlm1);
                    if (v != 0.0)
                        v *= V (l, m, n, *Rone, *Rlm1);
                    if (w != 0.0)
                        w *= W (l, m, n, *Rone, *Rlm1);

                    Rl->operator() (m + l, n + l) = u + v + w;
                }
            }
        }

        rotationParamsHaveChanged = false;
    }

    static constexpr int maxChannels = (maxOrder + 1) * (maxOrder + 1);

    float yawRadians { 0.0f };
    float pitchRadians { 0.0f };
    float rollRadians { 0.0f };
    int orderSetting { 0 };

    juce::AudioBuffer<float> copyBuffer;

    juce::OwnedArray<juce::dsp::Matrix<float>> orderMatrices;
    juce::OwnedArray<juce::dsp::Matrix<float>> orderMatricesCopy;
    juce::Atomic<bool> rotationParamsHaveChanged { true };
};