fork(1) download
  1. // tmp::Sort — C++14 TMP Introsort
  2. // Phase 1: IntroSort (QuickSort + HeapSort fallback), skip small segments
  3. // Phase 2: Single-pass Insertion Sort over whole array (guarded only)
  4. // Correctness first, then optimization
  5.  
  6. #include <utility>
  7. #include <type_traits>
  8. #include <cstdio>
  9. #include <cstdlib>
  10. #include <algorithm>
  11.  
  12. namespace tmp {
  13.  
  14. // ============================================================
  15. // § 1 TMP 基礎工具
  16. // ============================================================
  17. template<int V> struct Int { static constexpr int value = V; };
  18. template<bool V> struct Bool { static constexpr bool value = V; };
  19.  
  20. template<int N> struct Log2 : Int<1 + Log2<N/2>::value> {};
  21. template<> struct Log2<1> : Int<0> {};
  22. template<> struct Log2<0> : Int<0> {};
  23.  
  24. template<bool C, typename T, typename F> struct If { using type = F; };
  25. template<typename T, typename F> struct If<true,T,F>{ using type = T; };
  26.  
  27. // ============================================================
  28. // § 2 Comparator 策略
  29. // ============================================================
  30. template<typename T>
  31. struct Less { bool operator()(const T& a, const T& b) const { return a < b; } };
  32. template<typename T>
  33. struct Greater { bool operator()(const T& a, const T& b) const { return a > b; } };
  34.  
  35. // ============================================================
  36. // § 3 Median-of-3
  37. // 排好 arr[lo] <= arr[mid] <= arr[hi],pivot(=arr[mid])放到 arr[hi-1]
  38. // ============================================================
  39. template<typename Comp>
  40. struct MedianOf3 {
  41. template<typename T>
  42. static void apply(T* a, int lo, int mid, int hi, Comp cmp) {
  43. if (cmp(a[mid], a[lo])) std::swap(a[mid], a[lo]);
  44. if (cmp(a[hi], a[lo])) std::swap(a[hi], a[lo]);
  45. if (cmp(a[hi], a[mid])) std::swap(a[hi], a[mid]);
  46. std::swap(a[mid], a[hi-1]); // pivot → a[hi-1]
  47. }
  48. };
  49.  
  50. // ============================================================
  51. // § 4 3-way Partition
  52. // pivot 已在 a[hi-1];哨兵 a[lo]<=pivot, a[hi]>=pivot
  53. // 結果:a[lt..gt]==pivot
  54. // ============================================================
  55. template<typename Comp>
  56. struct Partition3Way {
  57. template<typename T>
  58. static std::pair<int,int> run(T* a, int lo, int hi, Comp cmp) {
  59. T pivot = a[hi-1];
  60. int lt = lo, gt = hi-1, i = lo;
  61. while (i <= gt) {
  62. if (cmp(a[i], pivot)) std::swap(a[lt++], a[i++]);
  63. else if (cmp(pivot, a[i])) std::swap(a[i], a[gt--]);
  64. else ++i;
  65. }
  66. return {lt, gt};
  67. }
  68. };
  69.  
  70. // ============================================================
  71. // § 5 Heap Sort(depth 超標 fallback)
  72. // ============================================================
  73. template<typename Comp>
  74. struct HeapSort {
  75. template<typename T>
  76. static void siftDown(T* a, int i, int n, int base, Comp cmp) {
  77. T tmp = std::move(a[base+i]);
  78. for (;;) {
  79. int c = 2*i+1;
  80. if (c >= n) break;
  81. if (c+1 < n && cmp(a[base+c], a[base+c+1])) ++c;
  82. if (!cmp(tmp, a[base+c])) break;
  83. a[base+i] = std::move(a[base+c]);
  84. i = c;
  85. }
  86. a[base+i] = std::move(tmp);
  87. }
  88. template<typename T>
  89. static void sort(T* a, int lo, int hi, Comp cmp) {
  90. int n = hi-lo+1;
  91. for (int i = n/2-1; i >= 0; --i) siftDown(a, i, n, lo, cmp);
  92. for (int i = n-1; i > 0; --i) {
  93. std::swap(a[lo], a[lo+i]);
  94. siftDown(a, 0, i, lo, cmp);
  95. }
  96. }
  97. };
  98.  
  99. // ============================================================
  100. // § 6 Insertion Sort(guarded,有邊界保護)
  101. // ============================================================
  102. template<typename Comp>
  103. struct InsertionSort {
  104. template<typename T>
  105. static void sort(T* a, int lo, int hi, Comp cmp) {
  106. for (int i = lo+1; i <= hi; ++i) {
  107. T key = std::move(a[i]);
  108. int j = i-1;
  109. while (j >= lo && cmp(key, a[j])) {
  110. a[j+1] = std::move(a[j]); --j;
  111. }
  112. a[j+1] = std::move(key);
  113. }
  114. }
  115. };
  116.  
  117. // ============================================================
  118. // § 7 IntroSort 核心
  119. // ============================================================
  120. template<typename Comp, int THRESHOLD>
  121. struct IntroCore {
  122. template<typename T>
  123. static void sort(T* a, int lo, int hi, int depth, Comp cmp) {
  124. while (hi - lo + 1 > THRESHOLD) {
  125. if (depth == 0) {
  126. HeapSort<Comp>::sort(a, lo, hi, cmp);
  127. return;
  128. }
  129. --depth;
  130. int mid = lo + (hi-lo)/2;
  131. MedianOf3<Comp>::apply(a, lo, mid, hi, cmp);
  132. auto p = Partition3Way<Comp>::run(a, lo, hi, cmp);
  133. // 尾遞迴:遞迴較小側
  134. if (p.first - lo < hi - p.second) {
  135. sort(a, lo, p.first-1, depth, cmp);
  136. lo = p.second + 1;
  137. } else {
  138. sort(a, p.second+1, hi, depth, cmp);
  139. hi = p.first - 1;
  140. }
  141. }
  142. // 小段落下給 Phase 2
  143. }
  144. };
  145.  
  146. // ============================================================
  147. // § 8 對外介面:tmp::Sort<THRESHOLD>
  148. // ============================================================
  149. template<int THRESHOLD = 16>
  150. struct Sort {
  151. static int depthLimit(int n) {
  152. int d = 0; while (n > 1) { n >>= 1; ++d; } return d * 2;
  153. }
  154.  
  155. // --- sort(arr, n) ---
  156. template<typename T>
  157. static void sort(T* a, int n) { sort(a, n, Less<T>{}); }
  158.  
  159. // --- sort(arr, n, comp) ---
  160. template<typename T, typename Comp>
  161. static void sort(T* a, int n, Comp cmp) {
  162. if (n <= 1) return;
  163. using Core = IntroCore<Comp, THRESHOLD>;
  164. // Phase 1:QuickSort,跳過小片段
  165. Core::sort(a, 0, n-1, depthLimit(n), cmp);
  166. // Phase 2:一次 Insertion Sort 收尾(guarded,正確且簡單)
  167. InsertionSort<Comp>::sort(a, 0, n-1, cmp);
  168. }
  169.  
  170. // --- sort(first, last) ---
  171. template<typename T>
  172. static void sort(T* first, T* last) {
  173. sort(first, static_cast<int>(last - first));
  174. }
  175.  
  176. // --- sort(first, last, comp) ---
  177. template<typename T, typename Comp>
  178. static void sort(T* first, T* last, Comp cmp) {
  179. sort(first, static_cast<int>(last - first), cmp);
  180. }
  181. };
  182.  
  183. } // namespace tmp
  184.  
  185. // ============================================================
  186. // § 9 Demo
  187. // ============================================================
  188. static bool verify(const int* a, const int* b, int n) {
  189. for (int i = 0; i < n; ++i) if (a[i] != b[i]) return false;
  190. return true;
  191. }
  192. static void test(const char* name, int* a, int* r, int n) {
  193. std::sort(r, r+n);
  194. tmp::Sort<>::sort(a, n);
  195. printf("[%s] n=%-8d %s\n", verify(a,r,n)?"PASS":"FAIL", n, name);
  196. }
  197. static void show(const char* lbl, const int* a, int n) {
  198. printf(" %-30s", lbl);
  199. int s = n<20?n:20;
  200. for(int i=0;i<s;i++) printf("%d%c",a[i]," \n"[i==s-1]);
  201. if(n>20) printf(" ...(n=%d)\n",n);
  202. }
  203.  
  204. int main() {
  205. srand(42);
  206. printf("=== tmp::Sort<> — Introsort TMP ===\n\n");
  207.  
  208. // 1. 基本
  209. { int a[]={5,3,8,1,9,2,7,4,6,0},r[]={5,3,8,1,9,2,7,4,6,0};
  210. test("random 10",a,r,10); show("→",a,10); }
  211.  
  212. // 2. 邊界:n=0,1,2
  213. { int a1[]={}, r1[]={}; test("n=0",a1,r1,0); }
  214. { int a2[]={7}, r2[]={7}; test("n=1",a2,r2,1); }
  215. { int a3[]={5,2}, r3[]={5,2}; test("n=2",a3,r3,2); }
  216.  
  217. // 3. 已排序 升/降
  218. { static int a[10000],r[10000];
  219. for(int i=0;i<10000;i++) a[i]=r[i]=i;
  220. test("sorted asc 10000",a,r,10000);
  221. for(int i=0;i<10000;i++) a[i]=r[i]=9999-i;
  222. test("sorted desc 10000",a,r,10000); }
  223.  
  224. // 4. 全部相同
  225. { static int a[100000],r[100000];
  226. for(int i=0;i<100000;i++) a[i]=r[i]=42;
  227. test("all equal 100000",a,r,100000); }
  228.  
  229. // 5. 大量重複(3-way partition 優勢)
  230. { static int a[100000],r[100000];
  231. for(int i=0;i<100000;i++) a[i]=r[i]=rand()%100;
  232. test("100 vals in 100000",a,r,100000); }
  233.  
  234. // 6. 大陣列隨機
  235. { static int a[1000000],r[1000000];
  236. for(int i=0;i<1000000;i++) a[i]=r[i]=rand();
  237. test("random 1000000",a,r,1000000); }
  238.  
  239. // 7. 自訂 Comp(降序)
  240. printf("\n--- Descending ---\n");
  241. { int a[]={3,1,4,1,5,9,2,6,5,3};
  242. tmp::Sort<>::sort(a,10,tmp::Greater<int>{});
  243. show("desc:",a,10); }
  244.  
  245. // 8. 指標範圍介面
  246. printf("--- Pointer range ---\n");
  247. { int a[]={9,1,8,2,7,3,6,4,5,0};
  248. tmp::Sort<>::sort(a,a+10);
  249. show("range:",a,10); }
  250.  
  251. // 9. 自訂 struct + lambda
  252. printf("--- Struct + lambda ---\n");
  253. { struct Pt { int x,y; };
  254. Pt pts[]={{3,1},{1,5},{2,2},{1,1},{3,0}};
  255. tmp::Sort<>::sort(pts,5,[](const Pt&a,const Pt&b){
  256. return a.x!=b.x?a.x<b.x:a.y<b.y; });
  257. for(int i=0;i<5;i++) printf("(%d,%d)%c",pts[i].x,pts[i].y," \n"[i==4]); }
  258.  
  259. // 10. 不同 THRESHOLD
  260. printf("--- THRESHOLD=4 ---\n");
  261. { static int a[50000],r[50000];
  262. for(int i=0;i<50000;i++) a[i]=r[i]=rand();
  263. std::sort(r,r+50000);
  264. tmp::Sort<4>::sort(a,50000);
  265. printf("[%s] THRESHOLD=4 n=50000\n",verify(a,r,50000)?"PASS":"FAIL"); }
  266.  
  267. // 11. 互動
  268. printf("\n--- Interactive ---\n");
  269. { int n; scanf("%d",&n);
  270. static int a[1000000];
  271. for(int i=0;i<n;i++) scanf("%d",&a[i]);
  272. tmp::Sort<>::sort(a,n);
  273. for(int i=0;i<n;i++) printf("%d%c",a[i]," \n"[i==n-1]); }
  274.  
  275. return 0;
  276. }
Success #stdin #stdout 0.16s 12884KB
stdin
Standard input is empty
stdout
=== tmp::Sort<> — Introsort TMP ===

[PASS] n=10        random 10
  →                           0 1 2 3 4 5 6 7 8 9
[PASS] n=0         n=0
[PASS] n=1         n=1
[PASS] n=2         n=2
[PASS] n=10000     sorted asc 10000
[PASS] n=10000     sorted desc 10000
[PASS] n=100000    all equal 100000
[PASS] n=100000    100 vals in 100000
[PASS] n=1000000   random 1000000

--- Descending ---
  desc:                         9 6 5 5 4 3 3 2 1 1
--- Pointer range ---
  range:                        0 1 2 3 4 5 6 7 8 9
--- Struct + lambda ---
(1,1) (1,5) (2,2) (3,0) (3,1)
--- THRESHOLD=4 ---
[PASS] THRESHOLD=4 n=50000

--- Interactive ---