天天看點

theano入門教程1.4

使用共享變量

# -*- coding: utf-8 -*-

"""

Created on Wed Jun  4 23:28:21 2014

@author: wencc

from theano import shared

from theano import function

import theano.tensor as T

if __name__ == ‘__main__‘:

state = shared(0)

inc = T.iscalar(‘inc‘)

accumulator = function([inc], state, updates=[(state, state+inc)])

print state.get_value()

print accumulator(1)

print accumulator(20)

fn_of_state = state*2 + inc

foo = T.scalar(dtype=state.dtype)

skip_shared = function([inc, foo], fn_of_state, givens=[(state, foo)])

skip_shared(1, 3)

state.get_value()

shared函數構造共享變量,共享變量的get_value,set_value函數用來檢視和設定共享變量的值

function函數中的updates參數用來更新共享變量,它是一個list,list中的每一項用map(共享變量,共享變量的新值表達式)的形式來表示。