UserPreferences

OfflineGovernor


  1 
  2 
  3 
  4 
  5 
  6 
  7 
  8 
  9 
 10 
 11 
 12 
 13 
 14 
 15 
 16 
 17 
 18 
 19 
 20 
 21 
 22 
 23 
 24 
 25 
 26 
 27 
 28 
 29 
 30 
 31 
 32 
 33 
 34 
 35 
 36 
 37 
 38 
 39 
 40 
 41 
 42 
 43 
 44 
 45 
 46 
 47 
 48 
 49 
 50 
 51 
 52 
 53 
 54 
 55 
 56 
 57 
 58 
 59 
 60 
 61 
 62 
 63 
 64 

from pyrobot.brain.conx import *
from pyrobot.brain.ravq import *

n = SRN()
n.setSequenceType("ordered-continuous")
n.addLayers(16,2,2)
n.loadDataFromFile('input_target.dat')

n.setEpsilon(0.2)
n.setMomentum(0.9)
n.setTolerance(0.05)
n.setLearning(1)

ravq = ARAVQ(3, .2, 1.6, .05)
ravq.setAddModels(1)
ravq.setHistory(1)
ravq.setMask([1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,8,8])

fp = open('balanced.dat','w')

counter = 0
buffer = []
bufferIndex = 0

method = 1

def saveListToFile(ls, file):
    for i in range(len(ls)):
        file.write(str(ls[i]) + " ")
    file.write("\n")

for x in n.loadOrder:
    inputs = n.inputs[x]
    targets = n.targets[x]
    ravq.input(inputs + targets)
    if method:
        if ravq.getNewWinner(): # is 1 if the winner is a new winner, 0 otherwise
            if len(buffer) >= 100:
                buffer = buffer[1:] + [inputs + targets]
            else:
                buffer.append(inputs + targets)
        if len(buffer) > 0: # cycle through current buffer
            array = buffer[bufferIndex]
            bufferIndex = (bufferIndex + 1) % len(buffer)
            n.step(input = array[:16], output = array[16:])
            saveListToFile(array, fp)
        if x > 50000: # train for 50000 steps
            break
    else:
        if ravq.getHistoryLength() > 0:
            array = ravq.getHistory(bufferIndex)
            bufferIndex = (bufferIndex + 1) % ravq.getHistoryLength()
            n.step(input = array[:16], output = array[16:])
            saveListToFile(array, fp)
        if x > 50000:
            break

print " Count: ", x
print " Steps: ", n.count
print " Number of model vectors: ", len(ravq.models)

n.saveWeightsToFile('network.wts')
fp.close()