from numpy import random
[docs]class BaseSchedule(object):
"""Maintains logic for deciding whether to consequate trials.
This base class provides the most basic reinforcent schedule: every
response is consequated.
Methods:
consequate(trial) -- returns a boolean value based on whether the trial
should be consequated. Always returns True.
"""
def __init__(self):
super(BaseSchedule, self).__init__()
[docs] def consequate(self,trial):
assert hasattr(trial, 'correct') and isinstance(trial.correct, bool)
if trial.correct:
return True
else:
return True
[docs]class ContinuousReinforcement(BaseSchedule):
"""Maintains logic for deciding whether to consequate trials.
This base class provides the most basic reinforcent schedule: every
response is consequated.
Methods:
consequate(trial) -- returns a boolean value based on whether the trial
should be consequated. Always returns True.
"""
def __init__(self):
super(ContinuousReinforcement, self).__init__()
[docs] def consequate(self,trial):
assert hasattr(trial, 'correct') and isinstance(trial.correct, bool)
if trial.correct:
return True
else:
return True
[docs]class FixedRatioSchedule(BaseSchedule):
"""Maintains logic for deciding whether to consequate trials.
This class implements a fixed ratio schedule, where a reward reinforcement
is provided after every nth correct response, where 'n' is the 'ratio'.
Incorrect trials are always reinforced.
Methods:
consequate(trial) -- returns a boolean value based on whether the trial
should be consequated.
"""
def __init__(self, ratio=1):
super(FixedRatioSchedule, self).__init__()
self.ratio = max(ratio,1)
self._update()
def _update(self):
self.cumulative_correct = 0
self.threshold = self.ratio
[docs] def consequate(self,trial):
assert hasattr(trial, 'correct') and isinstance(trial.correct, bool)
if trial.correct==True:
self.cumulative_correct += 1
if self.cumulative_correct >= self.threshold:
self._update()
return True
else:
return False
elif trial.correct==False:
self.cumulative_correct = 0
return True
else:
return False
def __unicode__(self):
return "FR%i" % self.ratio
[docs]class VariableRatioSchedule(FixedRatioSchedule):
"""Maintains logic for deciding whether to consequate trials.
This class implements a variable ratio schedule, where a reward
reinforcement is provided after every a number of consecutive correct
responses. On average, the number of consecutive responses necessary is the
'ratio'. After a reinforcement is provided, the number of consecutive
correct trials needed for the next reinforcement is selected by sampling
randomly from the interval [1,2*ratio-1]. e.g. a ratio of '3' will require
consecutive correct trials of 1, 2, 3, 4, & 5, randomly.
Incorrect trials are always reinforced.
Methods:
consequate(trial) -- returns a boolean value based on whether the trial
should be consequated.
"""
def __init__(self, ratio=1):
super(VariableRatioSchedule, self).__init__(ratio=ratio)
def _update(self):
''' update min correct by randomly sampling from interval [1:2*ratio)'''
self.cumulative_correct = 0
self.threshold = random.randint(1, 2*self.ratio)
def __unicode__(self):
return "VR%i" % self.ratio
[docs]class PercentReinforcement(BaseSchedule):
"""Maintains logic for deciding whether to consequate trials.
This class implements a probabalistic reinforcement, where a reward reinforcement
is provided x percent of the time.
Incorrect trials are always reinforced.
Methods:
consequate(trial) -- returns a boolean value based on whether the trial
should be consequated.
"""
def __init__(self, prob=1):
super(PercentReinforcement, self).__init__()
self.prob = prob
[docs] def consequate(self,trial):
assert hasattr(trial, 'correct') and isinstance(trial.correct, bool)
if trial.correct:
return random.random() < self.prob
else:
return True
def __unicode__(self):
return "PR%i" % self.prob