
/* osgEarth
* Copyright 2008-2016 Pelican Mapping
* MIT License
*/
#pragma once

#include <osgEarth/Common>
#include <osgEarth/Units>
#include <osgEarth/GeoData>
#include <osgEarth/TileKey>
#include <osgEarth/Math>
#include <osg/Texture2D>

namespace osgEarth
{
    constexpr int ELEVATION_TILE_SIZE = 257;

    class Map;
    class ProgressCallback;

    extern OSGEARTH_EXPORT osg::Texture* createEmptyElevationTile();

    extern OSGEARTH_EXPORT osg::Texture* createEmptyNormalMapTexture();

    /**
     * Result of an elevation sampling query
     */
    class OSGEARTH_EXPORT ElevationSample
    {
    public:
        ElevationSample() : _e(NO_DATA_VALUE, Units::METERS), _r(0.0, Units::METERS) { }
        ElevationSample(const Distance& e, const Distance& r) : _e(e), _r(r) { }
        const Distance& elevation() const { return _e; }
        const Distance& resolution() const { return _r; }
        bool hasData() const { return _e.getValue() != NO_DATA_VALUE && _r.getValue() > 0.0f; }
    private:
        Distance _e, _r;
    };

    /**
     * Elevation grid as a texture with in optional associated normal map.
     */
    class OSGEARTH_EXPORT ElevationTile : public osg::Referenced
    {
    public:
        ElevationTile(
            const TileKey& key,
            const GeoHeightField& hf,
            std::vector<float>&& resolutions);

        //! Gets the elevation at the map coordinates. These coordinates must
        //! be in the SRS used to create the texture.
        ElevationSample getElevation(double x, double y) const;

        //! Extent of the texture
        const GeoExtent& getExtent() const { return _extent; }

        //! Min and Max heights in the texture
        const std::pair<float, float> getMaxima() const { return _maxima; }

        //! Whether the data in this tile represents all native resolution data
        //! (versus a combination of native and fallback data)
        bool allHeightsAtNativeResolution() const { return _allHeightsAtNativeResolution; }

        //! TileKey used to create this elevation grid
        const TileKey& getTileKey() const { return _tilekey; }

        //! Elevation data texture
        osg::Texture2D* getElevationTile() const {
            return _elevationTex.get();
        }

        //! Normal map associated with the elevation data, if available.
        osg::Texture2D* getNormalMapTexture() const {
            return _normalTex.get();
        }

        //! The normal at map coordinates, or (0,0,1) if the tile does not
        //! contain a normal map.
        osg::Vec3 getNormal(double x, double y) const;

        //! Generates a normal map for this object.
        void generateNormalMap(const Map* map, unsigned size, void* workingSet, ProgressCallback* progress);

        //! Direct access to the pixel reader
        const ImageUtils::PixelReader& reader() const { return _read; }

        //! Matching list of sample resolutions
        const std::vector<float>& getResolutions() const { return _resolutions; }

        //! Get the resolution at s,t
        inline float getResolution(int s, int t) const {
            return _resolutions[t * _read.s() + s];
        }
        inline float getResolutionUV(double u, double v) const {
            return getResolution(
                (int)(u * (double)(_read.s() - 1)),
                (int)(v * (double)(_read.t() - 1)));
        }

        //! The heightfield that was used to populate this object
        const osg::HeightField* getHeightField() const {
            return _heightField.get();
        }

        //! Supported elevation encodings
        enum class Encoding { R16, RG8, R32F };

        //! Which encoding matches the given GL internal format
        static inline Encoding encodingFor(GLenum internalFormat) {
            return internalFormat == GL_R16 ? Encoding::R16 :
                internalFormat == GL_RG8 ? Encoding::RG8 :
                internalFormat == GL_R32F ? Encoding::R32F :
                Encoding::R16;
        }

        //! Decodes the elevation from a sampled pixel value
        static inline float decodeElevation(const osg::Vec4& encoded, Encoding encoding, float lo, float hi) {
            return
                encoding == Encoding::R16 ? encoded.r() * 1.5259021e-5 * (hi - lo) + lo :
                encoding == Encoding::RG8 ? (encoded.r() * 65280.f + encoded.g() * 255.0) * 1.5259021e-5 * (hi - lo) + lo :
                encoded.r();
        }

        //! Encoding used in this tile
        inline Encoding encoding() const {
            return _encoding;
        }

