天天看點

C# ConcurrentBag實作

ConcurrentBag可以了解為是一個線程安全無序集合,API比我們的list要弱一點,那我們來看看它的實作:

public class ConcurrentBag<T> : IProducerConsumerCollection<T>, IReadOnlyCollection<T>
    {
        // ThreadLocalList object that contains the data per thread
        ThreadLocal<ThreadLocalList> m_locals;

        // This head and tail pointers points to the first and last local lists, to allow enumeration on the thread locals objects
        volatile ThreadLocalList m_headList, m_tailList;
        
        bool m_needSync;
        
        public ConcurrentBag() { Initialize(null);}
        public ConcurrentBag(IEnumerable<T> collection)
        {
            if (collection == null)
            {
                throw new ArgumentNullException("collection", SR.GetString(SR.ConcurrentBag_Ctor_ArgumentNullException));
            }
            Initialize(collection);
        }
        
        private void Initialize(IEnumerable<T> collection)
        {
            m_locals = new ThreadLocal<ThreadLocalList>();

            // Copy the collection to the bag
            if (collection != null)
            {
                ThreadLocalList list = GetThreadList(true);
                foreach (T item in collection)
                {
                    list.Add(item, false);
                }
            }
        }
        
        public void Add(T item)
        {
            // Get the local list for that thread, create a new list if this thread doesn't exist 
            //(first time to call add)
            ThreadLocalList list = GetThreadList(true);
            AddInternal(list, item);
        }
        
       private void AddInternal(ThreadLocalList list, T item)
        {
            bool lockTaken = false;
            try
            {
                Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Add);
                //Synchronization cases:
                // if the list count is less than two to avoid conflict with any stealing thread
                // if m_needSync is set, this means there is a thread that needs to freeze the bag
                if (list.Count < 2 || m_needSync)
                {
                    // reset it back to zero to avoid deadlock with stealing thread
                    list.m_currentOp = (int)ListOperation.None;
                    Monitor.Enter(list, ref lockTaken);
                }
                list.Add(item, lockTaken);
            }
            finally
            {
                list.m_currentOp = (int)ListOperation.None;
                if (lockTaken)
                {
                    Monitor.Exit(list);
                }
            }
        }
        
        private ThreadLocalList GetThreadList(bool forceCreate)
        {
            ThreadLocalList list = m_locals.Value;
            if (list != null)
            {
                return list;
            }
            else if (forceCreate)
            {
                // Acquire the lock to update the m_tailList pointer
                lock (GlobalListsLock)
                {
                    if (m_headList == null)
                    {
                        list = new ThreadLocalList(Thread.CurrentThread);
                        m_headList = list;
                        m_tailList = list;
                    }
                    else
                    {
                        list = GetUnownedList();
                        if (list == null)
                        {
                            list = new ThreadLocalList(Thread.CurrentThread);
                            m_tailList.m_nextList = list;
                            m_tailList = list;
                        }
                    }
                    m_locals.Value = list;
                }
            }
            else
            {
                return null;
            }
            Debug.Assert(list != null);
            return list;
        }
        
        public bool TryTake(out T result)
        {
            return TryTakeOrPeek(out result, true);
        }
        
        public bool TryPeek(out T result)
        {
            return TryTakeOrPeek(out result, false);
        }
        
        private bool TryTakeOrPeek(out T result, bool take)
        {
            // Get the local list for that thread, return null if the thread doesn't exit 
            //(this thread never add before) 
            ThreadLocalList list = GetThreadList(false);
            if (list == null || list.Count == 0)
            {
                return Steal(out result, take);
            }
            bool lockTaken = false;
            try
            {
                if (take) // Take operation
                {
                    Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Take);
                    //Synchronization cases:
                    // if the list count is less than or equal two to avoid conflict with any stealing thread
                    // if m_needSync is set, this means there is a thread that needs to freeze the bag
                    if (list.Count <= 2 || m_needSync)
                    {
                        // reset it back to zero to avoid deadlock with stealing thread
                        list.m_currentOp = (int)ListOperation.None;
                        Monitor.Enter(list, ref lockTaken);

                        // Double check the count and steal if it became empty
                        if (list.Count == 0)
                        {
                            // Release the lock before stealing
                            if (lockTaken)
                            {
                                try { }
                                finally
                                {
                                    lockTaken = false; // reset lockTaken to avoid calling Monitor.Exit again in the finally block
                                    Monitor.Exit(list);
                                }
                            }
                            return Steal(out result, true);
                        }
                    }
                    list.Remove(out result);
                }
                else
                {
                    if (!list.Peek(out result))
                    {
                        return Steal(out result, false);
                    }
                }
            }
            finally
            {
                list.m_currentOp = (int)ListOperation.None;
                if (lockTaken)
                {
                    Monitor.Exit(list);
                }
            }
            return true;
        }

   private bool Steal(out T result, bool take)
        {
            bool loop;
            List<int> versionsList = new List<int>(); // save the lists version
            do
            {
                versionsList.Clear(); //clear the list from the previous iteration
                loop = false;
              
                ThreadLocalList currentList = m_headList;
                while (currentList != null)
                {
                    versionsList.Add(currentList.m_version);
                    if (currentList.m_head != null && TrySteal(currentList, out result, take))
                    {
                        return true;
                    }
                    currentList = currentList.m_nextList;
                }

                // verify versioning, if other items are added to this list since we last visit it, we should retry
                currentList = m_headList;
                foreach (int version in versionsList)
                {
                    if (version != currentList.m_version) //oops state changed
                    {
                        loop = true;
                        if (currentList.m_head != null && TrySteal(currentList, out result, take))
                            return true;
                    }
                    currentList = currentList.m_nextList;
                }
            } while (loop);

            result = default(T);
            return false;
        }
        
        private bool TrySteal(ThreadLocalList list, out T result, bool take)
        {
            lock (list)
            {
                if (CanSteal(list))
                {
                    list.Steal(out result, take);
                    return true;
                }
                result = default(T);
                return false;
            }
        }
        
        private bool CanSteal(ThreadLocalList list)
        {
            if (list.Count <= 2 && list.m_currentOp != (int)ListOperation.None)
            {
                SpinWait spinner = new SpinWait();
                while (list.m_currentOp != (int)ListOperation.None)
                {
                    spinner.SpinOnce();
                }
            }
            if (list.Count > 0)
            {
                return true;
            }
            return false;
        }
        /// <summary>
        /// Try to reuse an unowned list if exist
        /// unowned lists are the lists that their owner threads are aborted or terminated
        /// this is workaround to avoid memory leaks.
        /// </summary>
        /// <returns>The list object, null if all lists are owned</returns>
        private ThreadLocalList GetUnownedList()
        {
            //the global lock must be held at this point
            Contract.Assert(Monitor.IsEntered(GlobalListsLock));

            ThreadLocalList currentList = m_headList;
            while (currentList != null)
            {
                if (currentList.m_ownerThread.ThreadState == System.Threading.ThreadState.Stopped)
                {
                    currentList.m_ownerThread = Thread.CurrentThread; // the caller should acquire a lock to make this line thread safe
                    return currentList;
                }
                currentList = currentList.m_nextList;
            }
            return null;
        }
          internal class ThreadLocalList
        {
          
            internal volatile Node m_head;
            private volatile Node m_tail;
            internal volatile int m_currentOp;
            private int m_count;
            internal int m_stealCount;
            internal volatile ThreadLocalList m_nextList;
            internal bool m_lockTaken;
            internal Thread m_ownerThread;
            internal volatile int m_version;
            internal ThreadLocalList(Thread ownerThread)
            {
                m_ownerThread = ownerThread;
            }
            internal void Add(T item, bool updateCount)
            {
                checked
                {
                    m_count++;
                }
                Node node = new Node(item);
                if (m_head == null)
                {
                    Debug.Assert(m_tail == null);
                    m_head = node;
                    m_tail = node;
                    m_version++; // changing from empty state to non empty state
                }
                else
                {
                    node.m_next = m_head;
                    m_head.m_prev = node;
                    m_head = node;
                }
                if (updateCount) // update the count to avoid overflow if this add is synchronized
                {
                    m_count = m_count - m_stealCount;
                    m_stealCount = 0;
                }
            }

            /// <summary>
            /// Remove an item from the head of the list
            /// </summary>
            /// <param name="result">The removed item</param>
            internal void Remove(out T result)
            {
                Debug.Assert(m_head != null);
                Node head = m_head;
                m_head = m_head.m_next;
                if (m_head != null)
                {
                    m_head.m_prev = null;
                }
                else
                {
                    m_tail = null;
                }
                m_count--;
                result = head.m_value;

            }

            /// <summary>
            /// Peek an item from the head of the list
            /// </summary>
            /// <param name="result">the peeked item</param>
            /// <returns>True if succeeded, false otherwise</returns>
            internal bool Peek(out T result)
            {
                Node head = m_head;
                if (head != null)
                {
                    result = head.m_value;
                    return true;
                }
                result = default(T);
                return false;
            }

            internal void Steal(out T result, bool remove)
            {
                Node tail = m_tail;
                Debug.Assert(tail != null);
                if (remove) // Take operation
                {
                    m_tail = m_tail.m_prev;
                    if (m_tail != null)
                    {
                        m_tail.m_next = null;
                    }
                    else
                    {
                        m_head = null;
                    }
                    // Increment the steal count
                    m_stealCount++;
                }
                result = tail.m_value;
            }

        }
       internal class Node
        {
            public Node(T value)
            {
                m_value = value;
            }
            public readonly T m_value;
            public Node m_next;
            public Node m_prev;
        }
    }      

