UserPreferences

NNTrainRobotProgram


  1 
  2 
  3 
  4 
  5 
  6 
  7 
  8 
  9 
 10 
 11 
 12 
 13 
 14 
 15 
 16 
 17 
 18 
 19 
 20 
 21 
 22 
 23 
 24 
 25 
 26 
 27 
 28 
 29 
 30 
 31 
 32 
 33 
 34 
 35 
 36 
 37 
 38 
 39 
 40 
 41 
 42 
 43 
 44 
 45 
 46 
 47 
 48 
 49 
 50 
 51 
 52 
 53 
 54 
 55 
 56 
 57 
 58 
 59 
 60 
 61 

# Train a network offline
# Inputs: two scaled front sensor readings 
# Outputs: one translate reading (unscaled) 

from pyrobot.brain.conx import *
from pyrobot.system.log import *

def setFromFile(filename, cols = None, delim = ' '):
   fp = open(filename, "r")
   line = fp.readline()
   lineno = 1
   lastLength = None
   data = []
   while line:
      linedata = [float(x) for x in line.strip().split(delim)]
      if cols == None: # get em all 
         newdata = linedata
      else: # just get some cols 
         newdata = []
         for i in cols:
            newdata.append( linedata[i] )
      if lastLength == None or len(newdata) == lastLength:
         data.append( newdata )
      else:
         raise "DataFormatError", "line = %d" % lineno
      lastLength = len(newdata)
      lineno += 1
      line = fp.readline()
   fp.close()
   print "length of data array is", len(data)
   return data

def saveListToFile(ls, file):
   for i in range(len(ls)):
      file.write(str(ls[i]) + " ")
   file.write("\n")

# Create the network 
n = Network()
n.addLayers(2,1,1)
# Set learning parameters 
n.setEpsilon(0.3)
n.setMomentum(0.0)
n.setTolerance(0.05)
# set inputs and targets (from collected data set) 
n.setInputs(setFromFile('frontsensors.dat'))
n.setTargets(setFromFile('translatetargets.dat'))
# Logging 
log = Log(name = 'E05M01.txt')
best = 0
for i in xrange(0,1000,1):
   tssError, totalCorrect, totalCount, totalPCorrect = n.sweep()
   correctpercent = (totalCorrect*0.1) / (totalCount*0.1)
   log.writeln( "Epoch # "+ str(i)+ " TSS ERROR: "+ str(tssError)+
                " Correct: "+ str(totalCorrect)+ " Total Count: "+
                str(totalCount)+ " %correct = "+ str(correctpercent))
   if best < correctpercent:
      n.saveWeightsToFile("E05M01.wts")
      best = correctpercent
print "done"