1 Session概述
Session是TensorFlow前後端連接配接的橋梁。使用者利用session使得client能夠與master的執行引擎建立連接配接,并通過session.run()來觸發一次計算。它建立了一套上下文環境,封裝了operation計算以及tensor求值的環境。
session建立時,系統會配置設定一些資源,比如graph引用、要連接配接的計算引擎的名稱等。故計算完畢後,需要使用session.close()關閉session,避免引起記憶體洩漏,特别是graph無法釋放的問題。可以顯式調用session.close(),或利用with上下文管理器,或者直接使用InteractiveSession。
session之間采用共享graph的方式來提高運作效率。一個session隻能運作一個graph執行個體,但一個graph可以運作在多個session中。一般情況下,建立session時如果不指定Graph執行個體,則會使用系統預設Graph。常見情況下,我們都是使用一個graph,即預設graph。當session建立時,不會重新建立graph執行個體,而是預設graph引用計數加1。當session close時,引用計數減1。隻有引用計數為0時,graph才會被回收。這種graph共享的方式,大大減少了graph建立和回收的資源消耗,優化了TensorFlow運作效率。
2 預設session
op運算和tensor求值時,如果沒有指定運作在哪個session中,則會運作在預設session中。通過session.as_default()可以将自己設定為預設session。但個人建議最好還是通過session.run(operator)和session.run(tensor)來進行op運算和tensor求值。
operation.run()
operation.run()等價于tf.get_default_session().run(operation)
@tf_export("Operation")
class Operation(object):
# 通過operation.run()調用,進行operation計算
def run(self, feed_dict=None, session=None):
_run_using_default_session(self, feed_dict, self.graph, session)
def _run_using_default_session(operation, feed_dict, graph, session=None):
# 沒有指定session,則擷取預設session
if session is None:
session = get_default_session()
# 最終還是通過session.run()進行運作的。tf中任何運算,都是通過session來run的。
# 通過session來建立client和master的連接配接,并将graph發送給master,master再進行執行
session.run(operation, feed_dict)
tensor.eval()
tensor.eval()等價于tf.get_default_session().run(tensor), 如下
@tf_export("Tensor")
class Tensor(_TensorLike):
# 通過tensor.eval()調用,進行tensor運算
def eval(self, feed_dict=None, session=None):
return _eval_using_default_session(self, feed_dict, self.graph, session)
def _eval_using_default_session(tensors, feed_dict, graph, session=None):
# 如果沒有指定session,則擷取預設session
if session is None:
session = get_default_session()
return session.run(tensors, feed_dict)
預設session的管理
tf通過運作時維護的session本地線程棧,來管理預設session。故不同的線程會有不同的預設session,預設session是線程作用域的。
# session棧
_default_session_stack = _DefaultStack()
# 擷取預設session的接口
@tf_export("get_default_session")
def get_default_session():
return _default_session_stack.get_default()
# _DefaultStack預設session棧是線程相關的
class _DefaultStack(threading.local):
# 預設session棧的建立,其實就是一個list
def __init__(self):
super(_DefaultStack, self).__init__()
self._enforce_nesting = True
self.stack = []
# 擷取預設session
def get_default(self):
return self.stack[-1] if len(self.stack) >= 1 else None
3 前端Session類型
session類圖
會話Session的UML類圖如下
分為兩種類型,普通Session和互動式InteractiveSession。InteractiveSession和Session基本相同,差別在于
- InteractiveSession建立後,會将自己替換為預設session。使得之後operation.run()和tensor.eval()的執行通過這個預設session來進行。特别适合Python互動式環境。
- InteractiveSession自帶with上下文管理器。它在建立時和關閉時會調用上下文管理器的enter和exit方法,進而進行資源的申請和釋放,避免記憶體洩漏問題。這同樣很适合Python互動式環境。
Session和InteractiveSession的代碼邏輯不多,主要邏輯均在其父類BaseSession中。主要代碼如下
@tf_export('Session')
class Session(BaseSession):
def __init__(self, target='', graph=None, config=None):
# session建立的主要邏輯都在其父類BaseSession中
super(Session, self).__init__(target, graph, config=config)
self._default_graph_context_manager = None
self._default_session_context_manager = None
@tf_export('InteractiveSession')
class InteractiveSession(BaseSession):
def __init__(self, target='', graph=None, config=None):
self._explicitly_closed = False
# 将自己設定為default session
self._default_session = self.as_default()
self._default_session.enforce_nesting = False
# 自動調用上下文管理器的__enter__()方法
self._default_session.__enter__()
self._explicit_graph = graph
def close(self):
super(InteractiveSession, self).close()
## 省略無關代碼
## 自動調用上下文管理器的__exit__()方法,避免記憶體洩漏
self._default_session.__exit__(None, None, None)
self._default_session = None
BaseSession
BaseSession基本包含了所有的會話實作邏輯。包括會話的整個生命周期,也就是建立 執行 關閉和銷毀四個階段。生命周期後面詳細分析。BaseSession包含的主要成員變量有graph引用,序列化的graph_def, 要連接配接的tf引擎target,session配置資訊config等。
4 後端Session類型
在後端master中,根據前端client調用tf.Session(target='', graph=None, config=None)時指定的target,來建立不同的Session。target為要連接配接的tf後端執行引擎,預設為空字元串。Session建立采用了抽象工廠模式,如果為空字元串,則建立本地DirectSession,如果以grpc://開頭,則建立分布式GrpcSession。類圖如下
DirectSession隻能利用本地裝置,将任務建立到本地的CPU GPU上。而GrpcSession則可以利用遠端分布式裝置,将任務建立到其他機器的CPU GPU上,然後通過grpc協定進行通信。grpc協定是谷歌發明并開源的遠端通信協定。
5 Session生命周期
Session作為前後端連接配接的橋梁,以及上下文運作環境,其生命周期尤其關鍵。大緻分為4個階段
- 建立:通過tf.Session()建立session執行個體,進行系統資源配置設定,特别是graph引用計數加1
- 運作:通過session.run()觸發計算的執行,client會将整圖graph傳遞給master,由master進行執行
- 關閉:通過session.close()來關閉,會進行系統資源的回收,特别是graph引用計數減1.
- 銷毀:Python垃圾回收器進行GC時,調用
進行回收。session.__del__()
生命周期方法入口基本都在前端Python的BaseSession中,它會通過swig自動生成的函數符号映射關系,調用C層的實作。
5.1 建立
先從BaseSession類的init方法看起,隻保留了主要代碼。
def __init__(self, target='', graph=None, config=None):
# graph表示建構的圖。TensorFlow的一個session會對應一個圖。這個圖包含了所有涉及到的算子
# graph如果沒有設定(通常都不會設定),則使用預設graph
if graph is None:
self._graph = ops.get_default_graph()
else:
self._graph = graph
self._opened = False
self._closed = False
self._current_version = 0
self._extend_lock = threading.Lock()
# target為要連接配接的tf執行引擎
if target is not None:
self._target = compat.as_bytes(target)
else:
self._target = None
self._delete_lock = threading.Lock()
self._dead_handles = []
# config為session的配置資訊
if config is not None:
self._config = config
self._add_shapes = config.graph_options.infer_shapes
else:
self._config = None
self._add_shapes = False
self._created_with_new_api = ops._USE_C_API
# 調用C層來建立session
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, status)
BaseSession先進行成員變量的指派,然後調用TF_NewSession來建立session。TF_NewSession()方法由swig自動生成,在bazel-bin/tensorflow/python/pywrap_tensorflow_internal.py中
def TF_NewSession(graph, opts, status):
return _pywrap_tensorflow_internal.TF_NewSession(graph, opts, status)
_pywrap_tensorflow_internal包含了C層函數的符号表。在swig子產品import時,會加載pywrap_tensorflow_internal.so動态連結庫,進而得到符号表。在pywrap_tensorflow_internal.cc中,注冊了供Python調用的函數的符号表,進而實作Python到C的函數映射和調用。
// c++函數調用的符号表,Python通過它可以調用到C層代碼。符号表和動态連結庫由swig自動生成
static PyMethodDef SwigMethods[] = {
// .. 省略其他函數定義
// TF_NewSession的符号表,通過這個映射,Python中就可以調用到C層代碼了。
{ (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL},
// ... 省略其他函數定義
}
最終調用到c_api.c中的TF_NewSession()
// TF_NewSession建立session的新實作,在C層後端代碼中
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
// 建立session
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
// 采用了引用計數方式,多個session共享一個圖執行個體,效率更高。
// session建立時,引用計數加1。session close時引用計數減1。引用計數為0時,graph才會被回收。
mutex_lock l(graph->mu);
graph->sessions[new_session] = Status::OK();
}
return new_session;
} else {
DCHECK_EQ(nullptr, session);
return nullptr;
}
}
session建立時,并建立graph,而是采用共享方式,隻是引用計數加1了。這種方式減少了session建立和關閉時的資源消耗,提高了運作效率。NewSession()根據前端傳遞的target,使用sessionFactory建立對應的TensorFlow::Session執行個體。
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
const Status s = SessionFactory::GetFactory(options, &factory);
// 通過sessionFactory建立多态的Session。本地session為DirectSession,分布式為GRPCSession
*out_session = factory->NewSession(options);
if (!*out_session) {
return errors::Internal("Failed to create session.");
}
return Status::OK();
}
建立session采用了抽象工廠模式。根據client傳遞的target,來建立不同的session。如果target為空字元串,則建立本地DirectSession。如果以grpc://開頭,則建立分布式GrpcSession。TensorFlow包含本地運作時和分布式運作時兩種運作模式。
下面來看DirectSessionFactory的NewSession()方法
class DirectSessionFactory : public SessionFactory {
public:
Session* NewSession(const SessionOptions& options) override {
std::vector<Device*> devices;
// job在本地執行
const Status s = DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
{
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
}
return session;
}
GrpcSessionFactory的NewSession()方法就不詳細分析了,它會将job任務建立在分布式裝置上,各job通過grpc協定通信。
5.2 運作
通過session.run()可以啟動graph的執行。入口在BaseSession的run()方法中, 同樣隻列出關鍵代碼
class BaseSession(SessionInterface):
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
# fetches可以為單個變量,或者數組,或者元組。它是圖的一部分,可以是操作operation,也可以是資料tensor,或者他們的名字String
# feed_dict為對應placeholder的實際訓練資料,它的類型為字典
result = self._run(None, fetches, feed_dict, options_ptr,run_metadata_ptr)
return result
def _run(self, handle, fetches, feed_dict, options, run_metadata):
# 建立fetch處理器fetch_handler
fetch_handler = _FetchHandler(
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
# 經過不同類型的fetch_handler處理,得到最終的fetches和targets
# targets為要執行的operation,fetches為要執行的tensor
_ = self._update_with_movers(feed_dict_tensor, feed_map)
final_fetches = fetch_handler.fetches()
final_targets = fetch_handler.targets()
# 開始運作
if final_fetches or final_targets or (handle and feed_dict_tensor):
results = self._do_run(handle, final_targets, final_fetches,
feed_dict_tensor, options, run_metadata)
else:
results = []
# 輸出結果到results中
return fetch_handler.build_results(self, results)
def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata):
# 将要運作的operation添加到graph中
self._extend_graph()
# 執行一次運作run,會調用底層C來實作
return tf_session.TF_SessionPRunSetup_wrapper(
session, feed_list, fetch_list, target_list, status)
# 将要運作的operation添加到graph中
def _extend_graph(self):
with self._extend_lock:
if self._graph.version > self._current_version:
# 生成graph_def對象,它是graph的序列化表示
graph_def, self._current_version = self._graph._as_graph_def(
from_version=self._current_version, add_shapes=self._add_shapes)
# 通過TF_ExtendGraph将序列化後的graph,也就是graph_def傳遞給後端
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_ExtendGraph(self._session,
graph_def.SerializeToString(), status)
self._opened = True
邏輯還是十分複雜的,主要有一下幾步
- 入參處理,建立fetch處理器fetch_handler,得到最終要執行的operation和tensor
- 對graph進行序列化,生成graph_def對象
- 将序列化後的grap_def對象傳遞給後端master。
- 通過後端master來run。
我們分别來看extend和run。
5.2.1 extend添加節點到graph中
TF_ExtendGraph()會調用到c_api中,這個邏輯同樣通過swig工具自動生成。下面看c_api.cc中的TF_ExtendGraph()方法
// 增加節點到graph中,proto為序列化後的graph
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
size_t proto_len, TF_Status* status) {
GraphDef g;
// 先将proto反序列化,得到client傳遞的graph,放入g中
if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
// 再調用session的extend方法。根據建立的不同session類型,多态調用不同方法。
status->status = s->session->Extend(g);
}
後端系統根據生成的Session類型,多态的調用Extend方法。如果是本地session,則調用DirectSession的Extend()方法。如果是分布式session,則調用GrpcSession的相關方法。下面來看GrpcSession的Extend方法。
Status GrpcSession::Extend(const GraphDef& graph) {
CallOptions call_options;
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
return ExtendImpl(&call_options, graph);
}
Status GrpcSession::ExtendImpl(CallOptions* call_options,
const GraphDef& graph) {
bool handle_is_empty;
{
mutex_lock l(mu_);
handle_is_empty = handle_.empty();
}
if (handle_is_empty) {
// 如果graph句柄為空,則表明graph還沒有建立好,此時extend就等同于create
return Create(graph);
}
mutex_lock l(mu_);
ExtendSessionRequest req;
req.set_session_handle(handle_);
*req.mutable_graph_def() = graph;
req.set_current_graph_version(current_graph_version_);
ExtendSessionResponse resp;
// 調用底層實作,來添加節點到graph中
Status s = master_->ExtendSession(call_options, &req, &resp);
if (s.ok()) {
current_graph_version_ = resp.new_graph_version();
}
return s;
}
Extend()方法中要注意的一點是,如果是首次執行Extend(), 則要先調用Create()方法進行graph的注冊。否則才是執行添加節點到graph中。
5.2.2 run執行圖的計算
同樣,Python通過swig自動生成的代碼,來實作對C API的調用。C層實作在c_api.cc的TF_Run()中。
// session.run()的C層實作
void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
// Input tensors,輸入的資料tensor
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors,運作計算後輸出的資料tensor
const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
// Target nodes,要運作的節點
const char** c_target_oper_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) {
// 省略一段代碼
TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
c_outputs, target_oper_names, run_metadata, status);
}
// 真正的實作了session.run()
static void TF_Run_Helper() {
RunMetadata run_metadata_proto;
// 調用不同的session實作類的run方法,來執行
result = session->Run(run_options_proto, input_pairs, output_tensor_names,
target_oper_names, &outputs, &run_metadata_proto);
// 省略代碼
}
最終會調用建立的session來執行run方法。DirectSession和GrpcSession的Run()方法會有所不同。後面很複雜,就不接着分析了。
5.3 關閉session
通過session.close()來關閉session,釋放相關資源,防止記憶體洩漏。
class BaseSession(SessionInterface):
def close(self):
tf_session.TF_CloseSession(self._session, status)
會調用到C API的TF_CloseSession()方法。
void TF_CloseSession(TF_Session* s, TF_Status* status) {
status->status = s->session->Close();
}
最終根據建立的session,多态的調用其Close()方法。同樣分為DirectSession和GrpcSession兩種。
::tensorflow::Status DirectSession::Close() {
cancellation_manager_->StartCancel();
{
mutex_lock l(closed_lock_);
if (closed_) return ::tensorflow::Status::OK();
closed_ = true;
}
// 登出session
if (factory_ != nullptr) factory_->Deregister(this);
return ::tensorflow::Status::OK();
}
DirectSessionFactory中的Deregister()方法如下
void Deregister(const DirectSession* session) {
mutex_lock l(sessions_lock_);
// 釋放相關資源
sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
sessions_.end());
}
5.4 銷毀session
session的銷毀是由Python的GC自動執行的。python通過引用計數方法來判斷是否回收對象。當對象的引用計數為0,且虛拟機觸發了GC時,會調用對象的
__del__()
方法來銷毀對象。引用計數法有個很緻命的問題,就是無法解決循環引用問題,故會存在記憶體洩漏。Java虛拟機采用了調用鍊分析的方式來決定哪些對象會被回收。
class BaseSession(SessionInterface):
def __del__(self):
# 先close,防止使用者沒有調用close()
try:
self.close()
# 再調用c api的TF_DeleteSession來銷毀session
if self._session is not None:
try:
status = c_api_util.ScopedTFStatus()
if self._created_with_new_api:
tf_session.TF_DeleteSession(self._session, status)
c_api.cc中的相關邏輯如下
void TF_DeleteSession(TF_Session* s, TF_Status* status) {
status->status = Status::OK();
TF_Graph* const graph = s->graph;
if (graph != nullptr) {
graph->mu.lock();
graph->sessions.erase(s);
// 如果graph的引用計數為0,也就是graph沒有被任何session持有,則考慮銷毀graph對象
const bool del = graph->delete_requested && graph->sessions.empty();
graph->mu.unlock();
// 銷毀graph對象
if (del) delete graph;
}
// 銷毀session和TF_Session
delete s->session;
delete s;
}
TF_DeleteSession()會判斷graph的引用計數是否為0,如果為0,則會銷毀graph。然後銷毀session和TF_Session對象。通過Session實作類的析構函數,來銷毀session,釋放線程池Executor,資料總管ResourceManager等資源。
DirectSession::~DirectSession() {
for (auto& it : partial_runs_) {
it.second.reset(nullptr);
}
// 釋放線程池Executor
for (auto& it : executors_) {
it.second.reset();
}
for (auto d : device_mgr_->ListDevices()) {
d->op_segment()->RemoveHold(session_handle_);
}
// 釋放ResourceManager
for (auto d : device_mgr_->ListDevices()) {
d->ClearResourceMgr();
}
// 釋放CancellationManager執行個體
functions_.clear();
delete cancellation_manager_;
// 釋放ThreadPool
for (const auto& p_and_owned : thread_pools_) {
if (p_and_owned.second) delete p_and_owned.first;
}
execution_state_.reset(nullptr);
flib_def_.reset(nullptr);
}
6 總結
Session是TensorFlow的client和master連接配接的橋梁,client任何運算也是通過session來run。它是client端最重要的對象。在Python層和C++層,均有不同的session實作。session生命周期會經曆四個階段,create run close和del。四個階段均由Python前端開始,最終調用到C層後端實作。由此也可以看到,TensorFlow架構的前後端分離和子產品化設計是多麼的精巧。