首先我們需要知道裡面有2個内部類Node和ThreadLocalList都是連結清單結構,其中Node是雙向連結清單,因為它有m_next和m_prev屬性,但是ThreadLocalList确是單項連結清單隻有m_nextList屬性,ThreadLocalList是Node的集合,有m_head和m_tail屬性指向Node執行個體。現在我們來看ConcurrentBag的幾個變量,ThreadLocal<ThreadLocalList> m_locals表示目前線程的list,是以從這裡我們可以猜測線程安全是采用ThreadLocal來實作的。 volatile ThreadLocalList m_headList, m_tailList;這2個變量應該是可以周遊所有線程的list。

無論是初始化Initialize方法還是添加元素的Add方法,我們首先要調用GetThreadList放來擷取目前線程的list,GetThreadList方法 首先檢查目前線程的m_locals.Value是否存在,有則直接傳回;否者檢查目前線程是否是程式第一個線程【m_headList == null】,如果是則建立新的ThreadLocalList,否者調用GetUnownedList放法檢查是否有孤立ThreadLocalList使用【ThreadLocalList的邏輯線程已經停止,但是該ThreadLocalList執行個體确存在】,如果有則傳回改ThreadLocalList,否則隻有建立ThreadLocalList執行個體。

現在看看AddInternal方法的實作,首先修改ThreadLocalList的m_currentOp标記為添加元素【 Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Add)】,然後在添加元素 list.Add(item, lockTaken);,如果該list需要lock的話,那麼在添加元素前我們還需要加鎖Monitor.Enter(list, ref lockTaken),添加後需要解鎖 Monitor.Exit(list)。ThreadLocalList的Add方法非常簡單,把新節點放到連結清單頭部【 node.m_next = m_head;m_head.m_prev = node; m_head = node;】

