Dans ce guide, vous découvrirez comment utiliser la fonction argmax()
de NumPy pour identifier l’index de l’élément le plus élevé dans un tableau.
NumPy est une librairie essentielle pour le calcul scientifique en Python. Elle offre des tableaux multidimensionnels qui surclassent les listes Python en termes de performance. Une tâche courante lors de la manipulation de tableaux NumPy est de trouver la valeur maximale. Il est cependant parfois nécessaire d’obtenir l’index où cette valeur maximale se situe.
La fonction argmax()
est idéale pour cela, que ce soit dans des tableaux à une ou plusieurs dimensions. Découvrons ensemble son fonctionnement.
Comment localiser l’index de la valeur maximale dans un tableau NumPy
Pour ce tutoriel, Python et NumPy doivent être installés. Vous pouvez démarrer un interpréteur Python REPL ou un notebook Jupyter.
Commençons par importer NumPy en utilisant son alias habituel np
.
import numpy as np
La fonction max()
de NumPy permet d’obtenir la valeur la plus élevée d’un tableau, potentiellement le long d’un axe spécifié.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# Résultat
10
Ici, np.max(array_1)
retourne 10, ce qui est correct.
Imaginons que vous souhaitiez savoir à quel index se trouve cette valeur maximale. Une approche possible serait :
- Identifier l’élément maximum.
- Trouver l’index de cet élément.
Dans array_1
, la valeur maximale de 10 se trouve à l’index 4 (en commençant l’indexation à zéro). Le premier élément est à l’index 0, le deuxième à l’index 1, et ainsi de suite.
Pour repérer l’index de cette valeur maximale, la fonction where()
de NumPy est utile. np.where(condition)
renvoie un tableau des index où la condition est vraie.
Il faut accéder au premier élément du tableau résultant. Pour identifier l’emplacement de la valeur maximale, la condition est array_1==10
(puisque 10 est la valeur maximale de array_1
).
print(int(np.where(array_1==10)[0]))
# Résultat
4
Bien que np.where()
fonctionne avec la condition seule, ce n’est pas son utilisation la plus courante.
📑 Note : Fonction where()
de NumPy :np.where(condition,x,y)
retourne :
- Les éléments de
x
où la condition est vraie, et - Les éléments de
y
où la condition est fausse.
Ainsi, en combinant np.max()
et np.where()
, on pourrait trouver la valeur maximale et son index.
Toutefois, la fonction argmax()
de NumPy permet d’obtenir directement l’index de l’élément maximal.
Syntaxe de la fonction argmax()
de NumPy
La syntaxe générale de argmax()
est :
np.argmax(array,axis,out)
# numpy est importé avec l'alias np
Dans cette syntaxe :
array
est un tableau NumPy valide.axis
(optionnel) permet de chercher l’index du maximum selon un axe spécifique dans les tableaux multidimensionnels.out
(optionnel) permet de stocker le résultat deargmax()
dans un tableau NumPy.
Remarque : depuis la version 1.22.0 de NumPy, un paramètre keepdims
est disponible. Quand axis
est spécifié, le tableau est réduit le long de cet axe. Si keepdims
est mis à True
, la sortie conserve la forme du tableau d’entrée.
Utilisation de argmax()
pour trouver l’index du maximum
#1. Utilisons argmax()
pour identifier l’index du maximum dans array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# Résultat
4
argmax()
retourne 4, ce qui est juste ! ✅
#2. Si 10 apparaît deux fois dans array_1
, argmax()
retournera l’index de la première occurrence.
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# Résultat
4
Pour la suite, nous utiliserons les données de array_1
de l’exemple #1.
argmax()
dans un tableau 2D
Redimensionnons array_1
en un tableau 2D (2 lignes et 4 colonnes).
array_2 = array_1.reshape(2,4)
print(array_2)
# Résultat
[[ 1 5 7 2]
[10 9 8 4]]
Dans un tableau 2D, l’axe 0 correspond aux lignes et l’axe 1 aux colonnes. Les index commencent à zéro. Donc, les index de array_2
sont :
Appliquons maintenant argmax()
à array_2
.
print(np.argmax(array_2))
# Résultat
4
Même appliqué à un tableau 2D, argmax()
retourne 4, comme pour array_1
. Pourquoi ? 🤔
Car le paramètre axis
n’a pas été spécifié. Par défaut, argmax()
retourne l’index de l’élément maximal sur le tableau « aplati ».
Qu’est-ce qu’un tableau aplati ? Si un tableau N-dimensionnel a pour forme d1 x d2 x … x dN, le tableau aplati est un long tableau unidimensionnel de taille d1 * d2 * … * dN.
Pour voir le tableau aplati de array_2
, utilisez la méthode flatten()
:
array_2.flatten()
# Résultat
array([ 1, 5, 7, 2, 10, 9, 8, 4])
Index du maximum le long des lignes (axis = 0
)
Recherchons l’index du maximum le long des lignes (axis = 0
).
np.argmax(array_2,axis=0)
# Résultat
array([1, 1, 1, 1])
Cette sortie peut paraître déroutante, décryptons-la.
axis = 0
indique que nous cherchons l’index de la ligne où l’élément maximal apparaît, pour chaque colonne.
Visualisons cela pour une meilleure compréhension.
D’après le schéma et le résultat de argmax()
:
- Dans la première colonne (index 0), le maximum (10) est sur la deuxième ligne (index 1).
- Dans la deuxième colonne (index 1), le maximum (9) est aussi sur la deuxième ligne (index 1).
- Dans les troisième et quatrième colonnes (index 2 et 3), les maximums (8 et 4) sont également sur la deuxième ligne (index 1).
D’où le résultat ([1, 1, 1, 1])
, car le maximum le long des lignes est toujours sur la deuxième ligne (pour chaque colonne).
Index du maximum le long des colonnes (axis = 1
)
Cherchons maintenant l’index du maximum le long des colonnes en utilisant argmax()
.
Exécutez ce code et observez le résultat.
np.argmax(array_2,axis=1)
array([2, 0])
Pouvez-vous interpréter cette sortie ?
axis = 1
signifie que nous calculons l’index de l’élément maximum pour chaque ligne.
argmax()
retourne le numéro de la colonne contenant la valeur maximale, pour chaque ligne.
Voici une explication visuelle :
D’après le schéma et le résultat de argmax()
:
- Dans la première ligne (index 0), le maximum (7) est sur la troisième colonne (index 2).
- Dans la deuxième ligne (index 1), le maximum (10) est sur la première colonne (index 0).
Le résultat array([2, 0])
est donc compréhensible.
Utilisation du paramètre optionnel out
de argmax()
Le paramètre out
permet de stocker le résultat de argmax()
dans un tableau NumPy.
Initialisons un tableau de zéros pour stocker le résultat de l’exemple précédent (index du maximum le long des colonnes, axis = 1
).
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
Réutilisons l’exemple de l’index du maximum le long des colonnes (axis = 1
) et définissons out
sur out_arr
.
np.argmax(array_2,axis=1,out=out_arr)
Python affiche une TypeError
car out_arr
est initialisé comme un tableau de flottants.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Il est crucial que le tableau out
ait la forme et le type de données appropriés. Puisque les index sont toujours des entiers, il faut définir dtype
à int
lors de la création de out_arr
.
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# Résultat
[0 0]
Appelons maintenant argmax()
avec axis
et out
. Cette fois, il s’exécute sans erreur.
np.argmax(array_2,axis=1,out=out_arr)
La sortie de argmax()
est maintenant stockée dans out_arr
.
print(out_arr)
# Résultat
[2 0]
Conclusion
Ce tutoriel vous a montré comment utiliser argmax()
de NumPy. Vous pouvez tester ces exemples dans un notebook Jupyter.
Récapitulons les points clés :
argmax()
retourne l’index du maximum dans un tableau. Si le maximum apparaît plusieurs fois,np.argmax(a)
retourne l’index de la première occurrence.- Dans les tableaux multidimensionnels,
axis
permet d’obtenir l’index du maximum le long d’un axe spécifique. Par exemple, dans un tableau 2D,axis = 0
donne l’index le long des lignes, etaxis = 1
le long des colonnes. - Le paramètre optionnel
out
permet de stocker le résultat dans un autre tableau, qui doit être de la forme et du type appropriés.
Pour aller plus loin, consultez notre guide détaillé sur les ensembles en Python.