Sie sind hier: Fortran > Fortran Programme > Strassen Matrixmultiplikation (Fortran)

Strassen Matrixmultiplikation (Fortran)

Freitag 21. Juli 2006 von
Simon Praetorius
Es soll eine schnelle Matrizenmultiplikation nach dem Algorithmus von Volker Strassen programmiert werden und dann mit der vorimplementierten und einer einfachen Matrixmultiplikation verglichen werden. Weitere Informationen zum Algorithmus finden sich unter Wikipedia

fortran Code
  • module strassen
  • implicit none
  •  
  • type t
  • real :: value
  • end type t
  •  
  • interface operator(.mal.)
  • module procedure strassen_matprod
  • end interface operator(.mal.)
  • interface operator(.x.)
  • module procedure simple_matprod
  • end interface operator(.x.)
  • interface operator(*)
  • module procedure mul_t
  • end interface operator(*)
  •  
  • contains
  •  
  • function mul_t(x,y) result(erg)
  • type(t),intent(in) :: x,y
  • type(t) :: erg
  •  
  • erg%value = x%value * y%value
  •  
  • end function mul_t
  •  
  • recursive function strassen_matprod(a,b) result(erg)
  • real,dimension(:,:),intent(in) :: a,b
  • real,dimension(size(a,1),size(a,1)) :: erg
  • real,dimension(size(a,1)/2,size(a,1)/2) :: m1,m2,m3,m4,m5,m6,m7, &
  • & a11,a12,a21,a22,b11,b12,b21,b22
  • integer :: s
  •  
  • s = size(a,1)/2
  • if(size(a,1) <= 64) then
  • erg = simple_matprod4(a,b)
  • else
  • a11 = a(1:s,1:s)
  • a12 = a(1:s,s+1:2*s)
  • a21 = a(s+1:2*s,1:s)
  • a22 = a(s+1:2*s,s+1:2*s)
  •  
  • b11 = b(1:s,1:s)
  • b12 = b(1:s,s+1:2*s)
  • b21 = b(s+1:2*s,1:s)
  • b22 = b(s+1:2*s,s+1:2*s)
  •  
  • m1 = (a12-a22) .mal. (b21+b22)
  • m2 = (a11+a22) .mal. (b11+b22)
  • m3 = (a11-a21) .mal. (b11+b12)
  • m4 = (a11+a12) .mal. b22
  • m5 = a11 .mal. (b12-b22)
  • m6 = a22 .mal. (b21-b11)
  • m7 = (a21+a22) .mal. b11
  •  
  • erg(1:s,1:s) = m1+m2-m4+m6
  • erg(1:s,s+1:2*s) = m4+m5
  • erg(s+1:2*s,1:s) = m6+m7
  • erg(s+1:2*s,s+1:2*s) = m2-m3+m5-m7
  •  
  • end if
  •  
  • end function strassen_matprod
  •  
  • function simple_matprod(a,b) result(erg)
  • real,dimension(:,:),intent(in) :: a,b
  • real,dimension(size(a,1),size(a,2)) :: erg
  • real :: summe
  • integer :: i,j,k
  •  
  • erg=0
  • do i=1,size(a,1)
  • do j=1,size(a,1)
  • do k=1,size(a,1)
  • erg(i,j) = erg(i,j) + a(i,k)*b(k,j)
  • end do
  • end do
  • end do
  • end function simple_matprod
  •  
  • function simple_matprod4(a,b) result(erg)
  • real,dimension(:,:),intent(in) :: a,b
  • real,dimension(size(a,1),size(a,2)) :: erg
  • integer :: i,j,k,n
  •  
  • erg=0
  • n = size(a,1)
  • do j=1,n
  • do k=1,n
  • do i=1,n
  • erg(i,j) = erg(i,j) + a(i,k)*b(k,j)
  • end do
  • end do
  • end do
  • end function simple_matprod4
  •  
  • end module strassen
  •  
  • program test
  • use strassen
  • implicit none
  •  
  • integer :: k,n,i,j,fehler
  • real,dimension(:,:),allocatable :: a,b,c1,c2,c3
  • real :: z
  •  
  • open(30,file='daten.dat',status='old')
  • write(*,*) 'k:'
  • read(*,*) k
  • write(30,*) k
  • n = 2**k
  • call random_seed()
  • do i=1,2*n*n
  • call random_number(z)
  • write(30,*) int(z*20)
  • end do
  • rewind(30)
  •  
  • read(30,*) k
  • n = 2**k
  •  
  • allocate(a(n,n),b(n,n),c1(n,n),c2(n,n),c3(n,n))
  •  
  • aussen: do i=1,n
  • do j=1,n
  • read(30,*,iostat=fehler) a(i,j)
  • if(fehler /= 0) exit aussen
  • end do
  • end do aussen
  • aussen2: do i=1,n
  • do j=1,n
  • read(30,*,iostat=fehler) b(i,j)
  • if(fehler /= 0) exit aussen2
  • end do
  • end do aussen2
  •  
  • write(*,*) dtime()
  • c1 = a .mal. b
  • write(*,*) 'strassen:',dtime()
  • c2 = simple_matprod4(a,b)
  • write(*,*) 'simple4:',dtime()
  • c3 = matmul(a,b)
  • write(*,*) 'matmul:',dtime()
  •  
  • write(*,*) all(c1 == c2 .and. c2 == c3)
  •  
  •  
  • end program test
In der Funktion simple_matprod wurde auch noch eine Optimierung vorgenommen. Durch geschicktes Vertauschen der Schleifenindizes kann der Algorithmus beschleunigt werden. Dies liegt unter anderem daran, wie Fortran die Matrizen im Speicher ablegt. durch vertauschen der Indizes wird erzeugt, dass bei der Addition zweiter Elemente aus den Matrizen im Speicher nebeneinander stehende Werte gelesen werden können. Da beim Lesen aus dem Speicher nicht nur jeweils eine Zahl gelesen wird, sondern ein ganzer Block von vielleicht 16 oder 64 Zahlen, reicht für die Addition von 16 bzw. 64 Elementen ein Lesezugriff aus. Durch diese Einsparung erhöht sich die Geschwindigkeit einer einfachen Matrix-Multiplikation enorm. In anderen Progrmmiersprachen (z.B. C) werden Matrizen andersherum im Speicher abgelegt, so dass dort eine entsprechend andere Optimierung vorgenommen werden muss.

Eine weitere Implementierung des Algorithmus habe ich in Scilab geschrieben.
Besucher: 16088 | Permalink | Kategorie: Fortran
Tags: , , ,

Kommentar hinzufügen

Dieses Feld bitten nicht ausfüllen: