Comment utiliser la fonction NumPy argmax() en Python

Dans ce didacticiel, vous apprendrez à utiliser la fonction NumPy argmax() pour trouver l’index de l’élément maximum dans les tableaux.

NumPy est une bibliothèque puissante pour le calcul scientifique en Python ; il fournit des tableaux à N dimensions plus performants que les listes Python. L’une des opérations courantes que vous effectuerez lorsque vous travaillerez avec des tableaux NumPy consiste à trouver la valeur maximale dans le tableau. Cependant, vous souhaiterez peut-être parfois rechercher l’index auquel la valeur maximale se produit.

La fonction argmax() vous aide à trouver l’indice du maximum dans les tableaux unidimensionnels et multidimensionnels. Continuons pour apprendre comment cela fonctionne.

Comment trouver l’index de l’élément maximum dans un tableau NumPy

Pour suivre ce tutoriel, vous devez avoir installé Python et NumPy. Vous pouvez coder en démarrant un REPL Python ou en lançant un notebook Jupyter.

Tout d’abord, importons NumPy sous l’alias habituel np.

import numpy as np

Vous pouvez utiliser la fonction NumPy max() pour obtenir la valeur maximale dans un tableau (éventuellement le long d’un axe spécifique).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

Dans ce cas, np.max(array_1) renvoie 10, ce qui est correct.