添加元素時添加到各個線程的ThreadLocalList,那麼讀取就比較麻煩了,我們需要讀取各各線程ThreadLocalList的資料,也就是說需要用到m_headList, m_tailList兩個變量。如果目前線程存在ThreadLocalList執行個體,那麼直接從ThreadLocalList裡面拿去資料,如果需要加鎖,那麼我們就加鎖【 Monitor.Enter(list, ref lockTaken)】和解鎖【Monitor.Exit(list)】,都是目前線程的list,如果目前線程ThreadLocalList不存在,或者沒有資料,我們需要從其他線程的ThreadLocalList擷取資料,Steal方法 首先或從m_headList開始,依次周遊每一個ThreadLocalList,然後從它們裡面擷取資料,如果擷取不到資料,那麼就再次周遊一下所有的ThreadLocalList,檢查哪些ThreadLocalList的版本m_version在這兩次周遊過程中發生了變化。

do
            {
                versionsList.Clear(); //clear the list from the previous iteration
                loop = false;
              

                ThreadLocalList currentList = m_headList;
                while (currentList != null)
                {
                    versionsList.Add(currentList.m_version);
                    if (currentList.m_head != null && TrySteal(currentList, out result, take))
                    {
                        return true;
                    }
                    currentList = currentList.m_nextList;
                }

                // verify versioning, if other items are added to this list since we last visit it, we should retry
                currentList = m_headList;
                foreach (int version in versionsList)
                {
                    if (version != currentList.m_version) //oops state changed
                    {
                        loop = true;
                        if (currentList.m_head != null && TrySteal(currentList, out result, take))
                            return true;
                    }
                    currentList = currentList.m_nextList;
                }
            } while (loop);      

TrySteal方法的實作就非常簡單了,檢查list是否可以查詢資料【CanSteal(list)】,CanSteal裡面也用了自旋來實作【if (list.Count <= 2 && list.m_currentOp != (int)ListOperation.None){ SpinWait spinner = new SpinWait(); while (list.m_currentOp != (int)ListOperation.None) {spinner.SpinOnce(); } }】,真正Steal實作是由ThreadLocalList來做的,比較簡單。