Linux 6.10-rc4
[linux.git] / mm / mempolicy.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Simple NUMA memory policy for the Linux kernel.
4  *
5  * Copyright 2003,2004 Andi Kleen, SuSE Labs.
6  * (C) Copyright 2005 Christoph Lameter, Silicon Graphics, Inc.
7  *
8  * NUMA policy allows the user to give hints in which node(s) memory should
9  * be allocated.
10  *
11  * Support four policies per VMA and per process:
12  *
13  * The VMA policy has priority over the process policy for a page fault.
14  *
15  * interleave     Allocate memory interleaved over a set of nodes,
16  *                with normal fallback if it fails.
17  *                For VMA based allocations this interleaves based on the
18  *                offset into the backing object or offset into the mapping
19  *                for anonymous memory. For process policy an process counter
20  *                is used.
21  *
22  * weighted interleave
23  *                Allocate memory interleaved over a set of nodes based on
24  *                a set of weights (per-node), with normal fallback if it
25  *                fails.  Otherwise operates the same as interleave.
26  *                Example: nodeset(0,1) & weights (2,1) - 2 pages allocated
27  *                on node 0 for every 1 page allocated on node 1.
28  *
29  * bind           Only allocate memory on a specific set of nodes,
30  *                no fallback.
31  *                FIXME: memory is allocated starting with the first node
32  *                to the last. It would be better if bind would truly restrict
33  *                the allocation to memory nodes instead
34  *
35  * preferred      Try a specific node first before normal fallback.
36  *                As a special case NUMA_NO_NODE here means do the allocation
37  *                on the local CPU. This is normally identical to default,
38  *                but useful to set in a VMA when you have a non default
39  *                process policy.
40  *
41  * preferred many Try a set of nodes first before normal fallback. This is
42  *                similar to preferred without the special case.
43  *
44  * default        Allocate on the local node first, or when on a VMA
45  *                use the process policy. This is what Linux always did
46  *                in a NUMA aware kernel and still does by, ahem, default.
47  *
48  * The process policy is applied for most non interrupt memory allocations
49  * in that process' context. Interrupts ignore the policies and always
50  * try to allocate on the local CPU. The VMA policy is only applied for memory
51  * allocations for a VMA in the VM.
52  *
53  * Currently there are a few corner cases in swapping where the policy
54  * is not applied, but the majority should be handled. When process policy
55  * is used it is not remembered over swap outs/swap ins.
56  *
57  * Only the highest zone in the zone hierarchy gets policied. Allocations
58  * requesting a lower zone just use default policy. This implies that
59  * on systems with highmem kernel lowmem allocation don't get policied.
60  * Same with GFP_DMA allocations.
61  *
62  * For shmem/tmpfs shared memory the policy is shared between
63  * all users and remembered even when nobody has memory mapped.
64  */
65
66 /* Notebook:
67    fix mmap readahead to honour policy and enable policy for any page cache
68    object
69    statistics for bigpages
70    global policy for page cache? currently it uses process policy. Requires
71    first item above.
72    handle mremap for shared memory (currently ignored for the policy)
73    grows down?
74    make bind policy root only? It can trigger oom much faster and the
75    kernel is not always grateful with that.
76 */
77
78 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
79
80 #include <linux/mempolicy.h>
81 #include <linux/pagewalk.h>
82 #include <linux/highmem.h>
83 #include <linux/hugetlb.h>
84 #include <linux/kernel.h>
85 #include <linux/sched.h>
86 #include <linux/sched/mm.h>
87 #include <linux/sched/numa_balancing.h>
88 #include <linux/sched/task.h>
89 #include <linux/nodemask.h>
90 #include <linux/cpuset.h>
91 #include <linux/slab.h>
92 #include <linux/string.h>
93 #include <linux/export.h>
94 #include <linux/nsproxy.h>
95 #include <linux/interrupt.h>
96 #include <linux/init.h>
97 #include <linux/compat.h>
98 #include <linux/ptrace.h>
99 #include <linux/swap.h>
100 #include <linux/seq_file.h>
101 #include <linux/proc_fs.h>
102 #include <linux/migrate.h>
103 #include <linux/ksm.h>
104 #include <linux/rmap.h>
105 #include <linux/security.h>
106 #include <linux/syscalls.h>
107 #include <linux/ctype.h>
108 #include <linux/mm_inline.h>
109 #include <linux/mmu_notifier.h>
110 #include <linux/printk.h>
111 #include <linux/swapops.h>
112
113 #include <asm/tlbflush.h>
114 #include <asm/tlb.h>
115 #include <linux/uaccess.h>
116
117 #include "internal.h"
118
119 /* Internal flags */
120 #define MPOL_MF_DISCONTIG_OK (MPOL_MF_INTERNAL << 0)    /* Skip checks for continuous vmas */
121 #define MPOL_MF_INVERT       (MPOL_MF_INTERNAL << 1)    /* Invert check for nodemask */
122 #define MPOL_MF_WRLOCK       (MPOL_MF_INTERNAL << 2)    /* Write-lock walked vmas */
123
124 static struct kmem_cache *policy_cache;
125 static struct kmem_cache *sn_cache;
126
127 /* Highest zone. An specific allocation for a zone below that is not
128    policied. */
129 enum zone_type policy_zone = 0;
130
131 /*
132  * run-time system-wide default policy => local allocation
133  */
134 static struct mempolicy default_policy = {
135         .refcnt = ATOMIC_INIT(1), /* never free it */
136         .mode = MPOL_LOCAL,
137 };
138
139 static struct mempolicy preferred_node_policy[MAX_NUMNODES];
140
141 /*
142  * iw_table is the sysfs-set interleave weight table, a value of 0 denotes
143  * system-default value should be used. A NULL iw_table also denotes that
144  * system-default values should be used. Until the system-default table
145  * is implemented, the system-default is always 1.
146  *
147  * iw_table is RCU protected
148  */
149 static u8 __rcu *iw_table;
150 static DEFINE_MUTEX(iw_table_lock);
151
152 static u8 get_il_weight(int node)
153 {
154         u8 *table;
155         u8 weight;
156
157         rcu_read_lock();
158         table = rcu_dereference(iw_table);
159         /* if no iw_table, use system default */
160         weight = table ? table[node] : 1;
161         /* if value in iw_table is 0, use system default */
162         weight = weight ? weight : 1;
163         rcu_read_unlock();
164         return weight;
165 }
166
167 /**
168  * numa_nearest_node - Find nearest node by state
169  * @node: Node id to start the search
170  * @state: State to filter the search
171  *
172  * Lookup the closest node by distance if @nid is not in state.
173  *
174  * Return: this @node if it is in state, otherwise the closest node by distance
175  */
176 int numa_nearest_node(int node, unsigned int state)
177 {
178         int min_dist = INT_MAX, dist, n, min_node;
179
180         if (state >= NR_NODE_STATES)
181                 return -EINVAL;
182
183         if (node == NUMA_NO_NODE || node_state(node, state))
184                 return node;
185
186         min_node = node;
187         for_each_node_state(n, state) {
188                 dist = node_distance(node, n);
189                 if (dist < min_dist) {
190                         min_dist = dist;
191                         min_node = n;
192                 }
193         }
194
195         return min_node;
196 }
197 EXPORT_SYMBOL_GPL(numa_nearest_node);
198
199 struct mempolicy *get_task_policy(struct task_struct *p)
200 {
201         struct mempolicy *pol = p->mempolicy;
202         int node;
203
204         if (pol)
205                 return pol;
206
207         node = numa_node_id();
208         if (node != NUMA_NO_NODE) {
209                 pol = &preferred_node_policy[node];
210                 /* preferred_node_policy is not initialised early in boot */
211                 if (pol->mode)
212                         return pol;
213         }
214
215         return &default_policy;
216 }
217
218 static const struct mempolicy_operations {
219         int (*create)(struct mempolicy *pol, const nodemask_t *nodes);
220         void (*rebind)(struct mempolicy *pol, const nodemask_t *nodes);
221 } mpol_ops[MPOL_MAX];
222
223 static inline int mpol_store_user_nodemask(const struct mempolicy *pol)
224 {
225         return pol->flags & MPOL_MODE_FLAGS;
226 }
227
228 static void mpol_relative_nodemask(nodemask_t *ret, const nodemask_t *orig,
229                                    const nodemask_t *rel)
230 {
231         nodemask_t tmp;
232         nodes_fold(tmp, *orig, nodes_weight(*rel));
233         nodes_onto(*ret, tmp, *rel);
234 }
235
236 static int mpol_new_nodemask(struct mempolicy *pol, const nodemask_t *nodes)
237 {
238         if (nodes_empty(*nodes))
239                 return -EINVAL;
240         pol->nodes = *nodes;
241         return 0;
242 }
243
244 static int mpol_new_preferred(struct mempolicy *pol, const nodemask_t *nodes)
245 {
246         if (nodes_empty(*nodes))
247                 return -EINVAL;
248
249         nodes_clear(pol->nodes);
250         node_set(first_node(*nodes), pol->nodes);
251         return 0;
252 }
253
254 /*
255  * mpol_set_nodemask is called after mpol_new() to set up the nodemask, if
256  * any, for the new policy.  mpol_new() has already validated the nodes
257  * parameter with respect to the policy mode and flags.
258  *
259  * Must be called holding task's alloc_lock to protect task's mems_allowed
260  * and mempolicy.  May also be called holding the mmap_lock for write.
261  */
262 static int mpol_set_nodemask(struct mempolicy *pol,
263                      const nodemask_t *nodes, struct nodemask_scratch *nsc)
264 {
265         int ret;
266
267         /*
268          * Default (pol==NULL) resp. local memory policies are not a
269          * subject of any remapping. They also do not need any special
270          * constructor.
271          */
272         if (!pol || pol->mode == MPOL_LOCAL)
273                 return 0;
274
275         /* Check N_MEMORY */
276         nodes_and(nsc->mask1,
277                   cpuset_current_mems_allowed, node_states[N_MEMORY]);
278
279         VM_BUG_ON(!nodes);
280
281         if (pol->flags & MPOL_F_RELATIVE_NODES)
282                 mpol_relative_nodemask(&nsc->mask2, nodes, &nsc->mask1);
283         else
284                 nodes_and(nsc->mask2, *nodes, nsc->mask1);
285
286         if (mpol_store_user_nodemask(pol))
287                 pol->w.user_nodemask = *nodes;
288         else
289                 pol->w.cpuset_mems_allowed = cpuset_current_mems_allowed;
290
291         ret = mpol_ops[pol->mode].create(pol, &nsc->mask2);
292         return ret;
293 }
294
295 /*
296  * This function just creates a new policy, does some check and simple
297  * initialization. You must invoke mpol_set_nodemask() to set nodes.
298  */
299 static struct mempolicy *mpol_new(unsigned short mode, unsigned short flags,
300                                   nodemask_t *nodes)
301 {
302         struct mempolicy *policy;
303
304         if (mode == MPOL_DEFAULT) {
305                 if (nodes && !nodes_empty(*nodes))
306                         return ERR_PTR(-EINVAL);
307                 return NULL;
308         }
309         VM_BUG_ON(!nodes);
310
311         /*
312          * MPOL_PREFERRED cannot be used with MPOL_F_STATIC_NODES or
313          * MPOL_F_RELATIVE_NODES if the nodemask is empty (local allocation).
314          * All other modes require a valid pointer to a non-empty nodemask.
315          */
316         if (mode == MPOL_PREFERRED) {
317                 if (nodes_empty(*nodes)) {
318                         if (((flags & MPOL_F_STATIC_NODES) ||
319                              (flags & MPOL_F_RELATIVE_NODES)))
320                                 return ERR_PTR(-EINVAL);
321
322                         mode = MPOL_LOCAL;
323                 }
324         } else if (mode == MPOL_LOCAL) {
325                 if (!nodes_empty(*nodes) ||
326                     (flags & MPOL_F_STATIC_NODES) ||
327                     (flags & MPOL_F_RELATIVE_NODES))
328                         return ERR_PTR(-EINVAL);
329         } else if (nodes_empty(*nodes))
330                 return ERR_PTR(-EINVAL);
331
332         policy = kmem_cache_alloc(policy_cache, GFP_KERNEL);
333         if (!policy)
334                 return ERR_PTR(-ENOMEM);
335         atomic_set(&policy->refcnt, 1);
336         policy->mode = mode;
337         policy->flags = flags;
338         policy->home_node = NUMA_NO_NODE;
339
340         return policy;
341 }
342
343 /* Slow path of a mpol destructor. */
344 void __mpol_put(struct mempolicy *pol)
345 {
346         if (!atomic_dec_and_test(&pol->refcnt))
347                 return;
348         kmem_cache_free(policy_cache, pol);
349 }
350
351 static void mpol_rebind_default(struct mempolicy *pol, const nodemask_t *nodes)
352 {
353 }
354
355 static void mpol_rebind_nodemask(struct mempolicy *pol, const nodemask_t *nodes)
356 {
357         nodemask_t tmp;
358
359         if (pol->flags & MPOL_F_STATIC_NODES)
360                 nodes_and(tmp, pol->w.user_nodemask, *nodes);
361         else if (pol->flags & MPOL_F_RELATIVE_NODES)
362                 mpol_relative_nodemask(&tmp, &pol->w.user_nodemask, nodes);
363         else {
364                 nodes_remap(tmp, pol->nodes, pol->w.cpuset_mems_allowed,
365                                                                 *nodes);
366                 pol->w.cpuset_mems_allowed = *nodes;
367         }
368
369         if (nodes_empty(tmp))
370                 tmp = *nodes;
371
372         pol->nodes = tmp;
373 }
374
375 static void mpol_rebind_preferred(struct mempolicy *pol,
376                                                 const nodemask_t *nodes)
377 {
378         pol->w.cpuset_mems_allowed = *nodes;
379 }
380
381 /*
382  * mpol_rebind_policy - Migrate a policy to a different set of nodes
383  *
384  * Per-vma policies are protected by mmap_lock. Allocations using per-task
385  * policies are protected by task->mems_allowed_seq to prevent a premature
386  * OOM/allocation failure due to parallel nodemask modification.
387  */
388 static void mpol_rebind_policy(struct mempolicy *pol, const nodemask_t *newmask)
389 {
390         if (!pol || pol->mode == MPOL_LOCAL)
391                 return;
392         if (!mpol_store_user_nodemask(pol) &&
393             nodes_equal(pol->w.cpuset_mems_allowed, *newmask))
394                 return;
395
396         mpol_ops[pol->mode].rebind(pol, newmask);
397 }
398
399 /*
400  * Wrapper for mpol_rebind_policy() that just requires task
401  * pointer, and updates task mempolicy.
402  *
403  * Called with task's alloc_lock held.
404  */
405 void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new)
406 {
407         mpol_rebind_policy(tsk->mempolicy, new);
408 }
409
410 /*
411  * Rebind each vma in mm to new nodemask.
412  *
413  * Call holding a reference to mm.  Takes mm->mmap_lock during call.
414  */
415 void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new)
416 {
417         struct vm_area_struct *vma;
418         VMA_ITERATOR(vmi, mm, 0);
419
420         mmap_write_lock(mm);
421         for_each_vma(vmi, vma) {
422                 vma_start_write(vma);
423                 mpol_rebind_policy(vma->vm_policy, new);
424         }
425         mmap_write_unlock(mm);
426 }
427
428 static const struct mempolicy_operations mpol_ops[MPOL_MAX] = {
429         [MPOL_DEFAULT] = {
430                 .rebind = mpol_rebind_default,
431         },
432         [MPOL_INTERLEAVE] = {
433                 .create = mpol_new_nodemask,
434                 .rebind = mpol_rebind_nodemask,
435         },
436         [MPOL_PREFERRED] = {
437                 .create = mpol_new_preferred,
438                 .rebind = mpol_rebind_preferred,
439         },
440         [MPOL_BIND] = {
441                 .create = mpol_new_nodemask,
442                 .rebind = mpol_rebind_nodemask,
443         },
444         [MPOL_LOCAL] = {
445                 .rebind = mpol_rebind_default,
446         },
447         [MPOL_PREFERRED_MANY] = {
448                 .create = mpol_new_nodemask,
449                 .rebind = mpol_rebind_preferred,
450         },
451         [MPOL_WEIGHTED_INTERLEAVE] = {
452                 .create = mpol_new_nodemask,
453                 .rebind = mpol_rebind_nodemask,
454         },
455 };
456
457 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
458                                 unsigned long flags);
459 static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol,
460                                 pgoff_t ilx, int *nid);
461
462 static bool strictly_unmovable(unsigned long flags)
463 {
464         /*
465          * STRICT without MOVE flags lets do_mbind() fail immediately with -EIO
466          * if any misplaced page is found.
467          */
468         return (flags & (MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ==
469                          MPOL_MF_STRICT;
470 }
471
472 struct migration_mpol {         /* for alloc_migration_target_by_mpol() */
473         struct mempolicy *pol;
474         pgoff_t ilx;
475 };
476
477 struct queue_pages {
478         struct list_head *pagelist;
479         unsigned long flags;
480         nodemask_t *nmask;
481         unsigned long start;
482         unsigned long end;
483         struct vm_area_struct *first;
484         struct folio *large;            /* note last large folio encountered */
485         long nr_failed;                 /* could not be isolated at this time */
486 };
487
488 /*
489  * Check if the folio's nid is in qp->nmask.
490  *
491  * If MPOL_MF_INVERT is set in qp->flags, check if the nid is
492  * in the invert of qp->nmask.
493  */
494 static inline bool queue_folio_required(struct folio *folio,
495                                         struct queue_pages *qp)
496 {
497         int nid = folio_nid(folio);
498         unsigned long flags = qp->flags;
499
500         return node_isset(nid, *qp->nmask) == !(flags & MPOL_MF_INVERT);
501 }
502
503 static void queue_folios_pmd(pmd_t *pmd, struct mm_walk *walk)
504 {
505         struct folio *folio;
506         struct queue_pages *qp = walk->private;
507
508         if (unlikely(is_pmd_migration_entry(*pmd))) {
509                 qp->nr_failed++;
510                 return;
511         }
512         folio = pmd_folio(*pmd);
513         if (is_huge_zero_folio(folio)) {
514                 walk->action = ACTION_CONTINUE;
515                 return;
516         }
517         if (!queue_folio_required(folio, qp))
518                 return;
519         if (!(qp->flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
520             !vma_migratable(walk->vma) ||
521             !migrate_folio_add(folio, qp->pagelist, qp->flags))
522                 qp->nr_failed++;
523 }
524
525 /*
526  * Scan through folios, checking if they satisfy the required conditions,
527  * moving them from LRU to local pagelist for migration if they do (or not).
528  *
529  * queue_folios_pte_range() has two possible return values:
530  * 0 - continue walking to scan for more, even if an existing folio on the
531  *     wrong node could not be isolated and queued for migration.
532  * -EIO - only MPOL_MF_STRICT was specified, without MPOL_MF_MOVE or ..._ALL,
533  *        and an existing folio was on a node that does not follow the policy.
534  */
535 static int queue_folios_pte_range(pmd_t *pmd, unsigned long addr,
536                         unsigned long end, struct mm_walk *walk)
537 {
538         struct vm_area_struct *vma = walk->vma;
539         struct folio *folio;
540         struct queue_pages *qp = walk->private;
541         unsigned long flags = qp->flags;
542         pte_t *pte, *mapped_pte;
543         pte_t ptent;
544         spinlock_t *ptl;
545
546         ptl = pmd_trans_huge_lock(pmd, vma);
547         if (ptl) {
548                 queue_folios_pmd(pmd, walk);
549                 spin_unlock(ptl);
550                 goto out;
551         }
552
553         mapped_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
554         if (!pte) {
555                 walk->action = ACTION_AGAIN;
556                 return 0;
557         }
558         for (; addr != end; pte++, addr += PAGE_SIZE) {
559                 ptent = ptep_get(pte);
560                 if (pte_none(ptent))
561                         continue;
562                 if (!pte_present(ptent)) {
563                         if (is_migration_entry(pte_to_swp_entry(ptent)))
564                                 qp->nr_failed++;
565                         continue;
566                 }
567                 folio = vm_normal_folio(vma, addr, ptent);
568                 if (!folio || folio_is_zone_device(folio))
569                         continue;
570                 /*
571                  * vm_normal_folio() filters out zero pages, but there might
572                  * still be reserved folios to skip, perhaps in a VDSO.
573                  */
574                 if (folio_test_reserved(folio))
575                         continue;
576                 if (!queue_folio_required(folio, qp))
577                         continue;
578                 if (folio_test_large(folio)) {
579                         /*
580                          * A large folio can only be isolated from LRU once,
581                          * but may be mapped by many PTEs (and Copy-On-Write may
582                          * intersperse PTEs of other, order 0, folios).  This is
583                          * a common case, so don't mistake it for failure (but
584                          * there can be other cases of multi-mapped pages which
585                          * this quick check does not help to filter out - and a
586                          * search of the pagelist might grow to be prohibitive).
587                          *
588                          * migrate_pages(&pagelist) returns nr_failed folios, so
589                          * check "large" now so that queue_pages_range() returns
590                          * a comparable nr_failed folios.  This does imply that
591                          * if folio could not be isolated for some racy reason
592                          * at its first PTE, later PTEs will not give it another
593                          * chance of isolation; but keeps the accounting simple.
594                          */
595                         if (folio == qp->large)
596                                 continue;
597                         qp->large = folio;
598                 }
599                 if (!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
600                     !vma_migratable(vma) ||
601                     !migrate_folio_add(folio, qp->pagelist, flags)) {
602                         qp->nr_failed++;
603                         if (strictly_unmovable(flags))
604                                 break;
605                 }
606         }
607         pte_unmap_unlock(mapped_pte, ptl);
608         cond_resched();
609 out:
610         if (qp->nr_failed && strictly_unmovable(flags))
611                 return -EIO;
612         return 0;
613 }
614
615 static int queue_folios_hugetlb(pte_t *pte, unsigned long hmask,
616                                unsigned long addr, unsigned long end,
617                                struct mm_walk *walk)
618 {
619 #ifdef CONFIG_HUGETLB_PAGE
620         struct queue_pages *qp = walk->private;
621         unsigned long flags = qp->flags;
622         struct folio *folio;
623         spinlock_t *ptl;
624         pte_t entry;
625
626         ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
627         entry = huge_ptep_get(pte);
628         if (!pte_present(entry)) {
629                 if (unlikely(is_hugetlb_entry_migration(entry)))
630                         qp->nr_failed++;
631                 goto unlock;
632         }
633         folio = pfn_folio(pte_pfn(entry));
634         if (!queue_folio_required(folio, qp))
635                 goto unlock;
636         if (!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
637             !vma_migratable(walk->vma)) {
638                 qp->nr_failed++;
639                 goto unlock;
640         }
641         /*
642          * Unless MPOL_MF_MOVE_ALL, we try to avoid migrating a shared folio.
643          * Choosing not to migrate a shared folio is not counted as a failure.
644          *
645          * See folio_likely_mapped_shared() on possible imprecision when we
646          * cannot easily detect if a folio is shared.
647          */
648         if ((flags & MPOL_MF_MOVE_ALL) ||
649             (!folio_likely_mapped_shared(folio) && !hugetlb_pmd_shared(pte)))
650                 if (!isolate_hugetlb(folio, qp->pagelist))
651                         qp->nr_failed++;
652 unlock:
653         spin_unlock(ptl);
654         if (qp->nr_failed && strictly_unmovable(flags))
655                 return -EIO;
656 #endif
657         return 0;
658 }
659
660 #ifdef CONFIG_NUMA_BALANCING
661 /*
662  * This is used to mark a range of virtual addresses to be inaccessible.
663  * These are later cleared by a NUMA hinting fault. Depending on these
664  * faults, pages may be migrated for better NUMA placement.
665  *
666  * This is assuming that NUMA faults are handled using PROT_NONE. If
667  * an architecture makes a different choice, it will need further
668  * changes to the core.
669  */
670 unsigned long change_prot_numa(struct vm_area_struct *vma,
671                         unsigned long addr, unsigned long end)
672 {
673         struct mmu_gather tlb;
674         long nr_updated;
675
676         tlb_gather_mmu(&tlb, vma->vm_mm);
677
678         nr_updated = change_protection(&tlb, vma, addr, end, MM_CP_PROT_NUMA);
679         if (nr_updated > 0)
680                 count_vm_numa_events(NUMA_PTE_UPDATES, nr_updated);
681
682         tlb_finish_mmu(&tlb);
683
684         return nr_updated;
685 }
686 #endif /* CONFIG_NUMA_BALANCING */
687
688 static int queue_pages_test_walk(unsigned long start, unsigned long end,
689                                 struct mm_walk *walk)
690 {
691         struct vm_area_struct *next, *vma = walk->vma;
692         struct queue_pages *qp = walk->private;
693         unsigned long flags = qp->flags;
694
695         /* range check first */
696         VM_BUG_ON_VMA(!range_in_vma(vma, start, end), vma);
697
698         if (!qp->first) {
699                 qp->first = vma;
700                 if (!(flags & MPOL_MF_DISCONTIG_OK) &&
701                         (qp->start < vma->vm_start))
702                         /* hole at head side of range */
703                         return -EFAULT;
704         }
705         next = find_vma(vma->vm_mm, vma->vm_end);
706         if (!(flags & MPOL_MF_DISCONTIG_OK) &&
707                 ((vma->vm_end < qp->end) &&
708                 (!next || vma->vm_end < next->vm_start)))
709                 /* hole at middle or tail of range */
710                 return -EFAULT;
711
712         /*
713          * Need check MPOL_MF_STRICT to return -EIO if possible
714          * regardless of vma_migratable
715          */
716         if (!vma_migratable(vma) &&
717             !(flags & MPOL_MF_STRICT))
718                 return 1;
719
720         /*
721          * Check page nodes, and queue pages to move, in the current vma.
722          * But if no moving, and no strict checking, the scan can be skipped.
723          */
724         if (flags & (MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
725                 return 0;
726         return 1;
727 }
728
729 static const struct mm_walk_ops queue_pages_walk_ops = {
730         .hugetlb_entry          = queue_folios_hugetlb,
731         .pmd_entry              = queue_folios_pte_range,
732         .test_walk              = queue_pages_test_walk,
733         .walk_lock              = PGWALK_RDLOCK,
734 };
735
736 static const struct mm_walk_ops queue_pages_lock_vma_walk_ops = {
737         .hugetlb_entry          = queue_folios_hugetlb,
738         .pmd_entry              = queue_folios_pte_range,
739         .test_walk              = queue_pages_test_walk,
740         .walk_lock              = PGWALK_WRLOCK,
741 };
742
743 /*
744  * Walk through page tables and collect pages to be migrated.
745  *
746  * If pages found in a given range are not on the required set of @nodes,
747  * and migration is allowed, they are isolated and queued to @pagelist.
748  *
749  * queue_pages_range() may return:
750  * 0 - all pages already on the right node, or successfully queued for moving
751  *     (or neither strict checking nor moving requested: only range checking).
752  * >0 - this number of misplaced folios could not be queued for moving
753  *      (a hugetlbfs page or a transparent huge page being counted as 1).
754  * -EIO - a misplaced page found, when MPOL_MF_STRICT specified without MOVEs.
755  * -EFAULT - a hole in the memory range, when MPOL_MF_DISCONTIG_OK unspecified.
756  */
757 static long
758 queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
759                 nodemask_t *nodes, unsigned long flags,
760                 struct list_head *pagelist)
761 {
762         int err;
763         struct queue_pages qp = {
764                 .pagelist = pagelist,
765                 .flags = flags,
766                 .nmask = nodes,
767                 .start = start,
768                 .end = end,
769                 .first = NULL,
770         };
771         const struct mm_walk_ops *ops = (flags & MPOL_MF_WRLOCK) ?
772                         &queue_pages_lock_vma_walk_ops : &queue_pages_walk_ops;
773
774         err = walk_page_range(mm, start, end, ops, &qp);
775
776         if (!qp.first)
777                 /* whole range in hole */
778                 err = -EFAULT;
779
780         return err ? : qp.nr_failed;
781 }
782
783 /*
784  * Apply policy to a single VMA
785  * This must be called with the mmap_lock held for writing.
786  */
787 static int vma_replace_policy(struct vm_area_struct *vma,
788                                 struct mempolicy *pol)
789 {
790         int err;
791         struct mempolicy *old;
792         struct mempolicy *new;
793
794         vma_assert_write_locked(vma);
795
796         new = mpol_dup(pol);
797         if (IS_ERR(new))
798                 return PTR_ERR(new);
799
800         if (vma->vm_ops && vma->vm_ops->set_policy) {
801                 err = vma->vm_ops->set_policy(vma, new);
802                 if (err)
803                         goto err_out;
804         }
805
806         old = vma->vm_policy;
807         vma->vm_policy = new; /* protected by mmap_lock */
808         mpol_put(old);
809
810         return 0;
811  err_out:
812         mpol_put(new);
813         return err;
814 }
815
816 /* Split or merge the VMA (if required) and apply the new policy */
817 static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
818                 struct vm_area_struct **prev, unsigned long start,
819                 unsigned long end, struct mempolicy *new_pol)
820 {
821         unsigned long vmstart, vmend;
822
823         vmend = min(end, vma->vm_end);
824         if (start > vma->vm_start) {
825                 *prev = vma;
826                 vmstart = start;
827         } else {
828                 vmstart = vma->vm_start;
829         }
830
831         if (mpol_equal(vma->vm_policy, new_pol)) {
832                 *prev = vma;
833                 return 0;
834         }
835
836         vma =  vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
837         if (IS_ERR(vma))
838                 return PTR_ERR(vma);
839
840         *prev = vma;
841         return vma_replace_policy(vma, new_pol);
842 }
843
844 /* Set the process memory policy */
845 static long do_set_mempolicy(unsigned short mode, unsigned short flags,
846                              nodemask_t *nodes)
847 {
848         struct mempolicy *new, *old;
849         NODEMASK_SCRATCH(scratch);
850         int ret;
851
852         if (!scratch)
853                 return -ENOMEM;
854
855         new = mpol_new(mode, flags, nodes);
856         if (IS_ERR(new)) {
857                 ret = PTR_ERR(new);
858                 goto out;
859         }
860
861         task_lock(current);
862         ret = mpol_set_nodemask(new, nodes, scratch);
863         if (ret) {
864                 task_unlock(current);
865                 mpol_put(new);
866                 goto out;
867         }
868
869         old = current->mempolicy;
870         current->mempolicy = new;
871         if (new && (new->mode == MPOL_INTERLEAVE ||
872                     new->mode == MPOL_WEIGHTED_INTERLEAVE)) {
873                 current->il_prev = MAX_NUMNODES-1;
874                 current->il_weight = 0;
875         }
876         task_unlock(current);
877         mpol_put(old);
878         ret = 0;
879 out:
880         NODEMASK_SCRATCH_FREE(scratch);
881         return ret;
882 }
883
884 /*
885  * Return nodemask for policy for get_mempolicy() query
886  *
887  * Called with task's alloc_lock held
888  */
889 static void get_policy_nodemask(struct mempolicy *pol, nodemask_t *nodes)
890 {
891         nodes_clear(*nodes);
892         if (pol == &default_policy)
893                 return;
894
895         switch (pol->mode) {
896         case MPOL_BIND:
897         case MPOL_INTERLEAVE:
898         case MPOL_PREFERRED:
899         case MPOL_PREFERRED_MANY:
900         case MPOL_WEIGHTED_INTERLEAVE:
901                 *nodes = pol->nodes;
902                 break;
903         case MPOL_LOCAL:
904                 /* return empty node mask for local allocation */
905                 break;
906         default:
907                 BUG();
908         }
909 }
910
911 static int lookup_node(struct mm_struct *mm, unsigned long addr)
912 {
913         struct page *p = NULL;
914         int ret;
915
916         ret = get_user_pages_fast(addr & PAGE_MASK, 1, 0, &p);
917         if (ret > 0) {
918                 ret = page_to_nid(p);
919                 put_page(p);
920         }
921         return ret;
922 }
923
924 /* Retrieve NUMA policy */
925 static long do_get_mempolicy(int *policy, nodemask_t *nmask,
926                              unsigned long addr, unsigned long flags)
927 {
928         int err;
929         struct mm_struct *mm = current->mm;
930         struct vm_area_struct *vma = NULL;
931         struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
932
933         if (flags &
934                 ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
935                 return -EINVAL;
936
937         if (flags & MPOL_F_MEMS_ALLOWED) {
938                 if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
939                         return -EINVAL;
940                 *policy = 0;    /* just so it's initialized */
941                 task_lock(current);
942                 *nmask  = cpuset_current_mems_allowed;
943                 task_unlock(current);
944                 return 0;
945         }
946
947         if (flags & MPOL_F_ADDR) {
948                 pgoff_t ilx;            /* ignored here */
949                 /*
950                  * Do NOT fall back to task policy if the
951                  * vma/shared policy at addr is NULL.  We
952                  * want to return MPOL_DEFAULT in this case.
953                  */
954                 mmap_read_lock(mm);
955                 vma = vma_lookup(mm, addr);
956                 if (!vma) {
957                         mmap_read_unlock(mm);
958                         return -EFAULT;
959                 }
960                 pol = __get_vma_policy(vma, addr, &ilx);
961         } else if (addr)
962                 return -EINVAL;
963
964         if (!pol)
965                 pol = &default_policy;  /* indicates default behavior */
966
967         if (flags & MPOL_F_NODE) {
968                 if (flags & MPOL_F_ADDR) {
969                         /*
970                          * Take a refcount on the mpol, because we are about to
971                          * drop the mmap_lock, after which only "pol" remains
972                          * valid, "vma" is stale.
973                          */
974                         pol_refcount = pol;
975                         vma = NULL;
976                         mpol_get(pol);
977                         mmap_read_unlock(mm);
978                         err = lookup_node(mm, addr);
979                         if (err < 0)
980                                 goto out;
981                         *policy = err;
982                 } else if (pol == current->mempolicy &&
983                                 pol->mode == MPOL_INTERLEAVE) {
984                         *policy = next_node_in(current->il_prev, pol->nodes);
985                 } else if (pol == current->mempolicy &&
986                                 pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
987                         if (current->il_weight)
988                                 *policy = current->il_prev;
989                         else
990                                 *policy = next_node_in(current->il_prev,
991                                                        pol->nodes);
992                 } else {
993                         err = -EINVAL;
994                         goto out;
995                 }
996         } else {
997                 *policy = pol == &default_policy ? MPOL_DEFAULT :
998                                                 pol->mode;
999                 /*
1000                  * Internal mempolicy flags must be masked off before exposing
1001                  * the policy to userspace.
1002                  */
1003                 *policy |= (pol->flags & MPOL_MODE_FLAGS);
1004         }
1005
1006         err = 0;
1007         if (nmask) {
1008                 if (mpol_store_user_nodemask(pol)) {
1009                         *nmask = pol->w.user_nodemask;
1010                 } else {
1011                         task_lock(current);
1012                         get_policy_nodemask(pol, nmask);
1013                         task_unlock(current);
1014                 }
1015         }
1016
1017  out:
1018         mpol_cond_put(pol);
1019         if (vma)
1020                 mmap_read_unlock(mm);
1021         if (pol_refcount)
1022                 mpol_put(pol_refcount);
1023         return err;
1024 }
1025
1026 #ifdef CONFIG_MIGRATION
1027 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
1028                                 unsigned long flags)
1029 {
1030         /*
1031          * Unless MPOL_MF_MOVE_ALL, we try to avoid migrating a shared folio.
1032          * Choosing not to migrate a shared folio is not counted as a failure.
1033          *
1034          * See folio_likely_mapped_shared() on possible imprecision when we
1035          * cannot easily detect if a folio is shared.
1036          */
1037         if ((flags & MPOL_MF_MOVE_ALL) || !folio_likely_mapped_shared(folio)) {
1038                 if (folio_isolate_lru(folio)) {
1039                         list_add_tail(&folio->lru, foliolist);
1040                         node_stat_mod_folio(folio,
1041                                 NR_ISOLATED_ANON + folio_is_file_lru(folio),
1042                                 folio_nr_pages(folio));
1043                 } else {
1044                         /*
1045                          * Non-movable folio may reach here.  And, there may be
1046                          * temporary off LRU folios or non-LRU movable folios.
1047                          * Treat them as unmovable folios since they can't be
1048                          * isolated, so they can't be moved at the moment.
1049                          */
1050                         return false;
1051                 }
1052         }
1053         return true;
1054 }
1055
1056 /*
1057  * Migrate pages from one node to a target node.
1058  * Returns error or the number of pages not migrated.
1059  */
1060 static long migrate_to_node(struct mm_struct *mm, int source, int dest,
1061                             int flags)
1062 {
1063         nodemask_t nmask;
1064         struct vm_area_struct *vma;
1065         LIST_HEAD(pagelist);
1066         long nr_failed;
1067         long err = 0;
1068         struct migration_target_control mtc = {
1069                 .nid = dest,
1070                 .gfp_mask = GFP_HIGHUSER_MOVABLE | __GFP_THISNODE,
1071                 .reason = MR_SYSCALL,
1072         };
1073
1074         nodes_clear(nmask);
1075         node_set(source, nmask);
1076
1077         VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)));
1078
1079         mmap_read_lock(mm);
1080         vma = find_vma(mm, 0);
1081
1082         /*
1083          * This does not migrate the range, but isolates all pages that
1084          * need migration.  Between passing in the full user address
1085          * space range and MPOL_MF_DISCONTIG_OK, this call cannot fail,
1086          * but passes back the count of pages which could not be isolated.
1087          */
1088         nr_failed = queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask,
1089                                       flags | MPOL_MF_DISCONTIG_OK, &pagelist);
1090         mmap_read_unlock(mm);
1091
1092         if (!list_empty(&pagelist)) {
1093                 err = migrate_pages(&pagelist, alloc_migration_target, NULL,
1094                         (unsigned long)&mtc, MIGRATE_SYNC, MR_SYSCALL, NULL);
1095                 if (err)
1096                         putback_movable_pages(&pagelist);
1097         }
1098
1099         if (err >= 0)
1100                 err += nr_failed;
1101         return err;
1102 }
1103
1104 /*
1105  * Move pages between the two nodesets so as to preserve the physical
1106  * layout as much as possible.
1107  *
1108  * Returns the number of page that could not be moved.
1109  */
1110 int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
1111                      const nodemask_t *to, int flags)
1112 {
1113         long nr_failed = 0;
1114         long err = 0;
1115         nodemask_t tmp;
1116
1117         lru_cache_disable();
1118
1119         /*
1120          * Find a 'source' bit set in 'tmp' whose corresponding 'dest'
1121          * bit in 'to' is not also set in 'tmp'.  Clear the found 'source'
1122          * bit in 'tmp', and return that <source, dest> pair for migration.
1123          * The pair of nodemasks 'to' and 'from' define the map.
1124          *
1125          * If no pair of bits is found that way, fallback to picking some
1126          * pair of 'source' and 'dest' bits that are not the same.  If the
1127          * 'source' and 'dest' bits are the same, this represents a node
1128          * that will be migrating to itself, so no pages need move.
1129          *
1130          * If no bits are left in 'tmp', or if all remaining bits left
1131          * in 'tmp' correspond to the same bit in 'to', return false
1132          * (nothing left to migrate).
1133          *
1134          * This lets us pick a pair of nodes to migrate between, such that
1135          * if possible the dest node is not already occupied by some other
1136          * source node, minimizing the risk of overloading the memory on a
1137          * node that would happen if we migrated incoming memory to a node
1138          * before migrating outgoing memory source that same node.
1139          *
1140          * A single scan of tmp is sufficient.  As we go, we remember the
1141          * most recent <s, d> pair that moved (s != d).  If we find a pair
1142          * that not only moved, but what's better, moved to an empty slot
1143          * (d is not set in tmp), then we break out then, with that pair.
1144          * Otherwise when we finish scanning from_tmp, we at least have the
1145          * most recent <s, d> pair that moved.  If we get all the way through
1146          * the scan of tmp without finding any node that moved, much less
1147          * moved to an empty node, then there is nothing left worth migrating.
1148          */
1149
1150         tmp = *from;
1151         while (!nodes_empty(tmp)) {
1152                 int s, d;
1153                 int source = NUMA_NO_NODE;
1154                 int dest = 0;
1155
1156                 for_each_node_mask(s, tmp) {
1157
1158                         /*
1159                          * do_migrate_pages() tries to maintain the relative
1160                          * node relationship of the pages established between
1161                          * threads and memory areas.
1162                          *
1163                          * However if the number of source nodes is not equal to
1164                          * the number of destination nodes we can not preserve
1165                          * this node relative relationship.  In that case, skip
1166                          * copying memory from a node that is in the destination
1167                          * mask.
1168                          *
1169                          * Example: [2,3,4] -> [3,4,5] moves everything.
1170                          *          [0-7] - > [3,4,5] moves only 0,1,2,6,7.
1171                          */
1172
1173                         if ((nodes_weight(*from) != nodes_weight(*to)) &&
1174                                                 (node_isset(s, *to)))
1175                                 continue;
1176
1177                         d = node_remap(s, *from, *to);
1178                         if (s == d)
1179                                 continue;
1180
1181                         source = s;     /* Node moved. Memorize */
1182                         dest = d;
1183
1184                         /* dest not in remaining from nodes? */
1185                         if (!node_isset(dest, tmp))
1186                                 break;
1187                 }
1188                 if (source == NUMA_NO_NODE)
1189                         break;
1190
1191                 node_clear(source, tmp);
1192                 err = migrate_to_node(mm, source, dest, flags);
1193                 if (err > 0)
1194                         nr_failed += err;
1195                 if (err < 0)
1196                         break;
1197         }
1198
1199         lru_cache_enable();
1200         if (err < 0)
1201                 return err;
1202         return (nr_failed < INT_MAX) ? nr_failed : INT_MAX;
1203 }
1204
1205 /*
1206  * Allocate a new folio for page migration, according to NUMA mempolicy.
1207  */
1208 static struct folio *alloc_migration_target_by_mpol(struct folio *src,
1209                                                     unsigned long private)
1210 {
1211         struct migration_mpol *mmpol = (struct migration_mpol *)private;
1212         struct mempolicy *pol = mmpol->pol;
1213         pgoff_t ilx = mmpol->ilx;
1214         struct page *page;
1215         unsigned int order;
1216         int nid = numa_node_id();
1217         gfp_t gfp;
1218
1219         order = folio_order(src);
1220         ilx += src->index >> order;
1221
1222         if (folio_test_hugetlb(src)) {
1223                 nodemask_t *nodemask;
1224                 struct hstate *h;
1225
1226                 h = folio_hstate(src);
1227                 gfp = htlb_alloc_mask(h);
1228                 nodemask = policy_nodemask(gfp, pol, ilx, &nid);
1229                 return alloc_hugetlb_folio_nodemask(h, nid, nodemask, gfp,
1230                                 htlb_allow_alloc_fallback(MR_MEMPOLICY_MBIND));
1231         }
1232
1233         if (folio_test_large(src))
1234                 gfp = GFP_TRANSHUGE;
1235         else
1236                 gfp = GFP_HIGHUSER_MOVABLE | __GFP_RETRY_MAYFAIL | __GFP_COMP;
1237
1238         page = alloc_pages_mpol(gfp, order, pol, ilx, nid);
1239         return page_rmappable_folio(page);
1240 }
1241 #else
1242
1243 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
1244                                 unsigned long flags)
1245 {
1246         return false;
1247 }
1248
1249 int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
1250                      const nodemask_t *to, int flags)
1251 {
1252         return -ENOSYS;
1253 }
1254
1255 static struct folio *alloc_migration_target_by_mpol(struct folio *src,
1256                                                     unsigned long private)
1257 {
1258         return NULL;
1259 }
1260 #endif
1261
1262 static long do_mbind(unsigned long start, unsigned long len,
1263                      unsigned short mode, unsigned short mode_flags,
1264                      nodemask_t *nmask, unsigned long flags)
1265 {
1266         struct mm_struct *mm = current->mm;
1267         struct vm_area_struct *vma, *prev;
1268         struct vma_iterator vmi;
1269         struct migration_mpol mmpol;
1270         struct mempolicy *new;
1271         unsigned long end;
1272         long err;
1273         long nr_failed;
1274         LIST_HEAD(pagelist);
1275
1276         if (flags & ~(unsigned long)MPOL_MF_VALID)
1277                 return -EINVAL;
1278         if ((flags & MPOL_MF_MOVE_ALL) && !capable(CAP_SYS_NICE))
1279                 return -EPERM;
1280
1281         if (start & ~PAGE_MASK)
1282                 return -EINVAL;
1283
1284         if (mode == MPOL_DEFAULT)
1285                 flags &= ~MPOL_MF_STRICT;
1286
1287         len = PAGE_ALIGN(len);
1288         end = start + len;
1289
1290         if (end < start)
1291                 return -EINVAL;
1292         if (end == start)
1293                 return 0;
1294
1295         new = mpol_new(mode, mode_flags, nmask);
1296         if (IS_ERR(new))
1297                 return PTR_ERR(new);
1298
1299         /*
1300          * If we are using the default policy then operation
1301          * on discontinuous address spaces is okay after all
1302          */
1303         if (!new)
1304                 flags |= MPOL_MF_DISCONTIG_OK;
1305
1306         if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
1307                 lru_cache_disable();
1308         {
1309                 NODEMASK_SCRATCH(scratch);
1310                 if (scratch) {
1311                         mmap_write_lock(mm);
1312                         err = mpol_set_nodemask(new, nmask, scratch);
1313                         if (err)
1314                                 mmap_write_unlock(mm);
1315                 } else
1316                         err = -ENOMEM;
1317                 NODEMASK_SCRATCH_FREE(scratch);
1318         }
1319         if (err)
1320                 goto mpol_out;
1321
1322         /*
1323          * Lock the VMAs before scanning for pages to migrate,
1324          * to ensure we don't miss a concurrently inserted page.
1325          */
1326         nr_failed = queue_pages_range(mm, start, end, nmask,
1327                         flags | MPOL_MF_INVERT | MPOL_MF_WRLOCK, &pagelist);
1328
1329         if (nr_failed < 0) {
1330                 err = nr_failed;
1331                 nr_failed = 0;
1332         } else {
1333                 vma_iter_init(&vmi, mm, start);
1334                 prev = vma_prev(&vmi);
1335                 for_each_vma_range(vmi, vma, end) {
1336                         err = mbind_range(&vmi, vma, &prev, start, end, new);
1337                         if (err)
1338                                 break;
1339                 }
1340         }
1341
1342         if (!err && !list_empty(&pagelist)) {
1343                 /* Convert MPOL_DEFAULT's NULL to task or default policy */
1344                 if (!new) {
1345                         new = get_task_policy(current);
1346                         mpol_get(new);
1347                 }
1348                 mmpol.pol = new;
1349                 mmpol.ilx = 0;
1350
1351                 /*
1352                  * In the interleaved case, attempt to allocate on exactly the
1353                  * targeted nodes, for the first VMA to be migrated; for later
1354                  * VMAs, the nodes will still be interleaved from the targeted
1355                  * nodemask, but one by one may be selected differently.
1356                  */
1357                 if (new->mode == MPOL_INTERLEAVE ||
1358                     new->mode == MPOL_WEIGHTED_INTERLEAVE) {
1359                         struct folio *folio;
1360                         unsigned int order;
1361                         unsigned long addr = -EFAULT;
1362
1363                         list_for_each_entry(folio, &pagelist, lru) {
1364                                 if (!folio_test_ksm(folio))
1365                                         break;
1366                         }
1367                         if (!list_entry_is_head(folio, &pagelist, lru)) {
1368                                 vma_iter_init(&vmi, mm, start);
1369                                 for_each_vma_range(vmi, vma, end) {
1370                                         addr = page_address_in_vma(
1371                                                 folio_page(folio, 0), vma);
1372                                         if (addr != -EFAULT)
1373                                                 break;
1374                                 }
1375                         }
1376                         if (addr != -EFAULT) {
1377                                 order = folio_order(folio);
1378                                 /* We already know the pol, but not the ilx */
1379                                 mpol_cond_put(get_vma_policy(vma, addr, order,
1380                                                              &mmpol.ilx));
1381                                 /* Set base from which to increment by index */
1382                                 mmpol.ilx -= folio->index >> order;
1383                         }
1384                 }
1385         }
1386
1387         mmap_write_unlock(mm);
1388
1389         if (!err && !list_empty(&pagelist)) {
1390                 nr_failed |= migrate_pages(&pagelist,
1391                                 alloc_migration_target_by_mpol, NULL,
1392                                 (unsigned long)&mmpol, MIGRATE_SYNC,
1393                                 MR_MEMPOLICY_MBIND, NULL);
1394         }
1395
1396         if (nr_failed && (flags & MPOL_MF_STRICT))
1397                 err = -EIO;
1398         if (!list_empty(&pagelist))
1399                 putback_movable_pages(&pagelist);
1400 mpol_out:
1401         mpol_put(new);
1402         if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
1403                 lru_cache_enable();
1404         return err;
1405 }
1406
1407 /*
1408  * User space interface with variable sized bitmaps for nodelists.
1409  */
1410 static int get_bitmap(unsigned long *mask, const unsigned long __user *nmask,
1411                       unsigned long maxnode)
1412 {
1413         unsigned long nlongs = BITS_TO_LONGS(maxnode);
1414         int ret;
1415
1416         if (in_compat_syscall())
1417                 ret = compat_get_bitmap(mask,
1418                                         (const compat_ulong_t __user *)nmask,
1419                                         maxnode);
1420         else
1421                 ret = copy_from_user(mask, nmask,
1422                                      nlongs * sizeof(unsigned long));
1423
1424         if (ret)
1425                 return -EFAULT;
1426
1427         if (maxnode % BITS_PER_LONG)
1428                 mask[nlongs - 1] &= (1UL << (maxnode % BITS_PER_LONG)) - 1;
1429
1430         return 0;
1431 }
1432
1433 /* Copy a node mask from user space. */
1434 static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
1435                      unsigned long maxnode)
1436 {
1437         --maxnode;
1438         nodes_clear(*nodes);
1439         if (maxnode == 0 || !nmask)
1440                 return 0;
1441         if (maxnode > PAGE_SIZE*BITS_PER_BYTE)
1442                 return -EINVAL;
1443
1444         /*
1445          * When the user specified more nodes than supported just check
1446          * if the non supported part is all zero, one word at a time,
1447          * starting at the end.
1448          */
1449         while (maxnode > MAX_NUMNODES) {
1450                 unsigned long bits = min_t(unsigned long, maxnode, BITS_PER_LONG);
1451                 unsigned long t;
1452
1453                 if (get_bitmap(&t, &nmask[(maxnode - 1) / BITS_PER_LONG], bits))
1454                         return -EFAULT;
1455
1456                 if (maxnode - bits >= MAX_NUMNODES) {
1457                         maxnode -= bits;
1458                 } else {
1459                         maxnode = MAX_NUMNODES;
1460                         t &= ~((1UL << (MAX_NUMNODES % BITS_PER_LONG)) - 1);
1461                 }
1462                 if (t)
1463                         return -EINVAL;
1464         }
1465
1466         return get_bitmap(nodes_addr(*nodes), nmask, maxnode);
1467 }
1468
1469 /* Copy a kernel node mask to user space */
1470 static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
1471                               nodemask_t *nodes)
1472 {
1473         unsigned long copy = ALIGN(maxnode-1, 64) / 8;
1474         unsigned int nbytes = BITS_TO_LONGS(nr_node_ids) * sizeof(long);
1475         bool compat = in_compat_syscall();
1476
1477         if (compat)
1478                 nbytes = BITS_TO_COMPAT_LONGS(nr_node_ids) * sizeof(compat_long_t);
1479
1480         if (copy > nbytes) {
1481                 if (copy > PAGE_SIZE)
1482                         return -EINVAL;
1483                 if (clear_user((char __user *)mask + nbytes, copy - nbytes))
1484                         return -EFAULT;
1485                 copy = nbytes;
1486                 maxnode = nr_node_ids;
1487         }
1488
1489         if (compat)
1490                 return compat_put_bitmap((compat_ulong_t __user *)mask,
1491                                          nodes_addr(*nodes), maxnode);
1492
1493         return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
1494 }
1495
1496 /* Basic parameter sanity check used by both mbind() and set_mempolicy() */
1497 static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
1498 {
1499         *flags = *mode & MPOL_MODE_FLAGS;
1500         *mode &= ~MPOL_MODE_FLAGS;
1501
1502         if ((unsigned int)(*mode) >=  MPOL_MAX)
1503                 return -EINVAL;
1504         if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
1505                 return -EINVAL;
1506         if (*flags & MPOL_F_NUMA_BALANCING) {
1507                 if (*mode == MPOL_BIND || *mode == MPOL_PREFERRED_MANY)
1508                         *flags |= (MPOL_F_MOF | MPOL_F_MORON);
1509                 else
1510                         return -EINVAL;
1511         }
1512         return 0;
1513 }
1514
1515 static long kernel_mbind(unsigned long start, unsigned long len,
1516                          unsigned long mode, const unsigned long __user *nmask,
1517                          unsigned long maxnode, unsigned int flags)
1518 {
1519         unsigned short mode_flags;
1520         nodemask_t nodes;
1521         int lmode = mode;
1522         int err;
1523
1524         start = untagged_addr(start);
1525         err = sanitize_mpol_flags(&lmode, &mode_flags);
1526         if (err)
1527                 return err;
1528
1529         err = get_nodes(&nodes, nmask, maxnode);
1530         if (err)
1531                 return err;
1532
1533         return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
1534 }
1535
1536 SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, len,
1537                 unsigned long, home_node, unsigned long, flags)
1538 {
1539         struct mm_struct *mm = current->mm;
1540         struct vm_area_struct *vma, *prev;
1541         struct mempolicy *new, *old;
1542         unsigned long end;
1543         int err = -ENOENT;
1544         VMA_ITERATOR(vmi, mm, start);
1545
1546         start = untagged_addr(start);
1547         if (start & ~PAGE_MASK)
1548                 return -EINVAL;
1549         /*
1550          * flags is used for future extension if any.
1551          */
1552         if (flags != 0)
1553                 return -EINVAL;
1554
1555         /*
1556          * Check home_node is online to avoid accessing uninitialized
1557          * NODE_DATA.
1558          */
1559         if (home_node >= MAX_NUMNODES || !node_online(home_node))
1560                 return -EINVAL;
1561
1562         len = PAGE_ALIGN(len);
1563         end = start + len;
1564
1565         if (end < start)
1566                 return -EINVAL;
1567         if (end == start)
1568                 return 0;
1569         mmap_write_lock(mm);
1570         prev = vma_prev(&vmi);
1571         for_each_vma_range(vmi, vma, end) {
1572                 /*
1573                  * If any vma in the range got policy other than MPOL_BIND
1574                  * or MPOL_PREFERRED_MANY we return error. We don't reset
1575                  * the home node for vmas we already updated before.
1576                  */
1577                 old = vma_policy(vma);
1578                 if (!old) {
1579                         prev = vma;
1580                         continue;
1581                 }
1582                 if (old->mode != MPOL_BIND && old->mode != MPOL_PREFERRED_MANY) {
1583                         err = -EOPNOTSUPP;
1584                         break;
1585                 }
1586                 new = mpol_dup(old);
1587                 if (IS_ERR(new)) {
1588                         err = PTR_ERR(new);
1589                         break;
1590                 }
1591
1592                 vma_start_write(vma);
1593                 new->home_node = home_node;
1594                 err = mbind_range(&vmi, vma, &prev, start, end, new);
1595                 mpol_put(new);
1596                 if (err)
1597                         break;
1598         }
1599         mmap_write_unlock(mm);
1600         return err;
1601 }
1602
1603 SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
1604                 unsigned long, mode, const unsigned long __user *, nmask,
1605                 unsigned long, maxnode, unsigned int, flags)
1606 {
1607         return kernel_mbind(start, len, mode, nmask, maxnode, flags);
1608 }
1609
1610 /* Set the process memory policy */
1611 static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
1612                                  unsigned long maxnode)
1613 {
1614         unsigned short mode_flags;
1615         nodemask_t nodes;
1616         int lmode = mode;
1617         int err;
1618
1619         err = sanitize_mpol_flags(&lmode, &mode_flags);
1620         if (err)
1621                 return err;
1622
1623         err = get_nodes(&nodes, nmask, maxnode);
1624         if (err)
1625                 return err;
1626
1627         return do_set_mempolicy(lmode, mode_flags, &nodes);
1628 }
1629
1630 SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
1631                 unsigned long, maxnode)
1632 {
1633         return kernel_set_mempolicy(mode, nmask, maxnode);
1634 }
1635
1636 static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
1637                                 const unsigned long __user *old_nodes,
1638                                 const unsigned long __user *new_nodes)
1639 {
1640         struct mm_struct *mm = NULL;
1641         struct task_struct *task;
1642         nodemask_t task_nodes;
1643         int err;
1644         nodemask_t *old;
1645         nodemask_t *new;
1646         NODEMASK_SCRATCH(scratch);
1647
1648         if (!scratch)
1649                 return -ENOMEM;
1650
1651         old = &scratch->mask1;
1652         new = &scratch->mask2;
1653
1654         err = get_nodes(old, old_nodes, maxnode);
1655         if (err)
1656                 goto out;
1657
1658         err = get_nodes(new, new_nodes, maxnode);
1659         if (err)
1660                 goto out;
1661
1662         /* Find the mm_struct */
1663         rcu_read_lock();
1664         task = pid ? find_task_by_vpid(pid) : current;
1665         if (!task) {
1666                 rcu_read_unlock();
1667                 err = -ESRCH;
1668                 goto out;
1669         }
1670         get_task_struct(task);
1671
1672         err = -EINVAL;
1673
1674         /*
1675          * Check if this process has the right to modify the specified process.
1676          * Use the regular "ptrace_may_access()" checks.
1677          */
1678         if (!ptrace_may_access(task, PTRACE_MODE_READ_REALCREDS)) {
1679                 rcu_read_unlock();
1680                 err = -EPERM;
1681                 goto out_put;
1682         }
1683         rcu_read_unlock();
1684
1685         task_nodes = cpuset_mems_allowed(task);
1686         /* Is the user allowed to access the target nodes? */
1687         if (!nodes_subset(*new, task_nodes) && !capable(CAP_SYS_NICE)) {
1688                 err = -EPERM;
1689                 goto out_put;
1690         }
1691
1692         task_nodes = cpuset_mems_allowed(current);
1693         nodes_and(*new, *new, task_nodes);
1694         if (nodes_empty(*new))
1695                 goto out_put;
1696
1697         err = security_task_movememory(task);
1698         if (err)
1699                 goto out_put;
1700
1701         mm = get_task_mm(task);
1702         put_task_struct(task);
1703
1704         if (!mm) {
1705                 err = -EINVAL;
1706                 goto out;
1707         }
1708
1709         err = do_migrate_pages(mm, old, new,
1710                 capable(CAP_SYS_NICE) ? MPOL_MF_MOVE_ALL : MPOL_MF_MOVE);
1711
1712         mmput(mm);
1713 out:
1714         NODEMASK_SCRATCH_FREE(scratch);
1715
1716         return err;
1717
1718 out_put:
1719         put_task_struct(task);
1720         goto out;
1721 }
1722
1723 SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
1724                 const unsigned long __user *, old_nodes,
1725                 const unsigned long __user *, new_nodes)
1726 {
1727         return kernel_migrate_pages(pid, maxnode, old_nodes, new_nodes);
1728 }
1729
1730 /* Retrieve NUMA policy */
1731 static int kernel_get_mempolicy(int __user *policy,
1732                                 unsigned long __user *nmask,
1733                                 unsigned long maxnode,
1734                                 unsigned long addr,
1735                                 unsigned long flags)
1736 {
1737         int err;
1738         int pval;
1739         nodemask_t nodes;
1740
1741         if (nmask != NULL && maxnode < nr_node_ids)
1742                 return -EINVAL;
1743
1744         addr = untagged_addr(addr);
1745
1746         err = do_get_mempolicy(&pval, &nodes, addr, flags);
1747
1748         if (err)
1749                 return err;
1750
1751         if (policy && put_user(pval, policy))
1752                 return -EFAULT;
1753
1754         if (nmask)
1755                 err = copy_nodes_to_user(nmask, maxnode, &nodes);
1756
1757         return err;
1758 }
1759
1760 SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
1761                 unsigned long __user *, nmask, unsigned long, maxnode,
1762                 unsigned long, addr, unsigned long, flags)
1763 {
1764         return kernel_get_mempolicy(policy, nmask, maxnode, addr, flags);
1765 }
1766
1767 bool vma_migratable(struct vm_area_struct *vma)
1768 {
1769         if (vma->vm_flags & (VM_IO | VM_PFNMAP))
1770                 return false;
1771
1772         /*
1773          * DAX device mappings require predictable access latency, so avoid
1774          * incurring periodic faults.
1775          */
1776         if (vma_is_dax(vma))
1777                 return false;
1778
1779         if (is_vm_hugetlb_page(vma) &&
1780                 !hugepage_migration_supported(hstate_vma(vma)))
1781                 return false;
1782
1783         /*
1784          * Migration allocates pages in the highest zone. If we cannot
1785          * do so then migration (at least from node to node) is not
1786          * possible.
1787          */
1788         if (vma->vm_file &&
1789                 gfp_zone(mapping_gfp_mask(vma->vm_file->f_mapping))
1790                         < policy_zone)
1791                 return false;
1792         return true;
1793 }
1794
1795 struct mempolicy *__get_vma_policy(struct vm_area_struct *vma,
1796                                    unsigned long addr, pgoff_t *ilx)
1797 {
1798         *ilx = 0;
1799         return (vma->vm_ops && vma->vm_ops->get_policy) ?
1800                 vma->vm_ops->get_policy(vma, addr, ilx) : vma->vm_policy;
1801 }
1802
1803 /*
1804  * get_vma_policy(@vma, @addr, @order, @ilx)
1805  * @vma: virtual memory area whose policy is sought
1806  * @addr: address in @vma for shared policy lookup
1807  * @order: 0, or appropriate huge_page_order for interleaving
1808  * @ilx: interleave index (output), for use only when MPOL_INTERLEAVE or
1809  *       MPOL_WEIGHTED_INTERLEAVE
1810  *
1811  * Returns effective policy for a VMA at specified address.
1812  * Falls back to current->mempolicy or system default policy, as necessary.
1813  * Shared policies [those marked as MPOL_F_SHARED] require an extra reference
1814  * count--added by the get_policy() vm_op, as appropriate--to protect against
1815  * freeing by another task.  It is the caller's responsibility to free the
1816  * extra reference for shared policies.
1817  */
1818 struct mempolicy *get_vma_policy(struct vm_area_struct *vma,
1819                                  unsigned long addr, int order, pgoff_t *ilx)
1820 {
1821         struct mempolicy *pol;
1822
1823         pol = __get_vma_policy(vma, addr, ilx);
1824         if (!pol)
1825                 pol = get_task_policy(current);
1826         if (pol->mode == MPOL_INTERLEAVE ||
1827             pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
1828                 *ilx += vma->vm_pgoff >> order;
1829                 *ilx += (addr - vma->vm_start) >> (PAGE_SHIFT + order);
1830         }
1831         return pol;
1832 }
1833
1834 bool vma_policy_mof(struct vm_area_struct *vma)
1835 {
1836         struct mempolicy *pol;
1837
1838         if (vma->vm_ops && vma->vm_ops->get_policy) {
1839                 bool ret = false;
1840                 pgoff_t ilx;            /* ignored here */
1841
1842                 pol = vma->vm_ops->get_policy(vma, vma->vm_start, &ilx);
1843                 if (pol && (pol->flags & MPOL_F_MOF))
1844                         ret = true;
1845                 mpol_cond_put(pol);
1846
1847                 return ret;
1848         }
1849
1850         pol = vma->vm_policy;
1851         if (!pol)
1852                 pol = get_task_policy(current);
1853
1854         return pol->flags & MPOL_F_MOF;
1855 }
1856
1857 bool apply_policy_zone(struct mempolicy *policy, enum zone_type zone)
1858 {
1859         enum zone_type dynamic_policy_zone = policy_zone;
1860
1861         BUG_ON(dynamic_policy_zone == ZONE_MOVABLE);
1862
1863         /*
1864          * if policy->nodes has movable memory only,
1865          * we apply policy when gfp_zone(gfp) = ZONE_MOVABLE only.
1866          *
1867          * policy->nodes is intersect with node_states[N_MEMORY].
1868          * so if the following test fails, it implies
1869          * policy->nodes has movable memory only.
1870          */
1871         if (!nodes_intersects(policy->nodes, node_states[N_HIGH_MEMORY]))
1872                 dynamic_policy_zone = ZONE_MOVABLE;
1873
1874         return zone >= dynamic_policy_zone;
1875 }
1876
1877 static unsigned int weighted_interleave_nodes(struct mempolicy *policy)
1878 {
1879         unsigned int node;
1880         unsigned int cpuset_mems_cookie;
1881
1882 retry:
1883         /* to prevent miscount use tsk->mems_allowed_seq to detect rebind */
1884         cpuset_mems_cookie = read_mems_allowed_begin();
1885         node = current->il_prev;
1886         if (!current->il_weight || !node_isset(node, policy->nodes)) {
1887                 node = next_node_in(node, policy->nodes);
1888                 if (read_mems_allowed_retry(cpuset_mems_cookie))
1889                         goto retry;
1890                 if (node == MAX_NUMNODES)
1891                         return node;
1892                 current->il_prev = node;
1893                 current->il_weight = get_il_weight(node);
1894         }
1895         current->il_weight--;
1896         return node;
1897 }
1898
1899 /* Do dynamic interleaving for a process */
1900 static unsigned int interleave_nodes(struct mempolicy *policy)
1901 {
1902         unsigned int nid;
1903         unsigned int cpuset_mems_cookie;
1904
1905         /* to prevent miscount, use tsk->mems_allowed_seq to detect rebind */
1906         do {
1907                 cpuset_mems_cookie = read_mems_allowed_begin();
1908                 nid = next_node_in(current->il_prev, policy->nodes);
1909         } while (read_mems_allowed_retry(cpuset_mems_cookie));
1910
1911         if (nid < MAX_NUMNODES)
1912                 current->il_prev = nid;
1913         return nid;
1914 }
1915
1916 /*
1917  * Depending on the memory policy provide a node from which to allocate the
1918  * next slab entry.
1919  */
1920 unsigned int mempolicy_slab_node(void)
1921 {
1922         struct mempolicy *policy;
1923         int node = numa_mem_id();
1924
1925         if (!in_task())
1926                 return node;
1927
1928         policy = current->mempolicy;
1929         if (!policy)
1930                 return node;
1931
1932         switch (policy->mode) {
1933         case MPOL_PREFERRED:
1934                 return first_node(policy->nodes);
1935
1936         case MPOL_INTERLEAVE:
1937                 return interleave_nodes(policy);
1938
1939         case MPOL_WEIGHTED_INTERLEAVE:
1940                 return weighted_interleave_nodes(policy);
1941
1942         case MPOL_BIND:
1943         case MPOL_PREFERRED_MANY:
1944         {
1945                 struct zoneref *z;
1946
1947                 /*
1948                  * Follow bind policy behavior and start allocation at the
1949                  * first node.
1950                  */
1951                 struct zonelist *zonelist;
1952                 enum zone_type highest_zoneidx = gfp_zone(GFP_KERNEL);
1953                 zonelist = &NODE_DATA(node)->node_zonelists[ZONELIST_FALLBACK];
1954                 z = first_zones_zonelist(zonelist, highest_zoneidx,
1955                                                         &policy->nodes);
1956                 return z->zone ? zone_to_nid(z->zone) : node;
1957         }
1958         case MPOL_LOCAL:
1959                 return node;
1960
1961         default:
1962                 BUG();
1963         }
1964 }
1965
1966 static unsigned int read_once_policy_nodemask(struct mempolicy *pol,
1967                                               nodemask_t *mask)
1968 {
1969         /*
1970          * barrier stabilizes the nodemask locally so that it can be iterated
1971          * over safely without concern for changes. Allocators validate node
1972          * selection does not violate mems_allowed, so this is safe.
1973          */
1974         barrier();
1975         memcpy(mask, &pol->nodes, sizeof(nodemask_t));
1976         barrier();
1977         return nodes_weight(*mask);
1978 }
1979
1980 static unsigned int weighted_interleave_nid(struct mempolicy *pol, pgoff_t ilx)
1981 {
1982         nodemask_t nodemask;
1983         unsigned int target, nr_nodes;
1984         u8 *table;
1985         unsigned int weight_total = 0;
1986         u8 weight;
1987         int nid;
1988
1989         nr_nodes = read_once_policy_nodemask(pol, &nodemask);
1990         if (!nr_nodes)
1991                 return numa_node_id();
1992
1993         rcu_read_lock();
1994         table = rcu_dereference(iw_table);
1995         /* calculate the total weight */
1996         for_each_node_mask(nid, nodemask) {
1997                 /* detect system default usage */
1998                 weight = table ? table[nid] : 1;
1999                 weight = weight ? weight : 1;
2000                 weight_total += weight;
2001         }
2002
2003         /* Calculate the node offset based on totals */
2004         target = ilx % weight_total;
2005         nid = first_node(nodemask);
2006         while (target) {
2007                 /* detect system default usage */
2008                 weight = table ? table[nid] : 1;
2009                 weight = weight ? weight : 1;
2010                 if (target < weight)
2011                         break;
2012                 target -= weight;
2013                 nid = next_node_in(nid, nodemask);
2014         }
2015         rcu_read_unlock();
2016         return nid;
2017 }
2018
2019 /*
2020  * Do static interleaving for interleave index @ilx.  Returns the ilx'th
2021  * node in pol->nodes (starting from ilx=0), wrapping around if ilx
2022  * exceeds the number of present nodes.
2023  */
2024 static unsigned int interleave_nid(struct mempolicy *pol, pgoff_t ilx)
2025 {
2026         nodemask_t nodemask;
2027         unsigned int target, nnodes;
2028         int i;
2029         int nid;
2030
2031         nnodes = read_once_policy_nodemask(pol, &nodemask);
2032         if (!nnodes)
2033                 return numa_node_id();
2034         target = ilx % nnodes;
2035         nid = first_node(nodemask);
2036         for (i = 0; i < target; i++)
2037                 nid = next_node(nid, nodemask);
2038         return nid;
2039 }
2040
2041 /*
2042  * Return a nodemask representing a mempolicy for filtering nodes for
2043  * page allocation, together with preferred node id (or the input node id).
2044  */
2045 static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol,
2046                                    pgoff_t ilx, int *nid)
2047 {
2048         nodemask_t *nodemask = NULL;
2049
2050         switch (pol->mode) {
2051         case MPOL_PREFERRED:
2052                 /* Override input node id */
2053                 *nid = first_node(pol->nodes);
2054                 break;
2055         case MPOL_PREFERRED_MANY:
2056                 nodemask = &pol->nodes;
2057                 if (pol->home_node != NUMA_NO_NODE)
2058                         *nid = pol->home_node;
2059                 break;
2060         case MPOL_BIND:
2061                 /* Restrict to nodemask (but not on lower zones) */
2062                 if (apply_policy_zone(pol, gfp_zone(gfp)) &&
2063                     cpuset_nodemask_valid_mems_allowed(&pol->nodes))
2064                         nodemask = &pol->nodes;
2065                 if (pol->home_node != NUMA_NO_NODE)
2066                         *nid = pol->home_node;
2067                 /*
2068                  * __GFP_THISNODE shouldn't even be used with the bind policy
2069                  * because we might easily break the expectation to stay on the
2070                  * requested node and not break the policy.
2071                  */
2072                 WARN_ON_ONCE(gfp & __GFP_THISNODE);
2073                 break;
2074         case MPOL_INTERLEAVE:
2075                 /* Override input node id */
2076                 *nid = (ilx == NO_INTERLEAVE_INDEX) ?
2077                         interleave_nodes(pol) : interleave_nid(pol, ilx);
2078                 break;
2079         case MPOL_WEIGHTED_INTERLEAVE:
2080                 *nid = (ilx == NO_INTERLEAVE_INDEX) ?
2081                         weighted_interleave_nodes(pol) :
2082                         weighted_interleave_nid(pol, ilx);
2083                 break;
2084         }
2085
2086         return nodemask;
2087 }
2088
2089 #ifdef CONFIG_HUGETLBFS
2090 /*
2091  * huge_node(@vma, @addr, @gfp_flags, @mpol)
2092  * @vma: virtual memory area whose policy is sought
2093  * @addr: address in @vma for shared policy lookup and interleave policy
2094  * @gfp_flags: for requested zone
2095  * @mpol: pointer to mempolicy pointer for reference counted mempolicy
2096  * @nodemask: pointer to nodemask pointer for 'bind' and 'prefer-many' policy
2097  *
2098  * Returns a nid suitable for a huge page allocation and a pointer
2099  * to the struct mempolicy for conditional unref after allocation.
2100  * If the effective policy is 'bind' or 'prefer-many', returns a pointer
2101  * to the mempolicy's @nodemask for filtering the zonelist.
2102  */
2103 int huge_node(struct vm_area_struct *vma, unsigned long addr, gfp_t gfp_flags,
2104                 struct mempolicy **mpol, nodemask_t **nodemask)
2105 {
2106         pgoff_t ilx;
2107         int nid;
2108
2109         nid = numa_node_id();
2110         *mpol = get_vma_policy(vma, addr, hstate_vma(vma)->order, &ilx);
2111         *nodemask = policy_nodemask(gfp_flags, *mpol, ilx, &nid);
2112         return nid;
2113 }
2114
2115 /*
2116  * init_nodemask_of_mempolicy
2117  *
2118  * If the current task's mempolicy is "default" [NULL], return 'false'
2119  * to indicate default policy.  Otherwise, extract the policy nodemask
2120  * for 'bind' or 'interleave' policy into the argument nodemask, or
2121  * initialize the argument nodemask to contain the single node for
2122  * 'preferred' or 'local' policy and return 'true' to indicate presence
2123  * of non-default mempolicy.
2124  *
2125  * We don't bother with reference counting the mempolicy [mpol_get/put]
2126  * because the current task is examining it's own mempolicy and a task's
2127  * mempolicy is only ever changed by the task itself.
2128  *
2129  * N.B., it is the caller's responsibility to free a returned nodemask.
2130  */
2131 bool init_nodemask_of_mempolicy(nodemask_t *mask)
2132 {
2133         struct mempolicy *mempolicy;
2134
2135         if (!(mask && current->mempolicy))
2136                 return false;
2137
2138         task_lock(current);
2139         mempolicy = current->mempolicy;
2140         switch (mempolicy->mode) {
2141         case MPOL_PREFERRED:
2142         case MPOL_PREFERRED_MANY:
2143         case MPOL_BIND:
2144         case MPOL_INTERLEAVE:
2145         case MPOL_WEIGHTED_INTERLEAVE:
2146                 *mask = mempolicy->nodes;
2147                 break;
2148
2149         case MPOL_LOCAL:
2150                 init_nodemask_of_node(mask, numa_node_id());
2151                 break;
2152
2153         default:
2154                 BUG();
2155         }
2156         task_unlock(current);
2157
2158         return true;
2159 }
2160 #endif
2161
2162 /*
2163  * mempolicy_in_oom_domain
2164  *
2165  * If tsk's mempolicy is "bind", check for intersection between mask and
2166  * the policy nodemask. Otherwise, return true for all other policies
2167  * including "interleave", as a tsk with "interleave" policy may have
2168  * memory allocated from all nodes in system.
2169  *
2170  * Takes task_lock(tsk) to prevent freeing of its mempolicy.
2171  */
2172 bool mempolicy_in_oom_domain(struct task_struct *tsk,
2173                                         const nodemask_t *mask)
2174 {
2175         struct mempolicy *mempolicy;
2176         bool ret = true;
2177
2178         if (!mask)
2179                 return ret;
2180
2181         task_lock(tsk);
2182         mempolicy = tsk->mempolicy;
2183         if (mempolicy && mempolicy->mode == MPOL_BIND)
2184                 ret = nodes_intersects(mempolicy->nodes, *mask);
2185         task_unlock(tsk);
2186
2187         return ret;
2188 }
2189
2190 static struct page *alloc_pages_preferred_many(gfp_t gfp, unsigned int order,
2191                                                 int nid, nodemask_t *nodemask)
2192 {
2193         struct page *page;
2194         gfp_t preferred_gfp;
2195
2196         /*
2197          * This is a two pass approach. The first pass will only try the
2198          * preferred nodes but skip the direct reclaim and allow the
2199          * allocation to fail, while the second pass will try all the
2200          * nodes in system.
2201          */
2202         preferred_gfp = gfp | __GFP_NOWARN;
2203         preferred_gfp &= ~(__GFP_DIRECT_RECLAIM | __GFP_NOFAIL);
2204         page = __alloc_pages_noprof(preferred_gfp, order, nid, nodemask);
2205         if (!page)
2206                 page = __alloc_pages_noprof(gfp, order, nid, NULL);
2207
2208         return page;
2209 }
2210
2211 /**
2212  * alloc_pages_mpol - Allocate pages according to NUMA mempolicy.
2213  * @gfp: GFP flags.
2214  * @order: Order of the page allocation.
2215  * @pol: Pointer to the NUMA mempolicy.
2216  * @ilx: Index for interleave mempolicy (also distinguishes alloc_pages()).
2217  * @nid: Preferred node (usually numa_node_id() but @mpol may override it).
2218  *
2219  * Return: The page on success or NULL if allocation fails.
2220  */
2221 struct page *alloc_pages_mpol_noprof(gfp_t gfp, unsigned int order,
2222                 struct mempolicy *pol, pgoff_t ilx, int nid)
2223 {
2224         nodemask_t *nodemask;
2225         struct page *page;
2226
2227         nodemask = policy_nodemask(gfp, pol, ilx, &nid);
2228
2229         if (pol->mode == MPOL_PREFERRED_MANY)
2230                 return alloc_pages_preferred_many(gfp, order, nid, nodemask);
2231
2232         if (IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) &&
2233             /* filter "hugepage" allocation, unless from alloc_pages() */
2234             order == HPAGE_PMD_ORDER && ilx != NO_INTERLEAVE_INDEX) {
2235                 /*
2236                  * For hugepage allocation and non-interleave policy which
2237                  * allows the current node (or other explicitly preferred
2238                  * node) we only try to allocate from the current/preferred
2239                  * node and don't fall back to other nodes, as the cost of
2240                  * remote accesses would likely offset THP benefits.
2241                  *
2242                  * If the policy is interleave or does not allow the current
2243                  * node in its nodemask, we allocate the standard way.
2244                  */
2245                 if (pol->mode != MPOL_INTERLEAVE &&
2246                     pol->mode != MPOL_WEIGHTED_INTERLEAVE &&
2247                     (!nodemask || node_isset(nid, *nodemask))) {
2248                         /*
2249                          * First, try to allocate THP only on local node, but
2250                          * don't reclaim unnecessarily, just compact.
2251                          */
2252                         page = __alloc_pages_node_noprof(nid,
2253                                 gfp | __GFP_THISNODE | __GFP_NORETRY, order);
2254                         if (page || !(gfp & __GFP_DIRECT_RECLAIM))
2255                                 return page;
2256                         /*
2257                          * If hugepage allocations are configured to always
2258                          * synchronous compact or the vma has been madvised
2259                          * to prefer hugepage backing, retry allowing remote
2260                          * memory with both reclaim and compact as well.
2261                          */
2262                 }
2263         }
2264
2265         page = __alloc_pages_noprof(gfp, order, nid, nodemask);
2266
2267         if (unlikely(pol->mode == MPOL_INTERLEAVE) && page) {
2268                 /* skip NUMA_INTERLEAVE_HIT update if numa stats is disabled */
2269                 if (static_branch_likely(&vm_numa_stat_key) &&
2270                     page_to_nid(page) == nid) {
2271                         preempt_disable();
2272                         __count_numa_event(page_zone(page), NUMA_INTERLEAVE_HIT);
2273                         preempt_enable();
2274                 }
2275         }
2276
2277         return page;
2278 }
2279
2280 /**
2281  * vma_alloc_folio - Allocate a folio for a VMA.
2282  * @gfp: GFP flags.
2283  * @order: Order of the folio.
2284  * @vma: Pointer to VMA.
2285  * @addr: Virtual address of the allocation.  Must be inside @vma.
2286  * @hugepage: Unused (was: For hugepages try only preferred node if possible).
2287  *
2288  * Allocate a folio for a specific address in @vma, using the appropriate
2289  * NUMA policy.  The caller must hold the mmap_lock of the mm_struct of the
2290  * VMA to prevent it from going away.  Should be used for all allocations
2291  * for folios that will be mapped into user space, excepting hugetlbfs, and
2292  * excepting where direct use of alloc_pages_mpol() is more appropriate.
2293  *
2294  * Return: The folio on success or NULL if allocation fails.
2295  */
2296 struct folio *vma_alloc_folio_noprof(gfp_t gfp, int order, struct vm_area_struct *vma,
2297                 unsigned long addr, bool hugepage)
2298 {
2299         struct mempolicy *pol;
2300         pgoff_t ilx;
2301         struct page *page;
2302
2303         pol = get_vma_policy(vma, addr, order, &ilx);
2304         page = alloc_pages_mpol_noprof(gfp | __GFP_COMP, order,
2305                                        pol, ilx, numa_node_id());
2306         mpol_cond_put(pol);
2307         return page_rmappable_folio(page);
2308 }
2309 EXPORT_SYMBOL(vma_alloc_folio_noprof);
2310
2311 /**
2312  * alloc_pages - Allocate pages.
2313  * @gfp: GFP flags.
2314  * @order: Power of two of number of pages to allocate.
2315  *
2316  * Allocate 1 << @order contiguous pages.  The physical address of the
2317  * first page is naturally aligned (eg an order-3 allocation will be aligned
2318  * to a multiple of 8 * PAGE_SIZE bytes).  The NUMA policy of the current
2319  * process is honoured when in process context.
2320  *
2321  * Context: Can be called from any context, providing the appropriate GFP
2322  * flags are used.
2323  * Return: The page on success or NULL if allocation fails.
2324  */
2325 struct page *alloc_pages_noprof(gfp_t gfp, unsigned int order)
2326 {
2327         struct mempolicy *pol = &default_policy;
2328
2329         /*
2330          * No reference counting needed for current->mempolicy
2331          * nor system default_policy
2332          */
2333         if (!in_interrupt() && !(gfp & __GFP_THISNODE))
2334                 pol = get_task_policy(current);
2335
2336         return alloc_pages_mpol_noprof(gfp, order, pol, NO_INTERLEAVE_INDEX,
2337                                        numa_node_id());
2338 }
2339 EXPORT_SYMBOL(alloc_pages_noprof);
2340
2341 struct folio *folio_alloc_noprof(gfp_t gfp, unsigned int order)
2342 {
2343         return page_rmappable_folio(alloc_pages_noprof(gfp | __GFP_COMP, order));
2344 }
2345 EXPORT_SYMBOL(folio_alloc_noprof);
2346
2347 static unsigned long alloc_pages_bulk_array_interleave(gfp_t gfp,
2348                 struct mempolicy *pol, unsigned long nr_pages,
2349                 struct page **page_array)
2350 {
2351         int nodes;
2352         unsigned long nr_pages_per_node;
2353         int delta;
2354         int i;
2355         unsigned long nr_allocated;
2356         unsigned long total_allocated = 0;
2357
2358         nodes = nodes_weight(pol->nodes);
2359         nr_pages_per_node = nr_pages / nodes;
2360         delta = nr_pages - nodes * nr_pages_per_node;
2361
2362         for (i = 0; i < nodes; i++) {
2363                 if (delta) {
2364                         nr_allocated = alloc_pages_bulk_noprof(gfp,
2365                                         interleave_nodes(pol), NULL,
2366                                         nr_pages_per_node + 1, NULL,
2367                                         page_array);
2368                         delta--;
2369                 } else {
2370                         nr_allocated = alloc_pages_bulk_noprof(gfp,
2371                                         interleave_nodes(pol), NULL,
2372                                         nr_pages_per_node, NULL, page_array);
2373                 }
2374
2375                 page_array += nr_allocated;
2376                 total_allocated += nr_allocated;
2377         }
2378
2379         return total_allocated;
2380 }
2381
2382 static unsigned long alloc_pages_bulk_array_weighted_interleave(gfp_t gfp,
2383                 struct mempolicy *pol, unsigned long nr_pages,
2384                 struct page **page_array)
2385 {
2386         struct task_struct *me = current;
2387         unsigned int cpuset_mems_cookie;
2388         unsigned long total_allocated = 0;
2389         unsigned long nr_allocated = 0;
2390         unsigned long rounds;
2391         unsigned long node_pages, delta;
2392         u8 *table, *weights, weight;
2393         unsigned int weight_total = 0;
2394         unsigned long rem_pages = nr_pages;
2395         nodemask_t nodes;
2396         int nnodes, node;
2397         int resume_node = MAX_NUMNODES - 1;
2398         u8 resume_weight = 0;
2399         int prev_node;
2400         int i;
2401
2402         if (!nr_pages)
2403                 return 0;
2404
2405         /* read the nodes onto the stack, retry if done during rebind */
2406         do {
2407                 cpuset_mems_cookie = read_mems_allowed_begin();
2408                 nnodes = read_once_policy_nodemask(pol, &nodes);
2409         } while (read_mems_allowed_retry(cpuset_mems_cookie));
2410
2411         /* if the nodemask has become invalid, we cannot do anything */
2412         if (!nnodes)
2413                 return 0;
2414
2415         /* Continue allocating from most recent node and adjust the nr_pages */
2416         node = me->il_prev;
2417         weight = me->il_weight;
2418         if (weight && node_isset(node, nodes)) {
2419                 node_pages = min(rem_pages, weight);
2420                 nr_allocated = __alloc_pages_bulk(gfp, node, NULL, node_pages,
2421                                                   NULL, page_array);
2422                 page_array += nr_allocated;
2423                 total_allocated += nr_allocated;
2424                 /* if that's all the pages, no need to interleave */
2425                 if (rem_pages <= weight) {
2426                         me->il_weight -= rem_pages;
2427                         return total_allocated;
2428                 }
2429                 /* Otherwise we adjust remaining pages, continue from there */
2430                 rem_pages -= weight;
2431         }
2432         /* clear active weight in case of an allocation failure */
2433         me->il_weight = 0;
2434         prev_node = node;
2435
2436         /* create a local copy of node weights to operate on outside rcu */
2437         weights = kzalloc(nr_node_ids, GFP_KERNEL);
2438         if (!weights)
2439                 return total_allocated;
2440
2441         rcu_read_lock();
2442         table = rcu_dereference(iw_table);
2443         if (table)
2444                 memcpy(weights, table, nr_node_ids);
2445         rcu_read_unlock();
2446
2447         /* calculate total, detect system default usage */
2448         for_each_node_mask(node, nodes) {
2449                 if (!weights[node])
2450                         weights[node] = 1;
2451                 weight_total += weights[node];
2452         }
2453
2454         /*
2455          * Calculate rounds/partial rounds to minimize __alloc_pages_bulk calls.
2456          * Track which node weighted interleave should resume from.
2457          *
2458          * if (rounds > 0) and (delta == 0), resume_node will always be
2459          * the node following prev_node and its weight.
2460          */
2461         rounds = rem_pages / weight_total;
2462         delta = rem_pages % weight_total;
2463         resume_node = next_node_in(prev_node, nodes);
2464         resume_weight = weights[resume_node];
2465         for (i = 0; i < nnodes; i++) {
2466                 node = next_node_in(prev_node, nodes);
2467                 weight = weights[node];
2468                 node_pages = weight * rounds;
2469                 /* If a delta exists, add this node's portion of the delta */
2470                 if (delta > weight) {
2471                         node_pages += weight;
2472                         delta -= weight;
2473                 } else if (delta) {
2474                         /* when delta is depleted, resume from that node */
2475                         node_pages += delta;
2476                         resume_node = node;
2477                         resume_weight = weight - delta;
2478                         delta = 0;
2479                 }
2480                 /* node_pages can be 0 if an allocation fails and rounds == 0 */
2481                 if (!node_pages)
2482                         break;
2483                 nr_allocated = __alloc_pages_bulk(gfp, node, NULL, node_pages,
2484                                                   NULL, page_array);
2485                 page_array += nr_allocated;
2486                 total_allocated += nr_allocated;
2487                 if (total_allocated == nr_pages)
2488                         break;
2489                 prev_node = node;
2490         }
2491         me->il_prev = resume_node;
2492         me->il_weight = resume_weight;
2493         kfree(weights);
2494         return total_allocated;
2495 }
2496
2497 static unsigned long alloc_pages_bulk_array_preferred_many(gfp_t gfp, int nid,
2498                 struct mempolicy *pol, unsigned long nr_pages,
2499                 struct page **page_array)
2500 {
2501         gfp_t preferred_gfp;
2502         unsigned long nr_allocated = 0;
2503
2504         preferred_gfp = gfp | __GFP_NOWARN;
2505         preferred_gfp &= ~(__GFP_DIRECT_RECLAIM | __GFP_NOFAIL);
2506
2507         nr_allocated  = alloc_pages_bulk_noprof(preferred_gfp, nid, &pol->nodes,
2508                                            nr_pages, NULL, page_array);
2509
2510         if (nr_allocated < nr_pages)
2511                 nr_allocated += alloc_pages_bulk_noprof(gfp, numa_node_id(), NULL,
2512                                 nr_pages - nr_allocated, NULL,
2513                                 page_array + nr_allocated);
2514         return nr_allocated;
2515 }
2516
2517 /* alloc pages bulk and mempolicy should be considered at the
2518  * same time in some situation such as vmalloc.
2519  *
2520  * It can accelerate memory allocation especially interleaving
2521  * allocate memory.
2522  */
2523 unsigned long alloc_pages_bulk_array_mempolicy_noprof(gfp_t gfp,
2524                 unsigned long nr_pages, struct page **page_array)
2525 {
2526         struct mempolicy *pol = &default_policy;
2527         nodemask_t *nodemask;
2528         int nid;
2529
2530         if (!in_interrupt() && !(gfp & __GFP_THISNODE))
2531                 pol = get_task_policy(current);
2532
2533         if (pol->mode == MPOL_INTERLEAVE)
2534                 return alloc_pages_bulk_array_interleave(gfp, pol,
2535                                                          nr_pages, page_array);
2536
2537         if (pol->mode == MPOL_WEIGHTED_INTERLEAVE)
2538                 return alloc_pages_bulk_array_weighted_interleave(
2539                                   gfp, pol, nr_pages, page_array);
2540
2541         if (pol->mode == MPOL_PREFERRED_MANY)
2542                 return alloc_pages_bulk_array_preferred_many(gfp,
2543                                 numa_node_id(), pol, nr_pages, page_array);
2544
2545         nid = numa_node_id();
2546         nodemask = policy_nodemask(gfp, pol, NO_INTERLEAVE_INDEX, &nid);
2547         return alloc_pages_bulk_noprof(gfp, nid, nodemask,
2548                                        nr_pages, NULL, page_array);
2549 }
2550
2551 int vma_dup_policy(struct vm_area_struct *src, struct vm_area_struct *dst)
2552 {
2553         struct mempolicy *pol = mpol_dup(src->vm_policy);
2554
2555         if (IS_ERR(pol))
2556                 return PTR_ERR(pol);
2557         dst->vm_policy = pol;
2558         return 0;
2559 }
2560
2561 /*
2562  * If mpol_dup() sees current->cpuset == cpuset_being_rebound, then it
2563  * rebinds the mempolicy its copying by calling mpol_rebind_policy()
2564  * with the mems_allowed returned by cpuset_mems_allowed().  This
2565  * keeps mempolicies cpuset relative after its cpuset moves.  See
2566  * further kernel/cpuset.c update_nodemask().
2567  *
2568  * current's mempolicy may be rebinded by the other task(the task that changes
2569  * cpuset's mems), so we needn't do rebind work for current task.
2570  */
2571
2572 /* Slow path of a mempolicy duplicate */
2573 struct mempolicy *__mpol_dup(struct mempolicy *old)
2574 {
2575         struct mempolicy *new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
2576
2577         if (!new)
2578                 return ERR_PTR(-ENOMEM);
2579
2580         /* task's mempolicy is protected by alloc_lock */
2581         if (old == current->mempolicy) {
2582                 task_lock(current);
2583                 *new = *old;
2584                 task_unlock(current);
2585         } else
2586                 *new = *old;
2587
2588         if (current_cpuset_is_being_rebound()) {
2589                 nodemask_t mems = cpuset_mems_allowed(current);
2590                 mpol_rebind_policy(new, &mems);
2591         }
2592         atomic_set(&new->refcnt, 1);
2593         return new;
2594 }
2595
2596 /* Slow path of a mempolicy comparison */
2597 bool __mpol_equal(struct mempolicy *a, struct mempolicy *b)
2598 {
2599         if (!a || !b)
2600                 return false;
2601         if (a->mode != b->mode)
2602                 return false;
2603         if (a->flags != b->flags)
2604                 return false;
2605         if (a->home_node != b->home_node)
2606                 return false;
2607         if (mpol_store_user_nodemask(a))
2608                 if (!nodes_equal(a->w.user_nodemask, b->w.user_nodemask))
2609                         return false;
2610
2611         switch (a->mode) {
2612         case MPOL_BIND:
2613         case MPOL_INTERLEAVE:
2614         case MPOL_PREFERRED:
2615         case MPOL_PREFERRED_MANY:
2616         case MPOL_WEIGHTED_INTERLEAVE:
2617                 return !!nodes_equal(a->nodes, b->nodes);
2618         case MPOL_LOCAL:
2619                 return true;
2620         default:
2621                 BUG();
2622                 return false;
2623         }
2624 }
2625
2626 /*
2627  * Shared memory backing store policy support.
2628  *
2629  * Remember policies even when nobody has shared memory mapped.
2630  * The policies are kept in Red-Black tree linked from the inode.
2631  * They are protected by the sp->lock rwlock, which should be held
2632  * for any accesses to the tree.
2633  */
2634
2635 /*
2636  * lookup first element intersecting start-end.  Caller holds sp->lock for
2637  * reading or for writing
2638  */
2639 static struct sp_node *sp_lookup(struct shared_policy *sp,
2640                                         pgoff_t start, pgoff_t end)
2641 {
2642         struct rb_node *n = sp->root.rb_node;
2643
2644         while (n) {
2645                 struct sp_node *p = rb_entry(n, struct sp_node, nd);
2646
2647                 if (start >= p->end)
2648                         n = n->rb_right;
2649                 else if (end <= p->start)
2650                         n = n->rb_left;
2651                 else
2652                         break;
2653         }
2654         if (!n)
2655                 return NULL;
2656         for (;;) {
2657                 struct sp_node *w = NULL;
2658                 struct rb_node *prev = rb_prev(n);
2659                 if (!prev)
2660                         break;
2661                 w = rb_entry(prev, struct sp_node, nd);
2662                 if (w->end <= start)
2663                         break;
2664                 n = prev;
2665         }
2666         return rb_entry(n, struct sp_node, nd);
2667 }
2668
2669 /*
2670  * Insert a new shared policy into the list.  Caller holds sp->lock for
2671  * writing.
2672  */
2673 static void sp_insert(struct shared_policy *sp, struct sp_node *new)
2674 {
2675         struct rb_node **p = &sp->root.rb_node;
2676         struct rb_node *parent = NULL;
2677         struct sp_node *nd;
2678
2679         while (*p) {
2680                 parent = *p;
2681                 nd = rb_entry(parent, struct sp_node, nd);
2682                 if (new->start < nd->start)
2683                         p = &(*p)->rb_left;
2684                 else if (new->end > nd->end)
2685                         p = &(*p)->rb_right;
2686                 else
2687                         BUG();
2688         }
2689         rb_link_node(&new->nd, parent, p);
2690         rb_insert_color(&new->nd, &sp->root);
2691 }
2692
2693 /* Find shared policy intersecting idx */
2694 struct mempolicy *mpol_shared_policy_lookup(struct shared_policy *sp,
2695                                                 pgoff_t idx)
2696 {
2697         struct mempolicy *pol = NULL;
2698         struct sp_node *sn;
2699
2700         if (!sp->root.rb_node)
2701                 return NULL;
2702         read_lock(&sp->lock);
2703         sn = sp_lookup(sp, idx, idx+1);
2704         if (sn) {
2705                 mpol_get(sn->policy);
2706                 pol = sn->policy;
2707         }
2708         read_unlock(&sp->lock);
2709         return pol;
2710 }
2711
2712 static void sp_free(struct sp_node *n)
2713 {
2714         mpol_put(n->policy);
2715         kmem_cache_free(sn_cache, n);
2716 }
2717
2718 /**
2719  * mpol_misplaced - check whether current folio node is valid in policy
2720  *
2721  * @folio: folio to be checked
2722  * @vmf: structure describing the fault
2723  * @addr: virtual address in @vma for shared policy lookup and interleave policy
2724  *
2725  * Lookup current policy node id for vma,addr and "compare to" folio's
2726  * node id.  Policy determination "mimics" alloc_page_vma().
2727  * Called from fault path where we know the vma and faulting address.
2728  *
2729  * Return: NUMA_NO_NODE if the page is in a node that is valid for this
2730  * policy, or a suitable node ID to allocate a replacement folio from.
2731  */
2732 int mpol_misplaced(struct folio *folio, struct vm_fault *vmf,
2733                    unsigned long addr)
2734 {
2735         struct mempolicy *pol;
2736         pgoff_t ilx;
2737         struct zoneref *z;
2738         int curnid = folio_nid(folio);
2739         struct vm_area_struct *vma = vmf->vma;
2740         int thiscpu = raw_smp_processor_id();
2741         int thisnid = numa_node_id();
2742         int polnid = NUMA_NO_NODE;
2743         int ret = NUMA_NO_NODE;
2744
2745         /*
2746          * Make sure ptl is held so that we don't preempt and we
2747          * have a stable smp processor id
2748          */
2749         lockdep_assert_held(vmf->ptl);
2750         pol = get_vma_policy(vma, addr, folio_order(folio), &ilx);
2751         if (!(pol->flags & MPOL_F_MOF))
2752                 goto out;
2753
2754         switch (pol->mode) {
2755         case MPOL_INTERLEAVE:
2756                 polnid = interleave_nid(pol, ilx);
2757                 break;
2758
2759         case MPOL_WEIGHTED_INTERLEAVE:
2760                 polnid = weighted_interleave_nid(pol, ilx);
2761                 break;
2762
2763         case MPOL_PREFERRED:
2764                 if (node_isset(curnid, pol->nodes))
2765                         goto out;
2766                 polnid = first_node(pol->nodes);
2767                 break;
2768
2769         case MPOL_LOCAL:
2770                 polnid = numa_node_id();
2771                 break;
2772
2773         case MPOL_BIND:
2774         case MPOL_PREFERRED_MANY:
2775                 /*
2776                  * Even though MPOL_PREFERRED_MANY can allocate pages outside
2777                  * policy nodemask we don't allow numa migration to nodes
2778                  * outside policy nodemask for now. This is done so that if we
2779                  * want demotion to slow memory to happen, before allocating
2780                  * from some DRAM node say 'x', we will end up using a
2781                  * MPOL_PREFERRED_MANY mask excluding node 'x'. In such scenario
2782                  * we should not promote to node 'x' from slow memory node.
2783                  */
2784                 if (pol->flags & MPOL_F_MORON) {
2785                         /*
2786                          * Optimize placement among multiple nodes
2787                          * via NUMA balancing
2788                          */
2789                         if (node_isset(thisnid, pol->nodes))
2790                                 break;
2791                         goto out;
2792                 }
2793
2794                 /*
2795                  * use current page if in policy nodemask,
2796                  * else select nearest allowed node, if any.
2797                  * If no allowed nodes, use current [!misplaced].
2798                  */
2799                 if (node_isset(curnid, pol->nodes))
2800                         goto out;
2801                 z = first_zones_zonelist(
2802                                 node_zonelist(thisnid, GFP_HIGHUSER),
2803                                 gfp_zone(GFP_HIGHUSER),
2804                                 &pol->nodes);
2805                 polnid = zone_to_nid(z->zone);
2806                 break;
2807
2808         default:
2809                 BUG();
2810         }
2811
2812         /* Migrate the folio towards the node whose CPU is referencing it */
2813         if (pol->flags & MPOL_F_MORON) {
2814                 polnid = thisnid;
2815
2816                 if (!should_numa_migrate_memory(current, folio, curnid,
2817                                                 thiscpu))
2818                         goto out;
2819         }
2820
2821         if (curnid != polnid)
2822                 ret = polnid;
2823 out:
2824         mpol_cond_put(pol);
2825
2826         return ret;
2827 }
2828
2829 /*
2830  * Drop the (possibly final) reference to task->mempolicy.  It needs to be
2831  * dropped after task->mempolicy is set to NULL so that any allocation done as
2832  * part of its kmem_cache_free(), such as by KASAN, doesn't reference a freed
2833  * policy.
2834  */
2835 void mpol_put_task_policy(struct task_struct *task)
2836 {
2837         struct mempolicy *pol;
2838
2839         task_lock(task);
2840         pol = task->mempolicy;
2841         task->mempolicy = NULL;
2842         task_unlock(task);
2843         mpol_put(pol);
2844 }
2845
2846 static void sp_delete(struct shared_policy *sp, struct sp_node *n)
2847 {
2848         rb_erase(&n->nd, &sp->root);
2849         sp_free(n);
2850 }
2851
2852 static void sp_node_init(struct sp_node *node, unsigned long start,
2853                         unsigned long end, struct mempolicy *pol)
2854 {
2855         node->start = start;
2856         node->end = end;
2857         node->policy = pol;
2858 }
2859
2860 static struct sp_node *sp_alloc(unsigned long start, unsigned long end,
2861                                 struct mempolicy *pol)
2862 {
2863         struct sp_node *n;
2864         struct mempolicy *newpol;
2865
2866         n = kmem_cache_alloc(sn_cache, GFP_KERNEL);
2867         if (!n)
2868                 return NULL;
2869
2870         newpol = mpol_dup(pol);
2871         if (IS_ERR(newpol)) {
2872                 kmem_cache_free(sn_cache, n);
2873                 return NULL;
2874         }
2875         newpol->flags |= MPOL_F_SHARED;
2876         sp_node_init(n, start, end, newpol);
2877
2878         return n;
2879 }
2880
2881 /* Replace a policy range. */
2882 static int shared_policy_replace(struct shared_policy *sp, pgoff_t start,
2883                                  pgoff_t end, struct sp_node *new)
2884 {
2885         struct sp_node *n;
2886         struct sp_node *n_new = NULL;
2887         struct mempolicy *mpol_new = NULL;
2888         int ret = 0;
2889
2890 restart:
2891         write_lock(&sp->lock);
2892         n = sp_lookup(sp, start, end);
2893         /* Take care of old policies in the same range. */
2894         while (n && n->start < end) {
2895                 struct rb_node *next = rb_next(&n->nd);
2896                 if (n->start >= start) {
2897                         if (n->end <= end)
2898                                 sp_delete(sp, n);
2899                         else
2900                                 n->start = end;
2901                 } else {
2902                         /* Old policy spanning whole new range. */
2903                         if (n->end > end) {
2904                                 if (!n_new)
2905                                         goto alloc_new;
2906
2907                                 *mpol_new = *n->policy;
2908                                 atomic_set(&mpol_new->refcnt, 1);
2909                                 sp_node_init(n_new, end, n->end, mpol_new);
2910                                 n->end = start;
2911                                 sp_insert(sp, n_new);
2912                                 n_new = NULL;
2913                                 mpol_new = NULL;
2914                                 break;
2915                         } else
2916                                 n->end = start;
2917                 }
2918                 if (!next)
2919                         break;
2920                 n = rb_entry(next, struct sp_node, nd);
2921         }
2922         if (new)
2923                 sp_insert(sp, new);
2924         write_unlock(&sp->lock);
2925         ret = 0;
2926
2927 err_out:
2928         if (mpol_new)
2929                 mpol_put(mpol_new);
2930         if (n_new)
2931                 kmem_cache_free(sn_cache, n_new);
2932
2933         return ret;
2934
2935 alloc_new:
2936         write_unlock(&sp->lock);
2937         ret = -ENOMEM;
2938         n_new = kmem_cache_alloc(sn_cache, GFP_KERNEL);
2939         if (!n_new)
2940                 goto err_out;
2941         mpol_new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
2942         if (!mpol_new)
2943                 goto err_out;
2944         atomic_set(&mpol_new->refcnt, 1);
2945         goto restart;
2946 }
2947
2948 /**
2949  * mpol_shared_policy_init - initialize shared policy for inode
2950  * @sp: pointer to inode shared policy
2951  * @mpol:  struct mempolicy to install
2952  *
2953  * Install non-NULL @mpol in inode's shared policy rb-tree.
2954  * On entry, the current task has a reference on a non-NULL @mpol.
2955  * This must be released on exit.
2956  * This is called at get_inode() calls and we can use GFP_KERNEL.
2957  */
2958 void mpol_shared_policy_init(struct shared_policy *sp, struct mempolicy *mpol)
2959 {
2960         int ret;
2961
2962         sp->root = RB_ROOT;             /* empty tree == default mempolicy */
2963         rwlock_init(&sp->lock);
2964
2965         if (mpol) {
2966                 struct sp_node *sn;
2967                 struct mempolicy *npol;
2968                 NODEMASK_SCRATCH(scratch);
2969
2970                 if (!scratch)
2971                         goto put_mpol;
2972
2973                 /* contextualize the tmpfs mount point mempolicy to this file */
2974                 npol = mpol_new(mpol->mode, mpol->flags, &mpol->w.user_nodemask);
2975                 if (IS_ERR(npol))
2976                         goto free_scratch; /* no valid nodemask intersection */
2977
2978                 task_lock(current);
2979                 ret = mpol_set_nodemask(npol, &mpol->w.user_nodemask, scratch);
2980                 task_unlock(current);
2981                 if (ret)
2982                         goto put_npol;
2983
2984                 /* alloc node covering entire file; adds ref to file's npol */
2985                 sn = sp_alloc(0, MAX_LFS_FILESIZE >> PAGE_SHIFT, npol);
2986                 if (sn)
2987                         sp_insert(sp, sn);
2988 put_npol:
2989                 mpol_put(npol); /* drop initial ref on file's npol */
2990 free_scratch:
2991                 NODEMASK_SCRATCH_FREE(scratch);
2992 put_mpol:
2993                 mpol_put(mpol); /* drop our incoming ref on sb mpol */
2994         }
2995 }
2996
2997 int mpol_set_shared_policy(struct shared_policy *sp,
2998                         struct vm_area_struct *vma, struct mempolicy *pol)
2999 {
3000         int err;
3001         struct sp_node *new = NULL;
3002         unsigned long sz = vma_pages(vma);
3003
3004         if (pol) {
3005                 new = sp_alloc(vma->vm_pgoff, vma->vm_pgoff + sz, pol);
3006                 if (!new)
3007                         return -ENOMEM;
3008         }
3009         err = shared_policy_replace(sp, vma->vm_pgoff, vma->vm_pgoff + sz, new);
3010         if (err && new)
3011                 sp_free(new);
3012         return err;
3013 }
3014
3015 /* Free a backing policy store on inode delete. */
3016 void mpol_free_shared_policy(struct shared_policy *sp)
3017 {
3018         struct sp_node *n;
3019         struct rb_node *next;
3020
3021         if (!sp->root.rb_node)
3022                 return;
3023         write_lock(&sp->lock);
3024         next = rb_first(&sp->root);
3025         while (next) {
3026                 n = rb_entry(next, struct sp_node, nd);
3027                 next = rb_next(&n->nd);
3028                 sp_delete(sp, n);
3029         }
3030         write_unlock(&sp->lock);
3031 }
3032
3033 #ifdef CONFIG_NUMA_BALANCING
3034 static int __initdata numabalancing_override;
3035
3036 static void __init check_numabalancing_enable(void)
3037 {
3038         bool numabalancing_default = false;
3039
3040         if (IS_ENABLED(CONFIG_NUMA_BALANCING_DEFAULT_ENABLED))
3041                 numabalancing_default = true;
3042
3043         /* Parsed by setup_numabalancing. override == 1 enables, -1 disables */
3044         if (numabalancing_override)
3045                 set_numabalancing_state(numabalancing_override == 1);
3046
3047         if (num_online_nodes() > 1 && !numabalancing_override) {
3048                 pr_info("%s automatic NUMA balancing. Configure with numa_balancing= or the kernel.numa_balancing sysctl\n",
3049                         numabalancing_default ? "Enabling" : "Disabling");
3050                 set_numabalancing_state(numabalancing_default);
3051         }
3052 }
3053
3054 static int __init setup_numabalancing(char *str)
3055 {
3056         int ret = 0;
3057         if (!str)
3058                 goto out;
3059
3060         if (!strcmp(str, "enable")) {
3061                 numabalancing_override = 1;
3062                 ret = 1;
3063         } else if (!strcmp(str, "disable")) {
3064                 numabalancing_override = -1;
3065                 ret = 1;
3066         }
3067 out:
3068         if (!ret)
3069                 pr_warn("Unable to parse numa_balancing=\n");
3070
3071         return ret;
3072 }
3073 __setup("numa_balancing=", setup_numabalancing);
3074 #else
3075 static inline void __init check_numabalancing_enable(void)
3076 {
3077 }
3078 #endif /* CONFIG_NUMA_BALANCING */
3079
3080 void __init numa_policy_init(void)
3081 {
3082         nodemask_t interleave_nodes;
3083         unsigned long largest = 0;
3084         int nid, prefer = 0;
3085
3086         policy_cache = kmem_cache_create("numa_policy",
3087                                          sizeof(struct mempolicy),
3088                                          0, SLAB_PANIC, NULL);
3089
3090         sn_cache = kmem_cache_create("shared_policy_node",
3091                                      sizeof(struct sp_node),
3092                                      0, SLAB_PANIC, NULL);
3093
3094         for_each_node(nid) {
3095                 preferred_node_policy[nid] = (struct mempolicy) {
3096                         .refcnt = ATOMIC_INIT(1),
3097                         .mode = MPOL_PREFERRED,
3098                         .flags = MPOL_F_MOF | MPOL_F_MORON,
3099                         .nodes = nodemask_of_node(nid),
3100                 };
3101         }
3102
3103         /*
3104          * Set interleaving policy for system init. Interleaving is only
3105          * enabled across suitably sized nodes (default is >= 16MB), or
3106          * fall back to the largest node if they're all smaller.
3107          */
3108         nodes_clear(interleave_nodes);
3109         for_each_node_state(nid, N_MEMORY) {
3110                 unsigned long total_pages = node_present_pages(nid);
3111
3112                 /* Preserve the largest node */
3113                 if (largest < total_pages) {
3114                         largest = total_pages;
3115                         prefer = nid;
3116                 }
3117
3118                 /* Interleave this node? */
3119                 if ((total_pages << PAGE_SHIFT) >= (16 << 20))
3120                         node_set(nid, interleave_nodes);
3121         }
3122
3123         /* All too small, use the largest */
3124         if (unlikely(nodes_empty(interleave_nodes)))
3125                 node_set(prefer, interleave_nodes);
3126
3127         if (do_set_mempolicy(MPOL_INTERLEAVE, 0, &interleave_nodes))
3128                 pr_err("%s: interleaving failed\n", __func__);
3129
3130         check_numabalancing_enable();
3131 }
3132
3133 /* Reset policy of current process to default */
3134 void numa_default_policy(void)
3135 {
3136         do_set_mempolicy(MPOL_DEFAULT, 0, NULL);
3137 }
3138
3139 /*
3140  * Parse and format mempolicy from/to strings
3141  */
3142 static const char * const policy_modes[] =
3143 {
3144         [MPOL_DEFAULT]    = "default",
3145         [MPOL_PREFERRED]  = "prefer",
3146         [MPOL_BIND]       = "bind",
3147         [MPOL_INTERLEAVE] = "interleave",
3148         [MPOL_WEIGHTED_INTERLEAVE] = "weighted interleave",
3149         [MPOL_LOCAL]      = "local",
3150         [MPOL_PREFERRED_MANY]  = "prefer (many)",
3151 };
3152
3153 #ifdef CONFIG_TMPFS
3154 /**
3155  * mpol_parse_str - parse string to mempolicy, for tmpfs mpol mount option.
3156  * @str:  string containing mempolicy to parse
3157  * @mpol:  pointer to struct mempolicy pointer, returned on success.
3158  *
3159  * Format of input:
3160  *      <mode>[=<flags>][:<nodelist>]
3161  *
3162  * Return: %0 on success, else %1
3163  */
3164 int mpol_parse_str(char *str, struct mempolicy **mpol)
3165 {
3166         struct mempolicy *new = NULL;
3167         unsigned short mode_flags;
3168         nodemask_t nodes;
3169         char *nodelist = strchr(str, ':');
3170         char *flags = strchr(str, '=');
3171         int err = 1, mode;
3172
3173         if (flags)
3174                 *flags++ = '\0';        /* terminate mode string */
3175
3176         if (nodelist) {
3177                 /* NUL-terminate mode or flags string */
3178                 *nodelist++ = '\0';
3179                 if (nodelist_parse(nodelist, nodes))
3180                         goto out;
3181                 if (!nodes_subset(nodes, node_states[N_MEMORY]))
3182                         goto out;
3183         } else
3184                 nodes_clear(nodes);
3185
3186         mode = match_string(policy_modes, MPOL_MAX, str);
3187         if (mode < 0)
3188                 goto out;
3189
3190         switch (mode) {
3191         case MPOL_PREFERRED:
3192                 /*
3193                  * Insist on a nodelist of one node only, although later
3194                  * we use first_node(nodes) to grab a single node, so here
3195                  * nodelist (or nodes) cannot be empty.
3196                  */
3197                 if (nodelist) {
3198                         char *rest = nodelist;
3199                         while (isdigit(*rest))
3200                                 rest++;
3201                         if (*rest)
3202                                 goto out;
3203                         if (nodes_empty(nodes))
3204                                 goto out;
3205                 }
3206                 break;
3207         case MPOL_INTERLEAVE:
3208         case MPOL_WEIGHTED_INTERLEAVE:
3209                 /*
3210                  * Default to online nodes with memory if no nodelist
3211                  */
3212                 if (!nodelist)
3213                         nodes = node_states[N_MEMORY];
3214                 break;
3215         case MPOL_LOCAL:
3216                 /*
3217                  * Don't allow a nodelist;  mpol_new() checks flags
3218                  */
3219                 if (nodelist)
3220                         goto out;
3221                 break;
3222         case MPOL_DEFAULT:
3223                 /*
3224                  * Insist on a empty nodelist
3225                  */
3226                 if (!nodelist)
3227                         err = 0;
3228                 goto out;
3229         case MPOL_PREFERRED_MANY:
3230         case MPOL_BIND:
3231                 /*
3232                  * Insist on a nodelist
3233                  */
3234                 if (!nodelist)
3235                         goto out;
3236         }
3237
3238         mode_flags = 0;
3239         if (flags) {
3240                 /*
3241                  * Currently, we only support two mutually exclusive
3242                  * mode flags.
3243                  */
3244                 if (!strcmp(flags, "static"))
3245                         mode_flags |= MPOL_F_STATIC_NODES;
3246                 else if (!strcmp(flags, "relative"))
3247                         mode_flags |= MPOL_F_RELATIVE_NODES;
3248                 else
3249                         goto out;
3250         }
3251
3252         new = mpol_new(mode, mode_flags, &nodes);
3253         if (IS_ERR(new))
3254                 goto out;
3255
3256         /*
3257          * Save nodes for mpol_to_str() to show the tmpfs mount options
3258          * for /proc/mounts, /proc/pid/mounts and /proc/pid/mountinfo.
3259          */
3260         if (mode != MPOL_PREFERRED) {
3261                 new->nodes = nodes;
3262         } else if (nodelist) {
3263                 nodes_clear(new->nodes);
3264                 node_set(first_node(nodes), new->nodes);
3265         } else {
3266                 new->mode = MPOL_LOCAL;
3267         }
3268
3269         /*
3270          * Save nodes for contextualization: this will be used to "clone"
3271          * the mempolicy in a specific context [cpuset] at a later time.
3272          */
3273         new->w.user_nodemask = nodes;
3274
3275         err = 0;
3276
3277 out:
3278         /* Restore string for error message */
3279         if (nodelist)
3280                 *--nodelist = ':';
3281         if (flags)
3282                 *--flags = '=';
3283         if (!err)
3284                 *mpol = new;
3285         return err;
3286 }
3287 #endif /* CONFIG_TMPFS */
3288
3289 /**
3290  * mpol_to_str - format a mempolicy structure for printing
3291  * @buffer:  to contain formatted mempolicy string
3292  * @maxlen:  length of @buffer
3293  * @pol:  pointer to mempolicy to be formatted
3294  *
3295  * Convert @pol into a string.  If @buffer is too short, truncate the string.
3296  * Recommend a @maxlen of at least 32 for the longest mode, "interleave", the
3297  * longest flag, "relative", and to display at least a few node ids.
3298  */
3299 void mpol_to_str(char *buffer, int maxlen, struct mempolicy *pol)
3300 {
3301         char *p = buffer;
3302         nodemask_t nodes = NODE_MASK_NONE;
3303         unsigned short mode = MPOL_DEFAULT;
3304         unsigned short flags = 0;
3305
3306         if (pol && pol != &default_policy && !(pol->flags & MPOL_F_MORON)) {
3307                 mode = pol->mode;
3308                 flags = pol->flags;
3309         }
3310
3311         switch (mode) {
3312         case MPOL_DEFAULT:
3313         case MPOL_LOCAL:
3314                 break;
3315         case MPOL_PREFERRED:
3316         case MPOL_PREFERRED_MANY:
3317         case MPOL_BIND:
3318         case MPOL_INTERLEAVE:
3319         case MPOL_WEIGHTED_INTERLEAVE:
3320                 nodes = pol->nodes;
3321                 break;
3322         default:
3323                 WARN_ON_ONCE(1);
3324                 snprintf(p, maxlen, "unknown");
3325                 return;
3326         }
3327
3328         p += snprintf(p, maxlen, "%s", policy_modes[mode]);
3329
3330         if (flags & MPOL_MODE_FLAGS) {
3331                 p += snprintf(p, buffer + maxlen - p, "=");
3332
3333                 /*
3334                  * Currently, the only defined flags are mutually exclusive
3335                  */
3336                 if (flags & MPOL_F_STATIC_NODES)
3337                         p += snprintf(p, buffer + maxlen - p, "static");
3338                 else if (flags & MPOL_F_RELATIVE_NODES)
3339                         p += snprintf(p, buffer + maxlen - p, "relative");
3340         }
3341
3342         if (!nodes_empty(nodes))
3343                 p += scnprintf(p, buffer + maxlen - p, ":%*pbl",
3344                                nodemask_pr_args(&nodes));
3345 }
3346
3347 #ifdef CONFIG_SYSFS
3348 struct iw_node_attr {
3349         struct kobj_attribute kobj_attr;
3350         int nid;
3351 };
3352
3353 static ssize_t node_show(struct kobject *kobj, struct kobj_attribute *attr,
3354                          char *buf)
3355 {
3356         struct iw_node_attr *node_attr;
3357         u8 weight;
3358
3359         node_attr = container_of(attr, struct iw_node_attr, kobj_attr);
3360         weight = get_il_weight(node_attr->nid);
3361         return sysfs_emit(buf, "%d\n", weight);
3362 }
3363
3364 static ssize_t node_store(struct kobject *kobj, struct kobj_attribute *attr,
3365                           const char *buf, size_t count)
3366 {
3367         struct iw_node_attr *node_attr;
3368         u8 *new;
3369         u8 *old;
3370         u8 weight = 0;
3371
3372         node_attr = container_of(attr, struct iw_node_attr, kobj_attr);
3373         if (count == 0 || sysfs_streq(buf, ""))
3374                 weight = 0;
3375         else if (kstrtou8(buf, 0, &weight))
3376                 return -EINVAL;
3377
3378         new = kzalloc(nr_node_ids, GFP_KERNEL);
3379         if (!new)
3380                 return -ENOMEM;
3381
3382         mutex_lock(&iw_table_lock);
3383         old = rcu_dereference_protected(iw_table,
3384                                         lockdep_is_held(&iw_table_lock));
3385         if (old)
3386                 memcpy(new, old, nr_node_ids);
3387         new[node_attr->nid] = weight;
3388         rcu_assign_pointer(iw_table, new);
3389         mutex_unlock(&iw_table_lock);
3390         synchronize_rcu();
3391         kfree(old);
3392         return count;
3393 }
3394
3395 static struct iw_node_attr **node_attrs;
3396
3397 static void sysfs_wi_node_release(struct iw_node_attr *node_attr,
3398                                   struct kobject *parent)
3399 {
3400         if (!node_attr)
3401                 return;
3402         sysfs_remove_file(parent, &node_attr->kobj_attr.attr);
3403         kfree(node_attr->kobj_attr.attr.name);
3404         kfree(node_attr);
3405 }
3406
3407 static void sysfs_wi_release(struct kobject *wi_kobj)
3408 {
3409         int i;
3410
3411         for (i = 0; i < nr_node_ids; i++)
3412                 sysfs_wi_node_release(node_attrs[i], wi_kobj);
3413         kobject_put(wi_kobj);
3414 }
3415
3416 static const struct kobj_type wi_ktype = {
3417         .sysfs_ops = &kobj_sysfs_ops,
3418         .release = sysfs_wi_release,
3419 };
3420
3421 static int add_weight_node(int nid, struct kobject *wi_kobj)
3422 {
3423         struct iw_node_attr *node_attr;
3424         char *name;
3425
3426         node_attr = kzalloc(sizeof(*node_attr), GFP_KERNEL);
3427         if (!node_attr)
3428                 return -ENOMEM;
3429
3430         name = kasprintf(GFP_KERNEL, "node%d", nid);
3431         if (!name) {
3432                 kfree(node_attr);
3433                 return -ENOMEM;
3434         }
3435
3436         sysfs_attr_init(&node_attr->kobj_attr.attr);
3437         node_attr->kobj_attr.attr.name = name;
3438         node_attr->kobj_attr.attr.mode = 0644;
3439         node_attr->kobj_attr.show = node_show;
3440         node_attr->kobj_attr.store = node_store;
3441         node_attr->nid = nid;
3442
3443         if (sysfs_create_file(wi_kobj, &node_attr->kobj_attr.attr)) {
3444                 kfree(node_attr->kobj_attr.attr.name);
3445                 kfree(node_attr);
3446                 pr_err("failed to add attribute to weighted_interleave\n");
3447                 return -ENOMEM;
3448         }
3449
3450         node_attrs[nid] = node_attr;
3451         return 0;
3452 }
3453
3454 static int add_weighted_interleave_group(struct kobject *root_kobj)
3455 {
3456         struct kobject *wi_kobj;
3457         int nid, err;
3458
3459         wi_kobj = kzalloc(sizeof(struct kobject), GFP_KERNEL);
3460         if (!wi_kobj)
3461                 return -ENOMEM;
3462
3463         err = kobject_init_and_add(wi_kobj, &wi_ktype, root_kobj,
3464                                    "weighted_interleave");
3465         if (err) {
3466                 kfree(wi_kobj);
3467                 return err;
3468         }
3469
3470         for_each_node_state(nid, N_POSSIBLE) {
3471                 err = add_weight_node(nid, wi_kobj);
3472                 if (err) {
3473                         pr_err("failed to add sysfs [node%d]\n", nid);
3474                         break;
3475                 }
3476         }
3477         if (err)
3478                 kobject_put(wi_kobj);
3479         return 0;
3480 }
3481
3482 static void mempolicy_kobj_release(struct kobject *kobj)
3483 {
3484         u8 *old;
3485
3486         mutex_lock(&iw_table_lock);
3487         old = rcu_dereference_protected(iw_table,
3488                                         lockdep_is_held(&iw_table_lock));
3489         rcu_assign_pointer(iw_table, NULL);
3490         mutex_unlock(&iw_table_lock);
3491         synchronize_rcu();
3492         kfree(old);
3493         kfree(node_attrs);
3494         kfree(kobj);
3495 }
3496
3497 static const struct kobj_type mempolicy_ktype = {
3498         .release = mempolicy_kobj_release
3499 };
3500
3501 static int __init mempolicy_sysfs_init(void)
3502 {
3503         int err;
3504         static struct kobject *mempolicy_kobj;
3505
3506         mempolicy_kobj = kzalloc(sizeof(*mempolicy_kobj), GFP_KERNEL);
3507         if (!mempolicy_kobj) {
3508                 err = -ENOMEM;
3509                 goto err_out;
3510         }
3511
3512         node_attrs = kcalloc(nr_node_ids, sizeof(struct iw_node_attr *),
3513                              GFP_KERNEL);
3514         if (!node_attrs) {
3515                 err = -ENOMEM;
3516                 goto mempol_out;
3517         }
3518
3519         err = kobject_init_and_add(mempolicy_kobj, &mempolicy_ktype, mm_kobj,
3520                                    "mempolicy");
3521         if (err)
3522                 goto node_out;
3523
3524         err = add_weighted_interleave_group(mempolicy_kobj);
3525         if (err) {
3526                 pr_err("mempolicy sysfs structure failed to initialize\n");
3527                 kobject_put(mempolicy_kobj);
3528                 return err;
3529         }
3530
3531         return err;
3532 node_out:
3533         kfree(node_attrs);
3534 mempol_out:
3535         kfree(mempolicy_kobj);
3536 err_out:
3537         pr_err("failed to add mempolicy kobject to the system\n");
3538         return err;
3539 }
3540
3541 late_initcall(mempolicy_sysfs_init);
3542 #endif /* CONFIG_SYSFS */