Supposons que vous souhaitiez trouver l’index auquel la valeur maximale se produit dans le tableau. Vous pouvez adopter l’approche en deux étapes suivante :

  • Trouvez l’élément maximum.
  • Trouver l’indice de l’élément maximum.
  • Dans array_1, la valeur maximale de 10 se produit à l’index 4, après l’indexation zéro. Le premier élément est à l’index 0 ; le deuxième élément est à l’indice 1, et ainsi de suite.

    Pour trouver l’index auquel le maximum se produit, vous pouvez utiliser la fonction NumPy where(). np.where(condition) renvoie un tableau de tous les indices où la condition est vraie.

    Vous devrez puiser dans le tableau et accéder à l’élément au premier index. Pour trouver où se trouve la valeur maximale, nous définissons la condition sur array_1==10 ; rappelez-vous que 10 est la valeur maximale dans array_1.

    print(int(np.where(array_1==10)[0]))
    
    # Output
    4

    Nous avons utilisé np.where() avec uniquement la condition, mais ce n’est pas la méthode recommandée pour utiliser cette fonction.

    📑 Remarque : Fonction NumPy where() :
    np.where(condition,x,y) renvoie :

    – Éléments de x lorsque la condition est vraie, et
    – Éléments de y lorsque la condition est fausse.

    Par conséquent, en enchaînant les fonctions np.max() et np.where(), nous pouvons trouver l’élément maximum, suivi de l’indice auquel il se produit.

    Au lieu du processus en deux étapes ci-dessus, vous pouvez utiliser la fonction NumPy argmax() pour obtenir l’index de l’élément maximum dans le tableau.

    Syntaxe de la fonction NumPy argmax()

    La syntaxe générale pour utiliser la fonction NumPy argmax() est la suivante :

    np.argmax(array,axis,out)
    # we've imported numpy under the alias np

    Dans la syntaxe ci-dessus :

    • array est n’importe quel tableau NumPy valide.
    • l’axe est un paramètre facultatif. Lorsque vous travaillez avec des tableaux multidimensionnels, vous pouvez utiliser le paramètre d’axe pour trouver l’indice du maximum le long d’un axe spécifique.
    • out est un autre paramètre facultatif. Vous pouvez définir le paramètre out sur un tableau NumPy pour stocker la sortie de la fonction argmax().

    Remarque : à partir de la version 1.22.0 de NumPy, il existe un paramètre keepdims supplémentaire. Lorsque nous spécifions le paramètre axis dans l’appel de la fonction argmax(), le tableau est réduit le long de cet axe. Mais définir le paramètre keepdims sur True garantit que la sortie renvoyée a la même forme que le tableau d’entrée.

    Utilisation de NumPy argmax() pour trouver l’index de l’élément maximum

    #1. Utilisons la fonction NumPy argmax() pour trouver l’index de l’élément maximum dans array_1.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    La fonction argmax() renvoie 4, ce qui est correct ! ✅

    #2. Si nous redéfinissons array_1 de sorte que 10 se produise deux fois, la fonction argmax() renvoie uniquement l’indice de la première occurrence.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Pour le reste des exemples, nous utiliserons les éléments de array_1 que nous avons définis dans l’exemple #1.

    Utilisation de NumPy argmax() pour trouver l’index de l’élément maximum dans un tableau 2D

    Remodelons le tableau NumPy array_1 en un tableau à deux dimensions avec deux lignes et quatre colonnes.

    array_2 = array_1.reshape(2,4)
    print(array_2)
    
    # Output
    [[ 1  5  7  2]
     [10  9  8  4]]

    Pour un tableau à deux dimensions, l’axe 0 désigne les lignes et l’axe 1 désigne les colonnes. Les tableaux NumPy suivent l’indexation zéro. Ainsi, les indices des lignes et des colonnes du tableau NumPy array_2 sont les suivants :

    Maintenant, appelons la fonction argmax() sur le tableau à deux dimensions, array_2.

    print(np.argmax(array_2))
    
    # Output
    4

    Même si nous avons appelé argmax() sur le tableau à deux dimensions, il renvoie toujours 4. Ceci est identique à la sortie du tableau à une dimension, array_1 de la section précédente.

    Pourquoi cela arrive-t-il? 🤔

    C’est parce que nous n’avons spécifié aucune valeur pour le paramètre d’axe. Lorsque ce paramètre d’axe n’est pas défini, par défaut, la fonction argmax() renvoie l’index de l’élément maximum le long du tableau aplati.

    Qu’est-ce qu’un tableau aplati ? S’il existe un tableau à N dimensions de forme d1 x d2 x … x dN, où d1, d2, jusqu’à dN sont les tailles du tableau le long des N dimensions, alors le tableau aplati est un long tableau unidimensionnel de taille d1 * d2 * … * dN.

    Pour vérifier à quoi ressemble le tableau aplati pour array_2, vous pouvez appeler la méthode flatten(), comme indiqué ci-dessous :

    array_2.flatten()
    
    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])

    Index de l’élément maximum le long des lignes (axe = 0)

    Continuons à trouver l’indice de l’élément maximum le long des lignes (axe = 0).

    np.argmax(array_2,axis=0)
    
    # Output
    array([1, 1, 1, 1])

    Cette sortie peut être un peu difficile à comprendre, mais nous comprendrons comment cela fonctionne.

    Nous avons défini le paramètre d’axe sur zéro (axe = 0), car nous aimerions trouver l’indice de l’élément maximum le long des lignes. Par conséquent, la fonction argmax() renvoie le numéro de la ligne dans laquelle l’élément maximal apparaît, pour chacune des trois colonnes.

    Visualisons cela pour mieux comprendre.

    D’après le diagramme ci-dessus et la sortie argmax(), nous avons ce qui suit :

    • Pour la première colonne à l’index 0, la valeur maximale 10 apparaît dans la deuxième ligne, à l’index = 1.
    • Pour la deuxième colonne à l’index 1, la valeur maximale 9 apparaît dans la deuxième ligne, à l’index = 1.
    • Pour les troisième et quatrième colonnes aux index 2 et 3, les valeurs maximales 8 et 4 apparaissent toutes les deux dans la deuxième ligne, à l’index = 1.

    C’est précisément pourquoi nous avons le tableau de sortie ([1, 1, 1, 1]) car l’élément maximal le long des lignes se produit dans la deuxième ligne (pour toutes les colonnes).

    Index de l’élément maximum le long des colonnes (axe = 1)

    Ensuite, utilisons la fonction argmax() pour trouver l’index de l’élément maximum le long des colonnes.

    Exécutez l’extrait de code suivant et observez la sortie.

    np.argmax(array_2,axis=1)
    array([2, 0])

    Pouvez-vous analyser la sortie ?

    Nous avons défini axe = 1 pour calculer l’indice de l’élément maximum le long des colonnes.

    La fonction argmax() renvoie, pour chaque ligne, le numéro de colonne dans laquelle se trouve la valeur maximale.

    Voici une explication visuelle :

    D’après le diagramme ci-dessus et la sortie argmax(), nous avons ce qui suit :

    • Pour la première ligne à l’index 0, la valeur maximale 7 apparaît dans la troisième colonne, à l’index = 2.
    • Pour la deuxième ligne à l’index 1, la valeur maximale 10 apparaît dans la première colonne, à l’index = 0.

    J’espère que vous comprenez maintenant ce que la sortie, array([2, 0]) moyens.

    Utilisation du paramètre optionnel out dans NumPy argmax()

    Vous pouvez utiliser le paramètre optionnel out the dans la fonction NumPy argmax() pour stocker la sortie dans un tableau NumPy.

    Initialisons un tableau de zéros pour stocker la sortie de l’appel de fonction argmax() précédent – pour trouver l’indice du maximum le long des colonnes (axe = 1).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]

    Maintenant, reprenons l’exemple de trouver l’indice de l’élément maximum le long des colonnes (axe = 1) et définissons out sur out_arr que nous avons défini ci-dessus.

    np.argmax(array_2,axis=1,out=out_arr)

    Nous voyons que l’interpréteur Python génère une TypeError, car le out_arr a été initialisé à un tableau de flottants par défaut.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
         56     try:
    ---> 57         return bound(*args, **kwds)
         58     except TypeError:
    
    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

    Par conséquent, lors de la définition du paramètre out dans le tableau de sortie, il est important de s’assurer que le tableau de sortie a la forme et le type de données corrects. Comme les indices de tableau sont toujours des entiers, nous devons définir le paramètre dtype sur int lors de la définition du tableau de sortie.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)
    
    # Output
    [0 0]

    Nous pouvons maintenant continuer et appeler la fonction argmax() avec les paramètres axis et out, et cette fois, elle s’exécute sans erreur.

    np.argmax(array_2,axis=1,out=out_arr)

    La sortie de la fonction argmax() est maintenant accessible dans le tableau out_arr.

    print(out_arr)
    # Output
    [2 0]

    Conclusion

    J’espère que ce tutoriel vous a aidé à comprendre comment utiliser la fonction NumPy argmax(). Vous pouvez exécuter les exemples de code dans un notebook Jupyter.

    Passons en revue ce que nous avons appris.

    • La fonction NumPy argmax() renvoie l’index de l’élément maximum dans un tableau. Si l’élément maximum apparaît plus d’une fois dans un tableau a, alors np.argmax(a) renvoie l’indice de la première occurrence de l’élément.
    • Lorsque vous travaillez avec des tableaux multidimensionnels, vous pouvez utiliser le paramètre d’axe facultatif pour obtenir l’indice de l’élément maximal le long d’un axe particulier. Par exemple, dans un tableau à deux dimensions : en définissant axe = 0 et axe = 1, vous pouvez obtenir l’indice de l’élément maximum le long des lignes et des colonnes, respectivement.
    • Si vous souhaitez stocker la valeur renvoyée dans un autre tableau, vous pouvez définir le paramètre optionnel out sur le tableau de sortie. Cependant, le tableau de sortie doit avoir une forme et un type de données compatibles.

    Ensuite, consultez le guide détaillé sur les ensembles Python.