        //! Gets the raw elevation floating point value at the normalized location
        inline float getRawElevationUV(double u, double v) const
        {
            return decodeElevation(_read(u, v), _encoding, _maxima.first, _maxima.second);
        }

        //! Gets the elevation using normalized [0..1] coordinates.
        inline ElevationSample getElevationUV(double u, double v) const
        {
            u = osgEarth::clamp(u, 0.0, 1.0), v = osgEarth::clamp(v, 0.0, 1.0);
            auto h = decodeElevation(_read(u, v, 0, 0), _encoding, _maxima.first, _maxima.second);
            return ElevationSample(Distance(h, Units::METERS), _resolution);
        }

    private:
        TileKey _tilekey;
        GeoExtent _extent;
        std::pair<float, float> _maxima = { 0.0f, 0.0f };
        Distance _resolution;
        ImageUtils::PixelReader _read;
        ImageUtils::PixelReader _readNormal;
        osg::ref_ptr<osg::Texture2D> _elevationTex;
        osg::ref_ptr<osg::Texture2D> _normalTex;
        osg::ref_ptr<const osg::HeightField> _heightField;
        std::vector<float> _resolutions;
        std::mutex _mutex;
        bool _allHeightsAtNativeResolution = true;

        Encoding _encoding = Encoding::R16;
    };

    /**
     * Utility class that makes normal map texture for the given tile key
     */
    class OSGEARTH_EXPORT NormalMapGenerator
    {
    public:
        osg::Texture2D* createNormalMap(
            const TileKey& key,
            const class Map* map,
            unsigned tileSize,
            bool assumeNativeResolution,
            void* workingSet,
            ProgressCallback* progress);

        //! Packs a 3-vec normal into RG (octohedral compression)
        inline static osg::Vec4 pack(const osg::Vec3& normal);

        //! Unpacks the RG packed normal into a 3-vec.
        inline static osg::Vec3 unpack(const osg::Vec4& packed);
    };

    //! Revisioned key for elevation lookups (internal)
    namespace Internal
    {
        struct RevElevationKey
        {
            TileKey _tilekey;
            std::size_t _hash;

            inline bool operator < (const RevElevationKey& rhs) const {
                if ( _tilekey < rhs._tilekey ) return true;
                if ( rhs._tilekey < _tilekey ) return false;
                return _hash < rhs._hash;
            }
            inline bool operator == (const RevElevationKey& rhs) const {
                return 
                    _tilekey == rhs._tilekey &&
                    _hash == rhs._hash;
            }
            inline bool operator != (const RevElevationKey& rhs) const {
                return 
                    _tilekey != rhs._tilekey ||
                    _hash != rhs._hash;
            }
            inline std::size_t hash() const {
                return osgEarth::hash_value_unsigned(_tilekey.hash(), (std::size_t)_hash);
            }
        };
    }


    inline osg::Vec4 NormalMapGenerator::pack(const osg::Vec3& n)
    {
        osg::Vec4 p;
        // octohodreal normal packing
        float d = 1.0 / (fabs(n.x()) + fabs(n.y()) + fabs(n.z()));
        p.x() = n.x() * d;
        p.y() = n.y() * d;

        if (n.z() < 0.0)
        {
            p.x() = (1.0 - fabs(p.y())) * (p.x() >= 0.0 ? 1.0 : -1.0);
            p.y() = (1.0 - fabs(p.x())) * (p.y() >= 0.0 ? 1.0 : -1.0);
        }

        p.x() = 0.5f * (p.x() + 1.0f);
        p.y() = 0.5f * (p.y() + 1.0f);
        return p;
    }

    inline osg::Vec3 NormalMapGenerator::unpack(const osg::Vec4& packed)
    {
        osg::Vec3 normal;
        normal.x() = packed.x() * 2.0 - 1.0;
        normal.y() = packed.y() * 2.0 - 1.0;
        normal.z() = 1.0 - fabs(normal.x()) - fabs(normal.y());
        float t = clamp(-normal.z(), 0.0f, 1.0f);
        normal.x() += (normal.x() > 0) ? -t : t;
        normal.y() += (normal.y() > 0) ? -t : t;
        normal.normalize();
        return normal;
    }
}

namespace std {
    // std::hash specialization for Key
    template<> struct hash<osgEarth::Internal::RevElevationKey> {
        inline size_t operator()(const osgEarth::Internal::RevElevationKey& value) const {
            return value.hash();
        }
    };
}
