xenia-canary/third_party/crunch/crnlib/crn_clusterizer.h

765 lines
22 KiB
C++

// File: crn_clusterizer.h
// See Copyright Notice and license at the end of inc/crnlib.h
#pragma once
#include "crn_matrix.h"
namespace crnlib
{
template<typename VectorType>
class clusterizer
{
public:
clusterizer() :
m_overall_variance(0.0f),
m_split_index(0),
m_heap_size(0),
m_quick(false)
{
}
void clear()
{
m_training_vecs.clear();
m_codebook.clear();
m_nodes.clear();
m_overall_variance = 0.0f;
m_split_index = 0;
m_heap_size = 0;
m_quick = false;
}
void reserve_training_vecs(uint num_expected)
{
m_training_vecs.reserve(num_expected);
}
void add_training_vec(const VectorType& v, uint weight)
{
m_training_vecs.push_back( std::make_pair(v, weight) );
}
typedef bool (*progress_callback_func_ptr)(uint percentage_completed, void* pData);
bool generate_codebook(uint max_size, progress_callback_func_ptr pProgress_callback = NULL, void* pProgress_data = NULL, bool quick = false)
{
if (m_training_vecs.empty())
return false;
m_quick = quick;
double ttsum = 0.0f;
vq_node root;
root.m_vectors.reserve(m_training_vecs.size());
for (uint i = 0; i < m_training_vecs.size(); i++)
{
const VectorType& v = m_training_vecs[i].first;
const uint weight = m_training_vecs[i].second;
root.m_centroid += (v * (float)weight);
root.m_total_weight += weight;
root.m_vectors.push_back(i);
ttsum += v.dot(v) * weight;
}
root.m_variance = (float)(ttsum - (root.m_centroid.dot(root.m_centroid) / root.m_total_weight));
root.m_centroid *= (1.0f / root.m_total_weight);
m_nodes.clear();
m_nodes.reserve(max_size * 2 + 1);
m_nodes.push_back(root);
m_heap.resize(max_size + 1);
m_heap[1] = 0;
m_heap_size = 1;
m_split_index = 0;
uint total_leaves = 1;
m_left_children.reserve(m_training_vecs.size() + 1);
m_right_children.reserve(m_training_vecs.size() + 1);
int prev_percentage = -1;
while ((total_leaves < max_size) && (m_heap_size))
{
int worst_node_index = m_heap[1];
m_heap[1] = m_heap[m_heap_size];
m_heap_size--;
if (m_heap_size)
down_heap(1);
split_node(worst_node_index);
total_leaves++;
if ((pProgress_callback) && ((total_leaves & 63) == 0) && (max_size))
{
int cur_percentage = (total_leaves * 100U + (max_size / 2U)) / max_size;
if (cur_percentage != prev_percentage)
{
if (!(*pProgress_callback)(cur_percentage, pProgress_data))
return false;
prev_percentage = cur_percentage;
}
}
}
m_codebook.clear();
m_overall_variance = 0.0f;
for (uint i = 0; i < m_nodes.size(); i++)
{
vq_node& node = m_nodes[i];
if (node.m_left != -1)
{
CRNLIB_ASSERT(node.m_right != -1);
continue;
}
CRNLIB_ASSERT((node.m_left == -1) && (node.m_right == -1));
node.m_codebook_index = m_codebook.size();
m_codebook.push_back(node.m_centroid);
m_overall_variance += node.m_variance;
}
m_heap.clear();
m_left_children.clear();
m_right_children.clear();
return true;
}
inline uint get_num_training_vecs() const { return m_training_vecs.size(); }
const VectorType& get_training_vec(uint index) const { return m_training_vecs[index].first; }
uint get_training_vec_weight(uint index) const { return m_training_vecs[index].second; }
typedef crnlib::vector< std::pair<VectorType, uint> > training_vec_array;
const training_vec_array& get_training_vecs() const { return m_training_vecs; }
training_vec_array& get_training_vecs() { return m_training_vecs; }
inline float get_overall_variance() const { return m_overall_variance; }
inline uint get_codebook_size() const
{
return m_codebook.size();
}
inline const VectorType& get_codebook_entry(uint index) const
{
return m_codebook[index];
}
VectorType& get_codebook_entry(uint index)
{
return m_codebook[index];
}
typedef crnlib::vector<VectorType> vector_vec_type;
inline const vector_vec_type& get_codebook() const
{
return m_codebook;
}
uint find_best_codebook_entry(const VectorType& v) const
{
uint cur_node_index = 0;
for ( ; ; )
{
const vq_node& cur_node = m_nodes[cur_node_index];
if (cur_node.m_left == -1)
return cur_node.m_codebook_index;
const vq_node& left_node = m_nodes[cur_node.m_left];
const vq_node& right_node = m_nodes[cur_node.m_right];
float left_dist = left_node.m_centroid.squared_distance(v);
float right_dist = right_node.m_centroid.squared_distance(v);
if (left_dist < right_dist)
cur_node_index = cur_node.m_left;
else
cur_node_index = cur_node.m_right;
}
}
const VectorType& find_best_codebook_entry(const VectorType& v, uint max_codebook_size) const
{
uint cur_node_index = 0;
for ( ; ; )
{
const vq_node& cur_node = m_nodes[cur_node_index];
if ((cur_node.m_left == -1) || ((cur_node.m_codebook_index + 1) >= (int)max_codebook_size))
return cur_node.m_centroid;
const vq_node& left_node = m_nodes[cur_node.m_left];
const vq_node& right_node = m_nodes[cur_node.m_right];
float left_dist = left_node.m_centroid.squared_distance(v);
float right_dist = right_node.m_centroid.squared_distance(v);
if (left_dist < right_dist)
cur_node_index = cur_node.m_left;
else
cur_node_index = cur_node.m_right;
}
}
uint find_best_codebook_entry_fs(const VectorType& v) const
{
float best_dist = math::cNearlyInfinite;
uint best_index = 0;
for (uint i = 0; i < m_codebook.size(); i++)
{
float dist = m_codebook[i].squared_distance(v);
if (dist < best_dist)
{
best_dist = dist;
best_index = i;
if (best_dist == 0.0f)
break;
}
}
return best_index;
}
void retrieve_clusters(uint max_clusters, crnlib::vector< crnlib::vector<uint> >& clusters) const
{
clusters.resize(0);
clusters.reserve(max_clusters);
crnlib::vector<uint> stack;
stack.reserve(512);
uint cur_node_index = 0;
for ( ; ; )
{
const vq_node& cur_node = m_nodes[cur_node_index];
if ( (cur_node.is_leaf()) || ((cur_node.m_codebook_index + 2) > (int)max_clusters) )
{
clusters.resize(clusters.size() + 1);
clusters.back() = cur_node.m_vectors;
if (stack.empty())
break;
cur_node_index = stack.back();
stack.pop_back();
continue;
}
cur_node_index = cur_node.m_left;
stack.push_back(cur_node.m_right);
}
}
private:
training_vec_array m_training_vecs;
struct vq_node
{
vq_node() : m_centroid(cClear), m_total_weight(0), m_left(-1), m_right(-1), m_codebook_index(-1), m_unsplittable(false) { }
VectorType m_centroid;
uint64 m_total_weight;
float m_variance;
crnlib::vector<uint> m_vectors;
int m_left;
int m_right;
int m_codebook_index;
bool m_unsplittable;
bool is_leaf() const { return m_left < 0; }
};
typedef crnlib::vector<vq_node> node_vec_type;
node_vec_type m_nodes;
vector_vec_type m_codebook;
float m_overall_variance;
uint m_split_index;
crnlib::vector<uint> m_heap;
uint m_heap_size;
bool m_quick;
void insert_heap(uint node_index)
{
const float variance = m_nodes[node_index].m_variance;
uint pos = ++m_heap_size;
if (m_heap_size >= m_heap.size())
m_heap.resize(m_heap_size + 1);
for ( ; ; )
{
uint parent = pos >> 1;
if (!parent)
break;
float parent_variance = m_nodes[m_heap[parent]].m_variance;
if (parent_variance > variance)
break;
m_heap[pos] = m_heap[parent];
pos = parent;
}
m_heap[pos] = node_index;
}
void down_heap(uint pos)
{
uint child;
uint orig = m_heap[pos];
const float orig_variance = m_nodes[orig].m_variance;
while ((child = (pos << 1)) <= m_heap_size)
{
if (child < m_heap_size)
{
if (m_nodes[m_heap[child]].m_variance < m_nodes[m_heap[child + 1]].m_variance)
child++;
}
if (orig_variance > m_nodes[m_heap[child]].m_variance)
break;
m_heap[pos] = m_heap[child];
pos = child;
}
m_heap[pos] = orig;
}
void compute_split_estimate(VectorType& left_child_res, VectorType& right_child_res, const vq_node& parent_node)
{
VectorType furthest(0);
double furthest_dist = -1.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
double dist = v.squared_distance(parent_node.m_centroid);
if (dist > furthest_dist)
{
furthest_dist = dist;
furthest = v;
}
}
VectorType opposite(0);
double opposite_dist = -1.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
double dist = v.squared_distance(furthest);
if (dist > opposite_dist)
{
opposite_dist = dist;
opposite = v;
}
}
left_child_res = (furthest + parent_node.m_centroid) * .5f;
right_child_res = (opposite + parent_node.m_centroid) * .5f;
}
void compute_split_pca(VectorType& left_child_res, VectorType& right_child_res, const vq_node& parent_node)
{
if (parent_node.m_vectors.size() == 2)
{
left_child_res = m_training_vecs[parent_node.m_vectors[0]].first;
right_child_res = m_training_vecs[parent_node.m_vectors[1]].first;
return;
}
const uint N = VectorType::num_elements;
matrix<N, N, float> covar;
covar.clear();
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const VectorType v(m_training_vecs[parent_node.m_vectors[i]].first - parent_node.m_centroid);
const VectorType w(v * (float)m_training_vecs[parent_node.m_vectors[i]].second);
for (uint x = 0; x < N; x++)
for (uint y = x; y < N; y++)
covar[x][y] = covar[x][y] + v[x] * w[y];
}
float one_over_total_weight = 1.0f / parent_node.m_total_weight;
for (uint x = 0; x < N; x++)
for (uint y = x; y < N; y++)
covar[x][y] *= one_over_total_weight;
for (uint x = 0; x < (N - 1); x++)
for (uint y = x + 1; y < N; y++)
covar[y][x] = covar[x][y];
VectorType axis;//(1.0f);
if (N == 1)
axis.set(1.0f);
else
{
for (uint i = 0; i < N; i++)
axis[i] = math::lerp(.75f, 1.25f, i * (1.0f / math::maximum<int>(N - 1, 1)));
}
VectorType prev_axis(axis);
for (uint iter = 0; iter < 10; iter++)
{
VectorType x;
double max_sum = 0;
for (uint i = 0; i < N; i++)
{
double sum = 0;
for (uint j = 0; j < N; j++)
sum += axis[j] * covar[i][j];
x[i] = static_cast<float>(sum);
max_sum = math::maximum(max_sum, fabs(sum));
}
if (max_sum != 0.0f)
x *= static_cast<float>(1.0f / max_sum);
VectorType delta_axis(prev_axis - x);
prev_axis = axis;
axis = x;
if (delta_axis.norm() < .0025f)
break;
}
axis.normalize();
VectorType left_child(0.0f);
VectorType right_child(0.0f);
double left_weight = 0.0f;
double right_weight = 0.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const float weight = (float)m_training_vecs[parent_node.m_vectors[i]].second;
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
double t = (v - parent_node.m_centroid) * axis;
if (t < 0.0f)
{
left_child += v * weight;
left_weight += weight;
}
else
{
right_child += v * weight;
right_weight += weight;
}
}
if ((left_weight > 0.0f) && (right_weight > 0.0f))
{
left_child_res = left_child * (float)(1.0f / left_weight);
right_child_res = right_child * (float)(1.0f / right_weight);
}
else
{
compute_split_estimate(left_child_res, right_child_res, parent_node);
}
}
#if 0
void compute_split_pca2(VectorType& left_child_res, VectorType& right_child_res, const vq_node& parent_node)
{
if (parent_node.m_vectors.size() == 2)
{
left_child_res = m_training_vecs[parent_node.m_vectors[0]].first;
right_child_res = m_training_vecs[parent_node.m_vectors[1]].first;
return;
}
const uint N = VectorType::num_elements;
VectorType furthest;
double furthest_dist = -1.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
double dist = v.squared_distance(parent_node.m_centroid);
if (dist > furthest_dist)
{
furthest_dist = dist;
furthest = v;
}
}
VectorType opposite;
double opposite_dist = -1.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
double dist = v.squared_distance(furthest);
if (dist > opposite_dist)
{
opposite_dist = dist;
opposite = v;
}
}
VectorType axis(opposite - furthest);
if (axis.normalize() < .000125f)
{
left_child_res = (furthest + parent_node.m_centroid) * .5f;
right_child_res = (opposite + parent_node.m_centroid) * .5f;
return;
}
for (uint iter = 0; iter < 2; iter++)
{
double next_axis[N];
utils::zero_object(next_axis);
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const double weight = m_training_vecs[parent_node.m_vectors[i]].second;
VectorType v(m_training_vecs[parent_node.m_vectors[i]].first - parent_node.m_centroid);
double dot = (v * axis) * weight;
for (uint j = 0; j < N; j++)
next_axis[j] += dot * v[j];
}
double w = 0.0f;
for (uint j = 0; j < N; j++)
w += next_axis[j] * next_axis[j];
if (w > 0.0f)
{
w = 1.0f / sqrt(w);
for (uint j = 0; j < N; j++)
axis[j] = static_cast<float>(next_axis[j] * w);
}
else
break;
}
VectorType left_child(0.0f);
VectorType right_child(0.0f);
double left_weight = 0.0f;
double right_weight = 0.0f;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const float weight = (float)m_training_vecs[parent_node.m_vectors[i]].second;
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
double t = (v - parent_node.m_centroid) * axis;
if (t < 0.0f)
{
left_child += v * weight;
left_weight += weight;
}
else
{
right_child += v * weight;
right_weight += weight;
}
}
if ((left_weight > 0.0f) && (right_weight > 0.0f))
{
left_child_res = left_child * (float)(1.0f / left_weight);
right_child_res = right_child * (float)(1.0f / right_weight);
}
else
{
left_child_res = (furthest + parent_node.m_centroid) * .5f;
right_child_res = (opposite + parent_node.m_centroid) * .5f;
}
}
#endif
// thread safety warning: shared state!
crnlib::vector<uint> m_left_children;
crnlib::vector<uint> m_right_children;
void split_node(uint index)
{
vq_node& parent_node = m_nodes[index];
if (parent_node.m_vectors.size() == 1)
return;
VectorType left_child, right_child;
if (m_quick)
compute_split_estimate(left_child, right_child, parent_node);
else
compute_split_pca(left_child, right_child, parent_node);
uint64 left_weight = 0;
uint64 right_weight = 0;
float prev_total_variance = 1e+10f;
float left_variance = 0.0f;
float right_variance = 0.0f;
const uint cMaxLoops = m_quick ? 2 : 8;
for (uint total_loops = 0; total_loops < cMaxLoops; total_loops++)
{
m_left_children.resize(0);
m_right_children.resize(0);
VectorType new_left_child(cClear);
VectorType new_right_child(cClear);
double left_ttsum = 0.0f;
double right_ttsum = 0.0f;
left_weight = 0;
right_weight = 0;
for (uint i = 0; i < parent_node.m_vectors.size(); i++)
{
const VectorType& v = m_training_vecs[parent_node.m_vectors[i]].first;
const uint weight = m_training_vecs[parent_node.m_vectors[i]].second;
double left_dist2 = left_child.squared_distance(v);
double right_dist2 = right_child.squared_distance(v);
if (left_dist2 < right_dist2)
{
m_left_children.push_back(parent_node.m_vectors[i]);
new_left_child += (v * (float)weight);
left_weight += weight;
left_ttsum += v.dot(v) * weight;
}
else
{
m_right_children.push_back(parent_node.m_vectors[i]);
new_right_child += (v * (float)weight);
right_weight += weight;
right_ttsum += v.dot(v) * weight;
}
}
if ((!left_weight) || (!right_weight))
{
parent_node.m_unsplittable = true;
return;
}
left_variance = (float)(left_ttsum - (new_left_child.dot(new_left_child) / left_weight));
right_variance = (float)(right_ttsum - (new_right_child.dot(new_right_child) / right_weight));
new_left_child *= (1.0f / left_weight);
new_right_child *= (1.0f / right_weight);
left_child = new_left_child;
left_weight = left_weight;
right_child = new_right_child;
right_weight = right_weight;
float total_variance = left_variance + right_variance;
if (total_variance < .00001f)
break;
//const float variance_delta_thresh = .00001f;
const float variance_delta_thresh = .00125f;
if (((prev_total_variance - total_variance) / total_variance) < variance_delta_thresh)
break;
prev_total_variance = total_variance;
}
const uint left_child_index = m_nodes.size();
const uint right_child_index = m_nodes.size() + 1;
parent_node.m_left = m_nodes.size();
parent_node.m_right = m_nodes.size() + 1;
parent_node.m_codebook_index = m_split_index;
m_split_index++;
m_nodes.resize(m_nodes.size() + 2);
// parent_node is invalid now, because m_nodes has been changed
vq_node& left_child_node = m_nodes[left_child_index];
vq_node& right_child_node = m_nodes[right_child_index];
left_child_node.m_centroid = left_child;
left_child_node.m_total_weight = left_weight;
left_child_node.m_vectors.swap(m_left_children);
left_child_node.m_variance = left_variance;
if ((left_child_node.m_vectors.size() > 1) && (left_child_node.m_variance > 0.0f))
insert_heap(left_child_index);
right_child_node.m_centroid = right_child;
right_child_node.m_total_weight = right_weight;
right_child_node.m_vectors.swap(m_right_children);
right_child_node.m_variance = right_variance;
if ((right_child_node.m_vectors.size() > 1) && (right_child_node.m_variance > 0.0f))
insert_heap(right_child_index);
}
};
} // namespace crnlib