Source code for pyoperant.queues

import random
from pyoperant.utils import rand_from_log_shape_dist
import cPickle as pickle
import numpy as np

[docs]def random_queue(conditions,tr_max=100,weights=None): """ generator which randomly samples conditions Args: conditions (list): The conditions to sample from. weights (list of ints): Weights of each condition Kwargs: tr_max (int): Maximum number of trial conditions to generate. (default: 100) Returns: whatever the elements of 'conditions' are """ if weights: conditions_weighted = [] for cond,w in zip(conditions,weights): for ww in range(w): conditions_weighted += cond conditions = conditions_weighted tr_num = 0 while tr_num < tr_max: yield random.choice(conditions) tr_num += 1
[docs]def block_queue(conditions,reps=1,shuffle=False): """ generate trial conditions from a block Args: conditions (list): The conditions to sample from. Kwargs: reps (int): number of times each item in conditions will be presented (default: 1) shuffle (bool): Shuffles the queue (default: False) Returns: whatever the elements of 'conditions' are """ conditions_repeated = [] for rr in range(reps): conditions_repeated += conditions conditions = conditions_repeated if shuffle: random.shuffle(conditions) for cond in conditions: yield cond
[docs]class AdaptiveBase(object): """docstring for AdaptiveBase This is an abstract object for implementing adaptive procedures, such as a staircase. Importantly, any objects inheriting this need to define the `update()` and `next()` methods. """ def __init__(self, **kwargs): self.updated = True # for first trial, no update needed self.update_error_str = "queue hasn't been updated since last trial" def __iter__(self): return self
[docs] def update(self, correct, no_resp): self.updated = True if no_resp: self.no_response()
[docs] def next(self): if not self.updated: #hasn't been updated since last trial raise Exception(self.update_error_msg()) #TODO: find what causes bug self.updated = False
[docs] def no_response(self): pass
[docs] def on_load(self): try: super(AdaptiveBase, self).on_load() except AttributeError: pass self.updated = True self.no_response()
[docs] def update_error_msg(self): return self.update_error_str
[docs]class PersistentBase(object): """ A mixin that allows for the creation of an obj through a load command that first checks for a pickled file to load an object before generating a new one. """ def __init__(self, filename=None, **kwargs): assert filename != None super(PersistentBase, self).__init__(**kwargs) self.filename = filename self.save()
[docs] @classmethod def load(cls, filename, *args, **kwargs): try: with open(filename, 'rb') as handle: ab = pickle.load(handle) ab.on_load() return ab except IOError: return cls(*args, filename=filename, **kwargs)
[docs] def on_load(self): try: super(PersistentBase, self).on_load() except AttributeError: pass
[docs] def save(self): with open(self.filename, 'wb') as handle: pickle.dump(self, handle)
[docs]class KaernbachStaircase(AdaptiveBase): """ generates values for a staircase procedure from Kaernbach 1991 This procedure returns values for each trial and assumes that larger values are easier. Thus, after a correct trial, the next value returned will be smaller and after incorrect trials, the next value returned will be larger. The magnitudes of these changes are stepsize_dn and stepsize_up, respectively. Args: start_val (float/int): the starting value of the procedure (default: 100) Kwargs: stepsize_up (int): number of steps to take after incorrect trial (default: 3) stepsize_dn (int): number of steps to take after correct trial (default: 1) min_val (float): minimum parameter value to allow (default: 0) max_val (float): maximum parameter value to allow (default: 100) crit (int): minimum number of trials (default: 0) crit_method (int): maximum number of trials (default: 100) Returns: float """ def __init__(self, start_val=100, stepsize_up=3, stepsize_dn=1, min_val=0, max_val=100, crit=100, crit_method='trials' ): super(KaernbachStaircase, self).__init__() self.val = start_val self.stepsize_up = stepsize_up self.stepsize_dn = stepsize_dn self.min_val = min_val self.max_val = max_val self.crit = crit self.crit_method = crit_method self.counter = 0 self.going_up = False
[docs] def update(self, correct, no_resp): super(KaernbachStaircase, self).update(correct, no_resp) self.val += -1*self.stepsize_dn if correct else self.stepsize_up if self.crit_method=='reversals': if correct==self.going_up: # checks if last trial's perf was consistent w/ trend self.counter += 1 self.going_up = not self.going_up # stop at max/min if we hit the rails if (self.max_val!=None) and (self.val > self.max_val): self.val = self.max_val elif (self.min_val!=None) and (self.val < self.min_val): self.val = self.min_val
[docs] def next(self): super(KaernbachStaircase, self).next() if self.counter > self.crit: raise StopIteration self.counter += 1 if self.crit_method=='trials' else 0 return self.val
[docs]class DoubleStaircase(AdaptiveBase): """ Generates conditions from a list of stims that monotonically vary from most easily left to most easily right i.e. left is low and right is high The goal of this queue is to estimate the 50% point of a psychometric curve. This will probe left and right trials, if the response is correct, it will move the indices closer to each other until they are adjacent. stims: an array of stimuli names ordered from most easily left to most easily right rate_constant: the step size is the rate_constant*(high_idx-low_idx) """ def __init__(self, stims, rate_constant=.05, **kwargs): super(DoubleStaircase, self).__init__(**kwargs) self.stims = stims self.rate_constant = rate_constant self.low_idx = 0 self.high_idx = len(self.stims) - 1 self.trial = {} self.update_error_str = "double staircase queue %s hasn't been updated since last trial" % (self.stims[0])
[docs] def update(self, correct, no_resp): super(DoubleStaircase, self).update(correct, no_resp) if correct: if self.trial['low']: self.low_idx = self.trial['value'] else: self.high_idx = self.trial['value'] self.trial = {}
[docs] def next(self): super(DoubleStaircase, self).next() if self.high_idx - self.low_idx <= 1: raise StopIteration delta = int(np.ceil((self.high_idx - self.low_idx) * self.rate_constant)) if random.random() < .5: # probe low side self.trial['low'] = True self.trial['value'] = self.low_idx + delta return {'class': 'L', 'stim_name': self.stims[self.trial['value']]} else: self.trial['low'] = False self.trial['value'] = self.high_idx - delta return {'class': 'R', 'stim_name': self.stims[self.trial['value']]}
[docs] def no_response(self): super(DoubleStaircase, self).no_response() self.trial = {}
[docs] def update_error_msg(self): sup = super(DoubleStaircase, self).update_error_msg() state = "self.trial.low=%s self.trial.value=%d self.low_idx=%d self.high_idx=%d" % (self.trial['low'], self.trial['value'], self.low_idx, self.high_idx) return "\n".join([sup, state])
[docs]class DoubleStaircaseReinforced(AdaptiveBase): """ Generates conditions as with DoubleStaircase, but 1-probe_rate proportion of the trials easier/known trials to reduce frustration. Easier trials are sampled from a log shaped distribution so that more trials are sampled from the edges than near the indices stims: an array of stimuli names ordered from most easily left to most easily right rate_constant: the step size is the rate_constant*(high_idx-low_idx) probe_rate: proportion of trials that are between [0, low_idx] or [high_idx, length(stims)] """ def __init__(self, stims, rate_constant=.05, probe_rate=.1, sample_log=False, **kwargs): super(DoubleStaircaseReinforced, self).__init__(**kwargs) self.dblstaircase = DoubleStaircase(stims, rate_constant) self.stims = stims self.probe_rate = probe_rate self.sample_log = sample_log self.last_probe = False self.update_error_str = "reinforced double staircase queue %s hasn't been updated since last trial" % (self.stims[0])
[docs] def update(self, correct, no_resp): if self.last_probe: self.dblstaircase.update(correct, no_resp) super(DoubleStaircaseReinforced, self).update(correct, no_resp)
[docs] def next(self): super(DoubleStaircaseReinforced, self).next() if random.random() < self.probe_rate: try: ret = self.dblstaircase.next() self.last_probe = True return ret except StopIteration: self.probe_rate = 0 self.last_probe = False self.updated = True return self.next() else: self.last_probe = False if random.random() < .5: # probe left if self.sample_log: val = int((1 - rand_from_log_shape_dist()) * self.dblstaircase.low_idx) else: val = random.randrange(self.dblstaircase.low_idx + 1) return {'class': 'L', 'stim_name': self.stims[val]} else: # probe right if self.sample_log: val = self.dblstaircase.high_idx + int(rand_from_log_shape_dist() * (len(self.stims) - self.dblstaircase.high_idx)) else: val = self.dblstaircase.high_idx + random.randrange(len(self.stims) - self.dblstaircase.high_idx) return {'class': 'R', 'stim_name': self.stims[val]}
[docs] def no_response(self): super(DoubleStaircaseReinforced, self).no_response()
[docs] def on_load(self): super(DoubleStaircaseReinforced, self).on_load() self.dblstaircase.on_load()
[docs] def update_error_msg(self): sup = super(DoubleStaircaseReinforced, self).update_error_msg() state = "self.last_probe=%s" % (self.last_probe) sub_state = "self.dbs.trial=%s self.dbs.low_idx=%d self.dbs.high_idx=%d" % (self.dblstaircase.trial, self.dblstaircase.low_idx, self.dblstaircase.high_idx) return "\n".join([sup, state, sub_state])
[docs]class MixedAdaptiveQueue(PersistentBase, AdaptiveBase): """ Generates conditions from multiple adaptive sub queues. Use the generator MixedAdaptiveQueue.load(filename, sub_queues) to load a previously saved MixedAdaptiveQueue or generate a new one if the pkl file doesn't exist. sub_queues: a list of adaptive queues probabilities: a list of weights with which to sample from sub_queues should be same length as sub_queues NotImplemented filename: filename of pickle to save itself """ def __init__(self, sub_queues, probabilities=None, **kwargs): super(MixedAdaptiveQueue, self).__init__(**kwargs) self.sub_queues = sub_queues self.probabilities = probabilities self.sub_queue_idx = -1 self.update_error_str = "MixedAdaptiveQueue hasn't been updated since last trial" self.save()
[docs] def update(self, correct, no_resp): super(MixedAdaptiveQueue, self).update(correct, no_resp) self.sub_queues[self.sub_queue_idx].update(correct, no_resp) self.save()
[docs] def next(self): super(MixedAdaptiveQueue, self).next() if self.probabilities is None: try: self.sub_queue_idx = random.randrange(len(self.sub_queues)) return self.sub_queues[self.sub_queue_idx].next() except StopIteration: #TODO: deal with subqueue finished, and possibility of all subqueues finishing raise NotImplementedError else: #TODO: support variable probabilities for each sub_queue raise NotImplementedError
[docs] def on_load(self): super(MixedAdaptiveQueue, self).on_load() for sub_queue in self.sub_queues: try: sub_queue.on_load() except AttributeError: pass