Commit db1c3eeb authored by Torge Berckmann's avatar Torge Berckmann
Browse files

Added realistic test case

parent a2d9baaa
......@@ -182,7 +182,7 @@ if False:
print("new mean u", corrobj.mu_u1)
print("new mean v", corrobj.mu_v1)
class TestStringMethods(unittest.TestCase):
class TestIncrCorrMethods(unittest.TestCase):
@staticmethod
def _generate_rand_seqs(dims):
......@@ -341,6 +341,52 @@ class TestStringMethods(unittest.TestCase):
self.assertAlmostEqual(npcorr, mycorr, places=5)
# Example use case
def test_ex_use_case(self):
# Comparison
d1 = 46
d2 = 4
embed_dim = 512
rst01inner = torch.rand(d1,d2,embed_dim)
rst11inner = torch.rand(d1,d2,embed_dim)
d1, d2, embed_dim = rst01inner.shape
components_x = [list() for _ in range(embed_dim)]
components_y = [list() for _ in range(embed_dim)]
res_corr = []
for x in range(d1):
for y in range(d2):
for comp in range(embed_dim):
components_x[comp].append(float(rst01inner[x,y,comp]))
components_y[comp].append(float(rst11inner[x,y,comp]))
for comp in range(embed_dim):
m = np.concatenate((
np.expand_dims(components_x[comp], 0),
np.expand_dims(components_y[comp], 0)))
outmat = np.corrcoef(m)
npcorr = outmat[0][1]
res_corr.append(npcorr)
np_total_corr = sum(res_corr) / embed_dim
# Similar to calculation in code
run_corr = IncrCorr((1,embed_dim), 0)
tx_tens = torch.reshape(rst01inner, (-1, embed_dim))
ty_tens = torch.reshape(rst11inner, (-1, embed_dim))
run_corr.update(tx_tens, ty_tens)
cur_corr = run_corr.retrieve()
total_incr_corr = float(cur_corr.sum() / embed_dim)
self.assertAlmostEqual(total_incr_corr, np_total_corr, places=5)
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment