// @ Luk Gajdoech 2019

#include "Octree.h"
#include "OctreeNode.h"
#include "PointCloud.h"

#include "Engine.h"
#include "DrawDebugHelpers.h"

#include "Data/PointCloudScan.h"

void UOctree::Initialize(UPointCloudScan *scan)
{
  bInitialized = true;
  bDraw = false;
  parentScan = scan;
  float minExtent = scan->ScanMinExtent();
  float maxExtent = scan->ScanMaxExtent();
  float center = (minExtent + maxExtent) / 2;
  FVector pos = FVector(center, center, center);
  root = NewObject<UOctreeNode>();
  root->SetParentOctree(this);
  root->SetPosition(pos);
  root->SetSize(maxExtent - minExtent);
  int32 n = scan->NumberOfPoints();
  for (int32 i = 0; i < n; i++)
  {
    root->pointsIndices.Add(i);
  }
  int32 maxPoints = FMath::Max(32, n / 100);
  root->Subdivide(maxPoints);
  sphereRadius = OptimalRadius(0.25f) * 2;
}

float UOctree::OptimalRadius(float neighborDistance)
{
  int32 n = parentScan->NumberOfPoints();
  int32 pointIndex = n / 2;
  if (pointIndex >= n)
  {
    return 0;
  }
  TSet<int32> neighborhood = parentScan->PointNeighborhood(pointIndex, neighborDistance);
  if (neighborhood.Num() <= 1)
  {
    return OptimalRadius(neighborDistance * 2);
  }
  float total = 0;
  FVector myLocation = parentScan->GetPointCloud()->GetPointCloudData()[pointIndex].Location;
  for (int32 index : neighborhood)
  {
    FVector otherLocation = parentScan->GetPointCloud()->GetPointCloudData()[index].Location;
    total += FVector::Dist(myLocation, otherLocation);
  }
  float result = total / (neighborhood.Num() - 1);
  return result;
}

void UOctree::Draw() const
{
  if (bDraw && bInitialized)
  {
    root->Draw();
  }
}

int32 UOctree::NodeCount() const
{
  if (bInitialized && root != nullptr)
  {
    return root->NodeCount();
  }
  return 0;
}

UOctreeNode *UOctree::ClosestIntersectedNode(FVector rayOrigin, FVector rayDirection)
{
  TArray<UOctreeNode *> nodes;
  root->GatherIntersectedNodes(rayOrigin, rayDirection, nodes);
  UOctreeNode *closestNode = FindClosestNode(rayOrigin, nodes);
  return closestNode;
}

UOctreeNode *UOctree::FindClosestNode(FVector rayOrigin, TArray<UOctreeNode *> &nodes)
{
  if (nodes.Num() == 0)
  {
    return nullptr;
  }

  UOctreeNode *closestNode = nullptr;
  for (UOctreeNode *node : nodes)
  {
    if (closestNode == nullptr || FVector::Dist(closestNode->GetPosition(), rayOrigin) > FVector::Dist(node->GetPosition(), rayOrigin))
    {
      closestNode = node;
    }
  }
  return closestNode;
}

void UOctree::GatherIntersectedNodes(FVector rayOrigin, FVector rayDirection, TArray<UOctreeNode *> &nodes)
{
  if (bInitialized && root != nullptr)
  {
    root->GatherIntersectedNodes(rayOrigin, rayDirection, nodes);
  }
}

UOctreeNode *UOctree::FindNode(FVector location)
{
  return root->FindNode(location);
}

UOctreeNode *UOctree::GetRoot()
{
  return root;
}


